結果

問題 No.1239 Multiplication -2
ユーザー lam6er
提出日時 2025-04-09 21:05:09
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 215 ms / 2,000 ms
コード長 3,674 bytes
コンパイル時間 431 ms
コンパイル使用メモリ 82,708 KB
実行使用メモリ 133,292 KB
最終ジャッジ日時 2025-04-09 21:06:41
合計ジャッジ時間 6,183 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 34
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    input = sys.stdin.read().split()
    N = int(input[0])
    a = list(map(int, input[1:N+1]))
    
    # Precompute powers of 2 and their modular inverses up to N+2
    max_pow = N + 2
    pow_2 = [1] * (max_pow + 1)
    for i in range(1, max_pow + 1):
        pow_2[i] = (pow_2[i-1] * 2) % MOD
    inv_2 = pow(2, MOD-2, MOD)
    inv_pow2 = [1] * (max_pow + 1)
    for i in range(1, max_pow + 1):
        inv_pow2[i] = (inv_pow2[i-1] * inv_2) % MOD
    
    # Split the array into groups of non-zero elements
    groups = []
    current_group = []
    current_L = None
    for idx in range(N):
        val = a[idx]
        if val != 0:
            if current_L is None:
                current_L = idx + 1  # 1-based
                current_group = [val]
            else:
                current_group.append(val)
        else:
            if current_L is not None:
                current_R = idx  # 0-based, converted to 1-based below
                groups.append((current_group, current_L, current_L + len(current_group) - 1))
                current_L = None
    if current_L is not None:
        current_R = current_L + len(current_group) - 1
        groups.append((current_group, current_L, current_R))
    
    total = 0
    for group, L, R in groups:
        m = len(group)
        if m == 0:
            continue
        
        # Precompute term1 for l in 1..m
        term1 = [0] * (m + 2)
        if L == 1:
            term1[1] = 1
            for l in range(2, m + 1):
                term1[l] = inv_2
        else:
            for l in range(1, m + 1):
                term1[l] = inv_2
        
        # Precompute term2 for r in 1..m
        term2 = [0] * (m + 2)
        if R < N:
            for r in range(1, m + 1):
                term2[r] = inv_2
        else:
            for r in range(1, m + 1):
                term2[r] = inv_2 if r < m else 1
        
        track = {}
        current_a = 0
        current_b = 0
        # Initialize track with prefix 0 (l=1)
        initial_a = 0
        initial_b = 0
        key = (initial_a % 2, initial_b)
        l = 1
        value = (term1[l] * pow_2[l]) % MOD
        track[key] = track.get(key, 0) + value
        track[key] %= MOD
        
        group_sum = 0
        current_a = 0
        current_b = 0
        for r in range(1, m + 1):
            elem = group[r-1]
            da, db = 0, 0
            if elem == 1:
                da, db = 0, 0
            elif elem == -1:
                da, db = 1, 0
            elif elem == 2:
                da, db = 0, 1
            elif elem == -2:
                da, db = 1, 1
            # Update current_a and current_b
            current_a = (current_a + da) % 2
            current_b = (current_b + db) % MOD
            # Compute needed key
            needed_a = (current_a - 1) % 2
            needed_b = (current_b - 1) % MOD
            key_needed = (needed_a, needed_b)
            sum_prev = track.get(key_needed, 0)
            # Calculate contribution
            contrib = sum_prev * term2[r] % MOD
            contrib = contrib * inv_pow2[r] % MOD
            group_sum = (group_sum + contrib) % MOD
            # Update track if not last element
            if r < m:
                l_next = r + 1
                a_current = current_a % 2
                b_current = current_b % MOD
                key_current = (a_current, b_current)
                value = (term1[l_next] * pow_2[l_next]) % MOD
                track[key_current] = (track.get(key_current, 0) + value) % MOD
        total = (total + group_sum) % MOD
    print(total % MOD)

if __name__ == "__main__":
    main()
0