mod = 998244353 eps = 10**-9 def main(): import sys input = sys.stdin.readline N, M = map(int, input().split()) ans = 0 K = [0] * (M+1) L = [(M * (M+1) // 2)%mod] * (M+1) for i in range(1, M+1): K[i] = (K[i-1] + i)%mod L[i] = (L[i-1] - (i - 1))%mod L.append(0) for i in range(1, M+1): ans = (ans + (i * ((K[i] * pow(i, N-1, mod) * N)%mod - (K[i-1] * pow(i-1, N-1, mod) * N)%mod)%mod)%mod)%mod ans = (ans - (i * ((L[i] * pow(M - i + 1, N - 1, mod) * N)%mod - (L[i + 1] * pow(M - i, N - 1, mod) * N)%mod) % mod)%mod)% mod print(ans) if __name__ == '__main__': main()