結果

問題 No.3315 FPS Game
コンテスト
ユーザー kidodesu
提出日時 2025-04-13 12:01:25
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,520 ms / 3,250 ms
コード長 2,838 bytes
コンパイル時間 581 ms
コンパイル使用メモリ 82,356 KB
実行使用メモリ 153,204 KB
最終ジャッジ日時 2025-06-19 05:21:36
合計ジャッジ時間 21,095 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 25
権限があれば一括ダウンロードができます

ソースコード

diff #

mod = 998244353
N = 10 ** 6 + 2
F = [1] * N
E = [1] * N
for i in range(2, N):
    F[i] = F[i-1]*i%mod
E[-1] = pow(F[-1], -1, mod)
for i in range(N-1, 0, -1):
    E[i-1] = E[i]*i%mod

def comb(a, b):
    if b < 0:
        return 0
    if a < b:
        return 0
    return F[a] * E[b] * E[a-b] % mod


K,M,W = 119, 23, 31
class NTT:
    def __init__(self):
        self.ws = [pow(W,2**i,mod) for i in range(M,-1,-1)]
        self.iws = [pow(w,mod-2,mod) for w in self.ws]
    def polymul_ntt(self,f,g):
        nf = len(f)
        ng = len(g)
        m = nf+ng-1
        n = 2**(m-1).bit_length()
        f = [x % mod for x in f]+[0]*(n-nf)
        g = [x % mod for x in g]+[0]*(n-ng)
        self.ntt(f)
        self.ntt(g)
        for i in range(n):
            f[i] = f[i]*g[i]%mod
        self.intt(f)
        return f[:m]
    def ntt(self, A):
        if len(A) == 1: return
        n = len(A)
        k = n.bit_length()-1
        r = 1<<(k-1)
        for w in self.ws[k:0:-1]:
            for l in range(0,n,2*r):
                wi = 1
                for i in range(r):
                    A[l+i],A[l+i+r] = (A[l+i]+A[l+i+r])%mod,(A[l+i]-A[l+i+r])*wi%mod
                    wi = wi*w%mod
            r = r//2
    def intt(self, A):
        if len(A) == 1: return
        n = len(A)
        k = (n-1).bit_length()
        r = 1
        for w in self.iws[1:k+1]:
            for l in range(0,n,2*r):
                wi = 1
                for i in range(r):
                    A[l+i],A[l+i+r] = (A[l+i]+A[l+i+r]*wi)%mod,(A[l+i]-A[l+i+r]*wi)%mod
                    wi = wi*w%mod
            r = r*2
        ni = pow(n, mod-2, mod)
        for i in range(n):
            A[i] = A[i]*ni%mod
ntt = NTT()

from collections import deque

n, s, t = map(int, input().split())
s, t = s-1, t-1
N = 2*n-1
node = [[] for _ in range(N)]

V = [0] * N
for i in range(n-1):
    u, v = [int(x)-1 for x in input().split()]
    node[u].append(n+i)
    node[v].append(n+i)
    node[n+i].append(u)
    node[n+i].append(v)
    V[u] += 1
    V[v] += 1
    V[n+i] += 2

D = [-1] * N
D[n+s] = 0
dq = deque([n+s])
while dq:
    now = dq.popleft()
    for nxt in node[now]:
        if D[nxt] == -1:
            D[nxt] = D[now] + 1
            dq.append(nxt)

now = n+t
A = [0]
cnt = 0
while now != n+s:
    if 0 <= now < n and V[now] > 2:
        A.append(V[now]-2)
    if n <= now:
        cnt += 1
    for nxt in node[now]:
        if D[nxt] < D[now]:
            now = nxt
            break

from heapq import *
hq = []
for i in range(len(A)):
    a = A[i]
    heappush(hq, (a+1, [(comb(a, i) * F[i]) % mod for i in range(a+1)]))

while len(hq) > 1:
    a0, b0 = heappop(hq)
    a1, b1 = heappop(hq)
    a2 = a0+a1-1
    b2 = ntt.polymul_ntt(b0, b1)
    heappush(hq, (a2, b2))

b = list(heappop(hq)[1])
print(* [0] * cnt + b + [0] * (n - len(b) - cnt))

0