from sys import stdin mod = 998244353 N,L = map(int,stdin.readline().split()) N = (N+(L-1)) // L print ((pow(2,N,mod)-1) % mod)