結果
問題 |
No.3202 Periodic Alternating Subsequence
|
ユーザー |
![]() |
提出日時 | 2025-07-11 22:25:54 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,799 ms / 2,000 ms |
コード長 | 4,764 bytes |
コンパイル時間 | 360 ms |
コンパイル使用メモリ | 82,512 KB |
実行使用メモリ | 78,036 KB |
最終ジャッジ日時 | 2025-09-04 10:43:18 |
合計ジャッジ時間 | 34,824 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 24 |
ソースコード
def solve(S, K): d = {} for type in "C", "S", "Q": for last in "X", "Y", "Z": d[type, last] = len(d) MX = Matrix.identity(len(d)) MY = Matrix.identity(len(d)) Ms = {"X": MX, "Y": MY} for last in "X", "Y", "Z": for nlast in "X", "Y": if last == nlast: continue Ms[nlast][d["C", last]][d["C", nlast]] += 1 for last in "X", "Y", "Z": for nlast in "X", "Y": if last == nlast: continue Ms[nlast][d["C", last]][d["S", nlast]] += 1 Ms[nlast][d["S", last]][d["S", nlast]] += 1 for last in "X", "Y", "Z": for nlast in "X", "Y": if last == nlast: continue Ms[nlast][d["C", last]][d["Q", nlast]] += 1 Ms[nlast][d["S", last]][d["Q", nlast]] += 2 Ms[nlast][d["Q", last]][d["Q", nlast]] += 1 mat = Matrix.identity(len(d)) for ch in S: match ch: case "0": mat = mat@MX case "1": mat = mat@MY case _: assert False V = Matrix.row_vector([0]*len(d)) V.set_at(d["C", "Z"], 1) V = V @ pow(mat, K) #print(V) res = V.at(d["Q", "X"]) res += V.at(d["Q", "Y"]) res %= MOD return res class Matrix: MOD = 998_244_353 def __init__(self, matrix: list[list], *, _copy=True): self.shape: tuple[int, int] self.mat: list[list] self.shape = len(matrix), len(matrix[0]) assert all(len(row) == len(matrix[0]) for row in matrix) if _copy: self.mat = [row.copy() for row in matrix] else: self.mat = matrix @classmethod def set_mod(cls, MOD: int): cls.MOD = MOD def copy(self) -> "Matrix": cls = self.__class__ return cls(self.mat, _copy=True) @classmethod def full(cls, shape: tuple[int, int], fill_value) -> "Matrix": H, W = shape matrix = [[fill_value] * W for _ in range(H)] return cls(matrix, _copy=False) @classmethod def zeros(cls, shape: tuple[int, int]) -> "Matrix": return cls.full(shape, 0) @classmethod def identity(cls, N: int) -> "Matrix": matrix = [[0]*N for _ in range(N)] for i in range(N): matrix[i][i] = 1 return cls(matrix, _copy=False) @classmethod def column_vector(cls, seq: list[int]) -> "Matrix": return cls([[x] for x in seq], _copy=False) @classmethod def row_vector(cls, seq: list[int]) -> "Matrix": return cls([list(seq)], _copy=False) def __getitem__(self, key): if isinstance(key, tuple): i, j = key return self.mat[i][j] return self.mat[key] def __setitem__(self, key, value): if isinstance(key, tuple): i, j = key self.mat[i][j] = value else: raise ValueError def at(self, i): if self.shape[0] == 1: return self.mat[0][i] elif self.shape[1] == 1: return self.mat[i][0] else: raise IndexError def set_at(self, i, x): if self.shape[0] == 1: self.mat[0][i] = x elif self.shape[1] == 1: self.mat[i][0] = x else: raise IndexError @classmethod def matmul(cls, A: list[list], B: list[list], H, W, K) -> list[list]: if cls.MOD is None: res = [] for i in range(H): nrow = [0]*K for j, a in enumerate(A[i]): if a == 0: continue for k, b in enumerate(B[j]): nrow[k] += a*b res.append(nrow) return res else: res = [] for i in range(H): nrow = [0]*K for j, a in enumerate(A[i]): if a == 0: continue for k, b in enumerate(B[j]): nrow[k] = (nrow[k] + a*b) % cls.MOD res.append(nrow) return res def __matmul__(self, other): H, W = self.shape W2, K = other.shape assert W == W2 C = self.matmul(self.mat, other.mat, H, W, K) return self.__class__(C, _copy=False) def __pow__(self, exp: int) -> "Matrix": assert exp >= 0 assert self.shape[0] == self.shape[1] N, _ = self.shape A = self.mat X = self.identity(N).mat bitlen = int.bit_length(exp) for pos in range(bitlen - 1, -1, -1): X = self.matmul(X, X, N, N, N) if exp>>pos&1: X = self.matmul(A, X, N, N, N) return self.__class__(X, _copy=False) pow = __pow__ def __array__(self): import numpy as np return np.array(self.mat) def __repr__(self): import numpy as np return str(np.array(self)) MOD = 1_000_000_007 Matrix.set_mod(MOD) S = input() K = int(input()) ans = solve(S, K) print(ans)