MOD = 998244353 n, l = map(int, input().split()) m = (n+l-1)//l res = pow(2, m, MOD) print(res-1)