def pow(x,n): if n==0: return 1 K=1 while n>1: if n%2!=0: K=K*x x=x**2 n//=2 return K*x a,n=map(int,input().split()) M=998244353 print(pow(a,n)%M)