結果

問題 No.1677 mæx
ユーザー lam6er
提出日時 2025-04-15 22:04:27
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 711 ms / 2,000 ms
コード長 3,179 bytes
コンパイル時間 346 ms
コンパイル使用メモリ 81,768 KB
実行使用メモリ 179,832 KB
最終ジャッジ日時 2025-04-15 22:06:24
合計ジャッジ時間 10,675 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 18
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def solve():
    import sys
    sys.setrecursionlimit(1 << 25)
    S = sys.stdin.readline().strip()
    K = int(sys.stdin.readline())
    memo = {}

    def parse(pos):
        if pos >= len(S):
            return None, pos
        if pos in memo:
            return memo[pos]
        c = S[pos]
        if c in ['0', '1', '2', '?']:
            dp = [0] * 3
            if c == '?':
                dp[0] = 1
                dp[1] = 1
                dp[2] = 1
            else:
                dp[int(c)] = 1
            memo[pos] = (dp, pos + 1)
            return (dp, pos + 1)
        elif c == 'm':
            if pos + 3 > len(S):
                memo[pos] = (None, pos)
                return (None, pos)
            func_part = S[pos:pos+3]
            total_dp = [0] * 3
            total_end = None
            for func in ['max', 'mex']:
                ways = 1
                for i in range(3):
                    current_char = func_part[i]
                    required_char = func[i]
                    if current_char != '?' and current_char != required_char:
                        ways = 0
                        break
                if ways == 0:
                    continue
                if pos + 3 >= len(S) or S[pos+3] != '(':
                    continue
                a_start = pos + 4
                a_dp, a_end = parse(a_start)
                if a_dp is None:
                    continue
                if a_end >= len(S) or S[a_end] != ',':
                    continue
                b_start = a_end + 1
                b_dp, b_end = parse(b_start)
                if b_dp is None:
                    continue
                if b_end >= len(S) or S[b_end] != ')':
                    continue
                func_dp = [0] * 4
                for a in range(3):
                    if a_dp[a] == 0:
                        continue
                    for b in range(3):
                        if b_dp[b] == 0:
                            continue
                        if func == 'max':
                            res = max(a, b)
                        else:
                            s_vals = {a, b}
                            res = 0
                            while res in s_vals:
                                res += 1
                        cnt = (a_dp[a] * b_dp[b]) % MOD
                        cnt = (cnt * ways) % MOD
                        if res < 3:
                            func_dp[res] = (func_dp[res] + cnt) % MOD
                for i in range(3):
                    total_dp[i] = (total_dp[i] + func_dp[i]) % MOD
                current_end = b_end + 1
                if total_end is None or current_end > total_end:
                    total_end = current_end
            if total_end is None:
                memo[pos] = (None, pos)
                return (None, pos)
            memo[pos] = (total_dp, total_end)
            return (total_dp, total_end)
        else:
            memo[pos] = (None, pos)
            return (None, pos)
    
    root_dp, end = parse(0)
    if root_dp is None or end != len(S):
        print(0)
    else:
        print(root_dp[K] % MOD)

solve()
0