import sys input = sys.stdin.readline MOD = 998244353 N, M = map(int, input().split()) ans = 0 mn = pow(M, N, MOD) for i in range(1, M+1): if i==M: ans += N ans %= MOD else: tmp = i*(pow(i, N, MOD) - mn)%MOD * pow(mn, MOD-2, MOD) % MOD * pow(i-M, MOD-2, MOD) % MOD ans += tmp ans %= MOD ans *= mn ans %= MOD print(ans)