import java.util.*; public class Main { static final int MOD = 998244353; public static void main(String[] args) { Scanner sc = new Scanner(System.in); long n = sc.nextLong(); long l = sc.nextLong(); long count = (n + l - 1) / l; System.out.println((pow(2, count) - 1 + MOD) % MOD); } static long pow(long x, long y) { if (y == 0) { return 1; } else if (y % 2 == 0) { return pow(x * x % MOD, y / 2); } else { return pow(x, y - 1) * x % MOD; } } }