#include #include #include using namespace std; using ll = long long; #include using mint = atcoder::modint998244353; int main(){ cin.tie(nullptr); ios::sync_with_stdio(false); ll n,m; cin>>n>>m; mint ans = 0; mint inv = mint(2).inv(); mint sum = 0; for(int i = 1;i<=m;i++){ mint p = mint(i) * mint(i-1) * inv; p *= n; p *= i; p *= mint(i).pow(n-1) - mint(i-1).pow(n-1); ans += p; p = n; p *= i; p *= i; p *= mint(i).pow(n-1); ans += p; //cout<=1;i--){ mint p = sum; p *= n; p *= i; p *= mint(m-i+1).pow(n-1) - mint(m-i).pow(n-1); ans -= p; p = n; p *= i; p *= mint(m-i+1).pow(n-1); p *= i; ans -= p; sum += i; } cout<