N,M = map(int,input().split())
mod = 998244353
ans = 1
for i in range(2,10**5):
	if M%i==0:
		cnt = 0
		while M%i==0: M//=i ; cnt += 1
		ans *= pow(cnt+1,N,mod)-pow(cnt,N,mod)
		ans %= mod
if M!=1: ans *= pow(2,N,mod)-1 ; ans%=mod
print(ans)