mod = 998244353


def main():
    import sys
    input = sys.stdin.readline

    N, K = map(int, input().split())

    ans = 0
    for x in range(1, K+1):
        ans += ((x * N * (K - x))%mod * (pow(x, N - 1, mod) - pow(x - 1, N - 1, mod)))%mod
        ans += (x * (pow(x, N, mod) - pow(x-1, N, mod) - N * pow(x-1, N-1, mod))%mod)%mod
        ans %= mod
    print(ans)


if __name__ == '__main__':
    main()