結果

問題 No.2062 Sum of Subset mod 999630629
ユーザー 👑 rin204
提出日時 2022-10-17 23:18:17
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 14,056 bytes
コンパイル時間 219 ms
コンパイル使用メモリ 81,904 KB
実行使用メモリ 93,524 KB
最終ジャッジ日時 2024-06-28 12:12:36
合計ジャッジ時間 4,333 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 10 RE * 19
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

MOD = 998244353
MOD2 = 999630629
n = int(input())
A = list(map(int, input().split()))
times = pow(2, n - 1, MOD)
ans = sum(A) * times % MOD
x = sum(A) - MOD2
if x < 0:
print(ans)
exit()
class FFT:
def __init__(self, MOD=998244353):
FFT.MOD = MOD
self.make_info(MOD)
def make_info(self, MOD):
g = self.primitive_root_constexpr()
m = MOD - 1
rank2 = (m & -m).bit_length() - 1
root = [0] * (rank2 + 1)
iroot = [0] * (rank2 + 1)
rate2 = [0] * (rank2 + 1)
irate2 = [0] * (rank2 + 1)
rate3 = [0] * (rank2)
irate3 = [0] * (rank2)
root[rank2] = pow(g, (MOD - 1) >> rank2, MOD)
iroot[rank2] = pow(root[rank2], MOD - 2, MOD)
for i in range(rank2 - 1, -1, -1):
root[i] = root[i + 1] * root[i + 1] % MOD
iroot[i] = iroot[i + 1] * iroot[i + 1] % MOD
prod = 1
iprod = 1
for i in range(1, rank2):
rate2[i] = root[i + 1] * prod % MOD
irate2[i] = iroot[i + 1] * iprod % MOD
prod = prod * iroot[i + 1] % MOD
iprod = iprod * root[i + 1] % MOD
prod = 1
iprod = 1
for i in range(1, rank2 - 1):
rate3[i] = root[i + 2] * prod % MOD
irate3[i] = iroot[i + 2] * iprod % MOD
prod = prod * iroot[i + 2] % MOD
iprod = iprod * root[i + 2] % MOD
self.IMAG = rate2[1]
self.IIMAG = irate2[1]
self.rate2 = rate2
self.irate2 = irate2
self.rate3 = rate3
self.irate3 = irate3
def primitive_root_constexpr(self):
if FFT.MOD == 998244353:
return 3
elif FFT.MOD == 200003:
return 2
elif FFT.MOD == 167772161:
return 3
elif FFT.MOD == 469762049:
return 3
elif FFT.MOD == 754974721:
return 11
divs = [0] * 20
divs[0] = 2
cnt = 1
x = (FFT.MOD - 1) // 2
while x % 2 == 0:
x //= 2
i = 3
while i * i <= x:
if x % i == 0:
divs[cnt] = i
cnt += 1
while x % i == 0:
x //= i
i += 2
if x > 1:
divs[cnt] = x
cnt += 1
g = 2
while 1:
ok = True
for i in range(cnt):
if pow(g, (FFT.MOD - 1) // divs[i], FFT.MOD) == 1:
ok = False
break
if ok:
return g
g += 1
def butterfly(self, A):
n = len(A)
h = (n - 1).bit_length()
le = 0
while le < h:
if h - le == 1:
p = 1 << (h - le - 1)
rot = 1
for s in range(1 << le):
offset = s << (h - le)
for i in range(p):
l = A[i + offset]
r = A[i + offset + p] * rot
A[i + offset] = (l + r) % FFT.MOD
A[i + offset + p] = (l - r) % FFT.MOD
rot *= self.rate2[(~s & -~s).bit_length()]
rot %= FFT.MOD
le += 1
else:
p = 1 << (h - le - 2)
rot = 1
for s in range(1 << le):
rot2 = rot * rot % FFT.MOD
rot3 = rot2 * rot % FFT.MOD
offset = s << (h - le)
for i in range(p):
a0 = A[i + offset]
a1 = A[i + offset + p] * rot
a2 = A[i + offset + p * 2] * rot2
a3 = A[i + offset + p * 3] * rot3
a1na3imag = (a1 - a3) % FFT.MOD * self.IMAG
A[i + offset] = (a0 + a2 + a1 + a3) % FFT.MOD
A[i + offset + p] = (a0 + a2 - a1 - a3) % FFT.MOD
A[i + offset + p * 2] = (a0 - a2 + a1na3imag) % FFT.MOD
A[i + offset + p * 3] = (a0 - a2 - a1na3imag) % FFT.MOD
rot *= self.rate3[(~s & -~s).bit_length()]
rot %= FFT.MOD
le += 2
def butterfly_inv(self, A):
n = len(A)
h = (n - 1).bit_length()
le = h
while le:
if le == 1:
p = 1 << (h - le)
irot = 1
for s in range(1 << (le - 1)):
offset = s << (h - le + 1)
for i in range(p):
l = A[i + offset]
r = A[i + offset + p]
A[i + offset] = (l + r) % FFT.MOD
A[i + offset + p] = (l - r) * irot % FFT.MOD
irot *= self.irate2[(~s & -~s).bit_length()]
irot %= FFT.MOD
le -= 1
else:
p = 1 << (h - le)
irot = 1
for s in range(1 << (le - 2)):
irot2 = irot * irot % FFT.MOD
irot3 = irot2 * irot % FFT.MOD
offset = s << (h - le + 2)
for i in range(p):
a0 = A[i + offset]
a1 = A[i + offset + p]
a2 = A[i + offset + p * 2]
a3 = A[i + offset + p * 3]
a2na3iimag = (a2 - a3) * self.IIMAG % FFT.MOD
A[i + offset] = (a0 + a1 + a2 + a3) % FFT.MOD
A[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % FFT.MOD
A[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % FFT.MOD
A[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % FFT.MOD
irot *= self.irate3[(~s & -~s).bit_length()]
irot %= FFT.MOD
le -= 2
def convolve(self, A, B):
n = len(A)
m = len(B)
if min(n, m) <= 60:
C = [0] * (n + m - 1)
for i in range(n):
if i % 8 == 0:
for j in range(m):
C[i + j] += A[i] * B[j]
C[i + j] %= FFT.MOD
else:
for j in range(m):
C[i + j] += A[i] * B[j]
return [c % FFT.MOD for c in C]
A = A[:]
B = B[:]
z = 1 << (n + m - 2).bit_length()
A += [0] * (z - n)
B += [0] * (z - m)
self.butterfly(A)
self.butterfly(B)
for i in range(z):
A[i] *= B[i]
A[i] %= FFT.MOD
self.butterfly_inv(A)
A = A[:n + m - 1]
iz = pow(z, FFT.MOD - 2, FFT.MOD)
return [a * iz % FFT.MOD for a in A]
class FPS:
fact = [1]
invfact = [1]
MOD = None
def __init__(self, data, MOD=998244353):
if FPS.MOD is None:
FPS.MOD = MOD
FPS.fft = FFT(MOD)
if type(data) == int:
self.f = [data]
else:
self.f = data[:]
def __len__(self):
return len(self.f)
def __getitem__(self, i):
return self.f[i]
def __add__(self, other):
if len(self) < len(other):
other, self = self, other
for i in range(len(other)):
self.f[i] += other[i]
if self.f[i] >= FPS.MOD:
self.f[i] -= FPS.MOD
return self
def __iadd__(self, other):
return self.__add__(other)
def __sub__(self, other):
self.f += [0] * (len(other) - len(self))
for i in range(len(other)):
self.f[i] -= other[i]
if self.f[i] < 0:
self.f[i] += FPS.MOD
return self
def __isub__(self, other):
return self.__sub__(other)
def __mul__(self, other):
if type(other) == int:
f = [other * x % FPS.MOD for x in self.f]
return FPS(f)
f = FPS.fft.convolve(self.f[:], other.f[:])
return FPS(f)
def __imul__(self, other):
if type(other) == int:
self.f = [other * x % FPS.MOD for x in self.f]
return self
self.f = FPS.fft.convolve(self.f, other.f[:])
return self
def inv(self, deg=None):
if deg is None:
deg = len(self)
g = FPS(pow(self[0], FPS.MOD - 2, FPS.MOD))
l = 1
while l < deg:
tmp = g * 2
l *= 2
tmp2 = FPS(self.f[:l]) * (g * g)
g = tmp - tmp2
del g.f[l:]
del g.f[deg:]
return g
def differential(self):
return FPS([x * i % FPS.MOD for i, x in enumerate(self.f[1:], 1)])
def extend_fact(self, l):
l1 = len(FPS.fact)
l += 1
if l1 <= l:
FPS.fact += [0] * (l - l1)
FPS.invfact += [0] * (l - l1)
for i in range(l1, l):
FPS.fact[i] = FPS.fact[i - 1] * i % FPS.MOD
FPS.invfact[l - 1] = pow(FPS.fact[l - 1], FPS.MOD - 2, FPS.MOD)
for i in range(l - 1, l1, -1):
FPS.invfact[i - 1] = FPS.invfact[i] * i % FPS.MOD
def integral(self):
self.extend_fact(len(self))
return FPS([0] + [x * (FPS.fact[i] * FPS.invfact[i + 1] % FPS.MOD) % FPS.MOD for i, x in enumerate(self.f)])
def log(self, deg=None):
if deg is None:
deg = len(self)
tmp = self.differential() * self.inv(deg=deg)
del tmp.f[deg:]
tmp = tmp.integral()
del tmp.f[deg:]
return tmp
def exp(self, deg=None):
if deg is None:
deg = len(self)
g = FPS(1)
l = 1
while l < deg * 2:
l *= 2
log = FPS(1) - g.log(deg=l) + FPS(self.f[:l])
del log.f[l:]
g *= log
del g.f[l:]
del g.f[deg:]
return g
def __pow__(self, k, deg=None):
if k == 0:
if deg is None:
ret = [0] * len(self)
else:
ret = [0] * deg
ret[0] = 1
return FPS(ret)
if deg is None:
deg = len(self)
i = 0
p = None
for i in range(deg):
if self[i] != 0:
a = self[i]
p = i
break
if p is None:
if deg is not None:
return FPS([0] * deg)
else:
return FPS(0)
elif deg is not None and p * k >= deg:
return FPS([0] * deg)
inv = pow(a, FPS.MOD - 2, FPS.MOD)
tmp = FPS([x * inv % FPS.MOD for x in self.f[p:]])
tmp = tmp.log(deg=deg)
if deg is not None:
del tmp.f[deg:]
tmp *= k
tmp = tmp.exp(deg=deg)
tmp = [0] * (p * k) + tmp.f[:deg - p * k]
times = pow(a, k, FPS.MOD)
return FPS([x * times % FPS.MOD for x in tmp])
def __ipow__(self, k):
return self.__pow__(k)
def cipolla(self, a):
if FPS.MOD == 2:
return a
elif a == 0:
return 0
elif pow(a, (FPS.MOD - 1) // 2, FPS.MOD) != 1:
return -1
b = 0
while pow((b * b + FPS.MOD - a) % FPS.MOD, (FPS.MOD - 1) // 2, FPS.MOD) == 1:
b += 1
base = b * b + FPS.MOD - a
def multi(a0, b0, a1, b1):
return (a0 * a1 + (b0 * b1 % FPS.MOD) * base) % FPS.MOD, (a0 * b1 + b0 * a1) % FPS.MOD
def pow_(a, b, n):
if n == 0:
return 1, 0
a_, b_ = pow_(*multi(a, b, a, b), n // 2)
if n % 2 == 1:
a_, b_ = multi(a_, b_, a, b)
return a_, b_
return pow_(b, 1, (FPS.MOD + 1) // 2)[0]
def sqrt(self, deg=None):
if deg is None:
deg = len(self)
if len(self) == 0:
return FPS([0] * deg)
if self[0] == 0:
for i in range(1, len(self)):
if self[i] != 0:
if i & 1:
return FPS([])
if deg <= i // 2:
break
ret = FPS(self.f[i:]).sqrt(deg - i // 2)
if len(ret) == 0:
return FPS([])
ret.f = [0] * (i // 2) + ret.f
if len(ret) < deg:
ret.f += [0] * (deg - len(ret))
return ret
return FPS([0] * deg)
sq = self.cipolla(self[0])
if sq == -1:
return FPS([])
inv2 = (FPS.MOD + 1) // 2
g = FPS([sq])
l = 1
while l < deg:
l *= 2
tmp = FPS(self.f[:l]) * g.inv(deg=l)
g += tmp
g *= inv2
del g.f[deg:]
return g
def taylorshift(self, a):
deg = len(self)
f = self.f[:]
self.extend_fact(deg)
for i in range(deg):
f[i] *= FPS.fact[i]
f[i] %= FPS.MOD
f = f[::-1]
g = [0] * deg
g[0] = 1
for i in range(1, deg):
g[i] = (g[i - 1] * a % FPS.MOD) * (FPS.fact[i - 1] * FPS.invfact[i] % FPS.MOD) % FPS.MOD
f = FPS.fft.convolve(f, g)
del f[deg:]
f = f[::-1]
for i in range(deg):
f[i] *= FPS.invfact[i]
f[i] %= FPS.MOD
return FPS(f)
dp = {0:1}
A.sort(reverse = True)
T = x
F = [0] * (T + 1)
S = list(map(int, input().split()))
cnt = [0] * (T + 1)
for a in A:
if a <= x:
cnt[a] += 1
inv = [0] * (T + 1)
inv[1] = 1
for i in range(2, T + 1):
inv[i] = -inv[MOD % i] * (MOD // i) % MOD
for i, c in enumerate(cnt):
if c == 0:
continue
pm = 1
for j in range(i, T + 1, i):
F[j] += pm * c * inv[j // i] % MOD
pm *= -1
F[j] %= MOD
F = FPS(F)
F = F.exp()
tot = sum(F.f) % MOD
ans -= tot * MOD2
print(ans % MOD)
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0