n, l = map(int, input().split()) print(pow(2, (n+l-1)//l, 998244353) - 1)