結果

問題 No.1195 数え上げを愛したい(文字列編)
ユーザー Navier_BoltzmannNavier_Boltzmann
提出日時 2023-11-17 10:11:41
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,903 ms / 3,000 ms
コード長 19,080 bytes
コンパイル時間 283 ms
コンパイル使用メモリ 81,828 KB
実行使用メモリ 261,752 KB
最終ジャッジ日時 2023-11-17 10:12:12
合計ジャッジ時間 30,089 ms
ジャッジサーバーID
(参考情報)
judge11 / judge15
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1,903 ms
253,460 KB
testcase_01 AC 1,867 ms
252,976 KB
testcase_02 AC 1,854 ms
253,184 KB
testcase_03 AC 336 ms
99,316 KB
testcase_04 AC 409 ms
104,632 KB
testcase_05 AC 355 ms
102,504 KB
testcase_06 AC 49 ms
59,848 KB
testcase_07 AC 49 ms
59,848 KB
testcase_08 AC 404 ms
104,224 KB
testcase_09 AC 1,832 ms
261,752 KB
testcase_10 AC 1,046 ms
183,988 KB
testcase_11 AC 1,670 ms
242,996 KB
testcase_12 AC 1,566 ms
237,580 KB
testcase_13 AC 1,345 ms
210,872 KB
testcase_14 AC 924 ms
160,324 KB
testcase_15 AC 1,044 ms
182,036 KB
testcase_16 AC 913 ms
162,756 KB
testcase_17 AC 449 ms
108,756 KB
testcase_18 AC 1,577 ms
235,692 KB
testcase_19 AC 1,567 ms
237,820 KB
testcase_20 AC 1,346 ms
211,572 KB
testcase_21 AC 1,668 ms
245,100 KB
testcase_22 AC 1,274 ms
203,004 KB
testcase_23 AC 51 ms
59,848 KB
testcase_24 AC 51 ms
59,848 KB
testcase_25 AC 52 ms
61,952 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import *
from itertools import *
from functools import *
from heapq import *
import sys,math
input = sys.stdin.readline

#https://judge.yosupo.jp/submission/84100
#Convolution_998244353

MOD = 998244353
IMAG = 911660635
IIMAG = 86583718
rate2 = (911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456, 131300601,
              842657263, 730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443, 56250497, 867605899)
irate2 = (86583718, 372528824, 373294451, 645684063, 112220581, 692852209, 155456985, 797128860, 90816748, 860285882, 927414960,
               354738543, 109331171, 293255632, 535113200, 308540755, 121186627, 608385704, 438932459, 359477183, 824071951, 103369235)
rate3 = (372528824, 337190230, 454590761, 816400692, 578227951, 180142363, 83780245, 6597683, 70046822, 623238099,
              183021267, 402682409, 631680428, 344509872, 689220186, 365017329, 774342554, 729444058, 102986190, 128751033, 395565204)
irate3 = (509520358, 929031873, 170256584, 839780419, 282974284, 395914482, 444904435, 72135471, 638914820, 66769500,
               771127074, 985925487, 262319669, 262341272, 625870173, 768022760, 859816005, 914661783, 430819711, 272774365, 530924681)

def _butterfly_(a):
    n = len(a)
    h = (n - 1).bit_length()
    len_ = 0
    while len_ < h:
        if h - len_ == 1:
            p = 1 << (h - len_ - 1)
            rot = 1
            for s in range(1 << len_):
                offset = s << (h - len_)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p] * rot % MOD
                    a[i + offset] = (l + r) % MOD
                    a[i + offset + p] = (l - r) % MOD
                if s + 1 != 1 << len_:
                    rot *= rate2[(~s & -~s).bit_length() - 1]
                    rot %= MOD
            len_ += 1
        else:
            p = 1 << (h - len_ - 2)
            rot = 1
            for s in range(1 << len_):
                rot2 = rot * rot % MOD
                rot3 = rot2 * rot % MOD
                offset = s << (h - len_)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p] * rot
                    a2 = a[i + offset + p * 2] * rot2
                    a3 = a[i + offset + p * 3] * rot3
                    a1na3imag = (a1 - a3) % MOD * IMAG
                    a[i + offset] = (a0 + a2 + a1 + a3) % MOD
                    a[i + offset + p] = (a0 + a2 - a1 - a3) % MOD
                    a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % MOD
                    a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % MOD
                if s + 1 != 1 << len_:
                    rot *= rate3[(~s & -~s).bit_length() - 1]
                    rot %= MOD
            len_ += 2

