結果
問題 |
No.271 next_permutation (2)
|
ユーザー |
|
提出日時 | 2025-05-24 16:05:04 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 720 ms / 2,000 ms |
コード長 | 4,439 bytes |
コンパイル時間 | 3,867 ms |
コンパイル使用メモリ | 295,372 KB |
実行使用メモリ | 8,664 KB |
最終ジャッジ日時 | 2025-05-24 16:05:12 |
合計ジャッジ時間 | 7,892 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 21 |
ソースコード
#include <bits/stdc++.h> using namespace std; #define int long long #define rep(i, j, k) for(int i = (j); i <= (k); i++) #define per(i, j, k) for(int i = (j); i >= (k); i--) #define pb emplace_back #define fi first #define se second using vi = vector<int>; using pi = pair<int, int>; template<typename T0, typename T1> bool chmin(T0 &x, const T1 &y){ if(y < x){x = y; return true;} return false; } template<typename T0, typename T1> bool chmax(T0 &x, const T1 &y){ if(x < y){x = y; return true;} return false; } template<typename T> void debug(char *s, T x){ cerr << s <<" = "<< x <<endl; } template<typename T, typename ...Ar> void debug(char *s, T x, Ar... y){ int dep = 0; while(!(*s == ',' && dep == 0)){ if(*s == '(') dep++; if(*s == ')') dep--; cerr << *s++; } cerr <<" = "<< x <<","; debug(s + 1, y...); } #define gdb(...) debug((char*)#__VA_ARGS__, __VA_ARGS__) using u32 = uint32_t; using u64 = uint64_t; constexpr int mod = 1e9 + 7; struct mint{ u32 x; mint(): x(0){} mint(int _x){ _x %= mod; if(_x < 0) _x += mod; x = _x; } u32 val()const { return x; } mint qpow(int y = mod - 2)const { assert(y >= 0); mint x = *this, res = 1; while(y){ if(y % 2) res = res * x; x = x * x; y /= 2; } return res; } mint& operator += (const mint &B){ if((x += B.x) >= mod) x -= mod; return *this; } mint& operator -= (const mint &B){ if((x -= B.x) >= mod) x += mod; return *this; } mint& operator *= (const mint &B){ x = (u64)x * B.x % mod; return *this; } mint& operator /= (const mint &B){ return *this *= B.qpow(); } friend mint operator + (const mint &A, const mint &B){ return mint(A) += B; } friend mint operator - (const mint &A, const mint &B){ return mint(A) -= B; } friend mint operator * (const mint &A, const mint &B){ return mint(A) *= B; } friend mint operator / (const mint &A, const mint &B){ return mint(A) /= B; } mint operator-(){ return mint() - *this; } }; struct BIT{ vector<mint> s; BIT(int n){ s.resize(n); } void add(int x, mint w){ x++; while(x <= (int)s.size()){ s[x - 1] += w; x += x & -x; } } mint sum(int x){ mint res = 0; while(x){ res += s[x - 1]; x -= x & - x; } return res; } }; constexpr int inf = 4e18; void solve(){ int n, k; cin >> n >> k; mint sum = mint(k) * n * (n - 1) / 2; mint ans = 0; if(n <= 20){ int prd = 1; rep(i, 1, n){ prd *= i; } ans += mint(n * (n - 1) / 2) * mint(prd / 2) * mint(k / prd); k %= prd; if(k == 0){ cout << ans.val() <<'\n'; return; } } vi P(n); rep(i, 0, n - 1){ cin >> P[i]; } BIT T(n); vector<mint> pinv(n + 1), psum(n + 1); rep(i, 0, n - 1){ pinv[i + 1] = pinv[i] + T.sum(P[i]); psum[i + 1] = psum[i] + P[i]; T.add(P[i], 1); } auto slv = [&](vi A, vi B){ BIT T(n); mint cnt0 = 0, cnt1 = 0; int fac = 1; for(int x:A){ T.add(x, 1); cnt0 += T.sum(x); } for(int x:B){ cnt0 += T.sum(x); } cnt1 = mint(B.size() * (B.size() - 1)) / 4; rep(i, 1, (int)B.size()){ fac *= i; } k -= fac; ans += (cnt0 + cnt1) * fac; }; ans += pinv[n]; k--; set<int> st; int fac = 1; int i = n - 1; while(i >= 0){ st.insert(P[i]); auto it = st.upper_bound(P[i]); while(it != st.end() && k >= fac){ P[i] = *it; auto itx = st.begin(); rep(j, i + 1, n - 1){ if(it == itx) itx++; P[j] = *itx++; } slv(vi{P.begin(), P.begin() + i + 1}, vi{P.begin() + i + 1, P.end()}); it++; } if(it != st.end()){ P[i] = *it; auto itx = st.begin(); rep(j, i + 1, n - 1){ if(it == itx) itx++; P[j] = *itx++; } break; } if(i == 0){ sort(P.begin(), P.end()); } fac = min((__int128)inf, (__int128)fac * (n - i)); i--; } chmax(i, n - 21); gdb(k); fac = 1; per(j, n - 1, i + 1){ fac *= n - j; } st.clear(); rep(j, i + 1, n - 1){ st.insert(P[j]); } while((++i) < n){ fac /= n - i; auto it = st.begin(); while(it != st.end() && k >= fac){ P[i] = *it; auto itx = st.begin(); rep(j, i + 1, n - 1){ if(it == itx) itx++; P[j] = *itx++; } slv(vi{P.begin(), P.begin() + i + 1}, vi{P.begin() + i + 1, P.end()}); it++; } assert(it != st.end()); P[i] = *it; st.erase(it); } gdb(k); ans = sum - ans; cout << ans.val() <<'\n'; } signed main(){ #ifdef LOCAL freopen(".in", "r", stdin); freopen(".out", "w", stdout); #endif ios::sync_with_stdio(0); cin.tie(0); solve(); }