結果

問題 No.1978 Permutation Repetition
ユーザー lam6er
提出日時 2025-04-09 20:56:07
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,852 bytes
コンパイル時間 597 ms
コンパイル使用メモリ 82,364 KB
実行使用メモリ 76,980 KB
最終ジャッジ日時 2025-04-09 20:56:12
合計ジャッジ時間 4,330 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 32 WA * 12
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 10**9 + 7

def main():
    import sys
    sys.setrecursionlimit(1 << 25)
    n, m = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))
    A = [x-1 for x in A]  # convert to 0-based

    # Decompose A into cycles
    visited = [False] * n
    cycles = {}
    for i in range(n):
        if not visited[i]:
            current = i
            cycle = []
            while not visited[current]:
                visited[current] = True
                cycle.append(current)
                current = A[current]
            l = len(cycle)
            if l not in cycles:
                cycles[l] = 0
            cycles[l] += 1

    # Precompute factorial and inverse factorial up to max possible k_l (n)
    max_fact = 10**3
    fact = [1] * (max_fact + 1)
    for i in range(1, max_fact + 1):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (max_fact + 1)
    inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
    for i in range(max_fact-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD

    # Function to factorize m and generate all divisors
    def factor(m):
        factors = {}
        i = 2
        while i*i <= m:
            while m % i == 0:
                factors[i] = factors.get(i, 0) + 1
                m //= i
            i += 1
        if m > 1:
            factors[m] = 1
        return factors

    def get_divisors(factors):
        divisors = [1]
        for p, exp in factors.items():
            temp = []
            for e in range(1, exp+1):
                pe = p ** e
                for d in divisors:
                    temp.append(d * pe)
            divisors += temp
        return divisors

    factors = factor(m)
    divisors = get_divisors(factors)
    divisors = list(set(divisors))  # Deduplicate
    divisors.sort()

    total_answer = 1

    for l, kl in cycles.items():
        # Find allowed d's: d must divide m, and gcd(l, m/d) ==1
        allowed_ds = []
        for d in divisors:
            m_div_d = m // d
            if (l == 0 or m_div_d == 0):
                continue
            if (gcd(l, m_div_d) == 1):
                allowed_ds.append(d)
        if not allowed_ds:
            if kl == 0:
                continue
            else:
                print(0)
                return

        # Now, use dynamic programming to compute the possible combinations
        dp = [0] * (kl + 1)
        dp[0] = 1
        for d in allowed_ds:
            m_div_d = m // d
            if gcd(l, m_div_d) != 1:
                continue  # skip invalid d (shouldn't happen)
            # Compute a = (l^(d-1) * (d-1)! ) / d
            numerator = pow(l, d-1, MOD) if d >=1 else 0
            if d ==0:
                continue
            factorial_d_minus_1 = fact[d-1] if d-1 <= max_fact else 1
            a = numerator * factorial_d_minus_1 % MOD
            inv_d = pow(d, MOD-2, MOD)
            a = a * inv_d % MOD
            t_max = kl // d
            # Precompute a_t for t from 0 to t_max: a^t / t!
            a_t = [1]*(t_max+1)
            for t in range(1, t_max+1):
                a_t[t] = a_t[t-1] * a % MOD
                inv_t = pow(t, MOD-2, MOD)
                a_t[t] = a_t[t] * inv_t % MOD
            # Update dp (backward)
            for j in range(kl, -1, -1):
                if dp[j] ==0:
                    continue
                for t in range(1, t_max+1):
                    next_j = j + d * t
                    if next_j > kl:
                        break
                    dp[next_j] = (dp[next_j] + dp[j] * a_t[t]) % MOD
        # After processing all d, multiply by kl!
        res = dp[kl] * fact[kl] % MOD
        total_answer = total_answer * res % MOD

    print(total_answer)

def gcd(a, b):
    while b:
        a, b = b, a % b
    return a

if __name__ == "__main__":
    main()
0