結果

問題 No.3451 Same Numbers
コンテスト
ユーザー 👑 potato167
提出日時 2026-01-27 18:55:22
言語 PyPy3
(7.3.17)
結果
AC  
実行時間 1,336 ms / 2,000 ms
コード長 3,492 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 740 ms
コンパイル使用メモリ 82,332 KB
実行使用メモリ 83,828 KB
最終ジャッジ日時 2026-02-20 20:52:47
合計ジャッジ時間 12,483 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 37
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys

MOD = 998244353


class Binomial:
    def __init__(self, MAX=0):
        self.fact_vec = [1]
        self.fact_inv_vec = [1]
        self.extend(MAX + 1)

    def extend(self, m=-1):
        n = len(self.fact_vec)
        if m == -1:
            m = n * 2
        if n >= m:
            return

        self.fact_vec.extend([0] * (m - n))
        self.fact_inv_vec.extend([0] * (m - n))

        for i in range(n, m):
            self.fact_vec[i] = (self.fact_vec[i - 1] * i) % MOD

        self.fact_inv_vec[m - 1] = pow(self.fact_vec[m - 1], MOD - 2, MOD)
        for i in range(m - 1, n, -1):
            self.fact_inv_vec[i - 1] = (self.fact_inv_vec[i] * i) % MOD

    def fact(self, i: int) -> int:
        if i < 0:
            return 0
        while len(self.fact_vec) <= i:
            self.extend()
        return self.fact_vec[i]

    def invfact(self, i: int) -> int:
        if i < 0:
            return 0
        while len(self.fact_inv_vec) <= i:
            self.extend()
        return self.fact_inv_vec[i]

    def C(self, a: int, b: int) -> int:
        if a < b or b < 0:
            return 0
        return (self.fact(a) * self.invfact(b) % MOD) * self.invfact(a - b) % MOD

    def inv_int(self, a: int) -> int:
        # C++: if (a < 0) return inv(-a) * T(-1);
        if a < 0:
            return (-self.inv_int(-a)) % MOD
        # C++: if (a == 0) return 1;
        if a == 0:
            return 1
        # C++: fact(a - 1) * invfact(a)
        return (self.fact(a - 1) * self.invfact(a)) % MOD


# O(V^1.5)
def solve4(N: int, M: int, E: int):
    table = Binomial()
    res = [1] * N
    pw = [1] * (M + 1)
    for i in range(0, M + 1):
    	pw[i] = pow(i, E, MOD)

    B = 0
    while B * B < M and B < N - 1:
        B += 1

    res[N - 1] = pw[((M - 1) // N + 1)]

    for k in range(0, B):
        inv = table.inv_int(N - k - 1)

        # x 回中 y 回未満を常に管理する
        x = M
        y = 0
        tmp = 0
        tmp2 = pow(table.inv_int(N - k), x, MOD) * pow(N - k - 1, x, MOD) % MOD

        a = 1
        while True:
            A = M - 1 - k * a
            Bv = a
            if A < Bv:
                break

            while A < x:
                # (x - 1, y - 1) -> (x, y)
                tmp = (tmp + table.C(x - 1, y - 1) * tmp2) % MOD
                x -= 1
                tmp2 = (tmp2 * inv) % MOD
                tmp2 = (tmp2 * (N - k)) % MOD

            while y < Bv:
                tmp = (tmp + table.C(A, y) * tmp2) % MOD
                tmp2 = (tmp2 * inv) % MOD
                y += 1

            res[k] = (res[k] + (1 - tmp) * (pw[a + 1] - pw[a])) % MOD
            a += 1

    for k in range(B, N - 1):
        inv = table.inv_int(N - k - 1)

        # N - 1 - k * a 回中 a 回未満
        a = 1
        while True:
            A = M - 1 - k * a
            Bv = a
            if A < Bv:
                break
            sm = 1
            tmp = pow(table.inv_int(N - k), A, MOD) * pow(N - k - 1, A, MOD) % MOD
            for j in range(0, a):
                sm = (sm - table.C(A, j) * tmp) % MOD
                tmp = (tmp * inv) % MOD
            sm = sm * (pw[a + 1] - pw[a]) % MOD
            res[k] = (res[k] + sm) % MOD
            a += 1

    return res


def main():
    data = sys.stdin.buffer.read().split()
    N = int(data[0])
    M = int(data[1])
    E = int(data[2])

    ans = solve4(N, M, E)
    sys.stdout.write("\n".join(map(str, ans)))


if __name__ == "__main__":
    main()
0