結果

問題 No.2409 Strange Werewolves
ユーザー rulerruler
提出日時 2023-08-14 15:19:57
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
AC  
実行時間 717 ms / 2,000 ms
コード長 2,709 bytes
コンパイル時間 146 ms
コンパイル使用メモリ 12,928 KB
実行使用メモリ 93,100 KB
最終ジャッジ日時 2024-05-02 03:25:04
合計ジャッジ時間 15,048 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 717 ms
92,572 KB
testcase_01 AC 712 ms
92,552 KB
testcase_02 AC 662 ms
92,684 KB
testcase_03 AC 660 ms
93,044 KB
testcase_04 AC 674 ms
93,096 KB
testcase_05 AC 660 ms
92,528 KB
testcase_06 AC 654 ms
92,664 KB
testcase_07 AC 659 ms
92,612 KB
testcase_08 AC 652 ms
92,748 KB
testcase_09 AC 656 ms
93,096 KB
testcase_10 AC 664 ms
92,584 KB
testcase_11 AC 657 ms
92,688 KB
testcase_12 AC 665 ms
93,100 KB
testcase_13 AC 672 ms
92,836 KB
testcase_14 AC 652 ms
92,564 KB
testcase_15 AC 655 ms
92,928 KB
testcase_16 AC 685 ms
92,768 KB
testcase_17 AC 687 ms
92,680 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

V = sys.version_info
_39 = False
_310 = False
_311 = False
if V.major == 3:
    _39 = V.minor >= 9
    _310 = V.minor >= 10
    _311 = V.minor >= 11

if _39:
    li = list
    tup = tuple
    dic = dict
    st = set
    ty = type
else:
    from typing import (
        List,
        Tuple,
        Type,
        Dict,
        Set,
    )

    li = List
    tup = Tuple
    dic = Dict
    st = Set
    ty = Type

from sys import stdin

read = stdin.buffer.read
rl = stdin.buffer.readline
rb = lambda: rl().split()
rls = stdin.buffer.readlines

from typing import Iterable


def prints(
    a: Iterable[object],
    sep: str = "\n",
) -> None:
    print(sep.join(map(str, a)))


import numpy as np


def cumprod(
    m: int,
    a: np.ndarray,
) -> np.ndarray:
    assert a.ndim == 1
    n = a.size
    k = int(n**0.5 + 1)
    a = np.resize(a, (k, k))
    for i in range(k - 1):
        a[:, i + 1] *= a[:, i]
        a[:, i + 1] %= m
    for i in range(k - 1):
        a[i + 1] *= a[i, -1]
        a[i + 1] %= m
    return a.ravel()[:n]


def fact(m: int, n: int) -> np.ndarray:
    f = np.arange(n)
    f[0] = 1
    return cumprod(m, f)


def tables(
    m: int,
    n: int,
) -> tup[(np.ndarray,) * 3]:
    assert n <= m
    f = fact(m, n)
    fi = np.arange(n, 0, -1)
    fi[0] = pow(int(f[-1]), -1, m)
    fi = cumprod(m, fi)[::-1]
    inv = fi.copy()
    inv[0] = 0
    inv[1:] *= f[:-1]
    return f, fi, inv % m


class Comb:
    m: int
    f: np.ndarray
    fi: np.ndarray
    inv: np.ndarray

    def __init__(
        self,
        m: int,
        n: int,
    ) -> None:
        self.m = m
        self.f, self.fi, self.inv = tables(
            m, n
        )

    # arary-like type
    # not only int
    def p(self, n: int, k: int) -> int:
        ok = (0 <= k) & (k <= n)
        v = self.f[n] * self.fi[n - k]
        return v % self.m * ok

    def c(self, n: int, k: int) -> int:
        v = self.p(n, k) * self.fi[k]
        return v % self.m

    def h(self, n: int, k: int) -> int:
        return self.c(n - 1 + k, k)

    # nCk is undefined -> fill with 0
    def ip(self, n: int, k: int) -> int:
        ok = (0 <= k) & (k <= n)
        # assert np.all(ok)
        v = self.fi[n] * self.f[n - k]
        return v % self.m * ok

    def ic(self, n: int, k: int) -> int:
        v = self.ip(n, k) * self.f[k]
        return v % self.m


def solve() -> None:
    mod = 998_244_353
    x, y, z, w = map(int, rb())
    f = Comb(mod, 1 << 20)
    z = max(z, 1)
    w = max(w, 1)
    v = f.f[x + y - z - w]
    v *= f.c(x, z)
    v %= mod
    v *= f.c(y, w)
    v %= mod
    print(v)


def main() -> None:
    t = 1
    # t = int(rl())
    for _ in range(t):
        solve()


main()
0