#include #include #include int main() { long long N; int M; std::cin >> N >> M; using mint = atcoder::modint998244353; // X[i] = sum(terms of A; A = (N integers between 1 and i (incl.))) // Y[i] = sum(terms of A; A = (N integers between M - i + 1 and M (incl.))) std::vector X(M + 2), Y(M + 2); mint s = 0; // s = 1 + 2 + ... + i for (int i = 1; i <= M; i++) { s += i; X[i] = mint(i).pow(N - 1) * N * s; Y[i] = mint(i).pow(N) * N * (M + 1) - X[i]; } mint ans = 0; for (int i = 1; i <= M; i++) { ans += i * (X[i] - X[i - 1]); ans -= i * (Y[i] - Y[i + 1]); } std::cout << ans.val() << "\n"; }