結果

問題 No.2129 Perfect Binary Tree...?
ユーザー 遭難者遭難者
提出日時 2022-10-27 21:26:36
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 7,807 bytes
コンパイル時間 194 ms
コンパイル使用メモリ 82,244 KB
実行使用メモリ 68,992 KB
最終ジャッジ日時 2024-07-08 02:29:18
合計ジャッジ時間 5,108 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 47 ms
62,476 KB
testcase_01 AC 46 ms
54,144 KB
testcase_02 AC 47 ms
54,656 KB
testcase_03 AC 72 ms
67,968 KB
testcase_04 AC 66 ms
68,992 KB
testcase_05 TLE -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

# from : https://judge.yosupo.jp/submission/55648
import atexit
from operator import mod
import __pypy__
import sys
import os
MOD = 998244353
IMAG = 911660635
IIMAG = 86583718
rate2 = (0, 911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456, 131300601,
         842657263, 730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443, 56250497, 867605899, 0)
irate2 = (0, 86583718, 372528824, 373294451, 645684063, 112220581, 692852209, 155456985, 797128860, 90816748, 860285882, 927414960,
          354738543, 109331171, 293255632, 535113200, 308540755, 121186627, 608385704, 438932459, 359477183, 824071951, 103369235, 0)
rate3 = (0, 372528824, 337190230, 454590761, 816400692, 578227951, 180142363, 83780245, 6597683, 70046822, 623238099,
         183021267, 402682409, 631680428, 344509872, 689220186, 365017329, 774342554, 729444058, 102986190, 128751033, 395565204, 0)
irate3 = (0, 509520358, 929031873, 170256584, 839780419, 282974284, 395914482, 444904435, 72135471, 638914820, 66769500,
          771127074, 985925487, 262319669, 262341272, 625870173, 768022760, 859816005, 914661783, 430819711, 272774365, 530924681, 0)


def butterfly(a):
    n = len(a)
    h = (n - 1).bit_length()
    le = 0
    while le < h:
        if h - le == 1:
            p = 1 << (h - le - 1)
            rot = 1
            for s in range(1 << le):
                offset = s << (h - le)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p] * rot
                    a[i + offset] = (l + r) % MOD
                    a[i + offset + p] = (l - r) % MOD
                rot *= rate2[(~s & -~s).bit_length()]
                rot %= MOD
            le += 1
        else:
            p = 1 << (h - le - 2)
            rot = 1
            for s in range(1 << le):
                rot2 = rot * rot % MOD
                rot3 = rot2 * rot % MOD
                offset = s << (h - le)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p] * rot
                    a2 = a[i + offset + p * 2] * rot2
                    a3 = a[i + offset + p * 3] * rot3
                    a1na3imag = (a1 - a3) % MOD * IMAG
                    a[i + offset] = (a0 + a2 + a1 + a3) % MOD
                    a[i + offset + p] = (a0 + a2 - a1 - a3) % MOD
                    a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % MOD
                    a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % MOD
                rot *= rate3[(~s & -~s).bit_length()]
                rot %= MOD
            le += 2


def butterfly_inv(a):
    n = len(a)
    h = (n - 1).bit_length()
    le = h
    while le:
        if le == 1:
            p = 1 << (h - le)
            irot = 1
            for s in range(1 << (le - 1)):
                offset = s << (h - le + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p]
                    a[i + offset] = (l + r) % MOD
                    a[i + offset + p] = (l - r) * irot % MOD
                irot *= irate2[(~s & -~s).bit_length()]
                irot %= MOD
            le -= 1
        else:
            p = 1 << (h - le)
            irot = 1
            for s in range(1 << (le - 2)):
                irot2 = irot * irot % MOD
                irot3 = irot2 * irot % MOD
                offset = s << (h - le + 2)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p]
                    a2 = a[i + offset + p * 2]
                    a3 = a[i + offset + p * 3]
                    a2na3iimag = (a2 - a3) * IIMAG % MOD
                    a[i + offset] = (a0 + a1 + a2 + a3) % MOD
                    a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % MOD
                    a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % MOD
                    a[i + offset + p * 3] = (a0 - a1 -
                                             a2na3iimag) * irot3 % MOD
                irot *= irate3[(~s & -~s).bit_length()]
                irot %= MOD
            le -= 2


def multiply(s, t):
    n = len(s)
    m = len(t)
    if min(n, m) <= 60:
        a = [0] * (n + m - 1)
        for i in range(n):
            if i % 8 == 0:
                for j in range(m):
                    a[i + j] += s[i] * t[j]
                    a[i + j] %= MOD
            else:
                for j in range(m):
                    a[i + j] += s[i] * t[j]
        return [x % MOD for x in a]
    a = s.copy()
    b = t.copy()
    z = 1 << (n + m - 2).bit_length()
    a += [0] * (z - n)
    b += [0] * (z - m)
    butterfly(a)
    butterfly(b)
    for i in range(z):
        a[i] *= b[i]
        a[i] %= MOD
    butterfly_inv(a)
    a = a[:n + m - 1]
    iz = pow(z, MOD - 2, MOD)
    return [v * iz % MOD for v in a]


