結果

問題 No.2588 Increasing Record
ユーザー 👑 rin204rin204
提出日時 2023-11-24 15:46:48
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,202 ms / 3,000 ms
コード長 6,932 bytes
コンパイル時間 170 ms
コンパイル使用メモリ 81,828 KB
実行使用メモリ 201,124 KB
最終ジャッジ日時 2023-12-15 23:31:06
合計ジャッジ時間 32,981 ms
ジャッジサーバーID
(参考情報)
judge14 / judge15
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
55,736 KB
testcase_01 AC 39 ms
55,736 KB
testcase_02 AC 40 ms
55,736 KB
testcase_03 AC 36 ms
53,588 KB
testcase_04 AC 36 ms
53,588 KB
testcase_05 AC 36 ms
55,736 KB
testcase_06 AC 36 ms
55,736 KB
testcase_07 AC 36 ms
55,736 KB
testcase_08 AC 38 ms
55,736 KB
testcase_09 AC 39 ms
55,736 KB
testcase_10 AC 38 ms
55,736 KB
testcase_11 AC 41 ms
55,736 KB
testcase_12 AC 340 ms
99,068 KB
testcase_13 AC 347 ms
98,936 KB
testcase_14 AC 369 ms
99,552 KB
testcase_15 AC 419 ms
100,048 KB
testcase_16 AC 661 ms
114,252 KB
testcase_17 AC 801 ms
131,736 KB
testcase_18 AC 972 ms
146,116 KB
testcase_19 AC 1,139 ms
155,596 KB
testcase_20 AC 1,165 ms
154,112 KB
testcase_21 AC 1,169 ms
155,456 KB
testcase_22 AC 1,171 ms
159,928 KB
testcase_23 AC 1,167 ms
160,580 KB
testcase_24 AC 1,130 ms
160,580 KB
testcase_25 AC 868 ms
128,528 KB
testcase_26 AC 1,054 ms
142,040 KB
testcase_27 AC 1,202 ms
159,844 KB
testcase_28 AC 1,119 ms
158,980 KB
testcase_29 AC 1,113 ms
159,868 KB
testcase_30 AC 633 ms
131,012 KB
testcase_31 AC 730 ms
149,424 KB
testcase_32 AC 798 ms
164,108 KB
testcase_33 AC 838 ms
166,072 KB
testcase_34 AC 812 ms
168,972 KB
testcase_35 AC 821 ms
170,008 KB
testcase_36 AC 821 ms
168,744 KB
testcase_37 AC 1,191 ms
158,064 KB
testcase_38 AC 871 ms
131,592 KB
testcase_39 AC 840 ms
167,204 KB
testcase_40 AC 861 ms
183,664 KB
testcase_41 AC 861 ms
192,852 KB
testcase_42 AC 844 ms
194,244 KB
testcase_43 AC 831 ms
201,124 KB
testcase_44 AC 858 ms
145,296 KB
testcase_45 AC 914 ms
146,536 KB
testcase_46 AC 881 ms
154,168 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353


