import sys input = sys.stdin.readline mod=998244353 FACT=[1] for i in range(1,5*10**6+1): FACT.append(FACT[-1]*i%mod) FACT_INV=[pow(FACT[-1],mod-2,mod)] for i in range(5*10**6,0,-1): FACT_INV.append(FACT_INV[-1]*i%mod) FACT_INV.reverse() def Combi(a,b): if 0<=b<=a: return FACT[a]*FACT_INV[b]%mod*FACT_INV[a-b]%mod else: return 0 N,M=map(int,input().split()) ANS=0 ALL=Combi(N+M,N)-N y=M k=0 while True: y-=N k+=1 if y<0: break ANS+=Combi(N+y,y) if N==1: ANS-=M print(ANS*pow(ALL,mod-2,mod)%mod)