n = int(input()) a = list(map(int, input().split())) b = list(map(int, input().split())) a.sort() b.sort() ans = 0 i = 0 j = 0 for _ in range(n): if a[i] < b[j]: i += 1 else: j += 1 mod = 998244353 def f(x): ans = 1 for i in range(x): ans *= i + 1 ans %= mod return ans # print(i) print(f(i) * f(n-i) % mod)