n = int(input()) a = list(map(int, input().split())) g = {} dist = [0]*n for i in range(n): ai = a[i] if ai in g: dist[i] = g[ai] if ai+1 in g: g[ai+1] += dist[i]+1 else: g[ai+1] = dist[i]+1 g[ai+1] %= 998244353 dist[i] %= 998244353 print(sum(dist)%998244353)