結果

問題 No.2506 Sum of Weighted Powers
ユーザー suisensuisen
提出日時 2023-03-31 21:30:05
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,267 ms / 2,000 ms
コード長 6,251 bytes
コンパイル時間 1,405 ms
コンパイル使用メモリ 86,344 KB
実行使用メモリ 159,252 KB
最終ジャッジ日時 2023-10-13 18:23:18
合計ジャッジ時間 18,914 ms
ジャッジサーバーID
(参考情報)
judge11 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 165 ms
79,784 KB
testcase_01 AC 167 ms
79,928 KB
testcase_02 AC 166 ms
79,792 KB
testcase_03 AC 209 ms
83,240 KB
testcase_04 AC 170 ms
80,688 KB
testcase_05 AC 166 ms
79,780 KB
testcase_06 AC 208 ms
83,208 KB
testcase_07 AC 201 ms
82,892 KB
testcase_08 AC 211 ms
82,992 KB
testcase_09 AC 202 ms
82,840 KB
testcase_10 AC 210 ms
83,320 KB
testcase_11 AC 209 ms
83,464 KB
testcase_12 AC 166 ms
80,004 KB
testcase_13 AC 188 ms
82,792 KB
testcase_14 AC 170 ms
79,924 KB
testcase_15 AC 170 ms
80,744 KB
testcase_16 AC 188 ms
83,020 KB
testcase_17 AC 203 ms
82,856 KB
testcase_18 AC 203 ms
83,280 KB
testcase_19 AC 188 ms
82,304 KB
testcase_20 AC 206 ms
82,888 KB
testcase_21 AC 217 ms
83,276 KB
testcase_22 AC 205 ms
83,288 KB
testcase_23 AC 821 ms
121,096 KB
testcase_24 AC 1,053 ms
149,288 KB
testcase_25 AC 487 ms
109,272 KB
testcase_26 AC 1,131 ms
159,252 KB
testcase_27 AC 295 ms
86,204 KB
testcase_28 AC 1,253 ms
153,524 KB
testcase_29 AC 1,258 ms
153,352 KB
testcase_30 AC 1,267 ms
151,896 KB
testcase_31 AC 1,257 ms
153,988 KB
testcase_32 AC 1,253 ms
153,484 KB
testcase_33 AC 169 ms
79,888 KB
testcase_34 AC 167 ms
79,708 KB
testcase_35 AC 248 ms
120,944 KB
testcase_36 AC 279 ms
131,008 KB
testcase_37 AC 167 ms
80,044 KB
testcase_38 AC 169 ms
79,760 KB
testcase_39 AC 169 ms
80,040 KB
testcase_40 AC 164 ms
79,784 KB
testcase_41 AC 167 ms
79,924 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from typing import List

def bsf(x):
    res = 0
    while not (x & 1):
        res += 1
        x >>= 1
    return res

P = 943718401
G = 7
rank2 = bsf(P - 1)

