import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) sys.setrecursionlimit(10**7) yes = lambda :print("yes");Yes = lambda :print("Yes") no = lambda :print("no");No = lambda :print("No") n,m = na() mod = 998244353 z = [pow(i,n,mod) for i in range(m+1)] d2 = pow(2,mod-2,mod) def g(x): return ((z[x]*(x+1)*d2)%mod)*n%mod def h(x): return ((z[m-x+1]*(m+x)*d2)%mod)*n%mod ans = 0 for k in range(1,m+1): #print(g(k)-g(k-1),h(k)-h(k+1)) ans+=k*(g(k)-g(k-1))%mod ans%=mod ans-=k*(h(k)-h(k+1))%mod ans%=mod print(ans)