MOD = 998244353 N, K = map(int, input().split()) K_ = pow(K, MOD-2, MOD) ans = pow(K_, N, MOD) ans *= N ans %= MOD for i in range(K-1, K+1): ans *= i ans %= MOD print(ans)