import java.util.Scanner; public class Main { public static void main(String[] args) throws Exception { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); sc.close(); int mod = 998244353; long v1 = (long) k * (k - 1); long v2 = v1 * n % mod; long u1 = power(k, n, mod); long u2 = modinv(u1, mod); long ans = v2 * u2 % mod; System.out.println(ans); } static long modinv(long a, int m) { long b = m; long u = 1; long v = 0; long tmp = 0; while (b > 0) { long t = a / b; a -= t * b; tmp = a; a = b; b = tmp; u -= t * v; tmp = u; u = v; v = tmp; } u %= m; if (u < 0) u += m; return u; } static long power(long x, long n, int m) { if (n == 0) { return 1; } long val = power(x, n / 2, m); val = val * val % m; if (n % 2 == 1) { x %= m; val = val * x % m; } return val; } }