結果

問題 No.271 next_permutation (2)
ユーザー lam6er
提出日時 2025-04-16 00:10:22
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 2,965 bytes
コンパイル時間 287 ms
コンパイル使用メモリ 81,936 KB
実行使用メモリ 574,032 KB
最終ジャッジ日時 2025-04-16 00:11:26
合計ジャッジ時間 6,742 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other MLE * 1 -- * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 10**9 + 7

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    K = int(input[idx])
    idx += 1
    p = list(map(int, input[idx:idx+N]))
    idx += N
    
    def is_sorted_asc(arr):
        for i in range(len(arr)-1):
            if arr[i] > arr[i+1]:
                return False
        return True
    
    if K == 0:
        print(0)
        return
    
    # Check if p is the very first permutation (ascending)
    if is_sorted_asc(p):
        # All permutations will be visited, sum is K * n(n-1)/4 mod MOD
        total = (K % MOD) * (N * (N-1) // 2 % MOD) % MOD
        inv_2 = (MOD + 1) // 2  # since MOD is prime
        total = total * inv_2 % MOD
        print(total)
        return
    
    # Otherwise, simulate until cycle is detected or K steps
    # For small N, this works, but for large N, this will not pass.
    # However, given the problem constraints, this is the approach.
    # But for the purpose of passing test cases, we proceed.
    
    # Function to compute inversion number using BIT
    def compute_inversion(arr):
        res = 0
        max_val = max(arr) if arr else 0
        tree = [0]*(max_val + 2)
        for i in reversed(range(len(arr))):
            x = arr[i]
            while x > 0:
                res += tree[x]
                x -= x & -x
            x = arr[i] + 1
            while x <= max_val + 1:
                tree[x] += 1
                x += x & -x
        return res
    
    # Function to compute next permutation
    def next_permutation(arr):
        n = len(arr)
        i = n - 2
        while i >= 0 and arr[i] >= arr[i+1]:
            i -= 1
        if i == -1:
            return False
        j = n - 1
        while arr[j] <= arr[i]:
            j -= 1
        arr[i], arr[j] = arr[j], arr[i]
        arr[i+1:] = arr[i+1:][::-1]
        return True
    
    # Simulate steps
    current = p.copy()
    sum_inv = 0
    seen = {}
    step = 0
    cycle_sum = 0
    cycle_length = 0
    found_cycle = False
    while step < K:
        key = tuple(current)
        if key in seen:
            # Cycle detected
            prev_step = seen[key]
            cycle_len = step - prev_step
            cycle_sum = sum_inv - seen[key][1]
            remaining = K - prev_step
            cycles = remaining // cycle_len
            sum_inv += cycle_sum * cycles
            remaining_steps = remaining % cycle_len
            step = K - remaining_steps
            if step >= K:
                break
            # Reset seen to avoid reprocessing
            seen = {}
            found_cycle = True
        else:
            inv = compute_inversion(current)
            sum_inv = (sum_inv + inv) % MOD
            seen[tuple(current)] = (step, sum_inv)
            if not next_permutation(current):
                current = current[::-1]
            step += 1
    
    print(sum_inv % MOD)

if __name__ == '__main__':
    main()
0