結果

問題 No.3333 Consecutive Power Sum (Large)
コンテスト
ユーザー wasd314
提出日時 2025-10-28 21:46:38
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 7,333 bytes
コンパイル時間 335 ms
コンパイル使用メモリ 82,908 KB
実行使用メモリ 170,076 KB
最終ジャッジ日時 2025-11-02 21:26:50
合計ジャッジ時間 17,291 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1 RE * 2
other AC * 1 RE * 62
権限があれば一括ダウンロードができます

ソースコード

diff #

#! /usr/bin/env pypy

from math import gcd
from collections import Counter

def power_sum(e: int, l: int, r: int | None = None) -> int:
    """e-th power sum of [l, r) or [0, l)"""
    if r is not None:
        if 1 <= e <= 5:
            return power_sum(e, r) - power_sum(e, l)
        return sum(i**e for i in range(l, r))
    elif e == 1:
        return l * (l - 1) // 2
    elif e == 2:
        return l * (l - 1) * (2 * l - 1) // 6
    elif e == 3:
        return l * (l - 1) // 2 * l * (l - 1) // 2
    elif e == 4:
        return l * (l - 1) * (2 * l - 1) * (3 * l * l - 3 * l - 1) // 30
    elif e == 5:
        return (l - 1) * l // 2 * (l - 1) * l // 2 * (2 * (l - 1) * l - 1) // 3
    else:
        return 0

def min_true(l: int, r: int, pred):
    """min i s.t. l <= i < r and pred(i)"""
    if pred(l):
        return l
    while l + 1 < r:
        m = (l + r) // 2
        if pred(m):
            r = m
        else:
            l = m
    return r

