n=int(input()) a=list(map(int,input().split())) q=[0]*n M=998244353 d={} c={} for i in range(n): if a[i]-1 in d: q[i]=(d[a[i]-1]+c[a[i]-1])%M if a[i] not in d: d[a[i]]=0 c[a[i]]=0 d[a[i]]+=q[i] c[a[i]]+=1 print(sum(q)%M)