結果

問題 No.263 Common Palindromes Extra
ユーザー amesyuamesyu
提出日時 2024-10-24 15:44:52
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
TLE  
実行時間 -
コード長 4,115 bytes
コンパイル時間 456 ms
コンパイル使用メモリ 13,184 KB
実行使用メモリ 116,604 KB
最終ジャッジ日時 2024-10-24 15:45:00
合計ジャッジ時間 7,743 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 TLE -
testcase_01 -- -
testcase_02 -- -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict
sys.setrecursionlimit(int(1e7))
MOD = (1<<61) - 1
BASE = 998244353
MAXN = 5 * 10**5
POWER = [1]
for i in range(1, MAXN+1):
    POWER.append((POWER[-1] * BASE) % MOD)

class RollingHash:
    def __init__(self, string):
        self.string = string
        self.size = len(string)
        
        # construct
        self.hashed = [0]
        for i in range(self.size):
            self.hashed.append(self.hashed[-1] * BASE + ord(self.string[i]))
            self.hashed[-1] %= MOD

    def query(self, l, r):
        assert 0 <= l <= r <= self.size
        return (self.hashed[r] - self.hashed[l] * POWER[r-l]) % MOD

def composite_query(R1, R2, l1, r1, l2, r2):
    LEFT = R1.query(l1, r1)
    RIGHT = R2.query(l2, r2)
    return (RIGHT + LEFT * POWER[r2-l2]) % MOD

class eertree:
    def __init__(self, string):
        self.string = string
        self.size = len(string)
        self.r1 = RollingHash(self.string)
        self.r2 = RollingHash(''.join(reversed(self.string)))
    
    def __rev(self, l, r):
        return self.size - r, self.size - l

    def build(self):
        
        BEGIN = ord('A') # or ord('a')
        # Centroid of the palindrome A-Z and ''
        PALINDROME_SIZE = 2 * self.size
        G = [[] for _ in range(PALINDROME_SIZE)]
        H = [0] * PALINDROME_SIZE
        W = [0] * PALINDROME_SIZE

        ptr = [defaultdict(lambda: -1) for _ in range(27)]
        ptr[26][0] = 0
        H[0] = (26, 0)
        free = 1

        for i in range(self.size):
            
            # Centroid = self.string[i]
            ok, ng = -1, min(i, self.size-i-1)+1
            ID = ord(self.string[i]) - BEGIN
            while ng - ok > 1:
                x = (ng + ok) >> 1
                h1 = self.r1.query(i+1, i+1+x) # [i+1, i+1+x)
                h2 = self.r2.query(self.size-i, self.size-i+x) # [i-x, i)
                if ptr[ID][h1] >= 0 and h1 == h2: ok = x
                else: ng = x
            
            if ok == -1:
                ptr[ID][0] = free
                H[free] = (ID, 0)
                free += 1
                ok = 0

            par = ptr[ID][self.r1.query(i+1,i+1+ok)]
            for j in range(ok+1, min(i+1, self.size-i)):
                c1 = self.string[i+j]
                c2 = self.string[i-j]
                if c1 == c2:
                    h1 = self.r1.query(i+1, i+j+1)
                    ptr[ID][h1] = free
                    G[par].append(free)
                    H[free] = (ID, h1)
                    par = free
                    free += 1
                else:
                    break
            W[par] += 1


            # Centroid = ''
            ok, ng = 0, min(i, self.size-i-1) + 1
            ID = 26
            while ng - ok > 1:
                x = (ng + ok) >> 1
                h1 = self.r1.query(i, i+x) # [i, i+x)
                h2 = self.r2.query(self.size-i, self.size-i+x) # [i-x, i)
                if ptr[ID][h1] >= 0 and h1 == h2: ok = x
                else: ng = x

            par = ptr[ID][self.r1.query(i,i+ok)]
            for j in range(ok+1, min(i+1, self.size-i+1)):
                c1 = self.string[i+j-1]
                c2 = self.string[i-j]
                if c1 == c2:
                    h1 = self.r1.query(i, i+j)
                    ptr[ID][h1] = free
                    G[par].append(free)
                    H[free] = (ID, h1)
                    par = free
                    free += 1
                else:
                    break

            W[par] += 1
        
        def dfs(v):
            for nxt in G[v]:
                W[v] += dfs(nxt)
            return W[v]
        
        for ID in range(27):
            if ptr[ID][0] >= 0: dfs(ptr[ID][0])
        W[0] = 0
        return W, H
    
s = input()
t = input()

palS = eertree(s)
palT = eertree(t)
SW, SH = palS.build()
TW, TH = palT.build()

pair = defaultdict(lambda: [0, 0])

for i in range(len(SW)):
    pair[SH[i]][0] += SW[i]
for i in range(len(TW)):
    pair[TH[i]][1] += TW[i]

ans = 0
for key in pair.keys():
    x, y = pair[key]
    ans += x * y

print(ans)
0