結果
問題 | No.2917 二重木 |
ユーザー | PNJ |
提出日時 | 2024-10-04 22:29:51 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 10,554 bytes |
コンパイル時間 | 259 ms |
コンパイル使用メモリ | 82,224 KB |
実行使用メモリ | 94,300 KB |
最終ジャッジ日時 | 2024-10-04 22:30:25 |
合計ジャッジ時間 | 29,056 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 39 ms
56,008 KB |
testcase_01 | AC | 40 ms
56,916 KB |
testcase_02 | AC | 41 ms
56,596 KB |
testcase_03 | AC | 39 ms
55,732 KB |
testcase_04 | AC | 39 ms
57,036 KB |
testcase_05 | AC | 38 ms
56,960 KB |
testcase_06 | WA | - |
testcase_07 | WA | - |
testcase_08 | WA | - |
testcase_09 | WA | - |
testcase_10 | WA | - |
testcase_11 | WA | - |
testcase_12 | AC | 46 ms
64,952 KB |
testcase_13 | AC | 49 ms
66,536 KB |
testcase_14 | WA | - |
testcase_15 | AC | 38 ms
56,960 KB |
testcase_16 | AC | 38 ms
56,772 KB |
testcase_17 | AC | 37 ms
55,784 KB |
testcase_18 | AC | 35 ms
56,616 KB |
testcase_19 | AC | 44 ms
56,716 KB |
testcase_20 | WA | - |
testcase_21 | WA | - |
testcase_22 | WA | - |
testcase_23 | AC | 40 ms
55,428 KB |
testcase_24 | WA | - |
testcase_25 | WA | - |
testcase_26 | WA | - |
testcase_27 | AC | 37 ms
55,468 KB |
testcase_28 | WA | - |
testcase_29 | TLE | - |
testcase_30 | TLE | - |
testcase_31 | TLE | - |
testcase_32 | TLE | - |
testcase_33 | TLE | - |
testcase_34 | WA | - |
ソースコード
N,P = map(int,input().split())if N == P:print(pow(N,N - 2,P))exit()mod = 998244353Mod,MOd,MOD = 1045430273,1051721729,1053818881n = Nfact = [1 for i in range(n+1)]for i in range(1,n+1):fact[i] = fact[i-1] * i % Pfact_inv = [1 for i in range(n+1)]fact_inv[-1] = pow(fact[-1],P-2,P)for i in range(n,0,-1):fact_inv[i-1] = fact_inv[i]*i % Pdef binom(n,r):res = fact[n] * (fact_inv[n - r] * fact_inv[r] % P) % Preturn resNTT_friend = [120586241,167772161,469762049,754974721,880803841,924844033,943718401,998244353,1045430273,1051721729,1053818881]NTT_dict = {}for i in range(len(NTT_friend)):NTT_dict[NTT_friend[i]] = iNTT_info = [[20,74066978],[25,17],[26,30],[24,362],[23,211],[21,44009197],[22,663003469],[23,31],[20,363],[20,330],[20,2789]]def popcount(n):c=(n&0x5555555555555555)+((n>>1)&0x5555555555555555)c=(c&0x3333333333333333)+((c>>2)&0x3333333333333333)c=(c&0x0f0f0f0f0f0f0f0f)+((c>>4)&0x0f0f0f0f0f0f0f0f)c=(c&0x00ff00ff00ff00ff)+((c>>8)&0x00ff00ff00ff00ff)c=(c&0x0000ffff0000ffff)+((c>>16)&0x0000ffff0000ffff)c=(c&0x00000000ffffffff)+((c>>32)&0x00000000ffffffff)return cdef topbit(n):h = n.bit_length()h -= 1return hdef prepared_fft(mod = 998244353):rank2 = NTT_info[NTT_dict[mod]][0]root,iroot = [0] * 30,[0] * 30rate2,irate2= [0] * 30,[0] * 30rate3,irate3= [0] * 30,[0] * 30root[rank2] = NTT_info[NTT_dict[mod]][1]iroot[rank2] = pow(root[rank2],mod - 2,mod)for i in range(rank2-1,-1,-1):root[i] = root[i+1] * root[i+1] % modiroot[i] = iroot[i+1] * iroot[i+1] % modprod,iprod = 1,1for i in range(rank2-1):rate2[i] = root[i + 2] * prod % modirate2[i] = iroot[i + 2] * iprod % modprod = prod * iroot[i + 2] % modiprod = iprod * root[i + 2] % modprod,iprod = 1,1for i in range(rank2-2):rate3[i] = root[i + 3] * prod % modirate3[i] = iroot[i + 3] * iprod % modprod = prod * iroot[i + 3] % modiprod = iprod * root[i + 3] % modreturn root,iroot,rate2,irate2,rate3,irate3root,iroot,rate2,irate2,rate3,irate3 = prepared_fft()def ntt(a):n = len(a)h = topbit(n)assert (n == 1 << h)le = 0while le < h:if h - le == 1:p = 1 << (h - le - 1)rot = 1for s in range(1 << le):offset = s << (h - le)for i in range(p):l = a[i + offset]r = a[i + offset + p] * rot % moda[i + offset] = (l + r) % moda[i + offset + p] = (l - r) % modrot = rot * rate2[topbit(~s & -~s)] % modle += 1else:p = 1 << (h - le - 2)rot,imag = 1,root[2]for s in range(1 << le):rot2 = rot * rot % modrot3 = rot2 * rot % modoffset = s << (h - le)for i in range(p):a0 = a[i + offset]a1 = a[i + offset + p] * rota2 = a[i + offset + p * 2] * rot2a3 = a[i + offset + p * 3] * rot3a1na3imag = (a1 - a3) % mod * imaga[i + offset] = (a0 + a2 + a1 + a3) % moda[i + offset + p] = (a0 + a2 - a1 - a3) % moda[i + offset + p * 2] = (a0 - a2 + a1na3imag) % moda[i + offset + p * 3] = (a0 - a2 - a1na3imag) % modrot = rot * rate3[topbit(~s & -~s)] % modle += 2def intt(a):n = len(a)h = topbit(n)assert (n == 1 << h)coef = pow(n,mod - 2,mod)for i in range(n):a[i] = a[i] * coef % modle = hwhile le:if le == 1:p = 1 << (h - le)irot = 1for s in range(1 << (le - 1)):offset = s << (h - le + 1)for i in range(p):l = a[i + offset]r = a[i + offset + p]a[i + offset] = (l + r) % moda[i + offset + p] = (l - r) * irot % modirot = irot * irate2[topbit(~s & -~s)] % modle -= 1else:p = 1 << (h - le)irot,iimag = 1,iroot[2]for s in range(1 << (le - 2)):irot2 = irot * irot % modirot3 = irot2 * irot % modoffset = s << (h - le + 2)for i in range(p):a0 = a[i + offset]a1 = a[i + offset + p]a2 = a[i + offset + p * 2]a3 = a[i + offset + p * 3]a2na3iimag = (a2 - a3) * iimag % moda[i + offset] = (a0 + a1 + a2 + a3) % moda[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % moda[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % moda[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % modirot *= irate3[topbit(~s & -~s)]irot %= modle -= 2def convolute_naive(a,b):res = [0] * (len(a) + len(b) - 1)for i in range(len(a)):for j in range(len(b)):res[i+j] = (res[i+j] + a[i] * b[j] % mod) % modreturn resdef convolute(a,b):s = a[:]t = b[:]n = len(s)m = len(t)if min(n,m) <= 60:return convolute_naive(s,t)le = 1while le < n + m - 1:le *= 2s += [0] * (le - n)t += [0] * (le - m)ntt(s)ntt(t)for i in range(le):s[i] = s[i] * t[i] % modintt(s)s = s[:n + m - 1]return sdef fps_inv(f,deg = -1):assert (f[0] != 0)if deg == -1:deg = len(f)res = [0] * degres[0] = pow(f[0],mod-2,mod)d = 1while d < deg:a = [0] * (d << 1)tmp = min(len(f),d << 1)a[:tmp] = f[:tmp]b = [0] * (d << 1)b[:d] = res[:d]ntt(a)ntt(b)for i in range(d << 1):a[i] = a[i] * b[i] % modintt(a)a[:d] = [0] * dntt(a)for i in range(d << 1):a[i] = a[i] * b[i] % modintt(a)for j in range(d,min(d << 1,deg)):if a[j]:res[j] = mod - a[j]else:res[j] = 0d <<= 1return resdef fps_div(f,g):n,m = len(f),len(g)if n < m:return [],frev_f = f[:]rev_f = rev_f[::-1]rev_g = g[:]rev_g = rev_g[::-1]rev_q = convolute(rev_f,fps_inv(rev_g,n-m+1))[:n-m+1]q = rev_q[:]q = q[::-1]p = convolute(g,q)r = f[:]for i in range(min(len(p),len(r))):r[i] -= p[i]r[i] %= modwhile len(r):if r[-1] != 0:breakr.pop()return q,rdef fps_add(f,g):n = max(len(f),len(g))res = [0] * nfor i in range(len(f)):res[i] = f[i]for i in range(len(g)):res[i] = (res[i] + g[i]) % modreturn resdef fps_diff(f):if len(f) <= 1:return [0]res = []for i in range(1,len(f)):res.append(i * f[i] % mod)return resdef fps_integrate(f):n = len(f)res = [0] * (n + 1)for i in range(n):res[i+1] = pow(i + 1,mod-2,mod) * f[i] % modreturn resdef fps_log(f,deg = -1):assert (f[0] == 1)if deg == -1:deg = len(f)res = convolute(fps_diff(f),fps_inv(f,deg))res = fps_integrate(res)return res[:deg]def fps_exp(f,deg = -1):assert (f[0] == 0)if deg == -1:deg = len(f)res = [1,0]if len(f) > 1:res[1] = f[1]g = [1]p = []q = [1,1]m = 2while m < deg:y = res + [0]*mntt(y)p = q[:]z = [y[i] * p[i] for i in range(len(p))]intt(z)z[:m >> 1] = [0] * (m >> 1)ntt(z)for i in range(len(p)):z[i] = z[i] * (-p[i]) % modintt(z)g[m >> 1:] = z[m >> 1:]q = g + [0] * mntt(q)tmp = min(len(f),m)x = f[:tmp] + [0] * (m - tmp)x = fps_diff(x)x.append(0)ntt(x)for i in range(len(x)):x[i] = x[i] * y[i] % modintt(x)for i in range(len(res)):if i == 0:continuex[i-1] -= res[i] * i % modx += [0] * mfor i in range(m-1):x[m+i],x[i] = x[i],0ntt(x)for i in range(len(q)):x[i] = x[i] * q[i] % modintt(x)x.pop()x = fps_integrate(x)x[:m] = [0] * mfor i in range(m,min(len(f),m << 1)):x[i] += f[i]ntt(x)for i in range(len(y)):x[i] = x[i] * y[i] % modintt(x)res[m:] = x[m:]m <<= 1return res[:deg]def fps_pow(f,k,deg = -1):if deg == -1:deg = len(f)if k == 0:return [1] + [0] * (deg - 1)while len(f) < deg:f.append(0)p = 0while p < deg:if f[p]:breakp += 1if p * k >= deg:return [0] * dega = f[p]g = [0 for _ in range(deg - p)]a_inv = pow(a,mod-2,mod)for i in range(deg - p):g[i] = f[i + p] * a_inv % modg = fps_log(g)for i in range(deg-p):g[i] = g[i] * k % modg = fps_exp(g)a = pow(a,k,mod)res = [0] * degfor i in range(deg):j = i + p * kif j >= deg:breakres[j] = g[i] * a % modreturn resdef mod_inv(a,mod):if mod == 1:return 0a %= modb,s,t = mod,1,0while True:if a == 1:return st -= (b // a) * sb %= aif b == 1:return t + mods -= (a // b) * ta %= bdef gcd_inv(a,mod):a %= modb,s,t = mod,1,0while True:if a == 0:return (b,t + mod)t -= (b // a) * sb %= aif b == 0:return (a,s)s -= (a // b) * ta %= b# (0,0)のとき存在しない.def garner(Rem,Mod):assert (len(Rem) == len(Mod))r,m = 0,1for i in range(len(Rem)):assert (Mod[i])Rem[i] %= Mod[i]m1,r1 = Mod[i],Rem[i]if m < m1:m,m1,r,r1 = m1,m,r1,rif m % m1 == 0:if r % m1 != r1:return (0,0)g,im = gcd_inv(m,m1)y = abs(r1 - r)if y % g:return (0,0)u1 = m1 // gy = y // g % u1if (r > r1 and y != 0):y = u1 - yx = y * im % u1r += x * mm *= u1return (r,m)# Modの中身が互いに素じゃないとダメdef Garner(Rem,Mod,mod):assert (len(Rem) == len(Mod))Rem.append(0)Mod.append(mod)n = len(Mod)coffs = [1] * nconstants = [0] * nfor i in range(n - 1):v = (Rem[i] - constants[i]) * mod_inv(coffs[i],Mod[i]) % Mod[i]for j in range(i + 1,n):constants[j] = (constants[j] + coffs[j] * v) % Mod[j]coffs[j] = (coffs[j] * Mod[i]) % Mod[j]return constants[-1]f = [1]for i in range(1,N):c = pow(i + 1,i - 1,P) * fact_inv[i] % Pf.append(c)ans = 0g = [0] * Ng[0] = 1for n in range(1,N + 1):res = binom(N,n) * pow(n,n - 2,P) % Proot,iroot,rate2,irate2,rate3,irate3 = prepared_fft(mod)h = convolute(f,g)[:N]root,iroot,rate2,irate2,rate3,irate3 = prepared_fft(Mod)hh = convolute(f,g)[:N]root,iroot,rate2,irate2,rate3,irate3 = prepared_fft(MOd)hhh = convolute(f,g)[:N]root,iroot,rate2,irate2,rate3,irate3 = prepared_fft(MOD)hhhh = convolute(f,g)[:N]for i in range(N):g[i] = Garner([h[i],hh[i],hhh[i],hhhh[i]],[mod,Mod,MOd,MOD],P)res = res * fact[N - n] % Pres = res * g[N - n] % Pans += resans %= Pprint(ans)