""" """ import sys from sys import stdin from collections import deque import heapq def popcnt(X): ret = 0 while X: ret += X % 2 X //= 2 return ret % 2 N,M = map(int,stdin.readline().split()) NR = N.bit_length() mod = 998244353 if M <= 70: ans = 0 for i in range(M): ans ^= N N *= 2 print (ans % mod) else: minbit = 0 NN = N for i in range(70): minbit ^= NN NN *= 2 maxbit = 0 NN = N for i in range(70): maxbit ^= NN NN //= 2 ans = 0 cnt = popcnt(N) if cnt == 0: pass else: ans += ( pow(2,M-NR+1,mod)-1 ) * pow(2,NR-1,mod) ans %= mod for i in range(NR-1): ans += minbit & (2**i) for i in range(1,NR): ans += (maxbit & (2**i)) * pow(2,M-1,mod) ans %= mod print (ans)