N, K = map(int, input().split()) dp = [[0 for _ in range(N + 1)] for _ in range(K + 1)] dp[0][0] = 1 mod = 998244353 for i in range(K): for j in range(N + 1): if j > 0: dp[i + 1][j - 1] = (dp[i + 1][j - 1] + dp[i][j] * j) % mod if j < N: dp[i + 1][j + 1] = (dp[i + 1][j + 1] + dp[i][j] * (N - j)) % mod div = pow(pow(N, K, mod), mod - 2, mod) ans = 0 for i in range(N + 1): ans = (ans + ((K - i) // 2 + N) * dp[K][i]) % mod print((ans * div) % mod)