結果

問題 No.287 場合の数
ユーザー strangerxxxstrangerxxx
提出日時 2022-12-20 18:31:58
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 408 ms / 5,000 ms
コード長 8,365 bytes
コンパイル時間 163 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 81,032 KB
最終ジャッジ日時 2024-04-29 02:53:41
合計ジャッジ時間 9,742 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 324 ms
79,852 KB
testcase_01 AC 62 ms
65,280 KB
testcase_02 AC 106 ms
76,496 KB
testcase_03 AC 401 ms
80,148 KB
testcase_04 AC 405 ms
80,144 KB
testcase_05 AC 308 ms
79,724 KB
testcase_06 AC 296 ms
79,644 KB
testcase_07 AC 316 ms
79,740 KB
testcase_08 AC 315 ms
79,708 KB
testcase_09 AC 313 ms
79,864 KB
testcase_10 AC 398 ms
80,148 KB
testcase_11 AC 301 ms
79,620 KB
testcase_12 AC 408 ms
80,008 KB
testcase_13 AC 321 ms
79,132 KB
testcase_14 AC 326 ms
79,732 KB
testcase_15 AC 323 ms
79,724 KB
testcase_16 AC 308 ms
79,476 KB
testcase_17 AC 393 ms
80,016 KB
testcase_18 AC 307 ms
79,740 KB
testcase_19 AC 394 ms
81,032 KB
testcase_20 AC 308 ms
79,468 KB
testcase_21 AC 315 ms
79,856 KB
testcase_22 AC 311 ms
79,724 KB
testcase_23 AC 308 ms
79,580 KB
testcase_24 AC 314 ms
79,860 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD1 = 998244353
MOD2 = 985661441
MOD3 = 943718401
MOD4 = 935329793
MOD5 = 918552577

mod1 = lambda : MOD1
mod2 = lambda : MOD2
mod3 = lambda : MOD3
mod4 = lambda : MOD4
mod5 = lambda : MOD5

def primitive_root(m):
    if m == 2: return 1
    if m == 167772161: return 3
    if m == 469762049: return 3
    if m == 754974721: return 11
    if m == 998244353: return 3
    divs = [0] * 20
    divs[0] = 2
    cnt = 1
    x = (m - 1) // 2
    while x % 2 == 0: x //= 2
    i = 3
    while i * i <= x:
        if x % i == 0:
            divs[cnt] = i
            cnt += 1
            while x % i == 0: x //= i
        i += 2
    if x > 1:
        divs[cnt] = x
        cnt += 1
    g = 2
    while True:
        for i in range(cnt):
            if pow(g, (m - 1) // divs[i], m) == 1: break
        else:
            return g
        g += 1

def popcount(x):
    x = ((x >> 1)  & 0x55555555) + (x & 0x55555555)
    x = ((x >> 2)  & 0x33333333) + (x & 0x33333333)
    x = ((x >> 4)  & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
    x = ((x >> 8)  & 0x00ff00ff) + (x & 0x00ff00ff)
    x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
    return x

def tzcount(x):
    return popcount(~x & (x - 1))

def build_ntt(mod):
    g = primitive_root(mod())
    rank2 = tzcount(mod() - 1)
    root = [0] * (rank2 + 1)
    iroot = [0] * (rank2 + 1)
    rate2 = [0] * max(0, rank2 - 1)
    irate2 = [0] * max(0, rank2 - 1)
    rate3 = [0] * max(0, rank2 - 2)
    irate3 = [0] * max(0, rank2 - 2)
    root[rank2] = pow(g, (mod() - 1) >> rank2, mod())
    iroot[rank2] = pow(root[rank2], mod() - 2, mod())
    for i in range(rank2)[::-1]:
        root[i] = root[i + 1] * root[i + 1]
        root[i] %= mod()
        iroot[i] = iroot[i + 1] * iroot[i + 1]
        iroot[i] %= mod()
    prod = 1
    iprod = 1
    for i in range(rank2 - 1):
        rate2[i] = root[i + 2] * prod % mod()
        irate2[i] = iroot[i + 2] * iprod % mod()
        prod *= iroot[i + 2]
        prod %= mod()
        iprod *= root[i + 2]
        iprod %= mod()
    prod = 1
    iprod = 1
    for i in range(rank2 - 2):
        rate3[i] = root[i + 3] * prod % mod()
        irate3[i] = iroot[i + 3] * iprod % mod()
        prod *= iroot[i + 3]
        prod %= mod()
        iprod *= root[i + 3]
        iprod %= mod()
    return root, iroot, rate2, irate2, rate3, irate3

def butterfly(a, mod, rate2, irate2, rate3, irate3, imag, iimag):
    n = len(a)
    h = (n - 1).bit_length()
    len_ = 0
    while len_ < h:
        if h - len_ == 1:
            p = 1 << (h - len_ - 1)
            rot = 1
            for s in range(1 << len_):
                offset = s << (h - len_)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p] * rot % mod()
                    a[i + offset] = (l + r) % mod()
                    a[i + offset + p] = (l - r) % mod()
                if s + 1 != 1 << len_:
                    rot *= rate2[(~s & -~s).bit_length() - 1]
                    rot %= mod()
            len_ += 1
        else:
            p = 1 << (h - len_ - 2)
            rot = 1
            for s in range(1 << len_):
                rot2 = rot * rot % mod()
                rot3 = rot2 * rot % mod()
                offset = s << (h - len_)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p] * rot
                    a2 = a[i + offset + p * 2] * rot2
                    a3 = a[i + offset + p * 3] * rot3
                    a1na3imag = (a1 - a3) % mod() * imag
                    a[i + offset] = (a0 + a2 + a1 + a3) % mod()
                    a[i + offset + p] = (a0 + a2 - a1 - a3) % mod()
                    a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % mod()
                    a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % mod()
                if s + 1 != 1 << len_:
                    rot *= rate3[(~s & -~s).bit_length() - 1]
                    rot %= mod()
            len_ += 2