def test_miller_rabin(n: int, bases: list):
    nn = n - 1
    e = (nn & -nn).bit_length() - 1
    o = n >> e
    # assert n == (o << e | 1)
    for b in bases:
        x = pow(b, o, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(e - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        if x != n - 1:
            return False
    return True

def is_prime(n: int):
    if n < 2:
        return False
    for p in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
        if n == p:
            return True
        if n % p == 0:
            return False
    if n < 41**2:
        return True
    if n < 2047:
        return test_miller_rabin(n, [2])
    if n < 90_80191:
        return test_miller_rabin(n, [31, 73])
    if n < 47591_23141:
        return test_miller_rabin(n, [2, 7, 61])
    if n < 112_20046_69633:
        return test_miller_rabin(n, [2, 13, 23, 16_62803])
    if n < 3_77057_95821_54547:
        return test_miller_rabin(n, [2, 8_80937, 25_70940, 6103_86380, 41307_85767])
    if n < 18446_74407_37095_51616:
        return test_miller_rabin(n, [2, 325, 9375, 28178, 450775, 9780504, 17952_65022])
    if n < 3186_65857_83403_11511_67461:
        return test_miller_rabin(n, [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37])
    if n < 33170_44064_67988_73859_61981:
        return test_miller_rabin(n, [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41])
    return test_miller_rabin(n, [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47])

def factorize(n: int):
    assert n >= 1
    if n == 1:
        return []
    if is_prime(n):
        return [n]
    ans = []
    for p in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
        while n % p == 0:
            n //= p
            ans.append(p)

    def dfs(nn: int):
        if nn == 1:
            return
        if is_prime(nn):
            ans.append(nn)
            return

        factor_round = 1 << nn.bit_length() // 8

        def find_factor():
            c = 0

            def f(x):
                return (x * x + c) % nn

            while True:
                c += 1
                x, y = c, c
                d = 1

                # round ごとにみる
                checkpoint = x, y
                while d == 1:
                    combined = 1
                    for _ in range(factor_round):
                        # Floyd's
                        x, y = f(x), f(f(y))
                        combined = combined * abs(x - y) % nn
                    d = gcd(combined, nn)
                    if d == 1:
                        # この round では見つからなかった
                        checkpoint = x, y
                    elif d != nn:
                        # 非自明な約数
                        return d

                # 1つずつ進める
                x, y = checkpoint
                d = 1
                while d == 1:
                    x, y = f(x), f(f(y))
                    d = gcd(abs(x - y), nn)
                    if d != 1 and d != nn:
                        return d
        d = find_factor()
        dfs(d)
        dfs(nn // d)
    dfs(n)
    return ans

def list_divisors(pe: Counter, pred):
    if not pred(1):
        return []
    ds = [1]
    for p, e in pe.items():
        for i in range(len(ds)):
            d = ds[i]
            for _ in range(e):
                d *= p
                if not pred(d):
                    break
                ds.append(d)
    return ds

def make_merged_freq(primes_n: list, d: int):
    """ n * d の素因数分解の頻度分布 """
    c = Counter(primes_n)
    for p in factorize(d):
        c[p] += 1
    return c

def r13_pe(n: int, e: int, primes_n: list):
    """
    e = 1 の解を列挙する
    素因数分解を除いて d(2n, (2n)^{1/2}) 時間
    """
    ans = []
    pe_2n = make_merged_freq(primes_n, 2)
    for w in list_divisors(pe_2n, lambda d: d * d < 2 * n):
        w2 = 2 * n // w
        if w > w2 or w % 2 == w2 % 2:
            continue
        l = (w2 - w + 1) // 2
        ans.append((1, l, l + w - 1))
    ans.sort(key=lambda t: t[1])
    return ans

enough_denom = [1, 2, 6, 2, 30, 2, 42, 2, 30, 2, 66, 2, 2730, 2, 6, 2, 510, 2, 798, 2, 330, 2, 138, 2, 2730, 2, 6, 2, 870, 2, 14322, 2, 510, 2, 6, 2, 1919190, 2, 6, 2, 13530, 2, 1806, 2, 690, 2, 282, 2, 46410, 2, 66, 2, 1590, 2, 798, 2, 870, 2, 354, 2, 56786730, 2, 6, 2, 510, 2, 64722, 2, 30, 2, 4686, 2, 140100870, 2, 6, 2, 30, 2, 3318, 2, 230010, 2, 498, 2, 3404310, 2, 6, 2, 61410, 2, 272118, 2, 1410, 2, 6, 2, 4501770, 2, 6, 2]

def re1_bisect_div_2(n: int, e: int):
    """
    e = e (小さい)の解を列挙する
    二分探索により,素因数分解を除いて Θ(W_e + d(D_e N, W_e) log N) 時間
    """
    dn = enough_denom[e] * n

    w_e = 1
    while power_sum(e, 1, w_e + 1) <= n:
        w_e <<= 1
    w_e = min_true(0, w_e, lambda w: power_sum(e, 1, w + 1) > n)

    ans = []
    for w in range(w_e):
        if dn % w:
            continue

        def pred(l):
            return power_sum(e, l, l + w) >= n
        r = 1
        while not pred(r):
            r <<= 1
        l = min_true(0, r, pred)
        if power_sum(e, l, l + w) == n:
            ans.append((e, l, l + w - 1))
    ans.sort(key=lambda t: t[1])
    return ans

def re0_two_pointer(n: int, e: int):
    """
    e = e (大きい)の解を列挙する
    尺取り法により Θ(n^{1/e}) 時間
    """
    pows = []
    for i in range(n + 1):
        if i**e <= n:
            pows.append(i**e)
        else:
            break
    # pows[i] = i**e
    c = len(pows)
    ans = []
    r = 1
    current_sum = 0
    for l in range(1, c):
        while r < c and current_sum + pows[r] <= n:
            current_sum += pows[r]
            r += 1
        # (r >= c or) current_sum <= n < current_sum + pows[r]
        if current_sum == n:
            ans.append((e, l, r - 1))
        current_sum -= pows[l]
    return ans

def solve(n: int):
    primes_n = factorize(n)
    ans = r13_pe(n, 1, primes_n)
    for e in range(2, n.bit_length()):
        if e <= 3:
            ans.extend(re1_bisect_div(n, e, primes_n))
        else:
            ans.extend(re0_two_pointer(n, e))
    print(len(ans))
    print("\n".join(f"{e} {l} {r}" for e, l, r in ans))

if __name__ == "__main__":
    n = int(input())
    solve(n)
0