class NTT:
    class __RootInitializer:
        @staticmethod
        def root():
            return [pow(G, (P - 1) >> i, P) for i in range(0, rank2 + 1)]

        @staticmethod
        def iroot():
            return [pow(pow(G, P - 2, P), (P - 1) >> i, P) for i in range(0, rank2 + 1)]

    root = __RootInitializer.root()
    iroot = __RootInitializer.iroot()

    class __RateInitializer:
        @staticmethod
        def rates(root: List[int], iroot: List[int]):
            rate2 = [0] * max(0, rank2 - 1)
            irate2 = [0] * max(0, rank2 - 1)
            prod = iprod = 1
            for i in range(rank2 - 1):
                rate2[i] = root[i + 2] * prod % P
                irate2[i] = iroot[i + 2] * iprod % P
                prod = prod * iroot[i + 2] % P
                iprod = iprod * root[i + 2] % P
            
            rate3 = [0] * max(0, rank2 - 2)
            irate3 = [0] * max(0, rank2 - 2)
            prod = iprod = 1
            for i in range(rank2 - 2):
                rate3[i] = root[i + 3] * prod % P
                irate3[i] = iroot[i + 3] * iprod % P
                prod = prod * iroot[i + 3] % P
                iprod = iprod * root[i + 3] % P
            return rate2, irate2, rate3, irate3
    
    rate2, irate2, rate3, irate3 = __RateInitializer.rates(__RootInitializer.root(), __RootInitializer.iroot())

    @staticmethod
    def butterfly(a: List[int]) -> None:
        n = len(a)
        h = bsf(n)
        l = 0
        while l < h:
            if h - l == 1:
                p = 1 << (h - l - 1)
                rot = 1
                for s in range(1 << l):
                    offset = s << (h - l)
                    for i in range(p):
                        u = a[i + offset]
                        v = a[i + offset + p] * rot
                        a[i + offset] = (u + v) % P
                        a[i + offset + p] = (u - v) % P
                    if s + 1 != 1 << l:
                        rot = rot * NTT.rate2[bsf(~s)] % P
                l += 1
            else:
                p = 1 << (h - l - 2)
                rot, imag = 1, NTT.root[2]
                for s in range(1 << l):
                    rot2 = rot * rot % P
                    rot3 = rot2 * rot % P
                    offset = s << (h - l)
                    for i in range(p):
                        a0 = a[i + offset]
                        a1 = a[i + offset + p] * rot
                        a2 = a[i + offset + 2 * p] * rot2
                        a3 = a[i + offset + 3 * p] * rot3
                        a1na3imag = (a1 - a3) % P * imag
                        a[i + offset] = (a0 + a2 + a1 + a3) % P
                        a[i + offset + 1 * p] = (a0 + a2 - a1 - a3) % P
                        a[i + offset + 2 * p] = (a0 - a2 + a1na3imag) % P
                        a[i + offset + 3 * p] = (a0 - a2 - a1na3imag) % P
                    if s + 1 != (1 << l):
                        rot = rot * NTT.rate3[bsf(~s)] % P
                l += 2

    @staticmethod
    def butterfly_inv(a : List[int]) -> None:
        n = len(a)
        h = bsf(n)

        l = h
        while l:
            if l == 1:
                p = 1 << (h - l)
                irot = 1
                for s in range(1 << (l - 1)):
                    offset = s << (h - l + 1)
                    for i in range(p):
                        u = a[i + offset]
                        v = a[i + offset + p]
                        a[i + offset] = (u + v) % P
                        a[i + offset + p] = ((u - v) * irot) % P
                    if s + 1 != 1 << (l - 1):
                        irot = irot * NTT.irate2[bsf(~s)] % P
                l -= 1
            else:
                p = 1 << (h - l)
                irot = 1
                iimag = NTT.iroot[2]
                for s in range(1 << (l - 2)):
                    irot2 = irot * irot % P
                    irot3 = irot2 * irot % P
                    offset = s << (h - l + 2)
                    for i in range(p):
                        a0 = a[i + offset]
                        a1 = a[i + offset + p]
                        a2 = a[i + offset + 2 * p]
                        a3 = a[i + offset + 3 * p]

                        a2na3iimag = (a2 - a3) * iimag % P

                        a[i + offset] = (a0 + a1 + a2 + a3) % P
                        a[i + offset + p] = ((a0 - a1 + a2na3iimag) * irot) % P
                        a[i + offset + 2 * p] = ((a0 + a1 - a2 - a3) * irot2) % P
                        a[i + offset + 3 * p] = ((a0 - a1 - a2na3iimag) * irot3) % P
                    if s + 1 != 1 << (l - 2):
                        irot = irot * NTT.irate3[bsf(~s)] % P
                l -= 2
    
    @staticmethod
    def convolution(a, b):
        n = len(a)
        m = len(b)
        if not a or not b:
            return []
        if min(n, m) <= 40:
            if n < m:
                n, m = m, n
                a, b = b, a
            res = [0] * (n + m - 1)
            for i in range(n):
                for j in range(m):
                    res[i + j] += a[i] * b[j]
                    res[i + j] %= P
            return res
        z = 1 << ((n + m - 1).bit_length())

        iz = pow(z, P - 2, P)
        a += [0] * (z - n)
        b += [0] * (z - m)
        NTT.butterfly(a)
        NTT.butterfly(b)
        c = [a[i] * b[i] % P * iz % P for i in range(z)]
        NTT.butterfly_inv(c)
        return c[:n + m - 1]

n, x = map(int, input().split())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
C = list(map(int, input().split()))

if x == 0:
    ans = sum(A[i] * B[i] % P * C[0] % P for i in range(n + 1)) + sum(A[i] * B[0] % P * C[i] % P for i in range(1, n + 1))
    print(ans % P)
else:
    t = lambda k : (k - 1) * k * (k + 1) // 3

    inv_x = pow(x, P - 2, P)

    F = [B[i] * pow(inv_x, t(i), P) % P for i in range(n + 1)]
    G = [C[i] * pow(inv_x, t(i), P) % P for i in range(n + 1)]

    H = NTT.convolution(F, G)

    ans = sum(A[i] * pow(x, t(i), P) % P * H[i] % P for i in range(n + 1))

    print(ans % P)
0