from sys import stdin
n, *a = map(int, stdin.read().split())
p = 998244353
if n <= 2:
    print("{}".format(1))
    exit()

def powmod(n,pow,mod):
    val = 1
    while pow > 0:
        if pow & 1:
            val = (val * n) % mod
        pow = pow >> 1
        n = (n * n) % mod
    return val

kaijo = [1 for i in range(n)]
for i in range(1,n):
    kaijo[i] = kaijo[i-1] * i % p

ichi = [0 for i in range(n)]
for i in range(n):
    ichi[a[i]] = i

left = ichi[0]
right = ichi[1]
if left > right:
    left, right = right, left

anslist = [(right - left - 1,1)]
for i in range(2,n):
    if ichi[i] < left:
        anslist.append((left - ichi[i] - 1,i))
        left = ichi[i]
    elif ichi[i] > right:
        anslist.append((ichi[i] - right - 1,i))
        right = ichi[i]

last = 0
remained = 0
ans = 1
for range, ind in anslist:
    ans = ans * kaijo[remained] % p * powmod(kaijo[remained-(ind-last-1)],p-2,p) % p
    remained += range
    remained -= ind-last-1
    last = ind

ans = ans * kaijo[remained] % p
print("{}".format(ans))