結果
| 問題 |
No.2013 Can we meet?
|
| コンテスト | |
| ユーザー |
Kiri8128
|
| 提出日時 | 2022-04-16 14:13:47 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,480 ms / 2,500 ms |
| コード長 | 4,411 bytes |
| コンパイル時間 | 159 ms |
| コンパイル使用メモリ | 82,456 KB |
| 実行使用メモリ | 183,372 KB |
| 最終ジャッジ日時 | 2024-07-04 22:09:12 |
| 合計ジャッジ時間 | 20,341 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 35 |
ソースコード
P = 998244353
p, g, ig = 998244353, 3, 332748118
W = [pow(g, (p - 1) >> i, p) for i in range(24)]
iW = [pow(ig, (p - 1) >> i, p) for i in range(24)]
def convolve(a, b):
def fft(f):
for l in range(k, 0, -1):
d = 1 << l - 1
U = [1]
for i in range(d):
U.append(U[-1] * W[l] % p)
for i in range(1 << k - l):
for j in range(d):
s = i * 2 * d + j
t = s + d
f[s], f[t] = (f[s] + f[t]) % p, U[j] * (f[s] - f[t]) % p
def ifft(f):
for l in range(1, k + 1):
d = 1 << l - 1
U = [1]
for i in range(d):
U.append(U[-1] * iW[l] % p)
for i in range(1 << k - l):
for j in range(d):
s = i * 2 * d + j
t = s + d
f[s], f[t] = (f[s] + f[t] * U[j]) % p, (f[s] - f[t] * U[j]) % p
n0 = len(a) + len(b) - 1
if len(a) < 80 or len(b) < 80:
ret = [0] * n0
if len(a) > len(b): a, b = b, a
for i, aa in enumerate(a):
for j, bb in enumerate(b):
ret[i+j] = (ret[i+j] + aa * bb) % p
return ret
k = (n0).bit_length()
n = 1 << k
a = a + [0] * (n - len(a))
b = b + [0] * (n - len(b))
fft(a), fft(b)
for i in range(n):
a[i] = a[i] * b[i] % p
ifft(a)
invn = pow(n, p - 2, p)
for i in range(n0):
a[i] = a[i] * invn % p
del a[n0:]
return a
class SemiRelaxedMultiplication():
# h = f * g
# f: online
# g: given
def __init__(self, g):
self.f = []
self.g = g # コピーしていないので注意
self.h = [0] * 8
self.n = 0
def calc(self, l, m):
self.h += [0] * (l + 3 * m - 1 - len(self.h))
co = convolve(self.f[l:l+m], self.g[m:2*m])
for i, a in enumerate(co, l + m):
self.h[i] = (self.h[i] + a) % p
def append(self, a):
# self.h += [0, 0]
self.f.append(a)
self.n += 1
n = self.n
self.h[n-1] = (self.h[n-1] + self.f[n-1] * self.g[0]) % P
self.h[n] = (self.h[n] + self.f[n-1] * self.g[1]) % P
s = n
m = 2
while n % m == 0:
self.calc(s - m, m)
m *= 2
return self.h[n-1]
nn = 1001001
fa = [1] * (nn+1)
fainv = [1] * (nn+1)
for i in range(nn):
fa[i+1] = fa[i] * (i+1) % P
fainv[-1] = pow(fa[-1], P-2, P)
for i in range(nn)[::-1]:
fainv[i] = fainv[i+1] * (i+1) % P
C = lambda a, b: fa[a] * fainv[b] % P * fainv[a-b] % P if 0 <= b <= a else 0
def calc(n, x1, y1, x2, y2, a, b, L):
x = abs(x1 - x2)
y = abs(y1 - y2)
if x + y > 2 * n:
return 0
if (x + y) % 2:
return 0
m = n - (x + y) // 2 + 1
iv = pow(2 * (a + b), P - 2, P)
alpha = a * iv % P
beta = b * iv % P
poa = [1]
pob = [1]
for i in range(n * 4 + 1):
poa.append(poa[-1] * alpha % P)
for i in range(n * 4 + 1):
pob.append(pob[-1] * beta % P)
assert (alpha + beta) * 2 % P == 1
tmp1 = [fainv[x+k] * fainv[k] % P * poa[x+2*k] % P for k in range(m)]
tmp2 = [fainv[y+l] * fainv[l] % P * pob[y+2*l] % P for l in range(m)]
o = (x + y) // 2
qq = ([0] * o + [fa[(o+i)*2] * a % P for i, a in enumerate(convolve(tmp1, tmp2))])[:n+1]
tmp1 = [fainv[k] * fainv[k] % P * poa[2*k] % P for k in range(n + 1)]
tmp2 = [fainv[l] * fainv[l] % P * pob[2*l] % P for l in range(n + 1)]
ss = [fa[i*2] * a % P for i, a in enumerate(convolve(tmp1, tmp2))]
ss1 = ss[1:2+n]
srm = SemiRelaxedMultiplication(ss1)
a = 0
rr = [a]
for b in ss1:
a = (b - srm.append(a)) % P
rr.append(a)
qqrr = convolve(qq, rr)
pp = [(a - b) % P for a, b in zip(qq, qqrr)]
ans = 0
for i in range(n):
ans = (ans + pp[i+1] * L[i]) % P
return ans
N = int(input())
x1, y1, x2, y2 = map(int, input().split())
a, b = map(int, input().split())
A = [int(a) for a in input().split()]
print(calc(N, x1, y1, x2, y2, a, b, A))
# Check
assert 1 <= N <= 10 ** 5
assert 0 <= x1 <= 10 ** 9
assert 0 <= y1 <= 10 ** 9
assert 0 <= x2 <= 10 ** 9
assert 0 <= y2 <= 10 ** 9
assert (x1, y1) != (x2, y2)
assert 1 <= a <= 10 ** 6
assert 1 <= b <= 10 ** 6
for aa in A:
assert 1 <= aa <= 10 ** 9
Kiri8128