import sys def input(): return sys.stdin.readline().strip() def mapint(): return map(int, input().split()) sys.setrecursionlimit(10**9) N, L = mapint() mod = 998244353 N = -(-N//L) print((pow(2, N, mod)-1)%mod)