結果

問題 No.260 世界のなんとか3
ユーザー simansiman
提出日時 2023-07-03 05:46:24
言語 C++17(clang)
(17.0.6 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 3,641 bytes
コンパイル時間 7,109 ms
コンパイル使用メモリ 103,568 KB
実行使用メモリ 11,132 KB
最終ジャッジ日時 2023-09-24 02:26:53
合計ジャッジ時間 10,401 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
4,380 KB
testcase_01 AC 2 ms
4,376 KB
testcase_02 AC 2 ms
4,380 KB
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 AC 1 ms
4,380 KB
testcase_28 AC 39 ms
11,052 KB
testcase_29 AC 55 ms
11,040 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cassert>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <climits>
#include <map>
#include <queue>
#include <set>
#include <cstring>
#include <vector>

using namespace std;
typedef long long ll;

const ll MOD = 1000000007;

ll mod_pow(ll x, ll n, ll mod = MOD) {
  ll res = 1;

  while (n > 0) {
    if (n & 1) {
      res = res * x % mod;
    }

    x = x * x % mod;
    n >>= 1;
  }

  return res;
}

ll f(string S) {
  int len = S.size();
  ll dp1[len][2][3][8];
  ll dp2[len][2][3][8];
  memset(dp1, 0, sizeof(dp1));
  memset(dp2, 0, sizeof(dp2));

  for (int i = 0; i < len; ++i) {
    ll base3 = mod_pow(10, len - i - 1, 3);
    ll base8 = mod_pow(10, len - i - 1, 8);
    int d = S[i] - '0';

    if (i == 0) {
      for (int v = d; v >= 1; --v) {
        int m3 = (v * base3) % 3;
        int m8 = (v * base8) % 8;

        if (v == d) {
          if (v != 3) {
            dp1[i][0][m3][m8] += 1;
          } else {
            dp1[i][1][m3][m8] += 1;
          }
        } else {
          if (v != 3) {
            dp2[i][0][m3][m8] += 1;
          } else {
            dp2[i][1][m3][m8] += 1;
          }
        }
      }
    } else {
      for (int u = 1; u <= 9; ++u) {
        int n_m3 = (u * base3) % 3;
        int n_m8 = (u * base8) % 8;

        if (u != 3) {
          dp2[i][0][n_m3][n_m8] += 1;
        } else {
          dp2[i][1][n_m3][n_m8] += 1;
        }
      }

      for (int b_m3 = 0; b_m3 < 3; ++b_m3) {
        for (int b_m8 = 0; b_m8 < 8; ++b_m8) {
          for (int has_3 = 0; has_3 < 2; ++has_3) {
            {
              int n_m3 = (d * base3 + b_m3) % 3;
              int n_m8 = (d * base8 + b_m8) % 8;

              if (d != 3) {
                dp1[i][has_3][n_m3][n_m8] += dp1[i - 1][has_3][b_m3][b_m8];
              } else {
                dp1[i][1][n_m3][n_m8] += dp1[i - 1][has_3][b_m3][b_m8];
              }
            }

            for (int u = 0; u < d; ++u) {
              int n_m3 = (u * base3 + b_m3) % 3;
              int n_m8 = (u * base8 + b_m8) % 8;

              if (u != 3) {
                dp2[i][has_3][n_m3][n_m8] += dp1[i - 1][has_3][b_m3][b_m8];
              } else {
                dp2[i][1][n_m3][n_m8] += dp1[i - 1][has_3][b_m3][b_m8];
              }
            }

            for (int u = 0; u <= 9; ++u) {
              int n_m3 = (u * base3 + b_m3) % 3;
              int n_m8 = (u * base8 + b_m8) % 8;

              if (u != 3) {
                dp2[i][has_3][n_m3][n_m8] += dp2[i - 1][has_3][b_m3][b_m8];
              } else {
                dp2[i][1][n_m3][n_m8] += dp2[i - 1][has_3][b_m3][b_m8];
              }
            }
          }
        }
      }
    }
  }

  ll cnt = 0;

  for (int m3 = 0; m3 < 3; ++m3) {
    for (int m8 = 0; m8 < 8; ++m8) {
      for (int has_3 = 0; has_3 < 2; ++has_3) {
        if ((m3 == 0 || has_3) && m8 != 0) {
          cnt += dp1[len - 1][has_3][m3][m8];
          cnt += dp2[len - 1][has_3][m3][m8];
          cnt %= MOD;
        }
      }
    }
  }

  return cnt;
}

string str_dec(string str) {
  if (str == "1") {
    return "0";
  }

  reverse(str.begin(), str.end());
  int len = str.size();

  for (int i = 0; i < len; ++i) {
    if (str[i] != '0') {
      str[i]--;
      break;
    }
    str[i] = '9';
  }

  if (str.back() == '0') {
    str.resize(len - 1);
  }

  reverse(str.begin(), str.end());
  return str;
}

int main() {
  string A, B;
  cin >> A >> B;

  ll cnt1 = f(B);
  ll cnt2 = f(str_dec(A));

  // cerr << f("300") << endl;
  // fprintf(stderr, "(%lld, %lld)\n", cnt1, cnt2);
  cout << (cnt1 - cnt2 + MOD) % MOD << endl;

  return 0;
}
0