def factorization(n):
    arr = []
    temp = n
    for i in range(2, int(-(-n**0.5//1))+1):
        if temp%i==0:
            cnt=0
            while temp%i==0:
                cnt+=1
                temp //= i
            arr.append([i, cnt])

    if temp!=1:
        arr.append([temp, 1])

    return arr


N, M = map(int, input().split())
ans = 1
mod = 998244353
for _, v in factorization(M):
    # for i in range(1, N + 1):
    #     comb = 1
    #     for j in range(i):
    #         comb *= N - j
    #         comb //= j + 1
    #     temp += comb * pow(v, N - i, mod)
    #     temp %= mod
    ans *= pow(v + 1, N, mod) - pow(v, N, mod)
    ans %= mod
        
print(ans)