結果
問題 | No.309 シャイな人たち (1) |
ユーザー | Min_25 |
提出日時 | 2015-12-02 07:20:28 |
言語 | PyPy2 (7.3.15) |
結果 |
AC
|
実行時間 | 1,033 ms / 4,000 ms |
コード長 | 2,701 bytes |
コンパイル時間 | 1,640 ms |
コンパイル使用メモリ | 77,624 KB |
実行使用メモリ | 114,580 KB |
最終ジャッジ日時 | 2023-10-12 08:52:23 |
合計ジャッジ時間 | 11,299 ms |
ジャッジサーバーID (参考情報) |
judge11 / judge14 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 741 ms
114,448 KB |
testcase_01 | AC | 1,033 ms
114,312 KB |
testcase_02 | AC | 282 ms
111,524 KB |
testcase_03 | AC | 889 ms
114,460 KB |
testcase_04 | AC | 153 ms
80,660 KB |
testcase_05 | AC | 306 ms
89,824 KB |
testcase_06 | AC | 279 ms
113,008 KB |
testcase_07 | AC | 334 ms
114,540 KB |
testcase_08 | AC | 916 ms
114,432 KB |
testcase_09 | AC | 1,031 ms
114,580 KB |
testcase_10 | AC | 1,024 ms
114,412 KB |
testcase_11 | AC | 991 ms
114,464 KB |
testcase_12 | AC | 876 ms
114,412 KB |
testcase_13 | AC | 70 ms
76,752 KB |
testcase_14 | AC | 71 ms
76,480 KB |
testcase_15 | AC | 150 ms
87,156 KB |
ソースコード
import sys def prob309(): def arith_transform_plus(A, lvn): ret = A[::-1] n = 1 << lvn for lvm in range(lvn, 0, -1): m = 1 << lvm mh = m >> 1 for r in range(0, n, m): for j in range(0, mh): ret[r + j] += ret[r + mh + j] return ret def cumu(A, lv): total = 1 << lv ret = [0.0] * 3 ** lv offsets = [0] * total pos = 0 for i in range(total): offsets[i] = pos s = ret[pos] = A[i] pos += 1 f = i t = 1 while f: r = f & -f ofs = offsets[i ^ r] for j in range(t): ret[pos + j] = ret[ofs + j] - ret[pos - t + j] pos += t t <<= 1 f ^= r return ret, offsets def ctz(n): return n.bit_length() - 1 # input rl = sys.stdin.readline R, C = map(int, rl().split()) rl() P = [list(map(lambda n: int(n) / 100., rl().split())) for _ in range(R)] rl() S = [list(map(lambda n: 4 - int(n), rl().split())) for _ in range(R)] # init total = 1 << C pops = [0] * total for i in range(C): t = 1 << i for j in range(t): pops[j + t] = 1 + pops[j] # dp curr = [0.] * total pcurr = [0.] * total pcurr[0] = 1.0 pcurr = arith_transform_plus(pcurr, C) cumu_curr, offsets = cumu(curr, C) cumu_pcurr, _ = cumu(pcurr, C) # O(3^C * C * R) for y in range(R): next = [0.] * total pnext = [0.] * total # O(3^C * C) for s2 in range(total): p = 1.0 for x in range(C): p *= P[y][x] if (s2 & (1 << x)) else 1.0 - P[y][x] points_back = [S[y][x] if (s2 & (1 << x)) else 0 for x in range(C)] s1 = s2 ofs = offsets[s2] + (1 << pops[s2]) - 1 while 1: if cumu_pcurr[ofs]: points = points_back[:] s = s1 while s > 0: t = s & -s x = ctz(t) points[x] += 1 s ^= t for x in range(0, C - 1): if points[x] >= 4: points[x + 1] += 1 for x in range(C - 1, 0, -1): if points[x] >= 4: points[x - 1] += 1 nstate = 0 cnt = 0 for x in range(C): if points[x] >= 4: nstate |= 1 << x cnt += 1 pnext[nstate] += p * cumu_pcurr[ofs] next[nstate] += (cumu_curr[ofs] + cnt * cumu_pcurr[ofs]) * p s1 = (s1 - 1) & s2 ofs -= 1 if s1 == s2: break # (2^C * C) curr = arith_transform_plus(next, C) pcurr = arith_transform_plus(pnext, C) # (3^C) cumu_curr, _ = cumu(curr, C) cumu_pcurr, _ = cumu(pcurr, C) print("%.12f" % (cumu_curr[0])) prob309()