結果

問題 No.1138 No Bingo!
ユーザー fastmathfastmath
提出日時 2020-08-02 17:56:54
言語 C++17(clang)
(17.0.6 + boost 1.87.0)
結果
AC  
実行時間 340 ms / 3,000 ms
コード長 12,645 bytes
コンパイル時間 4,953 ms
コンパイル使用メモリ 167,564 KB
実行使用メモリ 33,156 KB
最終ジャッジ日時 2024-07-19 03:43:38
合計ジャッジ時間 9,195 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 7 ms
6,400 KB
testcase_01 AC 6 ms
6,528 KB
testcase_02 AC 6 ms
6,528 KB
testcase_03 AC 104 ms
19,328 KB
testcase_04 AC 202 ms
32,200 KB
testcase_05 AC 6 ms
6,528 KB
testcase_06 AC 6 ms
6,528 KB
testcase_07 AC 6 ms
6,528 KB
testcase_08 AC 7 ms
6,528 KB
testcase_09 AC 106 ms
19,996 KB
testcase_10 AC 186 ms
31,144 KB
testcase_11 AC 199 ms
32,032 KB
testcase_12 AC 60 ms
13,480 KB
testcase_13 AC 24 ms
7,972 KB
testcase_14 AC 195 ms
31,816 KB
testcase_15 AC 157 ms
19,484 KB
testcase_16 AC 195 ms
31,164 KB
testcase_17 AC 340 ms
32,740 KB
testcase_18 AC 331 ms
32,548 KB
testcase_19 AC 192 ms
21,060 KB
testcase_20 AC 333 ms
33,156 KB
testcase_21 AC 183 ms
30,616 KB
testcase_22 AC 178 ms
30,116 KB
testcase_23 AC 14 ms
7,120 KB
testcase_24 AC 181 ms
20,736 KB
testcase_25 AC 86 ms
17,760 KB
testcase_26 AC 63 ms
13,332 KB
testcase_27 AC 99 ms
19,276 KB
testcase_28 AC 56 ms
13,392 KB
testcase_29 AC 56 ms
13,288 KB
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp:182:70: warning: operator '<<' has lower precedence than '-'; '-' will be evaluated first [-Wshift-op-parentheses]
  182 |     if (power(root, 1 << max_base) == 1 && power(root, 1 << max_base - 1) != 1) {
      |                                                          ~~ ~~~~~~~~~^~~
main.cpp:182:70: note: place parentheses around the '-' expression to silence this warning
  182 |     if (power(root, 1 << max_base) == 1 && power(root, 1 << max_base - 1) != 1) {
      |                                                                      ^  
      |                                                             (           )
main.cpp:199:50: warning: operator '<<' has lower precedence than '-'; '-' will be evaluated first [-Wshift-op-parentheses]
  199 |     rev[i] = rev[i >> 1] >> 1 | (i & 1) << nbase - 1;
      |                                         ~~ ~~~~~~^~~
main.cpp:199:50: note: place parentheses around the '-' expression to silence this warning
  199 |     rev[i] = rev[i >> 1] >> 1 | (i & 1) << nbase - 1;
      |                                                  ^  
      |                                            (        )
main.cpp:203:43: warning: operator '<<' has lower precedence than '-'; '-' will be evaluated first [-Wshift-op-parentheses]
  203 |     int z = power(root, 1 << max_base - 1 - base);
      |                           ~~ ~~~~~~~~~~~~~^~~~~~
main.cpp:203:43: note: place parentheses around the '-' expression to silence this warning
  203 |     int z = power(root, 1 << max_base - 1 - base);
      |                                           ^     
      |                              (                  )
main.cpp:204:28: warning: operator '<<' has lower precedence than '-'; '-' will be evaluated first [-Wshift-op-parentheses]
  204 |     for (int i = 1 << base - 1; i < 1 << base; ++i) {
      |                    ~~ ~~~~~^~~
main.cpp:204:28: note: place parentheses around the '-' expression to silence this warning
  204 |     for (

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std;
#define int long long
#define ii pair <int, int>
#define app push_back
#define all(a) a.begin(), a.end()
#define bp __builtin_popcountll
#define ll long long
#define mp make_pair
#define f first
#define s second
#define Time (double)clock()/CLOCKS_PER_SEC
#define debug(x) std::cout << #x << ": " << x << '\n';

//need define int long long
const int N = 2e5+7, MOD = 998244353;
int mod(int n) {
    n %= MOD;
    if (n < 0) return n + MOD;
    else return n;
}   
int fp(int a, int p) {
    int ans = 1, c = a;
    for (int i = 0; (1ll << i) <= p; ++i) {
        if ((p >> i) & 1) ans = mod(ans * c);
        c = mod(c * c);
    }   
    return ans;
}   
int dv(int a, int b) { return mod(a * fp(b, MOD - 2)); }

struct M {
ll x;
M (int x_) {
    x = mod(x_);
}   
M () {
    x = 0;
}
M operator + (M y) {
    int ans = x + y.x;
    if (ans >= MOD)
        ans -= MOD;
    return M(ans);
}
M operator - (M y) {
    int ans = x - y.x;
    if (ans < 0)
        ans += MOD;
    return M(ans);            
}   
M operator * (M y) {
    return M(x * y.x % MOD);   
}   
M operator / (M y) {
    return M(x * fp(y.x, MOD - 2) % MOD);
}   
M operator + (int y) {
    return (*this) + M(y);
}
M operator - (int y) {
    return (*this) - M(y);
}   
M operator * (int y) {
    return (*this) * M(y);
}   
M operator / (int y) {
    return (*this) / M(y);
}   
M operator ^ (int p) {
    return M(fp(x, p));
}   
void operator += (M y) {
    *this = *this + y;
}   
void operator -= (M y) {
    *this = *this - y;
}   
void operator *= (M y) {
    *this = *this * y;
}
void operator /= (M y) {
    *this = *this / y;
}   
void operator += (int y) {
    *this = *this + y;
}   
void operator -= (int y) {
    *this = *this - y;
}   
void operator *= (int y) {
    *this = *this * y;
}
void operator /= (int y) {
    *this = *this / y;
}   
void operator ^= (int p) {
    *this = *this ^ p;
}
};  

M f[N], inv[N];
void prec() {
    f[0] = M(1);
    for (int i = 1; i < N; ++i)
        f[i] = f[i - 1] * M(i);
    inv[N - 1] = f[N - 1] ^ (MOD - 2);
    for (int i = N - 2; i >= 0; --i)
        inv[i] = inv[i + 1] * M(i + 1);
}
M C(int n, int k) {
    if (n < k)
        return M(0);
    else
        return f[n] * inv[k] * inv[n - k];
}   

const int md = 998244353;
 
namespace faq{
 
inline void add(int &x, int y) {
  x += y;
  if (x >= md) {
    x -= md;
  }
}
 
inline void sub(int &x, int y) {
  x -= y;
  if (x < 0) {
    x += md;
  }
}
 
inline int mul(int x, int y) {
  return (long long) x * y % md;
}
 
inline int power(int x, int y) {
  int res = 1;
  for (; y; y >>= 1, x = mul(x, x)) {
    if (y & 1) {
      res = mul(res, x);
    }
  }
  return res;
}
 
inline int inv(int a) {
  a %= md;
  if (a < 0) {
    a += md;
  }
  int b = md, u = 0, v = 1;
  while (a) {
    int t = b / a;
    b -= t * a;
    swap(a, b);
    u -= t * v;
    swap(u, v);
  }
  if (u < 0) {
    u += md;
  }
  return u;
}
 
namespace ntt {
int base = 1, root = -1, max_base = -1;
vector<int> rev = {0, 1}, roots = {0, 1};
 
void init() {
  int temp = md - 1;
  max_base = 0;
  while (temp % 2 == 0) {
    temp >>= 1;
    ++max_base;
  }
  root = 2;
  while (true) {
    if (power(root, 1 << max_base) == 1 && power(root, 1 << max_base - 1) != 1) {
      break;
    }
    ++root;
  }
}
 
void ensure_base(int nbase) {
  if (max_base == -1) {
    init();
  }
  if (nbase <= base) {
    return;
  }
  assert(nbase <= max_base);
  rev.resize(1 << nbase);
  for (int i = 0; i < 1 << nbase; ++i) {
    rev[i] = rev[i >> 1] >> 1 | (i & 1) << nbase - 1;
  }
  roots.resize(1 << nbase);
  while (base < nbase) {
    int z = power(root, 1 << max_base - 1 - base);
    for (int i = 1 << base - 1; i < 1 << base; ++i) {
      roots[i << 1] = roots[i];
      roots[i << 1 | 1] = mul(roots[i], z);
    }
    ++base;
  }
}
 
void dft(vector<int> &a) {
  int n = a.size(), zeros = __builtin_ctz(n);
  ensure_base(zeros);
  int shift = base - zeros;
  for (int i = 0; i < n; ++i) {
    if (i < rev[i] >> shift) {
      swap(a[i], a[rev[i] >> shift]);
    }
  }
  for (int i = 1; i < n; i <<= 1) {
    for (int j = 0; j < n; j += i << 1) {
      for (int k = 0; k < i; ++k) {
        int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
        a[j + k] = (x + y) % md;
        a[j + k + i] = (x + md - y) % md;
      }
    }
  }
}
 
vector<int> multiply(vector<int> a, vector<int> b) {
  int need = a.size() + b.size() - 1, nbase = 0;
  while (1 << nbase < need) {
    ++nbase;
  }
  ensure_base(nbase);
  int sz = 1 << nbase;
  a.resize(sz);
  b.resize(sz);
  bool equal = a == b;
  dft(a);
  if (equal) {
    b = a;
  } else {
    dft(b);
  }
  int inv_sz = inv(sz);
  for (int i = 0; i < sz; ++i) {
    a[i] = mul(mul(a[i], b[i]), inv_sz);
  }
  reverse(a.begin() + 1, a.end());
  dft(a);
  a.resize(need);
  return a;
}
 
vector<int> inverse(vector<int> a) {
  int n = a.size(), m = n + 1 >> 1;
  if (n == 1) {
    return vector<int>(1, inv(a[0]));
  } else {
    vector<int> b = inverse(vector<int>(a.begin(), a.begin() + m));
    int need = n << 1, nbase = 0;
    while (1 << nbase < need) {
      ++nbase;
    }
    ensure_base(nbase);
    int sz = 1 << nbase;
    a.resize(sz);
    b.resize(sz);
    dft(a);
    dft(b);
    int inv_sz = inv(sz);
    for (int i = 0; i < sz; ++i) {
      a[i] = mul(mul(md + 2 - mul(a[i], b[i]), b[i]), inv_sz);
    }
    reverse(a.begin() + 1, a.end());
    dft(a);
    a.resize(n);
    return a;
  }
}
}
 
using ntt::multiply;
using ntt::inverse;
 
vector<int>& operator += (vector<int> &a, const vector<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    add(a[i], b[i]);
  }
  return a;
}
 
vector<int> operator + (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c += b;
}
 
vector<int>& operator -= (vector<int> &a, const vector<int> &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < b.size(); ++i) {
    sub(a[i], b[i]);
  }
  return a;
}
 
vector<int> operator - (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c -= b;
}
 
vector<int>& operator *= (vector<int> &a, const vector<int> &b) {
  if (min(a.size(), b.size()) < 128) {
    vector<int> c = a;
    a.assign(a.size() + b.size() - 1, 0);
    for (int i = 0; i < c.size(); ++i) {
      for (int j = 0; j < b.size(); ++j) {
        add(a[i + j], mul(c[i], b[j]));
      }
    }
  } else {
    a = multiply(a, b);
  }
  return a;
}
 
vector<int> operator * (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c *= b;
}
 
vector<int>& operator /= (vector<int> &a, const vector<int> &b) {
  int n = a.size(), m = b.size();
  if (n < m) {
    a.clear();
  } else {
    vector<int> c = b;
    reverse(a.begin(), a.end());
    reverse(c.begin(), c.end());
    c.resize(n - m + 1);
    a *= inverse(c);
    a.erase(a.begin() + n - m + 1, a.end());
    reverse(a.begin(), a.end());
  }
  return a;
}
 
vector<int> operator / (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c /= b;
}
 
vector<int>& operator %= (vector<int> &a, const vector<int> &b) {
  int n = a.size(), m = b.size();
  if (n >= m) {
    vector<int> c = (a / b) * b;
    a.resize(m - 1);
    for (int i = 0; i < m - 1; ++i) {
      sub(a[i], c[i]);
    }
  }
  return a;
}
 
vector<int> operator % (const vector<int> &a, const vector<int> &b) {
  vector<int> c = a;
  return c %= b;
}
 
vector<int> derivative(const vector<int> &a) {
  int n = a.size();
  vector<int> b(n - 1);
  for (int i = 1; i < n; ++i) {
    b[i - 1] = mul(a[i], i);
  }
  return b;
}
 
vector<int> primitive(const vector<int> &a) {
  int n = a.size();
  vector<int> b(n + 1), invs(n + 1);
  for (int i = 1; i <= n; ++i) {
    invs[i] = i == 1 ? 1 : mul(md - md / i, invs[md % i]);
    b[i] = mul(a[i - 1], invs[i]);
  }
  return b;
}
 
vector<int> logarithm(const vector<int> &a) {
  vector<int> b = primitive(derivative(a) * inverse(a));
  b.resize(a.size());
  return b;
}
 
vector<int> exponent(const vector<int> &a) {
  vector<int> b(1, 1);
  while (b.size() < a.size()) {
    vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    add(c[0], 1);
    vector<int> old_b = b;
    b.resize(b.size() << 1);
    c -= logarithm(b);
    c *= old_b;
    for (int i = b.size() >> 1; i < b.size(); ++i) {
      b[i] = c[i];
    }
  }
  b.resize(a.size());
  return b;
}
 
vector<int> power(const vector<int> &a, int m) {
  int n = a.size(), p = -1;
  vector<int> b(n);
  for (int i = 0; i < n; ++i) {
    if (a[i]) {
      p = i;
      break;
    }
  }
  if (p == -1) {
    b[0] = !m;
    return b;
  }
  if ((long long) m * p >= n) {
    return b;
  }
  int mu = power(a[p], m), di = inv(a[p]);
  vector<int> c(n - m * p);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(a[i + p], di);
  }
  c = logarithm(c);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(c[i], m);
  }
  c = exponent(c);
  for (int i = 0; i < n - m * p; ++i) {
    b[i + m * p] = mul(c[i], mu);
  }
  return b;
}
 
vector<int> sqrt(const vector<int> &a) {
  vector<int> b(1, 1);
  while (b.size() < a.size()) {
    vector<int> c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    vector<int> old_b = b;
    b.resize(b.size() << 1);
    c *= inverse(b);
    for (int i = b.size() >> 1; i < b.size(); ++i) {
      b[i] = mul(c[i], md + 1 >> 1);
    }
  }
  b.resize(a.size());
  return b;
}
 
vector<int> multiply_all(int l, int r, vector<vector<int>> &all) {
  if (l > r) {
    return vector<int>();
  } else if (l == r) {
    return all[l];
  } else {
    int y = l + r >> 1;
    return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
  }
}
 
vector<int> evaluate(const vector<int> &f, const vector<int> &x) {
  int n = x.size();
  if (!n) {
    return vector<int>();
  }
  vector<vector<int>> up(n * 2);
  for (int i = 0; i < n; ++i) {
    up[i + n] = vector<int>{(md - x[i]) % md, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  vector<vector<int>> down(n * 2);
  down[1] = f % up[1];
  for (int i = 2; i < n * 2; ++i) {
    down[i] = down[i >> 1] % up[i];
  }
  vector<int> y(n);
  for (int i = 0; i < n; ++i) {
    y[i] = down[i + n][0];
  }
  return y;
}
 
vector<int> interpolate(const vector<int> &x, const vector<int> &y) {
  int n = x.size();
  vector<vector<int>> up(n * 2);
  for (int i = 0; i < n; ++i) {    up[i + n] = vector<int>{(md - x[i]) % md, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  vector<int> a = evaluate(derivative(up[1]), x);
  for (int i = 0; i < n; ++i) {
    a[i] = mul(y[i], inv(a[i]));
  }
  vector<vector<int>> down(n * 2);
  for (int i = 0; i < n; ++i) {
    down[i + n] = vector<int>(1, a[i]);
  }
  for (int i = n - 1; i; --i) {
    down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
  }
  return down[1];
}
}

signed main() {
    #ifdef HOME
    freopen("input.txt", "r", stdin);
    #else
    #define endl '\n'
    ios_base::sync_with_stdio(0); cin.tie(0);
    #endif
    int n;
    cin >> n;
    vector <int> p(n);
    for (int i = 0; i < n; ++i)
        p[i] = i;
    prec();
    M b = M(0);
    for (int i = 0; i <= n; ++i) {
        if (i & 1)
            b -= f[n]/f[i];
        else
            b += f[n]/f[i];
    }   
    M ans = f[n] - b * 2;
    for (int mid = 0; mid <= (n&1); ++mid) {
        /*
        for (int p = 0; p <= n/2; ++p) {
            for (int i = 0; 2 * p + i <= n - (n&1); ++i) {

                // f[n - (n&1) -2*p] * inv[i] * inv[n - (n&1) -2*p- i]

                M t = C(n/2,p)*C(n - (n&1) -2*p,i)*fp(2,p+i)*f[n-2*p-i-mid];
                if ((p+i+mid) & 1)
                    ans -= t;
                else
                    ans += t;
            }   
        }
        */          

        vector <M> a(n+1), b(n+1);

        for (int p = 0; p <= n/2; ++p) {
            a[p * 2] = C(n/2,p) * f[n - (n&1) - 2 * p] * fp(2,p) * fp(-1, p);
        }   
        for (int i = 0; i <= n - (n&1); ++i) {
            b[i] = inv[i] * fp(2, i) * fp(-1, i);
        }   
        
        vector <int> a1,b1;
        for (auto e : a)
            a1.app(e.x);
        for (auto e : b)
            b1.app(e.x);

        auto res = faq::multiply(a1, b1);

        /*
        for (int p = 0; p <= n; ++p) {
            for (int i = 0; p + i <= n - (n & 1); ++i) {
                res[p + i] += a[p] * b[i];
            }   
        }
        */

        for (int i = 0; i <= n - (n & 1); ++i) {
            ans += (inv[n - (n&1) - i] * f[n - mid - i] * fp(-1, mid)) * res[i];
        }   
    }
    cout << ans.x << endl;
}
0