from sys import stdin n, *a = map(int, stdin.read().split()) p = 998244353 if n <= 2: print("{}".format(1)) exit() def powmod(n,pow,mod): val = 1 while pow > 0: if pow & 1: val = (val * n) % mod pow = pow >> 1 n = (n * n) % mod return val kaijo = [1 for i in range(n)] for i in range(1,n): kaijo[i] = kaijo[i-1] * i % p ichi = [0 for i in range(n)] for i in range(n): ichi[a[i]] = i left = ichi[0] right = ichi[1] if left > right: left, right = right, left anslist = [(right - left - 1,1)] for i in range(2,n): if ichi[i] < left: anslist.append((left - ichi[i] - 1,i)) left = ichi[i] elif ichi[i] > right: anslist.append((ichi[i] - right - 1,i)) right = ichi[i] last = 0 remained = 0 ans = 1 for range, ind in anslist: ans = ans * kaijo[remained] % p * powmod(kaijo[remained-(ind-last-1)],p-2,p) % p remained += range remained -= ind-last-1 last = ind ans = ans * kaijo[remained] % p print("{}".format(ans))