結果

問題 No.1302 Random Tree Score
ユーザー sasa8uyauyasasa8uyauya
提出日時 2024-09-19 21:08:09
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,282 ms / 3,000 ms
コード長 1,836 bytes
コンパイル時間 196 ms
コンパイル使用メモリ 82,048 KB
実行使用メモリ 250,908 KB
最終ジャッジ日時 2024-09-19 21:08:54
合計ジャッジ時間 39,756 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2,282 ms
248,616 KB
testcase_01 AC 2,184 ms
249,004 KB
testcase_02 AC 2,153 ms
250,780 KB
testcase_03 AC 2,105 ms
250,620 KB
testcase_04 AC 2,119 ms
250,528 KB
testcase_05 AC 2,173 ms
250,876 KB
testcase_06 AC 2,127 ms
250,432 KB
testcase_07 AC 2,152 ms
250,908 KB
testcase_08 AC 2,136 ms
250,620 KB
testcase_09 AC 2,162 ms
250,620 KB
testcase_10 AC 2,192 ms
250,744 KB
testcase_11 AC 2,133 ms
249,620 KB
testcase_12 AC 2,156 ms
250,876 KB
testcase_13 AC 2,169 ms
249,300 KB
testcase_14 AC 2,132 ms
250,488 KB
testcase_15 AC 2,120 ms
250,616 KB
testcase_16 AC 2,160 ms
248,620 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

mod = 998244353
R = 3
Rinv = 332748118
W = [pow(R, (mod-1)>>i, mod) for i in range(24)]
Winv = [pow(Rinv, (mod-1)>>i, mod) for i in range(24)]


def fft(k, f):
	for l in range(k, 0, -1):
		d = 1<<l-1
		U = [1]
		for i in range(d):
			U.append(U[-1]*W[l]%mod)
		for i in range(1<<k-l):
			for j in range(d):
				s = i*2*d+j
				f[s], f[s+d] = (f[s]+f[s+d])%mod, U[j]*(f[s]-f[s+d])%mod


def fftinv(k, f):
	for l in range(1, k+1):
		d = 1<<l-1
		for i in range(1<<k-l):
			u = 1
			for j in range(i*2*d, (i*2+1)*d):
				f[j+d] *= u
				f[j], f[j+d] = (f[j]+f[j+d])%mod, (f[j]-f[j+d])%mod
				u *= Winv[l]
				u %= mod


def convolution(a, b):
	le = len(a)+len(b)-1
	k = le.bit_length()
	n = 1<<k
	a = a+[0]*(n-len(a))
	b = b+[0]*(n-len(b))
	fft(k, a)
	fft(k, b)
	for i in range(n):
		a[i] *= b[i]
		a[i] %= mod
	fftinv(k, a)
	ninv = pow(n, mod-2, mod)
	for i in range(le):
		a[i] *= ninv
		a[i] %= mod
	return a[:le]


def FPSinv(H):
  I=[pow(H[0],M-2,M)]
  l=1
  while l<len(H):
    I+=[0]*l
    nI=convolution(H[:l*2],convolution(I,I)[:l*2])[:l*2]
    for i in range(l*2):
      nI[i]=(2*I[i]-nI[i])%M
    I=nI
    l*=2
  return I[:len(H)]

def FPSlog(H):
  H1=[H[i]*i%M for i in range(1,len(H))]+[0]
  H2=FPSinv(H)
  I=convolution(H1,H2)
  I=[0]+[I[i]*fb[i+1]*fa[i]%M for i in range(len(H)-1)]
  return I

def FPSexp(H):
  I=[1]
  l=1
  while l<len(H):
    I+=[0]*l
    I2=FPSlog(I)[:l*2]
    I3=H[:l*2]
    I3[0]+=1
    for i in range(l*2):
      I3[i]-=I2[i]
    nI=convolution(I,I3)[:l*2]
    I=nI
    l*=2
  return I[:len(H)]

n=int(input())

L=1<<17
M=998244353

fa=[1,1]
fb=[1,1]
for i in range(2,L+1):
  fa+=[fa[-1]*i%M]
  fb+=[fb[-1]*(M//i)*fb[M%i]*fa[M%i-1]*(-1)%M]

q1=[(i+1)*fb[i]%M for i in range(L)]
q2=FPSlog(q1)
for i in range(L):
  q2[i]*=n
  q2[i]%=M
q3=FPSexp(q2)

a=q3[n-2]
print(a*fa[n-2]*pow(pow(n,n-2,M),M-2,M)%M)
0