import math def main(): mod = 998244353 N,M = map(int,input().split()) mx = 0 for i in range(1,M+1): #1-M temp = (i*(i+1)//2)*pow(i,N-1,mod)*N - ((i-1)*i//2)*pow(i-1,N-1,mod)*N mx += temp*i mx %= mod mn = 0 for i in range(1,M+1): temp = ((M+i)*(M-i+1)//2)*pow(M-i+1,N-1,mod)*N - ((M+i+1)*(M-i)//2)*pow(M-i,N-1,mod)*N mn += temp*i mn %= mod ans = mx - mn print(ans%mod) if __name__ == '__main__': main()