結果
問題 | No.2530 Yellow Cards |
ユーザー | StanMarsh |
提出日時 | 2023-11-07 03:10:50 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,695 ms / 2,000 ms |
コード長 | 8,803 bytes |
コンパイル時間 | 295 ms |
コンパイル使用メモリ | 82,432 KB |
実行使用メモリ | 265,216 KB |
最終ジャッジ日時 | 2024-09-25 23:14:31 |
合計ジャッジ時間 | 18,575 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 181 ms
90,368 KB |
testcase_01 | AC | 166 ms
90,368 KB |
testcase_02 | AC | 178 ms
91,392 KB |
testcase_03 | AC | 201 ms
91,520 KB |
testcase_04 | AC | 162 ms
90,368 KB |
testcase_05 | AC | 176 ms
91,392 KB |
testcase_06 | AC | 178 ms
91,392 KB |
testcase_07 | AC | 1,558 ms
264,960 KB |
testcase_08 | AC | 1,569 ms
264,576 KB |
testcase_09 | AC | 1,555 ms
265,088 KB |
testcase_10 | AC | 1,695 ms
264,576 KB |
testcase_11 | AC | 1,564 ms
264,704 KB |
testcase_12 | AC | 1,561 ms
265,216 KB |
testcase_13 | AC | 1,566 ms
264,960 KB |
testcase_14 | AC | 1,113 ms
186,752 KB |
testcase_15 | AC | 777 ms
134,272 KB |
testcase_16 | AC | 782 ms
148,864 KB |
testcase_17 | AC | 1,204 ms
207,744 KB |
testcase_18 | AC | 559 ms
129,792 KB |
testcase_19 | AC | 282 ms
96,256 KB |
testcase_20 | AC | 320 ms
93,568 KB |
ソースコード
from random import getrandbits, randrange from string import ascii_lowercase, ascii_uppercase import sys from math import ceil, floor, sqrt, pi, factorial, gcd, log, log10, log2, inf, cos, sin from copy import deepcopy, copy from collections import Counter, deque, defaultdict from heapq import heapify, heappop, heappush from itertools import ( accumulate, product, combinations, combinations_with_replacement, permutations, ) from bisect import bisect, bisect_left, bisect_right from functools import lru_cache, reduce from decimal import Decimal, getcontext from typing import List, Tuple, Optional inf = float("inf") def ceil_div(a, b): return (a + b - 1) // b def isqrt(num): res = int(sqrt(num)) while res * res > num: res -= 1 while (res + 1) * (res + 1) <= num: res += 1 return res def int1(s): return int(s) - 1 from types import GeneratorType def bootstrap(f, stack=[]): def wrapped(*args, **kwargs): if stack: return f(*args, **kwargs) else: to = f(*args, **kwargs) while True: if type(to) is GeneratorType: stack.append(to) to = next(to) else: stack.pop() if not stack: break to = stack[-1].send(to) return to return wrapped import sys import os input = lambda: sys.stdin.readline().rstrip("\r\n") print = lambda *args, end="\n", sep=" ": sys.stdout.write( sep.join(map(str, args)) + end ) def II(): return int(input()) def MII(base=0): return map(lambda s: int(s) - base, input().split()) def LII(base=0): return list(MII(base)) def NA(): n = II() a = LII() return n, a def read_graph(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b = MII(base) if return_edges: edges.append((a, b)) g[a].append(b) if not directed: g[b].append(a) if return_edges: return g, edges return g def read_graph_with_weight(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b, w = MII() a, b = a - base, b - base if return_edges: edges.append((a, b, w)) g[a].append((b, w)) if not directed: g[b].append((a, w)) if return_edges: return g, edges return g def iterate_tokens(): for line in sys.stdin: for word in line.split(): yield word tokens = None def NI(): global tokens if tokens is None: tokens = iterate_tokens() return int(next(tokens)) def LNI(n): return [NI() for _ in range(n)] def yes(res): print("Yes" if res else "No") def YES(res): print("YES" if res else "NO") def pairwise(a): n = len(a) for i in range(n - 1): yield a[i], a[i + 1] def factorial(n): return reduce(lambda x, y: x * y, range(1, n + 1)) def cmin(dp, i, x): if x < dp[i]: dp[i] = x def cmax(dp, i, x): if x > dp[i]: dp[i] = x def alp_a_to_i(s): return ord(s) - ord("a") def alp_A_to_i(s): return ord(s) - ord("A") def alp_i_to_a(i): return chr(ord("a") + i) def alp_i_to_A(i): return chr(ord("A") + i) d4 = [(1, 0), (0, 1), (-1, 0), (0, -1)] d8 = [(1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)] def ranges(n, m): return ((i, j) for i in range(n) for j in range(m)) def valid(i, j, n, m): return 0 <= i < n and 0 <= j < m def ninj(i, j, n, m): return [(i + di, j + dj) for di, dj in d4 if valid(i + di, j + dj, n, m)] def gen(x, *args): if len(args) == 1: return [x] * args[0] if len(args) == 2: return [[x] * args[1] for _ in [0] * args[0]] if len(args) == 3: return [[[x] * args[2] for _ in [0] * args[1]] for _ in [0] * args[0]] if len(args) == 4: return [ [[[x] * args[3] for _ in [0] * args[2]] for _ in [0] * args[1]] for _ in [0] * args[0] ] list2d = lambda a, b, v: [[v] * b for _ in range(a)] list3d = lambda a, b, c, v: [[[v] * c for _ in range(b)] for _ in range(a)] class Debug: def __init__(self, debug=False): self.debug = debug cur_path = os.path.dirname(os.path.abspath(__file__)) self.local = os.path.exists(cur_path + "/.cph") def get_ic(self): if self.debug and self.local: from icecream import ic return ic else: return lambda *args, **kwargs: ... from typing import Union from functools import lru_cache class ModInt998244353: __slots__ = ["val"] @staticmethod @lru_cache(maxsize=None) def _inv(a: int) -> int: res = 1 b = 998244351 while b: if b & 1: res = res * a % 998244353 a = a * a % 998244353 b >>= 1 return res @classmethod def get_mod(cls) -> int: return 998244353 def __init__(self, val: int) -> None: self.val = val if 0 <= val and val < 998244353 else val % 998244353 def __add__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( self.val + (other if isinstance(other, int) else other.val) ) def __sub__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( self.val - (other if isinstance(other, int) else other.val) ) def __mul__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( self.val * (other if isinstance(other, int) else other.val) ) def __pow__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( pow(self.val, (other if isinstance(other, int) else other.val), 998244353) ) def __truediv__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( self.val * (self._inv(other) if isinstance(other, int) else self._inv(other.val)) ) __iadd__ = __add__ __isub__ = __sub__ __imul__ = __mul__ __ipow__ = __pow__ __itruediv__ = __truediv__ def __radd__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( (other if isinstance(other, int) else other.val) + self.val ) def __rsub__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( (other if isinstance(other, int) else other.val) - self.val ) def __rmul__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( (other if isinstance(other, int) else other.val) * self.val ) def __rpow__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( pow((other if isinstance(other, int) else other.val), self.val, 998244353) ) def __rtruediv__(self, other: Union[int, "ModInt998244353"]) -> "ModInt998244353": return ModInt998244353( (other if isinstance(other, int) else other.val) * self._inv(self.val) ) def __eq__(self, other: Union[int, "ModInt998244353"]): return self.val == int(other) def __lt__(self, other: Union[int, "ModInt998244353"]): return self.val < int(other) def __le__(self, other: Union[int, "ModInt998244353"]): return self.val <= int(other) def __gt__(self, other: Union[int, "ModInt998244353"]): return self.val > int(other) def __ge__(self, other: Union[int, "ModInt998244353"]): return self.val >= int(other) def __ne__(self, other: Union[int, "ModInt998244353"]): return self.val != int(other) def __neg__(self): return ModInt998244353(-self.val) def __pos__(self): return ModInt998244353(self.val) def __int__(self): return self.val def __str__(self): return str(self.val) def __repr__(self): return f"{self}" mint = ModInt998244353 mod = 998244353 ic = Debug(1).get_ic() n, k = MII() dp = [mint(0)] * (n + 1) dp[0] = mint(1) for _ in range(k): ndp = [mint(0)] * (n + 1) for i in range(n): ndp[i + 1] += dp[i] * (n - i) for i in range(1, n + 1): ndp[i - 1] += dp[i] * i dp = ndp res = mint(0) for i, cnt in enumerate(dp): c2 = (k - i) // 2 res += (n + c2) * cnt print(res / mint(pow(n, k, mod)))