結果

問題 No.3432 popcount & sum (Hard)
コンテスト
ユーザー detteiuu
提出日時 2026-01-29 20:07:24
言語 PyPy3
(7.3.17)
結果
AC  
実行時間 359 ms / 2,000 ms
コード長 1,722 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 322 ms
コンパイル使用メモリ 82,288 KB
実行使用メモリ 97,864 KB
最終ジャッジ日時 2026-01-29 20:07:29
合計ジャッジ時間 5,166 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 16
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

MOD = 998244353

N = int(input())

L = N.bit_length()

dp = [[[[[0]*(L+1) for _ in range(2)] for _ in range(2)] for _ in range(L+1)] for _ in range(L+1)]
dp2 = [[[[[0]*(L+1) for _ in range(2)] for _ in range(2)] for _ in range(L+1)] for _ in range(L+1)]
dp[0][0][0][0][0] = 1
for i in range(L):
    c = L-1-i
    for j in range(L+1):
        for k in range(L+1):
            for smallA in range(2):
                for smallB in range(2):
                    if dp[j][k][smallA][smallB][i] == 0:
                        continue
                    for l in range(2):
                        if not 1<<c & N and smallA == 0 and l == 1:
                            continue
                        for m in range(2):
                            if not 1<<c & N and smallB == 0 and m == 1:
                                continue
                            dp[j+l][k+m][smallA or N>>c & 1 == 1 and l == 0][smallB or N>>c & 1 == 1 and m == 0][i+1] += dp[j][k][smallA][smallB][i]
                            dp[j+l][k+m][smallA or N>>c & 1 == 1 and l == 0][smallB or N>>c & 1 == 1 and m == 0][i+1] %= MOD
                            add = (1<<c)%MOD if l == m == 1 else 0
                            dp2[j+l][k+m][smallA or N>>c & 1 == 1 and l == 0][smallB or N>>c & 1 == 1 and m == 0][i+1] += (dp2[j][k][smallA][smallB][i]+add*dp[j][k][smallA][smallB][i])%MOD
                            dp2[j+l][k+m][smallA or N>>c & 1 == 1 and l == 0][smallB or N>>c & 1 == 1 and m == 0][i+1] %= MOD

ans = 0
for i in range(L+1):
    for j in range(2):
        for k in range(2):
            ans += dp2[i][i][j][k][-1]
            ans %= MOD
ans += (1+N)*N%MOD*pow(2, -1, MOD)%MOD
ans %= MOD
ans *= pow(2, -1, MOD)
ans %= MOD

print(ans)
0