import java.util.*; import java.lang.*; import java.io.*; public class Main { public static void main (String[] args) throws java.lang.Exception { // your code goes here // 入力 Scanner sc = new Scanner(System.in); long N = sc.nextLong(); long L = sc.nextLong(); long DIV = 998244353L; // 計算 long disc = N / L; if(N % L != 0){ disc++; } long ope = 1L; for(int i = 1; i <= disc; i++){ ope = (ope * 2) % DIV; } ope--; // 出力 System.out.println(ope); } }