結果

問題 No.1677 mæx
ユーザー lam6er
提出日時 2025-03-31 17:50:49
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,778 bytes
コンパイル時間 193 ms
コンパイル使用メモリ 82,668 KB
実行使用メモリ 422,708 KB
最終ジャッジ日時 2025-03-31 17:51:49
合計ジャッジ時間 4,088 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def main():
    S = sys.stdin.readline().strip()
    K = int(sys.stdin.readline())
    
    from functools import lru_cache
    
    @lru_cache(maxsize=None)
    def dp(start, end):
        s_part = S[start:end]
        if len(s_part) == 1:
            if s_part[0] == '0':
                return (1, 0, 0)
            elif s_part[0] == '1':
                return (0, 1, 0)
            elif s_part[0] == '2':
                return (0, 0, 1)
            elif s_part[0] == '?':
                return (1, 1, 1)
            else:
                return (0, 0, 0)
        else:
            if not (len(s_part) >= 5 and s_part.startswith('m') and s_part[2] == 'x' and s_part[3] == '(' and s_part[-1] == ')'):
                return (0, 0, 0)
            c = s_part[1]
            if c not in ['a', 'e', '?']:
                return (0, 0, 0)
            possible_funcs = []
            if c == 'a':
                possible_funcs = ['max']
            elif c == 'e':
                possible_funcs = ['mex']
            else:
                possible_funcs = ['max', 'mex']
            
            inner_str = s_part[4:-1]
            comma_pos = find_split_comma(inner_str)
            if comma_pos == -1:
                return (0, 0, 0)
            
            a_start = start + 4
            a_end = a_start + comma_pos
            b_start = a_end + 1
            b_end = start + 4 + len(inner_str)
            
            a_counts = dp(a_start, a_end)
            b_counts = dp(b_start, b_end)
            
            res = [0, 0, 0]
            for func in possible_funcs:
                func_multiplier = 1
                for a_val in range(3):
                    for b_val in range(3):
                        if func == 'max':
                            val = max(a_val, b_val)
                        else:
                            s = {a_val, b_val}
                            val = 0
                            while val in s:
                                val += 1
                            if val > 2:
                                val = 0
                        contrib = (a_counts[a_val] * b_counts[b_val]) % MOD
                        contrib = (contrib * func_multiplier) % MOD
                        res[val] = (res[val] + contrib) % MOD
            return tuple(res)
    
    def find_split_comma(s):
        balance = 0
        for i, c in enumerate(s):
            if c == '(':
                balance += 1
            elif c == ')':
                balance -= 1
                if balance < 0:
                    return -1
            elif c == ',' and balance == 0:
                return i
        return -1
    
    result = dp(0, len(S))
    print(result[K] % MOD)

if __name__ == '__main__':
    main()
0