結果

問題 No.1677 mæx
ユーザー lam6er
提出日時 2025-04-15 22:00:15
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 194 ms / 2,000 ms
コード長 2,259 bytes
コンパイル時間 293 ms
コンパイル使用メモリ 81,588 KB
実行使用メモリ 82,172 KB
最終ジャッジ日時 2025-04-15 22:01:23
合計ジャッジ時間 4,302 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 18
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def mex(a, b):
    if 0 not in {a, b}:
        return 0
    if 1 not in {a, b}:
        return 1
    return 2

S = input().strip()
K = int(input())

stack = []
i = 0
n = len(S)

while i < n:
    if S[i] == 'm' and i + 3 < n and S[i+3] == '(':
        second_char = S[i+1]
        max_ways = 0
        mex_ways = 0
        if second_char in ['a', '?']:
            max_ways += 1
        if second_char in ['e', '?']:
            mex_ways += 1
        stack.append(('function', max_ways, mex_ways))
        stack.append(('(',))
        i += 4
    elif S[i] in ['0', '1', '2', '?']:
        if S[i] == '0':
            dp = (1, 0, 0)
        elif S[i] == '1':
            dp = (0, 1, 0)
        elif S[i] == '2':
            dp = (0, 0, 1)
        else:
            dp = (1, 1, 1)
        stack.append(('digit', dp))
        i += 1
    elif S[i] == ',':
        stack.append(('comma',))
        i += 1
    elif S[i] == ')':
        args = []
        while stack:
            elem = stack.pop()
            if elem[0] == '(':
                break
            args.append(elem)
        comma_pos = -1
        for j in range(len(args)):
            if args[j][0] == 'comma':
                comma_pos = j
                break
        first_arg = args[comma_pos + 1]
        second_arg = args[0]
        first_dp = first_arg[1]
        second_dp = second_arg[1]
        function_marker = stack.pop()
        max_ways = function_marker[1]
        mex_ways = function_marker[2]
        new_dp = [0, 0, 0]
        if max_ways > 0:
            for a in range(3):
                for b in range(3):
                    val = max(a, b)
                    count = (first_dp[a] * second_dp[b]) % MOD
                    new_dp[val] = (new_dp[val] + count * max_ways) % MOD
        if mex_ways > 0:
            for a in range(3):
                for b in range(3):
                    val = mex(a, b)
                    count = (first_dp[a] * second_dp[b]) % MOD
                    new_dp[val] = (new_dp[val] + count * mex_ways) % MOD
        stack.append(('digit', (new_dp[0], new_dp[1], new_dp[2])))
        i += 1
    else:
        i += 1

if len(stack) != 1 or stack[0][0] != 'digit':
    print(0)
else:
    ans = stack[0][1][K] % MOD
    print(ans)
0