#!/usr/bin/env PyPy3

from collections import Counter, defaultdict, deque
import itertools
import re
import math
from functools import reduce
import operator
import bisect
from heapq import *
import functools
mod=998244353

import sys
input=sys.stdin.readline
def nCk(n, k):
    if k < 0 or n < k:
        return 0
    k = min(k, n - k)
    ret = 1
    for i in range(n, n - k, -1):
        ret *= i
        ret %= mod
    inv = 1
    for i in range(2, k + 1):
        inv *= i
        inv %= mod
    return ret * pow(inv, mod - 2, mod) % mod
 


a,b = map(int,input().split())
print(nCk(a+b-2,a-1))