N,M = map(int,input().split()) P = 998244353 import sys if M <= 70: ans = 0 for i in range(M): ans ^= N << i print(ans % P) exit() if N == 0: print(0) exit() if N == 1: ans = pow(2,M,P) - 1 print(ans % P) exit() k = len(bin(N)) - 2 a = 0 for i in range(k): a ^= N << i a &= (1 << (k-1)) - 1 b = 0 for i in range(k): b ^= N >> i b &= (1 << k) - 1 b >>= 1 t = bin(N).count('1') % 2 ans = b * pow(2,M,P) + t * pow(2,k-1,P) * (pow(2,M + 1 - k,P) - 1) + a print(ans % P)