#!/usr/bin/env PyPy3 from collections import Counter, defaultdict, deque import itertools import re import math from functools import reduce import operator import bisect import heapq import functools mod=998244353 import sys input=sys.stdin.readline n,k=map(int,input().split()) ans = 0 ans2 = 0 for i in range(1,k+1): ans += i * n * (k - i) % mod * (pow(i,n-1,mod) - pow(i-1,n-1,mod)) % mod ans %= mod tmp = ((pow(i,n,mod) - pow(i-1,n,mod)) % mod - n * pow(i-1,n-1,mod)) % mod ans += i * tmp ans %= mod print(ans)