結果

問題 No.3452 Divide Permutation
コンテスト
ユーザー 👑 potato167
提出日時 2026-02-13 04:49:45
言語 PyPy3
(7.3.17)
結果
TLE  
実行時間 -
コード長 9,935 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 288 ms
コンパイル使用メモリ 82,224 KB
実行使用メモリ 207,008 KB
最終ジャッジ日時 2026-02-20 20:55:09
合計ジャッジ時間 17,492 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 8 TLE * 17 -- * 44
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
sys.setrecursionlimit(10**7)

MOD = 998244353
BASE = 10

# --------- SegTree for int (min/max) ---------

class SegTreeInt:
    __slots__ = ("n", "size", "op", "E", "d")

    def __init__(self, op, E, v=None, n=0):
        self.op = op
        self.E = E
        if v is None:
            self.n = n
            size = 1
            while size < n:
                size <<= 1
            self.size = size
            self.d = [E] * (2 * size)
        else:
            self.n = len(v)
            size = 1
            while size < self.n:
                size <<= 1
            self.size = size
            d = [E] * (2 * size)
            d[size:size + self.n] = v
            for i in range(size - 1, 0, -1):
                d[i] = op(d[i << 1], d[i << 1 | 1])
            self.d = d

    def set(self, p, x):
        op = self.op
        d = self.d
        p += self.size
        d[p] = x
        p >>= 1
        while p:
            d[p] = op(d[p << 1], d[p << 1 | 1])
            p >>= 1

    def get(self, p):
        return self.d[p + self.size]

    def prod(self, l, r):
        op = self.op
        d = self.d
        E = self.E
        sml = E
        smr = E
        l += self.size
        r += self.size
        while l < r:
            if l & 1:
                sml = op(sml, d[l])
                l += 1
            if r & 1:
                r -= 1
                smr = op(d[r], smr)
            l >>= 1
            r >>= 1
        return op(sml, smr)

    def all_prod(self):
        return self.d[1]

    def max_right(self, l, pred):
        # pred(prod(l, r)) == True を満たす最大 r
        if l == self.n:
            return self.n
        op = self.op
        d = self.d
        E = self.E
        sm = E
        l += self.size
        while True:
            while (l & 1) == 0:
                l >>= 1
            nxt = op(sm, d[l])
            if not pred(nxt):
                while l < self.size:
                    l <<= 1
                    nxt2 = op(sm, d[l])
                    if pred(nxt2):
                        sm = nxt2
                        l += 1
                return l - self.size
            sm = nxt
            l += 1
            if (l & -l) == l:
                break
        return self.n

    def min_left(self, r, pred):
        if r == 0:
            return 0
        op = self.op
        d = self.d
        E = self.E
        sm = E
        r += self.size
        while True:
            r -= 1
            while r > 1 and (r & 1):
                r >>= 1
            nxt = op(d[r], sm)
            if not pred(nxt):
                while r < self.size:
                    r = (r << 1) + 1
                    nxt2 = op(d[r], sm)
                    if pred(nxt2):
                        sm = nxt2
                        r -= 1
                return r + 1 - self.size
            sm = nxt
            if (r & -r) == r:
                break
        return 0


# --------- SegTree for F=(v,b) tuple ---------

E_F = (0, 1)

def op_F(l, r):
    # l=(lv,lb), r=(rv,rb)
    lv, lb = l
    rv, rb = r
    return ((lv * rb + rv) % MOD, (lb * rb) % MOD)

class SegTreeF:
    __slots__ = ("n", "size", "d")

    def __init__(self, v=None, n=0):
        if v is None:
            self.n = n
            size = 1
            while size < n:
                size <<= 1
            self.size = size
            self.d = [E_F] * (2 * size)
        else:
            self.n = len(v)
            size = 1
            while size < self.n:
                size <<= 1
            self.size = size
            d = [E_F] * (2 * size)
            d[size:size + self.n] = v
            for i in range(size - 1, 0, -1):
                d[i] = op_F(d[i << 1], d[i << 1 | 1])
            self.d = d

    def set(self, p, x):
        d = self.d
        p += self.size
        d[p] = x
        p >>= 1
        while p:
            d[p] = op_F(d[p << 1], d[p << 1 | 1])
            p >>= 1

    def get(self, p):
        return self.d[p + self.size]

    def prod(self, l, r):
        d = self.d
        sml = E_F
        smr = E_F
        l += self.size
        r += self.size
        while l < r:
            if l & 1:
                sml = op_F(sml, d[l])
                l += 1
            if r & 1:
                r -= 1
                smr = op_F(d[r], smr)
            l >>= 1
            r >>= 1
        return op_F(sml, smr)

    def all_prod(self):
        return self.d[1]

    def max_right(self, l, pred):
        if l == self.n:
            return self.n
        d = self.d
        sm = E_F
        l += self.size
        while True:
            while (l & 1) == 0:
                l >>= 1
            nxt = op_F(sm, d[l])
            if not pred(nxt):
                while l < self.size:
                    l <<= 1
                    nxt2 = op_F(sm, d[l])
                    if pred(nxt2):
                        sm = nxt2
                        l += 1
                return l - self.size
            sm = nxt
            l += 1
            if (l & -l) == l:
                break
        return self.n

    def min_left(self, r, pred):
        if r == 0:
            return 0
        d = self.d
        sm = E_F
        r += self.size
        while True:
            r -= 1
            while r > 1 and (r & 1):
                r >>= 1
            nxt = op_F(d[r], sm)
            if not pred(nxt):
                while r < self.size:
                    r = (r << 1) + 1
                    nxt2 = op_F(d[r], sm)
                    if pred(nxt2):
                        sm = nxt2
                        r -= 1
                return r + 1 - self.size
            sm = nxt
            if (r & -r) == r:
                break
        return 0


