import java.io.BufferedInputStream; import java.io.IOException; import java.io.PrintWriter; import java.util.*; import java.util.concurrent.ThreadLocalRandom; import java.util.function.IntUnaryOperator; import java.util.function.LongUnaryOperator; import java.util.stream.Collectors; import java.util.stream.IntStream; import static java.lang.Math.min; public class Main { static In in = new In(); static Out out = new Out(false, false); static final long inf = 0x1fffffffffffffffL; static final int iinf = 0x3fffffff; static final double eps = 1e-9; static long mod = 998244353; void solve() { long n = in.nextLong(); int k = in.nextInt(); long a1 = f(n - 1, k); long a2 = f(n - 1, k + 1); long temp = (n * a1 + mod - a2) % mod; long fact = 1; for (long i = 2; i < n; i++) { fact = fact * i % mod; } out.println(fact * 2 * temp % mod); } long f(long n, int k) { long[] y = new long[k + 2]; for (int i = 1; i < k + 2; i++) { y[i] = (y[i - 1] + Comb.pow(i, k)) % mod; } return lagrange(y, n); } public long lagrange(long[] y, long t) { int n = y.length - 1; if (t <= n) { return y[(int)t]; } long ret = 0; long[] dp = new long[n + 1]; long[] pd = new long[n + 1]; dp[0] = 1; pd[n] = 1; for (int i = 0; i < n; i++) { dp[i + 1] = dp[i] * ((t - i) % mod) % mod; } for (int i = n; i > 0; i--) { pd[i - 1] = pd[i] * ((t - i) % mod) % mod; } for (int i = 0; i <= n; i++) { long temp = y[i] * dp[i] % mod * pd[i] % mod * Comb.invFact(i) % mod * Comb.invFact(n - i) % mod; if (((n - i) & 1) == 1) { ret = (ret + mod - temp) % mod; } else { ret = (ret + temp) % mod; } } return ret; } public static void main(String... args) { new Main().solve(); out.flush(); } } class Comb { private static final int MEMO_THRESHOLD = 1000000; static long mod = Main.mod; private static final List inv = new ArrayList<>(); private static final List fact = new ArrayList<>(); private static final List invFact = new ArrayList<>(); private static final Map> pow = new HashMap<>(); private static void buildInvTable(int n) { if (inv.isEmpty()) { inv.add(null); inv.add(1L); } for (int i = inv.size(); i <= n; i++) { inv.add(mod - inv.get((int)(mod % i)) * (mod / i) % mod); } } private static void buildFactTable(int n) { if (fact.isEmpty()) { fact.add(1L); invFact.add(1L); } for (int i = fact.size(); i <= n; i++) { fact.add(fact.get(i - 1) * i % mod); invFact.add(inv(fact.get(i))); } } public static void setupPowTable(long a) { pow.put(a, new ArrayList<>(Collections.singleton(1L))); } private static void rangeCheck(long n, long r) { if (n < r) { throw new IllegalArgumentException("n < r"); } if (n < 0) { throw new IllegalArgumentException("n < 0"); } if (r < 0) { throw new IllegalArgumentException("r < 0"); } } static long fact(int n) { buildFactTable(n); return fact.get(n); } static long invFact(int n) { buildFactTable(n); return invFact.get(n); } private static long comb0(int n, int r) { rangeCheck(n, r); return fact(n) * invFact(r) % mod * invFact(n - r) % mod; } static long comb(long n, long r) { rangeCheck(n, r); if (n < MEMO_THRESHOLD) { return comb0((int)n, (int)r); } r = min(r, n - r); long x = 1, y = 1; for (long i = 1; i <= r; i++) { x = x * (n - r + i) % mod; y = y * i % mod; } return x * inv(y) % mod; } private static long perm0(int n, int r) { rangeCheck(n, r); return fact(n) * invFact(n - r) % mod; } static long perm(long n, long r) { rangeCheck(n, r); if (n < MEMO_THRESHOLD) { return perm0((int)n, (int)r); } long x = 1; for (long i = 1; i <= r; i++) { x = x * (n - r + i) % mod; } return x; } static long homo(long n, long r) { return r == 0 ? 1 : comb(n + r - 1, r); } private static long inv0(int a) { buildInvTable(a); return inv.get(a); } static long inv(long a) { if (a < MEMO_THRESHOLD) { return inv0((int)a); } long b = mod; long u = 1, v = 0; while (b >= 1) { long t = a / b; a -= t * b; u -= t * v; if (a < 1) { return (v %= mod) < 0 ? v + mod : v; } t = b / a; b -= t * a; v -= t * u; } return (u %= mod) < 0 ? u + mod : u; } static long pow(long a, long b) { if (pow.containsKey(a) && b < MEMO_THRESHOLD) { return powMemo(a, (int)b); } long x = 1; while (b > 0) { if (b % 2 == 1) { x = x * a % mod; } a = a * a % mod; b >>= 1; } return x; } static long powMemo(long a, int b) { List powMemo = pow.get(a); while (powMemo.size() <= b) { powMemo.add(powMemo.get(powMemo.size() - 1) * a % mod); } return powMemo.get(b); } static long sqrt(long x) { if (x < 2) { return x; } long p = (mod - 1) / 2; if (pow(x, p) != 1) { return -1; } while (true) { long a = ThreadLocalRandom.current().nextLong(mod); long w = (a * a + mod - x) % mod; if (pow(w, p) == 1) { continue; } long n = p + 1; long k = Long.highestOneBit(n); long r = 1; long i = 0; while (k > 0) { long nr = (r * r + i * i % mod * w) % mod; long ni = r * i * 2 % mod; r = nr; i = ni; if ((n & k) > 0) { nr = (r * a + i * w) % mod; ni = (i * a + r) % mod; r = nr; i = ni; } k >>= 1; } return min(r, mod - r); } } } class In { private final BufferedInputStream reader = new BufferedInputStream(System.in); private final byte[] buffer = new byte[0x10000]; private int i = 0; private int length = 0; public int read() { if (i == length) { i = 0; try { length = reader.read(buffer); } catch (IOException ignored) { } if (length == -1) { return 0; } } if (length <= i) { throw new RuntimeException(); } return buffer[i++]; } public String next() { StringBuilder builder = new StringBuilder(); int b = read(); while (b < '!' || '~' < b) { b = read(); } while ('!' <= b && b <= '~') { builder.appendCodePoint(b); b = read(); } return builder.toString(); } public String nextLine() { StringBuilder builder = new StringBuilder(); int b = read(); while (b != 0 && b != '\r' && b != '\n') { builder.appendCodePoint(b); b = read(); } if (b == '\r') { read(); } return builder.toString(); } public int nextInt() { long val = nextLong(); if (val < Integer.MIN_VALUE || Integer.MAX_VALUE < val) { throw new NumberFormatException(); } return (int)val; } public long nextLong() { int b = read(); while (b < '!' || '~' < b) { b = read(); } boolean neg = false; if (b == '-') { neg = true; b = read(); } long n = 0; int c = 0; while ('0' <= b && b <= '9') { n = n * 10 + b - '0'; b = read(); c++; } if (c == 0 || c >= 2 && n == 0) { throw new NumberFormatException(); } return neg ? -n : n; } public double nextDouble() { return Double.parseDouble(next()); } public char[] nextCharArray() { return next().toCharArray(); } public String[] nextStringArray(int n) { String[] s = new String[n]; for (int i = 0; i < n; i++) { s[i] = next(); } return s; } public char[][] nextCharMatrix(int n, int m) { char[][] a = new char[n][m]; for (int i = 0; i < n; i++) { a[i] = next().toCharArray(); } return a; } public int[] nextIntArray(int n) { int[] a = new int[n]; for (int i = 0; i < n; i++) { a[i] = nextInt(); } return a; } public int[] nextIntArray(int n, IntUnaryOperator op) { int[] a = new int[n]; for (int i = 0; i < n; i++) { a[i] = op.applyAsInt(nextInt()); } return a; } public int[][] nextIntMatrix(int h, int w) { int[][] a = new int[h][w]; for (int i = 0; i < h; i++) { a[i] = nextIntArray(w); } return a; } public long[] nextLongArray(int n) { long[] a = new long[n]; for (int i = 0; i < n; i++) { a[i] = nextLong(); } return a; } public long[] nextLongArray(int n, LongUnaryOperator op) { long[] a = new long[n]; for (int i = 0; i < n; i++) { a[i] = op.applyAsLong(nextLong()); } return a; } public long[][] nextLongMatrix(int h, int w) { long[][] a = new long[h][w]; for (int i = 0; i < h; i++) { a[i] = nextLongArray(w); } return a; } public List> nextGraph(int n, int m, boolean directed) { List> res = new ArrayList<>(); for (int i = 0; i < n; i++) { res.add(new ArrayList<>()); } for (int i = 0; i < m; i++) { int u = nextInt() - 1; int v = nextInt() - 1; res.get(u).add(v); if (!directed) { res.get(v).add(u); } } return res; } } class Out { private final PrintWriter out = new PrintWriter(System.out); private final PrintWriter err = new PrintWriter(System.err); public boolean autoFlush; public boolean enableDebug; public Out(boolean autoFlush, boolean enableDebug) { this.autoFlush = autoFlush; this.enableDebug = enableDebug; } public void debug(Object... args) { if (!enableDebug) { return; } if (args == null || args.getClass() != Object[].class) { args = new Object[] {args}; } err.println(Arrays.stream(args).map(obj -> format(obj, true)).collect(Collectors.joining(" "))); err.flush(); } private String format(Object obj, boolean canMultiline) { if (obj == null) return "null"; Class clazz = obj.getClass(); if (clazz == Double.class) return String.format("%.10f", obj); if (clazz == int[].class) return Arrays.toString((int[])obj); if (clazz == long[].class) return Arrays.toString((long[])obj); if (clazz == char[].class) return String.valueOf((char[])obj); if (clazz == boolean[].class) return IntStream.range(0, ((boolean[])obj).length).mapToObj(i -> ((boolean[])obj)[i] ? "1" : "0").collect(Collectors.joining()); if (clazz == double[].class) return Arrays.toString(Arrays.stream((double[])obj).mapToObj(a -> format(a, false)).toArray()); if (canMultiline && clazz.isArray() && clazz.componentType().isArray()) return Arrays.stream((Object[])obj).map(a -> format(a, false)).collect(Collectors.joining("\n")); if (clazz == Object[].class) return Arrays.toString(Arrays.stream((Object[])obj).map(a -> format(a, false)).toArray()); if (clazz.isArray()) return Arrays.toString((Object[])obj); return String.valueOf(obj); } public void println(Object... args) { if (args == null || args.getClass() != Object[].class) { args = new Object[] {args}; } out.println(Arrays.stream(args) .map(obj -> obj instanceof Double ? String.format("%.10f", obj) : String.valueOf(obj)) .collect(Collectors.joining(" "))); if (autoFlush) { out.flush(); } } public void println(char a) { out.println(a); if (autoFlush) { out.flush(); } } public void println(int a) { out.println(a); if (autoFlush) { out.flush(); } } public void println(long a) { out.println(a); if (autoFlush) { out.flush(); } } public void println(double a) { out.println(String.format("%.10f", a)); if (autoFlush) { out.flush(); } } public void println(String s) { out.println(s); if (autoFlush) { out.flush(); } } public void println(char[] s) { out.println(String.valueOf(s)); if (autoFlush) { out.flush(); } } public void println(int[] a) { StringJoiner joiner = new StringJoiner(" "); for (int i : a) { joiner.add(Integer.toString(i)); } out.println(joiner); if (autoFlush) { out.flush(); } } public void println(long[] a) { StringJoiner joiner = new StringJoiner(" "); for (long i : a) { joiner.add(Long.toString(i)); } out.println(joiner); if (autoFlush) { out.flush(); } } public void flush() { err.flush(); out.flush(); } }