結果

問題 No.2313 Product of Subsequence (hard)
ユーザー 遭難者遭難者
提出日時 2023-05-24 10:41:15
言語 Java21
(openjdk 21)
結果
TLE  
実行時間 -
コード長 25,235 bytes
コンパイル時間 3,622 ms
コンパイル使用メモリ 87,168 KB
実行使用メモリ 69,848 KB
最終ジャッジ日時 2023-08-25 13:04:11
合計ジャッジ時間 10,925 ms
ジャッジサーバーID
(参考情報)
judge13 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
51,352 KB
testcase_01 AC 37 ms
49,704 KB
testcase_02 AC 41 ms
51,340 KB
testcase_03 AC 47 ms
51,116 KB
testcase_04 AC 51 ms
52,280 KB
testcase_05 AC 61 ms
53,584 KB
testcase_06 AC 45 ms
51,404 KB
testcase_07 AC 52 ms
52,380 KB
testcase_08 AC 461 ms
61,616 KB
testcase_09 AC 163 ms
57,340 KB
testcase_10 AC 312 ms
65,232 KB
testcase_11 AC 165 ms
57,556 KB
testcase_12 AC 286 ms
65,516 KB
testcase_13 TLE -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.util.*;
import java.io.*;

class Main {

    private static final long solve(int n, int k, long[] input) {
        if (k == 1) {
            long ans = mp(2, n) - 1;
            if ((ans %= mod) < 0)
                ans += mod;
            return ans;
        }
        var tt = p(k);
        int m = tt.size();
        var t = new long[m];
        int ttt = 0;
        for (long i : tt)
            t[ttt++] = i;
        Arrays.sort(t);
        var c = new int[m];
        for (int i = 0; i < n; i++)
            c[Arrays.binarySearch(t, g(input[i], k))]++;
        var d = new long[m];
        d[0] = 1L;
        fa = new long[n + 1];
        fai = new long[n + 1];
        fa[0] = 1L;
        for (int i = 1; i <= n; i++)
            fa[i] = i * fa[i - 1] % mod;
        fai[n] = mp(fa[n], mod - 2);
        for (int i = n; i > 0; i--)
            fai[i - 1] = i * fai[i] % mod;
        for (int x = 1; x < m; x++) {
            final int len = c[x];
            var dd = new long[m];
            var cc1 = new long[len + 1];
            var cc2 = new long[len + 1];
            for (int i = 0; i <= len; i++)
                cc1[i] = cc2[i] = comb(len, i);
            for (int i = len; i > 0; i--)
                if ((cc2[i - 1] += cc2[i]) >= mod)
                    cc2[i - 1] -= mod;
            for (int i = 0; i < m; i++) {
                int loop = 0;
                long now = t[i];
                long before_now = -1;
                while (true) {
                    if (before_now == now || loop > len)
                        break;
                    int key = Arrays.binarySearch(t, now);
                    if (key < 0)
                        break;
                    dd[key] += d[i] * cc1[loop];
                    dd[key] %= mod;
                    now = g(now * t[x], k);
                    loop++;
                }
                if (loop <= len) {
                    int key = Arrays.binarySearch(t, now);
                    dd[key] += d[i] * cc2[loop];
                    dd[key] %= mod;
                }
            }
            d = dd;
        }
        // ou.println(d[m - 1]);
        return d[m - 1] * mp(2, c[0]) % mod;
    }

    private static final void start() throws IOException {
        boolean submit = true;
        if (submit) {
            int t = 1;
            // t = ni();
            for (int i = 0; i < t; i++) {
                int n = ni(), k = ni();
                long[] a = nl(n);
                long ans = solve(n, k, a);
                ou.println(ans);
            }
        } else {
            var rand = new Random();
            for (int jud = 1; jud <= 10000; jud++) {
                int n = rand.nextInt(10) + 3;
                int k = rand.nextInt(20) + 2;
                var a = new long[n];
                for (int i = 0; i < n; i++)
                    a[i] = rand.nextInt(20) + 1;
                long ans1 = solve(n, k, a);
                long ans2 = gu(n, k, a);
                if (ans1 != ans2) {
                    ou.println(ans1 + " " + ans2);
                    ou.println(n + " " + k);
                    ou.print(a);
                    break;
                }
            }
        }
    }

