N,K=map(int,input().split()) A=list(map(int,input().split())) mod=998244353 if K==0: print(sum(A)%mod) exit() P=sum(A)*pow(N,mod-2,mod)%mod print((sum(A)+P*(pow(2,K,mod)-1)*N)%mod)