結果

問題 No.309 シャイな人たち (1)
ユーザー Min_25Min_25
提出日時 2015-12-02 03:13:14
言語 C++11
(gcc 11.4.0)
結果
AC  
実行時間 999 ms / 4,000 ms
コード長 3,288 bytes
コンパイル時間 787 ms
コンパイル使用メモリ 63,500 KB
実行使用メモリ 38,820 KB
最終ジャッジ日時 2023-10-12 08:50:02
合計ジャッジ時間 12,269 ms
ジャッジサーバーID
(参考情報)
judge12 / judge15
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 827 ms
38,488 KB
testcase_01 AC 921 ms
38,556 KB
testcase_02 AC 693 ms
38,544 KB
testcase_03 AC 863 ms
38,684 KB
testcase_04 AC 17 ms
7,584 KB
testcase_05 AC 188 ms
16,116 KB
testcase_06 AC 692 ms
38,480 KB
testcase_07 AC 697 ms
38,612 KB
testcase_08 AC 868 ms
38,548 KB
testcase_09 AC 923 ms
38,820 KB
testcase_10 AC 999 ms
38,556 KB
testcase_11 AC 966 ms
38,604 KB
testcase_12 AC 877 ms
38,484 KB
testcase_13 AC 7 ms
7,576 KB
testcase_14 AC 7 ms
7,604 KB
testcase_15 AC 162 ms
16,052 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <cassert>
#include <iostream>
#include <vector>
#include <queue>

using uint = unsigned;
using uint64 = unsigned long long;

using namespace std;

double P[11][11];
uint S[11][11];

double dp[2][2048];
double pdp[2][2048];

double ps[2048];
uint pts[2048][11];

uint trans[1 << 22][2];

uint64 s1s[2048];
uint64 s2s[2048];

const uint bits = 6;
uint64 conv[1 << (3 * bits)];

int main() {
  uint R, C; scanf("%u %u", &R, &C);

  for (uint i = 0; i < R; ++i) {
    for (uint j = 0; j < C; ++j) {
      uint r; scanf("%u", &r);
      P[i][j] = double(r) / 100;
    }
  }

  for (uint i = 0; i < R; ++i) {
    for (uint j = 0; j < C; ++j) {
      scanf("%u", &S[i][j]);
    }
  }

  uint total = 1 << C;

  for (uint s = 0; s < total * total; ++s) {
    uint points[11] = {0};
    uint plus[11] = {0};

    for (uint x = 0; x < C; ++x) {
      points[x] = ((s >> (2 * x)) & 3) + 1;
    }

    bool updated = true;
    while (updated) {
      updated = false;
      for (int x = 0; x < C; ++x) {
        uint pl = 0;
        if (x - 1 >= 0 && points[x - 1] >= 4) pl += 1;
        if (x + 1 < C  && points[x + 1] >= 4) pl += 1;
        if (pl > plus[x]) {
          points[x] += pl - plus[x];
          plus[x] = pl;
          updated = true;
        }
      }
    }

    uint nstate = 0;
    uint cnt = 0;
    for (uint x = 0; x < C; ++x) {
      if (points[x] >= 4) {
        nstate |= 1 << x;
        cnt += 1;
      }
    }
    trans[s][0] = nstate;
    trans[s][1] = cnt;
  }

  double* curr = dp[0], *next = dp[1];
  double* pcurr = pdp[0], *pnext = pdp[1];
  fill(curr, curr + total, .0);
  fill(pcurr, pcurr + total, .0);

  for (uint s1 = 0; s1 < total; ++s1) {
    uint64 t = 0;
    for (uint x = 0; x < C; ++x) {
      if (s1 & (1 << x)) {
        t |= 1ull << (3 * x);
      }
    }
    s1s[s1] = t;
  }

  for (uint s = 0; s < (1u << (3 * bits)); ++s) {
    uint t = s;
    uint c = 0;
    for (uint i = 0; i < bits; ++i) {
      int p = t & 7;
      c |= max(0, min(3, p - 1)) << (2 * i);
      t >>= 3;
    }
    conv[s] = c;
  }

  pcurr[0] = 1.0;
  for (uint y = 0; y < R; ++y) {
    fill(next, next + total, 0);
    fill(pnext, pnext + total, 0);

    for (uint s2 = 0; s2 < total; ++s2) {
      double p = 1.0;

      uint64 s = 0;
      for (uint x = 0; x < C; ++x) {
        if (s2 & (1 << x)) {
          p *= P[y][x];
          pts[s2][x] = 4 - S[y][x];
        } else {
          p *= 1 - P[y][x];
          pts[s2][x] = 0;
        }
        s |= uint64(pts[s2][x]) << (3 * x);
      }
      ps[s2] = p;
      s2s[s2] = s;
    }

    for (uint s1 = 0; s1 < total; ++s1) {
      if (pcurr[s1] == 0) {
        continue;
      }

      for (uint s2 = 0; s2 < total; ++s2) {
        double p = ps[s2];
        if (p == 0) {
          continue;
        }

        uint64 t = s1s[s1] + s2s[s2];
        uint s = conv[t & ((1u << (3 * bits)) - 1)];
        s |= conv[t >> (3 * bits)] << (2 * bits);

        uint nstate = trans[s][0];
        uint cnt = trans[s][1];

        next[nstate] += (curr[s1] + cnt * pcurr[s1]) * p;
        pnext[nstate] += pcurr[s1] * p;
      }
    }
    swap(curr, next);
    swap(pcurr, pnext);
  }
  double ans = 0;
  for (uint i = 0; i < total; ++i) {
    ans += curr[i];
  }
  printf("%.12f\n", ans);
  return 0;
}
0