MOD = 998244353 def find_primitive_root(n): phi = n - 1 factors = [] for i in range(2, n): if i * i > phi: break if phi % i == 0: factors.append(i) while phi % i == 0: phi //= i if phi > 1: factors.append(phi) for res in range(1, n): ok = True for factor in factors: if pow(res, (n - 1) // factor, n) == 1: ok = False break if ok: return res return -1 primitive_root = find_primitive_root(MOD) def ntt(a): n = len(a) m = n while m > 1: mh = m >> 1 wm = pow(primitive_root, (MOD - 1) // m, MOD) w = 1 for i in range(mh): for j in range(i, n, m): k = j + mh a0 = a[j] a1 = a[k] a[j] = a0 + a1 if a[j] >= MOD: a[j] -= MOD a[k] = (a0 - a1 + MOD) * w % MOD w = w * wm % MOD m = mh def intt(a): n = len(a) m = 2 while m <= n: mh = m >> 1 wm = pow(primitive_root, MOD - 1 - (MOD - 1) // m, MOD) w = 1 for i in range(mh): for j in range(i, n, m): k = j + mh a0 = a[j] a1 = a[k] * w % MOD a[j] = a0 + a1 if a[j] >= MOD: a[j] -= MOD a[k] = a0 - a1 if a[k] < 0: a[k] += MOD w = w * wm % MOD m <<= 1 inv = pow(n, MOD - 2, MOD) for i in range(n): a[i] = a[i] * inv % MOD def solve(): def comb(a, b): return (index_table_sum[a] - index_table_sum[b] - index_table_sum[a - b] + 2 * (p - 1)) % (p - 1) n, p = map(int, input().split()) q = find_primitive_root(p) index_table = [0] * p k = 1 for i in range(p - 1): index_table[k] = i k = k * q % p index_table_sum = [0] * p for i in range(1, p): index_table_sum[i] = index_table_sum[i - 1] + index_table[i] if index_table_sum[i] >= p - 1: index_table_sum[i] -= p - 1 len = 1 while len < (p - 1) * 2: len *= 2 count = [0] * len count[0] = 1 while n > 0: m = n % p n //= p a = [0] * len for i in range(m + 1): a[comb(m, i)] += 1 ntt(count) ntt(a) for i in range(len): count[i] = count[i] * a[i] % MOD intt(count) for i in range(len - 1, p - 2, -1): count[i - p + 1] += count[i] if count[i - p + 1] >= MOD: count[i - p + 1] -= MOD count[i] = 0 ans = 0 k = 1 for i in range(p - 1): ans = (ans + count[i] * k) % MOD k = k * q % p print(ans) solve()