結果

問題 No.3525 擬奇平方数
コンテスト
ユーザー lif4635
提出日時 2026-05-01 22:32:36
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
RE  
実行時間 -
コード長 8,065 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 181 ms
コンパイル使用メモリ 85,120 KB
実行使用メモリ 102,560 KB
平均クエリ数 3586.00
最終ジャッジ日時 2026-05-01 22:32:56
合計ジャッジ時間 19,948 ms
ジャッジサーバーID
(参考情報)
judge1_0 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other RE * 60
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# input
import sys
# input = sys.stdin.readline
II = lambda : int(input())
MI = lambda : map(int, input().split())
LI = lambda : [int(a) for a in input().split()]
SI = lambda : input().rstrip()
LLI = lambda n : [[int(a) for a in input().split()] for _ in range(n)]
LSI = lambda n : [input().rstrip() for _ in range(n)]
MI_1 = lambda : map(lambda x:int(x)-1, input().split())
LI_1 = lambda : [int(a)-1 for a in input().split()]

mod = 998244353
inf = 1001001001001001001
ordalp = lambda s : ord(s)-65 if s.isupper() else ord(s)-97
ordallalp = lambda s : ord(s)-39 if s.isupper() else ord(s)-97
yes = lambda : print("Yes")
no = lambda : print("No")
yn = lambda flag : print("Yes" if flag else "No")

prinf = lambda ans : print(ans if ans < 1000001001001001001 else -1)
alplow = "abcdefghijklmnopqrstuvwxyz"
alpup = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
alpall = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
URDL = {'U':(-1,0), 'R':(0,1), 'D':(1,0), 'L':(0,-1)}
DIR_4 = [[-1,0],[0,1],[1,0],[0,-1]]
DIR_8 = [[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1],[0,-1],[-1,-1]]
DIR_BISHOP = [[-1,1],[1,1],[1,-1],[-1,-1]]
prime60 = [2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59]
sys.set_int_max_str_digits(0)
# sys.setrecursionlimit(10**6)
# import pypyjit
# pypyjit.set_param('max_unroll_recursion=-1')

from collections import defaultdict,deque
from heapq import heappop,heappush
from bisect import bisect_left,bisect_right
DD = defaultdict
BSL = bisect_left
BSR = bisect_right


base = [15591, 2018, 166, 7429, 8064, 16045, 10503, 4399, 1949, 1295, 2776, 3620,
    560, 3128, 5212, 2657, 2300, 2021, 4652, 1471, 9336, 4018, 2398, 20462,
    10277, 8028, 2213, 6219, 620, 3763, 4852, 5012, 3185, 1333, 6227, 5298,
    1074, 2391, 5113, 7061, 803, 1269, 3875, 422, 751, 580, 4729, 10239, 746,
    2951, 556, 2206, 3778, 481, 1522, 3476, 481, 2487, 3266, 5633, 488, 3373,
    6441, 3344, 17, 15105, 1490, 4154, 2036, 1882, 1813, 467, 3307, 14042,
    6371, 658, 1005, 903, 737, 1887, 7447, 1888, 2848, 1784, 7559, 3400, 951,
    13969, 4304, 177, 41, 19875, 3110, 13221, 8726, 571, 7043, 6943, 1199, 352,
    6435, 165, 1169, 3315, 978, 233, 3003, 2562, 2994, 10587, 10030, 2377,
    1902, 5354, 4447, 1555, 263, 27027, 2283, 305, 669, 1912, 601, 6186, 429,
    1930, 14873, 1784, 1661, 524, 3577, 236, 2360, 6146, 2850, 55637, 1753,
    4178, 8466, 222, 2579, 2743, 2031, 2226, 2276, 374, 2132, 813, 23788, 1610,
    4422, 5159, 1725, 3597, 3366, 14336, 579, 165, 1375, 10018, 12616, 9816,
    1371, 536, 1867, 10864, 857, 2206, 5788, 434, 8085, 17618, 727, 3639, 1595,
    4944, 2129, 2029, 8195, 8344, 6232, 9183, 8126, 1870, 3296, 7455, 8947,
    25017, 541, 19115, 368, 566, 5674, 411, 522, 1027, 8215, 2050, 6544, 10049,
    614, 774, 2333, 3007, 35201, 4706, 1152, 1785, 1028, 1540, 3743, 493, 4474,
    2521, 26845, 8354, 864, 18915, 5465, 2447, 42, 4511, 1660, 166, 1249, 6259,
    2553, 304, 272, 7286, 73, 6554, 899, 2816, 5197, 13330, 7054, 2818, 3199,
    811, 922, 350, 7514, 4452, 3449, 2663, 4708, 418, 1621, 1171, 3471, 88,
    11345, 412, 1559, 194,]

def is_prime32(x):
    """ 1 <= x < 1<<32 """
    if x == 2: return True
    if x == 3: return True
    if x == 5: return True
    if x == 7: return True
    if not x&1: return False
    if not x%3: return False
    if not x%5: return False
    if not x%7: return False
    if x < 121: return x > 1
    
    h = x
    h = ((h>>16) ^ h) * 0x45d9f3b
    h = ((h>>16) ^ h) * 0x45d9f3b
    h = ((h>>16) ^ h) & 255
    a = base[h]
    
    d, s = x-1, 0
    while not d&1:
        d >>= 1
        s += 1
    
    t = pow(a, d, x)
    if t == 0: return False
    if t == 1: return True
    if t == x-1: return True
    for _ in range(s-1):
        t = t * t % x
        if t == x-1: return True
    return False

