from collections import * from itertools import * from functools import * from heapq import * import sys,math input = sys.stdin.readline N,K = map(int,input().split()) mod = 998244353 ans = 1 ans *= N*K*(K-1) ans %= mod ans *= pow(pow(K,N,mod),mod-2,mod) print(ans%mod)