結果
問題 | No.1437 01 Sort |
ユーザー |
|
提出日時 | 2021-03-23 22:27:42 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 610 ms / 2,000 ms |
コード長 | 3,266 bytes |
コンパイル時間 | 138 ms |
コンパイル使用メモリ | 81,968 KB |
実行使用メモリ | 108,748 KB |
最終ジャッジ日時 | 2024-11-26 02:27:40 |
合計ジャッジ時間 | 8,183 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 24 |
ソースコード
from functools import lru_cachedef solve(N, S):if list(S) == sorted(S):return 0S2 = S * 2# get #0 in S[l:r] by zero_cnt[r] - zero_cnt[l]zero_cnt = [0]one_cnt = [0]for c in S2:zero_cnt.append(zero_cnt[-1] + int(c == '0'))one_cnt.append(one_cnt[-1] + int(c == '1'))@lru_cache(10)def calc(l, t, r):left0 = zero_cnt[t + 1] - zero_cnt[l]right1 = one_cnt[r + 1] - one_cnt[t + 1]end = t + right1rem = (N - 1 - end) % Nassert l <= t <= end <= rif end < N:if right1:rskip = int(bool(one_cnt[N + 1] - one_cnt[t + 1]))right1 = right1 + 1 - rskiploop = max(left0, right1)elif t < N:rskip = int(bool(one_cnt[N + 1] - one_cnt[t + 1]))loop = max(left0 - 1, right1 - rskip)else:lskip = int(bool(zero_cnt[N + 1] - zero_cnt[l + 1]))loop = max(left0 - lskip, right1)loop = max(0, loop)return loop * N + remones = []for i, c in enumerate(S2):if c == '1':ones.append(i)answer = 10 ** 18# debug = []t = -1for i in range(one_cnt[N]):l = ones[i]r = ones[i + one_cnt[N] - 1]# debug.append([-1] * l + [calc(l, t, r) for t in range(l, r + 1)])t = max(l, t)while True:c = calc(l, t, r)# print(l, t, r, c)nt = t + 1while nt <= r and c == calc(l, nt, r):nt += 1if nt == r + 1 or c < calc(l, nt, r):breakt = ntanswer = min(answer, c)# print(*debug, sep='\n')return answerdef slow(N, S):parent = dict()sorted_S = sorted(S)q = [(0, list(S))]while True:nq = []for t, Sp in q:# print(Sp, sorted_S)if Sp == sorted_S:tr = []Sp_str = ''.join(Sp)while Sp_str != S:tr.append(Sp_str)Sp_str = parent[Sp_str]tr.append(S)print(tr[::-1])for i in range(len(tr)):k = i % Nprint(tr[-i-1][k:] + tr[-i-1][:k])return tnewSp = [Sp[-1]] + Sp[:-1]nq.append((t + 1, newSp))if ''.join(newSp) not in parent:parent[''.join(newSp)] = ''.join(Sp)newSp = [Sp[0]] + [Sp[-1]] + Sp[1:-1]nq.append((t + 1, newSp))if ''.join(newSp) not in parent:parent[''.join(newSp)] = ''.join(Sp)q = nqdef main():N = int(input())S = input()print(solve(N, S))# print(slow(N, S))def test(t):import randomfor r in range(t):S = bin(r)[2:]N = len(S)# print(N, S)ac = slow(N, S)wa = solve(N, S)# print(ac, wa)assert ac == waS = S.replace('1', '2').replace('0', '1').replace('2', '0')# print(N, S)ac = slow(N, S)wa = solve(N, S)# print(ac, wa)assert ac == wamain()# test(1000)