from collections import *
from itertools import *
from functools import *
from heapq import *
import sys,math
input = sys.stdin.readline

N,K = map(int,input().split())
mod = 998244353

ans = 1
ans *= N*K*(K-1)
ans %= mod
ans *= pow(pow(K,N,mod),mod-2,mod)
print(ans%mod)