from collections import defaultdict n = int(input()) A = list(map(int, input().split())) mod = 998244353 ans = 0 N_cnt = defaultdict(int) for i in range(n): a = A[i] N_cnt[a] += 1 C = [(num, cnt) for num, cnt in N_cnt.items()] C.sort(key=lambda x: -x[0]) res = 0 ans = 0 for i in range(len(C) - 1): if C[i][0] - 1 == C[i + 1][0]: res += C[i][1] res %= mod else: res = 0 ans += res ans %= mod print(ans)