結果

問題 No.2959 Dolls' Tea Party
ユーザー gew1fw
提出日時 2025-06-12 21:18:09
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,997 bytes
コンパイル時間 212 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 129,672 KB
最終ジャッジ日時 2025-06-12 21:18:27
合計ジャッジ時間 7,183 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 5 TLE * 1 -- * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def main():
    from collections import defaultdict

    N, K = map(int, sys.stdin.readline().split())
    A = list(map(int, sys.stdin.readline().split()))

    # Precompute factorial and inverse factorial up to 1300
    max_m = 1300
    factorial = [1] * (max_m + 1)
    for i in range(1, max_m + 1):
        factorial[i] = factorial[i - 1] * i % MOD
    inv_fact = [1] * (max_m + 1)
    inv_fact[max_m] = pow(factorial[max_m], MOD - 2, MOD)
    for i in range(max_m - 1, -1, -1):
        inv_fact[i] = inv_fact[i + 1] * (i + 1) % MOD

    # Function to get divisors of K
    def get_divisors(n):
        divisors = []
        for i in range(1, int(n**0.5) + 1):
            if n % i == 0:
                divisors.append(i)
                if i != n // i:
                    divisors.append(n // i)
        divisors.sort()
        return divisors

    divisors = get_divisors(K)
    phi = {}
    for d in divisors:
        m = d
        p = m
        for i in range(2, int(m**0.5) + 1):
            if m % i == 0:
                p = p // i * (i - 1)
                while m % i == 0:
                    m //= i
        if m > 1:
            p = p // m * (m - 1)
        phi[d] = p

    total = 0

    for d in divisors:
        m = K // d
        if m == 0:
            continue

        B = []
        S_count = 0
        T_count = 0
        groups_S = defaultdict(int)
        for a in A:
            bi = a // d
            bi = min(bi, m)
            if bi < m:
                groups_S[bi] += 1
            else:
                T_count += 1

        # Compute P_S(x)
        P = [0] * (m + 1)
        P[0] = 1
        for b, cnt in groups_S.items():
            if cnt == 0:
                continue
            # Generate the polynomial for b: sum_{c=0}^b x^c / c!
            poly = [0] * (b + 1)
            for c in range(b + 1):
                poly[c] = inv_fact[c]
            # Compute poly^cnt using exponentiation by squaring
            current = [1] + [0] * m
            exp = cnt
            while exp > 0:
                if exp % 2 == 1:
                    new_current = [0] * (len(current) + len(poly) - 1)
                    for i in range(len(current)):
                        if current[i] == 0:
                            continue
                        for j in range(len(poly)):
                            if i + j > m:
                                break
                            new_current[i + j] = (new_current[i + j] + current[i] * poly[j]) % MOD
                    current = new_current[:m + 1]
                poly_sq = [0] * (2 * len(poly) - 1)
                for i in range(len(poly)):
                    for j in range(len(poly)):
                        if i + j > m:
                            break
                        poly_sq[i + j] = (poly_sq[i + j] + poly[i] * poly[j]) % MOD
                poly = poly_sq[:m + 1]
                exp //= 2
            # Multiply current into P
            new_P = [0] * (m + 1)
            for i in range(len(P)):
                if P[i] == 0:
                    continue
                for j in range(len(current)):
                    if i + j > m:
                        break
                    new_P[i + j] = (new_P[i + j] + P[i] * current[j]) % MOD
            P = new_P

        # Compute Q(x) = (sum_{c=0}^m x^c / c! )^T_count
        sum_poly = [0] * (m + 1)
        for c in range(m + 1):
            sum_poly[c] = inv_fact[c]
        Q = [0] * (m + 1)
        if T_count == 0:
            Q[0] = 1
        else:
            Q = [1] + [0] * m
            exp = T_count
            current_poly = sum_poly.copy()
            while exp > 0:
                if exp % 2 == 1:
                    new_Q = [0] * (len(Q) + len(current_poly) - 1)
                    for i in range(len(Q)):
                        if Q[i] == 0:
                            continue
                        for j in range(len(current_poly)):
                            if i + j > m:
                                break
                            new_Q[i + j] = (new_Q[i + j] + Q[i] * current_poly[j]) % MOD
                    Q = new_Q[:m + 1]
                new_poly = [0] * (2 * len(current_poly) - 1)
                for i in range(len(current_poly)):
                    for j in range(len(current_poly)):
                        if i + j > m:
                            break
                        new_poly[i + j] = (new_poly[i + j] + current_poly[i] * current_poly[j]) % MOD
                current_poly = new_poly[:m + 1]
                exp //= 2

        # Multiply P and Q, take coefficient of x^m
        res = 0
        for s in range(m + 1):
            if s > len(P) - 1 or (m - s) > len(Q) - 1:
                continue
            res = (res + P[s] * Q[m - s]) % MOD

        res = res * factorial[m] % MOD
        total = (total + res * phi[d]) % MOD

    ans = total * pow(K, MOD - 2, MOD) % MOD
    print(ans)

if __name__ == '__main__':
    main()
0