結果

問題 No.1864 Shortest Paths Counting
ユーザー LyricalMaestro
提出日時 2024-09-08 16:33:41
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 409 ms / 2,000 ms
コード長 3,259 bytes
コンパイル時間 354 ms
コンパイル使用メモリ 82,352 KB
実行使用メモリ 158,644 KB
最終ジャッジ日時 2024-09-08 16:33:51
合計ジャッジ時間 8,837 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 23
権限があれば一括ダウンロードができます

ソースコード

diff #

## https://yukicoder.me/problems/no/1864

MOD = 998244353

class SegmentTree:
    """
    非再帰版セグメント木。
    更新は「加法」、取得は「最大値」のもの限定。
    """

    def __init__(self, init_array):
        n = 1
        while n < len(init_array):
            n *= 2
        
        self.size = n
        self.array = [0] * (2 * self.size)
        for i, a in enumerate(init_array):
            self.array[self.size + i] = a
        
        end_index = self.size
        start_index = end_index // 2
        while start_index >= 1:
            for i in range(start_index, end_index):
                self.array[i] = self.array[2 * i] + self.array[2 * i + 1]
                self.array[i] %= MOD
            end_index = start_index
            start_index = end_index // 2

    def add(self, x, a):
        index = self.size + x
        self.array[index] += a
        self.array[index] %= MOD
        while index > 1:
            index //= 2
            self.array[index] = self.array[2 * index] + self.array[2 * index + 1]
            self.array[index] %= MOD

    def get_sum(self, l, r):
        L = self.size + l; R = self.size + r

        # 2. 区間[l, r)の最大値を求める
        s = 0
        while L < R:
            if R & 1:
                R -= 1
                s += self.array[R]
                s %= MOD
            if L & 1:
                s += self.array[L]
                s %= MOD
                L += 1                
            L >>= 1; R >>= 1
        return s


def main():
    N = int(input())
    xy = []
    for _ in range(N):
        x, y = map(int, input().split())
        xy.append((x, y))

    # チェビシェフ変換 -> マンハッタン距離 * 2 の世界で考える
    uv = []
    for x, y in xy:
        u = x + y
        v = x - y
        uv.append([u, v])
    
    # u0 < uN, v0 < vNの世界で考えても問題ないように標準化する
    if uv[0][0] > uv[-1][0]:
        for i in range(N):
            uv[i][0] *= -1
    if uv[0][1] > uv[-1][1]:
        for i in range(N):
            uv[i][1] *= -1
    
    # u0 <= uN, v0 <= vNの中にいる点だけ取り出す
    uv2 = []
    for u, v in uv:
        if uv[0][0] <= u <= uv[-1][0] and uv[0][1] <= v <= uv[-1][1]:
            uv2.append((u, v))

    # uベースでまとめる
    u_map = {}
    for u, v in uv2:
        if u not in u_map:
            u_map[u] = []
        u_map[u].append(v)

    # vベースでの座標圧縮
    v_set = set()
    for _, v in uv2:
        v_set.add(v)
    v_list = list(v_set)
    v_list.sort()
    v_map = {}
    for i, v in enumerate(v_list):
        v_map[v] = i
    v_max = len(v_list)
    
    # 平面捜査による解法で実施
    seg_tree = SegmentTree([0] * v_max)
    u_array = [(u, v_array) for u, v_array in u_map.items()]
    u_array.sort(key=lambda x : x[0])
    ans = -1
    for u, v_array in u_array:
        v_array.sort()
        for v in v_array:
            if (u, v) == uv2[0]:
                seg_tree.add(0, 1)
            else:
                x = seg_tree.get_sum(0, v_map[v] + 1)
                ans = x
                seg_tree.add(v_map[v], x)
        
    print(ans)


    





    



if __name__ == '__main__':
    main()
0