from sys import stdin import sys import math mod = 998244353 def inverse(x): return pow(x,mod-2,mod) half = inverse(2) N,M = map(int,stdin.readline().split()) ans = 0 dp = [] nowok = 0 for K in range(M): a = M-K if K != 0: nowok += pow(K,N,mod) - pow(K-1,N,mod) else: nowok += pow(K,N,mod) nowok %= mod other = (pow(K+1,N,mod) - pow(K,N,mod)) * (M-K) dp.append((nowok + other) % mod) dp.append(0) ans = 0 for K in range(M): ans += (dp[K]-dp[K-1]) * K print (ans * (M+1) * N * half % mod)