結果

問題 No.1677 mæx
ユーザー lam6er
提出日時 2025-03-26 15:48:53
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,495 bytes
コンパイル時間 168 ms
コンパイル使用メモリ 82,744 KB
実行使用メモリ 54,636 KB
最終ジャッジ日時 2025-03-26 15:50:01
合計ジャッジ時間 4,004 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def main():
    sys.setrecursionlimit(1 << 25)
    s = sys.stdin.readline().strip()
    K = int(sys.stdin.readline())

    class Node:
        def __init__(self):
            self.dp = [0, 0, 0]

    class DigitNode(Node):
        def __init__(self, c):
            super().__init__()
            self.c = c
            if c == '0':
                self.dp[0] = 1
            elif c == '1':
                self.dp[1] = 1
            elif c == '2':
                self.dp[2] = 1
            else:  # '?'
                self.dp = [1, 1, 1]

    class FunctionNode(Node):
        def __init__(self, possible_types, left, right):
            super().__init__()
            self.possible_types = possible_types  # list of (type, count)
            self.left = left
            self.right = right

    def parse(s):
        n = len(s)
        stack = []
        i = 0
        while i < n:
            if s[i] == 'm':
                func_char2 = s[i+1] if i+1 < n else '?'
                func_types = []
                if func_char2 == 'a':
                    func_types = [('max', 1)]
                elif func_char2 == 'e':
                    func_types = [('mex', 1)]
                elif func_char2 == '?':
                    func_types = [('max', 1), ('mex', 1)]
                i += 3  # skip 'm?x'
                # Now parse arguments inside ()
                i += 1  # skip '('
                balance = 1
                comma_pos = -1
                start = i
                while i < n:
                    if s[i] == '(':
                        balance += 1
                    elif s[i] == ')':
                        balance -= 1
                        if balance == 0:
                            break
                    elif s[i] == ',' and balance == 1 and comma_pos == -1:
                        comma_pos = i
                    i += 1
                left_str = s[start:comma_pos]
                right_str = s[comma_pos+1:i]
                # Parse left and right
                left = parse(left_str)
                right = parse(right_str)
                node = FunctionNode(func_types, left, right)
                stack.append(node)
                i += 1  # move past ')'
            elif s[i] in {'0', '1', '2', '?'}:
                node = DigitNode(s[i])
                stack.append(node)
                i += 1
            else:
                i += 1
        return stack[0] if stack else None

    root = parse(s)

    mex_table = [
        [1, 2, 1],
        [2, 0, 0],
        [1, 0, 0]
    ]

    def compute_dp(node):
        if isinstance(node, DigitNode):
            return
        if isinstance(node, FunctionNode):
            compute_dp(node.left)
            compute_dp(node.right)
            left_dp = node.left.dp
            right_dp = node.right.dp
            new_dp = [0, 0, 0]
            for func_type, count in node.possible_types:
                for a in range(3):
                    for b in range(3):
                        if func_type == 'max':
                            res = max(a, b)
                        else:
                            res = mex_table[a][b]
                        contrib = left_dp[a] * right_dp[b] % MOD
                        contrib = contrib * count % MOD
                        new_dp[res] = (new_dp[res] + contrib) % MOD
            node.dp = new_dp

    compute_dp(root)
    print(root.dp[K] % MOD if root else 0)

if __name__ == "__main__":
    main()
0