#!/usr/bin/env PyPy3 from collections import Counter, defaultdict, deque import itertools import re import math from functools import reduce import operator import bisect from heapq import * import functools mod=998244353 import sys input = sys.stdin.readline n,m=map(int,input().split()) lenn = n.bit_length() binn = bin(n)[2:] lef = [0] * len(binn) rig = [0] * len(binn) for i in range(len(binn)): lef[i] = int(binn[i]) rig[i] = int(binn[~i]) if i: lef[i] ^= lef[i-1] rig[i] ^= rig[i-1] if m < 2 * len(binn) + 5: ans = 0 for i in range(m): ans ^= n * 2 ** i print(ans % mod) else: ans = 0 for i in range(len(binn) - 1): ans += rig[i] * pow(2,i,mod) ans %= mod for i in range(len(binn) - 1): ans += lef[i] * pow(2,m+len(binn)-2-i,mod) ans %= mod ans += rig[-1] * pow(2,m,mod) ans %= mod ans -= rig[-1] * pow(2,len(binn)-1,mod) ans %= mod print(ans)