mod = 998244353 def main(): import sys input = sys.stdin.readline N, K = map(int, input().split()) ans = 0 for x in range(1, K+1): ans += ((x * N * (K - x))%mod * (pow(x, N - 1, mod) - pow(x - 1, N - 1, mod)))%mod ans += (x * (pow(x, N, mod) - pow(x-1, N, mod) - N * pow(x-1, N-1, mod))%mod)%mod ans %= mod print(ans) if __name__ == '__main__': main()