# from : https://judge.yosupo.jp/submission/55648 import atexit from operator import mod import __pypy__ import sys import os MOD = 998244353 IMAG = 911660635 IIMAG = 86583718 rate2 = (0, 911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456, 131300601, 842657263, 730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443, 56250497, 867605899, 0) irate2 = (0, 86583718, 372528824, 373294451, 645684063, 112220581, 692852209, 155456985, 797128860, 90816748, 860285882, 927414960, 354738543, 109331171, 293255632, 535113200, 308540755, 121186627, 608385704, 438932459, 359477183, 824071951, 103369235, 0) rate3 = (0, 372528824, 337190230, 454590761, 816400692, 578227951, 180142363, 83780245, 6597683, 70046822, 623238099, 183021267, 402682409, 631680428, 344509872, 689220186, 365017329, 774342554, 729444058, 102986190, 128751033, 395565204, 0) irate3 = (0, 509520358, 929031873, 170256584, 839780419, 282974284, 395914482, 444904435, 72135471, 638914820, 66769500, 771127074, 985925487, 262319669, 262341272, 625870173, 768022760, 859816005, 914661783, 430819711, 272774365, 530924681, 0) def butterfly(a): n = len(a) h = (n - 1).bit_length() le = 0 while le < h: if h - le == 1: p = 1 << (h - le - 1) rot = 1 for s in range(1 << le): offset = s << (h - le) for i in range(p): l = a[i + offset] r = a[i + offset + p] * rot a[i + offset] = (l + r) % MOD a[i + offset + p] = (l - r) % MOD rot *= rate2[(~s & -~s).bit_length()] rot %= MOD le += 1 else: p = 1 << (h - le - 2) rot = 1 for s in range(1 << le): rot2 = rot * rot % MOD rot3 = rot2 * rot % MOD offset = s << (h - le) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] * rot a2 = a[i + offset + p * 2] * rot2 a3 = a[i + offset + p * 3] * rot3 a1na3imag = (a1 - a3) % MOD * IMAG a[i + offset] = (a0 + a2 + a1 + a3) % MOD a[i + offset + p] = (a0 + a2 - a1 - a3) % MOD a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % MOD a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % MOD rot *= rate3[(~s & -~s).bit_length()] rot %= MOD le += 2 def butterfly_inv(a): n = len(a) h = (n - 1).bit_length() le = h while le: if le == 1: p = 1 << (h - le) irot = 1 for s in range(1 << (le - 1)): offset = s << (h - le + 1) for i in range(p): l = a[i + offset] r = a[i + offset + p] a[i + offset] = (l + r) % MOD a[i + offset + p] = (l - r) * irot % MOD irot *= irate2[(~s & -~s).bit_length()] irot %= MOD le -= 1 else: p = 1 << (h - le) irot = 1 for s in range(1 << (le - 2)): irot2 = irot * irot % MOD irot3 = irot2 * irot % MOD offset = s << (h - le + 2) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] a2 = a[i + offset + p * 2] a3 = a[i + offset + p * 3] a2na3iimag = (a2 - a3) * IIMAG % MOD a[i + offset] = (a0 + a1 + a2 + a3) % MOD a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % MOD a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % MOD a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % MOD irot *= irate3[(~s & -~s).bit_length()] irot %= MOD le -= 2 def multiply(s, t): n = len(s) m = len(t) if min(n, m) <= 60: a = [0] * (n + m - 1) for i in range(n): if i % 8 == 0: for j in range(m): a[i + j] += s[i] * t[j] a[i + j] %= MOD else: for j in range(m): a[i + j] += s[i] * t[j] return [x % MOD for x in a] a = s.copy() b = t.copy() z = 1 << (n + m - 2).bit_length() a += [0] * (z - n) b += [0] * (z - m) butterfly(a) butterfly(b) for i in range(z): a[i] *= b[i] a[i] %= MOD butterfly_inv(a) a = a[:n + m - 1] iz = pow(z, MOD - 2, MOD) return [v * iz % MOD for v in a] n = int(input()) nn = n << 1 two = [1] * (nn + 1) for i in range(nn): two[i + 1] = two[i] << 1 if two[i + 1] > MOD: two[i + 1] -= MOD f = [0] * (n + 1) g = [0] * (n + 1) for i in range(1, n + 1): f[i] = ((i + 3) * two[i] + (i - 3) * two[i << 1]) % MOD g[i] = ((i - 2) * two[i] + 2) % MOD u, v = input(), input() if u == v: print(f[n]) exit() lca = 0 for i in range(min(len(u), len(v))): if u[i] != v[i]: break lca += 1 m = len(u) + len(v) - 2 * lca + 1 p = [0] * m p[0] = len(u) cnt = 0 while p[cnt] != lca: p[cnt + 1] = p[cnt] - 1 cnt += 1 lca_cnt = cnt while cnt != m - 1: p[cnt + 1] = p[cnt] + 1 cnt += 1 for i in range(m): p[i] -= 1 a = [0] * m b = [0] * m c = [0] * m a[0] = two[n - p[0]] - 1 a[m - 1] = two[n - p[m - 1]] - 1 a[lca_cnt] = 1 k, d = 1, p[lca_cnt] while k <= d: a[lca_cnt] += two[n - k] d += 1 if lca_cnt == 0 or lca_cnt == m - 1: a[lca_cnt] += two[n - p[lca_cnt] - 1] - 1 a[lca_cnt] %= MOD for i in range(1, m - 1): if i == lca_cnt: continue a[i] = two[n - 1 - p[i]] b[0] = g[n - p[0]] b[m - 1] = g[n - p[m - 1]] b[lca_cnt] = 0 k, d = 1, p[lca_cnt] while k <= d: b[lca_cnt] += g[n - k] + (two[n - k] - 1) * (d + 2 - k) % MOD + k if lca_cnt == 0 or lca_cnt == m - 1: b[lca_cnt] += g[n - p[lca_cnt] - 1] + two[n - p[lca_cnt] - 1] - 1 b[lca_cnt] %= MOD for i in range(1, m - 1): if i == lca_cnt: continue b[i] = (g[n - 1 - p[i]] + two[n - 1 - p[i]] - 1) % MOD c[0] = f[n - p[0]] c[m - 1] = f[n - p[m - 1]] c[lca_cnt] = 0 k, d = 1, p[lca_cnt] while k <= d: c[lca_cnt] += f[n - k] c[lca_cnt] += g[n - k] * (a[lca_cnt] - two[n - k] + 1) % MOD c[lca_cnt] += (two[n - k] - 1) * (a[lca_cnt] - two[n - k] + 1) % MOD c[lca_cnt] += (two[n] - two[n - k]) * \ (a[lca_cnt] - two[n] + two[n - k]) % MOD c[lca_cnt] %= MOD if lca_cnt == 0 or lca_cnt == m - 1: c[lca_cnt] += f[n - p[lca_cnt] - 1] c[lca_cnt] += g[n - p[lca_cnt] - 1] * \ (a[lca_cnt] - two[n - p[lca_cnt] - 1] + 1) % MOD c[lca_cnt] += (two[n - p[lca_cnt] - 1] - 1) * \ (a[lca_cnt] - two[n - p[lca_cnt] - 1] + 1) % MOD c[lca_cnt] %= MOD for i in range(1, m - 1): if i == lca_cnt: continue c[i] = f[n - 1 - p[i]] + g[n - 1 - p[i]] + two[n - 1 - p[i]] - 1 c[i] %= MOD sum_a, sum_b, sum_c, sum_ab = 0, 0, 0, 0 for i in a: sum_a += i sum_a %= MOD for i in b: sum_b += i sum_b %= MOD for i in c: sum_c += i sum_c %= MOD for i in range(m): sum_ab += a[i] * b[i] % MOD sum_ab %= MOD ans = (sum_c + sum_a * sum_b - sum_ab) % MOD aa = [a[m - 1 - i] for i in range(m)] tt = multiply(a, aa) sum = 0 for j in range(1, m): sum += min(j, m - j) * (tt[m - 1 - j] + tt[2 * m - 1 - j]) % MOD sum %= MOD if sum % 2 != 0: sum += MOD sum //= 2 ans += sum ans %= MOD if ans < 0: ans += MOD print(ans)