def factorization(n): arr = [] temp = n for i in range(2, int(-(-n**0.5//1))+1): if temp%i==0: cnt=0 while temp%i==0: cnt+=1 temp //= i arr.append([i, cnt]) if temp!=1: arr.append([temp, 1]) return arr N, M = map(int, input().split()) ans = 1 mod = 998244353 for _, v in factorization(M): # for i in range(1, N + 1): # comb = 1 # for j in range(i): # comb *= N - j # comb //= j + 1 # temp += comb * pow(v, N - i, mod) # temp %= mod ans *= pow(v + 1, N, mod) - pow(v, N, mod) ans %= mod print(ans)