結果

問題 No.3349 AtCoder Janken Train
コンテスト
ユーザー りすりす/TwoSquirrels
提出日時 2025-11-01 02:18:55
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 12,063 bytes
コンパイル時間 246 ms
コンパイル使用メモリ 82,856 KB
実行使用メモリ 360,528 KB
最終ジャッジ日時 2025-11-13 21:10:36
合計ジャッジ時間 3,914 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2 TLE * 1
other -- * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

# Converted by Gemini 2.5 Pro

import sys
import typing

# -----------------------------------------------------------------
# 依存ライブラリのインライン展開 (ac-library-python)
# -----------------------------------------------------------------

#
# _bit.py
#
def _ceil_pow2(n: int) -> int:
    x = 0
    while (1 << x) < n:
        x += 1
    return x

def _bsf(n: int) -> int:
    x = 0
    while n % 2 == 0:
        x += 1
        n //= 2
    return x

#
# _math.py
#
def _inv_gcd(a: int, b: int) -> typing.Tuple[int, int]:
    a %= b
    if a == 0:
        return (b, 0)

    s = b
    t = a
    m0 = 0
    m1 = 1

    while t:
        u = s // t
        s -= t * u
        m0 -= m1 * u

        s, t = t, s
        m0, m1 = m1, m0

    if m0 < 0:
        m0 += b // s
    return (s, m0)

