結果
| 問題 | 
                            No.2127 Mod, Sum, Sum, Mod
                             | 
                    
| コンテスト | |
| ユーザー | 
                             lam6er
                         | 
                    
| 提出日時 | 2025-03-20 18:52:38 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                AC
                                 
                             
                            
                         | 
                    
| 実行時間 | 61 ms / 2,000 ms | 
| コード長 | 1,445 bytes | 
| コンパイル時間 | 171 ms | 
| コンパイル使用メモリ | 82,456 KB | 
| 実行使用メモリ | 75,816 KB | 
| 最終ジャッジ日時 | 2025-03-20 18:53:38 | 
| 合計ジャッジ時間 | 2,454 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge3 / judge2 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 27 | 
ソースコード
MOD = 998244353
n, m = map(int, input().split())
# Compute part1: M * sum(i=1..n) i
part1 = (m % MOD) * (n % MOD) % MOD
part1 = part1 * ((n + 1) % MOD) % MOD
part1 = part1 * pow(2, MOD-2, MOD) % MOD  # Division by 2 using modular inverse
k = min(m, n)
part2 = 0
L = 1
while L <= k:
    q = n // L
    if q == 0:
        break
    R = n // q
    R = min(R, k)
    if L > R:
        L += 1
        continue
    
    # Compute c1 and c2 with mod
    c1 = (- (q * q + q) // 2) % MOD
    c2 = (q * (n + 1)) % MOD
    
    # Calculate sum of j^2 from L to R mod MOD
    def sum_j_sq(a, b):
        sum_b = b * (b + 1) % (6 * MOD)
        sum_b = sum_b * (2 * b + 1) % (6 * MOD)
        sum_b = sum_b // 6
        
        sum_a_minus1 = (a - 1) * a % (6 * MOD)
        sum_a_minus1 = sum_a_minus1 * (2 * (a - 1) + 1) % (6 * MOD)
        sum_a_minus1 = sum_a_minus1 // 6
        
        return (sum_b - sum_a_minus1) % MOD
    
    sum_j_sq_val = sum_j_sq(L, R)
    
    # Calculate sum of j from L to R mod MOD
    def sum_j(a, b):
        sum_b = b * (b + 1) // 2
        sum_a_minus1 = (a - 1) * a // 2
        return (sum_b - sum_a_minus1) % MOD
    
    sum_j_val = sum_j(L, R)
    
    # Calculate contribution
    contribution = (c1 * sum_j_sq_val) % MOD
    contribution = (contribution + (c2 * sum_j_val)) % MOD
    part2 = (part2 + contribution) % MOD
    
    L = R + 1
ans = (part1 - part2) % MOD
print(ans if ans >= 0 else ans + MOD)
            
            
            
        
            
lam6er