結果

問題 No.1978 Permutation Repetition
ユーザー lam6er
提出日時 2025-04-09 20:56:09
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,544 bytes
コンパイル時間 325 ms
コンパイル使用メモリ 82,284 KB
実行使用メモリ 71,284 KB
最終ジャッジ日時 2025-04-09 20:57:15
合計ジャッジ時間 5,353 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 13 WA * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
from math import gcd
from collections import defaultdict

MOD = 10**9 + 7

def main():
    sys.setrecursionlimit(1 << 25)
    n, m = map(int, sys.stdin.readline().split())
    a = list(map(int, sys.stdin.readline().split()))
    a = [x for x in a]

    # Compute cycle decomposition of A
    visited = [False] * (n + 1)
    cycles = defaultdict(int)
    for i in range(1, n + 1):
        if not visited[i]:
            current = i
            length = 0
            while not visited[current]:
                visited[current] = True
                current = a[current - 1]
                length += 1
            cycles[length] += 1

    # Precompute factorial and inverse factorial up to n
    fact = [1] * (n + 1)
    for i in range(1, n + 1):
        fact[i] = fact[i - 1] * i % MOD
    inv_fact = [1] * (n + 1)
    inv_fact[n] = pow(fact[n], MOD - 2, MOD)
    for i in range(n - 1, -1, -1):
        inv_fact[i] = inv_fact[i + 1] * (i + 1) % MOD

    def comb(a, b):
        if a < 0 or b < 0 or a < b:
            return 0
        return fact[a] * inv_fact[b] % MOD * inv_fact[a - b] % MOD

    def get_divisors(num):
        factors = []
        temp = num
        i = 2
        while i * i <= temp:
            if temp % i == 0:
                cnt = 0
                while temp % i == 0:
                    cnt += 1
                    temp //= i
                factors.append((i, cnt))
            i += 1
        if temp > 1:
            factors.append((temp, 1))
        divisors = [1]
        for (p, exp) in factors:
            current_length = len(divisors)
            for e in range(1, exp + 1):
                pe = p ** e
                for d in divisors[:current_length]:
                    divisors.append(d * pe)
        divisors = list(sorted(set(divisors)))
        return divisors

    result = 1

    for l in cycles:
        c = cycles[l]
        if c == 0:
            continue

        # Get all divisors of m
        m_divisors = get_divisors(m)
        valid_s = []
        for s in m_divisors:
            if m % s != 0:
                continue  # should be redundant
            t = m // s
            if gcd(l, t) == 1:
                valid_s.append(s)
        valid_s = [s for s in valid_s if c % s == 0]
        if not valid_s and c > 0:
            print(0)
            return

        dp = [0] * (c + 1)
        dp[0] = 1

        for s in valid_s:
            group_size = s
            max_groups = c // group_size
            if group_size == 0:
                continue

            contribution_per_group = (fact[group_size - 1] * pow(l, group_size, MOD)) % MOD

            for j in range(c, -1, -1):
                if dp[j] == 0:
                    continue
                max_k = (c - j) // group_size
                for k in range(1, max_k + 1):
                    used = k * group_size
                    new_j = j + used
                    if new_j > c:
                        continue
                    ways = comb(j + used, used)
                    ways = ways * pow(contribution_per_group, k, MOD) % MOD

                    denominator = pow(fact[group_size], k, MOD) * fact[k] % MOD
                    denominator_inv = pow(denominator, MOD - 2, MOD)
                    ways = ways * denominator_inv % MOD
                    ways = ways * fact[used] % MOD

                    dp[new_j] = (dp[new_j] + dp[j] * ways) % MOD

        total = dp[c]
        result = result * total % MOD

    print(result % MOD)

if __name__ == "__main__":
    main()
0