MOD = 998244353 n, m = map(int, input().split()) ans, p = 1, 2 while p * p <= m: if m % p == 0: c = 0 while m % p == 0: m, c = m // p, c + 1 ans = ans * (pow(c + 1, n, MOD) - pow(c, n, MOD) + MOD) % MOD p += 1 + p % 2 if m > 1: ans = ans * (pow(2, n, MOD) - 1 + MOD) % MOD print(ans)