def _butterfly_inv_(a):
    n = len(a)
    h = (n - 1).bit_length()
    len_ = h
    while len_:
        if len_ == 1:
            p = 1 << (h - len_)
            irot = 1
            for s in range(1 << (len_ - 1)):
                offset = s << (h - len_ + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p]
                    a[i + offset] = (l + r) % MOD
                    a[i + offset + p] = (l - r) * irot % MOD
                if s + 1 != (1 << (len_ - 1)):
                    irot *= irate2[(~s & -~s).bit_length() - 1]
                    irot %= MOD
            len_ -= 1
        else:
            p = 1 << (h - len_)
            irot = 1
            for s in range(1 << (len_ - 2)):
                irot2 = irot * irot % MOD
                irot3 = irot2 * irot % MOD
                offset = s << (h - len_ + 2)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p]
                    a2 = a[i + offset + p * 2]
                    a3 = a[i + offset + p * 3]
                    a2na3iimag = (a2 - a3) * IIMAG % MOD
                    a[i + offset] = (a0 + a1 + a2 + a3) % MOD
                    a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % MOD
                    a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % MOD
                    a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % MOD
                if s + 1 != (1 << (len_ - 2)):
                    irot *= irate3[(~s & -~s).bit_length() - 1]
                    irot %= MOD
            len_ -= 2
    inv = pow(n, MOD - 2, MOD)
    for i in range(n):
        a[i] *= inv
        a[i] %= MOD

def build_exp(n, b):
    exp = [0] * (n + 1)
    exp[0] = 1
    for i in range(n):
        exp[i + 1] = exp[i] * b % MOD
    return exp

def build_factorial(n):
    fct = [0] * (n + 1)
    inv = [0] * (n + 1)
    fct[0] = inv[0] = 1
    for i in range(n):
        fct[i + 1] = fct[i] * (i + 1) % MOD
    inv[n] = pow(fct[n], MOD - 2, MOD)
    for i in range(n)[::-1]:
        inv[i] = inv[i + 1] * (i + 1) % MOD
    return fct, inv

