from math import comb, perm def main(): N, M = map(int, input().split()) minimum = M // N plus1_number = M % N patterns = 1 current_sum = M for _ in range(plus1_number): patterns *= comb(current_sum, minimum+1) current_sum -= minimum+1 patterns %= 998244353 for _ in range(N-plus1_number): patterns *= comb(current_sum, minimum) current_sum -= minimum assert current_sum == 0 patterns %= 998244353 print(patterns) if __name__ == "__main__": main()