結果
問題 | No.1241 Eternal Tours |
ユーザー |
![]() |
提出日時 | 2020-09-10 01:54:20 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
AC
|
実行時間 | 3,070 ms / 6,000 ms |
コード長 | 2,129 bytes |
コンパイル時間 | 328 ms |
コンパイル使用メモリ | 12,800 KB |
実行使用メモリ | 107,284 KB |
最終ジャッジ日時 | 2024-12-17 17:50:30 |
合計ジャッジ時間 | 70,314 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 40 |
ソースコード
#!/usr/bin/env python3 import numpy as np class NTT(): def __init__(self, D, MOD, root): self.md = MOD self.w = np.array([1], np.int64) self.iw = np.array([1], np.int64) while len(self.w) < 1 << (D - 1): dw = pow(root, (self.md - 1) // (len(self.w) * 4), self.md) dwinv = pow(dw, -1, self.md) self.w = np.r_[self.w, self.w * dw] % self.md self.iw = np.r_[self.iw, self.iw * dwinv] % self.md def ntt(self, mat): in_shape = mat.shape n = in_shape[-1] m = n // 2 while m: mat = mat.reshape(-1, n // (m * 2), 2, m) w_use = self.w[:n // (m * 2)].reshape(1, -1, 1) y = mat[:, :, 1] * w_use % self.md mat = np.stack((mat[:, :, 0] + y, mat[:, :, 0] + self.md - y), 2) % self.md m //= 2 return mat.reshape(in_shape) def intt(self, mat): in_shape = mat.shape n = in_shape[-1] m = 1 while m < n: mat = mat.reshape(-1, n // (m * 2), 2, m) iw_use = self.iw[:n // (m * 2)].reshape(1, -1, 1) mat = np.stack((mat[:, :, 0] + mat[:, :, 1], (mat[:, :, 0] + self.md - mat[:, :, 1]) * iw_use), 2) % self.md m *= 2 n_inv = pow(n, -1, self.md) return mat.reshape(in_shape) * n_inv % self.md X, Y, T, a, b, c, d = list(map(int, input().split())) md = 998244353 T = (T - 1) % (md - 1) + 1 H, W = 1 << (X + 1), 1 << (Y + 1) dp = np.zeros((H, W), np.int64) trans = np.zeros((H, W), np.int64) dp[a, b] = dp[-a, -b] = 1 dp[a, -b] = dp[-a, b] = md - 1 trans[0, 0] = trans[0, 1] = trans[0, -1] = trans[1, 0] = trans[-1, 0] = 1 ntt = NTT(18, md, 3) dp = ntt.ntt(dp) trans = ntt.ntt(trans) dp = dp.T trans = trans.T dp = ntt.ntt(dp) trans = ntt.ntt(trans) def matpow(x, n, mod): ret, tmp = np.ones(x.shape, np.int64), x % mod while n: if n % 2: ret = ret * tmp % mod tmp = tmp * tmp % mod n //= 2 return ret dp = dp * matpow(trans, T, md) % md dp = ntt.intt(dp) dp = dp.T dp = ntt.intt(dp) print(dp[c][d])