n = int(input())
nn = n << 1
two = [1] * (nn + 1)
for i in range(nn):
    two[i + 1] = two[i] << 1
    if two[i + 1] > MOD:
        two[i + 1] -= MOD
f = [0] * (n + 1)
g = [0] * (n + 1)
for i in range(1, n + 1):
    f[i] = ((i + 3) * two[i] + (i - 3) * two[i << 1]) % MOD
    g[i] = ((i - 2) * two[i] + 2) % MOD
u, v = input(), input()
if u == v:
    print(f[n])
    exit()
lca = 0
for i in range(min(len(u), len(v))):
    if u[i] != v[i]:
        break
    lca += 1
m = len(u) + len(v) - 2 * lca + 1
p = [0] * m
p[0] = len(u)
cnt = 0
while p[cnt] != lca:
    p[cnt + 1] = p[cnt] - 1
    cnt += 1
lca_cnt = cnt
while cnt != m - 1:
    p[cnt + 1] = p[cnt] + 1
    cnt += 1
for i in range(m):
    p[i] -= 1
a = [0] * m
b = [0] * m
c = [0] * m
a[0] = two[n - p[0]] - 1
a[m - 1] = two[n - p[m - 1]] - 1
a[lca_cnt] = 1
k, d = 1, p[lca_cnt]
while k <= d:
    a[lca_cnt] += two[n - k]
    d += 1
if lca_cnt == 0 or lca_cnt == m - 1:
    a[lca_cnt] += two[n - p[lca_cnt] - 1] - 1
a[lca_cnt] %= MOD
for i in range(1, m - 1):
    if i == lca_cnt:
        continue
    a[i] = two[n - 1 - p[i]]
b[0] = g[n - p[0]]
b[m - 1] = g[n - p[m - 1]]
b[lca_cnt] = 0
k, d = 1, p[lca_cnt]
while k <= d:
    b[lca_cnt] += g[n - k] + (two[n - k] - 1) * (d + 2 - k) % MOD + k
if lca_cnt == 0 or lca_cnt == m - 1:
    b[lca_cnt] += g[n - p[lca_cnt] - 1] + two[n - p[lca_cnt] - 1] - 1
b[lca_cnt] %= MOD
for i in range(1, m - 1):
    if i == lca_cnt:
        continue
    b[i] = (g[n - 1 - p[i]] + two[n - 1 - p[i]] - 1) % MOD
c[0] = f[n - p[0]]
c[m - 1] = f[n - p[m - 1]]
c[lca_cnt] = 0
k, d = 1, p[lca_cnt]
while k <= d:
    c[lca_cnt] += f[n - k]
    c[lca_cnt] += g[n - k] * (a[lca_cnt] - two[n - k] + 1) % MOD
    c[lca_cnt] += (two[n - k] - 1) * (a[lca_cnt] - two[n - k] + 1) % MOD
    c[lca_cnt] += (two[n] - two[n - k]) * \
        (a[lca_cnt] - two[n] + two[n - k]) % MOD
    c[lca_cnt] %= MOD
if lca_cnt == 0 or lca_cnt == m - 1:
    c[lca_cnt] += f[n - p[lca_cnt] - 1]
    c[lca_cnt] += g[n - p[lca_cnt] - 1] * \
        (a[lca_cnt] - two[n - p[lca_cnt] - 1] + 1) % MOD
    c[lca_cnt] += (two[n - p[lca_cnt] - 1] - 1) * \
        (a[lca_cnt] - two[n - p[lca_cnt] - 1] + 1) % MOD
    c[lca_cnt] %= MOD
for i in range(1, m - 1):
    if i == lca_cnt:
        continue
    c[i] = f[n - 1 - p[i]] + g[n - 1 - p[i]] + two[n - 1 - p[i]] - 1
    c[i] %= MOD
sum_a, sum_b, sum_c, sum_ab = 0, 0, 0, 0
for i in a:
    sum_a += i
sum_a %= MOD
for i in b:
    sum_b += i
sum_b %= MOD
for i in c:
    sum_c += i
sum_c %= MOD
for i in range(m):
    sum_ab += a[i] * b[i] % MOD
sum_ab %= MOD
ans = (sum_c + sum_a * sum_b - sum_ab) % MOD
aa = [a[m - 1 - i] for i in range(m)]
tt = multiply(a, aa)
sum = 0
for j in range(1, m):
    sum += min(j, m - j) * (tt[m - 1 - j] + tt[2 * m - 1 - j]) % MOD
sum %= MOD
if sum % 2 != 0:
    sum += MOD
sum //= 2
ans += sum
ans %= MOD
if ans < 0:
    ans += MOD
print(ans)
0