N,K = map(int,input().split())
P = 998244353


ans = 0
dat = [0] * (K + 1)

for i in range(K):
    dat[K - i] = pow(K,N,P) - (K - i - 1 + (i + 1) * N) * pow(K - i -1,N - 1,P)
S = 0
"""
for i in range(K,0,-1):
    dat[i] -= S
    dat[i] %= P
    S += dat[i]
    S %= P
    ans += dat[i] * i
    ans %= P
"""
for i in range(1,K):
    dat[i] -= dat[i+1]
    ans += dat[i] * i
    ans %= P
ans += dat[K] * K
ans %= P
print(ans)