MOD = 998244353 N, L = map(int, input().split()) C = N//L if N%L>0: C+=1 print(pow(2, C, MOD)-1)