結果
| 問題 |
No.1552 Simple Dice Game
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-01-03 01:44:52 |
| 言語 | Rust (1.83.0 + proconio) |
| 結果 |
AC
|
| 実行時間 | 484 ms / 2,500 ms |
| コード長 | 4,380 bytes |
| コンパイル時間 | 12,676 ms |
| コンパイル使用メモリ | 397,236 KB |
| 実行使用メモリ | 40,544 KB |
| 最終ジャッジ日時 | 2025-01-03 01:45:18 |
| 合計ジャッジ時間 | 18,748 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 20 |
ソースコード
use std::collections::HashMap;
use std::io::{self, BufRead};
const MOD: i64 = 998244353;
/// pow_cache を利用して (a^x) mod MOD を計算する関数
fn my_pow(a: i64, x: i64, pow_cache: &mut HashMap<(i64, i64), i64>) -> i64 {
// キャッシュに登録済みならそれを返す
if let Some(&val) = pow_cache.get(&(a, x)) {
return val;
}
// a = 0 のときは 0^x = 0 (x>0 の場合)
if a == 0 {
pow_cache.insert((a, x), 0);
return 0;
}
// それ以外は (a^x) mod MOD を計算してキャッシュに入れる
let val = mod_pow(a, x, MOD);
pow_cache.insert((a, x), val);
val
}
/// 通常の繰り返し二乗法による (base^exp) mod m
fn mod_pow(base: i64, exp: i64, m: i64) -> i64 {
if exp < 0 {
return 0; // 今回の問題設定では exp>=0 のはずなので、念のため
}
let mut result = 1i64;
let mut cur = base % m;
let mut e = exp;
while e > 0 {
if e & 1 == 1 {
result = (result * cur) % m;
}
cur = (cur * cur) % m;
e >>= 1;
}
result
}
/// 1 から N までの和を (N*(N+1)/2) mod MOD で返す
/// Pythonコード中の linear_sum(N, pow_cache) 相当
fn linear_sum(n: i64, pow_cache: &mut HashMap<(i64, i64), i64>) -> i64 {
// ans = N*(N+1)/2 mod MOD
// 割り算はフェルマーの小定理を用いて (2^(MOD-2)) mod MOD を掛ける
let ans = (n % MOD) * ((n + 1) % MOD) % MOD;
// 2^(MOD-2) mod MOD は逆元
let inv2 = my_pow(2, MOD - 2, pow_cache);
(ans * inv2) % MOD
}
/// Python の main 関数相当
fn main() {
// 標準入力の読み取り
let stdin = io::stdin();
let mut lines = stdin.lock().lines();
let line = lines.next().unwrap().unwrap();
let mut iter = line.split_whitespace();
let n = iter.next().unwrap().parse::<i64>().unwrap();
let m = iter.next().unwrap().parse::<i64>().unwrap();
let mut pow_cache = HashMap::new();
// ------------------------
// answer_min の計算
// ------------------------
let mut answer_min = 0i64;
let max_linear_sum = linear_sum(m, &mut pow_cache);
for mm in 1..=m {
// diff_linear_sum = (max_linear_sum - linear_sum(m-1)) mod
let min_linear_sum = linear_sum(mm - 1, &mut pow_cache);
let diff_linear_sum = mod_sub(max_linear_sum, min_linear_sum);
// m0 = (M - m + 1)^(N-1) * N
let mut m0 = (m - mm + 1).max(0); // 念のため負にならないように
m0 = my_pow(m0, n - 1, &mut pow_cache);
m0 = (m0 % MOD) * (n % MOD) % MOD;
let a = (diff_linear_sum * m0) % MOD;
// diff_linear_sum2 = (max_linear_sum - linear_sum(m)) mod
let min_linear_sum2 = linear_sum(mm, &mut pow_cache);
let diff_linear_sum2 = mod_sub(max_linear_sum, min_linear_sum2);
// b = diff_linear_sum2 * ((M - m)^(N-1) * N) mod
let mut m1 = (m - mm).max(0);
m1 = my_pow(m1, n - 1, &mut pow_cache);
m1 = (m1 % MOD) * (n % MOD) % MOD;
let b = (diff_linear_sum2 * m1) % MOD;
let c = mod_sub(a, b);
let ans = (c * mm) % MOD;
answer_min = (answer_min + ans) % MOD;
}
// ------------------------
// answer_max の計算
// ------------------------
let mut answer_max = 0i64;
for mm in 1..=m {
let max_linear_sum_m = linear_sum(mm, &mut pow_cache);
// m0 = m^(N-1) * N
let mut m0 = mm;
m0 = my_pow(m0, n - 1, &mut pow_cache);
m0 = (m0 % MOD) * (n % MOD) % MOD;
let a = (max_linear_sum_m * m0) % MOD;
let max_linear_sum2 = linear_sum(mm - 1, &mut pow_cache);
let mut m1 = (mm - 1).max(0);
m1 = my_pow(m1, n - 1, &mut pow_cache);
m1 = (m1 % MOD) * (n % MOD) % MOD;
let b = (max_linear_sum2 * m1) % MOD;
let c = mod_sub(a, b);
let ans = (c * mm) % MOD;
answer_max = (answer_max + ans) % MOD;
}
// 結果
let mut answer = mod_sub(answer_max, answer_min);
// Python同様、マイナスになったら MOD を足す
if answer < 0 {
answer += MOD;
}
println!("{}", answer);
}
/// (a - b) mod MOD を返す(負にならないように調整)
/// Python の (a - b) % MOD 相当のヘルパー
fn mod_sub(a: i64, b: i64) -> i64 {
(a - b).rem_euclid(MOD)
}