結果

問題 No.309 シャイな人たち (1)
ユーザー gew1fw
提出日時 2025-06-12 20:12:05
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,633 bytes
コンパイル時間 199 ms
コンパイル使用メモリ 82,040 KB
実行使用メモリ 77,032 KB
最終ジャッジ日時 2025-06-12 20:15:57
合計ジャッジ時間 2,126 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 4 WA * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    idx = 0
    R = int(data[idx])
    idx += 1
    C = int(data[idx])
    idx += 1

    P = []
    for _ in range(R):
        row = list(map(int, data[idx:idx+C]))
        P.append(row)
        idx += C

    S = []
    for _ in range(R):
        row = list(map(int, data[idx:idx+C]))
        S.append(row)
        idx += C

    # Precompute neighbors for each cell (i, j)
    neighbors = [[[] for _ in range(C)] for __ in range(R)]
    for i in range(R):
        for j in range(C):
            # Check front (i-1, j)
            if i > 0:
                neighbors[i][j].append((i-1, j))
            # Check left (i, j-1)
            if j > 0:
                neighbors[i][j].append((i, j-1))
            # Check right (i, j+1)
            if j < C - 1:
                neighbors[i][j].append((i, j+1))

    # Initialize p
    p = [[0.5 for _ in range(C)] for __ in range(R)]
    threshold = 1e-9
    max_iterations = 100000

    for _ in range(max_iterations):
        new_p = [[0.0 for _ in range(C)] for __ in range(R)]
        for i in range(R):
            for j in range(C):
                P_ij = P[i][j]
                S_ij = S[i][j]
                p_know = P_ij / 100.0
                p_not_know = 1.0 - p_know
                a = 4 - S_ij
                # Get the neighbor cells
                neighbor_cells = neighbors[i][j]
                m = len(neighbor_cells)
                # List of p values for each neighbor
                neighbor_p = [p[i_][j_] for (i_, j_) in neighbor_cells]

                # Compute P_sum_ge_t
                p_sum_ge_t = 0.0
                if a < 4:
                    t = 4 - a
                    for mask in range(0, 1 << m):
                        s = bin(mask).count('1')
                        if s < t:
                            continue
                        prob = 1.0
                        for k in range(m):
                            if (mask >> k) & 1:
                                prob *= neighbor_p[k]
                            else:
                                prob *= (1.0 - neighbor_p[k])
                        p_sum_ge_t += prob

                # Compute P_sum_ge_4
                p_sum_ge_4 = 0.0
                if m >= 4:
                    for mask in range(0, 1 << m):
                        s = bin(mask).count('1')
                        if s >= 4:
                            prob = 1.0
                            for k in range(m):
                                if (mask >> k) & 1:
                                    prob *= neighbor_p[k]
                                else:
                                    prob *= (1.0 - neighbor_p[k])
                            p_sum_ge_4 += prob

                # Compute case1 and case2
                if a >= 4:
                    case1 = p_know * 1.0
                else:
                    case1 = p_know * p_sum_ge_t
                case2 = p_not_know * p_sum_ge_4
                new_p[i][j] = case1 + case2

        # Check for convergence
        max_change = 0.0
        for i in range(R):
            for j in range(C):
                change = abs(new_p[i][j] - p[i][j])
                if change > max_change:
                    max_change = change
        if max_change < threshold:
            break
        # Update p
        p = new_p

    # Compute the result
    result = 0.0
    for i in range(R):
        for j in range(C):
            result += p[i][j]

    # Output with sufficient precision
    print("{0:.10f}".format(result))

if __name__ == "__main__":
    main()
0