結果

問題 No.3452 Divide Permutation
コンテスト
ユーザー 👑 potato167
提出日時 2026-02-13 04:45:34
言語 PyPy3
(7.3.17)
結果
TLE  
実行時間 -
コード長 7,508 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 645 ms
コンパイル使用メモリ 82,208 KB
実行使用メモリ 56,048 KB
最終ジャッジ日時 2026-02-20 20:54:10
合計ジャッジ時間 7,594 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other TLE * 1 -- * 68
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
sys.setrecursionlimit(10**7)

MOD = 998244353
BASE = 10


# -------- F (rolling hash monoid) --------

class F:
    __slots__ = ("v", "b")

    def __init__(self, v=0, b=1):
        self.v = v % MOD
        self.b = b % MOD


def op_F(l: F, r: F) -> F:
    # l.v = l.v * r.b + r.v
    # l.b = l.b * r.b
    return F(
        (l.v * r.b + r.v) % MOD,
        (l.b * r.b) % MOD
    )


def e_F() -> F:
    return F(0, 1)


# -------- integer monoids --------

def op_max(a, b):
    return a if a >= b else b


def e_max():
    return -1


def op_min(a, b):
    return a if a <= b else b


def e_min():
    return 1 << 30


# -------- SegTree (ACL compatible) --------

class SegTree:
    __slots__ = ("n", "size", "log", "op", "e", "d")

    def __init__(self, op, e, v=None, n=0):
        self.op = op
        self.e = e
        if v is None:
            self.n = n
            size = 1
            log = 0
            while size < n:
                size <<= 1
                log += 1
            self.size = size
            self.log = log
            self.d = [e() for _ in range(2 * size)]
        else:
            self.n = len(v)
            size = 1
            log = 0
            while size < self.n:
                size <<= 1
                log += 1
            self.size = size
            self.log = log
            self.d = [e() for _ in range(2 * size)]
            for i in range(self.n):
                self.d[size + i] = v[i]
            for i in range(size - 1, 0, -1):
                self.d[i] = op(self.d[2*i], self.d[2*i+1])

    def set(self, p, x):
        p += self.size
        self.d[p] = x
        p >>= 1
        while p:
            self.d[p] = self.op(self.d[2*p], self.d[2*p+1])
            p >>= 1

    def get(self, p):
        return self.d[p + self.size]

    def prod(self, l, r):
        sml = self.e()
        smr = self.e()
        l += self.size
        r += self.size
        while l < r:
            if l & 1:
                sml = self.op(sml, self.d[l])
                l += 1
            if r & 1:
                r -= 1
                smr = self.op(self.d[r], smr)
            l >>= 1
            r >>= 1
        return self.op(sml, smr)

    def all_prod(self):
        return self.d[1]

    def max_right(self, l, pred):
        if l == self.n:
            return self.n
        sm = self.e()
        l += self.size
        while True:
            while (l & 1) == 0:
                l >>= 1
            nxt = self.op(sm, self.d[l])
            if not pred(nxt):
                while l < self.size:
                    l <<= 1
                    nxt2 = self.op(sm, self.d[l])
                    if pred(nxt2):
                        sm = nxt2
                        l += 1
                return l - self.size
            sm = nxt
            l += 1
            if (l & -l) == l:
                break
        return self.n

    def min_left(self, r, pred):
        if r == 0:
            return 0
        sm = self.e()
        r += self.size
        while True:
            r -= 1
            while r > 1 and (r & 1):
                r >>= 1
            nxt = self.op(self.d[r], sm)
            if not pred(nxt):
                while r < self.size:
                    r = (r << 1) + 1
                    nxt2 = self.op(self.d[r], sm)
                    if pred(nxt2):
                        sm = nxt2
                        r -= 1
                return r + 1 - self.size
            sm = nxt
            if (r & -r) == r:
                break
        return 0


# -------- main solve --------

