from collections import defaultdict n, m = map(int, input().split()) mod = 998244353 PC = defaultdict(int) cnt = 0 while m % 2 == 0: m //= 2 cnt += 1 if cnt > 0: PC[2] = cnt f = 3 while f * f <= m: cnt = 0 while m % f == 0: m //= f cnt += 1 if cnt > 0: PC[f] = cnt f += 2 if m > 1: PC[m] = 1 ans = 1 for val in PC.values(): ans *= (pow(val + 1, n, mod) - pow(val, n, mod)) % mod ans %= mod print(ans)