import sys sys.setrecursionlimit(5*10**5) input = sys.stdin.readline import pypyjit pypyjit.set_param('max_unroll_recursion=-1') from collections import defaultdict, deque, Counter from heapq import heappop, heappush from bisect import bisect_left, bisect_right from math import gcd mod = 998244353 n = int(input()) m = int(input()) ans = pow(2, n, mod) - 1 if m > n: print(0) exit() #nCk def com(n,mod): fact = [1,1] factinv = [1,1] inv = [0,1] for i in range(2,n+1): fact.append((fact[-1]*i)%mod) inv.append((-inv[mod%i]*(mod//i))%mod) factinv.append((factinv[-1]*inv[-1])%mod) return fact, factinv f,fi = com(m+10, mod) ue = m for i in range(1,m): tmp = ue * fi[i] % mod ans -= tmp ans %= mod ue *= m-i ue %= mod print(ans)