#include #include #include #include #include #include #include using namespace std; using namespace atcoder; using ll = long long; using mint = modint998244353; using vi = vector; using vvi = vector; using vvvi = vector; using vll = vector; using vvll = vector; using vvvll = vector; using vmi = vector; using vvmi = vector; using vvvmi = vector; #define all(a) (a).begin(), (a).end() #define rep2(i, m, n) for (int i = (m); i < (n); ++i) #define rep(i, n) rep2(i, 0, n) #define drep2(i, m, n) for (int i = (m)-1; i >= (n); --i) #define drep(i, n) drep2(i, n, 0) void solve(){ } mint sum(int n){ return mint(2).pow(n+1) - mint(1); } int main(){ int n, m; cin >> n >> m; if(m < 1000){ bitset<1200> bs(n); bitset<1200> a(0); rep(i, m){ a = a ^ bs; bs = bs << 1; } mint ans = 0, base = mint(1); rep(i, 1200){ ans += mint((int)a[i])*base; base *= mint(2); } cout << ans.val() << endl; return 0; } bitset<32> bs(n); int r = 31; while(!bs[r])r--; if(r == 0){ cout << sum(m).val() << endl; return 0; } mint ans = 0; rep(i, r){ int c = 0, d = 0; rep(j, i+1){ c += bs[j]; d += bs[r-j]; } if(c % 2 != 0)ans += mint(2).pow(i); if(d % 2 != 0)ans += mint(2).pow(m + r -1-i); }int e = bs.count(); if(e % 2 != 0){ ans += sum(m-1) - sum(r-1); } cout << ans.val() << endl; return 0; }