MOD = 998244353 N, M = map(int, input().split()) if N == 0: print(0) exit() # Collect all set bits in N (0-indexed positions) bits = [] for i in range(60): if N & (1 << i): bits.append(i) if not bits: print(0) exit() events = [] for i in bits: s = i e = i + M - 1 events.append((s, 1)) events.append((e + 1, -1)) # Sort events by position, and for the same position, -1 comes before +1 events.sort() result = 0 prev_pos = 0 current_count = 0 for (pos, delta) in events: if prev_pos < pos: if current_count % 2 == 1: # Calculate 2^prev_pos + 2^(prev_pos+1) + ... + 2^(pos-1) a = pow(2, pos, MOD) b = pow(2, prev_pos, MOD) contribution = (a - b) % MOD result = (result + contribution) % MOD # Update current_count current_count += delta # Keep it modulo 2 to prevent large numbers current_count %= 2 prev_pos = pos print(result % MOD)