結果

問題 No.1627 三角形の成立
ユーザー lam6er
提出日時 2025-04-09 21:02:54
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,190 bytes
コンパイル時間 203 ms
コンパイル使用メモリ 82,568 KB
実行使用メモリ 79,620 KB
最終ジャッジ日時 2025-04-09 21:04:54
合計ジャッジ時間 3,081 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 6 WA * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 10**9 + 7
import sys
from math import isqrt

def compute_mobius(max_n):
    max_n += 1
    mobius = [1] * max_n
    is_prime = [True] * max_n
    for p in range(2, max_n):
        if is_prime[p]:
            for multiple in range(p, max_n, p):
                is_prime[multiple] = False if multiple != p else is_prime[multiple]
                mobius[multiple] *= -1
            p_square = p * p
            for multiple in range(p_square, max_n, p_square):
                mobius[multiple] = 0
    return mobius

def comb3(x):
    if x < 3:
        return 0
    return x * (x-1) * (x-2) // 6 % MOD

def main():
    n, m = map(int, sys.stdin.readline().split())
    total = n * m
    ans = comb3(total)
    if ans < 0:
        ans += MOD
    ans = (ans - n * comb3(m)) % MOD
    ans = (ans - m * comb3(n)) % MOD
    max_d = min(n-1, m-1)
    if max_d < 1:
        print(ans % MOD)
        return
    
    max_mn = max(n-1, m-1)
    mobius = compute_mobius(max_mn)
    
    diag = 0
    for d in range(1, max_d + 1):
        A = (n - 1) // d
        B = (m - 1) // d
        if A == 0 or B == 0:
            continue
        
        current = 0
        max_g = min(A, B)
        for g in range(1, max_g + 1):
            mu = mobius[g]
            if mu == 0:
                continue
            a = A // g
            b = B // g
            if a == 0 or b == 0:
                continue
            term1 = (a * b) % MOD
            term2 = (g * a * b * (b + 1) // 2) % MOD
            term3 = (g * b * a * (a + 1) // 2) % MOD
            term4 = (g * g * a * (a + 1) // 2 % MOD) * (b * (b + 1) // 2 % MOD) % MOD
            
            mn_term = (n * m) % MOD
            part1 = mn_term * term1 % MOD
            part2 = (d * m % MOD) * term2 % MOD
            part3 = (d * n % MOD) * term3 % MOD
            part4 = (d * d % MOD) * term4 % MOD
            total_term = (part1 - part2 - part3 + part4) * mu % MOD
            current = (current + total_term) % MOD
        
        current = current * 2 * (d - 1) % MOD
        diag = (diag + current) % MOD
    
    ans = (ans - diag) % MOD
    if ans < 0:
        ans += MOD
    print(ans)

if __name__ == "__main__":
    main()
0