    private static final int mod = 998244353;
    private static long[] fa, fai;

    private static final long comb(int n, int k) {
        return fai[k] * fai[n - k] % mod * fa[n] % mod;
    }

    private static final long gu(int n, long k, long[] input) {
        var a = new long[n];
        for (int i = 0; i < n; i++)
            a[i] = input[i] % k;
        long ans = 0;
        int nn = 1 << n;
        for (int i = 1; i < nn; i++) {
            long mul = 1L;
            for (int j = 0; j < n; j++) {
                if (((i >> j) & 1) == 1)
                    mul = mul * a[j] % k;
            }
            if (mul == 0)
                ans++;
        }
        return ans;
    }

    private static final long mp(long a, int b) {
        if (b == 0)
            return 1L;
        if ((b & 1) == 0)
            return mp(a * a % mod, b >> 1);
        return mp(a, b - 1) * a % mod;
    }

    private static final ArrayList<Long> p(long n) {
        var t = new ArrayList<Long>();
        long i = 1;
        while (i * i < n) {
            if (n % i == 0) {
                t.add(i);
                t.add(n / i);
            }
            i++;
        }
        if (i * i == n)
            t.add(i);
        return t;
    }

    private static final long g(long x, long y) {
        return y == 0 ? x : g(y, x % y);
    }

    public static void main(String[] args) throws IOException {
        start();
        ou.flush();
    }

    private static final int ni() throws IOException {
        return sc.nextInt();
    }

    private static final int[] ni(int n) throws IOException {
        return sc.nextIntArray(n);
    }

    private static final long nl() throws IOException {
        return sc.nextLong();
    }

    private static final long[] nl(int n) throws IOException {
        return sc.nextLongArray(n);
    }

    private static final String ns() throws IOException {
        return sc.next();
    }

    private static final char[] nc() throws IOException {
        return sc.nextCharArray();
    }

    private static final double nd() throws IOException {
        return sc.nextDouble();
    }

    private static final ContestInputStream sc = new ContestInputStream();
    private static final ContestOutputStream ou = new ContestOutputStream();
}

final class ContestInputStream extends FilterInputStream {
    protected final byte[] buf;
    protected int pos = 0;
    protected int lim = 0;
    private final char[] cbuf;

    public ContestInputStream() {
        super(System.in);
        this.buf = new byte[1 << 13];
        this.cbuf = new char[1 << 20];
    }

    boolean hasRemaining() throws IOException {
        if (pos < lim)
            return true;
        lim = in.read(buf);
        pos = 0;
        return lim > 0;
    }

    final int remaining() throws IOException {
        if (pos >= lim) {
            lim = in.read(buf);
            pos = 0;
        }
        return lim - pos;
    }

    @Override
    public final int available() throws IOException {
        if (pos < lim)
            return lim - pos + in.available();
        return in.available();
    }

    @Override
    public final long skip(long n) throws IOException {
        if (pos < lim) {
            int rem = lim - pos;
            if (n < rem) {
                pos += n;
                return n;
            }
            pos = lim;
            return rem;
        }
        return in.skip(n);
    }

    @Override
    public final int read() throws IOException {
        if (hasRemaining())
            return buf[pos++];
        return -1;
    }

    @Override
    public final int read(byte[] b, int off, int len) throws IOException {
        if (pos < lim) {
            int rem = Math.min(lim - pos, len);
            for (int i = 0; i < rem; i++)
                b[off + i] = buf[pos + i];
            pos += rem;
            return rem;
        }
        return in.read(b, off, len);
    }

