MOD = 998244353 n, k = map(int, input().split()) dp = [0] * k for i in range(k): dp[i] = pow(k - i, n, MOD) + pow(k - i, n - 1, MOD) * i * n % MOD dp[i] %= MOD for i in range(k - 1): dp[i] -= dp[i + 1] ans = 0 for i in range(k): ans += dp[i] * (k - i) ans %= MOD print(ans)