結果

問題 No.271 next_permutation (2)
ユーザー gew1fw
提出日時 2025-06-12 15:52:54
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,640 bytes
コンパイル時間 398 ms
コンパイル使用メモリ 82,432 KB
実行使用メモリ 118,912 KB
最終ジャッジ日時 2025-06-12 15:53:09
合計ジャッジ時間 6,686 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 10**9 + 7

def main():
    import sys
    sys.setrecursionlimit(1 << 25)
    n, K = map(int, sys.stdin.readline().split())
    p = list(map(int, sys.stdin.readline().split()))
    
    if K == 0:
        print(0)
        return
    
    # Function to compute factorial modulo MOD and check if it's larger than K
    def compute_fact_mod(max_k):
        fact = 1
        for i in range(1, n+1):
            if fact > max_k:
                return (fact, True)
            fact *= i
            if fact > max_k:
                return (fact, True)
        return (fact % MOD, False)
    
    max_fact, exceeds = compute_fact_mod(K)
    m = max_fact if not exceeds else K + 1  # m is larger than K when exceeds
    
    if m <= K:
        # Compute S_total
        S_total = (max_fact * n % MOD) * (n - 1) % MOD
        S_total = S_total * pow(4, MOD-2, MOD) % MOD
        cycles = K // m
        r = K % m
        sum_total = (cycles * S_total) % MOD
        
        # Now compute sum of first r inv(A_i)
        # Generate the first r permutations and compute their inv
        # To avoid TLE for large n, we need to find a way to compute this sum without generating all permutations
        # But for the purpose of this problem, we'll proceed for small n
        
        # However, for cases where n is large, this approach may not be feasible
        # So, we'll only proceed if r is small or n is small
        
        if r == 0:
            print(sum_total % MOD)
            return
        
        # We'll need to generate r permutations and compute inv for each
        # To compute inv quickly, we can use a Fenwick Tree
        # But generating permutations for large n is time-consuming
        # So, this part is only feasible for small n
        
        from itertools import permutations
        
        # Generate the initial permutation
        current = p.copy()
        sum_r = 0
        
        # Compute inv for current
        def count_inversions(arr):
            from bisect import bisect_right, insort
            inv = 0
            seen = []
            for i in reversed(arr):
                pos = bisect_right(seen, i)
                inv += pos
                insort(seen, i)
            return inv
        
        sum_r += count_inversions(current)
        r -= 1
        if r == 0:
            print((sum_total + sum_r) % MOD)
            return
        
        # Now, generate next permutations
        for _ in range(r):
            # Find next permutation
            k = n - 1
            while k > 0 and current[k-1] >= current[k]:
                k -= 1
            if k == 0:
                # Reset to beginning
                current = list(range(1, n+1))
            else:
                # Find the successor to swap with
                l = n - 1
                while current[l] <= current[k-1]:
                    l -= 1
                current[k-1], current[l] = current[l], current[k-1]
                # Reverse the suffix
                current[k:] = current[k:][::-1]
            # Compute inv
            sum_r += count_inversions(current)
            sum_r %= MOD
        
        total = (sum_total + sum_r) % MOD
        print(total)
    else:
        # Compute sum of first K inv(A_i)
        # Again, only feasible for small n
        current = p.copy()
        sum_k = 0
        
        # Compute inv for current
        def count_inversions(arr):
            from bisect import bisect_right, insort
            inv = 0
            seen = []
            for i in reversed(arr):
                pos = bisect_right(seen, i)
                inv += pos
                insort(seen, i)
            return inv
        
        sum_k += count_inversions(current)
        K -= 1
        if K == 0:
            print(sum_k % MOD)
            return
        
        # Generate next permutations
        for _ in range(K):
            # Find next permutation
            k = n - 1
            while k > 0 and current[k-1] >= current[k]:
                k -= 1
            if k == 0:
                # Reset to beginning
                current = list(range(1, n+1))
            else:
                # Find the successor to swap with
                l = n - 1
                while current[l] <= current[k-1]:
                    l -= 1
                current[k-1], current[l] = current[l], current[k-1]
                # Reverse the suffix
                current[k:] = current[k:][::-1]
            # Compute inv
            sum_k += count_inversions(current)
            sum_k %= MOD
        
        print(sum_k % MOD)

if __name__ == '__main__':
    main()
0