#include using namespace std; #define ll long long const ll maxn=2000005; const ll mod=998244353; ll s[maxn+5],p[maxn+5]; ll qpow(ll a,ll b){ ll ans,cur; ans=1,cur=a; while(b){ if(b&1)ans=ans*cur%mod; cur=cur*cur%mod; b>>=1; } return ans; } ll inv(ll a){ return qpow(a,mod-2); } void init(){ ll i; s[0]=1; for(i=1;i<=maxn;i++)s[i]=s[i-1]*i%mod; p[maxn]=inv(s[maxn]); for(i=maxn-1;i>=0;i--)p[i]=p[i+1]*(i+1)%mod; } ll cal(ll n,ll m){ ll i,ans,cur; if(m<0 or m>n)return 0; if(n<=maxn)return s[n]*p[m]%mod*p[n-m]%mod; if(m>n/2)m=n-m; cur=1,ans=1; for(i=1;i<=m;i++){ cur=cur*((n-i+1)%mod)%mod; ans=ans*i%mod; } return cur*inv(ans)%mod; } ll luc(ll n,ll m){ if(m==0)return 1; return luc(n/mod,m/mod)*cal(n%mod,m%mod)%mod; } int main(){ ll n,m; init(); scanf("%lld%lld",&n,&m); if(m>n/2)m=n-m; printf("%lld\n",luc(n,m)); return 0; }