結果

問題 No.562 超高速一人かるた small
ユーザー lam6er
提出日時 2025-03-20 20:51:45
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 89 ms / 3,000 ms
コード長 2,959 bytes
コンパイル時間 159 ms
コンパイル使用メモリ 82,380 KB
実行使用メモリ 76,760 KB
最終ジャッジ日時 2025-03-20 20:51:56
合計ジャッジ時間 2,411 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 21
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 10**9 + 7
from bisect import bisect_left, bisect_right

def main():
    N = int(sys.stdin.readline())
    S = [sys.stdin.readline().strip() for _ in range(N)]
    
    # Compute pairwise l(x, y) for x != y
    l_matrix = [[0] * N for _ in range(N)]
    for x in range(N):
        for y in range(N):
            if x == y:
                continue
            s_x = S[x]
            s_y = S[y]
            pos = 0
            while pos < len(s_x) and pos < len(s_y) and s_x[pos] == s_y[pos]:
                pos += 1
            l_matrix[x][y] = pos + 1
    
    # Precompute combinations C(n, k)
    max_comb = 20
    comb = [[0] * (max_comb + 1) for _ in range(max_comb + 1)]
    comb[0][0] = 1
    for n in range(1, max_comb + 1):
        comb[n][0] = 1
        for k in range(1, n + 1):
            comb[n][k] = (comb[n-1][k] + comb[n-1][k-1]) % MOD
    
    # Precompute factorials and inverse factorials
    fact = [1] * (21)
    for i in range(1, 21):
        fact[i] = fact[i-1] * i % MOD
    inv_fact = [1] * (21)
    inv_fact[20] = pow(fact[20], MOD - 2, MOD)
    for i in range(19, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    def perm(a, b):
        if a < 0 or b < 0 or a < b:
            return 0
        return fact[a] * inv_fact[a - b] % MOD
    
    # Precompute sum_for_s[x][s]
    sum_for_s = [[0] * (N + 2) for _ in range(N)]
    for x in range(N):
        l_list = []
        for y in range(N):
            if y != x:
                l_list.append(l_matrix[x][y])
        sorted_l = sorted(l_list)
        M = len(l_list)
        for s in range(1, N + 1):
            if s == 1:
                sum_for_s[x][s] = 1
                continue
            k = s - 1
            if k > M:
                sum_for_s[x][s] = 0
                continue
            unique = sorted(list(set(sorted_l)))
            current_sum = 0
            for l in unique:
                cnt_le = bisect_right(sorted_l, l)
                cnt_lt = bisect_left(sorted_l, l)
                if cnt_le < k:
                    c_le = 0
                else:
                    c_le = comb[cnt_le][k]
                if cnt_lt < k:
                    c_lt = 0
                else:
                    c_lt = comb[cnt_lt][k]
                delta = (c_le - c_lt) % MOD
                current_sum = (current_sum + l * delta) % MOD
            sum_for_s[x][s] = current_sum
    
    # Process each K from 1 to N
    for K in range(1, N + 1):
        ans = 0
        for m in range(1, K + 1):
            s_size = N - (m - 1)
            if s_size < 1 or s_size > N:
                continue
            total = 0
            for x in range(N):
                total = (total + sum_for_s[x][s_size]) % MOD
            factor = (fact[m-1] * perm(N - m, K - m)) % MOD
            contribution = (factor * total) % MOD
            ans = (ans + contribution) % MOD
        print(ans % MOD)

if __name__ == "__main__":
    main()
0