def factorize(n): r = [] for p in range(2, int(n ** 0.5) + 1): if n % p != 0: continue e = 0 while n % p == 0: n //= p e += 1 r.append((p, e)) if n != 1: r.append((n, 1)) return r n, m = map(int, input().split()) ans = 1 for p, e in factorize(m): ans *= (e + 1) ** n - e ** n print(ans % 998244353)