#include using namespace std; using ll =long long; #define all(v) v.begin(),v.end() #define rep(i,a,b) for(int i=a;i=b;i--) ll INF=2e18; const ll MOD = 998244353; vector fact, fact_inv, inv; void init_nCk(int SIZE) { fact.resize(SIZE + 5); fact_inv.resize(SIZE + 5); inv.resize(SIZE + 5); fact[0] = fact[1] = 1; fact_inv[0] = fact_inv[1] = 1; inv[1] = 1; for (int i = 2; i < SIZE + 5; i++) { fact[i] = fact[i - 1] * i % MOD; inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD; fact_inv[i] = fact_inv[i - 1] * inv[i] % MOD; } } long long nCk(int n, int k) { if((n < k)) return 0; if ((n < 0 || k < 0)) return 0; return fact[n] * (fact_inv[k] * fact_inv[n - k] % MOD) % MOD; } long long nPk(int n, int k) { if((n < k)) return 0; if ((n < 0 || k < 0)) return 0; return fact[n] * fact_inv[n - k] % MOD; } ll mod_pow(ll x,ll n,ll mod) { ll res=1; while(n>0) { if(n&1) { res=res*x%mod; } x=x*x%mod; n>>=1; } return res; } int main() { ios::sync_with_stdio(false); cin.tie(0); ll N,K;cin>>N>>K; init_nCk(K+100); ll ans=nCk(K,2)*N*2%MOD; ll x=mod_pow(K,MOD-2,MOD); for(ll i=0;i