#include using namespace std; using ll=long long; #define ERASE(vec,s,e) vec.erase(vec.begin()+s, vec.begin()+e) #define ALL(v) (v).begin(), (v).end() #define CIN(v,n) for(int i=0; i<(n); i++) cin >> v[i] const int MOD=998244353; long long modpow(long long a,long long n,long long mod){ long long res=1; a%=mod; while(n>0){ if(n&1)res=res*a%mod; a=a*a%mod; n>>=1; } return res; } long long modinv(long long a,long long mod){ return modpow(a,mod-2,mod); } int main(){ const int MOD=998244353; vector factorial={1,808258749,117153405,761699708,573994984,62402409,511621808,242726978,887890124,875880304,0}; int N,K; cin>>N>>K; K=min(K,N-K); N%=MOD; if(N