def solve(p_in):
    p = p_in[:]
    N = len(p)

    # rolling hash of original values
    p_hash = SegTree(op_F, e_F, n=N)
    for i in range(N):
        p_hash.set(i, F(p[i], BASE))
        p[i] -= 1

    inv = [0] * N
    for i in range(N):
        inv[p[i]] = i

    p_max = SegTree(op_max, e_max, v=p)
    p_min = SegTree(op_min, e_min, v=p)
    cut_min = SegTree(op_min, e_min, n=N)
    mex = SegTree(op_max, e_max, n=N)

    ans_seg = SegTree(op_F, e_F, n=N)
    perm_seg = SegTree(op_F, e_F, n=N)

    def pred_F(x):
        return x.b == 1

    target = [0]

    def pred_g(x):
        return x < target[0]

    def sp_mex(ind):
        m = ans_seg.max_right(p[ind] + 1, pred_F)
        r = perm_seg.max_right(ind + 1, pred_F)
        target[0] = m
        nr = p_max.max_right(ind, pred_g)
        mex.set(p[ind], nr if nr < r else -1)

    def calc_cut_min(ind):
        r = perm_seg.max_right(ind + 1, pred_F)
        res = p_min.prod(ind + 1, r)
        if p[ind] < res:
            return -1
        return res

    def set_calc_min(ind):
        tmp = calc_cut_min(ind)
        if tmp != -1:
            cut_min.set(tmp, tmp)

    def seg_set(l, r):
        perm_seg.set(l, p_hash.prod(l, r))
        ans_seg.set(p[l], perm_seg.get(l))

    seg_set(0, N)

    def cut(i):
        if perm_seg.get(i).b == 1:
            a = perm_seg.min_left(i, pred_F) - 1
            c = perm_seg.max_right(i + 1, pred_F)

            tmp = calc_cut_min(a)
            if tmp != -1:
                cut_min.set(tmp, e_min())

            seg_set(a, i)
            seg_set(i, c)

            sp_mex(a)
            sp_mex(i)

            tmp2 = ans_seg.min_left(p[i], pred_F)
            if tmp2:
                sp_mex(inv[tmp2 - 1])

            set_calc_min(a)
            set_calc_min(i)
            return True
        return False

    res = [0] * N
    ind = 0
    res[0] = ans_seg.all_prod().v
    ind = 1

    for n in range(N):
        if n == 0:
            if p[0] != 0:
                cut(inv[0])
                res[ind] = ans_seg.all_prod().v
                ind += 1
            continue

        if p[0] != n and inv[n - 1] + 1 == inv[n]:
            continue

        back_cut = 0
        if p[N - 1] != n - 1 and p[inv[n - 1] + 1] > n:
            back_cut = 1

        front_cut = 0
        if p[0] != n and p[inv[n] - 1] > n:
            front_cut = 1

        if back_cut + front_cut == 0:
            continue

        if back_cut + front_cut == 2:
            target[0] = 0
            tmp = mex.max_right(0, pred_g)

            cm = cut_min.all_prod()
            if tmp > cm:
                tmp = cm
                m = inv[tmp]
                l = perm_seg.min_left(m, pred_F) - 1
                r = perm_seg.max_right(m, pred_F)
                seg_set(l, m)
                seg_set(m, r)
                res[ind] = ans_seg.all_prod().v
                ind += 1
                seg_set(l, r)
                seg_set(m, m)
            elif tmp == N:
                res[ind] = res[ind - 1]
                ind += 1
            else:
                l = inv[tmp]
                r = perm_seg.max_right(l + 1, pred_F)
                m = mex.get(tmp)
                seg_set(l, m)
                seg_set(m, r)
                res[ind] = ans_seg.all_prod().v
                ind += 1
                seg_set(l, r)
                seg_set(m, m)

        if back_cut:
            cut(inv[n - 1] + 1)
        if front_cut:
            cut(inv[n])

        res[ind] = ans_seg.all_prod().v
        ind += 1

    while ind < N:
        res[ind] = res[ind - 1]
        ind += 1

    return res


def main():
    it = iter(sys.stdin.buffer.read().split())
    T = int(next(it))
    out = []
    for _ in range(T):
        N = int(next(it))
        p = [int(next(it)) for _ in range(N)]
        ans = solve(p)
        out.append(" ".join(map(str, ans)))
    print("\n".join(out))


if __name__ == "__main__":
    main()
0