結果

問題 No.214 素数サイコロと合成数サイコロ (3-Medium)
ユーザー Min_25Min_25
提出日時 2015-05-23 05:50:02
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
AC  
実行時間 1,203 ms / 3,000 ms
コード長 14,585 bytes
コンパイル時間 128 ms
コンパイル使用メモリ 14,464 KB
実行使用メモリ 13,184 KB
最終ジャッジ日時 2024-07-06 06:02:39
合計ジャッジ時間 5,192 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1,189 ms
13,184 KB
testcase_01 AC 1,107 ms
12,800 KB
testcase_02 AC 1,203 ms
13,056 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

def _poly_mul(poly1, poly2):
  ret = [0] * (len(poly1) + len(poly2) - 1)
  for i in range(len(poly2)):
    if poly2[i] == 0:
      continue
    coef = poly2[i]
    for j in range(len(poly1)):
      ret[i + j] += coef * poly1[j]
  return ret

def poly_mul_karatsuba(poly1, poly2, threshold=16):
  size = len(poly1)
  if size >= threshold:
    size_l = (size + 1) // 2
    size_h = size - size_l
    p1 = poly_mul_karatsuba(poly1[:size_h], poly2[:size_h], threshold)
    p2 = poly_mul_karatsuba(poly1[size_h:], poly2[size_h:], threshold)
    q1 = poly1[size_h:]
    q2 = poly2[size_h:]
    ofs = size_l - size_h
    for i in range(size_h):
      q1[ofs + i] += poly1[i]
      q2[ofs + i] += poly2[i]
    p3 = poly_mul_karatsuba(q1, q2, threshold)
    ret = p1
    ret.extend([0])
    ret.extend(p2)
    for i in range(size_l * 2 - 1):
      p3[i] -= ret[2 * size_h + i]
    for i in range(size_h * 2 - 1):
      p3[ofs * 2 + i] -= ret[i]
    ofs = 2 * size - 3 * size_l
    for i in range(size_l * 2 - 1):
      ret[ofs + i] += p3[i]
    return ret
  else:
    return _poly_mul(poly1, poly2)

def _pack(pack, shamt):
  size = len(pack)
  while size > 1:
    npack = []
    for i in range(0, size - 1, 2):
      npack.append(pack[i] | (pack[i+1] << shamt))
    if size & 1:
      npack.append(pack[-1])
    pack = npack
    size = (size + 1) >> 1
    shamt <<= 1
  return pack[0]

def _pack1(seq, shamt):
  M = _pack(seq, shamt)
  size = len(seq) * 2 - 1
  block_size = 1 << ilog2(size - 1)
  return M, shamt * block_size

def _pack2(seq1, seq2, shamt):
  M1 = _pack(seq1, shamt)
  M2 = _pack(seq2, shamt)
  size = len(seq1) + len(seq2) - 1
  block_size = 1 << ilog2(size - 1)
  return M1, M2, shamt * block_size

def pack_sequence(seq):
  max_bits = max([c.bit_length() for c in seq])
  size = len(seq)
  shamt = (max_bits * 2 + size.bit_length())
  return _pack1(seq, shamt)

def pack_sequence_mod(seq, mod):
  size = len(seq)
  max_value = (mod - 1) ** 2 * size
  shamt = max_value.bit_length()
  return _pack1(seq, shamt)

def pack_sequence2(seq1, seq2):
  max_bits_1 = max([c.bit_length() for c in seq1])
  max_bits_2 = max([c.bit_length() for c in seq2])
  size = min(len(seq1), len(seq2))
  shamt = (max_bits_1 + max_bits_2 + size.bit_length())
  return _pack2(seq1, seq2, shamt)

def pack_sequence2_mod(seq1, seq2, mod):
  size = min(len(seq1), len(seq2))
  max_value = (mod - 1) ** 2 * size
  shamt = max_value.bit_length()
  return _pack2(seq1, seq2, shamt)

