import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES") no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO") ####################################################################### mod = 998244353 def pow_sum(r, n): if r == 1: return n else: return (pow(r, n, mod)-1) * pow(r - 1, mod-2, mod) % mod n, m = na() # for i in range(1, n + 1): # ans = 0 # for x in range(1, m + 1): # ans += pow(x, i, mod) # ans %= mod # print(ans * pow(m, n - i, mod) % mod) ans = 0 for x in range(1, m + 1): r = x * pow(m, mod-2, mod) % mod # print(x * pow_sum(x, n) % mod) ans += r * pow_sum(r, n) % mod ans %= mod print(ans * pow(m, n, mod) % mod)