結果

問題 No.2171 OR Assignment
ユーザー 👑 tute7627tute7627
提出日時 2022-11-21 23:00:34
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 3,379 ms / 3,500 ms
コード長 6,814 bytes
コンパイル時間 292 ms
コンパイル使用メモリ 82,260 KB
実行使用メモリ 130,024 KB
最終ジャッジ日時 2024-11-18 06:00:15
合計ジャッジ時間 44,814 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 69 ms
68,480 KB
testcase_01 AC 72 ms
68,608 KB
testcase_02 AC 72 ms
68,480 KB
testcase_03 AC 75 ms
68,352 KB
testcase_04 AC 73 ms
68,352 KB
testcase_05 AC 76 ms
68,352 KB
testcase_06 AC 75 ms
68,352 KB
testcase_07 AC 75 ms
68,480 KB
testcase_08 AC 73 ms
68,736 KB
testcase_09 AC 75 ms
68,992 KB
testcase_10 AC 71 ms
68,608 KB
testcase_11 AC 71 ms
68,480 KB
testcase_12 AC 798 ms
124,288 KB
testcase_13 AC 804 ms
124,416 KB
testcase_14 AC 795 ms
124,884 KB
testcase_15 AC 313 ms
126,028 KB
testcase_16 AC 406 ms
126,380 KB
testcase_17 AC 352 ms
124,928 KB
testcase_18 AC 1,748 ms
129,808 KB
testcase_19 AC 1,804 ms
125,972 KB
testcase_20 AC 2,740 ms
126,064 KB
testcase_21 AC 2,852 ms
126,280 KB
testcase_22 AC 3,105 ms
126,176 KB
testcase_23 AC 2,355 ms
129,448 KB
testcase_24 AC 3,085 ms
129,848 KB
testcase_25 AC 3,379 ms
128,592 KB
testcase_26 AC 3,255 ms
127,628 KB
testcase_27 AC 2,958 ms
127,628 KB
testcase_28 AC 3,010 ms
127,996 KB
testcase_29 AC 2,874 ms
129,724 KB
testcase_30 AC 2,396 ms
130,024 KB
testcase_31 AC 3,062 ms
129,816 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import typing

def _is_prime(n: int) -> bool:
    '''
    Reference:
    M. Forisek and J. Jancina,
    Fast Primality Testing for Integers That Fit into a Machine Word
    '''

    if n <= 1:
        return False
    if n == 2 or n == 7 or n == 61:
        return True
    if n % 2 == 0:
        return False

    d = n - 1
    while d % 2 == 0:
        d //= 2

    for a in (2, 7, 61):
        t = d
        y = pow(a, t, n)
        while t != n - 1 and y != 1 and y != n - 1:
            y = y * y % n
            t <<= 1
        if y != n - 1 and t % 2 == 0:
            return False
    return True


def _inv_gcd(a: int, b: int) -> typing.Tuple[int, int]:
    a %= b
    if a == 0:
        return (b, 0)

    # Contracts:
    # [1] s - m0 * a = 0 (mod b)
    # [2] t - m1 * a = 0 (mod b)
    # [3] s * |m1| + t * |m0| <= b
    s = b
    t = a
    m0 = 0
    m1 = 1

    while t:
        u = s // t
        s -= t * u
        m0 -= m1 * u  # |m1 * u| <= |m1| * s <= b

        # [3]:
        # (s - t * u) * |m1| + t * |m0 - m1 * u|
        # <= s * |m1| - t * u * |m1| + t * (|m0| + |m1| * u)
        # = s * |m1| + t * |m0| <= b

        s, t = t, s
        m0, m1 = m1, m0

    # by [3]: |m0| <= b/g
    # by g != b: |m0| < b/g
    if m0 < 0:
        m0 += b // s

    return (s, m0)


