P=998244353 for t in range(int(input())): N,K=map(int,input().split()) print(sum(pow(k,N,P)for k in range(2,K+2))%P)