結果

問題 No.2280 FizzBuzz Difference
ユーザー chineristACchineristAC
提出日時 2023-04-21 22:25:50
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 146 ms / 2,000 ms
コード長 5,285 bytes
コンパイル時間 274 ms
コンパイル使用メモリ 82,176 KB
実行使用メモリ 80,312 KB
最終ジャッジ日時 2024-04-24 08:32:54
合計ジャッジ時間 1,917 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 66 ms
69,248 KB
testcase_01 AC 142 ms
80,312 KB
testcase_02 AC 136 ms
80,000 KB
testcase_03 AC 117 ms
78,848 KB
testcase_04 AC 142 ms
80,000 KB
testcase_05 AC 141 ms
80,128 KB
testcase_06 AC 146 ms
79,744 KB
testcase_07 AC 140 ms
79,744 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys,random,bisect
from collections import deque,defaultdict
from heapq import heapify,heappop,heappush
from itertools import permutations
from math import gcd,log

from math import sqrt, ceil
from bisect import bisect_left, bisect_right
from typing import Iterable


input = lambda :sys.stdin.readline().rstrip()
mi = lambda :map(int,input().split())
li = lambda :list(mi())

class SegmentTree:
    def __init__(self, init_val, segfunc, ide_ele):
        n = len(init_val)
        self.segfunc = segfunc
        self.ide_ele = ide_ele
        self.num = 1 << (n - 1).bit_length()
        self.tree = [ide_ele] * 2 * self.num
        self.size = n
        for i in range(n):
            self.tree[self.num + i] = init_val[i]
        for i in range(self.num - 1, 0, -1):
            self.tree[i] = self.segfunc(self.tree[2 * i], self.tree[2 * i + 1])

    def update(self, k, x):
        k += self.num
        self.tree[k] = x
        while k > 1:
            k >>= 1
            self.tree[k] = self.segfunc(self.tree[2*k], self.tree[2*k+1])

    def query(self, l, r):
        if r==self.size:
            r = self.num

        res = self.ide_ele

        l += self.num
        r += self.num
        right = []
        while l < r:
            if l & 1:
                res = self.segfunc(res, self.tree[l])
                l += 1
            if r & 1:
                right.append(self.tree[r-1])
            l >>= 1
            r >>= 1

        for e in right[::-1]:
            res = self.segfunc(res,e)
        return res

def _inv_gcd(a,b):
    a %= b
    if a == 0:
        return (b, 0)
 
    # Contracts:
    # [1] s - m0 * a = 0 (mod b)
    # [2] t - m1 * a = 0 (mod b)
    # [3] s * |m1| + t * |m0| <= b
    s = b
    t = a
    m0 = 0
    m1 = 1
 
    while t:
        u = s // t
        s -= t * u
        m0 -= m1 * u  # |m1 * u| <= |m1| * s <= b
 
        # [3]:
        # (s - t * u) * |m1| + t * |m0 - m1 * u|
        # <= s * |m1| - t * u * |m1| + t * (|m0| + |m1| * u)
        # = s * |m1| + t * |m0| <= b
 
        s, t = t, s
        m0, m1 = m1, m0
 
    # by [3]: |m0| <= b/g
    # by g != b: |m0| < b/g
    if m0 < 0:
        m0 += b // s
 
    return (s, m0)
 
def crt(r,m):
    assert len(r) == len(m)
 
    n = len(r)
 
    # Contracts: 0 <= r0 < m0
    r0 = 0
    m0 = 1
    for i in range(n):
        assert 1 <= m[i]
        r1 = r[i] % m[i]
        m1 = m[i]
        if m0 < m1:
            r0, r1 = r1, r0
            m0, m1 = m1, m0
        if m0 % m1 == 0:
            if r0 % m1 != r1:
                return (0, 0)
            continue
 
        # assume: m0 > m1, lcm(m0, m1) >= 2 * max(m0, m1)
 
        '''
        (r0, m0), (r1, m1) -> (r2, m2 = lcm(m0, m1));
        r2 % m0 = r0
        r2 % m1 = r1
        -> (r0 + x*m0) % m1 = r1
        -> x*u0*g % (u1*g) = (r1 - r0) (u0*g = m0, u1*g = m1)
        -> x = (r1 - r0) / g * inv(u0) (mod u1)
        '''
 
        # im = inv(u0) (mod u1) (0 <= im < u1)
        g, im = _inv_gcd(m0, m1)
 
        u1 = m1 // g
        # |r1 - r0| < (m0 + m1) <= lcm(m0, m1)
        if (r1 - r0) % g:
            return (0, 0)
 
        # u1 * u1 <= m1 * m1 / g / g <= m0 * m1 / g = lcm(m0, m1)
        x = (r1 - r0) // g % u1 * im % u1
 
        '''
        |r0| + |m0 * x|
        < m0 + m0 * (u1 - 1)
        = m0 + m0 * m1 / g - m0
        = lcm(m0, m1)
        '''
 
        r0 += x * m0
        m0 *= u1  # -> lcm(m0, m1)
        if r0 < 0:
            r0 += m0
 
    return (r0, m0)

def solve(M,A,B,K):
    
    g = gcd(A,B)
    M //= g
    if K % g:
        return 0
    K //= g
    A //= g
    B //= g

    if A < K:
        return 0

    res = 0
    if A == K:
        mini_n = 1
        maxi_n = M//A

        mini_b = (A*mini_n//B)
        maxi_b = (A*maxi_n//B)

        diff = maxi_b - mini_b

        x = (maxi_n-mini_n) - diff
        y = M//(A*B)
        return x + y
    
    """
    By-Ax=Kのパターン
    """
    r,_ = crt([0,K],[B,A])
    y0 = (r//B)
    x0 = ((r-K)//A)
    """
    y = y0 + A*n
    x = x0 + B*n
    """

    #print(x0,y0,M)
    if y0 <= M:
        min_n = 0
        max_n = min((M//B-y0)//(A),(M//A-x0)//(B))

        check_y = (B*y0)//A + y0 - (y0//A)
        check_x = x0 + (A*x0)//B - (x0//B)

        #print(check_x,check_y,max_n)

        if check_y == check_x + 1:
            res += max_n - min_n + 1
        
    """
    Ay-Bx=Kのパターン
    """
    r,_ = crt([0,K],[A,B])
    y0 = (r//A)
    x0 = ((r-K))//B

    """
    y = y0 + B*n
    x = x0 + A*n
    """

    if y0 <= M:
        min_n = 0
        max_n = min((M//A-y0)//(B),(M//B-x0)//(A))

        check_y = y0 + (A*y0)//B - (y0//B)
        check_x = (B*x0)//A + x0 - (x0//A)

        if check_y == check_x + 1:
            res += max_n - min_n + 1
    
    return res





def brute(M,A,B,K):
    X = [n for n in range(1,M+1) if n%A == 0 or n%B == 0]
    res = 0
    for i in range(len(X)-1):
        if X[i+1]-X[i] == K:
            res += 1
    return res

while False:
    A = random.randint(2,20)
    B = random.randint(A+1,21)
    M = random.randint(5000,10000)
    K = random.randint(1,A-1)

    #M,A,B,K = 60,7,10,3
    print(M,A,B,K,solve(M,A,B,K),brute(M,A,B,K))
    assert solve(M,A,B,K) == brute(M,A,B,K)

for _ in range(int(input())):
    M,A,B,K = mi()
    print(solve(M,A,B,K))
0