import sys
MOD = 998244353
ROOT = 3

def ntt(a, inverse=False):
    n = len(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    log_n = (n-1).bit_length()
    for s in range(1, log_n+1):
        m = 1 << s
        wm = pow(ROOT, (MOD-1)//m, MOD)
        if inverse:
            wm = pow(wm, MOD-2, MOD)
        for k in range(0, n, m):
            w = 1
            for j_in in range(m//2):
                t = a[k + j_in + m//2] * w % MOD
                u = a[k + j_in]
                a[k + j_in] = (u + t) % MOD
                a[k + j_in + m//2] = (u - t) % MOD
                w = w * wm % MOD
    if inverse:
        inv_n = pow(n, MOD-2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

def convolve(a, b, max_len):
    n = len(a)
    m = len(b)
    if n == 0 or m == 0:
        return []
    size = n + m - 1
    size_ntt = 1 << (size-1).bit_length()
    a_ntt = a + [0] * (size_ntt - n)
    b_ntt = b + [0] * (size_ntt - m)
    a_ntt = ntt(a_ntt)
    b_ntt = ntt(b_ntt)
    c_ntt = [(x * y) % MOD for x, y in zip(a_ntt, b_ntt)]
    c = ntt(c_ntt, inverse=True)
    c = c[:size]
    c = c[:max_len+1]
    return [x % MOD for x in c]

def main():
    S = sys.stdin.readline().strip()
    max_sum = len(S)
    if max_sum == 0:
        print(0)
        return
    
    max_fact = max_sum
    fact = [1] * (max_fact + 1)
    for i in range(1, max_fact+1):
        fact[i] = fact[i-1] * i % MOD
    
    inv_fact = [1] * (max_fact + 1)
    inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD)
    for i in range(max_fact-1, -1, -1):
        inv_fact[i] = inv_fact[i+1] * (i+1) % MOD
    
    from collections import defaultdict
    cnt = defaultdict(int)
    for c in S:
        cnt[c] += 1
    chars = list(cnt.values())
    
    G = [1]
    for c in chars:
        f = [inv_fact[k] for k in range(c+1)]
        new_G = convolve(G, f, max_sum)
        G = new_G
    
    ans = 0
    for t in range(1, len(G)):
        if t > max_sum:
            break
        ans = (ans + G[t] * fact[t]) % MOD
    print(ans % MOD)

if __name__ == "__main__":
    main()