from collections import defaultdict

n, m = map(int, input().split())
mod = 998244353

PC = defaultdict(int)
cnt = 0
while m % 2 == 0:
    m //= 2
    cnt += 1
if cnt > 0:
    PC[2] = cnt
f = 3
while f * f <= m:
    cnt = 0
    while m % f == 0:
        m //= f
        cnt += 1
    if cnt > 0:
        PC[f] = cnt
    f += 2
if m > 1:
    PC[m] = 1
ans = 1
for val in PC.values():
    ans *= (pow(val + 1, n, mod) - pow(val, n, mod)) % mod
    ans %= mod
print(ans)