#include <bits/stdc++.h>
using namespace std;
using ll = long long;

ll a[1000005];
map<ll, int> memo;
ll const mod = 998244353;

int n, m;

ll f(int i){
  if(memo.find(i)!=memo.end()) return memo[i];
  if(i<=n-1) return memo[i] = 1;
  else return memo[i] = (f(i-1)+f(i-n))%mod;
}


int main(){
  cin >> n >> m;
  if(n > m){
    cout << 1 << endl;
    return 0;
  }
  if(n==1){
    cout << 1 << endl;
    return 0;
  }
  if(n <= m){
    cout << f(m) << endl;
    return 0;
  }
}