import java.util.Scanner; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); long n = sc.nextLong(); long k = sc.nextLong(); long mod = 998244353; long [] q = new long [(int)n + 10]; long ans = 0; for(int i = (int)k + 1; i > 0; i--){ q[i] = modPow(k, n, mod) - modPow(i - 1,n, mod) - (((k - i + 1) * n) % mod) * modPow(i - 1, n - 1, mod); q[i] %= mod; ans += ((q[i] - q[i + 1] + mod) * i) % mod; ans %= mod; } System.out.println(ans); } public static long modPow(long a, long n, long mod) { long ans = 1; long tmp = a; while(true) { if (n < 1) { break; } if (n % 2 == 1) { ans *= tmp; ans %= mod; } tmp *= tmp; tmp %= mod; n /= 2; } return ans; } }