結果

問題 No.2588 Increasing Record
ユーザー 👑 rin204rin204
提出日時 2023-11-24 15:46:48
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,472 ms / 3,000 ms
コード長 6,932 bytes
コンパイル時間 511 ms
コンパイル使用メモリ 82,612 KB
実行使用メモリ 200,792 KB
最終ジャッジ日時 2024-09-27 06:39:27
合計ジャッジ時間 39,669 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
55,132 KB
testcase_01 AC 39 ms
53,980 KB
testcase_02 AC 39 ms
54,092 KB
testcase_03 AC 39 ms
53,760 KB
testcase_04 AC 39 ms
53,620 KB
testcase_05 AC 38 ms
53,548 KB
testcase_06 AC 39 ms
54,204 KB
testcase_07 AC 40 ms
54,512 KB
testcase_08 AC 40 ms
54,952 KB
testcase_09 AC 40 ms
55,740 KB
testcase_10 AC 40 ms
54,564 KB
testcase_11 AC 42 ms
55,824 KB
testcase_12 AC 381 ms
99,892 KB
testcase_13 AC 375 ms
99,376 KB
testcase_14 AC 413 ms
100,224 KB
testcase_15 AC 453 ms
100,468 KB
testcase_16 AC 778 ms
114,316 KB
testcase_17 AC 958 ms
132,120 KB
testcase_18 AC 1,184 ms
146,520 KB
testcase_19 AC 1,396 ms
155,764 KB
testcase_20 AC 1,433 ms
154,768 KB
testcase_21 AC 1,447 ms
156,528 KB
testcase_22 AC 1,452 ms
158,928 KB
testcase_23 AC 1,471 ms
161,236 KB
testcase_24 AC 1,454 ms
160,988 KB
testcase_25 AC 1,034 ms
129,432 KB
testcase_26 AC 1,246 ms
142,712 KB
testcase_27 AC 1,414 ms
160,448 KB
testcase_28 AC 1,389 ms
159,388 KB
testcase_29 AC 1,394 ms
160,492 KB
testcase_30 AC 739 ms
131,496 KB
testcase_31 AC 863 ms
149,844 KB
testcase_32 AC 961 ms
164,532 KB
testcase_33 AC 978 ms
166,304 KB
testcase_34 AC 971 ms
169,500 KB
testcase_35 AC 964 ms
170,368 KB
testcase_36 AC 972 ms
169,156 KB
testcase_37 AC 1,472 ms
158,724 KB
testcase_38 AC 1,052 ms
132,128 KB
testcase_39 AC 1,017 ms
167,752 KB
testcase_40 AC 1,043 ms
184,088 KB
testcase_41 AC 1,031 ms
193,184 KB
testcase_42 AC 1,017 ms
194,672 KB
testcase_43 AC 1,019 ms
200,792 KB
testcase_44 AC 977 ms
144,824 KB
testcase_45 AC 1,013 ms
146,600 KB
testcase_46 AC 1,020 ms
154,588 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353


