import sys input = sys.stdin.readline MOD = 998244353 N, K = map(int, input().split()) L = 10**8 dp = [ 1, 808258749, 117153405, 761699708, 573994984, 62402409, 511621808, 242726978, 887890124, 875880304, ] K = min(K, N-K) if N-K>=MOD: bunsi = 1 for i in range(N-K+1, N+1): bunsi = bunsi * i % MOD bunbo = 1 for i in range(1, K+1): bunbo = bunbo * i % MOD ans = bunsi * pow(bunbo, MOD-2, MOD) % MOD print(ans) exit() cnt = N//L bunsi = dp[cnt] for i in range(cnt*L+1, N+1): bunsi = bunsi * i % MOD cnt = K//L bunbo = dp[cnt] for i in range(cnt*L+1, K+1): bunbo = bunbo * i % MOD cnt = (N-K)//L bunbo = bunbo * dp[cnt] % MOD for i in range(cnt*L+1, (N-K)+1): bunbo = bunbo * i % MOD ans = bunsi * pow(bunbo, MOD-2, MOD) % MOD print(ans)