結果

問題 No.1853 Many Operations
コンテスト
ユーザー 回転
提出日時 2026-05-24 23:07:53
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
TLE  
実行時間 -
コード長 6,502 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 185 ms
コンパイル使用メモリ 85,024 KB
実行使用メモリ 135,548 KB
最終ジャッジ日時 2026-05-24 23:08:17
合計ジャッジ時間 5,622 ms
ジャッジサーバーID
(参考情報)
judge1_0 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3 TLE * 1
other -- * 26
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

class BMBM:
    """
    Berlekamp-Massey 法 + Bostan-Mori 法
    https://atcoder.jp/contests/tdpc/submissions/75993697
    
    与えられた数列の最初の数項から線形漸化式を自動で推測し、
    その漸化式に従う N 項目を推測する Black Box ライブラリ。
    引数 use_ntt=True を指定することで O(K log K log N) の NTT 乗算に切り替え可能。
    """
    
    @staticmethod
    def convolution(a: list[int], b: list[int], mod: int = 998244353) -> list[int]:
        """
        mod 998244353 における高速フーリエ変換 (NTT) を用いた多項式乗算
        """
        if not a or not b:
            return []
            
        # NTTは 998244353 専用。それ以外のmodの場合は愚直O(K^2)にフォールバック
        if mod != 998244353:
            res = [0] * (len(a) + len(b) - 1)
            for i, va in enumerate(a):
                for j, vb in enumerate(b):
                    res[i+j] = (res[i+j] + va * vb) % mod
            return res

        n = 1
        while n < len(a) + len(b) - 1:
            n <<= 1
            
        A = a + [0] * (n - len(a))
        B = b + [0] * (n - len(b))
        
        def ntt(arr: list[int], inv: bool):
            j = 0
            for i in range(1, n):
                bit = n >> 1
                while j & bit:
                    j ^= bit
                    bit >>= 1
                j ^= bit
                if i < j:
                    arr[i], arr[j] = arr[j], arr[i]
                    
            step = 1
            while step < n:
                zeta = pow(3, (mod - 1) // (2 * step), mod)
                if inv:
                    zeta = pow(zeta, mod - 2, mod)
                
                w = [1] * step
                for i in range(1, step):
                    w[i] = (w[i-1] * zeta) % mod
                    
                for i in range(0, n, 2 * step):
                    for j in range(step):
                        u = arr[i + j]
                        v = (arr[i + j + step] * w[j]) % mod
                        arr[i + j] = (u + v) % mod
                        arr[i + j + step] = (u - v) % mod
                step <<= 1
                
            if inv:
                inv_n = pow(n, mod - 2, mod)
                for i in range(n):
                    arr[i] = (arr[i] * inv_n) % mod

        ntt(A, False)
        ntt(B, False)
        C = [(x * y) % mod for x, y in zip(A, B)]
        ntt(C, True)
        
        return C[:len(a) + len(b) - 1]

    @staticmethod
    def berlekamp_massey(a: list[int], mod: int) -> list[int]:
        """
        数列 a から最短の線形漸化式(分母多項式 Q)を復元する
        """
        c = [1]
        b = [1]
        L = 0
        m = 1
        b_val = 1
        for i in range(len(a)):
            d = 0
            for j in range(L + 1):
                d = (d + c[j] * a[i - j]) % mod
            if d == 0:
                m += 1
            else:
                tmp = c[:]
                c_val = (d * pow(b_val, mod - 2, mod)) % mod
                while len(c) <= len(b) + m:
                    c.append(0)
                for j in range(len(b)):
                    c[j + m] = (c[j + m] - c_val * b[j]) % mod
                if 2 * L <= i:
                    L = i + 1 - L
                    b = tmp
                    b_val = d
                    m = 1
                else:
                    m += 1
        return c[:L + 1]

    @staticmethod
    def bostan_mori(p: list[int], q: list[int], n: int, mod: int, use_ntt: bool = False) -> int:
        """
        母関数 P(x)/Q(x) の N 次の係数を求める
        """
        while n > 0:
            q_neg = [v if i % 2 == 0 else -v for i, v in enumerate(q)]
            
            if use_ntt and mod == 998244353:
                u = BMBM.convolution(p, q_neg, mod)
                v = BMBM.convolution(q, q_neg, mod)
            else:
                # 愚直 O(K^2)
                u = [0] * (len(p) + len(q) - 1)
                for i in range(len(p)):
                    pi = p[i]
                    if not pi: continue
                    for j in range(len(q)):
                        u[i + j] = (u[i + j] + pi * q_neg[j]) % mod
                        
                v = [0] * (len(q) + len(q) - 1)
                for i in range(len(q)):
                    qi = q[i]
                    if not qi: continue
                    for j in range(len(q)):
                        v[i + j] = (v[i + j] + qi * q_neg[j]) % mod
                    
            if n % 2 == 1:
                p = [u[i] for i in range(1, len(u), 2)]
            else:
                p = [u[i] for i in range(0, len(u), 2)]
                
            q = [v[i] for i in range(0, len(v), 2)]
            n //= 2
            
        return p[0] if len(p) > 0 else 0

    @classmethod
    def guess_kth_term(cls, seq: list[int], k: int, mod: int = 998244353, use_ntt: bool = False) -> int:
        """
        数列 seq の k 項目 (0-indexed) を推測して返す
        use_ntt=True で O(K log K log N) に切り替え可能
        """
        # 1. 漸化式(分母多項式 Q)を特定
        q = cls.berlekamp_massey(seq, mod)
        
        # 2. 分子多項式 P(x) = (seq(x) * Q(x)) mod x^{len(Q)-1} を計算
        if use_ntt and mod == 998244353:
            p_full = cls.convolution(seq[:len(q)-1], q, mod)
            p = p_full[:len(q)-1]
        else:
            p = [0] * (len(q) - 1)
            for i in range(len(p)):
                val = 0
                for j in range(i + 1):
                    if j < len(q) and i - j < len(seq):
                        val = (val + seq[i - j] * q[j]) % mod
                p[i] = val
            
        # 3. Bostan-Mori で k 項目へ一気に飛ぶ
        return cls.bostan_mori(p, q, k, mod, use_ntt)

from collections import deque
MOD = 998244353
N = int(input())

L = 3 * 10**5
def naive(n):
    q = deque([(0,0)])
    INF = 10**18
    dist = [INF] * (n+1)
    while(q):
        v,now = q.popleft()

        if(dist[now] <= v):continue
        dist[now] = v

        if(0 <= now+1 <= n):q.append((v+1,now+1))
        if(0 <= now-1 <= n):q.append((v+1,now-1))
        if(0 <= now*2 <= n):q.append((v+1,now*2))
    
    for i in range(n-1):
        dist[i+1] += dist[i]
    return dist

nai = naive(10**6)

if(N <= L):
    print(nai[N])
else:
    print(BMBM.guess_kth_term(nai[L:], N-L, MOD, use_ntt=True))
0