結果

問題 No.1907 DETERMINATION
ユーザー gew1fw
提出日時 2025-06-12 19:48:42
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,130 bytes
コンパイル時間 330 ms
コンパイル使用メモリ 83,032 KB
実行使用メモリ 68,724 KB
最終ジャッジ日時 2025-06-12 19:48:59
合計ジャッジ時間 9,610 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 3 TLE * 1 -- * 59
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def readints():
    import sys
    return list(map(int, sys.stdin.readline().split()))

def determinant_mod(matrix, mod):
    n = len(matrix)
    det = 1
    for i in range(n):
        max_row = i
        for j in range(i, n):
            if matrix[j][i] != 0:
                max_row = j
                break
        if matrix[max_row][i] == 0:
            return 0
        if max_row != i:
            matrix[i], matrix[max_row] = matrix[max_row], matrix[i]
            det = (-det) % mod
        inv = pow(matrix[i][i], mod-2, mod)
        for j in range(i+1, n):
            factor = (matrix[j][i] * inv) % mod
            for k in range(i, n):
                matrix[j][k] = (matrix[j][k] - factor * matrix[i][k]) % mod
    for i in range(n):
        det = (det * matrix[i][i]) % mod
    return det

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr +=1
    
    M0 = []
    for _ in range(N):
        row = list(map(int, input[ptr:ptr+N]))
        ptr += N
        M0.append(row)
    
    M1 = []
    for _ in range(N):
        row = list(map(int, input[ptr:ptr+N]))
        ptr += N
        M1.append(row)
    
    # Choose evaluation points x = 0, 1, ..., N
    x_eval = list(range(N+1))
    f_eval = []
    
    for xi in x_eval:
        A = []
        for i in range(N):
            row = []
            for j in range(N):
                a = (M0[i][j] + xi * M1[i][j]) % MOD
                row.append(a)
            A.append(row)
        det = determinant_mod([r[:] for r in A], MOD)
        f_eval.append(det)
    
    # Now interpolate using lagrange interpolation
    n = N
    a = [0] * (n+1)
    
    # Vandermonde matrix approach
    # We solve V * a = f, where V[i][j] = x_eval[i]^j
    # To find a, we can use Gaussian elimination
    # This is O(n^3), which is feasible for n=400
    
    # Build the Vandermonde matrix
    V = []
    for xi in x_eval:
        row = []
        for j in range(n+1):
            row.append(pow(xi, j, MOD))
        V.append(row)
    
    # Create the augmented matrix
    aug = [row[:] for row in V]
    for i in range(n+1):
        aug[i].append(f_eval[i])
    
    # Perform Gaussian elimination
    for col in range(n+1):
        pivot = -1
        for row in range(col, n+1):
            if aug[row][col] != 0:
                pivot = row
                break
        if pivot == -1:
            continue  # all zeros, which shouldn't happen
        
        aug[col], aug[pivot] = aug[pivot], aug[col]
        
        inv = pow(aug[col][col], MOD-2, MOD)
        for j in range(col, n+2):
            aug[col][j] = (aug[col][j] * inv) % MOD
        
        for row in range(n+1):
            if row != col and aug[row][col] != 0:
                factor = aug[row][col]
                for j in range(col, n+2):
                    aug[row][j] = (aug[row][j] - factor * aug[col][j]) % MOD
    
    # Extract the coefficients
    for j in range(n+1):
        a[j] = aug[j][n+1] % MOD
    
    # Output
    for k in range(n+1):
        print(a[k] % MOD)
    
if __name__ == "__main__":
    main()
0