def unpack_sequence(M, size, shamt):
  needed_sizes = []
  s = size
  while s > 1:
    needed_sizes.append(s)
    s = (s + 1) >> 1
  ret = [M]
  for needed_size in needed_sizes[::-1]:
    mask = (1 << shamt) - 1
    nret = []
    for c in ret:
      nret.append(c & mask)
      nret.append(c >> shamt)
    ret = nret[:needed_size]
    shamt >>= 1
  return ret

def poly_mul_builtin(poly1, poly2):
  M1, M2, shamt = pack_sequence2(poly1, poly2)
  size = len(poly1) + len(poly2) - 1
  return unpack_sequence(M1 * M2, size, shamt)

def poly_mul(poly1, poly2, threshold=16, use_builtin=False):
  t = type(poly1[0])
  if use_builtin and len(poly1) >= threshold and (t == int or t == long):
    return poly_mul_builtin(poly1, poly2)
  else:
    if len(poly1) == len(poly2):
      return poly_mul_karatsuba(poly1, poly2, threshold)
    else:
      return _poly_mul(poly1, poly2)

def poly_square_builtin(poly):
  M, shamt = pack_sequence(poly)
  size = len(poly) * 2 - 1
  return unpack_sequence(M ** 2, size, shamt)

def _poly_square(poly):
  size = len(poly)
  ret = [0] * (size * 2 - 1)
  for i in range(size):
    ret[2 * i] = poly[i] * poly[i]
  for i in range(size):
    coef = 2 * poly[i]
    for j in range(i + 1, size):
      ret[i + j] += coef * poly[j]
  return ret

def poly_square_karatsuba(poly, threshold=16):
  size = len(poly)
  if size >= threshold:
    size_l = (size + 1) // 2
    size_h = size - size_l
    p1 = poly_square_karatsuba(poly[:size_h], threshold)
    p2 = poly_square_karatsuba(poly[size_h:], threshold)
    S = poly[size_h:]
    ofs = size_l - size_h
    for i in range(size_h):
      S[ofs + i] += poly[i]
    p3 = poly_square_karatsuba(S, threshold)
    ret = p1
    ret.extend([0])
    ret.extend(p2)
    for i in range(size_l * 2 - 1):
      p3[i] -= ret[2 * size_h + i]
    for i in range(size_h * 2 - 1):
      p3[ofs * 2 + i] -= ret[i]
    ofs = 2 * size - 3 * size_l
    for i in range(size_l * 2 - 1):
      ret[ofs + i] += p3[i]
    return ret
  else:
    return _poly_square(poly)

def poly_square(poly, threshold=16, use_builtin=False):
  t = type(poly[0])
  if use_builtin and len(poly) >= threshold and (t == int or t == long):
    return poly_square_builtin(poly)
  else:
    if len(poly) >= threshold:
      return poly_square_karatsuba(poly)
    else:
      return _poly_square(poly)

def poly_pow(poly, e, threshold=16):
  ret = [1]
  if e == 0:
    return ret
  mask = 1 << (e.bit_length() - 1)
  ret = [1]
  while mask:
    if e & mask:
      ret = poly_mul(ret, poly, threshold, False)
    mask >>= 1
    if not mask:
      break
    ret = poly_square(ret, threshold, False)
  return ret

def poly_inverse(poly, size):
  assert(poly[0] == 1)

  degs = []
  deg = size - 1
  while deg:
    degs.append(deg)
    deg >>= 1

  poly2 = poly[:]
  if len(poly2) < size:
    poly2.extend([0] * (size - len(poly2)))

  inv = [1]
  for t in degs[::-1]:
    added = t + 1 - len(inv)
    tmp = poly_mul(poly2[:t + 1], inv)[len(inv):]
    tmp = poly_mul(tmp[:added], inv[:added])
    inv.extend([-v for v in tmp[:added]])
  return inv

