結果

問題 No.611 Day of the Mountain
ユーザー lam6er
提出日時 2025-04-09 20:56:57
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,806 bytes
コンパイル時間 178 ms
コンパイル使用メモリ 82,352 KB
実行使用メモリ 94,292 KB
最終ジャッジ日時 2025-04-09 20:58:41
合計ジャッジ時間 6,606 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 3
other TLE * 1 -- * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 201712111

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    H = int(input[idx]); idx +=1
    W = int(input[idx]); idx +=1
    grid = []
    for _ in range(H):
        row = input[idx]; idx +=1
        grid.append(list(row))
    
    orig_grid = [[1 if c == '?' else int(c) for c in row] for row in grid]
    
    # Compute T and forward DP
    dp = [[0]*W for _ in range(H)]
    dp[0][0] = orig_grid[0][0]
    for i in range(H):
        for j in range(W):
            if i ==0 and j ==0:
                continue
            vals = []
            if i >0:
                vals.append(dp[i-1][j])
            if j >0:
                vals.append(dp[i][j-1])
            dp[i][j] = min(vals) + orig_grid[i][j]
    T = dp[H-1][W-1]
    
    # Compute backward DP
    dp_back = [[0]*W for _ in range(H)]
    dp_back[H-1][W-1] = orig_grid[H-1][W-1]
    for i in reversed(range(H)):
        for j in reversed(range(W)):
            if i == H-1 and j == W-1:
                continue
            vals = []
            if i+1 < H:
                vals.append(dp_back[i+1][j])
            if j+1 < W:
                vals.append(dp_back[i][j+1])
            dp_back[i][j] = min(vals) + orig_grid[i][j]
    
    # Determine cells in critical paths
    crit_cells = set()
    for i in range(H):
        for j in range(W):
            if dp[i][j] + dp_back[i][j] - orig_grid[i][j] == T:
                crit_cells.add((i, j))
    
    # Collect M (cells that are '?' and in crit_cells)
    M = []
    for i in range(H):
        for j in range(W):
            if grid[i][j] == '?' and (i, j) in crit_cells:
                M.append( (i,j) )
    m = len(M)
    M_set = set(M)
    other_q = 0
    for i in range(H):
        for j in range(W):
            if grid[i][j] == '?' and (i,j) not in M_set:
                other_q +=1
    
    # Enumerate all critical paths via backtracking
    paths = []
    temp_path = []
    
    def backtrack(i, j):
        temp_path.append( (i,j) )
        if i == 0 and j ==0:
            temp_path.reverse()
            path = list(temp_path)
            temp_path.reverse()
            valid = True
            for (x,y) in path:
                if (x,y) not in crit_cells:
                    valid = False
                    break
            if valid:
                q_cells = []
                for (x,y) in path:
                    if grid[x][y] == '?':
                        q_cells.append( (x,y) )
                paths.append( set(q_cells) )
            temp_path.pop()
            return
        current_sum = dp[i][j]
        cell_val = orig_grid[i][j]
        prev_sum = current_sum - cell_val
        if i >0 and dp[i-1][j] == prev_sum:
            backtrack(i-1, j)
        if j>0 and dp[i][j-1] == prev_sum:
            backtrack(i, j-1)
        temp_path.pop()
    
    backtrack(H-1, W-1)
    
    # Remove duplicate paths and merge
    unique_paths = []
    seen = set()
    for p in paths:
        key = frozenset(p)
        if key not in seen:
            seen.add(key)
            unique_paths.append(p)
    
    # Apply inclusion-exclusion
    from itertools import combinations
    
    total = 0
    n = len(unique_paths)
    for k in range(1, n+1):
        for subset in combinations(unique_paths, k):
            union = set()
            for p in subset:
                union.update(p)
            u = len(union)
            cnt = pow(9, m - u, MOD)
            if k %2 ==1:
                total = (total + cnt * pow(-1, k+1)) % MOD
            else:
                total = (total + cnt * pow(-1, k+1)) % MOD
    Y = total % MOD
    
    # Calculate 9^other_q mod MOD
    pow_other = pow(9, other_q, MOD)
    
    X = (Y * pow_other) % MOD
    
    print(T)
    print(X)
    
if __name__ == '__main__':
    main()
0