結果

問題 No.309 シャイな人たち (1)
ユーザー gew1fw
提出日時 2025-06-12 20:22:24
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,821 bytes
コンパイル時間 345 ms
コンパイル使用メモリ 82,016 KB
実行使用メモリ 76,808 KB
最終ジャッジ日時 2025-06-12 20:23:14
合計ジャッジ時間 2,861 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 4 WA * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0

    R = int(input[idx])
    idx += 1
    C = int(input[idx])
    idx += 1

    P = []
    for _ in range(R):
        row = list(map(int, input[idx:idx + C]))
        idx += C
        P.append([p / 100.0 for p in row])

    S = []
    for _ in range(R):
        row = list(map(int, input[idx:idx + C]))
        idx += C
        S.append(row)

    # Precompute neighbors for each cell
    neighbors = [[[] for _ in range(C)] for _ in range(R)]
    for i in range(R):
        for j in range(C):
            # Front: i-1, j
            if i > 0:
                neighbors[i][j].append((i - 1, j))
            # Left: i, j-1
            if j > 0:
                neighbors[i][j].append((i, j - 1))
            # Right: i, j+1
            if j < C - 1:
                neighbors[i][j].append((i, j + 1))

    # Initialize x
    x = [[0.0 for _ in range(C)] for _ in range(R)]
    threshold = 1e-12
    max_iter = 1000000  # Prevent infinite loop

    for _ in range(max_iter):
        x_new = [row[:] for row in x]
        max_diff = 0.0

        # Iterate over all cells
        for i in range(R):
            for j in range(C):
                p = P[i][j]
                s = S[i][j]
                n_list = neighbors[i][j]
                m = len(n_list)

                # Compute Q_ij
                required_q = s
                q = 0.0
                for mask in range(0, 1 << m):
                    count = bin(mask).count('1')
                    prob = 1.0
                    for k in range(m):
                        ni, nj = n_list[k]
                        if (mask >> k) & 1:
                            prob *= x[ni][nj]
                        else:
                            prob *= (1 - x[ni][nj])
                    if count >= required_q:
                        q += prob

                # Compute R_ij
                required_r = 4
                r = 0.0
                for mask in range(0, 1 << m):
                    count = bin(mask).count('1')
                    prob = 1.0
                    for k in range(m):
                        ni, nj = n_list[k]
                        if (mask >> k) & 1:
                            prob *= x[ni][nj]
                        else:
                            prob *= (1 - x[ni][nj])
                    if count >= required_r:
                        r += prob

                new_val = p * q + (1 - p) * r
                x_new[i][j] = new_val
                max_diff = max(max_diff, abs(new_val - x[i][j]))

        if max_diff < threshold:
            break

        x = x_new

    expected = 0.0
    for i in range(R):
        for j in range(C):
            expected += x[i][j]

    print("{0:.15f}".format(expected))

if __name__ == "__main__":
    main()
0