    public final char[] readToken() throws IOException {
        int cpos = 0;
        int rem;
        byte b;
        l: while ((rem = remaining()) > 0) {
            for (int i = 0; i < rem; i++) {
                b = buf[pos + i];
                if (b <= 0x20) {
                    pos += i + 1;
                    cpos += i;
                    if (b == 0x0d && hasRemaining() && buf[pos] == 0x0a)
                        pos++;
                    break l;
                }
                cbuf[cpos + i] = (char) b;
            }
            pos += rem;
            cpos += rem;
        }
        char[] arr = new char[cpos];
        for (int i = 0; i < cpos; i++)
            arr[i] = cbuf[i];
        return arr;
    }

    public final int readToken(char[] cbuf, int off) throws IOException {
        int cpos = off;
        int rem;
        byte b;
        l: while ((rem = remaining()) > 0) {
            for (int i = 0; i < rem; i++) {
                b = buf[pos + i];
                if (b <= 0x20) {
                    pos += i + 1;
                    cpos += i;
                    if (b == 0x0d && hasRemaining() && buf[pos] == 0x0a)
                        pos++;
                    break l;
                }
                cbuf[cpos + i] = (char) b;
            }
            pos += rem;
            cpos += rem;
        }
        return cpos - off;
    }

    public final int readToken(char[] cbuf) throws IOException {
        return readToken(cbuf, 0);
    }

    public final String next() throws IOException {
        int cpos = 0;
        int rem;
        byte b;
        l: while ((rem = remaining()) > 0) {
            for (int i = 0; i < rem; i++) {
                b = buf[pos + i];
                if (b <= 0x20) {
                    pos += i + 1;
                    cpos += i;
                    if (b == 0x0d && hasRemaining() && buf[pos] == 0x0a)
                        pos++;
                    break l;
                }
                cbuf[cpos + i] = (char) b;
            }
            pos += rem;
            cpos += rem;
        }
        return String.valueOf(cbuf, 0, cpos);
    }

    public final char[] nextCharArray() throws IOException {
        return readToken();
    }

    public final int nextInt() throws IOException {
        if (!hasRemaining())
            return 0;
        int value = 0;
        byte b = buf[pos++];
        if (b == 0x2d) {
            while (hasRemaining() && (b = buf[pos++]) > 0x20)
                value = (value << 3) + (value << 1) - b + 0x30;
        } else {
            do {
                value = (value << 3) + (value << 1) + b - 0x30;
            } while (hasRemaining() && (b = buf[pos++]) > 0x20);
        }
        if (b == 0x0d && hasRemaining() && buf[pos] == 0x0a)
            pos++;
        return value;
    }

    public final long nextLong() throws IOException {
        if (!hasRemaining())
            return 0L;
        long value = 0L;
        byte b = buf[pos++];
        if (b == 0x2d) {
            while (hasRemaining() && (b = buf[pos++]) > 0x20)
                value = (value << 3) + (value << 1) - b + 0x30;
        } else {
            do {
                value = (value << 3) + (value << 1) + b - 0x30;
            } while (hasRemaining() && (b = buf[pos++]) > 0x20);
        }
        if (b == 0x0d && hasRemaining() && buf[pos] == 0x0a)
            pos++;
        return value;
    }

    public final char nextChar() throws IOException {
        if (!hasRemaining())
            throw new EOFException();
        final char c = (char) buf[pos++];
        if (hasRemaining() && buf[pos++] == 0x0d && hasRemaining() && buf[pos] == 0x0a)
            pos++;
        return c;
    }

    public final float nextFloat() throws IOException {
        return Float.parseFloat(next());
    }

    public final double nextDouble() throws IOException {
        return Double.parseDouble(next());
    }

    public final boolean[] nextBooleanArray(char ok) throws IOException {
        char[] s = readToken();
        int n = s.length;
        boolean[] t = new boolean[n];
        for (int i = 0; i < n; i++)
            t[i] = s[i] == ok;
        return t;
    }

    public final boolean[][] nextBooleanMatrix(int h, int w, char ok) throws IOException {
        boolean[][] s = new boolean[h][];
        for (int i = 0; i < h; i++) {
            char[] t = readToken();
            int n = t.length;
            s[i] = new boolean[n];
            for (int j = 0; j < n; j++)
                s[i][j] = t[j] == ok;
        }
        return s;
    }

