結果

問題 No.309 シャイな人たち (1)
ユーザー lam6er
提出日時 2025-04-09 21:02:53
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,604 bytes
コンパイル時間 179 ms
コンパイル使用メモリ 82,100 KB
実行使用メモリ 76,964 KB
最終ジャッジ日時 2025-04-09 21:04:20
合計ジャッジ時間 1,862 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 4 WA * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    input = sys.stdin.read().split()
    idx = 0
    R = int(input[idx])
    idx += 1
    C = int(input[idx])
    idx += 1

    P = []
    for _ in range(R):
        row = list(map(int, input[idx:idx+C]))
        idx += C
        P.append(row)
    S = []
    for _ in range(R):
        row = list(map(int, input[idx:idx+C]))
        idx += C
        S.append(row)
    
    # Precompute neighbors for each cell
    neighbors = [[[] for _ in range(C)] for __ in range(R)]
    for i in range(R):
        for j in range(C):
            current_neighbors = []
            if i > 0:
                current_neighbors.append((i-1, j))
            if j > 0:
                current_neighbors.append((i, j-1))
            if j < C-1:
                current_neighbors.append((i, j+1))
            neighbors[i][j] = current_neighbors
    
    e_prev = [[0.0 for _ in range(C)] for __ in range(R)]
    e_new = [[0.0 for _ in range(C)] for __ in range(R)]
    max_iterations = 10000
    tolerance = 1e-12
    iteration = 0

    while True:
        max_diff = 0.0
        for i in range(R):
            for j in range(C):
                p = P[i][j] / 100.0
                s = S[i][j]
                nb_cells = neighbors[i][j]
                k = len(nb_cells)
                if k == 0:
                    if s <= 0:
                        pr = 1.0
                    else:
                        pr = 0.0
                else:
                    pr = 0.0
                    e_nb = [e_prev[x][y] for (x, y) in nb_cells]
                    for mask in range(1 << k):
                        current_prob = 1.0
                        sum_s = 0
                        for bit in range(k):
                            if (mask >> bit) & 1:
                                sum_s += 1
                                current_prob *= e_nb[bit]
                            else:
                                current_prob *= (1 - e_nb[bit])
                        if sum_s >= s:
                            pr += current_prob
                new_val = p * pr
                e_new[i][j] = new_val
                max_diff = max(max_diff, abs(new_val - e_prev[i][j]))
        if max_diff < tolerance:
            break
        e_prev, e_new = e_new, e_prev  # Swap for next iteration
        iteration += 1
        if iteration > max_iterations:
            break  # Prevent infinite loop
    
    # Sum up all e_prev values (since after swap, e_prev holds the latest)
    total = sum(sum(row) for row in e_prev)
    print("{0:.10f}".format(total))

if __name__ == "__main__":
    main()
0