import sys,random,bisect from collections import deque,defaultdict import heapq from itertools import permutations from math import gcd input = lambda :sys.stdin.readline().rstrip() mi = lambda :map(int,input().split()) li = lambda :list(mi()) mod = 998244353 N,M = mi() if N==0: exit(print(0)) n = N.bit_length() if M <= 50: res = 0 for i in range(M): res ^= N<>i d = 0 for i in range(n): d ^= (N>>i)&1 """ n+M-1桁のうち上n桁がupper,下n桁がlower 残り2^n~2^(M-2)までがd """ res = upper*pow(2,M-1,mod)+lower + d * pow(2,n,mod) * (pow(2,M-1-n,mod)-1) % mod res %= mod print(res) def brute(N,M): res = 0 for i in range(M): res ^= N<