結果

問題 No.3182 recurrence relation’s intersection sum
ユーザー sgfc
提出日時 2025-06-13 23:06:55
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,533 bytes
コンパイル時間 407 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 265,816 KB
最終ジャッジ日時 2025-06-13 23:07:35
合計ジャッジ時間 36,869 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 37 TLE * 3
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

class ModInt:
    __slots__ = ['v']
    def __init__(self, x):
        self.v = x % MOD if x >= 0 else (x % MOD + MOD) % MOD
    def __add__(self, other): return ModInt(self.v + other.v)
    def __sub__(self, other): return ModInt(self.v - other.v + MOD)
    def __mul__(self, other): return ModInt(self.v * other.v % MOD)
    def __pow__(self, power): return ModInt(pow(self.v, power, MOD))
    def __truediv__(self, other): return self * other.inv()
    def inv(self): return ModInt(pow(self.v, MOD - 2, MOD))
    def __int__(self): return self.v
    def __repr__(self): return str(self.v)

def setup_comb(max_n):
    fac = [ModInt(1)] * (max_n + 1)
    finv = [ModInt(1)] * (max_n + 1)
    inv = [ModInt(1)] * (max_n + 1)
    for i in range(2, max_n + 1):
        fac[i] = fac[i - 1] * ModInt(i)
        inv[i] = ModInt(MOD - MOD // i) * inv[MOD % i]
        finv[i] = finv[i - 1] * inv[i]
    return fac, finv

def binom(n, r, fac, finv):
    if r < 0 or r > n: return ModInt(0)
    return fac[n] * finv[r] * finv[n - r]

def matmul(a, b):
    n, m, l = len(a), len(b[0]), len(b)
    res = [[ModInt(0) for _ in range(m)] for _ in range(n)]
    for i in range(n):
        ai = a[i]
        for k in range(l):
            aik = ai[k]
            if aik.v == 0: continue
            bk = b[k]
            resi = res[i]
            for j in range(m):
                resi[j] = resi[j] + aik * bk[j]
    return res

def matpow(mat, power):
    size = len(mat)
    res = [[ModInt(1 if i == j else 0) for j in range(size)] for i in range(size)]
    while power:
        if power & 1:
            res = matmul(res, mat)
        mat = matmul(mat, mat)
        power >>= 1
    return res

def solve(k, l, r):
    fac, finv = setup_comb(k + 10)
    size = k + 4
    m = [[ModInt(0) for _ in range(size)] for _ in range(size)]

    for i in range(size):
        if i == 0:
            m[0][0] = ModInt(k)
            m[0][1] = ModInt(1)
            m[0][k + 2] = ModInt(1)
        elif 1 <= i <= k + 1:
            top = k - (i - 1)
            for j in range(i, k + 2):
                m[i][j] = binom(top, k + 1 - j, fac, finv)
        elif i == k + 2:
            m[i][k + 2] = ModInt(k)
        elif i == k + 3:
            m[i][0] = ModInt(1)
            m[i][i] = ModInt(1)

    ml = matpow(m, l)
    mr = matpow(m, r + 1)
    ans = mr[k + 3][0] + mr[k + 3][k + 1] + mr[k + 3][k + 2] \
        - ml[k + 3][0] - ml[k + 3][k + 1] - ml[k + 3][k + 2]
    print(int(ans))

k, l, r = map(int, input().split())
solve(k, l, r)
0