結果
| 問題 |
No.1978 Permutation Repetition
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-09 20:56:09 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,544 bytes |
| コンパイル時間 | 325 ms |
| コンパイル使用メモリ | 82,284 KB |
| 実行使用メモリ | 71,284 KB |
| 最終ジャッジ日時 | 2025-04-09 20:57:15 |
| 合計ジャッジ時間 | 5,353 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 13 WA * 31 |
ソースコード
import sys
import math
from math import gcd
from collections import defaultdict
MOD = 10**9 + 7
def main():
sys.setrecursionlimit(1 << 25)
n, m = map(int, sys.stdin.readline().split())
a = list(map(int, sys.stdin.readline().split()))
a = [x for x in a]
# Compute cycle decomposition of A
visited = [False] * (n + 1)
cycles = defaultdict(int)
for i in range(1, n + 1):
if not visited[i]:
current = i
length = 0
while not visited[current]:
visited[current] = True
current = a[current - 1]
length += 1
cycles[length] += 1
# Precompute factorial and inverse factorial up to n
fact = [1] * (n + 1)
for i in range(1, n + 1):
fact[i] = fact[i - 1] * i % MOD
inv_fact = [1] * (n + 1)
inv_fact[n] = pow(fact[n], MOD - 2, MOD)
for i in range(n - 1, -1, -1):
inv_fact[i] = inv_fact[i + 1] * (i + 1) % MOD
def comb(a, b):
if a < 0 or b < 0 or a < b:
return 0
return fact[a] * inv_fact[b] % MOD * inv_fact[a - b] % MOD
def get_divisors(num):
factors = []
temp = num
i = 2
while i * i <= temp:
if temp % i == 0:
cnt = 0
while temp % i == 0:
cnt += 1
temp //= i
factors.append((i, cnt))
i += 1
if temp > 1:
factors.append((temp, 1))
divisors = [1]
for (p, exp) in factors:
current_length = len(divisors)
for e in range(1, exp + 1):
pe = p ** e
for d in divisors[:current_length]:
divisors.append(d * pe)
divisors = list(sorted(set(divisors)))
return divisors
result = 1
for l in cycles:
c = cycles[l]
if c == 0:
continue
# Get all divisors of m
m_divisors = get_divisors(m)
valid_s = []
for s in m_divisors:
if m % s != 0:
continue # should be redundant
t = m // s
if gcd(l, t) == 1:
valid_s.append(s)
valid_s = [s for s in valid_s if c % s == 0]
if not valid_s and c > 0:
print(0)
return
dp = [0] * (c + 1)
dp[0] = 1
for s in valid_s:
group_size = s
max_groups = c // group_size
if group_size == 0:
continue
contribution_per_group = (fact[group_size - 1] * pow(l, group_size, MOD)) % MOD
for j in range(c, -1, -1):
if dp[j] == 0:
continue
max_k = (c - j) // group_size
for k in range(1, max_k + 1):
used = k * group_size
new_j = j + used
if new_j > c:
continue
ways = comb(j + used, used)
ways = ways * pow(contribution_per_group, k, MOD) % MOD
denominator = pow(fact[group_size], k, MOD) * fact[k] % MOD
denominator_inv = pow(denominator, MOD - 2, MOD)
ways = ways * denominator_inv % MOD
ways = ways * fact[used] % MOD
dp[new_j] = (dp[new_j] + dp[j] * ways) % MOD
total = dp[c]
result = result * total % MOD
print(result % MOD)
if __name__ == "__main__":
main()
lam6er