N, M = map(int, input().split()) mod = 998244353 ans = 0 for i in range(1, M + 1): ans += i * (i*(i+1)//2 * pow(i, N - 1, mod) * N) ans -= i * (i*(i-1)//2 * pow(i - 1, N - 1, mod) * N) ans += i * ((M*(M+1)//2-i*(i+1)//2) * pow(M - i, N - 1, mod) * N) ans -= i * ((M*(M+1)//2-i*(i-1)//2) * pow(M - i + 1, N - 1, mod) * N) ans %= mod print(ans)