class UnionFind:
    def __init__(self, n):
        self.n = n
        self.par = [-1] * n
        self.group_ = n

    def find(self, x):
        if self.par[x] < 0:
            return x
        lst = []
        while self.par[x] >= 0:
            lst.append(x)
            x = self.par[x]
        for y in lst:
            self.par[y] = x
        return x

    def unite(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x == y:
            return False

        if self.par[x] > self.par[y]:
            x, y = y, x

        self.par[x] += self.par[y]
        self.par[y] = x
        self.group_ -= 1
        return True

    def size(self, x):
        return -self.par[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    @property
    def group(self):
        return self.group_


class BIT:
    def __init__(self, n):
        self.n = n
        self.data = [0] * (n + 1)
        if n == 0:
            self.n0 = 0
        else:
            self.n0 = 1 << (n.bit_length() - 1)

    def sum_(self, i):
        s = 0
        while i > 0:
            s += self.data[i]
            i -= i & -i
        return s

    def sum(self, l, r=-1):
        if r == -1:
            return self.sum_(l)
        else:
            return self.sum_(r) - self.sum_(l)

    def get(self, i):
        return self.sum(i, i + 1)

    def add(self, i, x):
        i += 1
        while i <= self.n:
            self.data[i] += x
            i += i & -i

    def lower_bound(self, x):
        if x <= 0:
            return 0
        i = 0
        k = self.n0
        while k > 0:
            if i + k <= self.n and self.data[i + k] < x:
                x -= self.data[i + k]
                i += k
            k //= 2
        return i + 1


class HLD:
    def __init__(self, n, edges=None):
        self.n = n
        if edges is None:
            self.edges = [[] for _ in range(n)]
        else:
            self.edges = edges
            # コピーしてないので注意

        self.size = [-1] * n
        self.par = [-1] * n
        self.depth = [-1] * n
        self.path_ind = [-1] * n
        self.path_root = []
        self.heavy_child = [-1] * n
        self.isheavy = [False] * n
        self.L = [-1] * n
        self.R = [-1] * n

    def add_edge(self, u, v):
        self.edges[u].append(v)
        self.edges[v].append(u)

    def read_edges(self, indexed=1):
        for _ in range(self.n - 1):
            u, v = map(int, input().split())
            u -= indexed
            v -= indexed
            self.add_edge(u, v)

    def build(self, root=0):
        self.depth[root] = 0
        st = [root]
        route = [root]
        while st:
            pos = st.pop()
            for npos in self.edges[pos]:
                if self.depth[npos] == -1:
                    self.depth[npos] = self.depth[pos] + 1
                    self.par[npos] = pos
                    st.append(npos)
                    route.append(npos)

        for pos in route[::-1]:
            self.size[pos] = 1
            ma = -1
            for npos in self.edges[pos]:
                if self.size[npos] != -1:
                    self.size[pos] += self.size[npos]
                    if self.size[npos] > ma:
                        ma = self.size[npos]
                        self.heavy_child[pos] = npos

            if ma != -1:
                self.isheavy[self.heavy_child[pos]] = True

        self.isheavy[root] = True

        path = 0
        st = [~root, root]
        self.path_root = [root]
        cc = 0
        while st:
            pos = st.pop()
            if pos >= 0:
                self.L[pos] = cc
                cc += 1
                if not self.isheavy[pos]:
                    path += 1
                    self.path_root.append(pos)

                self.path_ind[pos] = path
                for npos in self.edges[pos]:
                    if npos == self.par[pos] or npos == self.heavy_child[pos]:
                        continue
                    st.append(~npos)
                    st.append(npos)

                if self.heavy_child[pos] != -1:
                    npos = self.heavy_child[pos]
                    st.append(~npos)
                    st.append(npos)

            else:
                self.R[~pos] = cc

    def get_path(self, u, v):
        ll = [u]
        rr = [v]
        while self.path_ind[u] != self.path_ind[v]:
            if (
                self.depth[self.path_root[self.path_ind[u]]]
                >= self.depth[self.path_root[self.path_ind[v]]]
            ):
                u = self.path_root[self.path_ind[u]]
                ll.append(u)
                u = self.par[u]
                ll.append(u)
            else:
                v = self.path_root[self.path_ind[v]]
                rr.append(v)
                v = self.par[v]
                rr.append(v)

        ll += rr[::-1]
        res = []
        for i in range(0, len(ll), 2):
            res.append((ll[i], ll[i + 1]))

        return res

    def lca(self, u, v):
        while self.path_ind[u] != self.path_ind[v]:
            if (
                self.depth[self.path_root[self.path_ind[u]]]
                >= self.depth[self.path_root[self.path_ind[v]]]
            ):
                u = self.par[self.path_root[self.path_ind[u]]]
            else:
                v = self.par[self.path_root[self.path_ind[v]]]

        if self.depth[u] >= self.depth[v]:
            return v
        else:
            return u

    def dist(self, u, v):
        return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)]

    def reorder(self, A, rev=False):
        ret = [0] * self.n
        for i in range(self.n):
            ret[self.L[i]] = A[i]

        if rev:
            ret = ret[::-1]
        return ret


n, m = map(int, input().split())
assert 2 <= n <= 200000
assert 1 <= m <= 200000

E = [[] for _ in range(n)]
se = set()
for _ in range(m):
    u, v = map(int, input().split())
    assert 1 <= u < v <= n
    assert u * n + v not in se
    se.add(u * n + v)

    E[v - 1].append(u - 1)

UF = UnionFind(n)
hld = HLD(n)
ma = [i for i in range(n)]
for v in range(n):
    for u in E[v]:
        if not UF.same(u, v):
            hld.add_edge(ma[UF.find(u)], v)
            UF.unite(u, v)
            ma[UF.find(u)] = v
assert UF.group == 1

hld.build()
bit = BIT(n)

for v in range(n):
    lr = []
    for u in E[v]:
        for uu, vv in hld.get_path(u, v):
            uu = hld.L[uu]
            vv = hld.L[vv]
            if uu > vv:
                uu, vv = vv, uu
            lr.append(uu * (n + 1) + (vv + 1))

    lr.sort()
    tot = 0
    mar = 0
    for tmp in lr:
        l = tmp // (n + 1)
        r = tmp - l * (n + 1)
        l = max(l, mar)
        if l < r:
            tot += bit.sum(l, r)
        mar = max(mar, r)

    bit.add(hld.L[v], tot % MOD + 1)

ans = bit.sum(0, n)
print(ans % MOD)
0