import sys MOD = 998244353 def main(): N = int(sys.stdin.readline()) M = int(sys.stdin.readline()) if M > N: print(0) return mod = MOD b = N % mod t = min(M-1, b) sum_c = 0 current = 1 # C(b, 0) sum_c = current % mod for k in range(1, t + 1): current = current * (b - k + 1) % mod current = current * pow(k, mod-2, mod) % mod sum_c = (sum_c + current) % mod pow_2n = pow(2, N, mod) ans = (pow_2n - sum_c) % mod if ans < 0: ans += mod print(ans) if __name__ == "__main__": main()