def divisors(n:int) -> list[int]:
    divs_small, divs_big = [], []
    i = 1
    while i*i <= n:
        if n % i == 0:
            divs_small.append(i)
            if i != n//i:
                divs_big.append(n//i)
        i += 1
    return divs_small + divs_big[::-1]


def check(n):
    s2 = n ** (1 / 2)
    s3 = n ** (1 / 3)
    ok = []
    for d in divisors(n):
        if s2 - s3 <= d <= s2 + s3:
            ok.append(d)
    return len(ok) > 0

if 0:
    mc = 0
    ch = set()
    for n in range(1, 10 ** 6, 2):
        s2 = n ** (1 / 2)
        s3 = n ** (1 / 3)
        ok = []
        for d in divisors(n):
            if s2 - s3 <= d <= s2 + s3:
                if d <= n // d:
                    ok.append((d, n // d))
        
        # ok についてかんがえる
        
        if len(ok):
            r = int(s2)
            for x, y in ok:
                ch.add((x - r, y - r))
    
    print(len(ch), ch)


def nth_root(x:int, n:int, is_64bit = True) -> int:
    """
    floor(x^(1/n))
    """
    ngs = [-1, -1, 4294967296, 2642246, 65536, 7132, 1626, 566, 256, 139, 85, 57, 41, 31, 24, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
    if x <= 1 or n == 1:
        return x
    if is_64bit:
        if n >= 64: return 1
        ng = ngs[n]
    else:
        ng = x
    
    ok = 0
    while abs(ok - ng) > 1:
        mid = (ok + ng)//2
        if mid**n <= x:
            ok = mid
        else:
            ng = mid
    return ok    

if 0:
    s = set()
    for n in reversed(range(1, 10 ** 6, 2)):
        s2 = nth_root(n, 2)
        s3 = nth_root(n, 2)
        p = (s2, s3)
        if p in s:
            print(n, p)
            assert False
        s.add(p)
"""
素数だが疑似平方数のもの
3, 5, 7 only

合成数だが疑似平方数のもの
たくさんある

1 をまず作る

2, 3 も作る O(1)

1, 2, 3, 4
n, 1, 2, 3
"""

a = {}

def solve():

    idx = 1
    def add(i, op, j):
        nonlocal idx
        print("a", i, op, j)
        
        idx += 1
        
        # if op == "+":
        #     a[idx] = a[i] + a[j]
        # elif op == "-":
        #     a[idx] = a[i] - a[j]
        # elif op == "*":
        #     a[idx] = a[i] * a[j]
        # else:
        #     a[idx] = nth_root(a[j], a[i])
        
        # assert 0 <= a[idx] < 2 ** 64
        return idx

    def ask(i, j):
        print("?", i, "<", j)
        # print(a[i], a[j])
        # return a[i] < a[j]
        return SI() == "T"

    p = {}

    pn = 1
    p[1] = add(1, "r", 1)


    lim = 200
    # a[2] ~ a[lim]
    # a[i] = i - 1 が入っている
    for x in range(2, lim):
        p[x] = add(p[x-1], "+", p[1])

    s = add(p[2], "r", pn)
    # c = add(p[3], "r", pn)

    ch = []
    for x in range(1, lim):
        k = add(s, "+", p[x]) # k = sqrt(n) + x
        k2 = add(k, "*", k) # k * k
        
        # dk が平方数なら ok
        dk = add(k2, "-", pn) # k * k - n
        dks = add(p[2], "r", dk)
        dk2 = add(dks, "*", dks)
        
        # u >= s - c 
        u = add(k, "-", dks)
        """
        整数に丸めるんでした。怒られ
        c >= s - u
        c ** 3, dks
        """
        u2 = add(u, "*", u)
        u3 = add(u2, "*", u)
        nu = add(u, "*", pn)
        nu3 = add(p[3], "*", nu)
        tmp = add(pn, "+", nu3)
        tmp = add(tmp, "+", u3)
        l = add(tmp, "*", tmp)
        
        u23 = add(u2, "*", p[3])
        tmp = add(pn, "+", u23)
        tmp = add(tmp, "*", tmp)
        r = add(tmp, "*", pn)
        
        
        ch.append((dk, dk2, l, r))

    
    s2 = add(s, "*", s)
    s21 = add(s2, "+", p[1])




    if ask(pn, s21):
        print("! Yes")
        return True


    for x, y, l, r in ch:
        # x <= y かつ l >= r が条件
        if not ask(y, x):
            if not ask(l, r):
                print("! Yes")
                return True

    print("! No")
    return False


solve()


if 0:
    ng = []
    for n in range(1, 10 ** 6, 2):
        a.clear()
        a[1] = n
        ans = solve()
        if ans != check(n):
            ng.append(n)
            # print(n, ans, check(n))
            # assert False
            print(n)
        print(n)
    print(ng)
0