import java.util.Scanner; public class Main { public static void main(String[] args) throws Exception { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); sc.close(); int mod = 998244353; long[] dp = new long[n + 1]; long[] sum = new long[n + 1]; dp[0] = 1; sum[0] = 1; for (int i = 1; i <= n; i++) { if (k <= i) { dp[i] = sum[i - k]; } sum[i] = sum[i - 1] + dp[i]; sum[i] %= mod; } System.out.println(sum[n]); } }