N, M = map(int, input().split()) MOD = 998244353 pow_table = [pow(k, N - 1, MOD) for k in range(M + 1)] def tri(x): return x * (x + 1) // 2 def fix_max(k): return k * N * (tri(k) * pow_table[k] - tri(k - 1) * pow_table[k - 1]) % MOD def fix_min(k): return k * N * ((tri(M) - tri(k - 1)) * pow_table[M - k + 1] - (tri(M) - tri(k)) * pow_table[M - k]) % MOD ans = 0 for k in range(1, M + 1): # print(k, fix_max(k), fix_min(k)) ans += fix_max(k) ans -= fix_min(k) ans %= MOD print(ans)