def butterfly_inv(a, mod, rate2, irate2, rate3, irate3, imag, iimag):
    n = len(a)
    h = (n - 1).bit_length()
    len_ = h
    while len_:
        if len_ == 1:
            p = 1 << (h - len_)
            irot = 1
            for s in range(1 << (len_ - 1)):
                offset = s << (h - len_ + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p]
                    a[i + offset] = (l + r) % mod()
                    a[i + offset + p] = (l - r) * irot % mod()
                if s + 1 != (1 << (len_ - 1)):
                    irot *= irate2[(~s & -~s).bit_length() - 1]
                    irot %= mod()
            len_ -= 1
        else:
            p = 1 << (h - len_)
            irot = 1
            for s in range(1 << (len_ - 2)):
                irot2 = irot * irot % mod()
                irot3 = irot2 * irot % mod()
                offset = s << (h - len_ + 2)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p]
                    a2 = a[i + offset + p * 2]
                    a3 = a[i + offset + p * 3]
                    a2na3iimag = (a2 - a3) * iimag % mod()
                    a[i + offset] = (a0 + a1 + a2 + a3) % mod()
                    a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % mod()
                    a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % mod()
                    a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % mod()
                if s + 1 != (1 << (len_ - 2)):
                    irot *= irate3[(~s & -~s).bit_length() - 1]
                    irot %= mod()
            len_ -= 2

def convolution(a, b, mod):
    root, iroot, rate2, irate2, rate3, irate3 = build_ntt(mod)
    imag = root[2]
    iimag = iroot[2]
    n = len(a)
    m = len(b)
    if not n or not m: return []
    if min(n, m) <= 100:
        if n < m:
            n, m = m, n
            a, b = b, a
        res = [0] * (n + m - 1)
        for i in range(n):
            for j in range(m):
                res[i + j] += a[i] * b[j]
                res[i + j] %= mod()
        return res
    z = 1 << (n + m - 2).bit_length()
    a += [0] * (z - n)
    b += [0] * (z - m)
    butterfly(a, mod, rate2, irate2, rate3, irate3, imag, iimag)
    butterfly(b, mod, rate2, irate2, rate3, irate3, imag, iimag)
    for i in range(z):
        a[i] *= b[i]
        a[i] %= mod()
    butterfly_inv(a, mod, rate2, irate2, rate3, irate3, imag, iimag)
    a = a[:n + m - 1]
    iz = pow(z, mod() - 2, mod())
    for i in range(n + m - 1):
        a[i] *= iz
        a[i] %= mod()
    return a

def inv_gcd(a, b):
    a %= b
    if a == 0: return b, 0
    s = b
    t = a
    m0 = 0
    m1 = 1
    while t:
        u = s // t
        s -= t * u
        m0 -= m1 * u
        s, t = t, s
        m0, m1 = m1, m0
    if m0 < 0: m0 += b // s
    return s, m0

def gcd(x, y):
    while y:
        x, y = y, x % y
    return x

def crt(r, m):
    assert len(r) == len(m)
    n = len(r)
    r0 = 0
    m0 = 1
    for i in range(n):
        assert 1 <= m[i]
        r1 = r[i] % m[i]
        m1 = m[i]
        if m0 < m1:
            r0, r1 = r1, r0
            m0, m1 = m1, m0
        if m0 % m1 == 0:
            if r0 % m1 != r1: return 0, 0
            continue
        g, im = inv_gcd(m0, m1)
        u1 = m1 // g
        if (r1 - r0) % g: return 0, 0
        x = (r1 - r0) // g * im % u1
        r0 += x * m0
        m0 *= u1
        if (r0 < 0): r0 += m0
    return r0, m0

def convolution_64bit(a, b):
    n = len(a)
    m = len(b)
    mask = 18446744073709551615
    mods = (MOD1, MOD2, MOD3, MOD4, MOD5)
    c1 = convolution([v % MOD1 for v in a], [v % MOD1 for v in b], mod1)[:n + m - 1]
    c2 = convolution([v % MOD2 for v in a], [v % MOD2 for v in b], mod2)[:n + m - 1]
    c3 = convolution([v % MOD3 for v in a], [v % MOD3 for v in b], mod3)[:n + m - 1]
    c4 = convolution([v % MOD4 for v in a], [v % MOD4 for v in b], mod4)[:n + m - 1]
    c5 = convolution([v % MOD5 for v in a], [v % MOD5 for v in b], mod5)[:n + m - 1]
    res = [0] * (n + m - 1)
    for i, v in enumerate(zip(c1, c2, c3, c4, c5)):
        cr, cm = crt(v, mods)
        res[i] = cr & mask
    return res

n = int(input())
x = [1] * (n + 1) + [0] * (5 * n)
ans = x[:]
for _ in range(7):
    ans = convolution_64bit(ans, x[:])[:6 * n + 1]
print(ans[6 * n])
0