#yuki1186 n,m=map(int,input().split()) mod=998244353 if n>m or n==1: print(1) exit() l=[0]*m for i in range(m): if i