結果

問題 No.2013 Can we meet?
ユーザー Kiri8128Kiri8128
提出日時 2022-08-02 23:04:55
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,067 ms / 2,500 ms
コード長 13,238 bytes
コンパイル時間 264 ms
コンパイル使用メモリ 86,908 KB
実行使用メモリ 191,312 KB
最終ジャッジ日時 2023-09-30 23:29:37
合計ジャッジ時間 18,410 ms
ジャッジサーバーID
(参考情報)
judge14 / judge15
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 100 ms
91,656 KB
testcase_01 AC 100 ms
91,584 KB
testcase_02 AC 101 ms
91,568 KB
testcase_03 AC 101 ms
91,392 KB
testcase_04 AC 100 ms
91,712 KB
testcase_05 AC 99 ms
91,392 KB
testcase_06 AC 100 ms
91,440 KB
testcase_07 AC 100 ms
91,476 KB
testcase_08 AC 100 ms
91,476 KB
testcase_09 AC 100 ms
91,516 KB
testcase_10 AC 100 ms
91,436 KB
testcase_11 AC 101 ms
91,580 KB
testcase_12 AC 102 ms
91,604 KB
testcase_13 AC 99 ms
91,548 KB
testcase_14 AC 106 ms
91,936 KB
testcase_15 AC 152 ms
93,804 KB
testcase_16 AC 153 ms
93,620 KB
testcase_17 AC 156 ms
94,200 KB
testcase_18 AC 155 ms
94,196 KB
testcase_19 AC 151 ms
94,152 KB
testcase_20 AC 154 ms
94,064 KB
testcase_21 AC 103 ms
91,892 KB
testcase_22 AC 152 ms
94,224 KB
testcase_23 AC 149 ms
94,096 KB
testcase_24 AC 1,064 ms
190,660 KB
testcase_25 AC 998 ms
177,536 KB
testcase_26 AC 1,062 ms
191,040 KB
testcase_27 AC 996 ms
177,168 KB
testcase_28 AC 1,057 ms
190,804 KB
testcase_29 AC 1,053 ms
191,312 KB
testcase_30 AC 983 ms
176,132 KB
testcase_31 AC 1,061 ms
190,880 KB
testcase_32 AC 117 ms
101,480 KB
testcase_33 AC 117 ms
101,524 KB
testcase_34 AC 1,064 ms
190,512 KB
testcase_35 AC 990 ms
177,000 KB
testcase_36 AC 1,067 ms
190,876 KB
testcase_37 AC 116 ms
101,780 KB
testcase_38 AC 122 ms
101,828 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

P = 998244353
p, g, ig = 998244353, 3, 332748118
W = [pow(g, (p - 1) >> i, p) for i in range(24)]
iW = [pow(ig, (p - 1) >> i, p) for i in range(24)]

def convolve(a, b):
    def fft(f):
        for l in range(k, 0, -1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * W[l] % p)

            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t]) % p, U[j] * (f[s] - f[t]) % p

    def ifft(f):
        for l in range(1, k + 1):
            d = 1 << l - 1
            U = [1]
            for i in range(d):
                U.append(U[-1] * iW[l] % p)

            for i in range(1 << k - l):
                for j in range(d):
                    s = i * 2 * d + j
                    t = s + d
                    f[s], f[t] = (f[s] + f[t] * U[j]) % p, (f[s] - f[t] * U[j]) % p
    
    n0 = len(a) + len(b) - 1
    if len(a) < 50 or len(b) < 50:
        ret = [0] * n0
        if len(a) > len(b): a, b = b, a
        for i, aa in enumerate(a):
            for j, bb in enumerate(b):
                ret[i+j] = (ret[i+j] + aa * bb) % P
        return ret
    
    k = (n0).bit_length()
    n = 1 << k
    a = a + [0] * (n - len(a))
    b = b + [0] * (n - len(b))
    fft(a), fft(b)
    for i in range(n):
        a[i] = a[i] * b[i] % p
    ifft(a)
    invn = pow(n, p - 2, p)
    for i in range(n0):
        a[i] = a[i] * invn % p
    del a[n0:]
    return a

