結果
| 問題 |
No.3315 FPS Game
|
| コンテスト | |
| ユーザー |
kidodesu
|
| 提出日時 | 2025-04-13 11:57:51 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,047 ms / 3,250 ms |
| コード長 | 2,751 bytes |
| コンパイル時間 | 508 ms |
| コンパイル使用メモリ | 82,228 KB |
| 実行使用メモリ | 142,368 KB |
| 最終ジャッジ日時 | 2025-05-02 20:02:27 |
| 合計ジャッジ時間 | 13,991 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 25 |
ソースコード
mod = 998244353
N = 10 ** 6 + 2
F = [1] * N
E = [1] * N
for i in range(2, N):
F[i] = F[i-1]*i%mod
E[-1] = pow(F[-1], -1, mod)
for i in range(N-1, 0, -1):
E[i-1] = E[i]*i%mod
def comb(a, b):
if b < 0:
return 0
if a < b:
return 0
return F[a] * E[b] * E[a-b] % mod
K,M,W = 119, 23, 31
class NTT:
def __init__(self):
self.ws = [pow(W,2**i,mod) for i in range(M,-1,-1)]
self.iws = [pow(w,mod-2,mod) for w in self.ws]
def polymul_ntt(self,f,g):
nf = len(f)
ng = len(g)
m = nf+ng-1
n = 2**(m-1).bit_length()
f = [x % mod for x in f]+[0]*(n-nf)
g = [x % mod for x in g]+[0]*(n-ng)
self.ntt(f)
self.ntt(g)
for i in range(n):
f[i] = f[i]*g[i]%mod
self.intt(f)
return f[:m]
def ntt(self, A):
if len(A) == 1: return
n = len(A)
k = n.bit_length()-1
r = 1<<(k-1)
for w in self.ws[k:0:-1]:
for l in range(0,n,2*r):
wi = 1
for i in range(r):
A[l+i],A[l+i+r] = (A[l+i]+A[l+i+r])%mod,(A[l+i]-A[l+i+r])*wi%mod
wi = wi*w%mod
r = r//2
def intt(self, A):
if len(A) == 1: return
n = len(A)
k = (n-1).bit_length()
r = 1
for w in self.iws[1:k+1]:
for l in range(0,n,2*r):
wi = 1
for i in range(r):
A[l+i],A[l+i+r] = (A[l+i]+A[l+i+r]*wi)%mod,(A[l+i]-A[l+i+r]*wi)%mod
wi = wi*w%mod
r = r*2
ni = pow(n, mod-2, mod)
for i in range(n):
A[i] = A[i]*ni%mod
ntt = NTT()
from collections import deque
n, s, t = map(int, input().split())
s, t = s-1, t-1
N = 2*n-1
node = [[] for _ in range(N)]
V = [0] * N
for i in range(n-1):
u, v = [int(x)-1 for x in input().split()]
node[u].append(n+i)
node[v].append(n+i)
node[n+i].append(u)
node[n+i].append(v)
V[u] += 1
V[v] += 1
V[n+i] += 2
D = [-1] * N
D[n+s] = 0
dq = deque([n+s])
while dq:
now = dq.popleft()
for nxt in node[now]:
if D[nxt] == -1:
D[nxt] = D[now] + 1
dq.append(nxt)
now = n+t
A = [0]
cnt = 0
while now != n+s:
if 0 <= now < n and V[now] > 2:
A.append(V[now]-2)
if n <= now:
cnt += 1
for nxt in node[now]:
if D[nxt] < D[now]:
now = nxt
break
def merge(l, r):
if l == r:
a = A[l]
return [(comb(a, i) * F[i]) % mod for i in range(a+1)]
mid = l + r >> 1
return ntt.polymul_ntt(merge(l, mid), merge(mid + 1, r))
b = list(merge(0, len(A)-1))
print(* [0] * cnt + b + [0] * (n - len(b) - cnt))
kidodesu