def _primitive_root(m: int) -> int:
    if m == 2:
        return 1
    if m == 167772161:
        return 3
    if m == 469762049:
        return 3
    if m == 754974721:
        return 11
    if m == 998244353:
        return 3

    divs = [2] + [0] * 19
    cnt = 1
    x = (m - 1) // 2
    while x % 2 == 0:
        x //= 2

    i = 3
    while i * i <= x:
        if x % i == 0:
            divs[cnt] = i
            cnt += 1
            while x % i == 0:
                x //= i
        i += 2

    if x > 1:
        divs[cnt] = x
        cnt += 1

    g = 2
    while True:
        for i in range(cnt):
            if pow(g, (m - 1) // divs[i], m) == 1:
                break
        else:
            return g
        g += 1

#
# modint.py
#
class ModContext:
    context: typing.List[int] = []

    def __init__(self, mod: int) -> None:
        assert 1 <= mod
        self.mod = mod

    def __enter__(self) -> None:
        self.context.append(self.mod)

    def __exit__(self, exc_type: typing.Any, exc_value: typing.Any,
                 traceback: typing.Any) -> None:
        self.context.pop()

    @classmethod
    def get_mod(cls) -> int:
        return cls.context[-1]

class Modint:
    def __init__(self, v: int = 0) -> None:
        self._mod = ModContext.get_mod()
        if v == 0:
            self._v = 0
        else:
            self._v = v % self._mod

    def mod(self) -> int:
        return self._mod

    def val(self) -> int:
        return self._v

    def __iadd__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            self._v += rhs._v
        else:
            self._v += rhs
        if self._v >= self._mod:
            self._v -= self._mod
        return self

    def __isub__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            self._v -= rhs._v
        else:
            self._v -= rhs
        if self._v < 0:
            self._v += self._mod
        return self

    def __imul__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            self._v = self._v * rhs._v % self._mod
        else:
            self._v = self._v * rhs % self._mod
        return self

    def __ifloordiv__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            inv = rhs.inv()._v
        else:
            inv = _inv_gcd(rhs, self._mod)[1] # 修正: atcoder._math._inv_gcd -> _inv_gcd
        self._v = self._v * inv % self._mod
        return self

    def __pos__(self) -> 'Modint':
        return self

    def __neg__(self) -> 'Modint':
        return Modint() - self

    def __pow__(self, n: int) -> 'Modint':
        assert 0 <= n
        return Modint(pow(self._v, n, self._mod))

    def inv(self) -> 'Modint':
        eg = _inv_gcd(self._v, self._mod) # 修正: atcoder._math._inv_gcd -> _inv_gcd
        assert eg[0] == 1
        return Modint(eg[1])

    def __add__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            result = self._v + rhs._v
            if result >= self._mod:
                result -= self._mod
            return raw(result)
        else:
            return Modint(self._v + rhs)

    def __sub__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            result = self._v - rhs._v
            if result < 0:
                result += self._mod
            return raw(result)
        else:
            return Modint(self._v - rhs)

    def __mul__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            return Modint(self._v * rhs._v)
        else:
            return Modint(self._v * rhs)

    def __floordiv__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            inv = rhs.inv()._v
        else:
            inv = _inv_gcd(rhs, self._mod)[1] # 修正: atcoder._math._inv_gcd -> _inv_gcd
        return Modint(self._v * inv)

    def __eq__(self, rhs: typing.Union['Modint', int]) -> bool:
        if isinstance(rhs, Modint):
            return self._v == rhs._v
        else:
            return self._v == rhs

    def __ne__(self, rhs: typing.Union['Modint', int]) -> bool:
        if isinstance(rhs, Modint):
            return self._v != rhs._v
        else:
            return self._v != rhs

def raw(v: int) -> Modint:
    x = Modint()
    x._v = v
    return x

#
# convolution.py
#
_sum_e: typing.Dict[int, typing.List[Modint]] = {}

def _butterfly(a: typing.List[Modint]) -> None:
    g = _primitive_root(a[0].mod()) # 修正
    n = len(a)
    h = _ceil_pow2(n) # 修正

    if a[0].mod() not in _sum_e:
        es = [Modint(0)] * 30
        ies = [Modint(0)] * 30
        cnt2 = _bsf(a[0].mod() - 1) # 修正
        e = Modint(g) ** ((a[0].mod() - 1) >> cnt2)
        ie = e.inv()
        for i in range(cnt2, 1, -1):
            es[i - 2] = e
            ies[i - 2] = ie
            e = e * e
            ie = ie * ie
        sum_e = [Modint(0)] * 30
        now = Modint(1)
        for i in range(cnt2 - 2):
            sum_e[i] = es[i] * now
            now *= ies[i]
        _sum_e[a[0].mod()] = sum_e
    else:
        sum_e = _sum_e[a[0].mod()]

    for ph in range(1, h + 1):
        w = 1 << (ph - 1)
        p = 1 << (h - ph)
        now = Modint(1)
        for s in range(w):
            offset = s << (h - ph + 1)
            for i in range(p):
                left = a[i + offset]
                right = a[i + offset + p] * now
                a[i + offset] = left + right
                a[i + offset + p] = left - right
            now *= sum_e[_bsf(~s)] # 修正

_sum_ie: typing.Dict[int, typing.List[Modint]] = {}

def _butterfly_inv(a: typing.List[Modint]) -> None:
    g = _primitive_root(a[0].mod()) # 修正
    n = len(a)
    h = _ceil_pow2(n) # 修正

    if a[0].mod() not in _sum_ie:
        es = [Modint(0)] * 30
        ies = [Modint(0)] * 30
        cnt2 = _bsf(a[0].mod() - 1) # 修正
        e = Modint(g) ** ((a[0].mod() - 1) >> cnt2)
        ie = e.inv()
        for i in range(cnt2, 1, -1):
            es[i - 2] = e
            ies[i - 2] = ie
            e = e * e
            ie = ie * ie
        sum_ie = [Modint(0)] * 30
        now = Modint(1)
        for i in range(cnt2 - 2):
            sum_ie[i] = ies[i] * now
            now *= es[i]
        _sum_ie[a[0].mod()] = sum_ie
    else:
        sum_ie = _sum_ie[a[0].mod()]

    for ph in range(h, 0, -1):
        w = 1 << (ph - 1)
        p = 1 << (h - ph)
        inow = Modint(1)
        for s in range(w):
            offset = s << (h - ph + 1)
            for i in range(p):
                left = a[i + offset]
                right = a[i + offset + p]
                a[i + offset] = left + right
                a[i + offset + p] = Modint(
                    (a[0].mod() + left.val() - right.val()) * inow.val())
            inow *= sum_ie[_bsf(~s)] # 修正

def convolution_mod(a: typing.List[Modint],
                    b: typing.List[Modint]) -> typing.List[Modint]:
    n = len(a)
    m = len(b)

    if n == 0 or m == 0:
        return []

    # 畳み込みが小さい場合はナイーブな実装の方が速い
    if min(n, m) <= 60:
        if n < m:
            n, m = m, n
            a, b = b, a
        ans = [Modint(0) for _ in range(n + m - 1)]
        for i in range(n):
            for j in range(m):
                ans[i + j] += a[i] * b[j]
        return ans

    z = 1 << _ceil_pow2(n + m - 1) # 修正

    #
    # !!重要!!
    # ac-library-python の _butterfly はリストを *直接* 変更します。
    # この関数 (convolution_mod) も引数 a, b を直接変更します。
    # 呼び出し元 (main) で a と b に同じリスト (f[t]) を渡すと、
    # 予期せぬ動作になるため、呼び出し元でコピー (f[t][:]) を渡す必要があります。
    #
    # C++ の atcoder::convolution は引数を値渡し (コピー) で受け取るため、
    # この問題は発生しません。
    #
    # ここでは、呼び出し元がコピーを渡すことを前提とせず、
    # 安全のために関数内でコピーを作成します。
    # (※ 提出コードでは main 側で f[t][:] とコピーを渡すため、
    #   元の ac-library-python のコード (a.extend, b.extend) でも
    #   問題なく動作します。)
    #
    # 元の ac-library-python の実装 (インプレース変更):
    # a.extend([Modint(0)] * (z - n))
    # _butterfly(a)
    # b.extend([Modint(0)] * (z - m))
    # _butterfly(b)
    # for i in range(z):
    #     a[i] *= b[i]
    # _butterfly_inv(a)
    # a = a[:n + m - 1]
    # ...
    # return a
    #
    # 安全な(コピーを作成する)実装:
    a_copy = a + [Modint(0)] * (z - n)
    b_copy = b + [Modint(0)] * (z - m)
    
    _butterfly(a_copy)
    _butterfly(b_copy)
    
    for i in range(z):
        a_copy[i] *= b_copy[i]
        
    _butterfly_inv(a_copy)
    
    result = a_copy[:n + m - 1]
    
    iz = Modint(z).inv()
    for i in range(n + m - 1):
        result[i] *= iz
    
    return result

# -----------------------------------------------------------------
# メインロジック
# -----------------------------------------------------------------

def main():
    MOD = 998244353
    # Modint のコンテキストを設定
    with ModContext(MOD):
        # 高速入力を設定
        input = sys.stdin.readline
        
        n, m = map(int, input().split())

        # f[t] のリスト
        f: typing.List[typing.List[Modint]] = [[] for _ in range(n + 1)]

        # f[0](x) = x  (Modintオブジェクトとして初期化)
        f[0] = [Modint(0), Modint(1)]

        for t in range(n):
            # f[t+1](x) = (f_t(x))^2 + f_t(x)
            
            # (f_t(x))^2
            # !!重要!!: convolution_mod はリストをインプレースで変更する
            # 可能性があるため、必ずコピー [:] を渡す。
            f_t_squared = convolution_mod(f[t][:], f[t][:])
            
            len_t = len(f[t])
            len_sq = len(f_t_squared)
            
            # f[t] の方が長い場合、f_t_squared を拡張
            if len_t > len_sq:
                f_t_squared.extend([Modint(0)] * (len_t - len_sq))
            
            # f[t+1] = f_t_squared + f[t]
            for i in range(len_t):
                f_t_squared[i] += f[t][i]
            
            f[t + 1] = f_t_squared

        # ([x^{2^n - m}](f_n(x) + 1)) m! (2^n - m)!
        
        k = (1 << n) - m
        
        ans = Modint(0)
        if k < len(f[n]):
            ans = f[n][k]

        # f_n(x) + 1 の「+1」の部分 (x^0 の係数)
        # k = 0 (つまり m = 2^n) の場合に +1 する
        if k == 0:
            ans += 1

        # m! を掛ける
        fact_m = Modint(1)
        for i in range(1, m + 1):
            fact_m *= i
        
        # (2^n - m)! = k! を掛ける
        fact_k = Modint(1)
        for i in range(1, k + 1):
            fact_k *= i
        
        ans *= fact_m
        ans *= fact_k

        print(ans.val())

if __name__ == "__main__":
    main()
0