mod = 998244353 eps = 10**-9 def main(): import sys input = sys.stdin.readline N, M = map(int, input().split()) ans = 0 K = [0] * (M+1) L = [(M * (M+1) // 2)%mod] * (M+1) for i in range(1, M+1): K[i] = (K[i-1] + i)%mod L[i] = (L[i-1] - (i - 1))%mod L.append(0) P = [0] * (M+1) for i in range(M+1): P[i] = pow(i, N-1, mod) for i in range(1, M+1): ans = (ans + (i * ((K[i] * P[i])%mod - (K[i-1] * P[i-1])%mod)%mod)%mod)%mod ans = (ans - (i * ((L[i] * P[M-i+1])%mod - (L[i + 1] * P[M-i])%mod) % mod)%mod)% mod print((ans * N)%mod) if __name__ == '__main__': main()