N,M = map(int, input().split()) mod = 998244353 ans = 0 for i in range(M): ans^=N*pow(2,i) ans = ans%mod print(ans)