結果
問題 | No.206 数の積集合を求めるクエリ |
ユーザー |
|
提出日時 | 2025-02-09 14:45:08 |
言語 | Java (openjdk 23) |
結果 |
AC
|
実行時間 | 558 ms / 7,000 ms |
コード長 | 13,901 bytes |
コンパイル時間 | 3,116 ms |
コンパイル使用メモリ | 90,260 KB |
実行使用メモリ | 64,272 KB |
最終ジャッジ日時 | 2025-02-09 14:45:32 |
合計ジャッジ時間 | 13,499 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 28 |
ソースコード
import java.io.IOException;import java.io.InputStream;import java.util.Arrays;import java.util.NoSuchElementException;import java.util.function.IntFunction;import java.util.function.LongFunction;public class Main {public static void main(String[] args) {Main o = new Main();o.solve();}static final long MOD = 998244353;public void solve() {FastScanner sc = new FastScanner(System.in);int L = sc.nextInt();int M = sc.nextInt();int N = sc.nextInt();int[] A = new int[L];int[] B = new int[M];long[] a = new long[N + 1];long[] b = new long[N + 1];for (int i = 0; i < L; i++) {A[i] = sc.nextInt();a[A[i]] = 1;}for (int i = 0; i < M; i++) {B[i] = sc.nextInt();b[N - B[i]] = 1;}int Q = sc.nextInt();Convolution con = new Convolution();long[] c = con.convolution(a, b, MOD);long[] ans = new long[Q];for (int q = 0; q < Q; q++) {ans[q] = c[N + q];}print(ans, LF);}class Convolution {class FftInfo {long mod;int g;int rank2;long[] root;long[] iroot;long[] rate2;long[] irate2;long[] rate3;long[] irate3;FftInfo(long mod) {this.mod = mod;this.g = primitive_root((int)mod);this.rank2 = Long.numberOfTrailingZeros(mod - 1);this.root = new long[rank2 + 1];this.iroot = new long[rank2 + 1];this.rate2 = new long[Math.max(0, rank2 - 2 + 1)];this.irate2 = new long[Math.max(0, rank2 - 2 + 1)];this.rate3 = new long[Math.max(0, rank2 - 3 + 1)];this.irate3 = new long[Math.max(0, rank2 - 3 + 1)];root[rank2] = mod_pow(g % mod, (mod - 1) >> rank2, mod);iroot[rank2] = mod_inv(root[rank2], mod);for (int i = rank2 - 1; i >= 0; i--) {root[i] = (root[i + 1] * root[i + 1]) % mod;iroot[i] = (iroot[i + 1] * iroot[i + 1]) % mod;}{long prod = 1;long iprod = 1;for (int i = 0; i <= rank2 - 2; i++) {rate2[i] = (root[i + 2] * prod) % mod;irate2[i] = (iroot[i + 2] * iprod) % mod;prod = (prod * iroot[i + 2]) % mod;iprod = (iprod * root[i + 2]) % mod;}}{long prod = 1;long iprod = 1;for (int i = 0; i <= rank2 - 3; i++) {rate3[i] = (root[i + 3] * prod) % mod;irate3[i] = (iroot[i + 3] * iprod) % mod;prod = (prod * iroot[i + 3]) % mod;iprod = (iprod * root[i + 3]) % mod;}}}};long mod_inv(long a, long m) {long x0 = 1;long y0 = 0;long x1 = 0;long y1 = 1;long b = m;while ( b != 0 ) {long q = a / b;long tmp = b;b = a % b;a = tmp;tmp = x1;x1 = x0 - q * x1;x0 = tmp;tmp = y1;y1 = y0 - q * y1;y0 = tmp;}return (x0 + m) % m;}long mod_pow(long a, long n, long m) {long ret = 1L;while ( n > 0 ) {if ( (n & 1L) != 0 ) ret = (ret * a) % m;a = (a * a) % m;n >>= 1;}return ret;}int bit_ceil(int n) {int x = 1;while (x < n) x *= 2;return x;}long safe_mod(long x, long m) {x %= m;if ( x < 0 ) x += m;return x;}long[] inv_gcd(long a, long b) {a = safe_mod(a, b);if (a == 0) return new long[] {b, 0};long s = b;long t = a;long m0 = 0;long m1 = 1;while (t != 0) {long u = s / t;s -= t * u;m0 -= m1 * u;long tmp = s;s = t;t = tmp;tmp = m0;m0 = m1;m1 = tmp;}if (m0 < 0) m0 += b / s;return new long[] {s, m0};}int primitive_root(int m) {if (m == 2) return 1;if (m == 167772161) return 3;if (m == 469762049) return 3;if (m == 754974721) return 11;if (m == 998244353) return 3;int[] divs = new int[20];divs[0] = 2;int cnt = 1;int x = (m - 1) / 2;while (x % 2 == 0) x /= 2;for (int i = 3; ((long)(i))*i <= x; i += 2) {if (x % i == 0) {divs[cnt++] = i;while (x % i == 0) {x /= i;}}}if (x > 1) {divs[cnt++] = x;}for (int g = 2;; g++) {boolean ok = true;for (int i = 0; i < cnt; i++) {if (mod_pow(g, (m - 1) / divs[i], m) == 1) {ok = false;break;}}if (ok) return g;}}void butterfly(long[] a, long mod) {int n = a.length;int h = Integer.numberOfTrailingZeros(n);FftInfo info = new FftInfo(mod);int len = 0;while (len < h) {if (h - len == 1) {int p = 1 << (h - len - 1);long rot = 1;for (int s = 0; s < (1 << len); s++) {int offset = s << (h - len);for (int i = 0; i < p; i++) {long l = a[i + offset];long r = (a[i + offset + p] * rot) % mod;a[i + offset] = (l + r) % mod;a[i + offset + p] = safe_mod(l - r, mod);}if (s + 1 != (1 << len))rot = (rot * info.rate2[Integer.numberOfTrailingZeros(~s)]) % mod;}len++;} else {// 4-baseint p = 1 << (h - len - 2);long rot = 1;long imag = info.root[2];for (int s = 0; s < (1 << len); s++) {long rot2 = (rot * rot) % mod;long rot3 = (rot2 * rot) % mod;int offset = s << (h - len);for (int i = 0; i < p; i++) {long mod2 = mod * mod;long a0 = a[i + offset];long a1 = a[i + offset + p] * rot;long a2 = a[i + offset + 2 * p] * rot2;long a3 = a[i + offset + 3 * p] * rot3;long a1na3imag = safe_mod(a1 + mod2 - a3, mod) * imag;long na2 = mod2 - a2;a[i + offset] = (a0 + a2 + a1 + a3) % mod;a[i + offset + 1 * p] = safe_mod(a0 + a2 + (2 * mod2 - (a1 + a3)), mod);a[i + offset + 2 * p] = (a0 + na2 + a1na3imag) % mod;a[i + offset + 3 * p] = (a0 + na2 + (mod2 - a1na3imag)) % mod;}if (s + 1 != (1 << len))rot = (rot * info.rate3[Integer.numberOfTrailingZeros(~s)]) % mod;}len += 2;}}}void butterfly_inv(long[] a, long mod) {int n = a.length;int h = Integer.numberOfTrailingZeros(n);FftInfo info = new FftInfo(mod);int len = h;while (len != 0) {if (len == 1) {int p = 1 << (h - len);long irot = 1;for (int s = 0; s < (1 << (len - 1)); s++) {int offset = s << (h - len + 1);for (int i = 0; i < p; i++) {long l = a[i + offset];long r = a[i + offset + p];a[i + offset] = (l + r) % mod;a[i + offset + p] = (safe_mod(mod + l - r, mod) * irot) % mod;}if (s + 1 != (1 << (len - 1)))irot = (irot * info.irate2[Integer.numberOfTrailingZeros(~s)]) % mod;}len--;} else {// 4-baseint p = 1 << (h - len);long irot = 1;long iimag = info.iroot[2];for (int s = 0; s < (1 << (len - 2)); s++) {long irot2 = (irot * irot) % mod;long irot3 = (irot2 * irot) % mod;int offset = s << (h - len + 2);for (int i = 0; i < p; i++) {long a0 = a[i + offset + 0 * p];long a1 = a[i + offset + 1 * p];long a2 = a[i + offset + 2 * p];long a3 = a[i + offset + 3 * p];long a2na3iimag = (safe_mod(mod + a2 - a3, mod) * iimag) % mod;a[i + offset] = (a0 + a1 + a2 + a3) % mod;a[i + offset + 1 * p] = (((a0 + (mod - a1) + a2na3iimag) % mod) * irot) % mod;a[i + offset + 2 * p] = (((a0 + a1 + (mod - a2) + (mod - a3)) % mod) * irot2) % mod;a[i + offset + 3 * p] = (((a0 + (mod - a1) + (mod - a2na3iimag)) % mod) * irot3) % mod;}if (s + 1 != (1 << (len - 2)))irot = (irot * info.irate3[Integer.numberOfTrailingZeros(~s)]) % mod;}len -= 2;}}}long[] convolution_naive(long[] a, long[] b, long mod) {int n = a.length;int m = b.length;long[] ans = new long[n + m - 1];if (n < m) {for (int j = 0; j < m; j++) {for (int i = 0; i < n; i++) {ans[i + j] = (ans[i + j] + a[i] * b[j]) % mod;}}} else {for (int i = 0; i < n; i++) {for (int j = 0; j < m; j++) {ans[i + j] = (ans[i + j] + a[i] * b[j]) % mod;}}}return ans;}long[] convolution_fft(long[] a, long[] b, long mod) {int n = a.length;int m = b.length;int z = bit_ceil(n + m - 1);a = Arrays.copyOf(a, z);butterfly(a, mod);b = Arrays.copyOf(b, z);butterfly(b, mod);for (int i = 0; i < z; i++) {a[i] = (a[i] * b[i]) % mod;}butterfly_inv(a, mod);a = Arrays.copyOf(a, n + m - 1);long iz = mod_inv(z % mod, mod);for (int i = 0; i < n + m - 1; i++) a[i] = (a[i] * iz) % mod;return a;}public long[] convolution(long[] a, long[] b, long mod) {int n = a.length;int m = b.length;if (n == 0 || m == 0) return new long[0];int z = bit_ceil(n + m - 1);if ( (mod - 1) % z != 0 ) throw new RuntimeException();if (Math.min(n, m) <= 60) return convolution_naive(a, b, mod);return convolution_fft(a, b, mod);}public long[] convolution_ll(long[] a, long[] b) {int n = a.length;int m = b.length;if (n == 0 || m == 0) return new long[0];final long MOD1 = 754974721; // 2^24final long MOD2 = 167772161; // 2^25final long MOD3 = 469762049; // 2^26final long M2M3 = MOD2 * MOD3;final long M1M3 = MOD1 * MOD3;final long M1M2 = MOD1 * MOD2;final long M1M2M3 = MOD1 * MOD2 * MOD3;final long i1 = inv_gcd(MOD2 * MOD3, MOD1)[1];final long i2 = inv_gcd(MOD1 * MOD3, MOD2)[1];final long i3 = inv_gcd(MOD1 * MOD2, MOD3)[1];final int MAX_AB_BIT = 24;if ( n + m - 1 > (1 << MAX_AB_BIT) ) throw new RuntimeException();long[] c1 = convolution(a, b, MOD1);long[] c2 = convolution(a, b, MOD2);long[] c3 = convolution(a, b, MOD3);long[] c = new long[n + m - 1];for (int i = 0; i < n + m - 1; i++) {long x = 0;x += ((c1[i] * i1) % MOD1) * M2M3;x += ((c2[i] * i2) % MOD2) * M1M3;x += ((c3[i] * i3) % MOD3) * M1M2;long diff = c1[i] - safe_mod(x, MOD1);if (diff < 0) diff += MOD1;final long[] offset = new long[]{0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3};x -= offset[(int)safe_mod(diff, 5)];c[i] = x;}return c;}}static final char LF = '\n';static final char SPACE = ' ';static final String YES = "Yes";static final String NO = "No";void print(int[] array, char sep) {print(array, sep, n -> n, 0, array.length);}void print(int[] array, char sep, IntFunction<Integer> conv) {print(array, sep, conv, 0, array.length);}void print(int[] array, char sep, IntFunction<Integer> conv, int start, int end) {StringBuilder ans = new StringBuilder();for (int i = start; i < end; i++) {ans.append(conv.apply(array[i]));ans.append(sep);}ans.deleteCharAt(ans.length() - 1);System.out.println(ans.toString());}void print(long[] array, char sep) {print(array, sep, n -> n, 0, array.length);}void print(long[] array, char sep, LongFunction<Long> conv) {print(array, sep, conv, 0, array.length);}void print(long[] array, char sep, LongFunction<Long> conv, int start, int end) {StringBuilder ans = new StringBuilder();for (int i = start; i < end; i++) {ans.append(conv.apply(array[i]));ans.append(sep);}ans.deleteCharAt(ans.length() - 1);System.out.println(ans.toString());}void printYesNo(boolean yesno) {System.out.println(yesno ? YES : NO);}void printDouble(double val, int digit) {System.out.println(String.format("%." + digit + "f", val));}class FastScanner {private final InputStream in;private final byte[] buf = new byte[1024];private int ptr = 0;private int buflen = 0;FastScanner( InputStream source ) { this.in = source; }private boolean hasNextByte() {if ( ptr < buflen ) return true;else {ptr = 0;try { buflen = in.read(buf); } catch (IOException e) { e.printStackTrace(); }if ( buflen <= 0 ) return false;}return true;}private int readByte() { if ( hasNextByte() ) return buf[ptr++]; else return -1; }private boolean isPrintableChar( int c ) { return 33 <= c && c <= 126; }private boolean isNumeric( int c ) { return '0' <= c && c <= '9'; }private void skipToNextPrintableChar() { while ( hasNextByte() && !isPrintableChar(buf[ptr]) ) ptr++; }public boolean hasNext() { skipToNextPrintableChar(); return hasNextByte(); }public String next() {if ( !hasNext() ) throw new NoSuchElementException();StringBuilder ret = new StringBuilder();int b = readByte();while ( isPrintableChar(b) ) { ret.appendCodePoint(b); b = readByte(); }return ret.toString();}public long nextLong() {if ( !hasNext() ) throw new NoSuchElementException();long ret = 0;int b = readByte();boolean negative = false;if ( b == '-' ) { negative = true; if ( hasNextByte() ) b = readByte(); }if ( !isNumeric(b) ) throw new NumberFormatException();while ( true ) {if ( isNumeric(b) ) ret = ret * 10 + b - '0';else if ( b == -1 || !isPrintableChar(b) ) return negative ? -ret : ret;else throw new NumberFormatException();b = readByte();}}public int nextInt() { return (int)nextLong(); }public double nextDouble() { return Double.parseDouble(next()); }}}