P = 998244353 N, M = map(int, input().split()) ans = 0 s = 0 for i in range(1, M + 1): po = pow(i, N - 1, P) ans += i * (i + 1) // 2 % P * i % P * po s += M - i + 1 s %= P ans -= s * (M - i + 1) % P * po if i < M: ans -= i * (i + 1) // 2 % P * (i + 1) % P * po ans += s * (M - i + 0) % P * po ans %= P ans *= N ans %= P print(ans)