mod = 998244353
n = int(input())
c1 = sum(map(int, input().split()))
c0 = n - c1
if c0 > c1:
    c0, c1 = c1, c0
print(c0 * c1 * pow(n * c1, mod - 2, mod) % mod)