n,c0=map(int,input().split()) c1=n-c0 M=998244353 N=n fa=[1] for i in range(1,N+1): fa+=[fa[-1]*i%M] fb=[pow(fa[N],M-2,M)] for i in reversed(range(1,N+1)): fb+=[fb[-1]*i%M] fb.reverse() fc=lambda n,k:fa[n]*fb[k]*fb[n-k]%M if n>=k>=0 else 0 a=0 if c0>=2: for i in range(c1-2+1): if (c1-i)%2==0: a+=fc(c0-2+i,i)*(c1-i-1) a%=M for i in range(c1): if (c1-i)%2==0: a+=fc(c0-2+i,i)*2 a%=M a+=fc(c0-2+c1,c1) a%=M elif c0==1: for i in range(1,n-1): a+=(i)%2==(n-1-i)%2 a+=((n-1)%2)*2 a%=M else: a=1 print(a)