N,L = map(int,input().split()) P = 998244353 n = (N + L -1) // L print((pow(2,n,P)-1)%P)