import java.io.PrintWriter import java.util.* fun PrintWriter.solve() { val n = nextInt() val k = nextInt() val dp = Array(n + 1) { 1L } val mod = 998244353L for (i in 1..n) { dp[i] = dp[i - 1] + if (i >= k) dp[i - k] else 0 dp[i] %= mod } println(dp[n]) } fun main() { val writer = PrintWriter(System.out, false) writer.solve() writer.flush() } // region Scanner private var st = StringTokenizer("") private val br = System.`in`.bufferedReader() fun next(): String { while (!st.hasMoreTokens()) st = StringTokenizer(br.readLine()) return st.nextToken() } fun nextInt() = next().toInt() fun nextLong() = next().toLong() fun nextLine() = br.readLine() fun nextDouble() = next().toDouble() // endregion