    public final String[] nextStringArray(int len) throws IOException {
        String[] arr = new String[len];
        for (int i = 0; i < len; i++)
            arr[i] = next();
        return arr;
    }

    public final int[] nextIntArray(int len) throws IOException {
        int[] arr = new int[len];
        for (int i = 0; i < len; i++)
            arr[i] = nextInt();
        return arr;
    }

    public final int[] nextIntArray(int len, java.util.function.IntUnaryOperator map) throws IOException {
        int[] arr = new int[len];
        for (int i = 0; i < len; i++)
            arr[i] = map.applyAsInt(nextInt());
        return arr;
    }

    public final long[] nextLongArray(int len, java.util.function.LongUnaryOperator map) throws IOException {
        long[] arr = new long[len];
        for (int i = 0; i < len; i++)
            arr[i] = map.applyAsLong(nextLong());
        return arr;
    }

    public final int[][] nextIntMatrix(int h, int w) throws IOException {
        int[][] arr = new int[h][w];
        for (int i = 0; i < h; i++)
            for (int j = 0; j < w; j++)
                arr[i][j] = nextInt();
        return arr;
    }

    public final int[][] nextIntMatrix(int h, int w, java.util.function.IntUnaryOperator map) throws IOException {
        int[][] arr = new int[h][w];
        for (int i = 0; i < h; i++)
            for (int j = 0; j < w; j++)
                arr[i][j] = map.applyAsInt(nextInt());
        return arr;
    }

    public final long[] nextLongArray(int len) throws IOException {
        long[] arr = new long[len];
        for (int i = 0; i < len; i++)
            arr[i] = nextLong();
        return arr;
    }

    public final long[][] nextLongMatrix(int h, int w) throws IOException {
        long[][] arr = new long[h][w];
        for (int i = 0; i < h; i++)
            for (int j = 0; j < w; j++)
                arr[i][j] = nextLong();
        return arr;
    }

    public final float[] nextFloatArray(int len) throws IOException {
        float[] arr = new float[len];
        for (int i = 0; i < len; i++)
            arr[i] = nextFloat();
        return arr;
    }

    public final double[] nextDoubleArray(int len) throws IOException {
        double[] arr = new double[len];
        for (int i = 0; i < len; i++)
            arr[i] = nextDouble();
        return arr;
    }

    public final char[][] nextCharMatrix(int h, int w) throws IOException {
        char[][] arr = new char[h][];
        for (int i = 0; i < h; i++)
            arr[i] = readToken();
        return arr;
    }

    public final void nextThrow() throws IOException {
        next();
        return;
    }

    public final void nextThrow(int n) throws IOException {
        for (int i = 0; i < n; i++)
            nextThrow();
        return;
    }
}

final class ContestOutputStream extends FilterOutputStream implements Appendable {
    protected final byte[] buf;
    protected int pos = 0;

    public ContestOutputStream() {
        super(System.out);
        this.buf = new byte[1 << 13];
    }

    @Override
    public void flush() throws IOException {
        out.write(buf, 0, pos);
        pos = 0;
        out.flush();
    }

    void put(byte b) throws IOException {
        if (pos >= buf.length)
            flush();
        buf[pos++] = b;
    }

    int remaining() throws IOException {
        if (pos >= buf.length)
            flush();
        return buf.length - pos;
    }

    @Override
    public void write(int b) throws IOException {
        put((byte) b);
    }

    @Override
    public void write(byte[] b, int off, int len) throws IOException {
        int o = off;
        int l = len;
        while (l > 0) {
            int rem = Math.min(remaining(), l);
            for (int i = 0; i < rem; i++)
                buf[pos + i] = b[o + i];
            pos += rem;
            o += rem;
            l -= rem;
        }
    }

    @Override
    public ContestOutputStream append(char c) throws IOException {
        put((byte) c);
        return this;
    }

