結果

問題 No.754 畳み込みの和
ユーザー mkawa2mkawa2
提出日時 2021-12-03 11:49:14
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 888 ms / 5,000 ms
コード長 2,839 bytes
コンパイル時間 227 ms
コンパイル使用メモリ 82,176 KB
実行使用メモリ 105,664 KB
最終ジャッジ日時 2024-07-05 15:53:36
合計ジャッジ時間 4,404 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 883 ms
105,536 KB
testcase_01 AC 888 ms
105,664 KB
testcase_02 AC 882 ms
105,544 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

sys.setrecursionlimit(200005)
int1 = lambda x: int(x)-1
p2D = lambda x: print(*x, sep="\n")
def II(): return int(sys.stdin.readline())
def LI(): return list(map(int, sys.stdin.readline().split()))
def LLI(rows_number): return [LI() for _ in range(rows_number)]
def LI1(): return list(map(int1, sys.stdin.readline().split()))
def LLI1(rows_number): return [LI1() for _ in range(rows_number)]
def SI(): return sys.stdin.readline().rstrip()
dij = [(0, 1), (-1, 0), (0, -1), (1, 0)]
# dij = [(0, 1), (-1, 0), (0, -1), (1, 0), (1, 1), (1, -1), (-1, 1), (-1, -1)]
# inf = 18446744073709551615
inf = 4294967295
md = 10**9+7
# md = 998244353

def arbitrary_mod_convolve(a, b, mod):
    MOD1 = lambda: 167772161
    MOD2 = lambda: 469762049
    MOD3 = lambda: 1224736769
    ROOT1 = lambda: 3

    def _ntt(a, h, MOD, ROOT):
        roots = [pow(ROOT(), (MOD()-1) >> i, MOD()) for i in range(h+1)]
        for i in range(h):
            m = 1 << (h-i-1)
            for j in range(1 << i):
                w = 1
                j *= 2*m
                for k in range(m):
                    a[j+k], a[j+k+m] = (a[j+k]+a[j+k+m])%MOD(), (a[j+k]-a[j+k+m])*w%MOD()
                    w *= roots[h-i]
                    w %= MOD()

    def _intt(a, h, MOD, ROOT):
        roots = [pow(ROOT(), (MOD()-1) >> i, MOD()) for i in range(h+1)]
        iroots = [pow(r, MOD()-2, MOD()) for r in roots]
        for i in range(h):
            m = 1 << i
            for j in range(1 << (h-i-1)):
                w = 1
                j *= 2*m
                for k in range(m):
                    a[j+k], a[j+k+m] = (a[j+k]+a[j+k+m]*w)%MOD(), (a[j+k]-a[j+k+m]*w)%MOD()
                    w *= iroots[i+1]
                    w %= MOD()
        inv = pow(1 << h, MOD()-2, MOD())
        for i in range(1 << h):
            a[i] *= inv
            a[i] %= MOD()

    def ntt_convolve(a, b, MOD, ROOT):
        n = 1 << (len(a)+len(b)-1).bit_length()
        h = n.bit_length()-1
        a = list(a)+[0]*(n-len(a))
        b = list(b)+[0]*(n-len(b))

        _ntt(a, h, MOD, ROOT), _ntt(b, h, MOD, ROOT)
        a = [va*vb%MOD() for va, vb in zip(a, b)]
        _intt(a, h, MOD, ROOT)
        return a

    x = ntt_convolve(a, b, MOD1, ROOT1)
    y = ntt_convolve(a, b, MOD2, ROOT1)
    z = ntt_convolve(a, b, MOD3, ROOT1)

    inv1_2 = pow(MOD1(), MOD2()-2, MOD2())
    inv12_3 = pow(MOD1()*MOD2(), MOD3()-2, MOD3())
    mod12 = MOD1()*MOD2()%mod

    res = [0]*len(x)
    for i in range(len(x)):
        v1 = (y[i]-x[i])*inv1_2%MOD2()
        v2 = (z[i]-(x[i]+MOD1()*v1)%MOD3())*inv12_3%MOD3()
        res[i] = (x[i]+MOD1()*v1+mod12*v2)%mod
    return res[:len(a)+len(b)-1]

n = II()
a = [II() for _ in range(n+1)]
b = [II() for _ in range(n+1)]
c = arbitrary_mod_convolve(a, b, md)[:n+1]
ans = 0
for k in c:
    ans += k
    ans %= md
print(ans)
0