結果

問題 No.896 友達以上恋人未満
ユーザー gew1fw
提出日時 2025-06-12 13:18:52
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,085 bytes
コンパイル時間 205 ms
コンパイル使用メモリ 82,612 KB
実行使用メモリ 380,372 KB
最終ジャッジ日時 2025-06-12 13:21:27
合計ジャッジ時間 19,938 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 1 TLE * 3 MLE * 3
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0

    M = int(input[ptr]); ptr +=1
    N = int(input[ptr]); ptr +=1
    mulX = int(input[ptr]); ptr +=1
    addX = int(input[ptr]); ptr +=1
    mulY = int(input[ptr]); ptr +=1
    addY = int(input[ptr]); ptr +=1
    MOD = int(input[ptr]); ptr +=1

    X = list(map(int, input[ptr:ptr+M]))
    ptr += M
    Y = list(map(int, input[ptr:ptr+M]))
    ptr += M
    A = list(map(int, input[ptr:ptr+M]))
    ptr += M
    B = list(map(int, input[ptr:ptr+M]))
    ptr += M

    # Initialize z array
    z = [0] * MOD

    # Generate x and y for eels and accumulate z
    if M > 0:
        x_prev = X[0]
        y_prev = Y[0]
        z[x_prev] += y_prev
        for i in range(1, M):
            x_i = X[i]
            y_i = Y[i]
            z[x_i] += y_i
            x_prev = x_i
            y_prev = y_i
    else:
        x_prev = 0
        y_prev = 0

    for i in range(M, N):
        x_i = (x_prev * mulX + addX) % MOD
        y_i = (y_prev * mulY + addY) % MOD
        z[x_i] += y_i
        x_prev = x_i
        y_prev = y_i

    # Precompute sum_multiples
    sum_multiples = [0] * (MOD + 1)
    for d in range(1, MOD + 1):
        sum_multiples[d] = sum(z[d::d])

    # Process rabbits
    xor_all = 0
    ans = []

    # Generate a and b for rabbits
    if M > 0:
        a_prev = A[0]
        b_prev = B[0]
        for j in range(M):
            if j < M:
                a_j = A[j]
                b_j = B[j]
            else:
                a_j = (a_prev * mulX + addX + MOD -1) % MOD + 1
                b_j = (b_prev * mulY + addY + MOD -1) % MOD + 1
                a_prev = a_j
                b_prev = b_j

            ab = a_j * b_j
            if ab > MOD:
                current = sum_multiples[a_j]
            else:
                current = sum_multiples[a_j] - sum_multiples[ab] if ab <= MOD else sum_multiples[a_j]
            if j < M:
                ans.append(current)
            xor_all ^= current
            a_prev = a_j
            b_prev = b_j

        for j in range(M, N):
            a_j = (a_prev * mulX + addX + MOD -1) % MOD + 1
            b_j = (b_prev * mulY + addY + MOD -1) % MOD + 1
            ab = a_j * b_j
            if ab > MOD:
                current = sum_multiples[a_j]
            else:
                current = sum_multiples[a_j] - sum_multiples[ab] if ab <= MOD else sum_multiples[a_j]
            xor_all ^= current
            a_prev = a_j
            b_prev = b_j
    else:
        a_prev = 1
        b_prev = 1
        for j in range(N):
            a_j = (a_prev * mulX + addX + MOD -1) % MOD + 1
            b_j = (b_prev * mulY + addY + MOD -1) % MOD + 1
            ab = a_j * b_j
            if ab > MOD:
                current = sum_multiples[a_j]
            else:
                current = sum_multiples[a_j] - sum_multiples[ab] if ab <= MOD else sum_multiples[a_j]
            xor_all ^= current
            a_prev = a_j
            b_prev = b_j

    for val in ans:
        print(val)
    print(xor_all)

if __name__ == '__main__':
    main()
0