結果
| 問題 |
No.3349 AtCoder Janken Train
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-11-01 02:18:55 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 12,063 bytes |
| コンパイル時間 | 246 ms |
| コンパイル使用メモリ | 82,856 KB |
| 実行使用メモリ | 360,528 KB |
| 最終ジャッジ日時 | 2025-11-13 21:10:36 |
| 合計ジャッジ時間 | 3,914 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 TLE * 1 |
| other | -- * 30 |
ソースコード
# Converted by Gemini 2.5 Pro
import sys
import typing
# -----------------------------------------------------------------
# 依存ライブラリのインライン展開 (ac-library-python)
# -----------------------------------------------------------------
#
# _bit.py
#
def _ceil_pow2(n: int) -> int:
x = 0
while (1 << x) < n:
x += 1
return x
def _bsf(n: int) -> int:
x = 0
while n % 2 == 0:
x += 1
n //= 2
return x
#
# _math.py
#
def _inv_gcd(a: int, b: int) -> typing.Tuple[int, int]:
a %= b
if a == 0:
return (b, 0)
s = b
t = a
m0 = 0
m1 = 1
while t:
u = s // t
s -= t * u
m0 -= m1 * u
s, t = t, s
m0, m1 = m1, m0
if m0 < 0:
m0 += b // s
return (s, m0)
def _primitive_root(m: int) -> int:
if m == 2:
return 1
if m == 167772161:
return 3
if m == 469762049:
return 3
if m == 754974721:
return 11
if m == 998244353:
return 3
divs = [2] + [0] * 19
cnt = 1
x = (m - 1) // 2
while x % 2 == 0:
x //= 2
i = 3
while i * i <= x:
if x % i == 0:
divs[cnt] = i
cnt += 1
while x % i == 0:
x //= i
i += 2
if x > 1:
divs[cnt] = x
cnt += 1
g = 2
while True:
for i in range(cnt):
if pow(g, (m - 1) // divs[i], m) == 1:
break
else:
return g
g += 1
#
# modint.py
#
class ModContext:
context: typing.List[int] = []
def __init__(self, mod: int) -> None:
assert 1 <= mod
self.mod = mod
def __enter__(self) -> None:
self.context.append(self.mod)
def __exit__(self, exc_type: typing.Any, exc_value: typing.Any,
traceback: typing.Any) -> None:
self.context.pop()
@classmethod
def get_mod(cls) -> int:
return cls.context[-1]
class Modint:
def __init__(self, v: int = 0) -> None:
self._mod = ModContext.get_mod()
if v == 0:
self._v = 0
else:
self._v = v % self._mod
def mod(self) -> int:
return self._mod
def val(self) -> int:
return self._v
def __iadd__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
self._v += rhs._v
else:
self._v += rhs
if self._v >= self._mod:
self._v -= self._mod
return self
def __isub__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
self._v -= rhs._v
else:
self._v -= rhs
if self._v < 0:
self._v += self._mod
return self
def __imul__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
self._v = self._v * rhs._v % self._mod
else:
self._v = self._v * rhs % self._mod
return self
def __ifloordiv__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
inv = rhs.inv()._v
else:
inv = _inv_gcd(rhs, self._mod)[1] # 修正: atcoder._math._inv_gcd -> _inv_gcd
self._v = self._v * inv % self._mod
return self
def __pos__(self) -> 'Modint':
return self
def __neg__(self) -> 'Modint':
return Modint() - self
def __pow__(self, n: int) -> 'Modint':
assert 0 <= n
return Modint(pow(self._v, n, self._mod))
def inv(self) -> 'Modint':
eg = _inv_gcd(self._v, self._mod) # 修正: atcoder._math._inv_gcd -> _inv_gcd
assert eg[0] == 1
return Modint(eg[1])
def __add__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
result = self._v + rhs._v
if result >= self._mod:
result -= self._mod
return raw(result)
else:
return Modint(self._v + rhs)
def __sub__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
result = self._v - rhs._v
if result < 0:
result += self._mod
return raw(result)
else:
return Modint(self._v - rhs)
def __mul__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
return Modint(self._v * rhs._v)
else:
return Modint(self._v * rhs)
def __floordiv__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
if isinstance(rhs, Modint):
inv = rhs.inv()._v
else:
inv = _inv_gcd(rhs, self._mod)[1] # 修正: atcoder._math._inv_gcd -> _inv_gcd
return Modint(self._v * inv)
def __eq__(self, rhs: typing.Union['Modint', int]) -> bool:
if isinstance(rhs, Modint):
return self._v == rhs._v
else:
return self._v == rhs
def __ne__(self, rhs: typing.Union['Modint', int]) -> bool:
if isinstance(rhs, Modint):
return self._v != rhs._v
else:
return self._v != rhs
def raw(v: int) -> Modint:
x = Modint()
x._v = v
return x
#
# convolution.py
#
_sum_e: typing.Dict[int, typing.List[Modint]] = {}
def _butterfly(a: typing.List[Modint]) -> None:
g = _primitive_root(a[0].mod()) # 修正
n = len(a)
h = _ceil_pow2(n) # 修正
if a[0].mod() not in _sum_e:
es = [Modint(0)] * 30
ies = [Modint(0)] * 30
cnt2 = _bsf(a[0].mod() - 1) # 修正
e = Modint(g) ** ((a[0].mod() - 1) >> cnt2)
ie = e.inv()
for i in range(cnt2, 1, -1):
es[i - 2] = e
ies[i - 2] = ie
e = e * e
ie = ie * ie
sum_e = [Modint(0)] * 30
now = Modint(1)
for i in range(cnt2 - 2):
sum_e[i] = es[i] * now
now *= ies[i]
_sum_e[a[0].mod()] = sum_e
else:
sum_e = _sum_e[a[0].mod()]
for ph in range(1, h + 1):
w = 1 << (ph - 1)
p = 1 << (h - ph)
now = Modint(1)
for s in range(w):
offset = s << (h - ph + 1)
for i in range(p):
left = a[i + offset]
right = a[i + offset + p] * now
a[i + offset] = left + right
a[i + offset + p] = left - right
now *= sum_e[_bsf(~s)] # 修正
_sum_ie: typing.Dict[int, typing.List[Modint]] = {}
def _butterfly_inv(a: typing.List[Modint]) -> None:
g = _primitive_root(a[0].mod()) # 修正
n = len(a)
h = _ceil_pow2(n) # 修正
if a[0].mod() not in _sum_ie:
es = [Modint(0)] * 30
ies = [Modint(0)] * 30
cnt2 = _bsf(a[0].mod() - 1) # 修正
e = Modint(g) ** ((a[0].mod() - 1) >> cnt2)
ie = e.inv()
for i in range(cnt2, 1, -1):
es[i - 2] = e
ies[i - 2] = ie
e = e * e
ie = ie * ie
sum_ie = [Modint(0)] * 30
now = Modint(1)
for i in range(cnt2 - 2):
sum_ie[i] = ies[i] * now
now *= es[i]
_sum_ie[a[0].mod()] = sum_ie
else:
sum_ie = _sum_ie[a[0].mod()]
for ph in range(h, 0, -1):
w = 1 << (ph - 1)
p = 1 << (h - ph)
inow = Modint(1)
for s in range(w):
offset = s << (h - ph + 1)
for i in range(p):
left = a[i + offset]
right = a[i + offset + p]
a[i + offset] = left + right
a[i + offset + p] = Modint(
(a[0].mod() + left.val() - right.val()) * inow.val())
inow *= sum_ie[_bsf(~s)] # 修正
def convolution_mod(a: typing.List[Modint],
b: typing.List[Modint]) -> typing.List[Modint]:
n = len(a)
m = len(b)
if n == 0 or m == 0:
return []
# 畳み込みが小さい場合はナイーブな実装の方が速い
if min(n, m) <= 60:
if n < m:
n, m = m, n
a, b = b, a
ans = [Modint(0) for _ in range(n + m - 1)]
for i in range(n):
for j in range(m):
ans[i + j] += a[i] * b[j]
return ans
z = 1 << _ceil_pow2(n + m - 1) # 修正
#
# !!重要!!
# ac-library-python の _butterfly はリストを *直接* 変更します。
# この関数 (convolution_mod) も引数 a, b を直接変更します。
# 呼び出し元 (main) で a と b に同じリスト (f[t]) を渡すと、
# 予期せぬ動作になるため、呼び出し元でコピー (f[t][:]) を渡す必要があります。
#
# C++ の atcoder::convolution は引数を値渡し (コピー) で受け取るため、
# この問題は発生しません。
#
# ここでは、呼び出し元がコピーを渡すことを前提とせず、
# 安全のために関数内でコピーを作成します。
# (※ 提出コードでは main 側で f[t][:] とコピーを渡すため、
# 元の ac-library-python のコード (a.extend, b.extend) でも
# 問題なく動作します。)
#
# 元の ac-library-python の実装 (インプレース変更):
# a.extend([Modint(0)] * (z - n))
# _butterfly(a)
# b.extend([Modint(0)] * (z - m))
# _butterfly(b)
# for i in range(z):
# a[i] *= b[i]
# _butterfly_inv(a)
# a = a[:n + m - 1]
# ...
# return a
#
# 安全な(コピーを作成する)実装:
a_copy = a + [Modint(0)] * (z - n)
b_copy = b + [Modint(0)] * (z - m)
_butterfly(a_copy)
_butterfly(b_copy)
for i in range(z):
a_copy[i] *= b_copy[i]
_butterfly_inv(a_copy)
result = a_copy[:n + m - 1]
iz = Modint(z).inv()
for i in range(n + m - 1):
result[i] *= iz
return result
# -----------------------------------------------------------------
# メインロジック
# -----------------------------------------------------------------
def main():
MOD = 998244353
# Modint のコンテキストを設定
with ModContext(MOD):
# 高速入力を設定
input = sys.stdin.readline
n, m = map(int, input().split())
# f[t] のリスト
f: typing.List[typing.List[Modint]] = [[] for _ in range(n + 1)]
# f[0](x) = x (Modintオブジェクトとして初期化)
f[0] = [Modint(0), Modint(1)]
for t in range(n):
# f[t+1](x) = (f_t(x))^2 + f_t(x)
# (f_t(x))^2
# !!重要!!: convolution_mod はリストをインプレースで変更する
# 可能性があるため、必ずコピー [:] を渡す。
f_t_squared = convolution_mod(f[t][:], f[t][:])
len_t = len(f[t])
len_sq = len(f_t_squared)
# f[t] の方が長い場合、f_t_squared を拡張
if len_t > len_sq:
f_t_squared.extend([Modint(0)] * (len_t - len_sq))
# f[t+1] = f_t_squared + f[t]
for i in range(len_t):
f_t_squared[i] += f[t][i]
f[t + 1] = f_t_squared
# ([x^{2^n - m}](f_n(x) + 1)) m! (2^n - m)!
k = (1 << n) - m
ans = Modint(0)
if k < len(f[n]):
ans = f[n][k]
# f_n(x) + 1 の「+1」の部分 (x^0 の係数)
# k = 0 (つまり m = 2^n) の場合に +1 する
if k == 0:
ans += 1
# m! を掛ける
fact_m = Modint(1)
for i in range(1, m + 1):
fact_m *= i
# (2^n - m)! = k! を掛ける
fact_k = Modint(1)
for i in range(1, k + 1):
fact_k *= i
ans *= fact_m
ans *= fact_k
print(ans.val())
if __name__ == "__main__":
main()