MOD = 998244353 G = 3 def ntt(a, invert=False): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j & bit: j ^= bit bit >>= 1 j ^= bit if i < j: a[i], a[j] = a[j], a[i] length = 2 while length <= n: wlen = pow(G, (MOD - 1) // length, MOD) if invert: wlen = pow(wlen, MOD - 2, MOD) half = length // 2 for i in range(0, n, length): w = 1 for j in range(i, i + half): u = a[j] v = a[j + half] * w % MOD a[j] = (u + v) % MOD a[j + half] = (u - v) % MOD w = w * wlen % MOD length *= 2 if invert: inv_n = pow(n, MOD - 2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD def convolution(a, b): size = 1 while size < len(a) + len(b) - 1: size *= 2 fa = a[:] + [0] * (size - len(a)) fb = b[:] + [0] * (size - len(b)) ntt(fa) ntt(fb) for i in range(size): fa[i] = fa[i] * fb[i] % MOD ntt(fa, invert=True) return fa[:len(a) + len(b) - 1] N, B = map(int, input().split()) cnt = [0] * B for x in range(B): r = pow(x, N, B) cnt[r] += 1 conv = convolution(cnt, cnt) ans = 0 for z_res in range(B): pairs = conv[z_res] if z_res + B < len(conv): pairs += conv[z_res + B] ans += pairs * cnt[z_res] print(ans)