結果

問題 No.3182 recurrence relation’s intersection sum
ユーザー LyricalMaestro
提出日時 2025-07-12 02:55:42
言語 Rust
(1.83.0 + proconio)
結果
AC  
実行時間 608 ms / 2,000 ms
コード長 3,275 bytes
コンパイル時間 22,374 ms
コンパイル使用メモリ 383,392 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-07-12 02:56:14
合計ジャッジ時間 32,786 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 40
権限があれば一括ダウンロードができます

ソースコード

diff #

use std::io::{self, BufRead};

const MOD: u64 = 998244353;

struct CombinationCalculator {
    factorial: Vec<u64>,
    inv_factorial: Vec<u64>,
}

impl CombinationCalculator {
    fn new(size: usize) -> Self {
        let mut factorial = vec![1; size + 1];
        for i in 1..=size {
            factorial[i] = factorial[i - 1] * i as u64 % MOD;
        }

        let mut inv_factorial = vec![0; size + 1];
        inv_factorial[size] = modinv(factorial[size]);
        for i in (0..size).rev() {
            inv_factorial[i] = inv_factorial[i + 1] * (i as u64 + 1) % MOD;
        }

        CombinationCalculator {
            factorial,
            inv_factorial,
        }
    }

    fn comb(&self, n: usize, r: usize) -> u64 {
        if r > n {
            return 0;
        }
        self.factorial[n] * self.inv_factorial[r] % MOD * self.inv_factorial[n - r] % MOD
    }
}

fn modinv(x: u64) -> u64 {
    modpow(x, MOD - 2)
}

fn modpow(mut base: u64, mut exp: u64) -> u64 {
    let mut result = 1;
    while exp > 0 {
        if exp % 2 == 1 {
            result = result * base % MOD;
        }
        base = base * base % MOD;
        exp /= 2;
    }
    result
}

fn prod_matrix(x: usize, left: &Vec<Vec<u64>>, right: &Vec<Vec<u64>>) -> Vec<Vec<u64>> {
    let mut res = vec![vec![0; x]; x];
    for i in 0..x {
        for j in 0..x {
            for k in 0..x {
                res[i][j] = (res[i][j] + left[i][k] * right[k][j]) % MOD;
            }
        }
    }
    res
}

fn prod_vec(x: usize, matrix: &Vec<Vec<u64>>, vec: &Vec<u64>) -> Vec<u64> {
    let mut res = vec![0; x];
    for i in 0..x {
        for j in 0..x {
            res[i] = (res[i] + matrix[i][j] * vec[j]) % MOD;
        }
    }
    res
}

fn solve(k: usize, mut r: u64, comb: &CombinationCalculator) -> u64 {
    if k == 1 {
        let a1 = r * (r + 1) % MOD * (2 * r + 1) % MOD * modinv(12) % MOD;
        let a2 = r * (r + 1) % MOD * modinv(4) % MOD;
        let a3 = r + 1;
        return (a1 + a2 + a3) % MOD;
    }

    let l0 = k + 5;
    let mut s = vec![vec![0; l0]; l0];
    s[0][0] = k as u64;
    s[0][1] = 1;
    s[0][k + 3] = modinv((k - 1) as u64);
    s[0][k + 4] = (1 + MOD - s[0][k + 3]) % MOD;

    for i in 2..=k + 2 {
        for j in i..=k + 2 {
            s[i][j] = comb.comb(k - (i - 2), j - i);
        }
    }

    s[1][1] = 1;
    for j in 2..=k + 2 {
        s[1][j] = s[2][j];
    }

    s[k + 3][k + 3] = k as u64;
    s[k + 4][k + 4] = 1;

    let mut vec = vec![1, 0];
    vec.extend(vec![0; k]);
    vec.extend(vec![1, k as u64, 1]);

    while r > 0 {
        if r % 2 == 1 {
            vec = prod_vec(l0, &s, &vec);
        }
        s = prod_matrix(l0, &s, &s);
        r /= 2;
    }

    vec[0]
}

fn main() {
    let stdin = io::stdin();
    let mut line = String::new();
    stdin.lock().read_line(&mut line).unwrap();
    let parts: Vec<usize> = line
        .trim()
        .split_whitespace()
        .map(|x| x.parse().unwrap())
        .collect();
    let (k, l, r) = (parts[0], parts[1] as u64, parts[2] as u64);

    let comb = CombinationCalculator::new(2 * k);

    let ans = solve(k, r, &comb);
    let ans1 = if l > 0 {
        solve(k, l - 1, &comb)
    } else {
        0
    };

    println!("{}", (MOD + ans - ans1) % MOD);
}
0