結果

問題 No.309 シャイな人たち (1)
ユーザー qwewe
提出日時 2025-04-24 12:32:41
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,586 bytes
コンパイル時間 200 ms
コンパイル使用メモリ 82,644 KB
実行使用メモリ 77,184 KB
最終ジャッジ日時 2025-04-24 12:33:55
合計ジャッジ時間 1,913 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 4 WA * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    input = sys.stdin.read().split()
    idx = 0
    R = int(input[idx]); idx += 1
    C = int(input[idx]); idx += 1

    P = []
    for _ in range(R):
        row = list(map(int, input[idx:idx+C]))
        idx += C
        P.append(row)
    S = []
    for _ in range(R):
        row = list(map(int, input[idx:idx+C]))
        idx += C
        S.append(row)

    # Precompute neighbors for each cell
    neighbors = [[[] for _ in range(C)] for _ in range(R)]
    for i in range(R):
        for j in range(C):
            # Front (i-1, j)
            if i - 1 >= 0:
                neighbors[i][j].append((i-1, j))
            # Left (i, j-1)
            if j - 1 >= 0:
                neighbors[i][j].append((i, j-1))
            # Right (i, j+1)
            if j + 1 < C:
                neighbors[i][j].append((i, j+1))

    # Initialize expected values
    E = [[0.0 for _ in range(C)] for _ in range(R)]
    epsilon = 1e-12
    max_iter = 100000

    for _ in range(max_iter):
        new_E = [[0.0 for _ in range(C)] for _ in range(R)]
        max_diff = 0.0
        for i in range(R):
            for j in range(C):
                s = S[i][j]
                p = P[i][j] / 100.0
                if s == 0:
                    new_val = p
                else:
                    required = s
                    neighbor_list = neighbors[i][j]
                    n = len(neighbor_list)
                    if n < required:
                        new_val = 0.0
                    else:
                        probs = [E[x][y] for (x, y) in neighbor_list]
                        total = 0.0
                        for mask in range(1 << n):
                            cnt = bin(mask).count('1')
                            if cnt >= required:
                                prob = 1.0
                                for k in range(n):
                                    if (mask >> k) & 1:
                                        prob *= probs[k]
                                    else:
                                        prob *= (1 - probs[k])
                                total += prob
                        new_val = p * total
                new_E[i][j] = new_val
                max_diff = max(max_diff, abs(new_E[i][j] - E[i][j]))
        # Update E for next iteration
        for i in range(R):
            for j in range(C):
                E[i][j] = new_E[i][j]
        if max_diff < epsilon:
            break

    total = sum(sum(row) for row in E)
    print("{0:.12f}".format(total))

if __name__ == '__main__':
    main()
0