n, m = map(int, input().split()) MOD = 998244353 inv2 = pow(2, MOD - 2, MOD) pows = [1] * (5 * 10 ** 5 + 10) for i in range(5 * 10 ** 5 + 10): pows[i] = pow(i, n, MOD) def sum_max(k): ans = k # 最大 r = pows[k] * (k + 1) * inv2 * n l = pows[k - 1] * k * inv2 * n ans *= r - l return ans % MOD def sum_min(k): ans = k # 最小 r = pows[m - k + 1] * (k + m) * inv2 * n l = pows[m - k] * (k + m + 1) * inv2 * n ans *= r - l return ans % MOD ans = 0 for k in range(1, m + 1): ans += sum_max(k) ans -= sum_min(k) ans %= MOD print(ans) """ import itertools tmp = 0 for ptn in itertools.product([i + 1 for i in range(m)], repeat=n): tmp += sum(ptn) * max(ptn) tmp %= MOD print(tmp)"""