結果
問題 |
No.206 数の積集合を求めるクエリ
|
ユーザー |
|
提出日時 | 2025-02-09 14:45:08 |
言語 | Java (openjdk 23) |
結果 |
AC
|
実行時間 | 558 ms / 7,000 ms |
コード長 | 13,901 bytes |
コンパイル時間 | 3,116 ms |
コンパイル使用メモリ | 90,260 KB |
実行使用メモリ | 64,272 KB |
最終ジャッジ日時 | 2025-02-09 14:45:32 |
合計ジャッジ時間 | 13,499 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 28 |
ソースコード
import java.io.IOException; import java.io.InputStream; import java.util.Arrays; import java.util.NoSuchElementException; import java.util.function.IntFunction; import java.util.function.LongFunction; public class Main { public static void main(String[] args) { Main o = new Main(); o.solve(); } static final long MOD = 998244353; public void solve() { FastScanner sc = new FastScanner(System.in); int L = sc.nextInt(); int M = sc.nextInt(); int N = sc.nextInt(); int[] A = new int[L]; int[] B = new int[M]; long[] a = new long[N + 1]; long[] b = new long[N + 1]; for (int i = 0; i < L; i++) { A[i] = sc.nextInt(); a[A[i]] = 1; } for (int i = 0; i < M; i++) { B[i] = sc.nextInt(); b[N - B[i]] = 1; } int Q = sc.nextInt(); Convolution con = new Convolution(); long[] c = con.convolution(a, b, MOD); long[] ans = new long[Q]; for (int q = 0; q < Q; q++) { ans[q] = c[N + q]; } print(ans, LF); } class Convolution { class FftInfo { long mod; int g; int rank2; long[] root; long[] iroot; long[] rate2; long[] irate2; long[] rate3; long[] irate3; FftInfo(long mod) { this.mod = mod; this.g = primitive_root((int)mod); this.rank2 = Long.numberOfTrailingZeros(mod - 1); this.root = new long[rank2 + 1]; this.iroot = new long[rank2 + 1]; this.rate2 = new long[Math.max(0, rank2 - 2 + 1)]; this.irate2 = new long[Math.max(0, rank2 - 2 + 1)]; this.rate3 = new long[Math.max(0, rank2 - 3 + 1)]; this.irate3 = new long[Math.max(0, rank2 - 3 + 1)]; root[rank2] = mod_pow(g % mod, (mod - 1) >> rank2, mod); iroot[rank2] = mod_inv(root[rank2], mod); for (int i = rank2 - 1; i >= 0; i--) { root[i] = (root[i + 1] * root[i + 1]) % mod; iroot[i] = (iroot[i + 1] * iroot[i + 1]) % mod; } { long prod = 1; long iprod = 1; for (int i = 0; i <= rank2 - 2; i++) { rate2[i] = (root[i + 2] * prod) % mod; irate2[i] = (iroot[i + 2] * iprod) % mod; prod = (prod * iroot[i + 2]) % mod; iprod = (iprod * root[i + 2]) % mod; } } { long prod = 1; long iprod = 1; for (int i = 0; i <= rank2 - 3; i++) { rate3[i] = (root[i + 3] * prod) % mod; irate3[i] = (iroot[i + 3] * iprod) % mod; prod = (prod * iroot[i + 3]) % mod; iprod = (iprod * root[i + 3]) % mod; } } } }; long mod_inv(long a, long m) { long x0 = 1; long y0 = 0; long x1 = 0; long y1 = 1; long b = m; while ( b != 0 ) { long q = a / b; long tmp = b; b = a % b; a = tmp; tmp = x1; x1 = x0 - q * x1; x0 = tmp; tmp = y1; y1 = y0 - q * y1; y0 = tmp; } return (x0 + m) % m; } long mod_pow(long a, long n, long m) { long ret = 1L; while ( n > 0 ) { if ( (n & 1L) != 0 ) ret = (ret * a) % m; a = (a * a) % m; n >>= 1; } return ret; } int bit_ceil(int n) { int x = 1; while (x < n) x *= 2; return x; } long safe_mod(long x, long m) { x %= m; if ( x < 0 ) x += m; return x; } long[] inv_gcd(long a, long b) { a = safe_mod(a, b); if (a == 0) return new long[] {b, 0}; long s = b; long t = a; long m0 = 0; long m1 = 1; while (t != 0) { long u = s / t; s -= t * u; m0 -= m1 * u; long tmp = s; s = t; t = tmp; tmp = m0; m0 = m1; m1 = tmp; } if (m0 < 0) m0 += b / s; return new long[] {s, m0}; } int primitive_root(int m) { if (m == 2) return 1; if (m == 167772161) return 3; if (m == 469762049) return 3; if (m == 754974721) return 11; if (m == 998244353) return 3; int[] divs = new int[20]; divs[0] = 2; int cnt = 1; int x = (m - 1) / 2; while (x % 2 == 0) x /= 2; for (int i = 3; ((long)(i))*i <= x; i += 2) { if (x % i == 0) { divs[cnt++] = i; while (x % i == 0) { x /= i; } } } if (x > 1) { divs[cnt++] = x; } for (int g = 2;; g++) { boolean ok = true; for (int i = 0; i < cnt; i++) { if (mod_pow(g, (m - 1) / divs[i], m) == 1) { ok = false; break; } } if (ok) return g; } } void butterfly(long[] a, long mod) { int n = a.length; int h = Integer.numberOfTrailingZeros(n); FftInfo info = new FftInfo(mod); int len = 0; while (len < h) { if (h - len == 1) { int p = 1 << (h - len - 1); long rot = 1; for (int s = 0; s < (1 << len); s++) { int offset = s << (h - len); for (int i = 0; i < p; i++) { long l = a[i + offset]; long r = (a[i + offset + p] * rot) % mod; a[i + offset] = (l + r) % mod; a[i + offset + p] = safe_mod(l - r, mod); } if (s + 1 != (1 << len)) rot = (rot * info.rate2[Integer.numberOfTrailingZeros(~s)]) % mod; } len++; } else { // 4-base int p = 1 << (h - len - 2); long rot = 1; long imag = info.root[2]; for (int s = 0; s < (1 << len); s++) { long rot2 = (rot * rot) % mod; long rot3 = (rot2 * rot) % mod; int offset = s << (h - len); for (int i = 0; i < p; i++) { long mod2 = mod * mod; long a0 = a[i + offset]; long a1 = a[i + offset + p] * rot; long a2 = a[i + offset + 2 * p] * rot2; long a3 = a[i + offset + 3 * p] * rot3; long a1na3imag = safe_mod(a1 + mod2 - a3, mod) * imag; long na2 = mod2 - a2; a[i + offset] = (a0 + a2 + a1 + a3) % mod; a[i + offset + 1 * p] = safe_mod(a0 + a2 + (2 * mod2 - (a1 + a3)), mod); a[i + offset + 2 * p] = (a0 + na2 + a1na3imag) % mod; a[i + offset + 3 * p] = (a0 + na2 + (mod2 - a1na3imag)) % mod; } if (s + 1 != (1 << len)) rot = (rot * info.rate3[Integer.numberOfTrailingZeros(~s)]) % mod; } len += 2; } } } void butterfly_inv(long[] a, long mod) { int n = a.length; int h = Integer.numberOfTrailingZeros(n); FftInfo info = new FftInfo(mod); int len = h; while (len != 0) { if (len == 1) { int p = 1 << (h - len); long irot = 1; for (int s = 0; s < (1 << (len - 1)); s++) { int offset = s << (h - len + 1); for (int i = 0; i < p; i++) { long l = a[i + offset]; long r = a[i + offset + p]; a[i + offset] = (l + r) % mod; a[i + offset + p] = (safe_mod(mod + l - r, mod) * irot) % mod; } if (s + 1 != (1 << (len - 1))) irot = (irot * info.irate2[Integer.numberOfTrailingZeros(~s)]) % mod; } len--; } else { // 4-base int p = 1 << (h - len); long irot = 1; long iimag = info.iroot[2]; for (int s = 0; s < (1 << (len - 2)); s++) { long irot2 = (irot * irot) % mod; long irot3 = (irot2 * irot) % mod; int offset = s << (h - len + 2); for (int i = 0; i < p; i++) { long a0 = a[i + offset + 0 * p]; long a1 = a[i + offset + 1 * p]; long a2 = a[i + offset + 2 * p]; long a3 = a[i + offset + 3 * p]; long a2na3iimag = (safe_mod(mod + a2 - a3, mod) * iimag) % mod; a[i + offset] = (a0 + a1 + a2 + a3) % mod; a[i + offset + 1 * p] = (((a0 + (mod - a1) + a2na3iimag) % mod) * irot) % mod; a[i + offset + 2 * p] = (((a0 + a1 + (mod - a2) + (mod - a3)) % mod) * irot2) % mod; a[i + offset + 3 * p] = (((a0 + (mod - a1) + (mod - a2na3iimag)) % mod) * irot3) % mod; } if (s + 1 != (1 << (len - 2))) irot = (irot * info.irate3[Integer.numberOfTrailingZeros(~s)]) % mod; } len -= 2; } } } long[] convolution_naive(long[] a, long[] b, long mod) { int n = a.length; int m = b.length; long[] ans = new long[n + m - 1]; if (n < m) { for (int j = 0; j < m; j++) { for (int i = 0; i < n; i++) { ans[i + j] = (ans[i + j] + a[i] * b[j]) % mod; } } } else { for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { ans[i + j] = (ans[i + j] + a[i] * b[j]) % mod; } } } return ans; } long[] convolution_fft(long[] a, long[] b, long mod) { int n = a.length; int m = b.length; int z = bit_ceil(n + m - 1); a = Arrays.copyOf(a, z); butterfly(a, mod); b = Arrays.copyOf(b, z); butterfly(b, mod); for (int i = 0; i < z; i++) { a[i] = (a[i] * b[i]) % mod; } butterfly_inv(a, mod); a = Arrays.copyOf(a, n + m - 1); long iz = mod_inv(z % mod, mod); for (int i = 0; i < n + m - 1; i++) a[i] = (a[i] * iz) % mod; return a; } public long[] convolution(long[] a, long[] b, long mod) { int n = a.length; int m = b.length; if (n == 0 || m == 0) return new long[0]; int z = bit_ceil(n + m - 1); if ( (mod - 1) % z != 0 ) throw new RuntimeException(); if (Math.min(n, m) <= 60) return convolution_naive(a, b, mod); return convolution_fft(a, b, mod); } public long[] convolution_ll(long[] a, long[] b) { int n = a.length; int m = b.length; if (n == 0 || m == 0) return new long[0]; final long MOD1 = 754974721; // 2^24 final long MOD2 = 167772161; // 2^25 final long MOD3 = 469762049; // 2^26 final long M2M3 = MOD2 * MOD3; final long M1M3 = MOD1 * MOD3; final long M1M2 = MOD1 * MOD2; final long M1M2M3 = MOD1 * MOD2 * MOD3; final long i1 = inv_gcd(MOD2 * MOD3, MOD1)[1]; final long i2 = inv_gcd(MOD1 * MOD3, MOD2)[1]; final long i3 = inv_gcd(MOD1 * MOD2, MOD3)[1]; final int MAX_AB_BIT = 24; if ( n + m - 1 > (1 << MAX_AB_BIT) ) throw new RuntimeException(); long[] c1 = convolution(a, b, MOD1); long[] c2 = convolution(a, b, MOD2); long[] c3 = convolution(a, b, MOD3); long[] c = new long[n + m - 1]; for (int i = 0; i < n + m - 1; i++) { long x = 0; x += ((c1[i] * i1) % MOD1) * M2M3; x += ((c2[i] * i2) % MOD2) * M1M3; x += ((c3[i] * i3) % MOD3) * M1M2; long diff = c1[i] - safe_mod(x, MOD1); if (diff < 0) diff += MOD1; final long[] offset = new long[]{0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3}; x -= offset[(int)safe_mod(diff, 5)]; c[i] = x; } return c; } } static final char LF = '\n'; static final char SPACE = ' '; static final String YES = "Yes"; static final String NO = "No"; void print(int[] array, char sep) { print(array, sep, n -> n, 0, array.length); } void print(int[] array, char sep, IntFunction<Integer> conv) { print(array, sep, conv, 0, array.length); } void print(int[] array, char sep, IntFunction<Integer> conv, int start, int end) { StringBuilder ans = new StringBuilder(); for (int i = start; i < end; i++) { ans.append(conv.apply(array[i])); ans.append(sep); } ans.deleteCharAt(ans.length() - 1); System.out.println(ans.toString()); } void print(long[] array, char sep) { print(array, sep, n -> n, 0, array.length); } void print(long[] array, char sep, LongFunction<Long> conv) { print(array, sep, conv, 0, array.length); } void print(long[] array, char sep, LongFunction<Long> conv, int start, int end) { StringBuilder ans = new StringBuilder(); for (int i = start; i < end; i++) { ans.append(conv.apply(array[i])); ans.append(sep); } ans.deleteCharAt(ans.length() - 1); System.out.println(ans.toString()); } void printYesNo(boolean yesno) { System.out.println(yesno ? YES : NO); } void printDouble(double val, int digit) { System.out.println(String.format("%." + digit + "f", val)); } class FastScanner { private final InputStream in; private final byte[] buf = new byte[1024]; private int ptr = 0; private int buflen = 0; FastScanner( InputStream source ) { this.in = source; } private boolean hasNextByte() { if ( ptr < buflen ) return true; else { ptr = 0; try { buflen = in.read(buf); } catch (IOException e) { e.printStackTrace(); } if ( buflen <= 0 ) return false; } return true; } private int readByte() { if ( hasNextByte() ) return buf[ptr++]; else return -1; } private boolean isPrintableChar( int c ) { return 33 <= c && c <= 126; } private boolean isNumeric( int c ) { return '0' <= c && c <= '9'; } private void skipToNextPrintableChar() { while ( hasNextByte() && !isPrintableChar(buf[ptr]) ) ptr++; } public boolean hasNext() { skipToNextPrintableChar(); return hasNextByte(); } public String next() { if ( !hasNext() ) throw new NoSuchElementException(); StringBuilder ret = new StringBuilder(); int b = readByte(); while ( isPrintableChar(b) ) { ret.appendCodePoint(b); b = readByte(); } return ret.toString(); } public long nextLong() { if ( !hasNext() ) throw new NoSuchElementException(); long ret = 0; int b = readByte(); boolean negative = false; if ( b == '-' ) { negative = true; if ( hasNextByte() ) b = readByte(); } if ( !isNumeric(b) ) throw new NumberFormatException(); while ( true ) { if ( isNumeric(b) ) ret = ret * 10 + b - '0'; else if ( b == -1 || !isPrintableChar(b) ) return negative ? -ret : ret; else throw new NumberFormatException(); b = readByte(); } } public int nextInt() { return (int)nextLong(); } public double nextDouble() { return Double.parseDouble(next()); } } }