結果

問題 No.2062 Sum of Subset mod 999630629
ユーザー taiga0629kyoprotaiga0629kyopro
提出日時 2022-08-11 08:35:44
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,448 ms / 5,000 ms
コード長 15,206 bytes
コンパイル時間 303 ms
コンパイル使用メモリ 82,456 KB
実行使用メモリ 213,944 KB
最終ジャッジ日時 2024-04-22 02:20:24
合計ジャッジ時間 23,143 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 77 ms
95,076 KB
testcase_01 AC 80 ms
94,468 KB
testcase_02 AC 76 ms
95,004 KB
testcase_03 AC 76 ms
93,872 KB
testcase_04 AC 81 ms
93,984 KB
testcase_05 AC 76 ms
93,776 KB
testcase_06 AC 78 ms
94,492 KB
testcase_07 AC 84 ms
93,960 KB
testcase_08 AC 100 ms
117,004 KB
testcase_09 AC 91 ms
112,584 KB
testcase_10 AC 91 ms
109,956 KB
testcase_11 AC 703 ms
135,960 KB
testcase_12 AC 718 ms
136,540 KB
testcase_13 AC 442 ms
123,408 KB
testcase_14 AC 716 ms
136,188 KB
testcase_15 AC 223 ms
121,832 KB
testcase_16 AC 734 ms
134,352 KB
testcase_17 AC 758 ms
135,824 KB
testcase_18 AC 451 ms
123,048 KB
testcase_19 AC 236 ms
121,820 KB
testcase_20 AC 313 ms
122,296 KB
testcase_21 AC 445 ms
122,424 KB
testcase_22 AC 303 ms
122,268 KB
testcase_23 AC 98 ms
115,620 KB
testcase_24 AC 100 ms
114,896 KB
testcase_25 AC 2,448 ms
213,944 KB
testcase_26 AC 2,409 ms
213,672 KB
testcase_27 AC 2,433 ms
213,600 KB
testcase_28 AC 2,434 ms
213,652 KB
testcase_29 AC 2,423 ms
213,676 KB
testcase_30 AC 1,288 ms
165,320 KB
testcase_31 AC 1,304 ms
165,180 KB
権限があれば一括ダウンロードができます

ソースコード

diff #


