MOD = 998244353 def main(): import sys data = sys.stdin.read().split() if not data: return N = int(data[0]) if N == 0: print(0) return L_max = N.bit_length() B_N = 0 for k in range(0, L_max): power_k = 1 << k power_k1 = 1 << (k + 1) full_cycles = N // power_k1 count_full = full_cycles * power_k remainder = N % power_k1 if remainder >= power_k: count_rem = remainder - power_k + 1 else: count_rem = 0 B_N += count_full + count_rem term1 = (L_max - 1) * N term2 = L_max + 1 term3 = 1 << L_max S_N = term1 + term2 - term3 + B_N S_N %= MOD if S_N < 0: S_N += MOD print(S_N) if __name__ == '__main__': main()