結果

問題 No.1318 ABCD quadruplets
ユーザー qwewe
提出日時 2025-05-14 13:20:23
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,126 bytes
コンパイル時間 261 ms
コンパイル使用メモリ 82,788 KB
実行使用メモリ 101,560 KB
最終ジャッジ日時 2025-05-14 13:21:21
合計ジャッジ時間 8,234 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 13 TLE * 1 -- * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

# Set input function for potentially faster reading
input = sys.stdin.readline 

def solve():
    # Read N and M from input
    N, M = map(int, input().split()) 
    
    # Use a dictionary to store the counts of pairs (s, q)
    # where s = a + b and q = a^2 + b^2 for 0 <= a, b <= M.
    counts_dict = {}
    
    # Iterate through all possible pairs (a, b)
    for a in range(M + 1):
        for b in range(M + 1):
            # Calculate sum s and sum of squares q
            s = a + b
            q = a * a + b * b
            
            # Create a tuple key (s, q)
            key = (s, q)
            
            # Increment the count for this key in the dictionary
            # Use dict.get(key, 0) to handle the first time a key is seen
            counts_dict[key] = counts_dict.get(key, 0) + 1

    # Convert the dictionary items into a list of ((s, q), count) tuples
    # Iterating over a list might be slightly faster than iterating over dict items
    pairs_with_counts = list(counts_dict.items())
    
    # Initialize the answer array `ans` of size N+1 with zeros.
    # ans[n] will store the value f(n, M).
    ans = [0] * (N + 1)
    
    # Get the number of distinct (s, q) pairs
    K = len(pairs_with_counts)
    
    # Calculate the threshold value 2*N. We only care about quadruplets (a, b, c, d)
    # such that the expression value n satisfies n <= N.
    # The condition is E(a, b, c, d) = n, which is equivalent to
    # (a+b+c+d)^2 + (a^2+b^2+c^2+d^2) = 2n.
    # Let S = a+b+c+d = s1+s2 and Q = a^2+b^2+c^2+d^2 = q1+q2.
    # We need S^2 + Q = 2n. So S^2 + Q <= 2N.
    threshold = 2 * N

    # Precompute squares up to the maximum possible sum S = 4*M
    # This avoids repeated calculations of S*S inside the loops.
    max_S = 4 * M
    squares = [k*k for k in range(max_S + 1)]

    # Iterate through all distinct pairs (s1, q1) with index i
    for i in range(K):
        # Unpack the key (s1, q1) and its count count1
        (s1, q1), count1 = pairs_with_counts[i] 
        
        # Consider the combination of pair i with itself ((a,b), (c,d) where (a,b) and (c,d) yield the same (s,q))
        # This corresponds to quadruplets where (a+b, a^2+b^2) = (s1, q1) and (c+d, c^2+d^2) = (s1, q1).
        S_ii = s1 + s1
        Q_ii = q1 + q1
        
        # Check if S_ii is within the precomputed squares bounds (it should be)
        # Calculate val = S^2 + Q
        val_ii = squares[S_ii] + Q_ii
        
        # Check if the resulting value 2n satisfies 2n <= 2N
        if val_ii <= threshold:
            # Calculate n = (S^2 + Q) / 2. Use bit shift for integer division by 2.
            n_ii = val_ii >> 1 
            # Add count1 * count1 to ans[n_ii]. This accounts for choosing a pair type i twice.
            ans[n_ii] += count1 * count1

        # Consider combinations of pair i with pair j where j > i
        # This avoids double counting pairs and handles combinations of distinct pair types.
        for j in range(i + 1, K):
            # Unpack the key (s2, q2) and its count count2 for pair j
            (s2, q2), count2 = pairs_with_counts[j] 
            
            # Calculate S = s1 + s2 and Q = q1 + q2
            S_ij = s1 + s2
            Q_ij = q1 + q2
            
            # Calculate val = S^2 + Q using precomputed squares
            val_ij = squares[S_ij] + Q_ij
            
            # Check if the resulting value 2n satisfies 2n <= 2N
            if val_ij <= threshold:
                 # Calculate n = (S^2 + Q) / 2
                n_ij = val_ij >> 1
                # Add 2 * count1 * count2 to ans[n_ij].
                # The factor of 2 accounts for the symmetric case (j, i).
                # Use bit shift for multiplication by 2.
                ans[n_ij] += (count1 * count2) << 1 
                
    # Prepare the output strings
    output_lines = [str(x) for x in ans]
    
    # Write the results to standard output, separated by newlines.
    # Add a final newline as required by the problem statement format.
    sys.stdout.write("\n".join(output_lines) + "\n")

# Execute the solve function
solve()
0