def factorization(n): arr = [] temp = n for i in range(2, int(-(-n**0.5//1))+1): if temp%i==0: cnt=0 while temp%i==0: cnt+=1 temp //= i arr.append([i, cnt]) if temp!=1: arr.append([temp, 1]) return arr N, M = map(int, input().split()) ans = 1 mod = 998244353 for _, v in factorization(M): temp = 0 for i in range(1, N + 1): comb = 1 for j in range(i): comb *= N - j comb //= j + 1 temp += comb * pow(v, N - i, mod) temp %= mod ans *= temp ans %= mod print(ans)