n,m = map(int,input().split()) mod = 998244353 ans = 0 inv_m = pow(m,mod-2,mod) for i in range(1,m): r = i*inv_m % mod ans += r*(1-pow(r,n,mod))*pow(1-r,mod-2,mod) % mod ans %= mod ans += n print(pow(m,n,mod)*ans%mod)