n, m = map(int, input().split()) n = -(-n // m) mod = 998244353 """#def pow(x, n): ans = 1 while n: if n % 2: ans *= x ans %= mod x *= x n >>= 1 return ans % mod""" print((pow(2, n, mod=mod) - 1))