def _primitive_root(m: int) -> int:
    if m == 2:
        return 1
    if m == 167772161:
        return 3
    if m == 469762049:
        return 3
    if m == 754974721:
        return 11
    if m == 998244353:
        return 3

    divs = [2] + [0] * 19
    cnt = 1
    x = (m - 1) // 2
    while x % 2 == 0:
        x //= 2

    i = 3
    while i * i <= x:
        if x % i == 0:
            divs[cnt] = i
            cnt += 1
            while x % i == 0:
                x //= i
        i += 2

    if x > 1:
        divs[cnt] = x
        cnt += 1

    g = 2
    while True:
        for i in range(cnt):
            if pow(g, (m - 1) // divs[i], m) == 1:
                break
        else:
            return g
        g += 1

class ModContext:
    context: typing.List[int] = []

    def __init__(self, mod: int) -> None:
        assert 1 <= mod

        self.mod = mod

    def __enter__(self) -> None:
        self.context.append(self.mod)

    def __exit__(self, exc_type: typing.Any, exc_value: typing.Any,
                 traceback: typing.Any) -> None:
        self.context.pop()

    @classmethod
    def get_mod(cls) -> int:
        return cls.context[-1]


class Modint:
    def __init__(self, v: int = 0) -> None:
        self._mod = ModContext.get_mod()
        if v == 0:
            self._v = 0
        else:
            self._v = v % self._mod

    def mod(self) -> int:
        return self._mod

    def val(self) -> int:
        return self._v

    def __iadd__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            self._v += rhs._v
        else:
            self._v += rhs
        if self._v >= self._mod:
            self._v -= self._mod
        return self

    def __isub__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            self._v -= rhs._v
        else:
            self._v -= rhs
        if self._v < 0:
            self._v += self._mod
        return self

    def __imul__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            self._v = self._v * rhs._v % self._mod
        else:
            self._v = self._v * rhs % self._mod
        return self

    def __ifloordiv__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            inv = rhs.inv()._v
        else:
            inv = atcoder._math._inv_gcd(rhs, self._mod)[1]
        self._v = self._v * inv % self._mod
        return self

    def __pos__(self) -> 'Modint':
        return self

    def __neg__(self) -> 'Modint':
        return Modint() - self

    def __pow__(self, n: int) -> 'Modint':
        assert 0 <= n

        return Modint(pow(self._v, n, self._mod))

    def inv(self) -> 'Modint':
        eg = atcoder._math._inv_gcd(self._v, self._mod)

        assert eg[0] == 1

        return Modint(eg[1])

    def __add__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            result = self._v + rhs._v
            if result >= self._mod:
                result -= self._mod
            return raw(result)
        else:
            return Modint(self._v + rhs)

    def __sub__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            result = self._v - rhs._v
            if result < 0:
                result += self._mod
            return raw(result)
        else:
            return Modint(self._v - rhs)

    def __mul__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            return Modint(self._v * rhs._v)
        else:
            return Modint(self._v * rhs)

    def __floordiv__(self, rhs: typing.Union['Modint', int]) -> 'Modint':
        if isinstance(rhs, Modint):
            inv = rhs.inv()._v
        else:
            inv = atcoder._math._inv_gcd(rhs, self._mod)[1]
        return Modint(self._v * inv)

    def __eq__(self, rhs: typing.Union['Modint', int]) -> bool:  # type: ignore
        if isinstance(rhs, Modint):
            return self._v == rhs._v
        else:
            return self._v == rhs

    def __ne__(self, rhs: typing.Union['Modint', int]) -> bool:  # type: ignore
        if isinstance(rhs, Modint):
            return self._v != rhs._v
        else:
            return self._v != rhs


def raw(v: int) -> Modint:
    x = Modint()
    x._v = v
    return x

class FenwickTree:
    '''Reference: https://en.wikipedia.org/wiki/Fenwick_tree'''

    def __init__(self, n: int = 0) -> None:
        self._n = n
        self.data = [Modint(0) for _ in range(n)]

    def add(self, p: int, x: typing.Any) -> None:
        assert 0 <= p < self._n

        p += 1
        while p <= self._n:
            self.data[p - 1] += x
            p += p & -p

    def sum(self, left: int, right: int) -> typing.Any:
        assert 0 <= left <= right <= self._n

        return self._sum(right) - self._sum(left)

    def _sum(self, r: int) -> typing.Any:
        s = Modint(0)
        while r > 0:
            s += self.data[r - 1]
            r -= r & -r

        return s

with ModContext(998244353):

  N = int(input())
  A = list(map(int,input().split()))

  dp = FenwickTree(N)
  dp.add(0, Modint(1))
  
  prev = [-1] * 30

  for i in range(N):
    for j in range(30):
      if (A[i] >> j) & 1:
        prev[j] = i 
    idx = sorted(list(set(prev)), reverse=True)
    for j in idx:
      if j != i:
        dp.add(j + 1, dp.sum(0, j + 1))

  print(dp.sum(0, N).val())
0