結果

問題 No.1712 Read and Pile
ユーザー 👑 hos.lyrichos.lyric
提出日時 2024-03-30 13:20:51
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 429 ms / 2,000 ms
コード長 6,960 bytes
コンパイル時間 1,530 ms
コンパイル使用メモリ 119,820 KB
実行使用メモリ 10,368 KB
最終ジャッジ日時 2024-09-30 17:28:21
合計ジャッジ時間 6,375 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
5,248 KB
testcase_01 AC 1 ms
5,248 KB
testcase_02 AC 2 ms
5,248 KB
testcase_03 AC 2 ms
5,248 KB
testcase_04 AC 2 ms
5,248 KB
testcase_05 AC 2 ms
5,248 KB
testcase_06 AC 1 ms
5,248 KB
testcase_07 AC 2 ms
5,248 KB
testcase_08 AC 50 ms
8,576 KB
testcase_09 AC 64 ms
8,832 KB
testcase_10 AC 47 ms
8,448 KB
testcase_11 AC 55 ms
8,832 KB
testcase_12 AC 38 ms
8,320 KB
testcase_13 AC 47 ms
7,936 KB
testcase_14 AC 42 ms
7,808 KB
testcase_15 AC 55 ms
8,576 KB
testcase_16 AC 53 ms
8,192 KB
testcase_17 AC 48 ms
8,320 KB
testcase_18 AC 52 ms
8,192 KB
testcase_19 AC 46 ms
7,552 KB
testcase_20 AC 46 ms
7,936 KB
testcase_21 AC 56 ms
8,320 KB
testcase_22 AC 50 ms
8,576 KB
testcase_23 AC 65 ms
10,368 KB
testcase_24 AC 60 ms
10,240 KB
testcase_25 AC 67 ms
10,240 KB
testcase_26 AC 74 ms
10,368 KB
testcase_27 AC 73 ms
10,368 KB
testcase_28 AC 37 ms
7,424 KB
testcase_29 AC 58 ms
9,216 KB
testcase_30 AC 47 ms
8,448 KB
testcase_31 AC 30 ms
9,984 KB
testcase_32 AC 22 ms
8,192 KB
testcase_33 AC 49 ms
9,216 KB
testcase_34 AC 48 ms
8,320 KB
testcase_35 AC 67 ms
9,344 KB
testcase_36 AC 73 ms
9,856 KB
testcase_37 AC 69 ms
9,088 KB
testcase_38 AC 1 ms
5,248 KB
testcase_39 AC 39 ms
5,248 KB
testcase_40 AC 333 ms
5,248 KB
testcase_41 AC 211 ms
5,248 KB
testcase_42 AC 429 ms
5,248 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cassert>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <bitset>
#include <complex>
#include <deque>
#include <functional>
#include <iostream>
#include <limits>
#include <map>
#include <numeric>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

using namespace std;

using Int = long long;

template <class T1, class T2> ostream &operator<<(ostream &os, const pair<T1, T2> &a) { return os << "(" << a.first << ", " << a.second << ")"; };
template <class T> ostream &operator<<(ostream &os, const vector<T> &as) { const int sz = as.size(); os << "["; for (int i = 0; i < sz; ++i) { if (i >= 256) { os << ", ..."; break; } if (i > 0) { os << ", "; } os << as[i]; } return os << "]"; }
template <class T> void pv(T a, T b) { for (T i = a; i != b; ++i) cerr << *i << " "; cerr << endl; }
template <class T> bool chmin(T &t, const T &f) { if (t > f) { t = f; return true; } return false; }
template <class T> bool chmax(T &t, const T &f) { if (t < f) { t = f; return true; } return false; }
#define COLOR(s) ("\x1b[" s "m")

////////////////////////////////////////////////////////////////////////////////
template <unsigned M_> struct ModInt {
  static constexpr unsigned M = M_;
  unsigned x;
  constexpr ModInt() : x(0U) {}
  constexpr ModInt(unsigned x_) : x(x_ % M) {}
  constexpr ModInt(unsigned long long x_) : x(x_ % M) {}
  constexpr ModInt(int x_) : x(((x_ %= static_cast<int>(M)) < 0) ? (x_ + static_cast<int>(M)) : x_) {}
  constexpr ModInt(long long x_) : x(((x_ %= static_cast<long long>(M)) < 0) ? (x_ + static_cast<long long>(M)) : x_) {}
  ModInt &operator+=(const ModInt &a) { x = ((x += a.x) >= M) ? (x - M) : x; return *this; }
  ModInt &operator-=(const ModInt &a) { x = ((x -= a.x) >= M) ? (x + M) : x; return *this; }
  ModInt &operator*=(const ModInt &a) { x = (static_cast<unsigned long long>(x) * a.x) % M; return *this; }
  ModInt &operator/=(const ModInt &a) { return (*this *= a.inv()); }
  ModInt pow(long long e) const {
    if (e < 0) return inv().pow(-e);
    ModInt a = *this, b = 1U; for (; e; e >>= 1) { if (e & 1) b *= a; a *= a; } return b;
  }
  ModInt inv() const {
    unsigned a = M, b = x; int y = 0, z = 1;
    for (; b; ) { const unsigned q = a / b; const unsigned c = a - q * b; a = b; b = c; const int w = y - static_cast<int>(q) * z; y = z; z = w; }
    assert(a == 1U); return ModInt(y);
  }
  ModInt operator+() const { return *this; }
  ModInt operator-() const { ModInt a; a.x = x ? (M - x) : 0U; return a; }
  ModInt operator+(const ModInt &a) const { return (ModInt(*this) += a); }
  ModInt operator-(const ModInt &a) const { return (ModInt(*this) -= a); }
  ModInt operator*(const ModInt &a) const { return (ModInt(*this) *= a); }
  ModInt operator/(const ModInt &a) const { return (ModInt(*this) /= a); }
  template <class T> friend ModInt operator+(T a, const ModInt &b) { return (ModInt(a) += b); }
  template <class T> friend ModInt operator-(T a, const ModInt &b) { return (ModInt(a) -= b); }
  template <class T> friend ModInt operator*(T a, const ModInt &b) { return (ModInt(a) *= b); }
  template <class T> friend ModInt operator/(T a, const ModInt &b) { return (ModInt(a) /= b); }
  explicit operator bool() const { return x; }
  bool operator==(const ModInt &a) const { return (x == a.x); }
  bool operator!=(const ModInt &a) const { return (x != a.x); }
  friend std::ostream &operator<<(std::ostream &os, const ModInt &a) { return os << a.x; }
};
////////////////////////////////////////////////////////////////////////////////

