N, M = map(int, input().split()) MOD = 998244353 def tri(x): return x * (x + 1) // 2 def fix_max(k): return k * N * (tri(k) * pow(k, N - 1, MOD) - tri(k - 1) * pow(k - 1, N - 1, MOD)) % MOD def fix_min(k): return k * N * ((tri(M) - tri(k - 1)) * pow(M - k + 1, N - 1, MOD) - (tri(M) - tri(k)) * pow(M - k, N - 1, MOD)) % MOD ans = 0 for k in range(1, M + 1): # print(k, fix_max(k), fix_min(k)) ans += fix_max(k) ans -= fix_min(k) ans %= MOD print(ans)