import sys input=lambda: sys.stdin.readline().rstrip() n,m=map(int,input().split()) mod=998244353 if n>m or n==1: print(1) else: A=[0]*m for i in range(m): if imod: A[i]%=mod print(A[-1])