n,m=map(int,input().split()) p=list(map(int,input().split())) for i in range(n): p[i]-=1 M=998244353 fa=[1] for i in range(1,n+1): fa+=[fa[-1]*i%M] fb=[pow(fa[n],M-2,M)] for i in reversed(range(1,n+1)): fb+=[fb[-1]*i%M] fb.reverse() v=[0]*n o=[] for i in range(n): if v[i]==0: o+=[0] s=i while v[s]==0: v[s]=1 o[-1]+=1 s=p[s] def solve(k): q=[[0,0] for i in range(n+1)] q[1]=[k,0] for i in range(2,n+1): q[i][0]=q[i-1][1] q[i][1]=q[i-1][0]*(k-1)+q[i-1][1]*(k-2) q[i][0]%=M q[i][1]%=M a=1 for v in o: a*=q[v][1] a%=M return a a=0 for i in range(2,m+1): a+=solve(i)*fa[m]*fb[i]*fb[m-i]*((-1)**((m-i)%2)) a%=M print(a*fb[m]%M)