結果

問題 No.1552 Simple Dice Game
ユーザー LyricalMaestroLyricalMaestro
提出日時 2025-01-03 01:44:52
言語 Rust
(1.83.0 + proconio)
結果
AC  
実行時間 484 ms / 2,500 ms
コード長 4,380 bytes
コンパイル時間 12,676 ms
コンパイル使用メモリ 397,236 KB
実行使用メモリ 40,544 KB
最終ジャッジ日時 2025-01-03 01:45:18
合計ジャッジ時間 18,748 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
5,248 KB
testcase_01 AC 1 ms
5,248 KB
testcase_02 AC 1 ms
5,248 KB
testcase_03 AC 447 ms
40,452 KB
testcase_04 AC 1 ms
5,248 KB
testcase_05 AC 1 ms
5,248 KB
testcase_06 AC 1 ms
5,248 KB
testcase_07 AC 1 ms
5,248 KB
testcase_08 AC 1 ms
5,248 KB
testcase_09 AC 368 ms
21,272 KB
testcase_10 AC 412 ms
40,544 KB
testcase_11 AC 409 ms
40,424 KB
testcase_12 AC 106 ms
11,632 KB
testcase_13 AC 357 ms
21,276 KB
testcase_14 AC 176 ms
11,568 KB
testcase_15 AC 237 ms
21,368 KB
testcase_16 AC 30 ms
5,248 KB
testcase_17 AC 48 ms
6,852 KB
testcase_18 AC 295 ms
21,308 KB
testcase_19 AC 484 ms
40,484 KB
testcase_20 AC 441 ms
40,540 KB
testcase_21 AC 484 ms
40,544 KB
testcase_22 AC 437 ms
40,420 KB
testcase_23 AC 446 ms
40,452 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

use std::collections::HashMap;
use std::io::{self, BufRead};

const MOD: i64 = 998244353;

/// pow_cache を利用して (a^x) mod MOD を計算する関数
fn my_pow(a: i64, x: i64, pow_cache: &mut HashMap<(i64, i64), i64>) -> i64 {
    // キャッシュに登録済みならそれを返す
    if let Some(&val) = pow_cache.get(&(a, x)) {
        return val;
    }

    // a = 0 のときは 0^x = 0 (x>0 の場合)
    if a == 0 {
        pow_cache.insert((a, x), 0);
        return 0;
    }
    
    // それ以外は (a^x) mod MOD を計算してキャッシュに入れる
    let val = mod_pow(a, x, MOD);
    pow_cache.insert((a, x), val);
    val
}

/// 通常の繰り返し二乗法による (base^exp) mod m
fn mod_pow(base: i64, exp: i64, m: i64) -> i64 {
    if exp < 0 {
        return 0; // 今回の問題設定では exp>=0 のはずなので、念のため
    }
    let mut result = 1i64;
    let mut cur = base % m;
    let mut e = exp;
    while e > 0 {
        if e & 1 == 1 {
            result = (result * cur) % m;
        }
        cur = (cur * cur) % m;
        e >>= 1;
    }
    result
}

/// 1 から N までの和を (N*(N+1)/2) mod MOD で返す
/// Pythonコード中の linear_sum(N, pow_cache) 相当
fn linear_sum(n: i64, pow_cache: &mut HashMap<(i64, i64), i64>) -> i64 {
    // ans = N*(N+1)/2 mod MOD
    // 割り算はフェルマーの小定理を用いて (2^(MOD-2)) mod MOD を掛ける
    let ans = (n % MOD) * ((n + 1) % MOD) % MOD;
    // 2^(MOD-2) mod MOD は逆元
    let inv2 = my_pow(2, MOD - 2, pow_cache);
    (ans * inv2) % MOD
}

/// Python の main 関数相当
fn main() {
    // 標準入力の読み取り
    let stdin = io::stdin();
    let mut lines = stdin.lock().lines();
    let line = lines.next().unwrap().unwrap();
    let mut iter = line.split_whitespace();
    let n = iter.next().unwrap().parse::<i64>().unwrap();
    let m = iter.next().unwrap().parse::<i64>().unwrap();

    let mut pow_cache = HashMap::new();

    // ------------------------
    // answer_min の計算
    // ------------------------
    let mut answer_min = 0i64;
    let max_linear_sum = linear_sum(m, &mut pow_cache);

    for mm in 1..=m {
        // diff_linear_sum = (max_linear_sum - linear_sum(m-1)) mod
        let min_linear_sum = linear_sum(mm - 1, &mut pow_cache);
        let diff_linear_sum = mod_sub(max_linear_sum, min_linear_sum);

        // m0 = (M - m + 1)^(N-1) * N
        let mut m0 = (m - mm + 1).max(0); // 念のため負にならないように
        m0 = my_pow(m0, n - 1, &mut pow_cache);
        m0 = (m0 % MOD) * (n % MOD) % MOD;
        let a = (diff_linear_sum * m0) % MOD;

        // diff_linear_sum2 = (max_linear_sum - linear_sum(m)) mod
        let min_linear_sum2 = linear_sum(mm, &mut pow_cache);
        let diff_linear_sum2 = mod_sub(max_linear_sum, min_linear_sum2);

        // b = diff_linear_sum2 * ((M - m)^(N-1) * N) mod
        let mut m1 = (m - mm).max(0);
        m1 = my_pow(m1, n - 1, &mut pow_cache);
        m1 = (m1 % MOD) * (n % MOD) % MOD;
        let b = (diff_linear_sum2 * m1) % MOD;

        let c = mod_sub(a, b);
        let ans = (c * mm) % MOD;
        answer_min = (answer_min + ans) % MOD;
    }

    // ------------------------
    // answer_max の計算
    // ------------------------
    let mut answer_max = 0i64;
    for mm in 1..=m {
        let max_linear_sum_m = linear_sum(mm, &mut pow_cache);
        // m0 = m^(N-1) * N
        let mut m0 = mm;
        m0 = my_pow(m0, n - 1, &mut pow_cache);
        m0 = (m0 % MOD) * (n % MOD) % MOD;
        let a = (max_linear_sum_m * m0) % MOD;

        let max_linear_sum2 = linear_sum(mm - 1, &mut pow_cache);
        let mut m1 = (mm - 1).max(0);
        m1 = my_pow(m1, n - 1, &mut pow_cache);
        m1 = (m1 % MOD) * (n % MOD) % MOD;
        let b = (max_linear_sum2 * m1) % MOD;

        let c = mod_sub(a, b);
        let ans = (c * mm) % MOD;
        answer_max = (answer_max + ans) % MOD;
    }

    // 結果
    let mut answer = mod_sub(answer_max, answer_min);
    // Python同様、マイナスになったら MOD を足す
    if answer < 0 {
        answer += MOD;
    }
    println!("{}", answer);
}

/// (a - b) mod MOD を返す(負にならないように調整)
/// Python の (a - b) % MOD 相当のヘルパー
fn mod_sub(a: i64, b: i64) -> i64 {
    (a - b).rem_euclid(MOD)
}
0