n,k=map(int,input().split()) dp=[0 for _ in range(n+1)] dp[0]=1 dp[1]=-1 mod=998244353 for i in range(n): if i+k<=n: dp[i+k]+=dp[i] dp[i+k]%=mod dp[i+1]+=dp[i] dp[i+1]%=mod #print(dp) ans=0 for i in range(n+1): ans+=dp[i] ans%=mod print(ans)