結果

問題 No.1262 グラフを作ろう!
ユーザー zkou
提出日時 2020-10-16 22:57:37
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,920 bytes
コンパイル時間 992 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 212,036 KB
最終ジャッジ日時 2024-07-20 23:25:42
合計ジャッジ時間 12,739 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 95
権限があれば一括ダウンロードができます

ソースコード

diff #

from functools import lru_cache

# 与えられた N 以下の素数を列挙して、その N 以下の素数の素因数分解を高速に与えるアルゴリズム。
#
# 計算量:
#     構築: O(N) 
#     素因数分解: n <= N なる n に対し O(log n)、N < n <= N**2 なる n に対し、O(N / log N)
#
#
# 参考: https://cp-algorithms.com/algebra/prime-sieve-linear.html (n <= N に対して O(log n) の部分)

N = 10 ** 6 + 10  # 10 ** 7 以下だとよさそう。10 ** 7 でPyPy3で構築に500ms程度。
lp = [0] * (N+1)
pr = []  # N 以下の素数のリスト
for i in range(2, N+1):
    if lp[i] == 0:
        lp[i] = i
        pr.append(i)
    for j, p in enumerate(pr):
        if p > lp[i] or i * p > N:
            break
        lp[i * p] = p


def fac_small(n):
    """
    引数 n の素因数分解をして、素因数を昇順に格納したリストを返す。
    ex. fac_small(60) == [2, 2, 3, 5]
    与えられた n について 1 <= n <= N を仮定する。n == 1 では空リストを返す。
    計算量は O(log n) 
    """
    # assert 1 <= n <= N
    ret = []
    while n > 1:
        ret.append(lp[n])
        n //= lp[n]
    return ret


def fac(n):
    """
    引数 n の素因数分解をして、素因数を昇順に格納したリストを返す。
    ex. fac_small(60) == [2, 2, 3, 5]
    与えられた n について 1 <= n <= N**2 を仮定する。n == 1 では空リストを返す。
    計算量は、1 <= n <= N で O(log n)、N < n <= N ** 2 で O(N / log N) 程度。
    """
    # assert 1 <= n <= N**2
    if 1 <= n <= N:
        return fac_small(n)
    else:
        sqr = int(N ** 0.5) + 10
        ret = []
        for p in pr:
            while n % p == 0:
                n //= p
                ret.append(p)
            if n == 1:
                return ret
            elif n <= N:
                return ret + fac_small(n)
            if p > sqr:
                break
        ret.append(n)
        return ret

from collections import Counter
from itertools import product

def enumerate_factors(ls):
    c = Counter(ls)
    iterators = [[k**i for i in range(v + 1)] for k, v in c.items()]
    ret = []
    for tup in product(*iterators):
        t = 1
        for x in tup:
            t *= x
        ret.append(t)
    return ret

@lru_cache(100)
def solve(A):
    ls = fac_small(A)
    facs = enumerate_factors(ls)
    facs.sort()
    ps = list(set(ls))
    dic = {f: A // f for f in facs}
    for p in ps:
        for f in facs:  
            if A % (f * p) == 0:
                dic[f] -= dic[f * p]
    # print(dic)
    return sum((f * c for f, c in dic.items()))

N, M = map(int, input().split())
As = list(map(int, input().split()))
ans = 0
for A in As:
    ans += solve(A) - A
    # if solve(A) != sum((gcd(i, A) for i in range(1, A + 1))):
    #     print(A, solve(A), sum((gcd(i, A) for i in range(1, A + 1))))

print(ans)
0