N,K=map(int,input().split())
A=list(map(int,input().split()))
mod=998244353
if K==0:
  print(sum(A)%mod)
  exit()
P=sum(A)*pow(N,mod-2,mod)%mod
print((sum(A)+P*(pow(2,K,mod)-1)*N)%mod)