N, K = map(int, input().split()) print(pow(2, -(-N // K), 998244353) - 1)