結果
問題 | No.117 組み合わせの数 |
ユーザー | 草苺奶昔 |
提出日時 | 2023-05-10 17:15:37 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 434 ms / 5,000 ms |
コード長 | 2,724 bytes |
コンパイル時間 | 330 ms |
コンパイル使用メモリ | 82,432 KB |
実行使用メモリ | 232,876 KB |
最終ジャッジ日時 | 2024-05-05 04:26:08 |
合計ジャッジ時間 | 1,636 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ソースコード
# 求组合数 # mod为素数且0<=n,k<min(mod,1e7) class Enumeration: __slots__ = ("_fac", "_ifac", "_inv", "_mod") def __init__(self, size: int, mod: int) -> None: self._mod = mod self._fac = [1] self._ifac = [1] self._inv = [1] self._expand(size) def fac(self, k: int) -> int: self._expand(k) return self._fac[k] def ifac(self, k: int) -> int: self._expand(k) return self._ifac[k] def inv(self, k: int) -> int: """模逆元""" self._expand(k) return self._inv[k] def C(self, n: int, k: int) -> int: if n < 0 or k < 0 or n < k: return 0 mod = self._mod return self.fac(n) * self.ifac(k) % mod * self.ifac(n - k) % mod def P(self, n: int, k: int) -> int: if n < 0 or k < 0 or n < k: return 0 mod = self._mod return self.fac(n) * self.ifac(n - k) % mod def H(self, n: int, k: int) -> int: """可重复选取元素的组合数.""" if n == 0: return 1 if k == 0 else 0 return self.C(n + k - 1, k) def put(self, n: int, k: int) -> int: """n个相同的球放入k个不同的盒子(盒子可放任意个球)的方法数.""" return self.C(n + k - 1, n) def catalan(self, n: int) -> int: """卡特兰数.""" return self.C(2 * n, n) * self.inv(n + 1) % self._mod def _expand(self, size: int) -> None: size = min(size, self._mod - 1) if len(self._fac) < size + 1: mod = self._mod preSize = len(self._fac) diff = size + 1 - preSize self._fac += [1] * diff self._ifac += [1] * diff self._inv += [1] * diff for i in range(preSize, size + 1): self._fac[i] = self._fac[i - 1] * i % mod self._ifac[size] = pow(self._fac[size], mod - 2, mod) # !modInv for i in range(size - 1, preSize - 1, -1): self._ifac[i] = self._ifac[i + 1] * (i + 1) % mod for i in range(preSize, size + 1): self._inv[i] = self._ifac[i] * self._fac[i - 1] % mod if __name__ == "__main__": # https://yukicoder.me/problems/no/117 import sys sys.setrecursionlimit(int(1e9)) input = lambda: sys.stdin.readline().rstrip("\r\n") T = int(input()) C = Enumeration(10**6 + 10, 10**9 + 7) for _ in range(T): s = input() op = s[0] inner = s[2:-1] n, k = map(int, inner.split(",")) if op == "C": print(C.C(n, k)) elif op == "P": print(C.P(n, k)) elif op == "H": print(C.H(n, k))