#pragma GCC target("avx2") #pragma GCC optimize("O3") #pragma GCC optimize("unroll-loops") #include #include // cout, endl, cin #include // string, to_string, stoi #include // vector #include // min, max, swap, sort, reverse, lower_bound, upper_bound #include // pair, make_pair #include // tuple, make_tuple #include // int64_t, int*_t #include // printf #include // map #include // queue, priority_queue #include // set #include // stack #include // deque #include // unordered_map #include // unordered_set #include // bitset #include // isupper, islower, isdigit, toupper, tolower #include #include #include using namespace std; using namespace atcoder; #define rep(i, n) for (int i = 0; i < (int)(n); i++) #define repi(i, a, b) for (int i = (int)(a); i < (int)(b); i++) typedef long long ll; typedef unsigned long long ull; const ll inf=1e18; using graph = vector > ; using P= pair; using vi=vector; using vvi=vector; using vll=vector; using vvll=vector; using vp=vector

; using vvp=vector; using vd=vector; using vvd =vector; //string T="ABCDEFGHIJKLMNOPQRSTUVWXYZ"; //string S="abcdefghijklmnopqrstuvwxyz"; //g++ main.cpp -std=c++17 -I . //cout < bool chmin(T& a, T b) { if (a > b) { a = b; return true; } else return false; } template bool chmax(T& a, T b) { if (a < b) { a = b; return true; } else return false; } // https://youtu.be/L8grWxBlIZ4?t=9858 // https://youtu.be/ERZuLAxZffQ?t=4807 : optimize // https://youtu.be/8uowVvQ_-Mo?t=1329 : division ll mod =998244353; ll sqrt_(ll x) { ll l = 0, r = ll(3e9)+1; while (l+1=0 && y>=0 && x struct Matrix { int h, w; vector> d; Matrix() {} Matrix(int h, int w, T val=0): h(h), w(w), d(h, vector(w,val)) {} Matrix& unit() { // assert(h == w); rep(i,h) d[i][i] = 1; return *this; } const vector& operator[](int i) const { return d[i];} vector& operator[](int i) { return d[i];} Matrix operator*(const Matrix& a) const { // assert(w == a.h); Matrix r(h, a.w); rep(i,h)rep(k,w)rep(j,a.w) { r[i][j] += d[i][k]*a[k][j]; } return r; } Matrix pow(long long t) const { // assert(h == w); if (!t) return Matrix(h,h).unit(); if (t == 1) return *this; Matrix r = pow(t>>1); r = r*r; if (t&1) r = r*(*this); return r; } // https://youtu.be/-j02o6__jgs?t=11273 /* mint only mint det() { assert(h == w); mint res = 1; rep(k,h) { for (int i = k; i < h; ++i) { if (d[i][k] == 0) continue; if (i != k) { swap(d[i],d[k]); res = -res; } } if (d[k][k] == 0) return 0; res *= d[k][k]; mint inv = mint(1)/d[k][k]; rep(j,h) d[k][j] *= inv; for (int i = k+1; i < h; ++i) { mint c = d[i][k]; for (int j = k; j < h; ++j) d[i][j] -= d[k][j]*c; } } return res; } //*/ }; int MOD = 998244353; vll fact, fact_inv, inv; void init_nCk(int n) { fact.resize(n+1); fact_inv.resize(n+1); inv.resize(n+1); fact[0] = fact[1] = 1; fact_inv[0] = fact_inv[1] = 1; inv[1] = 1; for (int i = 2; i<=n; i++) { fact[i] = fact[i - 1] * i % MOD; inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD; fact_inv[i] = fact_inv[i - 1] * inv[i] % MOD; } } ll nCk(int n, int k) { assert(!(n < k)); assert(!(n < 0 || k < 0)); return fact[n] * (fact_inv[k] * fact_inv[n - k] % MOD) % MOD; } using ve=vector; using vve=vector; using mint=modint998244353; using vm=vector; void solve(int test){ init_nCk(500000); int m,n; cin >> m >> n; vi a(n);rep(i,n)cin >> a[i]; vi cnt(m+1); rep(i,n)cnt[a[i]]++; vp p; { vi v; for(int i=1; i<=m; i++)if(cnt[i])v.push_back(cnt[i]); vi cnt2(n+1); for(auto u:v)cnt2[u]++; for(int i=1; i<=n; i++)if(cnt2[i])p.push_back(P(i,cnt2[i])); } vm kai(n+1,1);for(int i=1; i<=n; i++)kai[i]=kai[i-1]*i; vm kaiinv(n+1,1);for(int i=1; i<=n; i++)kaiinv[i]=kai[i].inv(); vll g(n+1,1); vll g2(n+1,1); mint now=1; mint now2=1; for(int i=0; i<=n; i++){ now=1; for(auto u:p){ int num=u.first; now2=nCk(i,num)*kai[num]; now2=pow_pow(now2.val(),u.second,mod); now*=now2; } now*=kaiinv[i]; g[i]=now.val(); } mint base=1; for(int i=0; i<=n; i++){ now=base; now*=kaiinv[i]; g2[i]=now.val(); base*=(mod-1); } auto conv=convolution(g,g2); mint ans=0; for(int i=1; i<=n; i++){ ans+=kai[i]*conv[i]; } cout << ans.val() << endl; } //g++ main.cpp -std=c++17 -I . int main(){cin.tie(0);ios::sync_with_stdio(false); int t=1; //cin >> t; rep(test,t)solve(test); }