def sqrt_mod(n):
    if n == 0: return 0
    if n == 1: return 1
    h = (MOD - 1) // 2
    if pow(n, h, MOD) != 1: return -1
    q, s = MOD - 1, 0
    while not q & 1:
        q >>= 1
        s += 1
    z = 1
    while pow(z, h, MOD) != MOD - 1:
        z += 1
    m, c, t, r = s, pow(z, q, MOD), pow(n, q, MOD), pow(n, (q + 1) // 2, MOD)
    while t != 1:
        k = 1
        while pow(t, 1 << k, MOD) != 1:
            k += 1
        x = pow(c, pow(2, m - k - 1, MOD - 1), MOD)
        m = k
        c = (x * x) % MOD
        t = (t * c) % MOD
        r = (r * x) % MOD
    if r * r % MOD != n: return -1
    return r

class FormalPowerSeries():
    def __init__(self, arr=None):
        if arr is None: arr = []
        self.arr = [v % MOD for v in arr]

    def __len__(self):
        return len(self.arr)

    def __getitem__(self, key):
        if isinstance(key, slice):
            return FormalPowerSeries(self.arr[key])
        else:
            assert key >= 0
            if key >= len(self):
                return 0
            else:
                return self.arr[key]

    def __setitem__(self, key, val):
        assert key >= 0
        if key >= len(self):
            self.arr += [0] * (key - len(self) + 1)
        self.arr[key] = val % MOD

    def __str__(self):
        return ' '.join(map(str, self.arr))

    def resize(self, sz):
        assert sz >= 0
        if len(self) >= sz:
            return self[:sz]
        else:
            return FormalPowerSeries(self.arr + [0] * (sz - len(self)))

    def shrink(self):
        while self.arr and not self.arr[-1]:
            self.arr.pop()

    def times(self, k):
        if k:
            return FormalPowerSeries([v * k for v in self.arr])
        else:
            return FormalPowerSeries([])

    def __pos__(self):
        return self

    def __neg__(self):
        return self.times(-1)

    def __add__(self, other):
        if other.__class__ == FormalPowerSeries:
            n = len(self)
            m = len(other)
            arr = [self[i] + other[i] for i in range(min(n, m))]
            if n >= m:
                arr += self.arr[m:]
            else:
                arr += other.arr[n:]
            return FormalPowerSeries(arr)
        else:
            return self + FormalPowerSeries([other])

    def __iadd__(self, other):
        if other.__class__ == FormalPowerSeries:
            n = len(self)
            m = len(other)
            for i in range(min(n, m)):
                self.arr[i] += other[i]
                self.arr[i] %= MOD
            if n < m:
                self.arr += other.arr[n:]
        else:
            self.arr[0] += other
            self.arr[0] %= MOD
        return self

    def __radd__(self, other):
        return self + other

    def __sub__(self, other):
        return self + (-other)

    def __isub__(self, other):
        self += -other
        return self

    def __rsub__(self, other):
        return (-self) + other

    def __mul__(self, other):
        if other.__class__ == FormalPowerSeries:
            f = self.arr.copy()
            g = other.arr.copy()
            n = len(f)
            m = len(g)
            if not n or not m: return FormalPowerSeries()
            if min(n, m) <= 50:
                if n < m: f, n, g, m = g, m, f, n
                arr = [0] * (n + m - 1)
                for i in range(n):
                    for j in range(m):
                        arr[i + j] += f[i] * g[j]
                        arr[i + j] %= MOD
                return FormalPowerSeries(arr)
            z = 1 << (n + m - 2).bit_length()
            f += [0] * (z - n)
            g += [0] * (z - m)
            _butterfly_(f)
            _butterfly_(g)
            for i in range(z):
                f[i] *= g[i]
                f[i] %= MOD
            _butterfly_inv_(f)
            f = f[:n + m - 1]
            return FormalPowerSeries(f)
        else:
            return self.times(other)

    def __imul__(self, other):
        if other.__class__ == FormalPowerSeries:
            f = self.arr.copy()
            g = other.arr.copy()
            n = len(f)
            m = len(g)
            if not n or not m: return FormalPowerSeries()
            if min(n, m) <= 50:
                if n < m: f, n, g, m = g, m, f, n
                arr = [0] * (n + m - 1)
                for i in range(n):
                    for j in range(m):
                        arr[i + j] += f[i] * g[j]
                        arr[i + j] %= MOD
                self.arr = arr
                return self
            z = 1 << (n + m - 2).bit_length()
            f += [0] * (z - n)
            g += [0] * (z - m)
            _butterfly_(f)
            _butterfly_(g)
            for i in range(z):
                f[i] *= g[i]
                f[i] %= MOD
            _butterfly_inv_(f)
            self.arr = f[:n + m - 1]
            return self
        else:
            n = len(self)
            for i in range(n):
                self.arr[i] *= other
                self.arr[i] %= MOD
            return self

    def __rmul__(self, other):
        return self.times(other)

    def __pow__(self, k): #exp書いたら修正
        n = len(self)
        tmp = FormalPowerSeries(self.arr)
        res = FormalPowerSeries([1])
        while k:
            if k & 1:
                res *= tmp
                res = res.resize(n)
            tmp *= tmp
            tmp = tmp.resize(n)
            k >>= 1
        return res

    def square(self):
        f = self.arr.copy()
        n = len(f)
        if not n: return FormalPowerSeries()
        if n <= 50:
            arr = [0] * (2 * n - 1)
            for i in range(n):
                for j in range(n):
                    arr[i + j] += f[i] * f[j]
                    arr[i + j] %= MOD
            return FormalPowerSeries(arr)
        z = 1 << (2 * n - 2).bit_length()
        f += [0] * (z - n)
        _butterfly_(f)
        for i in range(z):
            f[i] *= f[i]
            f[i] %= MOD
        _butterfly_inv_(f)
        f = f[:2 * n - 1]
        return FormalPowerSeries(f)

    def __lshift__(self, key):
        assert key >= 0
        return FormalPowerSeries([0] * key + self.arr)

    def __rshift__(self, key):
        assert key >= 0
        return self[key:]

    def __invert__(self):
        assert self[0] != 0
        n = len(self)
        r = pow(self[0], MOD - 2, MOD)
        m = 1
        res = FormalPowerSeries([r])
        while m < n:
            f = [0] * (2 * m)
            g = [0] * (2 * m)
            for i in range(2 * m):
                f[i] = self[i]
            for i in range(m):
                g[i] = res[i]
            _butterfly_(f)
            _butterfly_(g)
            for i in range(2 * m):
                f[i] *= g[i]
                f[i] %= MOD
            _butterfly_inv_(f)
            for i in range(m):
                f[i] = 0
            _butterfly_(f)
            for i in range(2 * m):
                f[i] *= g[i]
                f[i] %= MOD
            _butterfly_inv_(f)
            for i in range(m, 2 * m):
                res[i] -= f[i]
            m <<= 1
        return res.resize(n)

    def __truediv__(self, other):
        if other.__class__ == FormalPowerSeries:
            n = max(len(self), len(other))
            return (self * ~other).resize(n)
        else:
            return self * pow(other, MOD - 2, MOD)

    def __rtruediv__(self, other):
        return other * ~self

    def differentiate(self):
        n = len(self)
        arr = [0] * n
        for i in range(1, n):
            arr[i - 1] = self[i] * i % MOD
        return FormalPowerSeries(arr)

    def integrate(self):
        n = len(self)
        arr = [0] * n
        inv = [1] * n
        for i in range(2, n):
            inv[i] = MOD - inv[MOD % i] * (MOD // i) % MOD
        for i in range(n - 1):
            arr[i + 1] = self[i] * inv[i + 1] % MOD
        return FormalPowerSeries(arr)

    def log(self):
        assert self[0] == 1
        n = len(self)
        return (self.differentiate() / self).integrate()

    def exp(self):
        assert self[0] == 0
        n = len(self)
        res = FormalPowerSeries([1])
        g = FormalPowerSeries([1])
        q = self.differentiate()
        m = 1
        while m < n:
            g = g * 2 - res * g.square().resize(m)
            res = res.resize(2 * m)
            m *= 2
            w = q.resize(m) + (g * (res.differentiate() - (res * q.resize(m)).resize(m))).resize(m)
            res = res + (res * (self.resize(m) - w.integrate())).resize(m)
        return res.resize(n)

    def __floordiv__(self, other):
        if other.__class__ == FormalPowerSeries:
            n = len(self)
            m = len(other)
            if n < m: return FormalPowerSeries([])
            l = n - m + 1
            if m <= 100:
                arr = [0] * l
                inv = pow(other[m - 1], MOD - 2, MOD)
                tmp = self[::-1]
                for i in range(l):
                    arr[i] = tmp[i] * inv % MOD
                    for j in range(m):
                        tmp[i + j] -= other[m - j - 1] * arr[i]
                        tmp[i + j] %= MOD
                return FormalPowerSeries(arr[::-1])
            res = (self[~l:][::-1] * ~(other[::-1].resize(l))).resize(l)[::-1]
            return res
        else:
            return self * pow(other, MOD - 2, MOD)

    def __rfloordiv__(self, other):
        return other * ~self

    def __mod__(self, other):
        if other.__class__ == FormalPowerSeries:
            n = len(self)
            m = len(other)
            if n < m: return FormalPowerSeries(self.arr)
            res = self[:m - 1] - ((self // other) * other)[:m - 1]
            res.shrink()
            return res
        else:
            return 0

    def divmod(self, other):
        if other.__class__ == FormalPowerSeries:
            div = self // other
            n = len(self)
            m = len(other)
            if n < m:
                mod = FormalPowerSeries(self.arr)
            else:
                mod = self[:m - 1] - ((self // other) * other)[:m - 1]
                mod.shrink()
        else:
            div = self // other
            mod = 0
        return div, mod

    def __matmul__(self, other):
        assert self.__class__ == other.__class__ == FormalPowerSeries
        assert other[0] == 0
        assert len(self) == len(other)
        n = len(self)
        #fをkブロックに分割する。dはブロック内の要素数。k >= dになるように。
        k = int((n - 1)**0.5 + 1)
        d = (n + k - 1) // k
        powg = [FormalPowerSeries([1])]
        for i in range(k):
            powg.append((powg[i] * other).resize(n))
        fi = [FormalPowerSeries([0] * n) for _ in range(k)]
        for i in range(k):
            for j in range(d):
                if i * d + j >= n: break
                for t in range(n):
                    if t >= len(powg[j]): break
                    fi[i][t] += powg[j][t] * self[i * d + j]
                    fi[i][t] %= MOD
        res = FormalPowerSeries([0] * n)
        gd = FormalPowerSeries([1])
        for i in range(k):
            fi[i] *= gd
            fi[i] = fi[i].resize(n)
            res += fi[i]
            gd *= powg[d]
            gd = gd.resize(n)
        return res

    def multipoint_evaluation(self, xs):
        n = len(xs)
        sz = 1 << (n - 1).bit_length()
        g = [FormalPowerSeries([1]) for _ in range(2 * sz)]
        for i in range(n):
            g[i + sz] = FormalPowerSeries([-xs[i], 1])
        for i in range(1, sz)[::-1]:
            g[i] = g[2 * i] * g[2 * i + 1]
        g[1] = self % g[1]
        for i in range(2, 2 * sz):
            g[i] = g[i >> 1] % g[i]
        res = [g[i + sz][0] for i in range(n)]
        return res

def polynomial_interpolation(xs, ys):
    assert len(xs) == len(ys)
    n = len(xs)
    sz = 1 << (n - 1).bit_length()
    f = [FormalPowerSeries([1]) for _ in range(2 * sz)]
    for i in range(n):
        f[i + sz] = FormalPowerSeries([-xs[i], 1])
    for i in range(1, sz)[::-1]:
        f[i] = f[2 * i] * f[2 * i + 1]
    g = [FormalPowerSeries([0])] * (2 * sz)
    g[1] = f[1].differentiate() % f[1]
    for i in range(2, n + sz):
        g[i] = g[i >> 1] % f[i]
    for i in range(n):
        g[i + sz] = FormalPowerSeries([ys[i] * pow(g[i + sz][0], MOD - 2, MOD) % MOD])
    for i in range(1, sz)[::-1]:
        g[i] = g[2 * i] * f[2 * i + 1] + g[2 * i + 1] * f[2 * i]
    return g[1][:n]

def berlekamp_massey(arr):
    if arr.__class__ == FormalPowerSeries:
        arr = arr.arr
    n = len(arr)
    b = [1]
    c = [1]
    l, m, p = 0, 0, 1
    for i in range(n):
        m += 1
        d = arr[i]
        for j in range(1, l + 1):
            d += c[j] * arr[i - j]
            d %= MOD
        if d == 0: continue
        t = c.copy()
        q = d * pow(p, MOD - 2, MOD) % MOD
        if len(c) < len(b) + m:
            c += [0] * (len(b) + m - len(c))
        for j in range(len(b)):
            c[j + m] -= q * b[j]
            c[j + m] %= MOD
        if 2 * l <= i:
            b = t
            l, m, p = i + 1 - l, 0, d
    return c

def linear_recurrence(arr, coeff, k):
    if arr.__class__ == FormalPowerSeries:
        arr = arr.arr
    d = len(arr)
    f = FormalPowerSeries(arr)
    q = FormalPowerSeries(coeff)
    p = (f * q).resize(d)
    while k:
        r = [-q[i] if i & 1 else q[i] for i in range(len(q))] + [0] * (d + 1 - len(q))
        r = FormalPowerSeries(r)
        p *= r
        q *= r
        p = p[(k & 1)::2]
        q = q[::2]
        k >>= 1
    return p[0] % MOD
    
mod = 998244353
S = list(input())[:-1]

C = Counter(S)
F = FormalPowerSeries([1])
finv = [1]*(len(S)+1)
f = [1]*(len(S) + 1)
for i in range(2,len(S)+1):
    finv[i] = finv[i-1]*pow(i,mod-2,mod)%mod
    f[i] = f[i-1]*i%mod
for v in C.values():
    
    G = [finv[i] for i in range(v+1)]
    F = F*FormalPowerSeries(G)

ans = 0
for i in range(1,len(S)+1):
    
    ans += F[i]*f[i]
    ans %= mod
print(ans)
    
    
0