def solve(p_in):
    p = p_in[:]  # 1-indexed values
    N = len(p)

    # p_hash を O(N) 構築(N log N の set をやめる)
    ph = [None] * N
    for i, x in enumerate(p):
        ph[i] = (x % MOD, BASE)
        p[i] = x - 1
    p_hash = SegTreeF(ph)

    inv = [0] * N
    for i, x in enumerate(p):
        inv[x] = i

    # int segtrees
    def op_max(a, b): return a if a >= b else b
    def op_min(a, b): return a if a <= b else b

    p_max = SegTreeInt(op_max, -1, v=p)
    p_min = SegTreeInt(op_min, 1 << 30, v=p)
    cut_min = SegTreeInt(op_min, 1 << 30, n=N)
    mex = SegTreeInt(op_max, -1, n=N)

    ans_seg = SegTreeF(n=N)
    perm_seg = SegTreeF(n=N)

    # predicates
    def pred_F(x):
        return x[1] == 1

    target = [0]
    def pred_g(x):
        return x < target[0]

    # helpers
    def seg_set(l, r):
        seg = p_hash.prod(l, r)
        perm_seg.set(l, seg)
        ans_seg.set(p[l], seg)

    seg_set(0, N)

    def sp_mex(ind):
        m = ans_seg.max_right(p[ind] + 1, pred_F)
        r = perm_seg.max_right(ind + 1, pred_F)
        target[0] = m
        nr = p_max.max_right(ind, pred_g)
        mex.set(p[ind], nr if nr < r else -1)

    def calc_cut_min(ind):
        r = perm_seg.max_right(ind + 1, pred_F)
        res = p_min.prod(ind + 1, r)
        return -1 if p[ind] < res else res

    def set_calc_min(ind):
        tmp = calc_cut_min(ind)
        if tmp != -1:
            cut_min.set(tmp, tmp)

    def cut(i):
        if perm_seg.get(i)[1] == 1:
            a = perm_seg.min_left(i, pred_F) - 1
            c = perm_seg.max_right(i + 1, pred_F)

            tmp = calc_cut_min(a)
            if tmp != -1:
                cut_min.set(tmp, 1 << 30)

            seg_set(a, i)
            seg_set(i, c)

            sp_mex(a)
            sp_mex(i)

            tmp2 = ans_seg.min_left(p[i], pred_F)
            if tmp2:
                sp_mex(inv[tmp2 - 1])

            set_calc_min(a)
            set_calc_min(i)
            return True
        return False

    res = [0] * N
    ind = 0
    res[0] = ans_seg.all_prod()[0]
    ind = 1

    for n in range(N):
        if n == 0:
            if p[0] != 0:
                cut(inv[0])
                res[ind] = ans_seg.all_prod()[0]
                ind += 1
            continue

        if p[0] != n and inv[n - 1] + 1 == inv[n]:
            continue

        back_cut = 0
        if p[N - 1] != n - 1 and p[inv[n - 1] + 1] > n:
            back_cut = 1

        front_cut = 0
        if p[0] != n and p[inv[n] - 1] > n:
            front_cut = 1

        if back_cut + front_cut == 0:
            continue

        if back_cut + front_cut == 2:
            target[0] = 0
            tmp = N
            if mex.all_prod() >= 0:
                tmp = mex.max_right(0, pred_g)

            cm = cut_min.all_prod()
            if tmp > cm:
                tmp = cm
                m = inv[tmp]
                l = perm_seg.min_left(m, pred_F) - 1
                r = perm_seg.max_right(m, pred_F)
                seg_set(l, m)
                seg_set(m, r)
                res[ind] = ans_seg.all_prod()[0]
                ind += 1
                seg_set(l, r)
                seg_set(m, m)
            elif tmp == N:
                res[ind] = res[ind - 1]
                ind += 1
            else:
                l = inv[tmp]
                r = perm_seg.max_right(l + 1, pred_F)
                m = mex.get(tmp)
                seg_set(l, m)
                seg_set(m, r)
                res[ind] = ans_seg.all_prod()[0]
                ind += 1
                seg_set(l, r)
                seg_set(m, m)

        if back_cut:
            cut(inv[n - 1] + 1)
        if front_cut:
            cut(inv[n])

        res[ind] = ans_seg.all_prod()[0]
        ind += 1

    while ind < N:
        res[ind] = res[ind - 1]
        ind += 1
    return res


def main():
    it = iter(sys.stdin.buffer.read().split())
    T = int(next(it))
    out_lines = []
    for _ in range(T):
        N = int(next(it))
        p = [int(next(it)) for _ in range(N)]
        ans = solve(p)
        out_lines.append(" ".join(map(str, ans)))
    sys.stdout.write("\n".join(out_lines))


if __name__ == "__main__":
    main()
0