結果

問題 No.1321 塗るめた
ユーザー vwxyzvwxyz
提出日時 2023-04-27 02:12:20
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,501 ms / 2,000 ms
コード長 16,204 bytes
コンパイル時間 192 ms
コンパイル使用メモリ 81,792 KB
実行使用メモリ 151,572 KB
最終ジャッジ日時 2024-04-28 01:46:47
合計ジャッジ時間 31,741 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
55,936 KB
testcase_01 AC 145 ms
78,940 KB
testcase_02 AC 38 ms
56,704 KB
testcase_03 AC 36 ms
56,576 KB
testcase_04 AC 36 ms
56,192 KB
testcase_05 AC 35 ms
56,704 KB
testcase_06 AC 34 ms
56,448 KB
testcase_07 AC 39 ms
57,472 KB
testcase_08 AC 47 ms
66,432 KB
testcase_09 AC 56 ms
70,784 KB
testcase_10 AC 39 ms
55,808 KB
testcase_11 AC 58 ms
72,192 KB
testcase_12 AC 745 ms
113,620 KB
testcase_13 AC 266 ms
85,908 KB
testcase_14 AC 423 ms
92,932 KB
testcase_15 AC 767 ms
112,784 KB
testcase_16 AC 801 ms
115,388 KB
testcase_17 AC 127 ms
78,352 KB
testcase_18 AC 1,491 ms
148,528 KB
testcase_19 AC 445 ms
94,768 KB
testcase_20 AC 755 ms
110,892 KB
testcase_21 AC 788 ms
116,240 KB
testcase_22 AC 1,438 ms
150,100 KB
testcase_23 AC 1,455 ms
148,268 KB
testcase_24 AC 1,439 ms
147,604 KB
testcase_25 AC 1,463 ms
149,200 KB
testcase_26 AC 1,501 ms
151,572 KB
testcase_27 AC 1,391 ms
145,652 KB
testcase_28 AC 1,386 ms
145,992 KB
testcase_29 AC 1,388 ms
146,364 KB
testcase_30 AC 1,498 ms
151,220 KB
testcase_31 AC 764 ms
116,200 KB
testcase_32 AC 737 ms
112,208 KB
testcase_33 AC 749 ms
110,096 KB
testcase_34 AC 802 ms
115,728 KB
testcase_35 AC 757 ms
112,236 KB
testcase_36 AC 53 ms
72,320 KB
testcase_37 AC 775 ms
115,616 KB
testcase_38 AC 792 ms
116,012 KB
testcase_39 AC 782 ms
115,752 KB
testcase_40 AC 776 ms
116,136 KB
testcase_41 AC 776 ms
115,872 KB
testcase_42 AC 36 ms
55,680 KB
testcase_43 AC 734 ms
112,076 KB
testcase_44 AC 741 ms
110,100 KB
testcase_45 AC 796 ms
115,468 KB
testcase_46 AC 272 ms
84,416 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import math
import sys
readline=sys.stdin.readline