def poly_mul_mod_ntt(poly1, poly2, mod):
  p1, p2, p3 = [880803841, 897581057, 998244353]
  z1, z2, z3 = [273508579, 872686320, 15311432]

  s1 = len(poly1)
  s2 = len(poly2)
  ntt_size = 2 << ilog2(max(s1, s2) * 2 - 1)
  size = s1 + s2 - 1

  A = poly1[:] + [0] * (ntt_size - s1)
  B = poly2[:] + [0] * (ntt_size - s2)

  A1 = _ntt_convolve(A[:], B[:], size, p1, z1)
  A2 = _ntt_convolve(A[:], B[:], size, p2, z2)
  A3 = _ntt_convolve(A[:], B[:], size, p3, z3)

  inv = inv_mod(p1, p2)
  for i in range(size):
    k = (A2[i] - A1[i]) * inv % p2
    A1[i] += k * p1

  p12 = p1 * p2
  inv = inv_mod(p12, p3)
  for i in range(size):
    k = (A3[i] - A1[i]) % p3 * inv % p3 
    A1[i] = (A1[i] + k * (p12 % mod)) % mod

  return A1[:size]

def poly_square_mod_ntt(poly1, mod):
  p1, p2, p3 = [880803841, 897581057, 998244353]
  z1, z2, z3 = [273508579, 872686320, 15311432]

  s1 = len(poly1)
  ntt_size = 2 << ilog2(s1 * 2 - 1)
  size = 2 * s1 - 1

  A = poly1[:] + [0] * (ntt_size - s1)

  A1 = _ntt_convolve_self(A[:], size, p1, z1)
  A2 = _ntt_convolve_self(A[:], size, p2, z2)
  A3 = _ntt_convolve_self(A[:], size, p3, z3)

  inv = inv_mod(p1, p2)
  for i in range(size):
    k = (A2[i] - A1[i]) * inv % p2
    A1[i] += k * p1

  p12 = p1 * p2
  inv = inv_mod(p12, p3)
  for i in range(size):
    k = (A3[i] - A1[i]) % p3 * inv % p3 
    A1[i] = (A1[i] + k * (p12 % mod)) % mod

  return A1[:size]

def poly_mul_mod_builtin(poly1, poly2, mod):
  M1, M2, shamt = pack_sequence2_mod(poly1, poly2, mod)
  size = len(poly1) + len(poly2) - 1
  seq = unpack_sequence(M1 * M2, size, shamt)
  return [int(x % mod) for x in seq]

def poly_square_mod_builtin(poly, mod):
  M, shamt = pack_sequence_mod(poly, mod)
  size = len(poly) * 2 - 1
  seq = unpack_sequence(M ** 2, size, shamt)
  return [int(x % mod) for x in seq]

def poly_add_mod(poly1, ofs1, poly2, ofs2, size, mod):
  diff = ofs2 - ofs1
  for i in range(ofs1, ofs1 + size):
    poly1[i] = (poly1[i] + poly2[i + diff]) % mod

def poly_sub_mod(poly1, ofs1, poly2, ofs2, size, mod):
  diff = ofs2 - ofs1
  for i in range(ofs1, ofs1 + size):
    poly1[i] = (poly1[i] - poly2[i + diff]) % mod

def poly_mul_mod_karatsuba(poly1, poly2, mod, threshold=128):
  size = len(poly1)
  if size >= threshold:
    size_l = (size + 1) // 2
    size_h = size - size_l
    p1 = poly_mul_mod_karatsuba(poly1[:size_h], poly2[:size_h], mod, threshold)
    p2 = poly_mul_mod_karatsuba(poly1[size_h:], poly2[size_h:], mod, threshold)
    q1 = poly1[size_h:]
    q2 = poly2[size_h:]
    ofs = size_l - size_h
    poly_add_mod(q1, ofs, poly1, 0, size_h, mod)
    poly_add_mod(q2, ofs, poly2, 0, size_h, mod)
    p3 = poly_mul_mod_karatsuba(q1, q2, mod, threshold)
    ret = p1
    ret.extend([0])
    ret.extend(p2)
    poly_sub_mod(p3, 0, ret, 2 * size_h, size_l * 2 - 1, mod)
    poly_sub_mod(p3, ofs * 2, ret, 0, size_h * 2 - 1, mod)
    ofs = 2 * size - 3 * size_l
    poly_add_mod(ret, ofs, p3, 0, size_l * 2 - 1, mod)
    return ret
  else:
    return _poly_mul_mod(poly1, poly2, mod)

