import sys MOD = 998244353 ROOT = 3 def ntt(a, inverse=False): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j >= bit: j -= bit bit >>= 1 j += bit if i < j: a[i], a[j] = a[j], a[i] log_n = (n-1).bit_length() for s in range(1, log_n+1): m = 1 << s wm = pow(ROOT, (MOD-1)//m, MOD) if inverse: wm = pow(wm, MOD-2, MOD) for k in range(0, n, m): w = 1 for j_in in range(m//2): t = a[k + j_in + m//2] * w % MOD u = a[k + j_in] a[k + j_in] = (u + t) % MOD a[k + j_in + m//2] = (u - t) % MOD w = w * wm % MOD if inverse: inv_n = pow(n, MOD-2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD return a def convolve(a, b, max_len): n = len(a) m = len(b) if n == 0 or m == 0: return [] size = n + m - 1 size_ntt = 1 << (size-1).bit_length() a_ntt = a + [0] * (size_ntt - n) b_ntt = b + [0] * (size_ntt - m) a_ntt = ntt(a_ntt) b_ntt = ntt(b_ntt) c_ntt = [(x * y) % MOD for x, y in zip(a_ntt, b_ntt)] c = ntt(c_ntt, inverse=True) c = c[:size] c = c[:max_len+1] return [x % MOD for x in c] def main(): S = sys.stdin.readline().strip() max_sum = len(S) if max_sum == 0: print(0) return max_fact = max_sum fact = [1] * (max_fact + 1) for i in range(1, max_fact+1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_fact + 1) inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD from collections import defaultdict cnt = defaultdict(int) for c in S: cnt[c] += 1 chars = list(cnt.values()) G = [1] for c in chars: f = [inv_fact[k] for k in range(c+1)] new_G = convolve(G, f, max_sum) G = new_G ans = 0 for t in range(1, len(G)): if t > max_sum: break ans = (ans + G[t] * fact[t]) % MOD print(ans % MOD) if __name__ == "__main__": main()