結果

問題 No.3579 区間積逆像
コンテスト
ユーザー tassei903
提出日時 2026-07-03 22:21:57
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
WA  
実行時間 -
コード長 6,791 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 143 ms
コンパイル使用メモリ 85,760 KB
実行使用メモリ 86,016 KB
最終ジャッジ日時 2026-07-03 22:22:06
合計ジャッジ時間 5,268 ms
ジャッジサーバーID
(参考情報)
judge2_0 / judge1_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 12 WA * 2 TLE * 1 -- * 15
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
input = lambda :sys.stdin.readline()[:-1]
ni = lambda :int(input())
na = lambda :list(map(int,input().split()))
yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES")
no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO")
#######################################################################
"""
x = (n - 1) / 2

2 * x == (n - 1) (mod b)

"""

def divisors(n):
    lower_divisors , upper_divisors = [], []
    i = 1
    while i*i <= n:
        if n % i == 0:
            lower_divisors.append(i)
            if i != n // i:
                upper_divisors.append(n//i)
        i += 1
    return lower_divisors + upper_divisors[::-1]

import sys
input = sys.stdin.readline

def gcd(x, y):
    while y:
        x, y = y, x % y
    return x
def solveDiscreteLogarithm(x, y, m):
    if y >= m or y < 0:
        return -1
    if x == 0:
        if m == 1:
            return 0
        if y == 1:
            return 0
        if y == 0:
            return 1
        return -1
    p = 3
    tmp = x - 1
    cnt = 0
    primes = []
    counts = []
    ps = 0
    while tmp & 1:
        tmp >>= 1
        cnt += 1
    if cnt:
        primes.append(2)
        counts.append(cnt)
        ps += 1
    tmp += 1
    while tmp != 1:
        cnt = 0
        while tmp % p == 0:
            tmp //= p
            cnt += 1
        if cnt:
            primes.append(p)
            counts.append(cnt)
            ps += 1
        p += 2
        if tmp != 1 and p * p > x:
            primes.append(tmp)
            counts.append(1)
            ps += 1
            break
    tail = 0
    mp = m
    for i in range(ps):
        f = 0
        while mp % primes[i] == 0:
            mp //= primes[i]
            f += 1
        if tail < (f + counts[i] - 1) // counts[i]:
            tail = (f + counts[i] - 1) // counts[i]
    z = 1
    for i in range(tail):
        if z == y:
            return i
        z = z * x % m
    if y % gcd(z, m):
        return -1
    p = 3
    u = mp
    tmp = mp - 1
    if tmp & 1:
        u >>= 1
        while tmp & 1:
            tmp >>= 1
    tmp += 1
    while tmp != 1:
        if tmp % p == 0:
            u //= p
            u *= p - 1
            while tmp % p == 0:
                tmp //= p
        p += 2
        if tmp != 1 and p * p > mp:
            u //= tmp
            u *= tmp - 1
            break
    p = 1
    loop = u
    while p * p <= u:
        if u % p == 0:
            if z * pow(x, p, m) % m == z:
                loop = p
                break
            ip = u // p
            if z * pow(x, ip, m) % m == z:
                loop = ip
        p += 1
    l, r = 0, loop+1
    sq = (loop+1) >> 1
    while r - l > 1:
        if sq * sq <= loop:
            l = sq
        else:
            r = sq
        sq = (l + r) >> 1
    if sq * sq < loop:
        sq += 1
    e = pow(x, loop, m)
    b = pow(pow(x, loop-1, m), sq, m)
    d = {}
    f = z
    for i in range(sq):
        d[f] = i
        f = f * x % m
    g = y
    for i in range(sq):
        if g in d:
            return i*sq+d[g]+tail
        g = g * b % m
    return -1

# https://judge.yosupo.jp/submission/265090

from math import isqrt
from random import randint

def gcd(x, y):
    """ x < y """
    while y:
        x, y = y, x%y
    return x

def is_prime(num):
    """ 1 <= x < 1<<64 """
    if num < 4: return num > 1
    if not num&1: return False
    
    d, s = num-1, 0
    while not d&1:
        d >>= 1
        s += 1
        
    tests = (2,7,61) if num < 4759123141 else (2,325,9375,28178,450775,9780504,1795265022)
        
    for test in tests:
        if test >= num: return True
        t = pow(test, d, num)
        if 1 < t < num-1:
            for _ in range(s-1):
                t = t*t%num
                if t == num-1: break
            else:
                return False
    return True

def find_prime(n):
    b = n.bit_length() - 1
    b = (b >> 2) << 2
    m = (1 << (b >> 3)) << 1
    while True:
        c = randint(1, n - 1)
        y = 0
        g = q = r = 1
        while g == 1:
            x = y
            for _ in range(r):
                y = (y * y + c) % n
            k = 0
            while k < r and g == 1:
                ys = y
                for _ in range(min(m, r - k)):
                    y = (y * y + c) % n
                    q = q * abs(x - y) % n
                g = gcd(q, n)
                k += m
            r <<= 1
        if g == n:
            g = 1
            y = ys
            while g == 1:
                y = (y * y + c) % n
                g = gcd(abs(x - y), n)
        if g == n:
            continue
        if is_prime(g):
            return g
        elif is_prime(n // g):
            return n // g
        else:
            n = g

def primefact(n):
    result = dict()
    for p in range(2, 500):
        if p * p > n:
            break
        c = 0
        while n%p == 0:
            n //= p
            c += 1
        if c:
            result[p] = c
    
    while n > 1 and not is_prime(n):
        p = find_prime(n)
        c = 0
        while n % p == 0:
            n //= p
            c += 1
        result[p] = c
    if n > 1: result[n] = 1
    return result

def primitive_root(p):
    """ p : prime """
    if p == 2: return 1
    
    r = p - 1
    tests = []
    for q in range(2, 500):
        if q * q > r:
            break
        if r % q == 0:
            while r % q == 0:
                r //= q
            tests.append((p - 1) // q)
    
    while r > 1 and not is_prime(r):
        q = find_prime(r)
        while r % q == 0:
            r //= q
        tests.append((p - 1) // q)
    if r > 1: tests.append((p - 1) // r)
    
    res = 2
    while True:
        for test in tests:
            if pow(res, test, p) == 1:
                break
        else:
            return res
        res = randint(3, p - 2)
    
from collections import defaultdict
def calc(a, c, p):
    n = len(a)
    x = 0
    ans = 0
    d = defaultdict(int)
    d[x] += 1
    for i in range(n):
        x = (x + a[i]) % (p - 1)
        ans += d[(c - x) % (p - 1)]
        d[x] += 1
    return ans

n, p, c = na()
a = na()

if c == 0:
    f = []
    for i in range(n):
        if a[i] == 0:
            f.append(i)
    ans = 0
    last = -1
    for i in f:
        ans += (i - last) * (n - i)
        last = i
    print(ans)
else:
    r = primitive_root(p)
    
    a = [solveDiscreteLogarithm(r, a[i], p) if a[i] != 0 else -1 for i in range(n)]
    c = solveDiscreteLogarithm(r, c, p)
    i = 0
    ans = 0
    while i < n:
        if a[i] == -1:
            i += 1
            continue
        j = i
        while j < n and a[j] != -1:
            j += 1
        
        
        ans += calc(a[i:j], c, p)
        i = j
    print(ans)
        
0