結果
問題 | No.1907 DETERMINATION |
ユーザー |
![]() |
提出日時 | 2025-06-12 19:48:42 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,130 bytes |
コンパイル時間 | 330 ms |
コンパイル使用メモリ | 83,032 KB |
実行使用メモリ | 68,724 KB |
最終ジャッジ日時 | 2025-06-12 19:48:59 |
合計ジャッジ時間 | 9,610 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 3 TLE * 1 -- * 59 |
ソースコード
MOD = 998244353 def readints(): import sys return list(map(int, sys.stdin.readline().split())) def determinant_mod(matrix, mod): n = len(matrix) det = 1 for i in range(n): max_row = i for j in range(i, n): if matrix[j][i] != 0: max_row = j break if matrix[max_row][i] == 0: return 0 if max_row != i: matrix[i], matrix[max_row] = matrix[max_row], matrix[i] det = (-det) % mod inv = pow(matrix[i][i], mod-2, mod) for j in range(i+1, n): factor = (matrix[j][i] * inv) % mod for k in range(i, n): matrix[j][k] = (matrix[j][k] - factor * matrix[i][k]) % mod for i in range(n): det = (det * matrix[i][i]) % mod return det def main(): import sys input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr +=1 M0 = [] for _ in range(N): row = list(map(int, input[ptr:ptr+N])) ptr += N M0.append(row) M1 = [] for _ in range(N): row = list(map(int, input[ptr:ptr+N])) ptr += N M1.append(row) # Choose evaluation points x = 0, 1, ..., N x_eval = list(range(N+1)) f_eval = [] for xi in x_eval: A = [] for i in range(N): row = [] for j in range(N): a = (M0[i][j] + xi * M1[i][j]) % MOD row.append(a) A.append(row) det = determinant_mod([r[:] for r in A], MOD) f_eval.append(det) # Now interpolate using lagrange interpolation n = N a = [0] * (n+1) # Vandermonde matrix approach # We solve V * a = f, where V[i][j] = x_eval[i]^j # To find a, we can use Gaussian elimination # This is O(n^3), which is feasible for n=400 # Build the Vandermonde matrix V = [] for xi in x_eval: row = [] for j in range(n+1): row.append(pow(xi, j, MOD)) V.append(row) # Create the augmented matrix aug = [row[:] for row in V] for i in range(n+1): aug[i].append(f_eval[i]) # Perform Gaussian elimination for col in range(n+1): pivot = -1 for row in range(col, n+1): if aug[row][col] != 0: pivot = row break if pivot == -1: continue # all zeros, which shouldn't happen aug[col], aug[pivot] = aug[pivot], aug[col] inv = pow(aug[col][col], MOD-2, MOD) for j in range(col, n+2): aug[col][j] = (aug[col][j] * inv) % MOD for row in range(n+1): if row != col and aug[row][col] != 0: factor = aug[row][col] for j in range(col, n+2): aug[row][j] = (aug[row][j] - factor * aug[col][j]) % MOD # Extract the coefficients for j in range(n+1): a[j] = aug[j][n+1] % MOD # Output for k in range(n+1): print(a[k] % MOD) if __name__ == "__main__": main()