def Tonelli_Shanks(n, p = P):
    if pow(n, (p-1) // 2, p) == -1:
        return -1
    
    if p % 4 == 3:
        a = pow(n, (p+1) // 4, p)
        return min(a, p - a)
    
    q = p - 1
    s = 0
    while q % 2 == 0:
        q //= 2
        s += 1
    
    for z in range(1, p):
        if pow(z, (p-1) // 2, p) != 1:
            break
    m = s
    c = pow(z, q, p)
    t = pow(n, q, p)
    r = pow(n, (q+1) // 2, p)
    while 1:
        if t == 0:
            return 0
        if t == 1:
            return min(r, p - r)
        for i in range(1, m):
            if pow(t, 1 << i, p) == 1:
                break
        if m - i <= 0:
            return -1
        b = pow(c, 1 << m-i-1, p)
        m = i
        c = b ** 2 % p
        t = t * b ** 2 % p
        r = r * b % p

class fps():
    def __init__(self, a, m = 10**6):
        if type(a) == int:
            self.len = 1
            self.f = [a]
        elif a:
            self.len = len(a)
            self.f = a
        else:
            self.len = 1
            self.f = [0]
    
    def __neg__(self):
        l = [0] * self.len
        for i, a in enumerate(self.f):
            l[i] = P - a if a else 0
        return self.__class__(l)
        
    def __add__(self, other):
        if type(other) == int:
            return self + self.__class__([other])
        if self.len > other.len:
            l = self.f[:]
            for i, a in enumerate(other.f):
                l[i] += a
                if l[i] >= P:
                    l[i] -= P
        else:
            l = other.f[:]
            for i, a in enumerate(self.f):
                l[i] += a
                if l[i] >= P:
                    l[i] -= P
        return self.__class__(l)
    
    def __radd__(self, other):
        return self + other
    
    def __sub__(self, other):
        if type(other) == int:
            return self - self.__class__([other])
        l = self.f[:] + [0] * (other.len - self.len)
        for i, a in enumerate(other.f):
            l[i] -= a
            if l[i] < 0:
                l[i] += P
        return self.__class__(l)
    def __rsub__(self, other):
        return self.__class__([other]) - self
    
    def __mul__(self, other):
        if type(other) == int:
            l = self.f[:]
            for i in range(self.len):
                l[i] = l[i] * other % P
            return self.__class__(l)
        else:
            return self.__class__(convolve(self.f, other.f))

    def __rmul__(self, other):
        l = self.f[:]
        for i in range(self.len):
            l[i] = l[i] * other % P
        return self.__class__(l)
    
    def inv(self, deg = -1):
        f = self.f[:]
        assert f[0]
        n = self.len
        if deg < 0: deg = n
        ret = __class__([pow(self.f[0], P - 2, P)])
        i = 1
        while i < deg:
            ret = (ret * (2 - ret * self[:i*2]))[:i*2]
            i <<= 1
        return ret[:deg]
    
    def __truediv__(self, other, deg = -1):
        if type(other) == int:
            iv = pow(other, P - 2, P)
            l = self.f[:]
            for i in range(self.len):
                l[i] = l[i] * iv % P
            return self.__class__(l)
        else:
            if deg < 0: deg = max(self.len, other.len)
            return (self * other.inv(deg))[:deg]
    
    def sqrt(self):
        if self.f[0] == 0:
            for k, a in enumerate(self.f):
                if a: break
            else:
                return self.__class__([0] * self.len)
            if k & 1: return None
            sq = self.__class__(self.f[k:] + [0] * (k//2)).sqrt()
            if not sq: return None
            return fps([0] * (k//2) + sq.f)
        ts = Tonelli_Shanks(self.f[0])
        if ts < 0: return None
        deg = self.len
        a = self.__class__([ts])
        i = 1
        while i < deg:
            a += self[:i*2].__truediv__(a)
            a /= 2
            i <<= 1
        return a
    
    def f2e(self):
        f = self.f[:]
        for i, a in enumerate(f):
            f[i] = a * fainv[i] % P
        return self.__class__(f)
    
    def e2f(self):
        f = self.f[:]
        for i, a in enumerate(f):
            f[i] = a * fa[i] % P
        return self.__class__(f)
    
    def differentiate(self):
        f = self.f[:]
        for i, a in enumerate(f):
            f[i] = a * i % P
        f = f[1:] + [0]
        return self.__class__(f)
    
    def integrate(self):
        f = self.f[:]
        for i, a in enumerate(f):
            f[i] = a * fainv[i+1] % P * fa[i] % P
        f = [0] + f[:-1]
        return self.__class__(f)
    
    def log(self, deg = -1):
        return (self.differentiate().__truediv__(self, deg)).integrate()
    
    def exp(self, deg = -1):
        assert self.f[0] == 0
        if deg < 0: deg = self.len
        a = self.__class__([1])
        i = 1
        while i < deg:
            a = (a * (self[:i*2] + 1 - a.log(i * 2)))[:i*2]
            i <<= 1
        return a[:deg]
    
    def __pow__(self, n, deg = -1):
        if deg < 0: deg = self.len
        if self.f[0] == 0:
            assert n >= 0
            for i, a in enumerate(self.f):
                if a:
                    if i * n >= deg: return self.__class__([0] * deg)
                    return self.__class__([0] * (i * n) + pow(self.__class__(self.f[i:]), n, deg - i * n).f)
            else:
                return self.__class__([0] * deg)
        if self.f[0] != 1:
            a = self.f[0]
            return pow(self / a, n, deg) * pow(a, n, P)
        return (self.log(deg) * n).exp(deg)
    
    def taylor_shift(self, c):
        deg = self.len
        L = []
        a = 1
        for i in range(deg):
            L.append(a * fainv[i] % P)
            a = a * c % P
        L = L[::-1]
        return (self.e2f() * self.__class__(L))[deg-1:].f2e()
    
    def composite(self, other, deg = -1):
        assert other.f[0] == 0
        if other.len == 1: return self[:1]
        if deg < 0: deg = (self.len - 1) * (other.len - 1) + 1
        n = other.len
        k = int((n / n.bit_length()) ** 0.5) + 1
        p = other[:k]
        q = other[k:]
        def calc():
            f = self.f + [0] * ((-self.len) % 4)
            pp = p
            while 1:
                pp2 = (pp * pp)[:deg]
                pp3 = (pp2 * pp)[:deg]
                g = []
                for i in range(0, len(f), 4):
                    g.append(f[i] + (f[i+1] * pp)[:deg] + (f[i+2] * pp2)[:deg] + (f[i+3] * pp3)[:deg])
                if len(g) <= 1:
                    break
                f = g + [0] * ((-len(g)) % 4)
                pp = (pp3 * pp)[:deg]
            return g[0]
        
        if p.iszero():
            ff = self[:]
            re = ff[0]
            qq = 1
            for i in range(k, deg, k):
                ff = ff.differentiate()
                qq = (qq * q)[:deg-i]
                re += (ff[0] * fainv[i//k] * qq).shift(i)
            return re
        
        fp = calc()
        re = fp[:]
        pd = p.differentiate()
        z = pd.leadingzeroes()
        pdi = pd[z:].inv(deg)
        qq = 1
        for i in range(k, deg, k):
            fp = (pdi[:deg-i+z] * fp[:deg-i+1+z].differentiate())[:deg-i+z][z:]
            qq = (qq * q)[:deg-i]
            re += ((fp * qq)[:deg-i] * fainv[i//k]).shift(i)
        return re
    
    def at(self, v):
        f = self.f
        s = 0
        for a in f[::-1]:
            s = (s * v + a) % P
        return s
    
    def shift(self, k):
        return self.__class__([0] * k + self.f)
    
    def iszero(self):
        return sum(self.f) == 0
    
    def leadingzeroes(self):
        for i, a in enumerate(self.f):
            if a: return i
        return self.len
    
    def __getitem__(self, s):
        return self.__class__(self.f[s])
    
    def to_frac(self, a):
        if 0 <= a <= 10000: return a
        if -10000 <= a - P < 0: return a - P
        for i in range(1, 10001):
            if i and a * i % P <= 10000:
                return str(a * i % P) + "/" + str(i)
            if i and -a * i % P <= 10000:
                return str(-(-a * i % P)) + "/" + str(i)
        return a
    
    def __str__(self):
        l = []
        for a in self.f:
            l.append(str(self.to_frac(a)))
        return ", ".join(l)

class SemiRelaxedMultiplication():
    # h = f * g
    # f: online
    # g: given
    def __init__(self, g):
        self.f = []
        self.g = g # コピーしていないので注意
        self.h = [0] * 8
        self.n = 0
    
    def calc(self, l, m):
        self.h += [0] * (l + 3 * m - 1 - len(self.h))
        co = convolve(self.f[l:l+m], self.g[m:2*m])
        for i, a in enumerate(co, l + m):
            self.h[i] = (self.h[i] + a) % p
        
    def append(self, a):
        # self.h += [0, 0]
        self.f.append(a)
        self.n += 1
        n = self.n
        self.h[n-1] = (self.h[n-1] + self.f[n-1] * self.g[0]) % P
        self.h[n] = (self.h[n] + self.f[n-1] * self.g[1]) % P
        s = n
        m = 2
        while n % m == 0:
            self.calc(s - m, m)
            m *= 2
        return self.h[n-1]

def r(a):
    if -10000 <= a <= 10000: return a
    for i in range(1, 10001):
        if i and a * i % P <= 10000:
            return str(a * i % P) + "/" + str(i)
        if i and -a * i % P <= 10000:
            return str(-(-a * i % P)) + "/" + str(i)
    return a

nn = 1001001

fa = [1] * (nn+1)
fainv = [1] * (nn+1)
for i in range(nn):
    fa[i+1] = fa[i] * (i+1) % P
fainv[-1] = pow(fa[-1], P-2, P)
for i in range(nn)[::-1]:
    fainv[i] = fainv[i+1] * (i+1) % P

C = lambda a, b: fa[a] * fainv[b] % P * fainv[a-b] % P if 0 <= b <= a else 0

def calc(n, x1, y1, x2, y2, a, b, L):
    x = abs(x1 - x2)
    y = abs(y1 - y2)
    if x + y > 2 * n:
        return 0
    if (x + y) % 2:
        return 0
    
    m = n - (x + y) // 2 + 1
    iv = pow(2 * (a + b), P - 2, P)
    alpha = a * iv % P
    beta = b * iv % P
    
    poa = [1]
    pob = [1]
    for i in range(n * 4 + 1):
        poa.append(poa[-1] * alpha % P)
    for i in range(n * 4 + 1):
        pob.append(pob[-1] * beta % P)
    
    assert (alpha + beta) * 2 % P == 1
    
    tmp1 = [fainv[x+k] * fainv[k] % P * poa[x+2*k] % P for k in range(m)]
    tmp2 = [fainv[y+l] * fainv[l] % P * pob[y+2*l] % P for l in range(m)]
    o = (x + y) // 2
    qq = ([0] * o + [fa[(o+i)*2] * a % P for i, a in enumerate(convolve(tmp1, tmp2))])[:n+1]
    tmp1 = [fainv[k] * fainv[k] % P * poa[2*k] % P for k in range(n + 1)]
    tmp2 = [fainv[l] * fainv[l] % P * pob[2*l] % P for l in range(n + 1)]
    ss = [fa[i*2] * a % P for i, a in enumerate(convolve(tmp1, tmp2))][:n+2]
    
    if 0:
        ss1 = ss[1:2+n]
        srm = SemiRelaxedMultiplication(ss1)
    
    f_s = fps(ss)
    f_r = 1 - fps(1) / f_s
    rr = f_r.f
    
    if 0:
        a = 0
        rr = [a]
        for b in ss1:
            a = (b - srm.append(a)) % P
            rr.append(a)
    
    if 0:
        print("ss =", [r(a) for a in ss])
        print("fr =", f_r)
        print("fr =", f_r.f)
    
    qqrr = convolve(qq, rr)
    pp = [(a - b) % P for a, b in zip(qq, qqrr)]
    
    ans = 0
    for i in range(n):
        ans = (ans + pp[i+1] * L[i]) % P
    
    return ans


N = int(input())
x1, y1, x2, y2 = map(int, input().split())
a, b = map(int, input().split())

A = [int(a) for a in input().split()]
print(calc(N, x1, y1, x2, y2, a, b, A))


# Check
assert 1 <= N <= 10 ** 5
assert 0 <= x1 <= 10 ** 9
assert 0 <= y1 <= 10 ** 9
assert 0 <= x2 <= 10 ** 9
assert 0 <= y2 <= 10 ** 9
assert (x1, y1) != (x2, y2)
assert 1 <= a <= 10 ** 6
assert 1 <= b <= 10 ** 6
for aa in A:
    assert 1 <= aa <= 10 ** 9
0