結果
問題 | No.3118 Increment or Multiply |
ユーザー |
|
提出日時 | 2025-04-24 14:24:26 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
WA
|
実行時間 | - |
コード長 | 1,259 bytes |
コンパイル時間 | 398 ms |
コンパイル使用メモリ | 12,032 KB |
実行使用メモリ | 10,496 KB |
最終ジャッジ日時 | 2025-04-24 14:24:33 |
合計ジャッジ時間 | 6,509 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 5 WA * 30 |
ソースコード
#!/usr/bin/env python3 import sys input = sys.stdin.readline MOD = 998244353 INV2 = (MOD + 1) // 2 def solve_case(N: int, A: int) -> int: # A=1 は別処理 if A == 1: return (N % MOD) * ((N-1) % MOD) % MOD * INV2 % MOD N_mod = N % MOD N2 = N_mod * N_mod % MOD total_m = 0 # Σ m * Δ_m total_k = 0 # Σ A^m * S_m Um = N power = 1 # A^m mod MOD m = 0 A_mod = A % MOD # U_{m+1}=0 になるまで while Um > 0: Um_mod = Um % MOD Um1 = Um // A Um1_mod = Um1 % MOD # Δ_m = U_m - U_{m+1} delta = (Um_mod - Um1_mod) % MOD total_m = (total_m + m * delta) % MOD # S_m = (U_m(U_m+1) - U_{m+1}(U_{m+1}+1)) / 2 sum_interval = ( (Um_mod * (Um_mod + 1) - Um1_mod * (Um1_mod + 1)) % MOD * INV2 ) % MOD total_k = (total_k + power * sum_interval) % MOD # 次へ Um = Um1 power = power * A_mod % MOD m += 1 # S = N^2 + Σ(m·Δ_m) - Σ(A^m·S_m) return (N2 + total_m - total_k) % MOD def main(): T = int(input()) for _ in range(T): n, a = map(int, input().split()) print(solve_case(n, a)) if __name__ == "__main__": main()