import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES") no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO") ####################################################################### n, m = na() a, b = m//n, m % n r = 1 mod = 998244353 for i in range(1, m + 1): r *= i r %= mod A = 1 for i in range(1, a + 1): A *= i A %= mod B = A * (a + 1) % mod print(r*pow(pow(A, n-b, mod),mod-2,mod)*pow(pow(B, b, mod),mod-2,mod)%mod)