import sys from collections import defaultdict MOD1 = 10**18 + 3 MOD2 = 10**18 + 7 BASE1 = 911382629 BASE2 = 3571428571 def main(): S = sys.stdin.readline().strip() T = sys.stdin.readline().strip() max_len = max(len(S), len(T)) pow_base1 = [1] * (max_len + 2) for i in range(1, max_len + 2): pow_base1[i] = (pow_base1[i-1] * BASE1) % MOD1 pow_base2 = [1] * (max_len + 2) for i in range(1, max_len + 2): pow_base2[i] = (pow_base2[i-1] * BASE2) % MOD2 class Node: __slots__ = ['len', 'suff_link', 'trans', 'h1', 'h2', 'count'] def __init__(self): self.len = 0 self.suff_link = None self.trans = dict() self.h1 = 0 # hash1 self.h2 = 0 # hash2 self.count = 0 def build_tree(s, pow_base1, pow_base2): root_neg1 = Node() root_neg1.len = -1 root_neg1.suff_link = root_neg1 root0 = Node() root0.len = 0 root0.suff_link = root_neg1 tree = [root_neg1, root0] last = root0 s_chars = [ord(c) - ord('A') + 1 for c in s] for idx, c in enumerate(s_chars): current = last while True: edge_len = current.len pos = idx - edge_len - 1 if pos >= 0 and s_chars[pos] == c: break current = current.suff_link if c in current.trans: last = current.trans[c] last.count += 1 continue new_node = Node() new_node.len = current.len + 2 tree.append(new_node) current.trans[c] = new_node if new_node.len == 1: new_node.suff_link = root0 new_node.h1 = c % MOD1 new_node.h2 = c % MOD2 else: parent = current new_node.h1 = (c * pow_base1[parent.len + 1] % MOD1 + parent.h1 * BASE1 % MOD1 + c) % MOD1 new_node.h2 = (c * pow_base2[parent.len + 1] % MOD2 + parent.h2 * BASE2 % MOD2 + c) % MOD2 suff = parent.suff_link while True: edge_len_suff = suff.len pos_suff = idx - edge_len_suff - 1 if pos_suff >= 0 and s_chars[pos_suff] == c: break suff = suff.suff_link if c in suff.trans: new_node.suff_link = suff.trans[c] else: new_node.suff_link = root0 new_node.count = 1 last = new_node for node in reversed(tree): if node.suff_link is not None: node.suff_link.count += node.count hash_map = defaultdict(int) for node in tree: if node.len > 0 and node.count > 0: hash_map[(node.h1, node.h2)] += node.count return hash_map hash_s = build_tree(S, pow_base1, pow_base2) hash_t = build_tree(T, pow_base1, pow_base2) total = 0 for key in hash_s: if key in hash_t: total += hash_s[key] * hash_t[key] print(total) if __name__ == "__main__": main()