結果
| 問題 |
No.271 next_permutation (2)
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-04-16 00:16:14 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,674 bytes |
| コンパイル時間 | 318 ms |
| コンパイル使用メモリ | 82,644 KB |
| 実行使用メモリ | 548,544 KB |
| 最終ジャッジ日時 | 2025-04-16 00:17:58 |
| 合計ジャッジ時間 | 4,792 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 8 WA * 9 MLE * 1 -- * 3 |
ソースコード
MOD = 10**9 + 7
def is_last_permutation(p, n):
for i in range(n):
if p[i] != n - i:
return False
return True
def next_permutation(arr):
n = len(arr)
i = n - 2
while i >= 0 and arr[i] >= arr[i+1]:
i -= 1
if i == -1:
return False
j = n - 1
while arr[j] <= arr[i]:
j -= 1
arr[i], arr[j] = arr[j], arr[i]
arr[i+1:] = reversed(arr[i+1:])
return True
def inversion_number(arr):
n = len(arr)
inv = 0
for i in range(n):
for j in range(i+1, n):
if arr[i] > arr[j]:
inv += 1
return inv
def main():
import sys
input = sys.stdin.read().split()
idx = 0
N = int(input[idx])
idx += 1
K = int(input[idx])
idx += 1
p = list(map(int, input[idx:idx+N]))
idx += N
if K == 0:
print(0)
return
if is_last_permutation(p, N):
inv_p = N * (N-1) // 2
sum_cycle = inv_p + 0
full_cycles = K // 2
remainder = K % 2
total = (full_cycles * sum_cycle) % MOD
if remainder:
total = (total + inv_p) % MOD
print(total)
return
if N <= 20:
current = p.copy()
sum_inv = 0
cycle = []
visited = {}
steps = 0
found_cycle = False
while steps < K:
key = tuple(current)
if key in visited:
cycle_start = visited[key]
cycle_length = len(cycle) - cycle_start
remaining = K - steps
cycles = remaining // cycle_length
sum_cycle = sum(cycle[cycle_start:])
total_cycles = cycles
sum_inv += sum_cycle * total_cycles
sum_inv %= MOD
steps += total_cycles * cycle_length
if steps < K:
remainder = K - steps
sum_inv += sum(cycle[cycle_start:cycle_start + remainder])
sum_inv %= MOD
steps += remainder
found_cycle = True
break
visited[key] = len(cycle)
inv = inversion_number(current)
sum_inv = (sum_inv + inv) % MOD
cycle.append(inv)
steps += 1
if not next_permutation(current):
current = list(range(1, N+1))
if not found_cycle:
print(sum_inv % MOD)
else:
print(sum_inv % MOD)
return
else:
term = (N * (N-1) // 4) % MOD
total = (K % MOD) * term % MOD
print(total)
return
if __name__ == "__main__":
main()
lam6er