結果

問題 No.1555 Constructed Balancing Sequence
ユーザー chineristACchineristAC
提出日時 2021-01-06 03:33:35
言語 PyPy3
(7.3.15)
結果
RE  
(最新)
AC  
(最初)
実行時間 -
コード長 18,352 bytes
コンパイル時間 359 ms
コンパイル使用メモリ 87,176 KB
実行使用メモリ 352,460 KB
最終ジャッジ日時 2023-09-04 22:00:31
合計ジャッジ時間 7,661 ms
ジャッジサーバーID
(参考情報)
judge11 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 81 ms
71,252 KB
testcase_01 AC 663 ms
114,192 KB
testcase_02 AC 149 ms
78,172 KB
testcase_03 AC 145 ms
78,156 KB
testcase_04 AC 145 ms
78,012 KB
testcase_05 AC 148 ms
77,736 KB
testcase_06 AC 133 ms
77,168 KB
testcase_07 AC 113 ms
76,856 KB
testcase_08 AC 112 ms
77,000 KB
testcase_09 AC 120 ms
77,272 KB
testcase_10 AC 101 ms
76,480 KB
testcase_11 AC 137 ms
77,932 KB
testcase_12 RE -
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 TLE -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
testcase_37 -- -
testcase_38 -- -
testcase_39 -- -
testcase_40 -- -
testcase_41 -- -
testcase_42 -- -
testcase_43 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