class UnionFind:
    def __init__(self, n):
        self.n = n
        self.par = [-1] * n
        self.group_ = n

    def find(self, x):
        if self.par[x] < 0:
            return x
        lst = []
        while self.par[x] >= 0:
            lst.append(x)
            x = self.par[x]
        for y in lst:
            self.par[y] = x
        return x

    def unite(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x == y:
            return False

        if self.par[x] > self.par[y]:
            x, y = y, x

        self.par[x] += self.par[y]
        self.par[y] = x
        self.group_ -= 1
        return True

    def size(self, x):
        return -self.par[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    @property
    def group(self):
        return self.group_


class BIT:
    def __init__(self, n):
        self.n = n
        self.data = [0] * (n + 1)
        if n == 0:
            self.n0 = 0
        else:
            self.n0 = 1 << (n.bit_length() - 1)

    def sum_(self, i):
        s = 0
        while i > 0:
            s += self.data[i]
            i -= i & -i
        return s

    def sum(self, l, r=-1):
        if r == -1:
            return self.sum_(l)
        else:
            return self.sum_(r) - self.sum_(l)

    def get(self, i):
        return self.sum(i, i + 1)

    def add(self, i, x):
        i += 1
        while i <= self.n:
            self.data[i] += x
            i += i & -i

    def lower_bound(self, x):
        if x <= 0:
            return 0
        i = 0
        k = self.n0
        while k > 0:
            if i + k <= self.n and self.data[i + k] < x:
                x -= self.data[i + k]
                i += k
            k //= 2
        return i + 1


class HLD:
    def __init__(self, n, edges=None):
        self.n = n
        if edges is None:
            self.edges = [[] for _ in range(n)]
        else:
            self.edges = edges
            # コピーしてないので注意

        self.size = [-1] * n
        self.par = [-1] * n
        self.depth = [-1] * n
        self.path_ind = [-1] * n
        self.path_root = []
        self.heavy_child = [-1] * n
        self.isheavy = [False] * n
        self.L = [-1] * n
        self.R = [-1] * n

    def add_edge(self, u, v):
        self.edges[u].append(v)
        self.edges[v].append(u)

    def read_edges(self, indexed=1):
        for _ in range(self.n - 1):
            u, v = map(int, input().split())
            u -= indexed
            v -= indexed
            self.add_edge(u, v)

    def build(self, root=0):
        self.depth[root] = 0
        st = [root]
        route = [root]
        while st:
            pos = st.pop()
            for npos in self.edges[pos]:
                if self.depth[npos] == -1:
                    self.depth[npos] = self.depth[pos] + 1
                    self.par[npos] = pos
                    st.append(npos)
                    route.append(npos)

        for pos in route[::-1]:
            self.size[pos] = 1
            ma = -1
            for npos in self.edges[pos]:
                if self.size[npos] != -1:
                    self.size[pos] += self.size[npos]
                    if self.size[npos] > ma:
                        ma = self.size[npos]
                        self.heavy_child[pos] = npos

            if ma != -1:
                self.isheavy[self.heavy_child[pos]] = True

        self.isheavy[root] = True

        path = 0
        st = [~root, root]
        self.path_root = [root]
        cc = 0
        while st:
            pos = st.pop()
            if pos >= 0:
                self.L[pos] = cc
                cc += 1
                if not self.isheavy[pos]:
                    path += 1
                    self.path_root.append(pos)

                self.path_ind[pos] = path
                for npos in self.edges[pos]:
                    if npos == self.par[pos] or npos == self.heavy_child[pos]:
                        continue
                    st.append(~npos)
                    st.append(npos)

                if self.heavy_child[pos] != -1:
                    npos = self.heavy_child[pos]
                    st.append(~npos)
                    st.append(npos)

            else:
                self.R[~pos] = cc

    def get_path(self, u, v):
        ll = [u]
        rr = [v]
        while self.path_ind[u] != self.path_ind[v]:
            if (
                self.depth[self.path_root[self.path_ind[u]]]
                >= self.depth[self.path_root[self.path_ind[v]]]
            ):
                u = self.path_root[self.path_ind[u]]
                ll.append(u)
                u = self.par[u]
                ll.append(u)
            else:
                v = self.path_root[self.path_ind[v]]
                rr.append(v)
                v = self.par[v]
                rr.append(v)

        ll += rr[::-1]
        res = []
        for i in range(0, len(ll), 2):
            res.append((ll[i], ll[i + 1]))

        return res

    def lca(self, u, v):
        while self.path_ind[u] != self.path_ind[v]:
            if (
                self.depth[self.path_root[self.path_ind[u]]]
                >= self.depth[self.path_root[self.path_ind[v]]]
            ):
                u = self.par[self.path_root[self.path_ind[u]]]
            else:
                v = self.par[self.path_root[self.path_ind[v]]]

        if self.depth[u] >= self.depth[v]:
            return v
        else:
            return u

    def dist(self, u, v):
        return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)]

    def reorder(self, A, rev=False):
        ret = [0] * self.n
        for i in range(self.n):
            ret[self.L[i]] = A[i]

        if rev:
            ret = ret[::-1]
        return ret


n, m = map(int, input().split())
assert 2 <= n <= 200000
assert 1 <= m <= 200000

E = [[] for _ in range(n)]
se = set()
for _ in range(m):
    u, v = map(int, input().split())
    assert 1 <= u < v <= n
    assert u * n + v not in se
    se.add(u * n + v)

    E[v - 1].append(u - 1)

UF = UnionFind(n)
hld = HLD(n)
ma = [i for i in range(n)]
for v in range(n):
    for u in E[v]:
        if not UF.same(u, v):
            hld.add_edge(ma[UF.find(u)], v)
            UF.unite(u, v)
            ma[UF.find(u)] = v
assert UF.group == 1

hld.build()
bit = BIT(n)

for v in range(n):
    lr = []
    for u in E[v]:
        for uu, vv in hld.get_path(u, v):
            uu = hld.L[uu]
            vv = hld.L[vv]
            if uu > vv:
                uu, vv = vv, uu
            lr.append(uu * (n + 1) + (vv + 1))

    lr.sort()
    tot = 0
    mar = 0
    for tmp in lr:
        l = tmp // (n + 1)
        r = tmp - l * (n + 1)
        l = max(l, mar)
        if l < r:
            tot += bit.sum(l, r)
        mar = max(mar, r)

    bit.add(hld.L[v], tot % MOD + 1)

ans = bit.sum(0, n)
print(ans % MOD)
0