N, M = map(int, input().split()) ret = 1 md = 998244353 for p in range(2, 1000000): d = 0 while M % p == 0: M //= p d += 1 if d: ret = ret * (pow(d + 1, N, md) - pow(d, N, md)) % md if M > 1: ret = ret * (pow(2, N, md) - 1) % md print(ret)