結果

問題 No.754 畳み込みの和
ユーザー 37zigen37zigen
提出日時 2018-03-13 02:30:33
言語 Java21
(openjdk 21)
結果
AC  
実行時間 2,048 ms / 5,000 ms
コード長 3,183 bytes
コンパイル時間 1,978 ms
コンパイル使用メモリ 79,544 KB
実行使用メモリ 68,228 KB
最終ジャッジ日時 2024-04-28 07:14:08
合計ジャッジ時間 10,505 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2,046 ms
67,108 KB
testcase_01 AC 2,048 ms
68,064 KB
testcase_02 AC 2,006 ms
68,228 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.util.Arrays;
import java.util.Scanner;

public class Main {
	public static void main(String[] args) {
		new Main().run();
	}

	long MOD = 1_000_000_000 + 7;

	void run() {
		Scanner sc = new Scanner(System.in);
		int n = sc.nextInt();
		long[] MODULO = { 1012924417L, 998244353L, 163577857L };
		long[] root = { 5, 3, 23 };
		long[] a = new long[n + 1];
		long[] b = new long[n + 1];
		for (int i = 0; i < a.length; ++i) {
			a[i] = sc.nextLong();
		}
		for (int i = 0; i < b.length; ++i) {
			b[i] = sc.nextLong();
		}
		long[][] c = new long[3][n + 1];
		c[0] = mul(a, b, MODULO[0], root[0]);
		c[1] = mul(a, b, MODULO[1], root[1]);
		c[2] = mul(a, b, MODULO[2], root[2]);
		long ans = 0;
		for (int i = 0; i <= n; ++i) {
			ans = (ans
					+ garner(new long[] { c[0][i], c[1][i], c[2][i] }, new long[] { MODULO[0], MODULO[1], MODULO[2] }))
					% MOD;
		}
		System.out.println(ans);
	}

	long garner(long[] x, long[] m) {
		assert x.length == m.length;
		int n = x.length;
		long[] gamma = new long[n];
		for (int i = 0; i < n; i++) {
			long prod = 1;
			for (int j = 0; j < i; j++) {
				prod = prod * m[j] % m[i];
			}
			gamma[i] = inv(prod, m[i]);
		}
		long[] v = new long[n];
		v[0] = x[0];
		for (int i = 1; i < n; i++) {
			long tmp = v[i - 1];
			for (int j = i - 2; j >= 0; j--) {
				tmp = (tmp * m[j] + v[j]) % m[i];
			}
			v[i] = (x[i] - tmp) * gamma[i] % m[i];
			while (v[i] < 0)
				v[i] += m[i];
		}
		long ret = 0;
		for (int i = v.length - 1; i >= 0; i--) {
			ret = (ret * m[i] + v[i]) % MOD;
		}
		return ret;
	}

	public static long inv(long a, long mod) {
		long b = mod;
		long p = 1, q = 0;
		while (b > 0) {
			long c = a / b;
			long d;
			d = a;
			a = b;
			b = d % b;
			d = p;
			p = q;
			q = d - c * q;
		}
		return p < 0 ? p + mod : p;
	}

	long[] mul(long[] a, long[] b, long MODULO, long root) {
		int n = Integer.highestOneBit(a.length + b.length) << 1;
		a = Arrays.copyOf(a, n);
		b = Arrays.copyOf(b, n);
		a = fft(a, false, MODULO, root);
		b = fft(b, false, MODULO, root);
		for (int i = 0; i < n; ++i)
			a[i] = a[i] * b[i] % MODULO;
		a = fft(a, true, MODULO, root);
		long inv = inv(n, MODULO);
		for (int i = 0; i < n; ++i) {
			a[i] = a[i] * inv % MODULO;
		}
		return a;
	}

	long[] fft(long[] a, boolean inv, long MODULO, long root) {
		int n = a.length;
		int c = 0;
		for (int i = 1; i < n; ++i) {
			for (int j = n >> 1; j > (c ^= j); j >>= 1)
				;
			if (c > i) {
				long d = a[i];
				a[i] = a[c];
				a[c] = d;
			}
		}
		for (int i = 1; i < n; i *= 2) {
			long w = pow(root, (MODULO - 1) / (2 * i), MODULO);
			if (inv)
				w = inv(w, MODULO);
			for (int j = 0; j < n; j += 2 * i) {
				long wn = 1;
				for (int k = 0; k < i; ++k) {
					long u = a[j + k];
					long v = a[j + k + i] * wn % MODULO;
					a[j + k] = (u + v) % MODULO;
					a[j + k + i] = (u - v + MODULO) % MODULO;
					wn = wn * w % MODULO;
				}
			}
		}
		return a;
	}

	long pow(long a, long n, long MODULO) {
		long ret = 1;
		for (; n > 0; n >>= 1, a = a * a % MODULO) {
			if (n % 2 == 1)
				ret = ret * a % MODULO;
		}
		return ret;
	}

	static void tr(Object... objects) {
		System.out.println(Arrays.deepToString(objects));
	}
}
0