class FPS:
    sum_e = (
        911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456,
        131300601, 842657263, 730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443,
        56250497)
    sum_ie = (
        86583718, 372528824, 373294451, 645684063, 112220581, 692852209, 155456985, 797128860, 90816748, 860285882,
        927414960, 354738543, 109331171, 293255632, 535113200, 308540755, 121186627, 608385704, 438932459, 359477183,
        824071951)
    mod = 998244353
    Func = [0]

    def __init__(self, L):
        self.Func = [x % self.mod for x in L]

    def butterfly(self, a):
        n = len(a)
        h = (n - 1).bit_length()
        for ph in range(1, h + 1):
            w = 1 << (ph - 1)
            p = 1 << (h - ph)
            now = 1
            for s in range(w):
                offset = s << (h - ph + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p] * now
                    r %= self.mod
                    a[i + offset] = l + r
                    a[i + offset] %= self.mod
                    a[i + offset + p] = l - r
                    a[i + offset + p] %= self.mod
                now *= self.sum_e[(~s & -~s).bit_length() - 1]
                now %= self.mod
        return a

    def butterfly_inv(self, a):
        n = len(a)
        h = (n - 1).bit_length()
        for ph in range(h, 0, -1):
            w = 1 << (ph - 1)
            p = 1 << (h - ph)
            inow = 1
            for s in range(w):
                offset = s << (h - ph + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p]
                    a[i + offset] = l + r
                    a[i + offset] %= self.mod
                    a[i + offset + p] = (l - r) * inow
                    a[i + offset + p] %= self.mod
                inow *= self.sum_ie[(~s & -~s).bit_length() - 1]
                inow %= self.mod
        return a

    def __mul__(self, other):
        if type(other) == int:
            ret = [(x * other) % self.mod for x in self.Func]
            return FPS(ret)
        a = self.Func
        b = other.Func
        n = len(a);
        m = len(b)
        if not (a) or not (b):
            return FPS([])
        if min(n, m) <= 40:
            if n < m:
                n, m = m, n
                a, b = b, a
            res = [0] * (n + m - 1)
            for i in range(n):
                for j in range(m):
                    res[i + j] += a[i] * b[j]
                    res[i + j] %= self.mod
            return FPS(res)
        z = 1 << ((n + m - 2).bit_length())
        a = a + [0] * (z - n)
        b = b + [0] * (z - m)
        a = self.butterfly(a)
        b = self.butterfly(b)
        c = [0] * z
        for i in range(z):
            c[i] = (a[i] * b[i]) % self.mod
        self.butterfly_inv(c)
        iz = pow(z, self.mod - 2, self.mod)
        for i in range(n + m - 1):
            c[i] = (c[i] * iz) % self.mod
        return FPS(c[:n + m - 1])

    def __imul__(self, other):
        self = self * other
        return self

    def __add__(self, other):
        res = [0 for i in range(max(len(self.Func), len(other.Func)))]
        for i, x in enumerate(self.Func):
            res[i] += x
            res[i] %= self.mod
        for i, x in enumerate(other.Func):
            res[i] += x
            res[i] %= self.mod
        return FPS(res)

    def __iadd__(self, other):
        self = (self + other)
        return self

    def __sub__(self, other):
        res = [0 for i in range(max(len(self.Func), len(other.Func)))]
        for i, x in enumerate(self.Func):
            res[i] += x
            res[i] %= self.mod
        for i, x in enumerate(other.Func):
            res[i] -= x
            res[i] %= self.mod
        return FPS(res)

    def __isub__(self, other):
        self = self - other
        return self

    def inv(self, d=-1):
        n = len(self.Func)
        assert n != 0 and self.Func[0] != 0
        if d == -1: d = n
        assert d > 0
        res = [pow(self.Func[0], self.mod - 2, self.mod)]
        while (len(res) < d):
            m = len(res)
            f = [self.Func[i] for i in range(min(n, 2 * m))]
            r = res[:]

            if len(f) < 2 * m:
                f += [0] * (2 * m - len(f))
            elif len(f) > 2 * m:
                f = f[:2 * m]
            if len(r) < 2 * m:
                r += [0] * (2 * m - len(r))
            elif len(r) > 2 * m:
                r = r[:2 * m]
            f = self.butterfly(f)
            r = self.butterfly(r)
            for i in range(2 * m):
                f[i] *= r[i]
                f[i] %= self.mod
            f = self.butterfly_inv(f)
            f = f[m:]
            if len(f) < 2 * m:
                f += [0] * (2 * m - len(f))
            elif len(f) > 2 * m:
                f = f[:2 * m]
            f = self.butterfly(f)
            for i in range(2 * m):
                f[i] *= r[i]
                f[i] %= self.mod
            f = self.butterfly_inv(f)
            iz = pow(2 * m, self.mod - 2, self.mod)
            iz *= -iz
            iz %= self.mod
            for i in range(m):
                f[i] *= iz
                f[i] %= self.mod
            res += f[:m]
        return FPS(res[:d])

    def __truediv__(self, other):
        if type(other) == int:
            invother = pow(other, self.mod - 2, self.mod)
            ret = [(x * invother) % self.mod for x in self.Func]
            return FPS(ret)
        assert (other.Func[0] != 0)
        return self * (other.inv())

    def __itruediv__(self, other):
        self = self / other
        return self

    def __lshift__(self, d):
        n = len(self.Func)
        self.Func = [0] * d + self.Func
        return FPS(self.Func[:n])

    def __ilshift__(self, d):
        self = self << d
        return self

    def __rshift__(self, d):
        n = len(self.Func)
        self.Func = self.Func[min(n, d):]
        self.Func += [0] * (n - len(self.Func))
        return FPS(self.Func)

    def __irshift__(self, d):
        self = self >> d
        return self

    def __str__(self):
        return f'FPS({self.Func})'

    def __eq__(self,other):
        return len(self.Func)==len(other.Func)
    def __lt__(self, other):
        return len(self.Func)<len(other.Func)

    def __ne__(self, other):
        return not self.__eq__(other)

    def __le__(self, other):
        return self.__lt__(other) or self.__eq__(other)

    def __gt__(self, other):
        return not self.__le__(other)

    def __ge__(self, other):
        return not self.__lt__(other)



    def diff(self):
        n = len(self.Func)
        ret = [0 for i in range(max(0, n - 1))]
        for i in range(1, n):
            ret[i - 1] = (self.Func[i] * i) % self.mod
        return FPS(ret)

    def integral(self):
        n = len(self.Func)
        ret = [0 for i in range(n + 1)]
        for i in range(n):
            ret[i + 1] = self.Func[i] * pow(i + 1, self.mod - 2, self.mod) % self.mod
        return FPS(ret)

    def log(self, deg=-1):
        assert self.Func[0] == 1
        n = len(self.Func)
        if deg == -1: deg = n
        return (self.diff() * self.inv()).integral()

    def mod_sqrt(self, a):
        p = self.mod
        assert 0 <= a and a < p
        if a < 2: return a
        if pow(a, (p - 1) // 2, p) != 1: return -1
        b = 1;
        one = 1
        while (pow(b, (p - 1) >> 1, p) == 1):
            b += one
        m = p - 1;
        e = 0
        while (m % 2 == 0):
            m >>= 1
            e += 1
        x = pow(a, (m - 1) >> 1, p)
        y = (a * x * x) % p
        x *= a;
        x %= p
        z = pow(b, m, p)
        while (y != 1):
            j = 0
            t = y
            while (t != one):
                j += 1
                t *= t
                t %= p
            z = pow(z, 1 << (e - j - 1), p)
            x *= z
            x %= p
            z *= z
            z %= p
            y *= z
            y %= p
            e = j
        return x

    def sqrt(self, deg=-1):
        n = len(self.Func)
        if deg == -1: deg = n
        if n == 0: return FPS([0 for i in range(deg)])
        if self.Func[0] == 0:
            for i in range(1, n):
                if self.Func[i] != 0:
                    if i & 1: return FPS([])
                    if deg - i // 2 <= 0: break
                    ret = (self >> i).sqrt(deg - i // 2)
                    if len(ret.Func) == 0: return FPS([])
                    ret = ret << (i // 2)
                    if len(ret.Func) < deg:
                        ret.Func += [0] * (deg - len(ret.Func))
                    return ret
            return FPS([0] * deg)
        sqr = self.mod_sqrt(self.Func[0])
        if sqr == -1: return FPS([])
        assert sqr * sqr % self.mod == self.Func[0]
        ret = FPS([sqr])
        inv2 = (self.mod + 1) // 2
        i = 1
        while (i < deg):
            ret = (ret + FPS(self.Func[:i << 1]) * ret.inv(i << 1)) * inv2
            i <<= 1
        return FPS(ret.Func[:deg])

    def resize(self, deg):
        if len(self.Func) < deg:
            return FPS(self.Func + [0] * (deg - len(self.Func)))
        elif len(self.Func) > deg:
            return FPS(self.Func[:deg])
        else:
            return self

    def exp(self, deg=-1):
        n = len(self.Func)
        assert n > 0 and self.Func[0] == 0
        if deg == -1: deg = n
        assert deg >= 0
        g = [1]
        g_fft = [1, 1]
        self.Func[0] = 1
        self.resize(deg)
        h_drv = self.diff()
        m = 2
        while (m < deg):
            f_fft = self.Func[:m] + [0] * m
            self.butterfly(f_fft)

            # step 2.a
            _g = [f_fft[i] * g_fft[i] % self.mod for i in range(m)]
            self.butterfly_inv(_g)
            _g = _g[m // 2:m] + [0] * (m // 2)
            self.butterfly(_g)
            for i in range(m):
                _g[i] *= g_fft[i]
                _g[i] %= self.mod
            self.butterfly_inv(_g)
            tmp = pow(-m * m, self.mod - 2, self.mod)
            for i in range(m):
                _g[i] *= tmp
                _g[i] %= self.mod
            g += _g[:m // 2]
            # step 2.b--2.d
            t = FPS(self.Func[:m]).diff()
            r = h_drv.Func[:m - 1] + [0]
            self.butterfly(r)
            for i in range(m):
                r[i] *= f_fft[i]
                r[i] %= self.mod
            self.butterfly_inv(r)
            tmp = pow(-m, self.mod - 2, self.mod)
            for i in range(m):
                r[i] *= tmp
                r[i] %= self.mod
            t = (t + FPS(r)).Func
            t = [t[-1]] + t
            t.pop()
            # step 2.e
            if (2 * m < deg):
                if len(t) < 2 * m:
                    t += [0] * (2 * m - len(t))
                elif len(t) > 2 * m:
                    t = t[:2 * m]
                self.butterfly(t)
                g_fft = g[:]
                if len(g_fft) < 2 * m:
                    g_fft += [0] * (2 * m - len(g_fft))
                elif len(g_fft) > 2 * m:
                    g_fft = g_fft[:2 * m]
                self.butterfly(g_fft)
                for i in range(2 * m):
                    t[i] *= g_fft[i]
                    t[i] %= self.mod
                self.butterfly_inv(t)
                tmp = pow(2 * m, self.mod - 2, self.mod)
                t = t[:m]
                for i in range(m):
                    t[i] *= tmp
                    t[i] %= self.mod
            else:
                g1 = g[m // 2:]
                s1 = t[m // 2:]
                t = t[:m // 2]
                g1 += [0] * (m - len(g1))
                s1 += [0] * (m - len(s1))
                t += [0] * (m - len(t))

                self.butterfly(g1)
                self.butterfly(t)
                self.butterfly(s1)
                for i in range(m):
                    s1[i] = (g_fft[i] * s1[i] + g1[i] * t[i]) % self.mod
                for i in range(m):
                    t[i] *= g_fft[i]
                    t[i] %= self.mod
                self.butterfly_inv(t)
                self.butterfly_inv(s1)
                for i in range(m // 2):
                    t[i + m // 2] += s1[i]
                    t[i + m // 2] %= self.mod
                tmp = pow(m, self.mod - 2, self.mod)
                for i in range(m):
                    t[i] *= tmp
                    t[i] %= self.mod
            # step 2.f
            v = self.Func[m:min(deg, 2 * m)] + [0] * (2 * m - min(deg, 2 * m))
            t = [0] * (m - 1) + t
            t = FPS(t).integral().Func
            for i in range(m):
                v[i] -= t[m + i]
                v[i] %= self.mod
            # step 2.g
            if len(v) < 2 * m:
                v += [0] * (2 * m - len(v))
            else:
                v = v[:2 * m]
            self.butterfly(v)
            for i in range(2 * m):
                v[i] *= f_fft[i]
                v[i] %= self.mod
            self.butterfly_inv(v)
            v = v[:m]
            tmp = pow(2 * m, self.mod - 2, self.mod)
            for i in range(m):
                v[i] *= tmp
                v[i] %= self.mod
            # step 2.h
            for i in range(min(deg - m, m)):
                self.Func[m + i] = v[i]
            m *= 2
        return self

    def powfps(self, k, deg=-1):
        a = self.Func[:]
        n = len(self.Func)
        l = 0
        while (l < len(a) and not a[l]):
            l += 1
        if l * k >= n:
            return FPS([0] * n)
        ic = pow(a[l], self.mod - 2, self.mod)
        pc = pow(a[l], k, self.mod)
        a = FPS([(a[i] * ic) % self.mod for i in range(l, len(a))]).log()
        a *= k
        a = a.exp()
        a *= pc
        a = [0] * (l * k) + a.Func[:n - l * k]
        return FPS(a)


mod=998244353
#############################
#############
cnb_max=2*10**6+10
#############

kai=[1]*(cnb_max+1)
rkai=[1]*(cnb_max+1)
for i in range(cnb_max):
    kai[i+1]=kai[i]*(i+1)%mod

rkai[cnb_max]=pow(kai[cnb_max],mod-2,mod)
for i in range(cnb_max):
    rkai[cnb_max-1-i]=rkai[cnb_max-i]*(cnb_max-i)%mod

def cnb(x,y):
    if y>x:
        return 0
    if x<0:return 0
    if y<0:return 0
    return (kai[x]*rkai[y]%mod)*rkai[x-y]%mod


def inv(n):
    return kai[n-1]*rkai[n]%mod

##################################


def sol2(n,A,p):
    a=A[:]
    ans=sum(a)*pow(2,n-1,mod)%mod
    k=sum(a)-p
    cnt=0
    if k>=0:
        num=[0]*(k+1)
        for x in a:
            if x<=k:num[x]+=1
        fx=[0]*(k+1)
        for e in range(1,k+1):
            for j in range(1,10**9):
                if e*j>k:break
                if j%2==1:
                    fx[e*j]+=num[e]*inv(j)
                else:
                    fx[e*j]-=num[e]*inv(j)
                fx[e*j]%=mod
        fx=FPS(fx)
        fx=fx.exp().Func
        for i in range(k+1):
            cnt+=fx[i]
            cnt%=mod
    ans-=p*cnt
    ans%=mod
    return ans

p=999630629
n=int(input())
a=list(map(int,input().split()))
print(sol2(n,a,p))


0