mod = 998244353
class FPS:
    sum_e = (
    911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456,
    131300601, 842657263, 730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443,
    56250497)
    sum_ie = (
    86583718, 372528824, 373294451, 645684063, 112220581, 692852209, 155456985, 797128860, 90816748, 860285882,
    927414960, 354738543, 109331171, 293255632, 535113200, 308540755, 121186627, 608385704, 438932459, 359477183,
    824071951)
    mod = 998244353
    Func = [0]

    def __init__(self, L):
        self.Func = [x % self.mod for x in L]

    def butterfly(self, a):
        n = len(a)
        h = (n - 1).bit_length()
        for ph in range(1, h + 1):
            w = 1 << (ph - 1)
            p = 1 << (h - ph)
            now = 1
            for s in range(w):
                offset = s << (h - ph + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p] * now
                    r %= self.mod
                    a[i + offset] = l + r
                    a[i + offset] %= self.mod
                    a[i + offset + p] = l - r
                    a[i + offset + p] %= self.mod
                now *= self.sum_e[(~s & -~s).bit_length() - 1]
                now %= self.mod
        return a

    def butterfly_inv(self, a):
        n = len(a)
        h = (n - 1).bit_length()
        for ph in range(h, 0, -1):
            w = 1 << (ph - 1)
            p = 1 << (h - ph)
            inow = 1
            for s in range(w):
                offset = s << (h - ph + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p]
                    a[i + offset] = l + r
                    a[i + offset] %= self.mod
                    a[i + offset + p] = (l - r) * inow
                    a[i + offset + p] %= self.mod
                inow *= self.sum_ie[(~s & -~s).bit_length() - 1]
                inow %= self.mod
        return a

    def __mul__(self, other):
        if type(other) == int:
            ret = [(x * other) % self.mod for x in self.Func]
            return FPS(ret)
        a = self.Func
        b = other.Func
        n = len(a);
        m = len(b)
        if not (a) or not (b):
            return FPS([])
        if min(n, m) <= 40:
            if n < m:
                n, m = m, n
                a, b = b, a
            res = [0] * (n + m - 1)
            for i in range(n):
                for j in range(m):
                    res[i + j] += a[i] * b[j]
                    res[i + j] %= self.mod
            return FPS(res)
        z = 1 << ((n + m - 2).bit_length())
        a = a + [0] * (z - n)
        b = b + [0] * (z - m)
        a = self.butterfly(a)
        b = self.butterfly(b)
        c = [0] * z
        for i in range(z):
            c[i] = (a[i] * b[i]) % self.mod
        self.butterfly_inv(c)
        iz = pow(z, self.mod - 2, self.mod)
        for i in range(n + m - 1):
            c[i] = (c[i] * iz) % self.mod
        return FPS(c[:n + m - 1])

    def __imul__(self, other):
        self = self * other
        return self

    def __add__(self, other):
        res = [0 for i in range(max(len(self.Func), len(other.Func)))]
        for i, x in enumerate(self.Func):
            res[i] += x
            res[i] %= self.mod
        for i, x in enumerate(other.Func):
            res[i] += x
            res[i] %= self.mod
        return FPS(res)

    def __iadd__(self, other):
        self = (self + other)
        return self

    def __sub__(self, other):
        res = [0 for i in range(max(len(self.Func), len(other.Func)))]
        for i, x in enumerate(self.Func):
            res[i] += x
            res[i] %= self.mod
        for i, x in enumerate(other.Func):
            res[i] -= x
            res[i] %= self.mod
        return FPS(res)

    def __isub__(self, other):
        self = self - other
        return self

    def inv(self, d=-1):
        n = len(self.Func)
        assert n != 0 and self.Func[0] != 0
        if d == -1: d = n
        assert d > 0
        res = [pow(self.Func[0], self.mod - 2, self.mod)]
        while (len(res) < d):
            m = len(res)
            f = [self.Func[i] for i in range(min(n, 2 * m))]
            r = res[:]

            if len(f) < 2 * m:
                f += [0] * (2 * m - len(f))
            elif len(f) > 2 * m:
                f = f[:2 * m]
            if len(r) < 2 * m:
                r += [0] * (2 * m - len(r))
            elif len(r) > 2 * m:
                r = r[:2 * m]
            f = self.butterfly(f)
            r = self.butterfly(r)
            for i in range(2 * m):
                f[i] *= r[i]
                f[i] %= self.mod
            f = self.butterfly_inv(f)
            f = f[m:]
            if len(f) < 2 * m:
                f += [0] * (2 * m - len(f))
            elif len(f) > 2 * m:
                f = f[:2 * m]
            f = self.butterfly(f)
            for i in range(2 * m):
                f[i] *= r[i]
                f[i] %= self.mod
            f = self.butterfly_inv(f)
            iz = pow(2 * m, self.mod - 2, self.mod)
            iz *= -iz
            iz %= self.mod
            for i in range(m):
                f[i] *= iz
                f[i] %= self.mod
            res += f[:m]
        return FPS(res[:d])

    def __truediv__(self, other):
        if type(other) == int:
            invother = pow(other, self.mod - 2, self.mod)
            ret = [(x * invother) % self.mod for x in self.Func]
            return FPS(ret)
        assert (other.Func[0] != 0)
        return self * (other.inv())

    def __itruediv__(self, other):
        self = self / other
        return self

    def __lshift__(self, d):
        n = len(self.Func)
        self.Func = [0] * d + self.Func
        return FPS(self.Func[:n])

    def __ilshift__(self, d):
        self = self << d
        return self

    def __rshift__(self, d):
        n = len(self.Func)
        self.Func = self.Func[min(n, d):]
        self.Func += [0] * (n - len(self.Func))
        return FPS(self.Func)

    def __irshift__(self, d):
        self = self >> d
        return self

    def __str__(self):
        return f'FPS({self.Func})'

    def diff(self):
        n = len(self.Func)
        ret = [0 for i in range(max(0, n - 1))]
        for i in range(1, n):
            ret[i - 1] = (self.Func[i] * i) % self.mod
        return FPS(ret)

    def integral(self):
        n = len(self.Func)
        ret = [0 for i in range(n + 1)]
        for i in range(n):
            ret[i + 1] = self.Func[i] * pow(i + 1, self.mod - 2, self.mod) % self.mod
        return FPS(ret)

    def log(self, deg=-1):
        assert self.Func[0] == 1
        n = len(self.Func)
        if deg == -1: deg = n
        return (self.diff() * self.inv()).integral()

    def mod_sqrt(self, a):
        p = self.mod
        assert 0 <= a and a < p
        if a < 2: return a
        if pow(a, (p - 1) // 2, p) != 1: return -1
        b = 1;
        one = 1
        while (pow(b, (p - 1) >> 1, p) == 1):
            b += one
        m = p - 1;
        e = 0
        while (m % 2 == 0):
            m >>= 1
            e += 1
        x = pow(a, (m - 1) >> 1, p)
        y = (a * x * x) % p
        x *= a;
        x %= p
        z = pow(b, m, p)
        while (y != 1):
            j = 0
            t = y
            while (t != one):
                j += 1
                t *= t
                t %= p
            z = pow(z, 1 << (e - j - 1), p)
            x *= z
            x %= p
            z *= z
            z %= p
            y *= z
            y %= p
            e = j
        return x

    def sqrt(self, deg=-1):
        n = len(self.Func)
        if deg == -1: deg = n
        if n == 0: return FPS([0 for i in range(deg)])
        if self.Func[0] == 0:
            for i in range(1, n):
                if self.Func[i] != 0:
                    if i & 1: return FPS([])
                    if deg - i // 2 <= 0: break
                    ret = (self >> i).sqrt(deg - i // 2)
                    if len(ret.Func) == 0: return FPS([])
                    ret = ret << (i // 2)
                    if len(ret.Func) < deg:
                        ret.Func += [0] * (deg - len(ret.Func))
                    return ret
            return FPS([0] * deg)
        sqr = self.mod_sqrt(self.Func[0])
        if sqr == -1: return FPS([])
        assert sqr * sqr % self.mod == self.Func[0]
        ret = FPS([sqr])
        inv2 = (self.mod + 1) // 2
        i = 1
        while (i < deg):
            ret = (ret + FPS(self.Func[:i << 1]) * ret.inv(i << 1)) * inv2
            i <<= 1
        return FPS(ret.Func[:deg])

    def resize(self, deg):
        if len(self.Func) < deg:
            return FPS(self.Func + [0] * (deg - len(self.Func)))
        elif len(self.Func) > deg:
            return FPS(self.Func[:deg])
        else:
            return self

    def exp(self, deg=-1):
        n = len(self.Func)
        assert n > 0 and self.Func[0] == 0
        if deg == -1: deg = n
        assert deg >= 0
        g = [1]
        g_fft = [1, 1]
        self.Func[0] = 1
        self.resize(deg)
        h_drv = self.diff()
        m = 2
        while (m < deg):
            f_fft = self.Func[:m] + [0] * m
            self.butterfly(f_fft)

            # step 2.a
            _g = [f_fft[i] * g_fft[i] % self.mod for i in range(m)]
            self.butterfly_inv(_g)
            _g = _g[m // 2:m] + [0] * (m // 2)
            self.butterfly(_g)
            for i in range(m):
                _g[i] *= g_fft[i]
                _g[i] %= self.mod
            self.butterfly_inv(_g)
            tmp = pow(-m * m, self.mod - 2, self.mod)
            for i in range(m):
                _g[i] *= tmp
                _g[i] %= self.mod
            g += _g[:m // 2]
            # step 2.b--2.d
            t = FPS(self.Func[:m]).diff()
            r = h_drv.Func[:m - 1] + [0]
            self.butterfly(r)
            for i in range(m):
                r[i] *= f_fft[i]
                r[i] %= self.mod
            self.butterfly_inv(r)
            tmp = pow(-m, self.mod - 2, self.mod)
            for i in range(m):
                r[i] *= tmp
                r[i] %= self.mod
            t = (t + FPS(r)).Func
            t = [t[-1]] + t
            t.pop()
            # step 2.e
            if (2 * m < deg):
                if len(t) < 2 * m:
                    t += [0] * (2 * m - len(t))
                elif len(t) > 2 * m:
                    t = t[:2 * m]
                self.butterfly(t)
                g_fft = g[:]
                if len(g_fft) < 2 * m:
                    g_fft += [0] * (2 * m - len(g_fft))
                elif len(g_fft) > 2 * m:
                    g_fft = g_fft[:2 * m]
                self.butterfly(g_fft)
                for i in range(2 * m):
                    t[i] *= g_fft[i]
                    t[i] %= self.mod
                self.butterfly_inv(t)
                tmp = pow(2 * m, self.mod - 2, self.mod)
                t = t[:m]
                for i in range(m):
                    t[i] *= tmp
                    t[i] %= self.mod
            else:
                g1 = g[m // 2:]
                s1 = t[m // 2:]
                t = t[:m // 2]
                g1 += [0] * (m - len(g1))
                s1 += [0] * (m - len(s1))
                t += [0] * (m - len(t))

                self.butterfly(g1)
                self.butterfly(t)
                self.butterfly(s1)
                for i in range(m):
                    s1[i] = (g_fft[i] * s1[i] + g1[i] * t[i]) % self.mod
                for i in range(m):
                    t[i] *= g_fft[i]
                    t[i] %= self.mod
                self.butterfly_inv(t)
                self.butterfly_inv(s1)
                for i in range(m // 2):
                    t[i + m // 2] += s1[i]
                    t[i + m // 2] %= self.mod
                tmp = pow(m, self.mod - 2, self.mod)
                for i in range(m):
                    t[i] *= tmp
                    t[i] %= self.mod
            # step 2.f
            v = self.Func[m:min(deg, 2 * m)] + [0] * (2 * m - min(deg, 2 * m))
            t = [0] * (m - 1) + t
            t = FPS(t).integral().Func
            for i in range(m):
                v[i] -= t[m + i]
                v[i] %= self.mod
            # step 2.g
            if len(v) < 2 * m:
                v += [0] * (2 * m - len(v))
            else:
                v = v[:2 * m]
            self.butterfly(v)
            for i in range(2 * m):
                v[i] *= f_fft[i]
                v[i] %= self.mod
            self.butterfly_inv(v)
            v = v[:m]
            tmp = pow(2 * m, self.mod - 2, self.mod)
            for i in range(m):
                v[i] *= tmp
                v[i] %= self.mod
            # step 2.h
            for i in range(min(deg - m, m)):
                self.Func[m + i] = v[i]
            m *= 2
        return self

    def powfps(self, k, deg=-1):
        a = self.Func[:]
        n = len(self.Func)
        l = 0
        while (l < len(a) and not a[l]):
            l += 1
        if l * k >= n:
            return FPS([0] * n)
        ic = pow(a[l], self.mod - 2, self.mod)
        pc = pow(a[l], k, self.mod)
        a = FPS([(a[i] * ic) % self.mod for i in range(l, len(a))]).log()
        a *= k
        a = a.exp()
        a *= pc
        a = [0] * (l * k) + a.Func[:n - l * k]
        return FPS(a)

def Extended_Euclid(n,m):
    stack=[]
    while m:
        stack.append((n,m))
        n,m=m,n%m
    if n>=0:
        x,y=1,0
    else:
        x,y=-1,0
    for i in range(len(stack)-1,-1,-1):
        n,m=stack[i]
        x,y=y,x-(n//m)*y
    return x,y

class MOD:
    def __init__(self,p,e=None):
        self.p=p
        self.e=e
        if self.e==None:
            self.mod=self.p
        else:
            self.mod=self.p**self.e

    def Pow(self,a,n):
        a%=self.mod
        if n>=0:
            return pow(a,n,self.mod)
        else:
            assert math.gcd(a,self.mod)==1
            x=Extended_Euclid(a,self.mod)[0]
            return pow(x,-n,self.mod)

    def Build_Fact(self,N):
        assert N>=0
        self.factorial=[1]
        if self.e==None:
            for i in range(1,N+1):
                self.factorial.append(self.factorial[-1]*i%self.mod)
        else:
            self.cnt=[0]*(N+1)
            for i in range(1,N+1):
                self.cnt[i]=self.cnt[i-1]
                ii=i
                while ii%self.p==0:
                    ii//=self.p
                    self.cnt[i]+=1
                self.factorial.append(self.factorial[-1]*ii%self.mod)
        self.factorial_inve=[None]*(N+1)
        self.factorial_inve[-1]=self.Pow(self.factorial[-1],-1)
        for i in range(N-1,-1,-1):
            ii=i+1
            while ii%self.p==0:
                ii//=self.p
            self.factorial_inve[i]=(self.factorial_inve[i+1]*ii)%self.mod

    def Fact(self,N):
        if N<0:
            return 0
        retu=self.factorial[N]
        if self.e!=None and self.cnt[N]:
            retu*=pow(self.p,self.cnt[N],self.mod)%self.mod
            retu%=self.mod
        return retu

    def Fact_Inve(self,N):
        if self.e!=None and self.cnt[N]:
            return None
        return self.factorial_inve[N]

    def Comb(self,N,K,divisible_count=False):
        if K<0 or K>N:
            return 0
        retu=self.factorial[N]*self.factorial_inve[K]%self.mod*self.factorial_inve[N-K]%self.mod
        if self.e!=None:
            cnt=self.cnt[N]-self.cnt[N-K]-self.cnt[K]
            if divisible_count:
                return retu,cnt
            else:
                retu*=pow(self.p,cnt,self.mod)
                retu%=self.mod
        return retu

N,M,K=map(int,readline().split())
mod=998244353
MD=MOD(mod)
MD.Build_Fact(N)
P=[None]*(N-K+1)
for i in range(N-K+1):
    P[i]=MD.Fact_Inve(i+1)
P=FPS(P)
P=P.powfps(K,N-K+1)
ans=0
for n in range(K,N+1):
    ans+=P.Func[n-K]*MD.Pow(M,N-n)%mod*MD.Fact_Inve(N-n)%mod
ans*=MD.Comb(M,K)*MD.Fact(N)%mod
ans%=mod
print(ans)
0