import sys input = sys.stdin.readline N, L = map(int, input().split()) N = (N+L-1)//L MOD = 998244353 print((pow(2, N, MOD)-1)%MOD)