n,k=map(int,input().split());mod=998244353 a=list(map(int,input().split())) b=sum(a)%mod c=pow(2,k,mod) print(b*c%mod)