結果

問題 No.1561 connect x connect
ユーザー Kiri8128Kiri8128
提出日時 2021-06-25 23:49:06
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,671 bytes
コンパイル時間 371 ms
コンパイル使用メモリ 87,152 KB
実行使用メモリ 90,732 KB
最終ジャッジ日時 2023-09-07 15:09:08
合計ジャッジ時間 4,116 ms
ジャッジサーバーID
(参考情報)
judge12 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 74 ms
75,984 KB
testcase_01 AC 216 ms
78,212 KB
testcase_02 TLE -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
testcase_37 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

def chk(i):
    # NG のとき 1 を返す
    return 0
    if i == 0 and Z[0] == 1: return 1
    return 1 if sum(Z[:i+1]) % 3 == 0 else 0

def calc(LL):
    if not LL: return []
    K = len(LL)
    Z = [0] * K
    R = [0] * K
    Z[0], R[0] = 0, 1 # 半開区間 [Start, End)
    RE = []
    i = 0
    cnt = 0 # Debug 用
    while i >= 0:
        ng = 0
        while i < K - 1:
            i += 1
            if LL[i] == LL[i-1] + 1:
                Z[i] = Z[i-1] # Start
                R[i] = Z[i-1] + 1 # End
            else:
                Z[i] = 0 # Start
                R[i] = max(Z[:i]) + 2 # End
            if chk(i):
                ng = 1
                break

        if not ng:
            ### ここに処理を書く
            # if cnt < 30:
            #     print("Z =", cnt, Z)
            re = [0] * N
            for i in range(K):
                re[LL[i]] = Z[i] + 1
            RE.append(re)
            cnt += 1
            ###

        Z[i] += 1
        while Z[i] >= R[i] or chk(i):
            if Z[i] < R[i]: Z[i] += 1
            while Z[i] >= R[i]:
                i -= 1
                if i < 0: break
                Z[i] += 1
            if i < 0: break
    return RE

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.PA = [-1] * n
    def root(self, a):
        L = []
        while self.PA[a] >= 0:
            L.append(a)
            a = self.PA[a]
        for l in L:
            self.PA[l] = a
        return a
    def unite(self, a, b):
        ra, rb = self.root(a), self.root(b)
        if ra != rb:
            if self.PA[rb] >= self.PA[ra]:
                self.PA[ra] += self.PA[rb]
                self.PA[rb] = ra
            else:
                self.PA[rb] += self.PA[ra]
                self.PA[ra] = rb
    def size(self, a):
        return -self.PA[self.root(a)]
    def groups(self):
        G = [[] for _ in range(self.n)]
        for i in range(self.n):
            G[self.root(i)].append(i)
        return [g for g in G if g]
    def groups_index(self):
        G = [[] for _ in range(self.n)]
        for i in range(self.n):
            G[self.root(i)].append(i)
        cnt = 0
        GG = []
        I = [-1] * self.n
        for i in range(self.n):
            if G[i]:
                GG.append(G[i])
                I[i] = cnt
                cnt += 1
        return GG, I
    def group_size(self):
        G = [[] for _ in range(self.n)]
        for i in range(self.n):
            G[self.root(i)].append(i)
        return [len(g) for g in G if g]
    def same(self, i, j):
        return 1 if self.root(i) == self.root(j) else 0

def encode(L):
    # print("ENCODE", L)
    re = 0
    for l in L:
        re = re * N + l
    return re

def make_move():
    X = [[0] * cnt for _ in range(cnt)]
    for k, tp in enumerate(TP):
        ma = max(tp)
        if ma < 0:
            X[k][k] = 1
            continue
        UU = [[] for _ in range(ma)]
        for i, a in enumerate(tp):
            if a:
                UU[a-1].append(i)
        
        
        for i in range(1 << N):
            if i == 0:
                if ma == 0:
                    X[k][k] = 1
                elif ma == 1:
                    X[k][-1] = 1
                continue
            
            uf = UnionFind(N + 1)
            for uu in UU:
                for j in range(len(uu) - 1):
                    uf.unite(uu[j], uu[j+1])
            
            A = [i >> j & 1 for j in range(N)]
            for j in range(N - 1):
                if A[j] and A[j+1]:
                    uf.unite(j, j + 1)
            
            B = [0] * N
            mm = 0
            for j in range(N):
                if A[j]:
                    for jj in range(j):
                        if A[jj] and uf.same(j, jj):
                            B[j] = B[jj]
                            break
                    else:
                        mm += 1
                        B[j] = mm
            
            for j in range(N):
                if A[j]:
                    uf.unite(j, N)
            for i, a in enumerate(tp):
                if a:
                    if uf.same(i, N) == 0:
                        break
            else:
                kk = DD[encode(B)]
                X[k][kk] += 1
    return X

P = 10 ** 9 + 7
N, M = map(int, input().split())
cnt = 0
TP = [[0] * N]
for i in range(1 << N):
    A = [i >> j & 1 for j in range(N)]
    L = []
    for j, a in enumerate(A):
        if a:
            L.append(j)
    # print("i, L =", i, A, L)
    for a in calc(L):
        TP.append(a)
TP.append([-1] * N)
DD = {encode(tp): i for i, tp in enumerate(TP)}
cnt = len(TP)

X = make_move()
if 0:
    print("cnt =", cnt)
    for a in TP:
        print(a)
    print("X =")
    for x in X:
        print(x)

def mmult(A, B):
    global mod
    n, m, l = len(A), len(B), len(B[0])
    ret = [[0] * l for _ in range(n)]
    for i in range(n):
        for j in range(m):
            for k in range(l):
                ret[i][k] = (ret[i][k] + A[i][j] * B[j][k]) % mod
    return ret
def mpow(A, n):
    if n == 1: return A
    if n == 0: return [[1 if i == j else 0 for j in range(len(A))] for i in range(len(A))]
    return mmult(mpow(A, n - 1), A) if n % 2 else mpow(mmult(A, A), n // 2)

mod = 10 ** 9 + 7


print(mpow(X, M + 1)[0][-1])

if 0:
    Y = [0] * cnt
    Y[0] = 1
    for _ in range(M + 1):
        nY = [0] * cnt
        for i, y in enumerate(Y):
            if not y: continue
            x = X[i]
            for j, xx in enumerate(x):
                if xx:
                    nY[j] = (nY[j] + y * xx) % P
        Y = nY
    print(Y[-1])
    # print("Y =", Y)


0