結果
問題 | No.309 シャイな人たち (1) |
ユーザー | Min_25 |
提出日時 | 2015-12-02 07:20:28 |
言語 | PyPy2 (7.3.15) |
結果 |
AC
|
実行時間 | 1,039 ms / 4,000 ms |
コード長 | 2,701 bytes |
コンパイル時間 | 1,447 ms |
コンパイル使用メモリ | 76,456 KB |
実行使用メモリ | 113,276 KB |
最終ジャッジ日時 | 2024-09-14 07:50:52 |
合計ジャッジ時間 | 11,768 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 750 ms
113,128 KB |
testcase_01 | AC | 1,039 ms
112,760 KB |
testcase_02 | AC | 292 ms
112,636 KB |
testcase_03 | AC | 896 ms
112,876 KB |
testcase_04 | AC | 155 ms
78,976 KB |
testcase_05 | AC | 314 ms
88,532 KB |
testcase_06 | AC | 285 ms
109,900 KB |
testcase_07 | AC | 339 ms
113,156 KB |
testcase_08 | AC | 921 ms
112,900 KB |
testcase_09 | AC | 1,031 ms
113,004 KB |
testcase_10 | AC | 1,033 ms
113,276 KB |
testcase_11 | AC | 993 ms
112,756 KB |
testcase_12 | AC | 886 ms
112,760 KB |
testcase_13 | AC | 77 ms
75,520 KB |
testcase_14 | AC | 78 ms
75,264 KB |
testcase_15 | AC | 157 ms
85,100 KB |
ソースコード
import sys def prob309(): def arith_transform_plus(A, lvn): ret = A[::-1] n = 1 << lvn for lvm in range(lvn, 0, -1): m = 1 << lvm mh = m >> 1 for r in range(0, n, m): for j in range(0, mh): ret[r + j] += ret[r + mh + j] return ret def cumu(A, lv): total = 1 << lv ret = [0.0] * 3 ** lv offsets = [0] * total pos = 0 for i in range(total): offsets[i] = pos s = ret[pos] = A[i] pos += 1 f = i t = 1 while f: r = f & -f ofs = offsets[i ^ r] for j in range(t): ret[pos + j] = ret[ofs + j] - ret[pos - t + j] pos += t t <<= 1 f ^= r return ret, offsets def ctz(n): return n.bit_length() - 1 # input rl = sys.stdin.readline R, C = map(int, rl().split()) rl() P = [list(map(lambda n: int(n) / 100., rl().split())) for _ in range(R)] rl() S = [list(map(lambda n: 4 - int(n), rl().split())) for _ in range(R)] # init total = 1 << C pops = [0] * total for i in range(C): t = 1 << i for j in range(t): pops[j + t] = 1 + pops[j] # dp curr = [0.] * total pcurr = [0.] * total pcurr[0] = 1.0 pcurr = arith_transform_plus(pcurr, C) cumu_curr, offsets = cumu(curr, C) cumu_pcurr, _ = cumu(pcurr, C) # O(3^C * C * R) for y in range(R): next = [0.] * total pnext = [0.] * total # O(3^C * C) for s2 in range(total): p = 1.0 for x in range(C): p *= P[y][x] if (s2 & (1 << x)) else 1.0 - P[y][x] points_back = [S[y][x] if (s2 & (1 << x)) else 0 for x in range(C)] s1 = s2 ofs = offsets[s2] + (1 << pops[s2]) - 1 while 1: if cumu_pcurr[ofs]: points = points_back[:] s = s1 while s > 0: t = s & -s x = ctz(t) points[x] += 1 s ^= t for x in range(0, C - 1): if points[x] >= 4: points[x + 1] += 1 for x in range(C - 1, 0, -1): if points[x] >= 4: points[x - 1] += 1 nstate = 0 cnt = 0 for x in range(C): if points[x] >= 4: nstate |= 1 << x cnt += 1 pnext[nstate] += p * cumu_pcurr[ofs] next[nstate] += (cumu_curr[ofs] + cnt * cumu_pcurr[ofs]) * p s1 = (s1 - 1) & s2 ofs -= 1 if s1 == s2: break # (2^C * C) curr = arith_transform_plus(next, C) pcurr = arith_transform_plus(pnext, C) # (3^C) cumu_curr, _ = cumu(curr, C) cumu_pcurr, _ = cumu(pcurr, C) print("%.12f" % (cumu_curr[0])) prob309()