結果

問題 No.2770 Coupon Optimization
ユーザー 👑 hahhohahho
提出日時 2024-05-11 00:19:08
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,817 ms / 3,000 ms
コード長 4,375 bytes
コンパイル時間 218 ms
コンパイル使用メモリ 82,080 KB
実行使用メモリ 159,424 KB
最終ジャッジ日時 2024-05-11 00:19:33
合計ジャッジ時間 23,446 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 48 ms
59,380 KB
testcase_01 AC 45 ms
59,668 KB
testcase_02 AC 54 ms
59,140 KB
testcase_03 AC 1,734 ms
157,304 KB
testcase_04 AC 1,728 ms
155,176 KB
testcase_05 AC 84 ms
101,608 KB
testcase_06 AC 1,791 ms
144,172 KB
testcase_07 AC 1,751 ms
144,544 KB
testcase_08 AC 1,816 ms
158,536 KB
testcase_09 AC 498 ms
103,548 KB
testcase_10 AC 925 ms
110,400 KB
testcase_11 AC 902 ms
107,984 KB
testcase_12 AC 927 ms
119,304 KB
testcase_13 AC 908 ms
113,876 KB
testcase_14 AC 1,817 ms
158,128 KB
testcase_15 AC 1,787 ms
159,424 KB
testcase_16 AC 1,817 ms
158,584 KB
testcase_17 AC 1,799 ms
158,324 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from cmath import rect, pi


def reverse_bits32(x: int):
    x = ((x & 0x55555555) << 1) | ((x & 0xAAAAAAAA) >> 1)
    x = ((x & 0x33333333) << 2) | ((x & 0xCCCCCCCC) >> 2)
    x = ((x & 0x0F0F0F0F) << 4) | ((x & 0xF0F0F0F0) >> 4)
    x = ((x & 0x00FF00FF) << 8) | ((x & 0xFF00FF00) >> 8)
    return ((x & 0x0000FFFF) << 16) | ((x & 0xFFFF0000) >> 16)

def prime_factors(n):
    """
    nの素因数列を生成する

    :param n: 自然数
    :return: nの素因数を小さいものから順に返すgenerator
    """
    i = 2
    while i * i <= n:
        if n % i:
            i += 1
        else:
            n //= i
            yield i
    if n > 1:
        yield n

def totient_factors(n):
    def it():
        prev = -1
        for p in prime_factors(n):
            if p == prev:
                yield p
            else:
                prev = p
                for q in prime_factors(p - 1):
                    yield q
    return it()


def int_product(iterable):
    x = 1
    for y in iterable:
        x *= y
    return x

def primitive_root(mod, phi_factors=None):
    if phi_factors is None:
        phi_factors = tuple(totient_factors(mod))
    phi = int_product(phi_factors)
    primes = set(phi_factors)
    for i in range(2, mod):
        for p in primes:
            if modpow(i, (phi // p), mod) == 1:
                break
        else:
            return i
    else:
        raise ValueError(f'There is no primitive root for modulo {mod}')


def modinv(x: int, mod: int) -> int:
    """
    Z/(mod Z)上でのxの逆元

    :param x: 整数
    :param mod: 整数
    :return: x * y % mod = 1を満たすy
    """
    s, ps, r, pr = 0, 1, mod, x
    while r != 0:
        pr, (q, r) = r, divmod(pr, r)
        ps, s = s, ps - q * s
    if pr == 1:
        return ps if ps >= 0 else ps + mod
    raise ValueError("base is not invertible for the given modulus")


def modpow(x, k, mod):
    """
    Z/(mod Z)上でのxのk乗

    :param x: 整数
    :param k: 整数
    :param mod: 整数
    :return: x ** k % mod
    """
    if k < 0:
        x = modinv(x, mod)
        k = -k
    r = 1
    while k != 0:
        if k & 1:
            r = (r * x) % mod
        x = (x * x) % mod
        k >>= 1
    return r

def ntt(a, mod: int, inverse: bool = False):
    if type(a[0]) is not int:
        for i, v in enumerate(a):
            a[i] = int(v)
    n = (len(a) - 1).bit_length()
    d2 = 0
    r = 1
    phi_factors = tuple(totient_factors(mod))
    for p in phi_factors:
        if p == 2:
            d2 += 1
        else:
            r *= p
    if d2 < n:
        raise ValueError(f'Given array is too long: modulo {mod} only support array length up to {2 ** d2}')

    pr = primitive_root(mod, phi_factors)
    if inverse:
        pr = modinv(pr, mod)
    pows = [modpow(pr, r * 2 ** (d2 - n), mod)]
    for _ in range(n - 1):
        pows.append(pows[-1] ** 2 % mod)
    pows = tuple(reversed(pows))

    m = 2 ** n
    a.extend(0 for _ in  range(m - len(a)))

    shift = 32 - n
    for i in range(m):
        j = reverse_bits32(i) >> shift
        if i < j:
            a[i], a[j] = a[j], a[i]

    for i in range(m):
        b = 1
        for w1 in pows:
            if not i & b:
                break
            i ^= b
            w = 1
            while not i & b:
                j = i | b
                s = a[i]
                t = a[j] * w
                a[i] = (s + t) % mod
                a[j] = (s - t) % mod
                w = (w * w1) % mod
                i += 1
            i ^= b
            b <<= 1

    if inverse:
        c = modinv(m, mod)
        for i, v in enumerate(a):
            a[i] = (v * c) % mod
    return a


n, m = map(int, input().split())
aa = list(map(int,input().split()))
bb = list(map(int,input().split()))

aa = [a//100 for a in aa]
bb = [100-b for b in bb]

aa.sort()
bb.sort()
if len(bb) < len(aa):
    bb += [100]*(len(aa)-len(bb))
elif len(bb) > len(aa):
    bb = bb[:len(aa)]

aa += [0]*len(aa)
bb += [0]*len(bb)
p0 = 924844033
fa = ntt(aa[:], p0)
fb = ntt(bb[:], p0)
fc = [(u*v)%p0 for u,v in zip(fa, fb)]
cc0 = ntt(fc, p0, inverse=True)
p1 = 998244353
fa = ntt(aa, p1)
fb = ntt(bb, p1)
fc = [(u*v)%p1 for u,v in zip(fa, fb)]
cc1 = ntt(fc, p1, inverse=True)

k0 = p1*modinv(p1, p0)
k1 = p0*modinv(p0, p1)
kk = p0*p1
for c0, c1 in zip(cc0[:n], cc1[:n]):
    print((c0*k0 + c1*k1)%kk)
0