import math MOD = 998244353 def multiply(m1, m2, mod): a = (m1[0][0] * m2[0][0] + m1[0][1] * m2[1][0]) % mod b = (m1[0][0] * m2[0][1] + m1[0][1] * m2[1][1]) % mod c = (m1[1][0] * m2[0][0] + m1[1][1] * m2[1][0]) % mod d = (m1[1][0] * m2[0][1] + m1[1][1] * m2[1][1]) % mod return [[a, b], [c, d]] def matrix_pow(mat, power, mod): result = [[1, 0], [0, 1]] # Identity matrix while power > 0: if power % 2 == 1: result = multiply(result, mat, mod) mat = multiply(mat, mat, mod) power //= 2 return result def matrix_inverse(mat, mod): a, b = mat[0] c, d = mat[1] det = (a * d - b * c) % mod if det == 0: return None det_inv = pow(det, -1, mod) return [ [(d * det_inv) % mod, (-b * det_inv) % mod], [(-c * det_inv) % mod, (a * det_inv) % mod] ] def multiply_matrix_vector(mat, vec, mod): a = (mat[0][0] * vec[0] + mat[0][1] * vec[1]) % mod b = (mat[1][0] * vec[0] + mat[1][1] * vec[1]) % mod return (a, b) def solve_matrix_case(A, B, P, Q, mod): V0 = (A % mod, 2 % mod) V = (P % mod, Q % mod) if V == V0: return 10**18 M = [[A % mod, (-B) % mod], [1, 0]] Minv = matrix_inverse(M, mod) if Minv is None: return -1 m = int(math.isqrt(10**10)) + 1 # m is sqrt(1e10) baby_steps = {} current = V0 baby_steps[current] = 0 for i in range(1, m): current = multiply_matrix_vector(M, current, mod) if current not in baby_steps: baby_steps[current] = i Minv_power_m = matrix_pow(Minv, m, mod) current_giant = V for j in range(m): if current_giant in baby_steps: k = j * m + baby_steps[current_giant] if k >= 0: return k + 1 current_giant = multiply_matrix_vector(Minv_power_m, current_giant, mod) return -1 def discrete_log(a, b, p): if b == 1: return 0 a = a % p b = b % p if a == 0: if b == 0: return 1 else: return -1 g = math.gcd(a, p) if g != 1: if b % g != 0: return -1 a_new = a // g b_new = b // g p_new = p // g inv_a_new = pow(a_new, -1, p_new) c = (b_new * inv_a_new) % p_new y = discrete_log(a_new, c, p_new) if y == -1: return -1 return y + 1 m = int(math.isqrt(p)) + 1 table = {} current = 1 for j in range(m): if current not in table: table[current] = j current = (current * a) % p a_m = pow(a, m, p) a_m_inv = pow(a_m, p-2, p) gamma = b for i in range(m): if gamma in table: return i * m + table[gamma] gamma = (gamma * a_m_inv) % p return -1 def solve(): A, B, P, Q = map(int, input().split()) A = A % MOD B = B % MOD P = P % MOD Q = Q % MOD if B == 0: if A == 0: if P == 0 and Q == 0: print(1000000000000000000) else: print(-1) return else: if (A * Q) % MOD != P: print(-1) return if Q == 0: print(-1) return x = discrete_log(A, Q, MOD) if x == -1: print(-1) else: if x >= 1: print(x + 1) else: print(-1) else: V0 = (A, 2 % MOD) V = (P, Q) if V == V0: print(1000000000000000000) return res = solve_matrix_case(A, B, P, Q, MOD) if res == -1: print(-1) else: print(res) solve()