import sys input = lambda : sys.stdin.readline().rstrip() sys.setrecursionlimit(2*10**5+10) write = lambda x: sys.stdout.write(x+"\n") debug = lambda x: sys.stderr.write(x+"\n") writef = lambda x: print("{:.12f}".format(x)) n,m = list(map(int, input().split())) M = 998244353 ans = 0 inv2 = pow(2, M-2, M) def f(v): return inv2 * (n * ((v+1) * pow(v, n, M) % M) % M) % M def g(v): return inv2 * ((m + v) * (n * pow(m-v+1, n, M) % M) % M) % M for v in range(1,m+1): ans += v * (f(v) - f(v-1)) % M ans -= v * (g(v) - g(v+1)) % M ans %= M print(ans%M)