def union(x,y): rx=root(x) ry=root(y) if rx==ry: return if rx>ry: rx,ry=ry,rx r[ry]=rx return def root(x): p=x l=[p] while r[p]!=p: p=r[p] l.append(p) for p in l: r[p]=l[-1] return r[x] M=998244353 n,m=map(int,input().split()) a=list(map(int,input().split())) r=list(range(n)) for i in range(m): u,v=map(int,input().split()) u-=1 v-=1 union(u,v) g=[0]*n for i in range(n): ri=root(i) g[ri]+=a[i] g[ri]%=M b=1 for i in range(n): ri=root(i) b*=g[ri] b%=M print(b)