n,k=map(int,input().split()) l=list(map(int,input().split())) print(sum(l)*pow(2,k,998244353))