def _poly_mul_mod(poly1, poly2, mod):
  ret = [0] * (len(poly1) + len(poly2) - 1)
  for i in range(len(poly2)):
    if poly2[i] == 0:
      continue
    coef = poly2[i]
    for j in range(len(poly1)):
      ret[i + j] = (ret[i + j] + coef * poly1[j]) % mod
  return ret

def poly_mul_mod(poly1, poly2, mod, threshold=128, ntt_threshold=65536):
  size1 = len(poly1)
  size2 = len(poly2)
  if size1 >= ntt_threshold and size2 >= ntt_threshold and mod <= 2 * 10 ** 9:
    return poly_mul_mod_ntt(poly1, poly2, mod)
  else:
    if size1 <= threshold and size2 <= threshold:
      return _poly_mul_mod(poly1, poly2, mod)
    else:
      return poly_mul_mod_builtin(poly1, poly2, mod)

def _poly_square_mod(poly, mod):
  size = len(poly)
  ret = [0] * (size * 2 - 1)
  for i in range(size):
    ret[2 * i] = poly[i] * poly[i] % mod
  for i in range(size):
    coef = 2 * poly[i]
    for j in range(i + 1, size):
      ret[i + j] = (ret[i + j] + coef * poly[j]) % mod
  return ret

def poly_square_mod_karatsuba(poly, mod, threshold=64):
  size = len(poly)
  if size >= threshold:
    size_l = (size + 1) // 2
    size_h = size - size_l
    p1 = poly_square_mod_karatsuba(poly[:size_h], mod, threshold)
    p2 = poly_square_mod_karatsuba(poly[size_h:], mod, threshold)
    S = poly[size_h:]
    ofs = size_l - size_h
    poly_add_mod(S, ofs, poly, 0, size_h, mod)
    p3 = poly_square_mod_karatsuba(S, mod, threshold)
    ret = p1
    ret.extend([0])
    ret.extend(p2)
    poly_sub_mod(p3, 0, ret, 2 * size_h, size_l * 2 - 1, mod)
    poly_sub_mod(p3, ofs * 2, ret, 0, size_h * 2 - 1, mod)
    ofs = 2 * size - 3 * size_l
    poly_add_mod(ret, ofs, p3, 0, size_l * 2 - 1, mod)
    return ret
  else:
    return _poly_square_mod(poly, mod)

def poly_square_mod(poly, mod, threshold=128, k_threshold=64, ntt_threshold=65536):
  size = len(poly)
  if size >= ntt_threshold and mod <= 2 * 10 ** 9:
    return poly_square_mod_ntt(poly, mod)
  elif size >= threshold:
    return poly_square_mod_builtin(poly, mod)
  elif size >= k_threshold:
    return poly_square_mod_karatsuba(poly, mod)
  else:
    return _poly_square_mod(poly, mod)

def poly_pow_mod(poly, e, mod):
  ret = [1]
  if e == 0:
    return ret
  mask = 1 << (e.bit_length() - 1)
  ret = [1]
  while mask:
    if e & mask:
      ret = poly_mul_mod(ret, poly, mod)
    mask >>= 1
    if not mask:
      break
    ret = poly_square_mod(ret, mod)
  return ret

def _poly_rem_mod(poly1, poly2, mod):
  if len(poly1) < len(poly2):
    return poly1[:]

  ret = poly1[:]
  dif = len(poly1) - len(poly2) + 1

  assert(poly2[0] == 1)
  for i in range(dif):
    if ret[i] == 0:
      continue
    coef = ret[i] % mod
    for j in range(1, len(poly2)):
      ret[i + j] = (ret[i + j] - coef * poly2[j]) % mod
    ret[i] = coef

  return ret[dif:]

def poly_inverse_mod(poly, size, mod):
  assert(poly[0] == 1)

  degs = []
  deg = size - 1
  while deg:
    degs.append(deg)
    deg >>= 1

  poly2 = poly[:]
  if len(poly2) < size:
    poly2.extend([0] * (size - len(poly2)))

  inv = [1]
  for t in degs[::-1]:
    added = t + 1 - len(inv)
    tmp = poly_mul_mod(poly2[:t + 1], inv, mod)[len(inv):]
    tmp = poly_mul_mod(tmp[:added], inv[:added], mod)
    inv.extend([-v % mod for v in tmp[:added]])
  return inv

