結果

問題 No.1504 ヌメロニム
ユーザー lam6er
提出日時 2025-03-31 18:01:29
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,004 bytes
コンパイル時間 188 ms
コンパイル使用メモリ 82,060 KB
実行使用メモリ 279,448 KB
最終ジャッジ日時 2025-03-31 18:02:19
合計ジャッジ時間 5,648 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 23 WA * 1 TLE * 1 -- * 36
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

import sys
MOD = 998244353
primitive_root = 3
def main():
sys.setrecursionlimit(1 << 25)
N = int(sys.stdin.readline())
S = sys.stdin.readline().strip()
# Collect i's and n's positions
i_pos = []
n_pos = []
for i, c in enumerate(S):
if c == 'i':
i_pos.append(i)
elif c == 'n':
n_pos.append(i)
# Compute cnt[m]: number of i-n pairs with exactly m characters between
cnt = {}
max_m = 0
for i in i_pos:
for n in n_pos:
if i >= n:
continue
m = n - i - 1
cnt[m] = cnt.get(m, 0) + 1
if m > max_m:
max_m = m
# Handle the case where cnt is empty (no i-n pairs)
if not cnt:
print(0)
return
# Precompute factorial and inv_factorial
size = max_m if max_m >= 0 else 0
fact = [1] * (size + 1)
for i in range(1, size + 1):
fact[i] = fact[i-1] * i % MOD
inv_fact = [1] * (size + 1)
inv_fact[size] = pow(fact[size], MOD-2, MOD)
for i in range(size-1, -1, -1):
inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
# Prepare array A and B
len_A = max_m + 1
A = [0] * len_A
for m in cnt:
if m >= len_A:
continue # theoretically not possible
A[m] = cnt[m] * fact[m] % MOD
B = [0] * (max_m + 1)
for m in range(max_m + 1):
if m <= size:
B[m] = inv_fact[m]
else:
B[m] = 0 # inv_fact[m] is zero beyond size, but since we precomputed up to size, this is redundant
# Create B_rev by reversing B
B_rev = B[::-1]
# Convolve A and B_rev
len_A = len(A)
len_B_rev = len(B_rev)
len_conv = len_A + len_B_rev - 1
n = 1
while n < len_conv:
n <<= 1
# Pad with zeros
A_pad = A + [0] * (n - len_A)
B_rev_pad = B_rev + [0] * (n - len_B_rev)
# NTT functions (taken from a standard implementation)
def ntt_transform(a, invert):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
a[i], a[j] = a[j], a[i]
log_n = (n).bit_length() - 1
for length in range(1, log_n + 1):
clen = 1 << length
ang = pow(primitive_root, (MOD - 1) // clen, MOD)
if invert:
ang = pow(ang, MOD - 2, MOD)
for i in range(0, n, clen):
w = 1
for j_ in range(clen // 2):
u = a[i + j_]
v = a[i + j_ + clen // 2] * w % MOD
a[i + j_] = (u + v) % MOD
a[i + j_ + clen // 2] = (u - v) % MOD
w = w * ang % MOD
if invert:
inv_n = pow(n, MOD - 2, MOD)
for i in range(n):
a[i] = a[i] * inv_n % MOD
return a
# Perform NTT on A and B_rev
A_ntt = A_pad.copy()
ntt_transform(A_ntt, False)
B_rev_ntt = B_rev_pad.copy()
ntt_transform(B_rev_ntt, False)
# Multiply point-wise
C_ntt = [ (a * b) % MOD for a, b in zip(A_ntt, B_rev_ntt) ]
# Inverse NTT
ntt_transform(C_ntt, True)
# Extract the first len_conv elements
conv = C_ntt[:len_conv]
# Now compute X_k for each k
X = [0] * N
len_B = len(B)
for k in range(N - 1):
if k > max_m:
X_k_val = 0
else:
# The position in convolution is k + (len_B - 1)
pos = k + (len_B - 1)
if pos >= len(conv):
c = 0
else:
c = conv[pos] % MOD
X_k_val = c * inv_fact[k] % MOD
X[k] = X_k_val
# Compute XOR for X_0 to X_{N-2}
xor_result = 0
for k in range(N - 1):
xor_result ^= X[k]
print(xor_result % MOD)
if __name__ == '__main__':
main()
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0