import sys from collections import defaultdict class Node: __slots__ = ['len', 'link', 'next', 'cnt', 'h1', 'h2'] def __init__(self): self.len = 0 self.link = 0 self.next = [0] * 26 # 'A' to 'Z' self.cnt = 0 self.h1 = 0 self.h2 = 0 def build_pal_tree(s, base1, base2, mod1, mod2, pow_base1, pow_base2): nodes = [] # Initialize root nodes node = Node() node.len = -1 node.link = 0 nodes.append(node) node = Node() node.len = 0 node.link = 0 nodes.append(node) size = 2 last = 1 # start with the empty node for i, c in enumerate(s): idx = ord(c) - ord('A') # Find the appropriate parent node p = last while True: current_len = nodes[p].len # Check if the character at position i - current_len - 1 is c if i - current_len - 1 >= 0 and s[i - current_len - 1] == c: break p = nodes[p].link # Check if the child exists if nodes[p].next[idx] != 0: last = nodes[p].next[idx] nodes[last].cnt += 1 continue # Create new node new_node = Node() new_node.len = nodes[p].len + 2 # Compute hash values if nodes[p].len == -1 and new_node.len == 1: # Special case for single character new_node.h1 = ord(c) new_node.h2 = ord(c) else: # Compute h1 and h2 using the parent's hash values new_node.h1 = (ord(c) * pow_base1[nodes[p].len + 1] + nodes[p].h1 * base1 + ord(c)) % mod1 new_node.h2 = (ord(c) * pow_base2[nodes[p].len + 1] + nodes[p].h2 * base2 + ord(c)) % mod2 # Find the link for the new node # Start from the parent's link new_link = 0 if new_node.len == 1: new_link = 1 # link to empty node else: # Find the longest proper suffix which is a palindrome link_p = nodes[p].link while True: current_len_link = nodes[link_p].len if i - current_len_link - 1 >= 0 and s[i - current_len_link - 1] == c: break link_p = nodes[link_p].link new_link = nodes[link_p].next[idx] if new_link == 0: new_link = 1 # default to empty node if not found (should not happen) new_node.link = new_link # Add the new node to the list nodes.append(new_node) size += 1 last = size - 1 nodes[p].next[idx] = last new_node.cnt = 1 # Accumulate counts # Sort nodes by length in descending order nodes_sorted = sorted(nodes[2:], key=lambda x: -x.len) for node in nodes_sorted: if node.link < len(nodes) and node.link >= 0: nodes[node.link].cnt += node.cnt # Collect all hashes except the two roots hash_map = defaultdict(int) for node in nodes[2:]: key = (node.h1, node.h2) hash_map[key] += node.cnt return hash_map def main(): s = sys.stdin.readline().strip() t = sys.stdin.readline().strip() # Precompute powers for hash calculation base1 = 911382629 mod1 = 10**18 + 3 base2 = 3571428571 mod2 = 10**18 + 7 max_len = max(len(s), len(t)) + 2 # Precompute pow_base arrays pow_base1 = [1] * (max_len + 1) for i in range(1, max_len + 1): pow_base1[i] = (pow_base1[i-1] * base1) % mod1 pow_base2 = [1] * (max_len + 1) for i in range(1, max_len + 1): pow_base2[i] = (pow_base2[i-1] * base2) % mod2 # Build hash maps for both strings map_S = build_pal_tree(s, base1, base2, mod1, mod2, pow_base1, pow_base2) map_T = build_pal_tree(t, base1, base2, mod1, mod2, pow_base1, pow_base2) # Calculate the result result = 0 for key in map_T: if key in map_S: result += map_S[key] * map_T[key] print(result) if __name__ == "__main__": main()