mod = 998244353 for _ in range(int(input())): n, k = map(int, input().split()) print(sum(pow(i, n, mod) for i in range(2, k + 2)) % mod)