def main(): n, k = map(int, input().split()) mod = 998244353 ans = (pow(2, (n + k - 1) // k, mod) - 1) % mod print(ans) if __name__ == "__main__": main()