""" 1887 まぁ、ない場合を計算すればヨシ dp[i][v] = i番目まで見て、i番目に置いたのがvの場合の数…? 同じ数字はまとめておくことにする。 推移を考える。 まず、区間の長さを決めると、その数字は使用不可になる。 また、直前の数字は使用不可となる。 貰うdpで考えるか~ 貰う時、lastindexを指定したとする。 直前の区間長に等しい色の場合に関しては、最初から全部除いておいていい。 自分の選ぶ色に関しては、 色cを選んだとき、 直前同じ色の時だけ、さらに除く。 ただし、あらかじめ除かれている場合はもう気にしなくてよい。 dp[i][c] に推移できないのは dp[i-c以下][?] か dp[?][c] である。 indsumは大きいほうから累積和しておく。 dp[?][c] の方は… 2 3 3*3 = 9通り 23 32 11 33 → 11 がまずいのか """ import sys from sys import stdin def ss(x1,y1,x2,y2): if x1 > x2 or y1 > y2: return 0 x1 = max(x1,0) ret = dp[x2][y2] if x1-1 >= 0: ret -= dp[x1-1][y2] if y1-1 >= 0: ret -= dp[x2][y1-1] if x1-1 >= 0 and y1-1 >= 0: ret += dp[x1-1][y1-1] return ret N,M = map(int,stdin.readline().split()) dp = [[0] * (M+1) for i in range(N+1)] dp[0] = [1] * (M+1) mod = 998244353 for i in range(1,N+1): for c in range(1,M+1): dp[i][c] = ( ss(i-c+1,0,i-1,M) - ss(i-c+1,c,i-1,c) ) % mod for c in range(M): dp[i][c+1] += dp[i][c] dp[i][c+1] %= mod for c in range(M+1): dp[i][c] += dp[i-1][c] dp[i][c] %= mod #print (dp) ansrev = (ss(N,0,N,M) % mod) ans = pow(M,N,mod) - ansrev print (ans % mod)