結果
| 問題 | No.309 シャイな人たち (1) |
| コンテスト | |
| ユーザー |
Min_25
|
| 提出日時 | 2015-12-02 07:20:28 |
| 言語 | PyPy2 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,039 ms / 4,000 ms |
| コード長 | 2,701 bytes |
| 記録 | |
| コンパイル時間 | 1,447 ms |
| コンパイル使用メモリ | 76,456 KB |
| 実行使用メモリ | 113,276 KB |
| 最終ジャッジ日時 | 2024-09-14 07:50:52 |
| 合計ジャッジ時間 | 11,768 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 13 |
ソースコード
import sys
def prob309():
def arith_transform_plus(A, lvn):
ret = A[::-1]
n = 1 << lvn
for lvm in range(lvn, 0, -1):
m = 1 << lvm
mh = m >> 1
for r in range(0, n, m):
for j in range(0, mh):
ret[r + j] += ret[r + mh + j]
return ret
def cumu(A, lv):
total = 1 << lv
ret = [0.0] * 3 ** lv
offsets = [0] * total
pos = 0
for i in range(total):
offsets[i] = pos
s = ret[pos] = A[i]
pos += 1
f = i
t = 1
while f:
r = f & -f
ofs = offsets[i ^ r]
for j in range(t):
ret[pos + j] = ret[ofs + j] - ret[pos - t + j]
pos += t
t <<= 1
f ^= r
return ret, offsets
def ctz(n):
return n.bit_length() - 1
# input
rl = sys.stdin.readline
R, C = map(int, rl().split())
rl()
P = [list(map(lambda n: int(n) / 100., rl().split())) for _ in range(R)]
rl()
S = [list(map(lambda n: 4 - int(n), rl().split())) for _ in range(R)]
# init
total = 1 << C
pops = [0] * total
for i in range(C):
t = 1 << i
for j in range(t):
pops[j + t] = 1 + pops[j]
# dp
curr = [0.] * total
pcurr = [0.] * total
pcurr[0] = 1.0
pcurr = arith_transform_plus(pcurr, C)
cumu_curr, offsets = cumu(curr, C)
cumu_pcurr, _ = cumu(pcurr, C)
# O(3^C * C * R)
for y in range(R):
next = [0.] * total
pnext = [0.] * total
# O(3^C * C)
for s2 in range(total):
p = 1.0
for x in range(C):
p *= P[y][x] if (s2 & (1 << x)) else 1.0 - P[y][x]
points_back = [S[y][x] if (s2 & (1 << x)) else 0 for x in range(C)]
s1 = s2
ofs = offsets[s2] + (1 << pops[s2]) - 1
while 1:
if cumu_pcurr[ofs]:
points = points_back[:]
s = s1
while s > 0:
t = s & -s
x = ctz(t)
points[x] += 1
s ^= t
for x in range(0, C - 1):
if points[x] >= 4:
points[x + 1] += 1
for x in range(C - 1, 0, -1):
if points[x] >= 4:
points[x - 1] += 1
nstate = 0
cnt = 0
for x in range(C):
if points[x] >= 4:
nstate |= 1 << x
cnt += 1
pnext[nstate] += p * cumu_pcurr[ofs]
next[nstate] += (cumu_curr[ofs] + cnt * cumu_pcurr[ofs]) * p
s1 = (s1 - 1) & s2
ofs -= 1
if s1 == s2:
break
# (2^C * C)
curr = arith_transform_plus(next, C)
pcurr = arith_transform_plus(pnext, C)
# (3^C)
cumu_curr, _ = cumu(curr, C)
cumu_pcurr, _ = cumu(pcurr, C)
print("%.12f" % (cumu_curr[0]))
prob309()
Min_25