結果

問題 No.309 シャイな人たち (1)
ユーザー Min_25Min_25
提出日時 2015-12-02 03:45:36
言語 C++11
(gcc 11.4.0)
結果
AC  
実行時間 562 ms / 4,000 ms
コード長 3,128 bytes
コンパイル時間 614 ms
コンパイル使用メモリ 62,936 KB
実行使用メモリ 38,732 KB
最終ジャッジ日時 2024-09-14 07:49:17
合計ジャッジ時間 6,422 ms
ジャッジサーバーID
(参考情報)
judge6 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 410 ms
38,732 KB
testcase_01 AC 502 ms
38,608 KB
testcase_02 AC 274 ms
38,728 KB
testcase_03 AC 442 ms
38,604 KB
testcase_04 AC 12 ms
6,940 KB
testcase_05 AC 97 ms
14,284 KB
testcase_06 AC 279 ms
38,604 KB
testcase_07 AC 279 ms
38,604 KB
testcase_08 AC 444 ms
38,604 KB
testcase_09 AC 501 ms
38,728 KB
testcase_10 AC 562 ms
38,604 KB
testcase_11 AC 527 ms
38,604 KB
testcase_12 AC 446 ms
38,732 KB
testcase_13 AC 6 ms
6,944 KB
testcase_14 AC 6 ms
6,944 KB
testcase_15 AC 69 ms
14,536 KB
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:30:19: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   30 |   uint R, C; scanf("%u %u", &R, &C);
      |              ~~~~~^~~~~~~~~~~~~~~~~
main.cpp:34:20: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   34 |       uint r; scanf("%u", &r);
      |               ~~~~~^~~~~~~~~~
main.cpp:41:12: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
   41 |       scanf("%u", &S[i][j]);
      |       ~~~~~^~~~~~~~~~~~~~~~

ソースコード

diff #

#include <cstdio>
#include <cassert>
#include <iostream>
#include <vector>
#include <queue>

using uint = unsigned;
using uint64 = unsigned long long;

using namespace std;

double P[11][11];
uint S[11][11];

double dp[2][2048];
double pdp[2][2048];

double ps[2048];
uint pts[2048][11];

uint trans[1 << 22][2];

uint64 s1s[2048];
uint64 s2s[2048];

const uint bits = 6;
uint64 conv[1 << (3 * bits)];

int main() {
  uint R, C; scanf("%u %u", &R, &C);

  for (uint i = 0; i < R; ++i) {
    for (uint j = 0; j < C; ++j) {
      uint r; scanf("%u", &r);
      P[i][j] = double(r) / 100;
    }
  }

  for (uint i = 0; i < R; ++i) {
    for (uint j = 0; j < C; ++j) {
      scanf("%u", &S[i][j]);
    }
  }

  uint total = 1 << C;

  for (uint s = 0; s < total * total; ++s) {
    uint points[11] = {0};
    uint plus[11] = {0};

    for (uint x = 0; x < C; ++x) {
      points[x] = ((s >> (2 * x)) & 3) + 1;
    }

    for (uint x = 0; x < C - 1; ++x) {
      if (points[x] >= 4) {
        points[x + 1] += 1;
      }
    }
    for (int x = C - 1; x > 0; --x) {
      if (points[x] >= 4) {
        points[x - 1] += 1;
      }
    }
    uint nstate = 0;
    uint cnt = 0;
    for (uint x = 0; x < C; ++x) {
      if (points[x] >= 4) {
        nstate |= 1 << x;
        cnt += 1;
      }
    }
    trans[s][0] = nstate;
    trans[s][1] = cnt;
  }

  double* curr = dp[0], *next = dp[1];
  double* pcurr = pdp[0], *pnext = pdp[1];
  fill(curr, curr + total, .0);
  fill(pcurr, pcurr + total, .0);

  for (uint s1 = 0; s1 < total; ++s1) {
    uint64 t = 0;
    for (uint x = 0; x < C; ++x) {
      if (s1 & (1 << x)) {
        t |= 1ull << (3 * x);
      }
    }
    s1s[s1] = t;
  }

  for (uint s = 0; s < (1u << (3 * bits)); ++s) {
    uint t = s;
    uint c = 0;
    for (uint i = 0; i < bits; ++i) {
      int p = t & 7;
      c |= max(0, min(3, p - 1)) << (2 * i);
      t >>= 3;
    }
    conv[s] = c;
  }

  pcurr[0] = 1.0;
  for (uint y = 0; y < R; ++y) {
    fill(next, next + total, 0);
    fill(pnext, pnext + total, 0);

    for (uint s2 = 0; s2 < total; ++s2) {
      double p = 1.0;

      uint64 s = 0;
      for (uint x = 0; x < C; ++x) {
        if (s2 & (1 << x)) {
          p *= P[y][x];
          pts[s2][x] = 4 - S[y][x];
        } else {
          p *= 1 - P[y][x];
          pts[s2][x] = 0;
        }
        s |= uint64(pts[s2][x]) << (3 * x);
      }
      ps[s2] = p;
      s2s[s2] = s;
    }

    for (uint s1 = 0; s1 < total; ++s1) {
      if (pcurr[s1] == 0) {
        continue;
      }

      for (uint s2 = 0; s2 < total; ++s2) {
        double p = ps[s2];
        if (p == 0) {
          continue;
        }

        uint64 t = s1s[s1] + s2s[s2];
        uint s = conv[t & ((1u << (3 * bits)) - 1)];
        s |= conv[t >> (3 * bits)] << (2 * bits);

        uint nstate = trans[s][0];
        uint cnt = trans[s][1];

        next[nstate] += (curr[s1] + cnt * pcurr[s1]) * p;
        pnext[nstate] += pcurr[s1] * p;
      }
    }
    swap(curr, next);
    swap(pcurr, pnext);
  }
  double ans = 0;
  for (uint i = 0; i < total; ++i) {
    ans += curr[i];
  }
  printf("%.12f\n", ans);
  return 0;
}
0