結果

問題 No.880 Yet Another Segment Tree Problem
ユーザー あじゃじゃ
提出日時 2025-03-22 13:42:09
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 5,409 bytes
コンパイル時間 419 ms
コンパイル使用メモリ 82,364 KB
実行使用メモリ 282,092 KB
最終ジャッジ日時 2025-03-22 13:43:58
合計ジャッジ時間 11,468 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 5 WA * 24 TLE * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
from typing import List

BINF = 1 << 30

class S:
    __slots__ = ('max', 'lcm', 'sz', 'sum', 'fail', 'all_same')
    def __init__(self, x: int = 0, sz: int = 1):
        if x == 0 and sz == 1:
            self.max = 0
            self.lcm = 1
            self.sz = 1
            self.sum = 0
            self.fail = False
            self.all_same = False
        else:
            self.max = x
            self.lcm = x
            self.sz = sz
            self.sum = x * sz
            self.fail = False
            self.all_same = True

class F:
    __slots__ = ('dogcd', 'reset')
    def __init__(self, dogcd: int = 0, reset: int = 0):
        self.dogcd = dogcd
        self.reset = reset
    @staticmethod
    def gcd(g: int) -> 'F':
        return F(g, 0)
    @staticmethod
    def update(a: int) -> 'F':
        return F(0, a)

def e() -> S:
    return S()

def id_F() -> F:
    return F()

def composition(fnew: F, fold: F) -> F:
    if fnew.reset:
        return fnew
    return F(math.gcd(fnew.dogcd, fold.dogcd), fold.reset)

def mapping(f: F, x: S) -> S:
    if x.fail:
        return x
    if f.reset:
        x = S(f.reset, x.sz)
    if f.dogcd:
        if x.all_same:
            x = S(math.gcd(f.dogcd, x.max), x.sz)
        else:
            if x.lcm == BINF or f.dogcd % x.lcm != 0:
                x.fail = True
    return x

def op(l: S, r: S) -> S:
    if r.sz == 0:
        return l
    if l.sz == 0:
        return r
    ret = S()
    ret.max = max(l.max, r.max)
    ret.sum = l.sum + r.sum
    g = math.gcd(l.lcm, r.lcm)
    lcm_val = (l.lcm * r.lcm) // g if g else 0
    ret.lcm = BINF if (l.lcm >= BINF or r.lcm >= BINF or lcm_val >= BINF) else lcm_val
    ret.sz = l.sz + r.sz
    ret.all_same = l.all_same and r.all_same and (l.max == r.max)
    ret.fail = l.fail or r.fail
    return ret

class SegtreeBeats:
    def __init__(self, arr: List[int]):
        self._n = len(arr)
        self.log = 0
        self.size = 1
        while self.size < self._n:
            self.size *= 2
            self.log += 1
        self.data = [e() for _ in range(2 * self.size)]
        self.lazy = [id_F() for _ in range(self.size)]
        for i in range(self._n):
            self.data[self.size + i] = S(arr[i], 1)
        for i in range(self._n, self.size):
            self.data[self.size + i] = e()
        for i in range(self.size - 1, 0, -1):
            self._update(i)
    def _update(self, k: int):
        self.data[k] = op(self.data[2 * k], self.data[2 * k + 1])
    def _all_apply(self, k: int, f: F):
        self.data[k] = mapping(f, self.data[k])
        if k < self.size:
            self.lazy[k] = composition(f, self.lazy[k])
            if self.data[k].fail:
                self._push(k)
                self._update(k)
    def _push(self, k: int):
        self._all_apply(2 * k, self.lazy[k])
        self._all_apply(2 * k + 1, self.lazy[k])
        self.lazy[k] = id_F()
    def prod(self, l: int, r: int) -> S:
        if l == r:
            return e()
        l += self.size
        r += self.size
        for i in range(self.log, 0, -1):
            if ((l >> i) << i) != l:
                self._push(l >> i)
            if ((r >> i) << i) != r:
                self._push((r - 1) >> i)
        sml = e()
        smr = e()
        while l < r:
            if l & 1:
                sml = op(sml, self.data[l])
                l += 1
            if r & 1:
                r -= 1
                smr = op(self.data[r], smr)
            l //= 2
            r //= 2
        return op(sml, smr)
    def apply_point(self, p: int, f: F):
        p += self.size
        for i in range(self.log, 0, -1):
            self._push(p >> i)
        self.data[p] = mapping(f, self.data[p])
        for i in range(1, self.log + 1):
            self._update(p >> i)
    def apply_range(self, l: int, r: int, f: F):
        if l == r:
            return
        l0 = l + self.size
        r0 = r + self.size
        for i in range(self.log, 0, -1):
            if ((l0 >> i) << i) != l0:
                self._push(l0 >> i)
            if ((r0 >> i) << i) != r0:
                self._push((r0 - 1) >> i)
        l1, r1 = l0, r0
        while l0 < r0:
            if l0 & 1:
                self._all_apply(l0, f)
                l0 += 1
            if r0 & 1:
                r0 -= 1
                self._all_apply(r0, f)
            l0 //= 2
            r0 //= 2
        for i in range(1, self.log + 1):
            if ((l1 >> i) << i) != l1:
                self._update(l1 >> i)
            if (((r1 - 1) >> i) << i) != (r1 - 1):
                self._update((r1 - 1) >> i)

def main():
    input_data = sys.stdin.read().split()
    it = iter(input_data)
    N = int(next(it))
    Q = int(next(it))
    A = [int(next(it)) for _ in range(N)]
    seg = SegtreeBeats(A)
    output_lines = []
    for _ in range(Q):
        q = int(next(it))
        l = int(next(it)) - 1
        r = int(next(it))
        if q <= 2:
            x = int(next(it))
            if q == 1:
                seg.apply_range(l, r, F.update(x))
            else:
                seg.apply_range(l, r, F.gcd(x))
        else:
            v = seg.prod(l, r)
            if q == 3:
                output_lines.append(str(v.max))
            elif q == 4:
                output_lines.append(str(v.sum))
    sys.stdout.write("\n".join(output_lines))

if __name__ == '__main__':
    main()
0