結果

問題 No.3343 Distance Sum of Large Tree
コンテスト
ユーザー Kude
提出日時 2025-11-13 22:28:04
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 1,722 bytes
コンパイル時間 3,839 ms
コンパイル使用メモリ 315,504 KB
実行使用メモリ 14,464 KB
最終ジャッジ日時 2025-11-13 22:28:10
合計ジャッジ時間 6,440 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 2 WA * 28
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
namespace {
#pragma GCC diagnostic ignored "-Wunused-function"
#include<atcoder/all>
#pragma GCC diagnostic warning "-Wunused-function"
using namespace std;
using namespace atcoder;
#define rep(i,n) for(int i = 0; i < (int)(n); i++)
#define rrep(i,n) for(int i = (int)(n) - 1; i >= 0; i--)
#define all(x) begin(x), end(x)
#define rall(x) rbegin(x), rend(x)
template<class T> bool chmax(T& a, const T& b) { if (a < b) { a = b; return true; } else return false; }
template<class T> bool chmin(T& a, const T& b) { if (b < a) { a = b; return true; } else return false; }
using ll = long long;
using P = pair<int,int>;
using VI = vector<int>;
using VVI = vector<VI>;
using VL = vector<ll>;
using VVL = vector<VL>;
using mint = modint998244353;

} int main() {
  ios::sync_with_stdio(false);
  cin.tie(0);
  int n;
  cin >> n;
  VI a(n), b(n), c(n), p(n);
  rep(i, n) cin >> a[i];
  rep(i, n) if (i) cin >> b[i], b[i]--;
  rep(i, n) if (i) cin >> c[i], c[i]--;
  rep(i, n) if (i) cin >> p[i], p[i]--;
  ll tot = accumulate(all(a), 0LL);
  vector<vector<pair<int, ll>>> ch(n);
  mint ans;
  mint inv6 = mint(6).inv();
  rrep(i, n) {
    ll rest = tot - a[i];
    for (auto [j, cnt] : ch[i]) tot -= cnt;
    if (i) {
      ans += mint(tot - rest) * rest;
      ch[i].emplace_back(b[i], rest);
    }
    ch[i].emplace_back(0, 0);
    ch[i].emplace_back(a[i]-1, 0);
    ranges::sort(ch[i], {}, &pair<int, ll>::first);
    mint A;
    mint T = tot;
    rep(idx, ssize(ch[i]) - 1) {
      A += ch[i][idx].second;
      int l = ch[i][idx].first, r = ch[i][idx+1].first;
      ans += inv6 * (l-r) * (6*A*A+6*A*(l+r-T+1)+2*l*l+l*(2*r-3*T+3)+(r+1)*(2*r-3*T+1));
    }
  }
  ans *= 2;
  cout << ans.val() << '\n';
}
0