mod = 998244353 n = int(input()) c1 = sum(map(int, input().split())) c0 = n - c1 if c0 > c1: c0, c1 = c1, c0 print(c0 * c1 * pow(n * c1, mod - 2, mod) % mod)