#include using namespace std; #define rep(i, a, n) for(int i=(a); i<(n); ++i) #define per(i, a, n) for(int i=(a); i>(n); --i) #define pb emplace_back #define mp make_pair #define clr(a, b) memset(a, b, sizeof(a)) #define all(x) (x).begin(),(x).end() #define lowbit(x) (x & -x) #define fi first #define se second #define lson o<<1 #define rson o<<1|1 #define gmid l[o]+r[o]>>1 using LL = long long; using ULL = unsigned long long; using pii = pair; using PLL = pair; using UI = unsigned int; const int mod = 998244353; const int inf = 0x3f3f3f3f; const double EPS = 1e-8; const double PI = acos(-1.0); const int N = 5e5 + 10; LL a[N], sum[N]; LL n; int m; LL pow_mod(LL x, LL p){ LL s = 1; while(p){ if(p & 1) s = s * x % mod; x = x * x % mod; p >>= 1; } return s; } int main(){ scanf("%lld %d", &n, &m); sum[0] = a[0] = 0; rep(i, 1, m + 1){ sum[i] = (sum[i-1] + i) % mod; a[i] = pow_mod(i, n - 1); } LL ans = 0; rep(i, 1, m + 1){ ans = ans + (sum[i] * a[i] % mod + mod - sum[i-1] * a[i-1] % mod) % mod * i % mod; ans = ans + ((sum[m] + mod - sum[i-1]) * a[m-i+1] % mod + mod - (sum[m] + mod - sum[i]) % mod * a[m-i]) % mod * (mod - i) % mod; ans = ans % mod; } printf("%lld\n", n % mod * ans % mod); return 0; }