import sys MOD1 = 10**18 + 3 BASE1 = 911382629 MOD2 = 10**18 + 7 BASE2 = 35714285 class EertreeNode: __slots__ = ('length', 'suffix_link', 'transitions', 'end', 'count') def __init__(self, length, suffix_link): self.length = length self.suffix_link = suffix_link self.transitions = {} self.end = -1 self.count = 0 def build_eertree(s): root_neg1 = EertreeNode(-1, None) root_0 = EertreeNode(0, root_neg1) root_neg1.suffix_link = root_neg1 nodes = [root_neg1, root_0] current = root_0 for i, c in enumerate(s): while True: if i - current.length - 1 >= 0 and s[i - current.length - 1] == c: break current = current.suffix_link if c in current.transitions: current = current.transitions[c] current.count += 1 continue new_length = current.length + 2 new_node = EertreeNode(new_length, None) new_node.end = i nodes.append(new_node) current.transitions[c] = new_node if new_length == 1: new_node.suffix_link = root_0 else: suffix = current.suffix_link while True: if i - suffix.length - 1 >= 0 and s[i - suffix.length - 1] == c: if c in suffix.transitions: suffix = suffix.transitions[c] else: suffix = root_0 break suffix = suffix.suffix_link new_node.suffix_link = suffix new_node.count = 1 current = new_node sorted_nodes = sorted(nodes[2:], key=lambda x: -x.length) for node in sorted_nodes: if node.suffix_link.length >= 0: node.suffix_link.count += node.count return nodes def compute_prefix_hash(s, base, mod): n = len(s) prefix = [0] * (n + 1) pow_base = [1] * (n + 1) for i in range(n): prefix[i+1] = (prefix[i] * base + ord(s[i])) % mod pow_base[i+1] = (pow_base[i] * base) % mod return prefix, pow_base def get_hash(prefix, pow_base, mod, a, b): len_sub = b - a + 1 hash_val = (prefix[b+1] - prefix[a] * pow_base[len_sub]) % mod return hash_val if hash_val >= 0 else hash_val + mod def process_string(s): prefix1, pow1 = compute_prefix_hash(s, BASE1, MOD1) prefix2, pow2 = compute_prefix_hash(s, BASE2, MOD2) nodes = build_eertree(s) hash_dict = {} for node in nodes: if node.length <= 0: continue end = node.end start = end - node.length + 1 if start < 0: continue h1 = get_hash(prefix1, pow1, MOD1, start, end) h2 = get_hash(prefix2, pow2, MOD2, start, end) key = (node.length, h1, h2) hash_dict[key] = hash_dict.get(key, 0) + node.count return hash_dict def main(): S = sys.stdin.readline().strip() T = sys.stdin.readline().strip() dict_S = process_string(S) dict_T = process_string(T) result = 0 for key in dict_S: if key in dict_T: result += dict_S[key] * dict_T[key] print(result) if __name__ == '__main__': main()