n = int(input()) m = int(input()) mod = 998244353 if n < m: print(0) exit() ans = 1 for i in range(m-1): ans *= n-i ans *= pow(1+i,mod-2,mod) ans %= mod ans *= pow(2,n-m+1,mod)-1 print(ans%mod)