結果
| 問題 | No.3432 popcount & sum (Hard) |
| コンテスト | |
| ユーザー |
mentos_grape
|
| 提出日時 | 2026-01-11 15:37:03 |
| 言語 | Python3 (3.14.2 + numpy 2.4.0 + scipy 1.16.3) |
| 結果 |
AC
|
| 実行時間 | 166 ms / 2,000 ms |
| コード長 | 3,387 bytes |
| 記録 | |
| コンパイル時間 | 736 ms |
| コンパイル使用メモリ | 20,932 KB |
| 実行使用メモリ | 15,364 KB |
| 最終ジャッジ日時 | 2026-01-11 15:37:08 |
| 合計ジャッジ時間 | 4,466 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 16 |
ソースコード
import sys
sys.setrecursionlimit(2000)
def solve():
# 入力
try:
input_data = sys.stdin.read().split()
if not input_data:
return
n = int(input_data[0])
except ValueError:
return
MOD = 998244353
# Nが0の場合は0を出力
if n == 0:
print(0)
return
# 組み合わせ (nCr) の前計算
MAX_BIT = 65
C = [[0] * MAX_BIT for _ in range(MAX_BIT)]
for i in range(MAX_BIT):
C[i][0] = 1
for j in range(1, i + 1):
C[i][j] = (C[i-1][j-1] + C[i-1][j]) % MOD
# リストを使用
counts = [[0] * MAX_BIT for _ in range(MAX_BIT)]
# ビットの長さ
L = n.bit_length()
# 各ビット位置 k について計算
for k in range(L):
current_popcount = 0
# 上位ビットから調査 (L-1 -> k+1)
for i in range(L - 1, k, -1):
if (n >> i) & 1:
free_bits = i - 1
base_popcount = current_popcount + 1 # +1 は kビット目の分
# 自由なビットから r 個選んで 1 にする場合
for r in range(free_bits + 1):
total_c = base_popcount + r
if total_c < MAX_BIT:
counts[k][total_c] += C[free_bits][r]
counts[k][total_c] %= MOD
# nに合わせて1を選ぶ
current_popcount += 1
else:
pass
# ターゲットである k ビット目の処理
if (n >> k) & 1:
current_popcount += 1 # kビット目を1にした分
else:
continue
# 下位ビットの調査 (k-1 -> 0)
for i in range(k - 1, -1, -1):
if (n >> i) & 1:
free_bits = i
base_popcount = current_popcount
for r in range(free_bits + 1):
total_c = base_popcount + r
if total_c < MAX_BIT:
counts[k][total_c] += C[free_bits][r]
counts[k][total_c] %= MOD
# nに合わせて1を選ぶ
current_popcount += 1
else:
pass
# n 自身についての判定
if current_popcount < MAX_BIT:
counts[k][current_popcount] += 1
counts[k][current_popcount] %= MOD
# 総和の計算
total_sum = 0
# 2^k の前計算
pow2 = [1] * (L + 1)
for i in range(1, L + 1):
pow2[i] = (pow2[i-1] * 2) % MOD
for c in range(MAX_BIT):
term_c = 0
for k in range(L):
cnt = counts[k][c]
if cnt > 0:
# (count^2 * 2^k)
val = (cnt * cnt) % MOD
val = (val * pow2[k]) % MOD
term_c = (term_c + val) % MOD
total_sum = (total_sum + term_c) % MOD
# 対角成分 (Sum of a from 0 to n) = n*(n+1)/2
diagonal = (n * (n + 1)) % MOD
diagonal = (diagonal * pow(2, MOD - 2, MOD)) % MOD # divide by 2
# 最終的な答え: (total_sum + diagonal) / 2
ans = (total_sum + diagonal) % MOD
ans = (ans * pow(2, MOD - 2, MOD)) % MOD
print(ans)
if __name__ == '__main__':
solve()
mentos_grape