    @Override
    public ContestOutputStream append(CharSequence csq, int start, int end) throws IOException {
        int off = start;
        int len = end - start;
        while (len > 0) {
            int rem = Math.min(remaining(), len);
            for (int i = 0; i < rem; i++)
                buf[pos + i] = (byte) csq.charAt(off + i);
            pos += rem;
            off += rem;
            len -= rem;
        }
        return this;
    }

    @Override
    public ContestOutputStream append(CharSequence csq) throws IOException {
        return append(csq, 0, csq.length());
    }

    public ContestOutputStream append(char[] arr, int off, int len) throws IOException {
        int o = off;
        int l = len;
        while (l > 0) {
            int rem = Math.min(remaining(), l);
            for (int i = 0; i < rem; i++)
                buf[pos + i] = (byte) arr[o + i];
            pos += rem;
            o += rem;
            l -= rem;
        }
        return this;
    }

    public ContestOutputStream print(char[] arr) throws IOException {
        return append(arr, 0, arr.length).newLine();
    }

    public ContestOutputStream print(boolean value) throws IOException {
        if (value)
            return append("o");
        return append(".");
    }

    public ContestOutputStream println(boolean value) throws IOException {
        if (value)
            return append("o\n");
        return append(".\n");
    }

    public ContestOutputStream print(boolean[][] value) throws IOException {
        final int n = value.length, m = value[0].length;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < m; j++)
                print(value[i][j]);
            newLine();
        }
        return this;
    }

    public ContestOutputStream print(int value) throws IOException {
        return append(String.valueOf(value));
    }

    public ContestOutputStream println(int value) throws IOException {
        return append(String.valueOf(value)).newLine();
    }

    public ContestOutputStream print(long value) throws IOException {
        return append(String.valueOf(value));
    }

    public ContestOutputStream println(long value) throws IOException {
        return append(String.valueOf(value)).newLine();
    }

    public ContestOutputStream print(float value) throws IOException {
        return append(String.valueOf(value));
    }

    public ContestOutputStream println(float value) throws IOException {
        return append(String.valueOf(value)).newLine();
    }

    private ContestOutputStream dtos(double x, int n) throws IOException {
        if (x < 0) {
            append('-');
            x = -x;
        }
        x += Math.pow(10, -n) / 2;
        long longx = (long) x;
        print(longx);
        append('.');
        x -= longx;
        for (int i = 0; i < n; i++) {
            x *= 10;
            int intx = (int) x;
            print(intx);
            x -= intx;
        }
        return this;
    }

    public ContestOutputStream print(double value) throws IOException {
        return dtos(value, 20);
    }

    public ContestOutputStream println(double value) throws IOException {
        return dtos(value, 20).newLine();
    }

    public ContestOutputStream print(char value) throws IOException {
        return append(value);
    }

    public ContestOutputStream println(char value) throws IOException {
        return append(value).newLine();
    }

    public ContestOutputStream print(String value) throws IOException {
        return append(value);
    }

    public ContestOutputStream println(String value) throws IOException {
        return append(String.valueOf(value)).newLine();
    }

    public ContestOutputStream print(Object value) throws IOException {
        return append(String.valueOf(value));
    }

    public ContestOutputStream println(Object value) throws IOException {
        return append(String.valueOf(value)).newLine();
    }

    public ContestOutputStream printYN(boolean yes) throws IOException {
        if (yes)
            return println("Yes");
        return println("No");
    }

    public ContestOutputStream printAB(boolean yes) throws IOException {
        if (yes)
            return println("Alice");
        return println("Bob");
    }

    public ContestOutputStream print(CharSequence[] arr) throws IOException {
        if (arr.length > 0) {
            append(arr[0]);
            for (int i = 1; i < arr.length; i++)
                append('\u0020').append(arr[i]);
        }
        return this;
    }

    public ContestOutputStream print(int[] arr) throws IOException {
        if (arr.length > 0) {
            print(arr[0]);
            for (int i = 1; i < arr.length; i++)
                append('\u0020').print(arr[i]);
        }
        newLine();
        return this;
    }

    public ContestOutputStream print(int[] arr, int length) throws IOException {
        if (length > 0)
            print(arr[0]);
        for (int i = 1; i < length; i++)
            append('\u0020').print(arr[i]);
        newLine();
        return this;
    }

    public ContestOutputStream println(int[] arr) throws IOException {
        for (int i : arr)
            print(i).newLine();
        return this;
    }

    public ContestOutputStream println(int[] arr, int length) throws IOException {
        for (int i = 0; i < length; i++)
            println(arr[i]);
        return this;
    }

    public ContestOutputStream print(boolean[] arr) throws IOException {
        if (arr.length > 0) {
            print(arr[0]);
            for (int i = 1; i < arr.length; i++)
                append('\u0020').print(arr[i]);
        }
        newLine();
        return this;
    }

    public ContestOutputStream print(long[] arr) throws IOException {
        if (arr.length > 0) {
            print(arr[0]);
            for (int i = 1; i < arr.length; i++)
                append('\u0020').print(arr[i]);
        }
        newLine();
        return this;
    }

    public ContestOutputStream print(long[] arr, int length) throws IOException {
        if (length > 0)
            print(arr[0]);
        for (int i = 1; i < length; i++)
            append('\u0020').print(arr[i]);
        newLine();
        return this;
    }

    public ContestOutputStream println(long[] arr, int length) throws IOException {
        for (int i = 0; i < length; i++)
            println(arr[i]);
        return this;
    }

    public ContestOutputStream println(long[] arr) throws IOException {
        for (long i : arr)
            print(i).newLine();
        return this;
    }

    public ContestOutputStream print(float[] arr) throws IOException {
        if (arr.length > 0) {
            print(arr[0]);
            for (int i = 1; i < arr.length; i++)
                append('\u0020').print(arr[i]);
        }
        return this;
    }

    public ContestOutputStream println(float[] arr) throws IOException {
        for (float i : arr)
            print(i).newLine();
        return this;
    }

    public ContestOutputStream print(double[] arr) throws IOException {
        if (arr.length > 0) {
            print(arr[0]);
            for (int i = 1; i < arr.length; i++)
                append('\u0020').print(arr[i]);
        }
        return newLine();
    }

    public ContestOutputStream println(double[] arr) throws IOException {
        for (double i : arr)
            print(i).newLine();
        return this;
    }

    public ContestOutputStream print(Object[] arr) throws IOException {
        if (arr.length > 0) {
            print(arr[0]);
            for (int i = 1; i < arr.length; i++)
                append('\u0020').print(arr[i]);
        }
        return newLine();
    }

    public ContestOutputStream print(java.util.ArrayList<?> arr) throws IOException {
        if (!arr.isEmpty()) {
            final int n = arr.size();
            print(arr.get(0));
            for (int i = 1; i < n; i++)
                print(" ").print(arr.get(i));
        }
        return newLine();
    }

    public ContestOutputStream println(java.util.ArrayList<?> arr) throws IOException {
        final int n = arr.size();
        for (int i = 0; i < n; i++)
            println(arr.get(i));
        return this;
    }

    public ContestOutputStream newLine() throws IOException {
        return append(System.lineSeparator());
    }

    public ContestOutputStream endl() throws IOException {
        newLine().flush();
        return this;
    }

    public ContestOutputStream print(int[][] arr) throws IOException {
        for (int[] i : arr)
            print(i);
        return this;
    }

    public ContestOutputStream print(long[][] arr) throws IOException {
        for (long[] i : arr)
            print(i);
        return this;
    }

    public ContestOutputStream print(char[][] arr) throws IOException {
        for (char[] i : arr)
            print(i);
        return this;
    }

    public ContestOutputStream print(Object[][] arr) throws IOException {
        for (Object[] i : arr)
            print(i);
        return this;
    }

    public ContestOutputStream println() throws IOException {
        return newLine();
    }

    public ContestOutputStream println(Object... arr) throws IOException {
        for (Object i : arr)
            print(i);
        return newLine();
    }

    public ContestOutputStream printToChar(int c) throws IOException {
        return print((char) c);
    }
}
0