MOD = 998244353 n, l = map(int, input().split()) m = (n+l-1)//l print((2**m - 1) % MOD)