def poly_div_mod(poly1, poly2, mod, inverse=[]):
  assert(len(poly1) >= len(poly2))
  assert(poly2[0] == 1)
  needed_size = len(poly1) - len(poly2) + 1
  if len(inverse) == 0:
    inverse = poly_inverse_mod(poly2, needed_size, mod)
  assert(len(inverse) >= needed_size)
  ret = poly_mul_mod(poly1[:needed_size], inverse[:needed_size], mod)
  return ret[:needed_size]

def poly_rem_mod(poly1, poly2, mod, inverse=[]):
  size1 = len(poly1)
  size2 = len(poly2)
  if size1 < size2:
    return poly1[:]

  needed_size = size1 - size2 + 1
  if len(poly2) < 10 or needed_size < 10:
    return _poly_rem_mod(poly1, poly2, mod)

  if len(inverse) == 0:
    inverse = poly_inverse_mod(poly2, needed_size, mod)
  
  poly_q = poly_div_mod(poly1, poly2, mod, inverse)
  poly_q2 = poly_mul_mod(poly_q, poly2, mod)
  return [(poly1[i] - poly_q2[i]) % mod for i in range(size1 - size2 + 1, size1)]

def poly_power_rem_mod(e, poly_divisor, mod, threshold=32):
  """
  Return x^e % poly_divisor (modulo mod)

  assume:
  - deg(poly_divisor) > 0
  - mod > 1
  """
  if e == 0:
    return [1]

  ret = [1]
  mask = 1 << (e.bit_length() - 1)

  inverse = []
  if len(poly_divisor) >= threshold:
    inverse = poly_inverse_mod(poly_divisor, len(poly_divisor), mod)

  while mask:
    if e & mask:
      ret.append(0)
    mask >>= 1
    if not mask:
      break
    ret = poly_square_mod(ret, mod)
    ret = poly_rem_mod(ret, poly_divisor, mod, inverse)

  if len(ret) >= len(poly_divisor):
    ret = poly_rem_mod(ret, poly_divisor, mod, inverse)
  return ret

def pat(dice, P, mod):
  dp = [[0] * (P * dice[-1] + 1) for _ in range(P + 1)]
  dp[0][0] = 1
  for di, d in enumerate(dice):
    for i in range(P):
      for k in range(dice[0] * i, dice[di] * i + 1):
        if dp[i][k]:
          dp[i + 1][k + d] = (dp[i + 1][k + d] + dp[i][k]) % mod
  return dp[-1]

def ilog2(n):
  if n <= 0:
    return 0
  else:
    return n.bit_length() - 1

import sys

def solve():
  N, P, C = map(int, sys.stdin.readline().split())

  Ps = [2, 3, 5, 7, 11, 13]
  Cs = [4, 6, 8, 9, 10, 12]

  mod = 10 ** 9 + 7

  Ps = pat(Ps, P, mod)
  Cs = pat(Cs, C, mod)

  poly = poly_mul_mod(Ps, Cs, mod)
  poly[0] = 1
  for i in range(1, len(poly)):
    poly[i] = -poly[i] % mod

  Max = 13 * P + 12 * C
  E = max(0, N - Max)
  inv = poly_inverse_mod(poly, Max, mod)

  sums = [0] * len(poly)
  for i in range(1, len(poly)):
    sums[i] = (sums[i-1] + -poly[i]) % mod
  poly_rem = poly_power_rem_mod(E, poly, mod)

  ans = 0
  for e in range(E, N):
    total = 0
    for i in range(len(poly_rem)):
      total = (total + poly_rem[-1 - i] * inv[i]) % mod
    ans = (ans + total * (sums[Max] - sums[N - e - 1])) % mod
    poly_rem.extend([0])
    poly_rem = poly_rem_mod(poly_rem, poly, mod)
  print(ans)

solve()
0