結果

問題 No.215 素数サイコロと合成数サイコロ (3-Hard)
ユーザー Min_25Min_25
提出日時 2018-08-04 13:57:45
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
RE  
(最新)
AC  
(最初)
実行時間 -
コード長 2,497 bytes
コンパイル時間 86 ms
コンパイル使用メモリ 12,332 KB
実行使用メモリ 42,568 KB
最終ジャッジ日時 2023-10-19 21:49:37
合計ジャッジ時間 2,140 ms
ジャッジサーバーID
(参考情報)
judge15 / judge13
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 RE -
testcase_01 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

import numpy as np
from numpy.fft import fft, ifft

def trans(f, size, shift):
  mask = (1 << shift) - 1
  ff = np.zeros(size, dtype=np.complex128)
  ff[:len(f)] = ((f & mask) + (f >> shift) * 1j) * 0.5
  ff = fft(ff)
  ffrc = np.concatenate((ff[0:1], ff[-1:0:-1])).conj()
  ffr, ffi = ff + ffrc, ff - ffrc
  return ffr, ffi

def itrans(f, s, mod, shift):
  a = ifft(f[0])[:s]
  lo = a.real.round().astype(np.int64)
  mid = a.imag.round().astype(np.int64)
  hi = ifft(f[1]).real.round().astype(np.int64)[:s]
  ret = (lo + ((mid % mod) << shift) + ((hi % mod) << (2 * shift))) % mod
  return ret

def fmul(r1, i1, r2, i2):
  return (r1 * (r2 + i2) + i1 * r2, -i1 * i2)

def poly_mul_mod(f, g, mod, shift=15):
  s = len(f) + len(g) - 1
  size = 1 << ((2 * s - 1).bit_length() - 1)
  ffr, ffi = trans(f, size, shift)
  fgr, fgi = trans(g, size, shift)
  return itrans(fmul(ffr, ffi, fgr, fgi), s, mod, shift)

def calc(n, f, g, mod, shift=15):
  sf, sg = len(f), len(g)
  size = 1 << ((2 * (2 * sg - 1) - 1).bit_length() - 1)
  sh = size // 2
  ffr, ffi = trans(f, size, shift)
  fgr, fgi = trans(g, size, shift)
  fmgr = np.concatenate((fgr[sh:], fgr[:sh]))
  fmgi = np.concatenate((fgi[sh:], fgi[:sh]))
  lo, hi = fmul(ffr, ffi, fmgr, fmgi)
  if not n & 1:
    lo, hi = (lo[:sh] + lo[sh:]) * 0.5, (hi[:sh] + hi[sh:]) * 0.5
  else:
    a = np.arange(sh) * (2 * np.pi / size)
    vs = (np.cos(a) + 1j * np.sin(a)) * 0.5
    lo, hi = (lo[:sh] - lo[sh:]) * vs, (hi[:sh] - hi[sh:]) * vs
  numer = itrans((lo, hi), (sg + sf - (n & 1)) // 2, mod, shift)
  denom = itrans(fmul(fgr[:sh], fgi[:sh], fgr[sh:], fgi[sh:]), sg, mod, shift)
  return numer, denom

def nth(n, numer, denom, mod):
  while n > 0:
    numer, denom = calc(n, numer, denom, mod)
    n >>= 1
  return numer[0]

def solve():
  import sys

  Ps = np.array([2, 3, 5, 7, 11, 13], dtype=np.int)
  Cs = np.array([4, 6, 8, 9, 10, 12], dtype=np.int)
  mod = 10 ** 9 + 7

  def gene(ds, T):
    dp = np.zeros((T + 1, ds[-1] * T + 1), dtype=np.int)
    dp[0, 0] = 1
    o = ds[0]
    for di in range(6):
      d = ds[di]
      for t in range(T):
        dp[t+1, d+o*t:d*(t+1)+1] = \
          (dp[t+1, d+o*t:d*(t+1)+1] + dp[t, o*t:d*t+1]) % mod
    return dp[T, :]

  for line in sys.stdin:
    N, P, C = map(int, line.split())
    denom = poly_mul_mod(gene(Ps, P), gene(Cs, C), mod)
    denom = (mod - denom) % mod
    denom[0] = 1
    numer = np.cumsum(denom, dtype=np.int64) % mod
    print(nth(N + len(denom) - 1, numer, denom, mod))

solve()
0