結果
| 問題 |
No.1555 Constructed Balancing Sequence
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-01-05 18:51:59 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
RE
(最新)
AC
(最初)
|
| 実行時間 | - |
| コード長 | 17,790 bytes |
| コンパイル時間 | 175 ms |
| コンパイル使用メモリ | 81,976 KB |
| 実行使用メモリ | 265,424 KB |
| 最終ジャッジ日時 | 2024-06-22 19:28:26 |
| 合計ジャッジ時間 | 6,131 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 10 RE * 5 TLE * 1 -- * 26 |
ソースコード
def solve_NK3(N,K,A):
diff = [A[0] for i in range(N)]
S = A[0]
for i in range(1,N):
diff[i] = S - A[i]
if diff[i] < 0:
return 0
S += A[i]
mod = 998244353
memo = {}
def dp(i,l,r,s):
if (i,l,r,s) in memo:
return memo[i,l,r,s]
if l>r:
return 0
if not i:
first = diff[0] - s
if l<=first<=r and -K<=first<=K:
return 1
else:
return 0
res = 0
if diff[i]:
L = max((l+s+diff[i]+1)//2,-K+s+diff[i])
R = min((r+s+diff[i])//2,K+s+diff[i])
res += dp(i-1,L,R,0)
res %= mod
if s%2==diff[i] and (l+1)%2==1 and -K<=r//2<=K:
for new_s in range(1,3*K+1):
res += dp(i-1,r//2-new_s,r//2-new_s,s//2+diff[i]+new_s)
res %= mod
if diff[i]<=1 and (l+s-diff[i]+1)//2==(r+s-diff[i])//2:
m = -K+s-diff[i]-((l+s-diff[i]+1)//2)
M = K+s-diff[i]-((r+s-diff[i])//2)
for k in range(max(m,0),min(M,(s-diff[i])//2)+1):
res += dp(i-1,(r+s-diff[i])//2-k,(r+s-diff[i])//2-k,diff[i]+k)
res %= mod
#dp[i-1]のL(S)の範囲
#diff[i]>=2 のとき
#-K+diff[i]<=S<=7*K+diff[i]
#diff[i]==0,1 のとき
#-K+diff[i]<=S<=7*K+diff[i]
#-2*K<=S<=K
#-K<=S<=7*K
#->-2*K<=S<=7*K+1
memo[i,l,r,s] = res
return memo[i,l,r,s]
return dp(N-1,-N*K,N*K,0)
def solve_NK2_cum_WRONG(N,K,A):
diff = [A[0] for i in range(N)]
S = A[0]
for i in range(1,N):
diff[i] = S - A[i]
if diff[i] < 0:
return 0
S += A[i]
for i in range(N-1,-1,-1):
if diff[i]==0:
N = i+1
diff = diff[:i+1]
break
diff.append(0)
mod = 998244353
dp = [[0 for minus in range(6*K+2)] for dp_S in range(10*K)]
for dp_S in range(10*K):
real_S = diff[1] + dp_S - 2*K - 1
for minus in range(6*K+2):
first = diff[0] - minus
if first==real_S and -K<=first<=K:
dp[dp_S][minus] = 1
cum = [[dp[dp_S][minus] for minus in range(6*K+2)] for dp_S in range(10*K)]
for dp_S in range(1,10*K):
for minus in range(6*K+2):
if minus<6*K+1:
cum[dp_S][minus] += cum[dp_S-1][minus+1]
for i in range(1,N):
ndp = [[0 for minus in range(6*K+2)] for dp_S in range(10*K)]
for dp_S in range(10*K):
real_S = dp_S + diff[i+1] - 2*K - 1
for minus in range(6*K+2):
if diff[i]:
L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
if L==R:
pre_dp_S = L - diff[i] + 2*K + 1
if 0<=pre_dp_S<10*K:
ndp[dp_S][minus] += dp[pre_dp_S][0]
ndp[dp_S][minus] %= mod
if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
pre_dp_S_R = min(10*K-1,real_S//2-1 - diff[i] + 2*K+1)
pre_dp_S_L = max(0,real_S//2-3*K - diff[i] + 2*K+1)
if pre_dp_S_L<=pre_dp_S_R:
ndp[dp_S][minus] += cum[pre_dp_S_R][real_S//2+minus//2+2*K+1-pre_dp_S_R] - cum[pre_dp_S_L][real_S//2+minus//2+2*K+1-pre_dp_S_L] + dp[pre_dp_S_L][real_S//2+minus//2+2*K+1-pre_dp_S_L]
ndp[dp_S][minus] %= mod
if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
L = min(10*K-1,(real_S+minus-diff[i])//2-diff[i]-m+2*K+1)
R = max(0,(real_S+minus-diff[i])//2-diff[i]-M+2*K+1)
if L>=R:
ndp[dp_S][minus] += cum[L][(real_S+minus-diff[i])//2+2*K+1-L] - cum[R][(real_S+minus-diff[i])//2+2*K+1-R] + dp[R][(real_S+minus-diff[i])//2+2*K+1-R]
ndp[dp_S][minus] %= mod
ncum = [[ndp[dp_S][minus] for minus in range(6*K+2)] for dp_S in range(10*K)]
for dp_S in range(1,10*K):
for minus in range(6*K+2):
if minus<6*K+1:
ncum[dp_S][minus] += ncum[dp_S-1][minus+1]
dp,cum = ndp,ncum
res = 0
for dp_S in range(1,10*K):
res += dp[dp_S][0]
res %= mod
return res
def solve_NK2_memo(N,K,A):
diff = [A[0] for i in range(N)]
S = A[0]
for i in range(1,N):
diff[i] = S - A[i]
if diff[i] < 0:
return 0
S += A[i]
diff.append(0)
mod = 998244353
memo = {}
def dp(i,dp_S,minus):
if (i,dp_S,minus) in memo:
return memo[i,dp_S,minus]
real_S = dp_S + diff[i+1] - 2 * K - 1
if not i:
first = diff[0] - minus
if first==real_S and -K<=first<=K:
memo[i,dp_S,minus] = 1
return 1
else:
memo[i,dp_S,minus] = 0
return 0
res = 0
if diff[i]:
L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
if L==R:
pre_dp_S = L - diff[i] + 2 * K + 1
res += dp(i-1,pre_dp_S,0)
res %= mod
if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
for new_s in range(1,3*K+1):
L = max((real_S+1-2*new_s)//2,-K-new_s)
R = min((real_S-2*new_s)//2,K-new_s)
if L==R:
pre_dp_S = L - diff[i] + 2 * K + 1
res += dp(i-1,pre_dp_S,minus//2+diff[i]+new_s)
res %= mod
if diff[i]<=1:
for j in range(minus+1):
if (minus-j)%2==diff[i]:
L = max((real_S+j+1)//2,-K+j)
R = min((real_S+j)//2,K+j)
if L==R:
pre_dp_S = L - diff[i] + 2 * K + 1
res += dp(i-1,pre_dp_S,(minus-j)//2+diff[i])
res %= mod
memo[i,dp_S,minus] = res
return res
res = 0
for real_S in range(-2*K-1,N*K+1):
dp_S = real_S + 2 * K + 1
res += dp(N-1,dp_S,0)
res %= mod
return res
def solve_NK(N,K,A):
diff = [A[0] for i in range(N)]
S = A[0]
for i in range(1,N):
diff[i] = S - A[i]
if diff[i] < 0:
return 0
S += A[i]
diff.append(0)
dp = [{} for i in range(N-1)]
stack = [(N-2,dp_S) for dp_S in range(10*K+10)]
while stack:
i,j = stack.pop()
if j in dp[i]:
continue
dp[i][j] = [0 for k in range(6*K+2)]
if not i:
continue
if diff[i]:
stack.append((i-1,(j+diff[i+1]-2*K-1+diff[i])//2-diff[i]+2*K+1))
if diff[i]<=1:
for k in range(-2,1):
stack.append((i-1,(j+diff[i+1]-2*K-1+k)//2+2*K+1))
stack.append((i-1,(j+diff[i+1]-2*K-1-diff[i])//2+2*K+1))
cum = [{sum:[0 for minus in range(6*K+2)] for sum in dp[i]} for i in range(N-1)]
mod = 998244353
for minus in range(6*K+2):
for sum in dp[0]:
dp_S = sum - minus
real_S = diff[1] + dp_S - 2*K - 1
first = diff[0] - minus
if first==real_S and -K<=first<=K:
dp[0][sum][minus] = 1
for sum in dp[0]:
cum[0][sum][0] = dp[0][sum][0]
for minus in range(1,6*K+2):
cum[0][sum][minus] = dp[0][sum][minus] + cum[0][sum][minus-1]
cum[0][sum][minus] %= mod
for i in range(1,N-1):
for minus in range(6*K+2):
for sum in dp[i]:
dp_S = sum - minus
real_S = dp_S + diff[i+1] - 2*K - 1
if diff[i]:
L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
if L==R:
pre_dp_S = L - diff[i] + 2*K + 1
dp[i][sum][minus] += dp[i-1][pre_dp_S][0]
dp[i][sum][minus] %= mod
if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
pre_minus_L = max(0,minus//2+1+diff[i])
pre_minus_R = min(6*K+1,minus//2+3*K+diff[i])
if pre_minus_L<=pre_minus_R:
pre_sum = real_S//2+minus//2+2*K+1
dp[i][sum][minus] += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
dp[i][sum][minus] %= mod
if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
pre_minus_L = max(0,m+diff[i])
pre_minus_R = min(6*K+1,M+diff[i])
if pre_minus_R>=pre_minus_L:
pre_sum = (real_S+minus-diff[i])//2+2*K+1
dp[i][sum][minus] += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
dp[i][sum][minus] %= mod
for sum in cum[i]:
cum[i][sum][0] = dp[i][sum][0]
for minus in range(1,6*K+2):
cum[i][sum][minus] = cum[i][sum][minus-1] + dp[i][sum][minus]
cum[i][sum][minus] %= mod
res = 0
for real_S in range(-N*K,N*K+1):
minus = 0
i = N - 1
if diff[i]:
L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
if L==R:
pre_dp_S = L - diff[i] + 2*K + 1
res += dp[i-1][pre_dp_S][0]
res %= mod
if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
pre_minus_L = max(0,minus//2+1+diff[i])
pre_minus_R = min(6*K+1,minus//2+3*K+diff[i])
if pre_minus_L<=pre_minus_R:
pre_sum = real_S//2+minus//2+2*K+1
res += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
res %= mod
if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
pre_minus_L = max(0,m+diff[i])
pre_minus_R = min(6*K+1,M+diff[i])
if pre_minus_R>=pre_minus_L:
pre_sum = (real_S+minus-diff[i])//2+2*K+1
res += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
res %= mod
return res
def solve_NK_Constant_Good(N,K,A):
diff = [A[0] for i in range(N)]
S = A[0]
for i in range(1,N):
diff[i] = S - A[i]
if diff[i] < 0:
return 0
S += A[i]
diff.append(0)
dp = [{} for i in range(N-1)] + [{real_S-diff[N]+2*K+1+0:[0] for real_S in range(-2*K,N*K+1)}]
stack = [(N-2,dp_S) for dp_S in range(K-10,3*K+10)]
while stack:
i,j = stack.pop()
if j in dp[i]:
continue
dp[i][j] = [0 for k in range(4*K+1)]
if not i:
continue
if diff[i]:
stack.append((i-1,(j+diff[i+1]-2*K-1+diff[i])//2-diff[i]+2*K+1))
if diff[i]<=1:
for k in range(-2,1):
stack.append((i-1,(j+diff[i+1]-2*K-1+k)//2+2*K+1))
stack.append((i-1,(j+diff[i+1]-2*K-1-diff[i])//2+2*K+1))
mod = 998244353
for minus in range(4*K+1):
for sum in dp[0]:
dp_S = sum - minus
real_S = diff[1] + dp_S - 2*K - 1
first = diff[0] - minus
if first==real_S and -K<=first<=K:
dp[0][sum][minus] = 1
if minus:
dp[0][sum][minus] += dp[0][sum][minus-1]
dp[0][sum][minus] %= mod
for i in range(1,N):
for sum in dp[i]:
for minus in range(len(dp[i][sum])):
dp_S = sum - minus
real_S = dp_S + diff[i+1] - 2*K - 1
if diff[i]:
L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
if L==R:
pre_dp_S = L - diff[i] + 2*K + 1
dp[i][sum][minus] += dp[i-1][pre_dp_S][0]
dp[i][sum][minus] %= mod
if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
pre_minus_L = max(0,minus//2+1+diff[i])
pre_minus_R = min(4*K,minus//2+3*K+diff[i])
if pre_minus_L<=pre_minus_R:
pre_sum = real_S//2+minus//2+2*K+1
dp[i][sum][minus] += dp[i-1][pre_sum][pre_minus_R] - dp[i-1][pre_sum][pre_minus_L-1] * (pre_minus_L>0)
dp[i][sum][minus] %= mod
if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
pre_minus_L = max(0,m+diff[i])
pre_minus_R = min(4*K,M+diff[i])
if pre_minus_R>=pre_minus_L:
pre_sum = (real_S+minus-diff[i])//2+2*K+1
dp[i][sum][minus] += dp[i-1][pre_sum][pre_minus_R] - dp[i-1][pre_sum][pre_minus_L-1] * (pre_minus_L>0)
dp[i][sum][minus] %= mod
if minus:
dp[i][sum][minus] += dp[i][sum][minus-1]
dp[i][sum][minus] %= mod
res = 0
for dp_S in dp[N-1]:
res += dp[N-1][dp_S][0]
res %= mod
return res
def solve_NK_memory(N,K,A):
diff = [A[0] for i in range(N)]
S = A[0]
for i in range(1,N):
diff[i] = S - A[i]
if diff[i] < 0:
return 0
S += A[i]
diff.append(0)
dp_sum_set = [set() for i in range(N-1)] + [set([real_S-diff[N]+2*K+1+0 for real_S in range(-2*K,N*K+1)])]
stack = [(N-2,dp_S) for dp_S in range(K-10,3*K+10)]
while stack:
i,j = stack.pop()
if j in dp_sum_set[i]:
continue
dp_sum_set[i].add(j)
if not i:
continue
if diff[i]:
stack.append((i-1,(j+diff[i+1]-2*K-1+diff[i])//2-diff[i]+2*K+1))
if diff[i]<=1:
for k in range(-2,1):
stack.append((i-1,(j+diff[i+1]-2*K-1+k)//2+2*K+1))
stack.append((i-1,(j+diff[i+1]-2*K-1-diff[i])//2+2*K+1))
mod = 998244353
dp = {sum:[0 for i in range(3*K+1)] for sum in dp_sum_set[0]}
for minus in range(3*K+1):
for sum in dp:
dp_S = sum - minus
real_S = diff[1] + dp_S - 2*K - 1
first = diff[0] - minus
if first==real_S and -K<=first<=K:
dp[sum][minus] = 1
if minus:
dp[sum][minus] += dp[sum][minus-1]
dp[sum][minus] %= mod
for i in range(1,N):
if i!=N-1:
next_dp = {sum:[0 for i in range(3*K+1)] for sum in dp_sum_set[i]}
else:
next_dp = {sum:[0 for i in range(1)] for sum in dp_sum_set[i]}
for sum in next_dp:
for minus in range(len(next_dp[sum])):
dp_S = sum - minus
real_S = dp_S + diff[i+1] - 2*K - 1
if diff[i]:
L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
if L==R:
pre_dp_S = L - diff[i] + 2*K + 1
next_dp[sum][minus] += dp[pre_dp_S][0]
next_dp[sum][minus] %= mod
if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
pre_minus_L = max(0,minus//2+1+diff[i])
pre_minus_R = min(3*K,minus//2+3*K+diff[i])
if pre_minus_L<=pre_minus_R:
pre_sum = real_S//2+minus//2+2*K+1
next_dp[sum][minus] += dp[pre_sum][pre_minus_R] - dp[pre_sum][pre_minus_L-1] * (pre_minus_L>0)
next_dp[sum][minus] %= mod
if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
pre_minus_L = max(0,m+diff[i])
pre_minus_R = min(3*K,M+diff[i])
if pre_minus_R>=pre_minus_L:
pre_sum = (real_S+minus-diff[i])//2+2*K+1
next_dp[sum][minus] += dp[pre_sum][pre_minus_R] - dp[pre_sum][pre_minus_L-1] * (pre_minus_L>0)
next_dp[sum][minus] %= mod
if minus:
next_dp[sum][minus] += next_dp[sum][minus-1]
next_dp[sum][minus] %= mod
dp = next_dp
res = 0
for dp_S in dp:
res += dp[dp_S][0]
res %= mod
return res
N,K = map(int,input().split())
A = list(map(int,input().split()))
print(solve_NK_memory(N,K,A))