import sys sys.setrecursionlimit(10**5) input=sys.stdin.readline import pypyjit pypyjit.set_param('max_unroll_recursion=-1') MOD=998244353 n,k=map(int,input().split()) print(n*k*pow(pow(k,n,MOD),MOD-2,MOD)%MOD)