MOD = 998244353 n, m = map(int, input().split()) if n == 0: print(0) exit() if m <= 100: tot = 0 for _ in range(m): tot ^= n n <<= 1 print(tot % MOD) else: le = n.bit_length() tot = 0 for _ in range(le): tot ^= n n <<= 1 ans = 0 max_ = le + m - 2 for i in range(le - 1): ans += (tot >> i & 1) * pow(2, i, MOD) ans %= MOD i = le - 1 ans -= (tot >> i & 1) * pow(2, i, MOD) ans %= MOD i = le ans += (tot >> (i - 1) & 1) * pow(2, max_ - (2 * le - 2 - i), MOD) ans %= MOD for i in range(le, 2 * le - 1): ans += (tot >> i & 1) * pow(2, max_ - (2 * le - 2 - i), MOD) ans %= MOD print(ans)