import sys input = lambda : sys.stdin.readline().rstrip() write = lambda x: sys.stdout.write(x+"\n") debug = lambda x: sys.stderr.write(x+"\n") writef = lambda x: print("{:.12f}".format(x)) # sys.setrecursionlimit(3*10**5+10) n = int(input()) p = list(map(int, input().split())) index = [0]*n for i,v in enumerate(p): index[v] = i M = 998244353 lr = [] inf = 10**9 mm = inf MM = -1 ans = 1 for i in range(n): ind = index[i] mm = min(mm, ind) MM = max(MM, ind) lr.append((mm,MM)) mm = inf MM = -1 for i in range(n): l,r = lr[i] if (mm, MM) == (l,r): ans *= (MM - mm + 1) - i ans %= M mm = min(mm, l) MM = max(MM, r) print(ans%M)