結果

問題 No.1907 DETERMINATION
ユーザー lam6er
提出日時 2025-03-31 17:35:58
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,986 bytes
コンパイル時間 213 ms
コンパイル使用メモリ 82,372 KB
実行使用メモリ 69,936 KB
最終ジャッジ日時 2025-03-31 17:36:45
合計ジャッジ時間 10,417 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 3 TLE * 1 -- * 59
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
mod = 998244353

def readints():
    return list(map(int, sys.stdin.readline().split()))

def main():
    import sys
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    M0 = []
    for _ in range(N):
        M0.append(list(map(int, sys.stdin.readline().split())))
    M1 = []
    for _ in range(N):
        M1.append(list(map(int, sys.stdin.readline().split())))
    
    # List of evaluation points x = 0, 1, 2, ..., N
    xs = list(range(N+1))
    ys = []
    
    # Precompute all determinants for x in 0..N
    for x in xs:
        # Compute matrix M = M0 + x*M1
        mat = [ [ (M0[i][j] + x * M1[i][j]) % mod for j in range(N)] for i in range(N)]
        det = 1
        sign = 1
        mat = [row[:] for row in mat]  # make a copy
        
        for i in range(N):
            # Find pivot in column i
            pivot = -1
            for j in range(i, N):
                if mat[j][i] != 0:
                    pivot = j
                    break
            if pivot == -1:
                det = 0
                break
            if pivot != i:
                # Swap rows
                mat[i], mat[pivot] = mat[pivot], mat[i]
                sign *= -1
            
            pivot_val = mat[i][i]
            det = (det * pivot_val) % mod
            inv_pivot = pow(pivot_val, mod-2, mod)
            
            for j in range(i+1, N):
                factor = (mat[j][i] * inv_pivot) % mod
                # Subtract factor * row i from row j
                for k in range(i, N):
                    mat[j][k] = (mat[j][k] - factor * mat[i][k]) % mod
        if det == 0:
            final_det = 0
        else:
            final_det = (det * sign) % mod
        ys.append(final_det)
    
    # Now interpolate to find coefficients of the polynomial
    # Using Lagrange interpolation
    a = [0]*(N+1)
    for i in range(N+1):
        xi = xs[i]
        yi = ys[i]
        
        # Compute the Lagrange basis polynomial L_i
        L = [0]*(N+1)
        L[0] = 1
        denom = 1
        for j in range(N+1):
            if j == i:
                continue
            # multiply by (x - xs[j]) / (xi - xj)
            # to construct the numerator and denominator
            denom = (denom * (xi - xs[j])) % mod
            # Multiply L by (x - xs[j]) 
            # as a polynomial, shifting degrees
            new_L = [0]*(len(L)+1)
            for k in range(len(L)):
                new_L[k] = (new_L[k] - xs[j] * L[k]) % mod
                new_L[k+1] = (new_L[k+1] + L[k]) % mod
            L = new_L[:N+1]  # truncate if necessary
        
        # Compute inv_denom = 1/denom
        inv_denom = pow(denom, mod-2, mod)
        # Multiply L by yi * inv_denom and add to a
        for k in range(N+1):
            term = (L[k] * yi) % mod
            term = (term * inv_denom) % mod
            a[k] = (a[k] + term) % mod
    
    for coeff in a:
        print(coeff % mod)

if __name__ == "__main__":
    main()
0