n, m = map(int, input().split()) D = [0 for _ in range(31)] q = n for i in range(31): q, r = divmod(q, 2) D[i] = r mod = 998244353 ans = 0 res = 0 for i in range(31): res += D[i] if i >= m: res -= D[i - m] res %= 2 ans += res * pow(2, i, mod) ans %= mod if m > 30: ans += res * (pow(2, m, mod) - pow(2, 31, mod)) % mod ans %= mod for i in range(31): res -= D[i] res %= 2 ans += res * pow(2, m + i, mod) ans %= mod else: for i in range(31 - m, 31): res -= D[i] res %= 2 ans += res * pow(2, i + 30, mod) ans %= mod print(ans)