結果

問題 No.309 シャイな人たち (1)
ユーザー Min_25Min_25
提出日時 2015-12-02 07:20:28
言語 PyPy2
(7.3.15)
結果
AC  
実行時間 1,039 ms / 4,000 ms
コード長 2,701 bytes
コンパイル時間 1,447 ms
コンパイル使用メモリ 76,456 KB
実行使用メモリ 113,276 KB
最終ジャッジ日時 2024-09-14 07:50:52
合計ジャッジ時間 11,768 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 750 ms
113,128 KB
testcase_01 AC 1,039 ms
112,760 KB
testcase_02 AC 292 ms
112,636 KB
testcase_03 AC 896 ms
112,876 KB
testcase_04 AC 155 ms
78,976 KB
testcase_05 AC 314 ms
88,532 KB
testcase_06 AC 285 ms
109,900 KB
testcase_07 AC 339 ms
113,156 KB
testcase_08 AC 921 ms
112,900 KB
testcase_09 AC 1,031 ms
113,004 KB
testcase_10 AC 1,033 ms
113,276 KB
testcase_11 AC 993 ms
112,756 KB
testcase_12 AC 886 ms
112,760 KB
testcase_13 AC 77 ms
75,520 KB
testcase_14 AC 78 ms
75,264 KB
testcase_15 AC 157 ms
85,100 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def prob309():
  def arith_transform_plus(A, lvn):
    ret = A[::-1]
    n = 1 << lvn
    for lvm in range(lvn, 0, -1):
      m = 1 << lvm
      mh = m >> 1
      for r in range(0, n, m):
        for j in range(0, mh):
          ret[r + j] += ret[r + mh + j]
    return ret

  def cumu(A, lv):
    total = 1 << lv
    ret = [0.0] * 3 ** lv
    offsets = [0] * total
    pos = 0
    for i in range(total):
      offsets[i] = pos
      s = ret[pos] = A[i]
      pos += 1
      f = i
      t = 1
      while f:
        r = f & -f
        ofs = offsets[i ^ r]
        for j in range(t):
          ret[pos + j] = ret[ofs + j] - ret[pos - t + j]
        pos += t
        t <<= 1
        f ^= r
    return ret, offsets

  def ctz(n):
    return n.bit_length() - 1

  # input
  rl = sys.stdin.readline

  R, C = map(int, rl().split())
  rl()

  P = [list(map(lambda n: int(n) / 100., rl().split())) for _ in range(R)]
  rl()

  S = [list(map(lambda n: 4 - int(n), rl().split())) for _ in range(R)]

  # init
  total = 1 << C

  pops = [0] * total
  for i in range(C):
    t = 1 << i
    for j in range(t):
      pops[j + t] = 1 + pops[j]

  # dp
  curr = [0.] * total
  pcurr = [0.] * total

  pcurr[0] = 1.0
  pcurr = arith_transform_plus(pcurr, C)

  cumu_curr, offsets = cumu(curr, C)
  cumu_pcurr, _ = cumu(pcurr, C)

  # O(3^C * C * R)
  for y in range(R):
    next = [0.] * total
    pnext = [0.] * total

    # O(3^C * C)
    for s2 in range(total):
      p = 1.0
      for x in range(C):
        p *= P[y][x] if (s2 & (1 << x)) else 1.0 - P[y][x]

      points_back = [S[y][x] if (s2 & (1 << x)) else 0 for x in range(C)]

      s1 = s2
      
      ofs = offsets[s2] + (1 << pops[s2]) - 1

      while 1:
        if cumu_pcurr[ofs]:
          points = points_back[:]
          s = s1
          while s > 0:
            t = s & -s
            x = ctz(t)
            points[x] += 1
            s ^= t

          for x in range(0, C - 1):
            if points[x] >= 4:
              points[x + 1] += 1

          for x in range(C - 1, 0, -1):
            if points[x] >= 4:
              points[x - 1] += 1

          nstate = 0
          cnt = 0
          for x in range(C):
            if points[x] >= 4:
              nstate |= 1 << x
              cnt += 1

          pnext[nstate] += p * cumu_pcurr[ofs]
          next[nstate] += (cumu_curr[ofs] + cnt * cumu_pcurr[ofs]) * p

        s1 = (s1 - 1) & s2
        ofs -= 1
        if s1 == s2:
          break

    # (2^C * C)
    curr = arith_transform_plus(next, C)
    pcurr = arith_transform_plus(pnext, C)

    # (3^C)
    cumu_curr, _ = cumu(curr, C)
    cumu_pcurr, _ = cumu(pcurr, C)

  print("%.12f" % (cumu_curr[0]))

prob309()
0