MOD = 998244353 def main(): import sys input = sys.stdin.read().split() N = int(input[0]) K = int(input[1]) if K == 0: print(0) return # Precompute inverses of 1..K inv = [1] * (K + 1) for i in range(1, K + 1): inv[i] = pow(i, MOD - 2, MOD) # Compute combination numbers C(K, j) for j in 0..K comb = [0] * (K + 1) comb[0] = 1 for j in range(1, K + 1): comb[j] = comb[j-1] * (K - j + 1) % MOD comb[j] = comb[j] * inv[j] % MOD # Precompute pow_val[j] = (K - j)^N mod MOD pow_val = [0] * (K + 1) for j in range(K + 1): m = K - j if m == 0: pow_val[j] = 0 else: pow_val[j] = pow(m, N, MOD) # Precompute power_of_minus_two: (-2)^j mod MOD power_minus_two = [1] * (K + 1) minus_two = MOD - 2 for j in range(1, K + 1): power_minus_two[j] = power_minus_two[j - 1] * minus_two % MOD # Compute sum_terms sum_terms = 0 for j in range(K + 1): term = comb[j] * pow_val[j] % MOD term = term * power_minus_two[j] % MOD sum_terms = (sum_terms + term) % MOD # Compute K^N mod MOD pow_k_n = pow(K, N, MOD) # Compute the answer ans = (pow_k_n - sum_terms) % MOD inv_2 = pow(2, MOD - 2, MOD) ans = ans * inv_2 % MOD print(ans) if __name__ == '__main__': main()