結果
問題 | No.214 素数サイコロと合成数サイコロ (3-Medium) |
ユーザー | Min_25 |
提出日時 | 2015-05-23 05:50:45 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 418 ms / 3,000 ms |
コード長 | 14,585 bytes |
コンパイル時間 | 138 ms |
コンパイル使用メモリ | 82,560 KB |
実行使用メモリ | 86,260 KB |
最終ジャッジ日時 | 2024-07-06 06:03:23 |
合計ジャッジ時間 | 2,139 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 418 ms
86,260 KB |
testcase_01 | AC | 397 ms
85,848 KB |
testcase_02 | AC | 410 ms
85,808 KB |
ソースコード
def _poly_mul(poly1, poly2): ret = [0] * (len(poly1) + len(poly2) - 1) for i in range(len(poly2)): if poly2[i] == 0: continue coef = poly2[i] for j in range(len(poly1)): ret[i + j] += coef * poly1[j] return ret def poly_mul_karatsuba(poly1, poly2, threshold=16): size = len(poly1) if size >= threshold: size_l = (size + 1) // 2 size_h = size - size_l p1 = poly_mul_karatsuba(poly1[:size_h], poly2[:size_h], threshold) p2 = poly_mul_karatsuba(poly1[size_h:], poly2[size_h:], threshold) q1 = poly1[size_h:] q2 = poly2[size_h:] ofs = size_l - size_h for i in range(size_h): q1[ofs + i] += poly1[i] q2[ofs + i] += poly2[i] p3 = poly_mul_karatsuba(q1, q2, threshold) ret = p1 ret.extend([0]) ret.extend(p2) for i in range(size_l * 2 - 1): p3[i] -= ret[2 * size_h + i] for i in range(size_h * 2 - 1): p3[ofs * 2 + i] -= ret[i] ofs = 2 * size - 3 * size_l for i in range(size_l * 2 - 1): ret[ofs + i] += p3[i] return ret else: return _poly_mul(poly1, poly2) def _pack(pack, shamt): size = len(pack) while size > 1: npack = [] for i in range(0, size - 1, 2): npack.append(pack[i] | (pack[i+1] << shamt)) if size & 1: npack.append(pack[-1]) pack = npack size = (size + 1) >> 1 shamt <<= 1 return pack[0] def _pack1(seq, shamt): M = _pack(seq, shamt) size = len(seq) * 2 - 1 block_size = 1 << ilog2(size - 1) return M, shamt * block_size def _pack2(seq1, seq2, shamt): M1 = _pack(seq1, shamt) M2 = _pack(seq2, shamt) size = len(seq1) + len(seq2) - 1 block_size = 1 << ilog2(size - 1) return M1, M2, shamt * block_size def pack_sequence(seq): max_bits = max([c.bit_length() for c in seq]) size = len(seq) shamt = (max_bits * 2 + size.bit_length()) return _pack1(seq, shamt) def pack_sequence_mod(seq, mod): size = len(seq) max_value = (mod - 1) ** 2 * size shamt = max_value.bit_length() return _pack1(seq, shamt) def pack_sequence2(seq1, seq2): max_bits_1 = max([c.bit_length() for c in seq1]) max_bits_2 = max([c.bit_length() for c in seq2]) size = min(len(seq1), len(seq2)) shamt = (max_bits_1 + max_bits_2 + size.bit_length()) return _pack2(seq1, seq2, shamt) def pack_sequence2_mod(seq1, seq2, mod): size = min(len(seq1), len(seq2)) max_value = (mod - 1) ** 2 * size shamt = max_value.bit_length() return _pack2(seq1, seq2, shamt) def unpack_sequence(M, size, shamt): needed_sizes = [] s = size while s > 1: needed_sizes.append(s) s = (s + 1) >> 1 ret = [M] for needed_size in needed_sizes[::-1]: mask = (1 << shamt) - 1 nret = [] for c in ret: nret.append(c & mask) nret.append(c >> shamt) ret = nret[:needed_size] shamt >>= 1 return ret def poly_mul_builtin(poly1, poly2): M1, M2, shamt = pack_sequence2(poly1, poly2) size = len(poly1) + len(poly2) - 1 return unpack_sequence(M1 * M2, size, shamt) def poly_mul(poly1, poly2, threshold=16, use_builtin=False): t = type(poly1[0]) if use_builtin and len(poly1) >= threshold and (t == int or t == long): return poly_mul_builtin(poly1, poly2) else: if len(poly1) == len(poly2): return poly_mul_karatsuba(poly1, poly2, threshold) else: return _poly_mul(poly1, poly2) def poly_square_builtin(poly): M, shamt = pack_sequence(poly) size = len(poly) * 2 - 1 return unpack_sequence(M ** 2, size, shamt) def _poly_square(poly): size = len(poly) ret = [0] * (size * 2 - 1) for i in range(size): ret[2 * i] = poly[i] * poly[i] for i in range(size): coef = 2 * poly[i] for j in range(i + 1, size): ret[i + j] += coef * poly[j] return ret def poly_square_karatsuba(poly, threshold=16): size = len(poly) if size >= threshold: size_l = (size + 1) // 2 size_h = size - size_l p1 = poly_square_karatsuba(poly[:size_h], threshold) p2 = poly_square_karatsuba(poly[size_h:], threshold) S = poly[size_h:] ofs = size_l - size_h for i in range(size_h): S[ofs + i] += poly[i] p3 = poly_square_karatsuba(S, threshold) ret = p1 ret.extend([0]) ret.extend(p2) for i in range(size_l * 2 - 1): p3[i] -= ret[2 * size_h + i] for i in range(size_h * 2 - 1): p3[ofs * 2 + i] -= ret[i] ofs = 2 * size - 3 * size_l for i in range(size_l * 2 - 1): ret[ofs + i] += p3[i] return ret else: return _poly_square(poly) def poly_square(poly, threshold=16, use_builtin=False): t = type(poly[0]) if use_builtin and len(poly) >= threshold and (t == int or t == long): return poly_square_builtin(poly) else: if len(poly) >= threshold: return poly_square_karatsuba(poly) else: return _poly_square(poly) def poly_pow(poly, e, threshold=16): ret = [1] if e == 0: return ret mask = 1 << (e.bit_length() - 1) ret = [1] while mask: if e & mask: ret = poly_mul(ret, poly, threshold, False) mask >>= 1 if not mask: break ret = poly_square(ret, threshold, False) return ret def poly_inverse(poly, size): assert(poly[0] == 1) degs = [] deg = size - 1 while deg: degs.append(deg) deg >>= 1 poly2 = poly[:] if len(poly2) < size: poly2.extend([0] * (size - len(poly2))) inv = [1] for t in degs[::-1]: added = t + 1 - len(inv) tmp = poly_mul(poly2[:t + 1], inv)[len(inv):] tmp = poly_mul(tmp[:added], inv[:added]) inv.extend([-v for v in tmp[:added]]) return inv def poly_mul_mod_ntt(poly1, poly2, mod): p1, p2, p3 = [880803841, 897581057, 998244353] z1, z2, z3 = [273508579, 872686320, 15311432] s1 = len(poly1) s2 = len(poly2) ntt_size = 2 << ilog2(max(s1, s2) * 2 - 1) size = s1 + s2 - 1 A = poly1[:] + [0] * (ntt_size - s1) B = poly2[:] + [0] * (ntt_size - s2) A1 = _ntt_convolve(A[:], B[:], size, p1, z1) A2 = _ntt_convolve(A[:], B[:], size, p2, z2) A3 = _ntt_convolve(A[:], B[:], size, p3, z3) inv = inv_mod(p1, p2) for i in range(size): k = (A2[i] - A1[i]) * inv % p2 A1[i] += k * p1 p12 = p1 * p2 inv = inv_mod(p12, p3) for i in range(size): k = (A3[i] - A1[i]) % p3 * inv % p3 A1[i] = (A1[i] + k * (p12 % mod)) % mod return A1[:size] def poly_square_mod_ntt(poly1, mod): p1, p2, p3 = [880803841, 897581057, 998244353] z1, z2, z3 = [273508579, 872686320, 15311432] s1 = len(poly1) ntt_size = 2 << ilog2(s1 * 2 - 1) size = 2 * s1 - 1 A = poly1[:] + [0] * (ntt_size - s1) A1 = _ntt_convolve_self(A[:], size, p1, z1) A2 = _ntt_convolve_self(A[:], size, p2, z2) A3 = _ntt_convolve_self(A[:], size, p3, z3) inv = inv_mod(p1, p2) for i in range(size): k = (A2[i] - A1[i]) * inv % p2 A1[i] += k * p1 p12 = p1 * p2 inv = inv_mod(p12, p3) for i in range(size): k = (A3[i] - A1[i]) % p3 * inv % p3 A1[i] = (A1[i] + k * (p12 % mod)) % mod return A1[:size] def poly_mul_mod_builtin(poly1, poly2, mod): M1, M2, shamt = pack_sequence2_mod(poly1, poly2, mod) size = len(poly1) + len(poly2) - 1 seq = unpack_sequence(M1 * M2, size, shamt) return [int(x % mod) for x in seq] def poly_square_mod_builtin(poly, mod): M, shamt = pack_sequence_mod(poly, mod) size = len(poly) * 2 - 1 seq = unpack_sequence(M ** 2, size, shamt) return [int(x % mod) for x in seq] def poly_add_mod(poly1, ofs1, poly2, ofs2, size, mod): diff = ofs2 - ofs1 for i in range(ofs1, ofs1 + size): poly1[i] = (poly1[i] + poly2[i + diff]) % mod def poly_sub_mod(poly1, ofs1, poly2, ofs2, size, mod): diff = ofs2 - ofs1 for i in range(ofs1, ofs1 + size): poly1[i] = (poly1[i] - poly2[i + diff]) % mod def poly_mul_mod_karatsuba(poly1, poly2, mod, threshold=128): size = len(poly1) if size >= threshold: size_l = (size + 1) // 2 size_h = size - size_l p1 = poly_mul_mod_karatsuba(poly1[:size_h], poly2[:size_h], mod, threshold) p2 = poly_mul_mod_karatsuba(poly1[size_h:], poly2[size_h:], mod, threshold) q1 = poly1[size_h:] q2 = poly2[size_h:] ofs = size_l - size_h poly_add_mod(q1, ofs, poly1, 0, size_h, mod) poly_add_mod(q2, ofs, poly2, 0, size_h, mod) p3 = poly_mul_mod_karatsuba(q1, q2, mod, threshold) ret = p1 ret.extend([0]) ret.extend(p2) poly_sub_mod(p3, 0, ret, 2 * size_h, size_l * 2 - 1, mod) poly_sub_mod(p3, ofs * 2, ret, 0, size_h * 2 - 1, mod) ofs = 2 * size - 3 * size_l poly_add_mod(ret, ofs, p3, 0, size_l * 2 - 1, mod) return ret else: return _poly_mul_mod(poly1, poly2, mod) def _poly_mul_mod(poly1, poly2, mod): ret = [0] * (len(poly1) + len(poly2) - 1) for i in range(len(poly2)): if poly2[i] == 0: continue coef = poly2[i] for j in range(len(poly1)): ret[i + j] = (ret[i + j] + coef * poly1[j]) % mod return ret def poly_mul_mod(poly1, poly2, mod, threshold=128, ntt_threshold=65536): size1 = len(poly1) size2 = len(poly2) if size1 >= ntt_threshold and size2 >= ntt_threshold and mod <= 2 * 10 ** 9: return poly_mul_mod_ntt(poly1, poly2, mod) else: if size1 <= threshold and size2 <= threshold: return _poly_mul_mod(poly1, poly2, mod) else: return poly_mul_mod_builtin(poly1, poly2, mod) def _poly_square_mod(poly, mod): size = len(poly) ret = [0] * (size * 2 - 1) for i in range(size): ret[2 * i] = poly[i] * poly[i] % mod for i in range(size): coef = 2 * poly[i] for j in range(i + 1, size): ret[i + j] = (ret[i + j] + coef * poly[j]) % mod return ret def poly_square_mod_karatsuba(poly, mod, threshold=64): size = len(poly) if size >= threshold: size_l = (size + 1) // 2 size_h = size - size_l p1 = poly_square_mod_karatsuba(poly[:size_h], mod, threshold) p2 = poly_square_mod_karatsuba(poly[size_h:], mod, threshold) S = poly[size_h:] ofs = size_l - size_h poly_add_mod(S, ofs, poly, 0, size_h, mod) p3 = poly_square_mod_karatsuba(S, mod, threshold) ret = p1 ret.extend([0]) ret.extend(p2) poly_sub_mod(p3, 0, ret, 2 * size_h, size_l * 2 - 1, mod) poly_sub_mod(p3, ofs * 2, ret, 0, size_h * 2 - 1, mod) ofs = 2 * size - 3 * size_l poly_add_mod(ret, ofs, p3, 0, size_l * 2 - 1, mod) return ret else: return _poly_square_mod(poly, mod) def poly_square_mod(poly, mod, threshold=128, k_threshold=64, ntt_threshold=65536): size = len(poly) if size >= ntt_threshold and mod <= 2 * 10 ** 9: return poly_square_mod_ntt(poly, mod) elif size >= threshold: return poly_square_mod_builtin(poly, mod) elif size >= k_threshold: return poly_square_mod_karatsuba(poly, mod) else: return _poly_square_mod(poly, mod) def poly_pow_mod(poly, e, mod): ret = [1] if e == 0: return ret mask = 1 << (e.bit_length() - 1) ret = [1] while mask: if e & mask: ret = poly_mul_mod(ret, poly, mod) mask >>= 1 if not mask: break ret = poly_square_mod(ret, mod) return ret def _poly_rem_mod(poly1, poly2, mod): if len(poly1) < len(poly2): return poly1[:] ret = poly1[:] dif = len(poly1) - len(poly2) + 1 assert(poly2[0] == 1) for i in range(dif): if ret[i] == 0: continue coef = ret[i] % mod for j in range(1, len(poly2)): ret[i + j] = (ret[i + j] - coef * poly2[j]) % mod ret[i] = coef return ret[dif:] def poly_inverse_mod(poly, size, mod): assert(poly[0] == 1) degs = [] deg = size - 1 while deg: degs.append(deg) deg >>= 1 poly2 = poly[:] if len(poly2) < size: poly2.extend([0] * (size - len(poly2))) inv = [1] for t in degs[::-1]: added = t + 1 - len(inv) tmp = poly_mul_mod(poly2[:t + 1], inv, mod)[len(inv):] tmp = poly_mul_mod(tmp[:added], inv[:added], mod) inv.extend([-v % mod for v in tmp[:added]]) return inv def poly_div_mod(poly1, poly2, mod, inverse=[]): assert(len(poly1) >= len(poly2)) assert(poly2[0] == 1) needed_size = len(poly1) - len(poly2) + 1 if len(inverse) == 0: inverse = poly_inverse_mod(poly2, needed_size, mod) assert(len(inverse) >= needed_size) ret = poly_mul_mod(poly1[:needed_size], inverse[:needed_size], mod) return ret[:needed_size] def poly_rem_mod(poly1, poly2, mod, inverse=[]): size1 = len(poly1) size2 = len(poly2) if size1 < size2: return poly1[:] needed_size = size1 - size2 + 1 if len(poly2) < 10 or needed_size < 10: return _poly_rem_mod(poly1, poly2, mod) if len(inverse) == 0: inverse = poly_inverse_mod(poly2, needed_size, mod) poly_q = poly_div_mod(poly1, poly2, mod, inverse) poly_q2 = poly_mul_mod(poly_q, poly2, mod) return [(poly1[i] - poly_q2[i]) % mod for i in range(size1 - size2 + 1, size1)] def poly_power_rem_mod(e, poly_divisor, mod, threshold=32): """ Return x^e % poly_divisor (modulo mod) assume: - deg(poly_divisor) > 0 - mod > 1 """ if e == 0: return [1] ret = [1] mask = 1 << (e.bit_length() - 1) inverse = [] if len(poly_divisor) >= threshold: inverse = poly_inverse_mod(poly_divisor, len(poly_divisor), mod) while mask: if e & mask: ret.append(0) mask >>= 1 if not mask: break ret = poly_square_mod(ret, mod) ret = poly_rem_mod(ret, poly_divisor, mod, inverse) if len(ret) >= len(poly_divisor): ret = poly_rem_mod(ret, poly_divisor, mod, inverse) return ret def pat(dice, P, mod): dp = [[0] * (P * dice[-1] + 1) for _ in range(P + 1)] dp[0][0] = 1 for di, d in enumerate(dice): for i in range(P): for k in range(dice[0] * i, dice[di] * i + 1): if dp[i][k]: dp[i + 1][k + d] = (dp[i + 1][k + d] + dp[i][k]) % mod return dp[-1] def ilog2(n): if n <= 0: return 0 else: return n.bit_length() - 1 import sys def solve(): N, P, C = map(int, sys.stdin.readline().split()) Ps = [2, 3, 5, 7, 11, 13] Cs = [4, 6, 8, 9, 10, 12] mod = 10 ** 9 + 7 Ps = pat(Ps, P, mod) Cs = pat(Cs, C, mod) poly = poly_mul_mod(Ps, Cs, mod) poly[0] = 1 for i in range(1, len(poly)): poly[i] = -poly[i] % mod Max = 13 * P + 12 * C E = max(0, N - Max) inv = poly_inverse_mod(poly, Max, mod) sums = [0] * len(poly) for i in range(1, len(poly)): sums[i] = (sums[i-1] + -poly[i]) % mod poly_rem = poly_power_rem_mod(E, poly, mod) ans = 0 for e in range(E, N): total = 0 for i in range(len(poly_rem)): total = (total + poly_rem[-1 - i] * inv[i]) % mod ans = (ans + total * (sums[Max] - sums[N - e - 1])) % mod poly_rem.extend([0]) poly_rem = poly_rem_mod(poly_rem, poly, mod) print(ans) solve()