#include #include #include using u32 = unsigned int; using u64 = unsigned long long; using i64 = long long; static u32 mod, r, n2_; i64 exgcd(i64 a, i64 b, i64 &x, i64 &y) { i64 d = a; if (b == 0) x = 1, y = 0; else d = exgcd(b, a % b, y, x), y -= a / b * x; return d; } inline u32 mul(u32 x, u32 y) { unsigned ret = (1ull * x * y + 1ull * (u32(1ull * x * y) * r) * mod) >> 32; return ret < mod ? ret : ret - mod; } inline u32 add(u32 a, u32 b) { u32 res = a + b; return res >= mod ? res - mod : res; } struct u32x8 { u32 v[8]; u32x8() = default; u32x8(u32 val) { for (int i = 0; i < 8; i++) v[i] = val; } u32x8 operator+(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = add(v[i], other.v[i]); return result; } u32x8 operator*(const u32x8& other) const { u32x8 result; for (int i = 0; i < 8; i++) result.v[i] = mul(v[i], other.v[i]); return result; } u32x8& operator+=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = add(v[i], other.v[i]); return *this; } u32x8& operator*=(const u32x8& other) { for (int i = 0; i < 8; i++) v[i] = mul(v[i], other.v[i]); return *this; } }; inline u32 mon_in(u32 x) { return mul(x, n2_); } inline u32 mon_out(u32 x) { u32 ret = ((x + (u64)((u32)x * r) * mod) >> 32); return ret < mod ? ret : ret - mod; } inline u32 qpow(u32 n, u32 m, u32 p) { if (!m) return 1; u32 ret = qpow(n, m >> 1, p); ret = (u64)ret * ret % p; if (m & 1) return (u64)ret * n % p; else return ret; } u32 solve(int N, u32 p) { const int mv = 8; int n = N - (N & (256 * (1 << mv) - 1)); mod = p; n2_ = -(u64)mod % mod; i64 x, y; exgcd(mod, 1ll << 32, x, y); r = -u32(x); u32 as = mon_in(1); u32 as2 = mon_in(1); u32x8 ans[8]; u32x8 ml[8]; u32x8 ad_val(mon_in(64)); for (int i = 0; i < 8; i++) ans[i] = u32x8(mon_in(1)); for (int i = 0; i < 8; i++) for (int j = 0; j < 8; j++) ml[i].v[j] = mon_in(i * 8 + j + 1); for (unsigned i = 1; i + 63 <= (n >> mv); i += 64) for (int j = 0; j < 8; j++) { ans[j] *= ml[j]; ml[j] += ad_val; } for (int i = 0; i < 8; i++) { u32 odd_prod = mon_in(1); u32 even_prod = mon_in(1); for (int j = 0; j < 8; j++) if (j & 1) odd_prod = mul(odd_prod, ans[i].v[j]); else even_prod = mul(even_prod, ans[i].v[j]); as = mul(as, odd_prod); as2 = mul(as2, even_prod); } for (int j = mv - 1; j >= 0; j--) { as = mul(as, as2); as = mul(as, mon_in(qpow(2, n >> (j + 1), p))); u32x8 inner_ans[8]; u32x8 inner_ml[8]; u32x8 inner_ad_val(mon_in(128)); const unsigned add_ = n >> (j + 1); for (int i = 0; i < 8; i++) inner_ans[i] = u32x8(mon_in(1)); for (int i = 0; i < 8; i++) for (int k = 0; k < 8; k++) inner_ml[i].v[k] = mon_in(add_ + i * 16 + k * 2 + 1); for (unsigned i = add_; i + 127 <= (n >> j); i += 128) for (int k = 0; k < 8; k++) { inner_ans[k] *= inner_ml[k]; inner_ml[k] += inner_ad_val; } for (int i = 0; i < 8; i++) { u32 prod0 = mul(mul(inner_ans[i].v[0], inner_ans[i].v[1]), mul(inner_ans[i].v[2], inner_ans[i].v[3])); u32 prod1 = mul(mul(inner_ans[i].v[4], inner_ans[i].v[5]), mul(inner_ans[i].v[6], inner_ans[i].v[7])); as2 = mul(as2, mul(prod0, prod1)); } } as = mul(as, as2); as = mon_out(as); for (int i = n + 1; i <= N; i++) as = (u64)as * i % p; return as; } int main() { int n, k; scanf("%d%d", &n, &k); int p = 998244353; n %= p; u32 a = solve(n, p); u32 b = solve(n-k, p); u32 c = solve(k, p); u32 invb = qpow(b, p-2, p); u32 invc = qpow(c, p-2, p); i64 ans = (i64)a*invb%p*invc%p; printf("%lld\n", ans); return 0; }