constexpr unsigned MO = 998244353;
using Mint = ModInt<MO>;


template <class T> void bAdd(vector<T> &bit, int pos, const T &val) {
  const int bitN = bit.size();
  for (int x = pos; x < bitN; x |= x + 1) bit[x] += val;
}
template <class T> T bSum(const vector<T> &bit, int pos) {
  T ret = 0;
  for (int x = pos; x > 0; x &= x - 1) ret += bit[x - 1];
  return ret;
}
template <class T> T bSum(const vector<T> &bit, int pos0, int pos1) {
  return bSum(bit, pos1) - bSum(bit, pos0);
}


int N, M;
vector<int> A;

Mint slow() {
  vector<int> sums(M + 1, 0);
  for (int i = 0; i < M; ++i) sums[i + 1] = sums[i] + ((!~A[i]) ? 1 : 0);
  auto get = [&](int l, int r) -> int {
    chmax(l, 0);
    chmin(r, M);
    return (l <= r) ? (sums[r] - sums[l]) : 0;
  };
  
  vector<int> app(N);
  for (int a = 0; a < N; ++a) app[a] = ~a;
  
  const Mint prob = 1 - 2 / Mint(N);
  Mint ans = 0;
  for (int i = 0; i < M; ++i) {
    auto calc = [&](int a) -> Mint {
      Mint ret = 0;
      for (int b = 0; b < N; ++b) if (a != b) {
        // Pr[b above a]
        if (app[a] < app[b]) {
          ret += (1 + prob.pow(get(app[b], i))) / 2;
        } else {
          ret += (1 - prob.pow(get(app[a], i))) / 2;
        }
      }
      return ret;
    };
    if (~A[i]) {
      ans += calc(A[i]);
      app[A[i]] = i;
    } else {
      Mint sum = 0;
      for (int a = 0; a < N; ++a) sum += calc(a);
cerr<<"[slow] sum = "<<sum<<endl;
      ans += sum / N;
    }
  }
  ans += M;
  return ans;
}

Mint fast() {
  if (N <= 2) return slow();
  
  vector<int> sums(M + 1, 0);
  for (int i = 0; i < M; ++i) sums[i + 1] = sums[i] + ((!~A[i]) ? 1 : 0);
  
  vector<int> app(N);
  for (int a = 0; a < N; ++a) app[a] = ~a;
  
  const Mint prob = 1 - 2 / Mint(N);
  const Mint invProb = prob.inv();
  vector<Mint> pw(M + 1), invPw(M + 1);
  pw[0] = 1;
  invPw[0] = 1;
  for (int i = 1; i <= M; ++i) {
    pw[i] = pw[i - 1] * prob;
    invPw[i] = invPw[i - 1] * invProb;
  }
  
  // \sum 1, \sum prob^(-sums[app[b]])
  vector<Mint> bit0(N + M, 0), bit1(N + M, 0);
  auto add = [&](int b, Mint val) -> void {
    bAdd(bit0, N + app[b], val);
    bAdd(bit1, N + app[b], val * invPw[sums[max(app[b], 0)]]);
  };
  for (int b = 0; b < N; ++b) add(b, +1);
  
  Mint ans = 0;
  for (int i = 0; i < M; ++i) {
    auto calc = [&](int a) -> Mint {
      Mint ret = 0;
      /*
        app[a] < app[b]
        ret += (1 + prob.pow(get(app[b], i))) / 2;
      */
      ret += bSum(bit0, N + app[a] + 1, N + M);
      ret += bSum(bit1, N + app[a] + 1, N + M) * pw[sums[i]];
      /*
        app[a] > app[b]
        ret += (1 - prob.pow(get(app[a], i))) / 2;
      */
      ret += bSum(bit0, N + app[a]) * (1 - pw[sums[i] - sums[max(app[a], 0)]]);
      return ret;
    };
    if (~A[i]) {
      ans += calc(A[i]);
      add(A[i], -1);
      app[A[i]] = i;
      add(A[i], +1);
    } else {
      // 1/2 on average
      ans += Mint(N - 1);
    }
  }
  ans /= 2;
  ans += M;
  return ans;
}

int main() {
  for (; ~scanf("%d%d", &N, &M); ) {
    A.resize(M);
    for (int i = 0; i < M; ++i) {
      scanf("%d", &A[i]);
      if (~A[i]) --A[i];
    }
    
    const Mint ans = fast();
    printf("%u\n", ans.x);
#ifdef LOCAL
const Mint slw=slow();
cerr<<"slw = "<<slw<<endl;
#endif
  }
  return 0;
}
0