import java.util.*; import java.lang.*; import java.io.*; public class Main { public static void main (String[] args) throws java.lang.Exception { // your code goes here // 入力 Scanner sc = new Scanner(System.in); long N = sc.nextLong(); long L = sc.nextLong(); long DIV = 998244353L; // 計算 long disc = N / L; if(N % L != 0){ disc++; } long ope = myPow(2L, disc, DIV); ope--; // 出力 System.out.println(ope); } public static long myPow(long x, long n, long m){ if(n == 0){ return 1; } if(n%2 == 0){ return myPow(x * x % m, n/2, m); }else{ return x * myPow(x, n-1, m) % m; } } }