# import pypyjit;pypyjit.set_param("max_unroll_recursion=-1") # from bisect import * # from collections import * # from heapq import * # from itertools import permutations # from math import sqrt, ceil # from random import randint # from datetime import * # from decimal import* # from string import ascii_lowercase,ascii_uppercase # import numpy as np import sys import os # sys.setrecursionlimit(10**7) INF = 10**18 MOD = 998244353 # MOD = 10**9 + 7 File = open("input.txt", "r") if os.path.exists("input.txt") else sys.stdin def input(): return File.readline()[:-1] # /////////////////////////////////////////////////////////////////////////// def modP(p, q): return (p * pow(q, MOD - 2, MOD)) % MOD N, K = map(int, input().split()) print(modP((K - 1) * K * N, pow(K, N)) % MOD)