import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) sys.setrecursionlimit(10**7) yes = lambda :print("yes");Yes = lambda :print("Yes") no = lambda :print("no");No = lambda :print("No") n,m = na() mod = 998244353 def g(x): return ((pow(x,n,mod)*(x+1)//2)%mod)*n%mod def h(x): return ((pow(m-x+1,n,mod)*(m+x)//2)%mod)*n%mod ans = 0 for k in range(1,m+1): #print(g(k)-g(k-1),h(k)-h(k+1)) ans+=k*((g(k)-g(k-1)-h(k)+h(k+1))%mod)%mod ans%=mod print(ans)