結果
問題 |
No.3119 A Little Cheat
|
ユーザー |
|
提出日時 | 2025-04-19 01:41:46 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 715 ms / 2,000 ms |
コード長 | 3,042 bytes |
コンパイル時間 | 775 ms |
コンパイル使用メモリ | 82,048 KB |
実行使用メモリ | 105,600 KB |
最終ジャッジ日時 | 2025-04-19 01:42:16 |
合計ジャッジ時間 | 27,796 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 49 |
ソースコード
import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES") no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO") ####################################################################### def f(a, b): ans = 0 for i in range(n): if a[i] < b[i]: ans += 1 return ans F = 0 def score(a, b): global F b = list(b) ans = f(a,b) A = f(a, b) for i in range(n-1): b[i], b[i+1] = b[i+1], b[i] ans = max(ans, f(a, b)) b[i], b[i+1] = b[i+1], b[i] if A < ans: print(b) F += 1 return ans from itertools import product def naive(n, m, a): ans = 0 for p in product(range(1, m + 1), repeat=n): ans += score(a, p) return ans def g(a0, a1, b0, b1): if a1 < a0: a0, a1 = a1, a0 b0, b1 = b1, b0 return b0 <= a0 < b1 <= a1 or a0 < b1 <= a1 < b0 def naive2(n, m, a): dp0 = [1] * m dp1 = [0] * m a = [i-1 for i in a] for i in range(n-1): ndp0 = [0] * m ndp1 = [0] * m for j in range(m): for k in range(m): x = g(a[i], a[i+1], j, k) if x: ndp1[k] += dp0[j] + dp1[j] else: ndp1[k] += dp1[j] ndp0[k] += dp0[j] dp1 = ndp1 dp0 = ndp0 return sum(ndp1) def h(a0, a1, a2, k, l): L = 0 R = m if l % 2: L = max(L, a0) else: R = min(R, a0) if (l // 2) % 2: L = max(L, a1) else: R = min(R, a1) if l // 4: L = max(L, a2) else: R = min(R, a2) if a0 <= a1: X = (l % 4 == 1) and (k == 0 or k == 3) else: X = (k == 2) and (l % 4 == 0 or l % 4 == 3) # print(bin(l), L, R) return X, max(R - L, 0) def solve(n, m, a): a.append(m) dp = [[0 for i in range(4)] for j in range(2)] dp[0][0] = min(a[0], a[1]) dp[0][1] = max(a[1] - a[0], 0) dp[0][2] = max(a[0] - a[1], 0) dp[0][3] = m - max(a[0], a[1]) # print(dp) for i in range(n-1): ndp = [[0 for i in range(4)] for j in range(2)] for j in range(2): for k in range(4): if dp[j][k] == 0: continue for l in range(8): x, y = h(a[i], a[i+1], a[i+2], k, l) if y: # print(a[i], a[i+1], a[i+2], bin(k), bin(l), x, y) ndp[j|x][l // 2] += dp[j][k] * y ndp[j|x][l // 2] %= mod dp = ndp # print(dp) return sum(dp[1]) mod = 998244353 n, m = na() a = [x for x in na()] # print(naive(n, m, a)) # print(F) # print(naive2(n, m, a)) ans = 0 for i in range(n): ans += m - a[i] ans = ans * pow(m, n-1, mod) % mod # print(ans) # print(solve(n, m, a)) print((ans + solve(n, m, a)) % mod)