from collections import defaultdict from bisect import bisect_left import sys input = sys.stdin.readline n = int(input()) A = list(map(int, input().split())) mod = 998244353 setA = set() idx_list = defaultdict(list) for i in range(n): setA.add(A[i]) idx_list[A[i]].append(i) sortA = sorted(A) ans = 1 idx = -1 for a in sortA: L = idx_list[a] i = bisect_left(L, idx) if i == len(L): continue if idx != -1: ans *= L[i] - idx + 1 ans %= mod for j in range(i, len(L) - 1): ans *= L[j + 1] - L[j] + 1 ans %= mod idx = L[-1] print(ans)