結果

問題 No.1552 Simple Dice Game
ユーザー LyricalMaestro
提出日時 2025-01-03 01:44:52
言語 Rust
(1.83.0 + proconio)
結果
AC  
実行時間 484 ms / 2,500 ms
コード長 4,380 bytes
コンパイル時間 12,676 ms
コンパイル使用メモリ 397,236 KB
実行使用メモリ 40,544 KB
最終ジャッジ日時 2025-01-03 01:45:18
合計ジャッジ時間 18,748 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

use std::collections::HashMap;
use std::io::{self, BufRead};

const MOD: i64 = 998244353;

/// pow_cache を利用して (a^x) mod MOD を計算する関数
fn my_pow(a: i64, x: i64, pow_cache: &mut HashMap<(i64, i64), i64>) -> i64 {
    // キャッシュに登録済みならそれを返す
    if let Some(&val) = pow_cache.get(&(a, x)) {
        return val;
    }

    // a = 0 のときは 0^x = 0 (x>0 の場合)
    if a == 0 {
        pow_cache.insert((a, x), 0);
        return 0;
    }
    
    // それ以外は (a^x) mod MOD を計算してキャッシュに入れる
    let val = mod_pow(a, x, MOD);
    pow_cache.insert((a, x), val);
    val
}

/// 通常の繰り返し二乗法による (base^exp) mod m
fn mod_pow(base: i64, exp: i64, m: i64) -> i64 {
    if exp < 0 {
        return 0; // 今回の問題設定では exp>=0 のはずなので、念のため
    }
    let mut result = 1i64;
    let mut cur = base % m;
    let mut e = exp;
    while e > 0 {
        if e & 1 == 1 {
            result = (result * cur) % m;
        }
        cur = (cur * cur) % m;
        e >>= 1;
    }
    result
}

/// 1 から N までの和を (N*(N+1)/2) mod MOD で返す
/// Pythonコード中の linear_sum(N, pow_cache) 相当
fn linear_sum(n: i64, pow_cache: &mut HashMap<(i64, i64), i64>) -> i64 {
    // ans = N*(N+1)/2 mod MOD
    // 割り算はフェルマーの小定理を用いて (2^(MOD-2)) mod MOD を掛ける
    let ans = (n % MOD) * ((n + 1) % MOD) % MOD;
    // 2^(MOD-2) mod MOD は逆元
    let inv2 = my_pow(2, MOD - 2, pow_cache);
    (ans * inv2) % MOD
}

/// Python の main 関数相当
fn main() {
    // 標準入力の読み取り
    let stdin = io::stdin();
    let mut lines = stdin.lock().lines();
    let line = lines.next().unwrap().unwrap();
    let mut iter = line.split_whitespace();
    let n = iter.next().unwrap().parse::<i64>().unwrap();
    let m = iter.next().unwrap().parse::<i64>().unwrap();

    let mut pow_cache = HashMap::new();

    // ------------------------
    // answer_min の計算
    // ------------------------
    let mut answer_min = 0i64;
    let max_linear_sum = linear_sum(m, &mut pow_cache);

    for mm in 1..=m {
        // diff_linear_sum = (max_linear_sum - linear_sum(m-1)) mod
        let min_linear_sum = linear_sum(mm - 1, &mut pow_cache);
        let diff_linear_sum = mod_sub(max_linear_sum, min_linear_sum);

        // m0 = (M - m + 1)^(N-1) * N
        let mut m0 = (m - mm + 1).max(0); // 念のため負にならないように
        m0 = my_pow(m0, n - 1, &mut pow_cache);
        m0 = (m0 % MOD) * (n % MOD) % MOD;
        let a = (diff_linear_sum * m0) % MOD;

        // diff_linear_sum2 = (max_linear_sum - linear_sum(m)) mod
        let min_linear_sum2 = linear_sum(mm, &mut pow_cache);
        let diff_linear_sum2 = mod_sub(max_linear_sum, min_linear_sum2);

        // b = diff_linear_sum2 * ((M - m)^(N-1) * N) mod
        let mut m1 = (m - mm).max(0);
        m1 = my_pow(m1, n - 1, &mut pow_cache);
        m1 = (m1 % MOD) * (n % MOD) % MOD;
        let b = (diff_linear_sum2 * m1) % MOD;

        let c = mod_sub(a, b);
        let ans = (c * mm) % MOD;
        answer_min = (answer_min + ans) % MOD;
    }

    // ------------------------
    // answer_max の計算
    // ------------------------
    let mut answer_max = 0i64;
    for mm in 1..=m {
        let max_linear_sum_m = linear_sum(mm, &mut pow_cache);
        // m0 = m^(N-1) * N
        let mut m0 = mm;
        m0 = my_pow(m0, n - 1, &mut pow_cache);
        m0 = (m0 % MOD) * (n % MOD) % MOD;
        let a = (max_linear_sum_m * m0) % MOD;

        let max_linear_sum2 = linear_sum(mm - 1, &mut pow_cache);
        let mut m1 = (mm - 1).max(0);
        m1 = my_pow(m1, n - 1, &mut pow_cache);
        m1 = (m1 % MOD) * (n % MOD) % MOD;
        let b = (max_linear_sum2 * m1) % MOD;

        let c = mod_sub(a, b);
        let ans = (c * mm) % MOD;
        answer_max = (answer_max + ans) % MOD;
    }

    // 結果
    let mut answer = mod_sub(answer_max, answer_min);
    // Python同様、マイナスになったら MOD を足す
    if answer < 0 {
        answer += MOD;
    }
    println!("{}", answer);
}

/// (a - b) mod MOD を返す(負にならないように調整)
/// Python の (a - b) % MOD 相当のヘルパー
fn mod_sub(a: i64, b: i64) -> i64 {
    (a - b).rem_euclid(MOD)
}
0