P = 998244353 N, M = map(int, input().split()) ans = 0 s = 0 for i in range(1, M + 1): ans += i * i * (i + 1) // 2 * pow(i, N - 1, P) s += M - i + 1 ans -= s * (M - i + 1) * pow(i, N - 1, P) if i < M: ans -= i * (i + 1) * (i + 1) // 2 * pow(i, N - 1, P) ans += s * (M - i + 0) * pow(i, N - 1, P) ans %= P ans *= N ans %= P print(ans)