結果
| 問題 |
No.2115 Making Forest Easy
|
| コンテスト | |
| ユーザー |
👑 |
| 提出日時 | 2022-10-29 00:47:30 |
| 言語 | C (gcc 13.3.0) |
| 結果 |
AC
|
| 実行時間 | 1,529 ms / 2,000 ms |
| コード長 | 2,723 bytes |
| コンパイル時間 | 1,122 ms |
| コンパイル使用メモリ | 34,176 KB |
| 実行使用メモリ | 119,480 KB |
| 最終ジャッジ日時 | 2024-07-06 03:51:37 |
| 合計ジャッジ時間 | 27,443 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 50 |
ソースコード
#include <stdio.h>
const int Mod = 998244353;
typedef struct Edge {
struct Edge *next;
int v, id;
long long dp[1001];
} edge;
long long div_mod(long long x, long long y, long long z)
{
if (x % y == 0) return x / y;
else return (div_mod((1 + x / y) * y - x, (z % y), y) * z + x) / y;
}
int main()
{
int i, N, M, u, w, A[5001];
edge *adj[5001] = {}, e[10001], *p;
scanf("%d", &N);
for (u = 1; u <= N; u++) scanf("%d", &(A[u]));
for (i = 0; i < N - 1; i++) {
scanf("%d %d", &u, &w);
e[i*2].v = w;
e[i*2+1].v = u;
e[i*2].id = i * 2;
e[i*2+1].id = i * 2 + 1;
e[i*2].next = adj[u];
e[i*2+1].next = adj[w];
adj[u] = &(e[i*2]);
adj[w] = &(e[i*2+1]);
}
int par[5001] = {}, pred[5001], q[5001], head, tail;
q[0] = 1;
par[1] = 1;
pred[1] = -1;
for (head = 0, tail = 1; head < tail; head++) {
u = q[head];
for (p = adj[u]; p != NULL; p = p->next) {
w = p->v;
if (par[w] == 0) {
par[w] = u;
pred[w] = p->id;
q[tail++] = w;
}
}
}
int k;
long long sum;
for (head--; head >= 1; head--) {
u = q[head];
i = pred[u];
if (adj[u]->next == NULL) {
for (k = 0; k < A[u]; k++) e[i].dp[k] = 1;
for (k = A[u]; k <= 1000; k++) e[i].dp[k] = 2;
continue;
}
for (p = adj[u], sum = 1; p != NULL; p = p->next) {
if (p->v == par[u]) continue;
sum = sum * p->dp[1000] % Mod;
}
e[i].dp[0] = sum;
for (k = 1; k < A[u]; k++) e[i].dp[k] = e[i].dp[0];
for (k = A[u]; k <= 1000; k++) {
for (p = adj[u], sum = 1; p != NULL; p = p->next) {
if (p->v == par[u]) continue;
sum = sum * p->dp[k] % Mod;
}
e[i].dp[k] = (e[i].dp[0] + sum) % Mod;
}
}
long long ans = 0;
static long long prod[5001][1001];
for (k = 0; k <= 1000; k++) {
for (p = adj[1], prod[1][k] = 1; p != NULL; p = p->next) prod[1][k] = prod[1][k] * p->dp[k] % Mod;
if (k > 0) ans += (prod[1][k] - prod[1][k-1] + Mod) * ((k >= A[1])? k: A[1]) % Mod;
}
ans += prod[1][0] * A[1] % Mod;
for (head = 1; head < tail; head++) {
w = q[head];
i = pred[w] ^ 1;
u = par[w];
if (adj[u]->next == NULL) {
for (k = 0; k < A[u]; k++) e[i].dp[k] = 1;
for (k = A[u]; k <= 1000; k++) e[i].dp[k] = 2;
} else {
e[i].dp[0] = div_mod(prod[u][1000], e[i^1].dp[1000], Mod);
for (k = 1; k < A[u]; k++) e[i].dp[k] = e[i].dp[0];
for (k = A[u]; k <= 1000; k++) e[i].dp[k] = (e[i].dp[0] + div_mod(prod[u][k], e[i^1].dp[k], Mod)) % Mod;
}
for (k = 0; k <= 1000; k++) {
for (p = adj[w], prod[w][k] = 1; p != NULL; p = p->next) prod[w][k] = prod[w][k] * p->dp[k] % Mod;
if (k > 0) ans += (prod[w][k] - prod[w][k-1] + Mod) * ((k >= A[w])? k: A[w]) % Mod;
}
ans += prod[w][0] * A[w] % Mod;
}
printf("%lld\n", ans % Mod);
fflush(stdout);
return 0;
}