mod = 998244353 def main(): import sys input = sys.stdin.readline N, M = map(int, input().split()) V = list(map(int, input().split())) S = [[0] * M for _ in range(N)] for i in range(N): T = input().rstrip('\n') for j in range(M): if T[j] == "o": S[i][j] = V[i] ans = 0 for _ in range(M): C = [0] * M for i in range(N): for j in range(M): C[j] += S[i][j] j_max = C.index(max(C)) ans += C[j_max] ** 2 for i in range(N): if S[i][j_max]: for j in range(M): S[i][j] = 0 print(ans) if __name__ == '__main__': main()