def solve_NK3(N,K,A):
    diff = [A[0] for i in range(N)]
    S = A[0]
    for i in range(1,N):
        diff[i] = S - A[i]
        if diff[i] < 0:
            return 0
        S += A[i]

    mod = 998244353
    memo = {}
    def dp(i,l,r,s):
        if (i,l,r,s) in memo:
            return memo[i,l,r,s]

        if l>r:
            return 0

        if not i:
            first = diff[0] - s
            if l<=first<=r and -K<=first<=K:
                return 1
            else:
                return 0

        res = 0
        if diff[i]:
            L = max((l+s+diff[i]+1)//2,-K+s+diff[i])
            R = min((r+s+diff[i])//2,K+s+diff[i])
            res += dp(i-1,L,R,0)
            res %= mod
        if s%2==diff[i] and (l+1)%2==1 and -K<=r//2<=K:
            for new_s in range(1,3*K+1):
                res += dp(i-1,r//2-new_s,r//2-new_s,s//2+diff[i]+new_s)
                res %= mod
        if diff[i]<=1 and (l+s-diff[i]+1)//2==(r+s-diff[i])//2:
            m = -K+s-diff[i]-((l+s-diff[i]+1)//2)
            M = K+s-diff[i]-((r+s-diff[i])//2)
            for k in range(max(m,0),min(M,(s-diff[i])//2)+1):
                res += dp(i-1,(r+s-diff[i])//2-k,(r+s-diff[i])//2-k,diff[i]+k)
                res %= mod

        #dp[i-1]のL(S)の範囲
        #diff[i]>=2 のとき
            #-K+diff[i]<=S<=7*K+diff[i]
        #diff[i]==0,1 のとき
            #-K+diff[i]<=S<=7*K+diff[i]
            #-2*K<=S<=K
            #-K<=S<=7*K
            #->-2*K<=S<=7*K+1


        memo[i,l,r,s] = res
        return memo[i,l,r,s]

    return dp(N-1,-N*K,N*K,0)

def solve_NK2_cum_WRONG(N,K,A):
    diff = [A[0] for i in range(N)]
    S = A[0]
    for i in range(1,N):
        diff[i] = S - A[i]
        if diff[i] < 0:
            return 0
        S += A[i]

    for i in range(N-1,-1,-1):
        if diff[i]==0:
            N = i+1
            diff = diff[:i+1]
            break

    diff.append(0)

    mod = 998244353

    dp = [[0 for minus in range(6*K+2)] for dp_S in range(10*K)]
    for dp_S in range(10*K):
        real_S = diff[1] + dp_S - 2*K - 1
        for minus in range(6*K+2):
            first = diff[0] - minus
            if first==real_S and -K<=first<=K:
                dp[dp_S][minus] = 1

    cum = [[dp[dp_S][minus] for minus in range(6*K+2)] for dp_S in range(10*K)]
    for dp_S in range(1,10*K):
        for minus in range(6*K+2):
            if minus<6*K+1:
                cum[dp_S][minus] += cum[dp_S-1][minus+1]

    for i in range(1,N):
        ndp = [[0 for minus in range(6*K+2)] for dp_S in range(10*K)]
        for dp_S in range(10*K):
            real_S = dp_S + diff[i+1] - 2*K - 1
            for minus in range(6*K+2):
                if diff[i]:
                    L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
                    R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
                    if L==R:
                        pre_dp_S = L - diff[i] + 2*K + 1
                        if 0<=pre_dp_S<10*K:
                            ndp[dp_S][minus] += dp[pre_dp_S][0]
                            ndp[dp_S][minus] %= mod
                if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
                    pre_dp_S_R = min(10*K-1,real_S//2-1 - diff[i] + 2*K+1)
                    pre_dp_S_L = max(0,real_S//2-3*K - diff[i] + 2*K+1)
                    if pre_dp_S_L<=pre_dp_S_R:
                        ndp[dp_S][minus] += cum[pre_dp_S_R][real_S//2+minus//2+2*K+1-pre_dp_S_R] - cum[pre_dp_S_L][real_S//2+minus//2+2*K+1-pre_dp_S_L] + dp[pre_dp_S_L][real_S//2+minus//2+2*K+1-pre_dp_S_L]
                        ndp[dp_S][minus] %= mod

                if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
                    m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    L = min(10*K-1,(real_S+minus-diff[i])//2-diff[i]-m+2*K+1)
                    R = max(0,(real_S+minus-diff[i])//2-diff[i]-M+2*K+1)
                    if L>=R:
                        ndp[dp_S][minus] += cum[L][(real_S+minus-diff[i])//2+2*K+1-L] - cum[R][(real_S+minus-diff[i])//2+2*K+1-R] + dp[R][(real_S+minus-diff[i])//2+2*K+1-R]
                        ndp[dp_S][minus] %= mod

        ncum = [[ndp[dp_S][minus] for minus in range(6*K+2)] for dp_S in range(10*K)]
        for dp_S in range(1,10*K):
            for minus in range(6*K+2):
                if minus<6*K+1:
                    ncum[dp_S][minus] += ncum[dp_S-1][minus+1]

        dp,cum = ndp,ncum

    res = 0
    for dp_S in range(1,10*K):
        res += dp[dp_S][0]
        res %= mod
    return res

def solve_NK2_memo(N,K,A):
    diff = [A[0] for i in range(N)]
    S = A[0]
    for i in range(1,N):
        diff[i] = S - A[i]
        if diff[i] < 0:
            return 0
        S += A[i]

    diff.append(0)

    mod = 998244353
    memo = {}

    def dp(i,dp_S,minus):
        if (i,dp_S,minus) in memo:
            return memo[i,dp_S,minus]

        real_S = dp_S + diff[i+1] - 2 * K - 1
        if not i:
            first = diff[0] - minus
            if first==real_S and -K<=first<=K:
                memo[i,dp_S,minus] = 1
                return 1
            else:
                memo[i,dp_S,minus] = 0
                return 0

        res = 0
        if diff[i]:
            L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
            R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
            if L==R:
                pre_dp_S = L - diff[i] + 2 * K + 1
                res += dp(i-1,pre_dp_S,0)
                res %= mod
        if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
            for new_s in range(1,3*K+1):
                L = max((real_S+1-2*new_s)//2,-K-new_s)
                R = min((real_S-2*new_s)//2,K-new_s)
                if L==R:
                    pre_dp_S = L - diff[i] + 2 * K + 1
                    res += dp(i-1,pre_dp_S,minus//2+diff[i]+new_s)
                    res %= mod

        if diff[i]<=1:
            for j in range(minus+1):
                if (minus-j)%2==diff[i]:
                    L = max((real_S+j+1)//2,-K+j)
                    R = min((real_S+j)//2,K+j)
                    if L==R:
                        pre_dp_S = L - diff[i] + 2 * K + 1
                        res += dp(i-1,pre_dp_S,(minus-j)//2+diff[i])
                        res %= mod

        memo[i,dp_S,minus] = res
        return res

    res = 0
    for real_S in range(-2*K-1,N*K+1):
        dp_S = real_S + 2 * K + 1
        res += dp(N-1,dp_S,0)
        res %= mod

    return res

def solve_NK(N,K,A):
    diff = [A[0] for i in range(N)]
    S = A[0]
    for i in range(1,N):
        diff[i] = S - A[i]
        if diff[i] < 0:
            return 0
        S += A[i]

    diff.append(0)

    dp = [{} for i in range(N-1)]
    stack = [(N-2,dp_S) for dp_S in range(10*K+10)]
    while stack:
        i,j = stack.pop()
        if j in dp[i]:
            continue
        dp[i][j] = [0 for k in range(6*K+2)]
        if not i:
            continue
        if diff[i]:
            stack.append((i-1,(j+diff[i+1]-2*K-1+diff[i])//2-diff[i]+2*K+1))
        if diff[i]<=1:
            for k in range(-2,1):
                stack.append((i-1,(j+diff[i+1]-2*K-1+k)//2+2*K+1))
            stack.append((i-1,(j+diff[i+1]-2*K-1-diff[i])//2+2*K+1))

    cum = [{sum:[0 for minus in range(6*K+2)] for sum in dp[i]} for i in range(N-1)]

    mod = 998244353

    for minus in range(6*K+2):
        for sum in dp[0]:
            dp_S = sum - minus
            real_S = diff[1] + dp_S - 2*K - 1
            first = diff[0] - minus
            if first==real_S and -K<=first<=K:
                dp[0][sum][minus] = 1

    for sum in dp[0]:
        cum[0][sum][0] = dp[0][sum][0]
        for minus in range(1,6*K+2):
            cum[0][sum][minus] = dp[0][sum][minus] + cum[0][sum][minus-1]
            cum[0][sum][minus] %= mod

    for i in range(1,N-1):
        for minus in range(6*K+2):
            for sum in dp[i]:
                dp_S = sum - minus
                real_S = dp_S + diff[i+1] - 2*K - 1

                if diff[i]:
                    L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
                    R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
                    if L==R:
                        pre_dp_S = L - diff[i] + 2*K + 1
                        dp[i][sum][minus] += dp[i-1][pre_dp_S][0]
                        dp[i][sum][minus] %= mod

                if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
                    pre_minus_L = max(0,minus//2+1+diff[i])
                    pre_minus_R = min(6*K+1,minus//2+3*K+diff[i])
                    if pre_minus_L<=pre_minus_R:
                        pre_sum = real_S//2+minus//2+2*K+1
                        dp[i][sum][minus] += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
                        dp[i][sum][minus] %= mod

                if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
                    m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    pre_minus_L = max(0,m+diff[i])
                    pre_minus_R = min(6*K+1,M+diff[i])
                    if pre_minus_R>=pre_minus_L:
                        pre_sum = (real_S+minus-diff[i])//2+2*K+1
                        dp[i][sum][minus] += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
                        dp[i][sum][minus] %= mod

        for sum in cum[i]:
            cum[i][sum][0] = dp[i][sum][0]
            for minus in range(1,6*K+2):
                cum[i][sum][minus] = cum[i][sum][minus-1] + dp[i][sum][minus]
                cum[i][sum][minus] %= mod


    res = 0
    for real_S in range(-N*K,N*K+1):
        minus = 0
        i = N - 1
        if diff[i]:
            L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
            R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
            if L==R:
                pre_dp_S = L - diff[i] + 2*K + 1
                res += dp[i-1][pre_dp_S][0]
                res %= mod

        if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
            pre_minus_L = max(0,minus//2+1+diff[i])
            pre_minus_R = min(6*K+1,minus//2+3*K+diff[i])
            if pre_minus_L<=pre_minus_R:
                pre_sum = real_S//2+minus//2+2*K+1
                res += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
                res %= mod

        if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
            m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
            M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
            pre_minus_L = max(0,m+diff[i])
            pre_minus_R = min(6*K+1,M+diff[i])
            if pre_minus_R>=pre_minus_L:
                pre_sum = (real_S+minus-diff[i])//2+2*K+1
                res += cum[i-1][pre_sum][pre_minus_R] - cum[i-1][pre_sum][pre_minus_L] + dp[i-1][pre_sum][pre_minus_L]
                res %= mod

    return res

def solve_NK_Constant_Good(N,K,A):
    diff = [A[0] for i in range(N)]
    S = A[0]
    for i in range(1,N):
        diff[i] = S - A[i]
        if diff[i] < 0:
            return 0
        S += A[i]

    diff.append(0)

    dp = [{} for i in range(N-1)] + [{real_S-diff[N]+2*K+1+0:[0] for real_S in range(-2*K,N*K+1)}]
    stack = [(N-2,dp_S) for dp_S in range(K-10,3*K+10)]
    while stack:
        i,j = stack.pop()
        if j in dp[i]:
            continue
        dp[i][j] = [0 for k in range(4*K+1)]
        if not i:
            continue
        if diff[i]:
            stack.append((i-1,(j+diff[i+1]-2*K-1+diff[i])//2-diff[i]+2*K+1))
        if diff[i]<=1:
            for k in range(-2,1):
                stack.append((i-1,(j+diff[i+1]-2*K-1+k)//2+2*K+1))
            stack.append((i-1,(j+diff[i+1]-2*K-1-diff[i])//2+2*K+1))

    mod = 998244353

    for minus in range(4*K+1):
        for sum in dp[0]:
            dp_S = sum - minus
            real_S = diff[1] + dp_S - 2*K - 1
            first = diff[0] - minus
            if first==real_S and -K<=first<=K:
                dp[0][sum][minus] = 1
            if minus:
                dp[0][sum][minus] += dp[0][sum][minus-1]
                dp[0][sum][minus] %= mod

    for i in range(1,N):
        for sum in dp[i]:
            for minus in range(len(dp[i][sum])):
                dp_S = sum - minus
                real_S = dp_S + diff[i+1] - 2*K - 1

                if diff[i]:
                    L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
                    R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
                    if L==R:
                        pre_dp_S = L - diff[i] + 2*K + 1
                        dp[i][sum][minus] += dp[i-1][pre_dp_S][0]
                        dp[i][sum][minus] %= mod

                if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
                    pre_minus_L = max(0,minus//2+1+diff[i])
                    pre_minus_R = min(4*K,minus//2+3*K+diff[i])
                    if pre_minus_L<=pre_minus_R:
                        pre_sum = real_S//2+minus//2+2*K+1
                        dp[i][sum][minus] += dp[i-1][pre_sum][pre_minus_R] - dp[i-1][pre_sum][pre_minus_L-1] * (pre_minus_L>0)
                        dp[i][sum][minus] %= mod

                if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
                    m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    pre_minus_L = max(0,m+diff[i])
                    pre_minus_R = min(4*K,M+diff[i])
                    if pre_minus_R>=pre_minus_L:
                        pre_sum = (real_S+minus-diff[i])//2+2*K+1
                        dp[i][sum][minus] += dp[i-1][pre_sum][pre_minus_R] - dp[i-1][pre_sum][pre_minus_L-1] * (pre_minus_L>0)
                        dp[i][sum][minus] %= mod

                if minus:
                    dp[i][sum][minus] += dp[i][sum][minus-1]
                    dp[i][sum][minus] %= mod

    res = 0
    for dp_S in dp[N-1]:
        res += dp[N-1][dp_S][0]
        res %= mod
    return res

def solve_NK_memory(N,K,A):
    diff = [A[0] for i in range(N)]
    S = A[0]
    for i in range(1,N):
        diff[i] = S - A[i]
        if diff[i] < 0:
            return 0
        S += A[i]

    diff.append(0)

    dp_sum_set = [set() for i in range(N-1)] + [set([real_S-diff[N]+2*K+1+0 for real_S in range(-2*K,N*K+1)])]
    stack = [(N-2,dp_S) for dp_S in range(K-10,3*K+10)]
    while stack:
        i,j = stack.pop()
        if j in dp_sum_set[i]:
            continue
        dp_sum_set[i].add(j)
        if not i:
            continue
        if diff[i]:
            stack.append((i-1,(j+diff[i+1]-2*K-1+diff[i])//2-diff[i]+2*K+1))
        if diff[i]<=1:
            for k in range(-2,1):
                stack.append((i-1,(j+diff[i+1]-2*K-1+k)//2+2*K+1))
            stack.append((i-1,(j+diff[i+1]-2*K-1-diff[i])//2+2*K+1))

    dp_sum_to_idx = [{sum:-1 for sum in dp_sum_set[i]} for i in range(N)]
    for i in range(N):
        dp_sum_set[i] = list(dp_sum_set[i])
        for j in range(len(dp_sum_set[i])):
            dp_sum_to_idx[i][dp_sum_set[i][j]] = j

    mod = 998244353

    dp = [[0 for i in range(3*K+1)] for sum in dp_sum_set[0]]

    for minus in range(3*K+1):
        for sum in dp_sum_set[0]:
            idx = dp_sum_to_idx[0][sum]
            dp_S = sum - minus
            real_S = diff[1] + dp_S - 2*K - 1
            first = diff[0] - minus
            if first==real_S and -K<=first<=K:
                dp[idx][minus] = 1
            if minus:
                dp[idx][minus] += dp[idx][minus-1]
                dp[idx][minus] %= mod

    for i in range(1,N):
        if i!=N-1:
            next_dp = [[0 for i in range(3*K+1)] for sum in dp_sum_set[i]]
        else:
            next_dp = [[0 for i in range(1)] for sum in dp_sum_set[i]]

        for sum in dp_sum_set[i]:
            idx = dp_sum_to_idx[i][sum]
            for minus in range(len(next_dp[idx])):
                dp_S = sum - minus
                real_S = dp_S + diff[i+1] - 2*K - 1

                if diff[i]:
                    L = max((real_S+minus+diff[i]+1)//2,-K+minus+diff[i])
                    R = min((real_S+minus+diff[i])//2,K+minus+diff[i])
                    if L==R:
                        pre_dp_S = L - diff[i] + 2*K + 1
                        pre_dp_S = dp_sum_to_idx[i-1][pre_dp_S]
                        next_dp[idx][minus] += dp[pre_dp_S][0]
                        next_dp[idx][minus] %= mod

                if minus%2==diff[i] and real_S%2==0 and -K<=real_S//2<=K:
                    pre_minus_L = max(0,minus//2+1+diff[i])
                    pre_minus_R = min(3*K,minus//2+3*K+diff[i])
                    if pre_minus_L<=pre_minus_R:
                        pre_sum = real_S//2+minus//2+2*K+1
                        pre_sum = dp_sum_to_idx[i-1][pre_sum]
                        next_dp[idx][minus] += dp[pre_sum][pre_minus_R] - dp[pre_sum][pre_minus_L-1] * (pre_minus_L>0)
                        next_dp[idx][minus] %= mod

                if diff[i]<=1 and (real_S+minus-diff[i])%2==0:
                    m = max(0,-K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    M = min((minus-diff[i])//2,K+minus-diff[i]-(real_S+minus-diff[i])//2)
                    pre_minus_L = max(0,m+diff[i])
                    pre_minus_R = min(3*K,M+diff[i])
                    if pre_minus_R>=pre_minus_L:
                        pre_sum = (real_S+minus-diff[i])//2+2*K+1
                        pre_sum = dp_sum_to_idx[i-1][pre_sum]
                        next_dp[idx][minus] += dp[pre_sum][pre_minus_R] - dp[pre_sum][pre_minus_L-1] * (pre_minus_L>0)
                        next_dp[idx][minus] %= mod

                if minus:
                    next_dp[idx][minus] += next_dp[idx][minus-1]
                    next_dp[idx][minus] %= mod

        dp = next_dp

    res = 0
    for dp_S in dp_sum_set[N-1]:
        idx = dp_sum_to_idx[N-1][dp_S]
        res += dp[idx][0]
        res %= mod
    return res

N,K = map(int,input().split())
A = list(map(int,input().split()))
print(solve_NK_memory(N,K,A))
0