結果

問題 No.1300 Sum of Inversions
ユーザー suisensuisen
提出日時 2020-11-27 21:56:31
言語 Java21
(openjdk 21)
結果
RE  
(最新)
AC  
(最初)
実行時間 -
コード長 48,508 bytes
コンパイル時間 4,609 ms
コンパイル使用メモリ 100,764 KB
実行使用メモリ 37,624 KB
最終ジャッジ日時 2024-07-26 12:31:11
合計ジャッジ時間 7,926 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 RE -
testcase_01 RE -
testcase_02 RE -
testcase_03 RE -
testcase_04 RE -
testcase_05 RE -
testcase_06 RE -
testcase_07 RE -
testcase_08 RE -
testcase_09 RE -
testcase_10 RE -
testcase_11 RE -
testcase_12 RE -
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 RE -
testcase_18 RE -
testcase_19 RE -
testcase_20 RE -
testcase_21 RE -
testcase_22 RE -
testcase_23 RE -
testcase_24 RE -
testcase_25 RE -
testcase_26 RE -
testcase_27 RE -
testcase_28 RE -
testcase_29 RE -
testcase_30 RE -
testcase_31 RE -
testcase_32 RE -
testcase_33 RE -
testcase_34 RE -
testcase_35 RE -
testcase_36 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.function.IntBinaryOperator;
import java.util.function.IntPredicate;
import java.util.function.IntToLongFunction;
import java.util.function.IntUnaryOperator;
import java.util.function.LongPredicate;
import java.util.function.LongUnaryOperator;

public class Main {
    public static void main(String[] args) throws Exception {
        ExtendedScanner sc = new ExtendedScanner();
        FastPrintStream pw = new FastPrintStream();
        solve(sc, pw);
        sc.close();
        pw.flush();
        pw.close();
    }

    static final ModArithmetic MA = ModArithmetic998244353.INSTANCE;

    public static void solve(ExtendedScanner sc, FastPrintStream pw) {
        int n = sc.nextInt();
        long[] a = sc.longs(n);
        LongCompress cmp = new LongCompress(a, false);
        IntSegmentTree s1 = new IntSegmentTree(n, Integer::sum, 0);
        IntSegmentTree s2 = new IntSegmentTree(n, Integer::sum, 0);
        LongSegmentTree t1 = new LongSegmentTree(n, MA::add, 0);
        LongSegmentTree t2 = new LongSegmentTree(n, MA::add, 0);
        for (long v : a) {
            int vi = cmp.compress(v);
            t2.set(vi, MA.add(t2.get(vi), v));
            s2.set(vi, s2.get(vi) + 1);
        }
        long ans = 0;
        for (int j = 0; j < n; j++) {
            int idx = cmp.compress(a[j]);
            t2.set(idx, MA.sub(t2.get(idx), a[j]));
            s2.set(idx, s2.get(idx) - 1);
            int l = s1.prod(idx + 1, n);
            int r = s2.prod(0, idx);
            long lsum = t1.prod(idx + 1, n);
            long rsum = t2.prod(0, idx);
            t1.set(idx, MA.add(t1.get(idx), a[j]));
            s1.set(idx, s1.get(idx) + 1);
            ans += MA.add(MA.mul(lsum, r), MA.mul(a[j], l, r), MA.mul(rsum, l));
        }
        pw.println(MA.mod(ans));
    }
}

class LongSegmentTree {
    final int MAX;

    final int N;
    final java.util.function.LongBinaryOperator op;
    final long E;

    final long[] data;

    public LongSegmentTree(int n, java.util.function.LongBinaryOperator op, long e) {
        this.MAX = n;
        int k = 1;
        while (k < n) k <<= 1;
        this.N = k;
        this.E = e;
        this.op = op;
        this.data = new long[N << 1];
        java.util.Arrays.fill(data, E);
    }

    public LongSegmentTree(long[] dat, java.util.function.LongBinaryOperator op, long e) {
        this(dat.length, op, e);
        build(dat);
    }

    private void build(long[] dat) {
        int l = dat.length;
        System.arraycopy(dat, 0, data, N, l);
        for (int i = N - 1; i > 0; i--) {
            data[i] = op.applyAsLong(data[i << 1 | 0], data[i << 1 | 1]);
        }
    }

    public void set(int p, long x) {
        exclusiveRangeCheck(p);
        data[p += N] = x;
        p >>= 1;
        while (p > 0) {
            data[p] = op.applyAsLong(data[p << 1 | 0], data[p << 1 | 1]);
            p >>= 1;
        }
    }

    public long get(int p) {
        exclusiveRangeCheck(p);
        return data[p + N];
    }

    public long prod(int l, int r) {
        if (l > r) {
            throw new IllegalArgumentException(
                String.format("Invalid range: [%d, %d)", l, r)
            );
        }
        inclusiveRangeCheck(l);
        inclusiveRangeCheck(r);
        long sumLeft = E;
        long sumRight = E;
        l += N; r += N;
        while (l < r) {
            if ((l & 1) == 1) sumLeft = op.applyAsLong(sumLeft, data[l++]);
            if ((r & 1) == 1) sumRight = op.applyAsLong(data[--r], sumRight);
            l >>= 1; r >>= 1;
        }
        return op.applyAsLong(sumLeft, sumRight);
    }

    public long allProd() {
        return data[1];
    }

    public int maxRight(int l, java.util.function.LongPredicate f) {
        inclusiveRangeCheck(l);
        if (!f.test(E)) {
            throw new IllegalArgumentException("Identity element must satisfy the condition.");
        }
        if (l == MAX) return MAX;
        l += N;
        long sum = E;
        do {
            l >>= Integer.numberOfTrailingZeros(l);
            if (!f.test(op.applyAsLong(sum, data[l]))) {
                while (l < N) {
                    l = l << 1;
                    if (f.test(op.applyAsLong(sum, data[l]))) {
                        sum = op.applyAsLong(sum, data[l]);
                        l++;
                    }
                }
                return l - N;
            }
            sum = op.applyAsLong(sum, data[l]);
            l++;
        } while ((l & -l) != l);
        return MAX;
    }

    public int minLeft(int r, java.util.function.LongPredicate f) {
        inclusiveRangeCheck(r);
        if (!f.test(E)) {
            throw new IllegalArgumentException("Identity element must satisfy the condition.");
        }
        if (r == 0) return 0;
        r += N;
        long sum = E;
        do {
            r--;
            while (r > 1 && (r & 1) == 1) r >>= 1;
            if (!f.test(op.applyAsLong(data[r], sum))) {
                while (r < N) {
                    r = r << 1 | 1;
                    if (f.test(op.applyAsLong(data[r], sum))) {
                        sum = op.applyAsLong(data[r], sum);
                        r--;
                    }
                }
                return r + 1 - N;
            }
            sum = op.applyAsLong(data[r], sum);
        } while ((r & -r) != r);
        return 0;
    }

    private void exclusiveRangeCheck(int p) {
        if (p < 0 || p >= MAX) {
            throw new IndexOutOfBoundsException(
                String.format("Index %d out of bounds for the range [%d, %d).", p, 0, MAX)
            );
        }
    }

    private void inclusiveRangeCheck(int p) {
        if (p < 0 || p > MAX) {
            throw new IndexOutOfBoundsException(
                String.format("Index %d out of bounds for the range [%d, %d].", p, 0, MAX)
            );
        }
    }

    // **************** DEBUG **************** //

    private int indent = 6;

    public void setIndent(int newIndent) {
        this.indent = newIndent;
    }

    @Override
    public String toString() {
        return toSimpleString();
    }

    public String toDetailedString() {
        return toDetailedString(1, 0);
    }

    private String toDetailedString(int k, int sp) {
        if (k >= N) return indent(sp) + data[k];
        String s = "";
        s += toDetailedString(k << 1 | 1, sp + indent);
        s += "\n";
        s += indent(sp) + data[k];
        s += "\n";
        s += toDetailedString(k << 1 | 0, sp + indent);
        return s;
    }

    private static String indent(int n) {
        StringBuilder sb = new StringBuilder();
        while (n --> 0) sb.append(' ');
        return sb.toString();
    }

    public String toSimpleString() {
        StringBuilder sb = new StringBuilder();
        sb.append('[');
        for (int i = 0; i < N; i++) {
            sb.append(data[i + N]);
            if (i < N - 1) sb.append(',').append(' ');
        }
        sb.append(']');
        return sb.toString();
    }
}


/**
 * @author https://atcoder.jp/users/suisen
 */
@FunctionalInterface
interface IntToLongBiFunction {
    public long apply(int x, int y);
    public default IntToLongFunction curry(final int x) {return y -> apply(x, y);}
}


/**
 * @author https://atcoder.jp/users/suisen
 */
final class ExtendedScanner extends FastScanner {
    public ExtendedScanner() {super();}
    public ExtendedScanner(InputStream in) {super(in);}
    public int[] ints(final int n) {
        final int[] a = new int[n];
        Arrays.setAll(a, $ -> nextInt());
        return a;
    }
    public int[] ints(final int n, final IntUnaryOperator f) {
        final int[] a = new int[n];
        Arrays.setAll(a, $ -> f.applyAsInt(nextInt()));
        return a;
    }
    public int[][] ints(final int n, final int m) {
        final int[][] a = new int[n][];
        Arrays.setAll(a, $ -> ints(m));
        return a;
    }
    public int[][] ints(final int n, final int m, final IntUnaryOperator f) {
        final int[][] a = new int[n][];
        Arrays.setAll(a, $ -> ints(m, f));
        return a;
    }
    public long[] longs(final int n) {
        final long[] a = new long[n];
        Arrays.setAll(a, $ -> nextLong());
        return a;
    }
    public long[] longs(final int n, final LongUnaryOperator f) {
        final long[] a = new long[n];
        Arrays.setAll(a, $ -> f.applyAsLong(nextLong()));
        return a;
    }
    public long[][] longs(final int n, final int m) {
        final long[][] a = new long[n][];
        Arrays.setAll(a, $ -> longs(m));
        return a;
    }
    public long[][] longs(final int n, final int m, final LongUnaryOperator f) {
        final long[][] a = new long[n][];
        Arrays.setAll(a, $ -> longs(m, f));
        return a;
    }
    public char[][] charArrays(final int n) {
        final char[][] c = new char[n][];
        Arrays.setAll(c, $ -> nextChars());
        return c;
    }
    public double[] doubles(final int n) {
        final double[] a = new double[n];
        Arrays.setAll(a, $ -> nextDouble());
        return a;
    }
    public double[][] doubles(final int n, final int m) {
        final double[][] a = new double[n][];
        Arrays.setAll(a, $ -> doubles(m));
        return a;
    }
    public String[] strings(final int n) {
        final String[] s = new String[n];
        Arrays.setAll(s, $ -> next());
        return s;
    }
}

/**
 * @author https://atcoder.jp/users/suisen
 */
class FastPrintStream implements AutoCloseable {
    private static final int INT_MAX_LEN = 11;
    private static final int LONG_MAX_LEN = 20;

    private int precision = 9;

    private static final int BUF_SIZE = 1 << 14;
    private static final int BUF_SIZE_MINUS_INT_MAX_LEN = BUF_SIZE - INT_MAX_LEN;
    private static final int BUF_SIZE_MINUS_LONG_MAX_LEN = BUF_SIZE - LONG_MAX_LEN;
    private final byte[] buf = new byte[BUF_SIZE];
    private int ptr = 0;
    private final java.lang.reflect.Field strField;
    private final java.nio.charset.CharsetEncoder encoder;

    private final java.io.OutputStream out;

    public FastPrintStream(java.io.OutputStream out) {
        this.out = out;
        java.lang.reflect.Field f;
        try {
            f = java.lang.String.class.getDeclaredField("value");
            f.setAccessible(true);
        } catch (NoSuchFieldException | SecurityException e) {
            f = null;
        }
        this.strField = f;
        this.encoder = java.nio.charset.StandardCharsets.US_ASCII.newEncoder();
    }

    public FastPrintStream(java.io.File file) throws java.io.IOException {
        this(new java.io.FileOutputStream(file));
    }

    public FastPrintStream(java.lang.String filename) throws java.io.IOException {
        this(new java.io.File(filename));
    }

    public FastPrintStream() {
        this(new java.io.FileOutputStream(java.io.FileDescriptor.out));
    }

    public FastPrintStream println() {
        if (ptr == BUF_SIZE) internalFlush();
        buf[ptr++] = (byte) '\n';
        return this;
    }

    public FastPrintStream println(java.lang.Object o) {
        return print(o).println();
    }

    public FastPrintStream println(java.lang.String s) {
        return print(s).println();
    }

    public FastPrintStream println(char[] s) {
        return print(s).println();
    }

    public FastPrintStream println(char c) {
        return print(c).println();
    }

    public FastPrintStream println(int x) {
        return print(x).println();
    }

    public FastPrintStream println(long x) {
        return print(x).println();
    }

    public FastPrintStream println(double d, int precision) {
        return print(d, precision).println();
    }

    public FastPrintStream println(double d) {
        return print(d).println();
    }

    private FastPrintStream print(byte[] bytes) {
        int n = bytes.length;
        if (ptr + n > BUF_SIZE) {
            internalFlush();
            try {
                out.write(bytes);
            } catch (java.io.IOException e) {
                throw new java.io.UncheckedIOException(e);
            }
        } else {
            System.arraycopy(bytes, 0, buf, ptr, n);
            ptr += n;
        }
        return this;
    }

    public FastPrintStream print(java.lang.Object o) {
        return print(o.toString());
    }

    public FastPrintStream print(java.lang.String s) {
        if (strField == null) {
            return print(s.getBytes());
        } else {
            try {
                Object value = strField.get(s);
                if (value instanceof byte[]) {
                    return print((byte[]) value);
                } else {
                    return print((char[]) value);
                }
            } catch (IllegalAccessException e) {
                return print(s.getBytes());
            }
        }
    }

    public FastPrintStream print(char[] s) {
        try {
            return print(encoder.encode(java.nio.CharBuffer.wrap(s)).array());
        } catch (java.nio.charset.CharacterCodingException e) {
            byte[] bytes = new byte[s.length];
            for (int i = 0; i < s.length; i++) {
                bytes[i] = (byte) s[i];
            }
            return print(bytes);
        }
    }

    public FastPrintStream print(char c) {
        if (ptr == BUF_SIZE) internalFlush();
        buf[ptr++] = (byte) c;
        return this;
    }

    public FastPrintStream print(int x) {
        if (ptr > BUF_SIZE_MINUS_INT_MAX_LEN) internalFlush();
        if (-10 < x && x < 10) {
            if (x < 0) {
                buf[ptr++] = '-';
                x = -x;
            }
            buf[ptr++] = (byte) ('0' + x);
            return this;
        }
        int d;
        if (x < 0) {
            if (x == Integer.MIN_VALUE) {
                buf[ptr++] = '-'; buf[ptr++] = '2'; buf[ptr++] = '1'; buf[ptr++] = '4';
                buf[ptr++] = '7'; buf[ptr++] = '4'; buf[ptr++] = '8'; buf[ptr++] = '3';
                buf[ptr++] = '6'; buf[ptr++] = '4'; buf[ptr++] = '8';
                return this;
            }
            d = len(x = -x);
            buf[ptr++] = '-';
        } else {
            d = len(x);
        }
        int j = ptr += d; 
        while (x > 0) {
            buf[--j] = (byte) ('0' + (x % 10));
            x /= 10;
        }
        return this;
    }

    public FastPrintStream print(long x) {
        if ((int) x == x) return print((int) x);
        if (ptr > BUF_SIZE_MINUS_LONG_MAX_LEN) internalFlush();
        int d;
        if (x < 0) {
            if (x == Long.MIN_VALUE) {
                buf[ptr++] = '-'; buf[ptr++] = '9'; buf[ptr++] = '2'; buf[ptr++] = '2';
                buf[ptr++] = '3'; buf[ptr++] = '3'; buf[ptr++] = '7'; buf[ptr++] = '2';
                buf[ptr++] = '0'; buf[ptr++] = '3'; buf[ptr++] = '6'; buf[ptr++] = '8';
                buf[ptr++] = '5'; buf[ptr++] = '4'; buf[ptr++] = '7'; buf[ptr++] = '7';
                buf[ptr++] = '5'; buf[ptr++] = '8'; buf[ptr++] = '0'; buf[ptr++] = '8';
                return this;
            }
            d = len(x = -x);
            buf[ptr++] = '-';
        } else {
            d = len(x);
        }
        int j = ptr += d; 
        while (x > 0) {
            buf[--j] = (byte) ('0' + (x % 10));
            x /= 10;
        }
        return this;
    }

    public FastPrintStream print(double d, int precision) {
        if (d < 0) {
            print('-');
            d = -d;
        }
        d += Math.pow(10, -precision) / 2;
        print((long) d).print('.');
        d -= (long) d;
        for(int i = 0; i < precision; i++){
            d *= 10;
            print((int) d);
            d -= (int) d;
        }
        return this;
    }

    public FastPrintStream print(double d) {
        return print(d, precision);
    }

    public void setPrecision(int precision) {
        this.precision = precision;
    }

    private void internalFlush() {
        try {
            out.write(buf, 0, ptr);
            ptr = 0;
        } catch (java.io.IOException e) {
            throw new java.io.UncheckedIOException(e);
        }
    }

    public void flush() {
        try {
            out.write(buf, 0, ptr);
            out.flush();
            ptr = 0;
        } catch (java.io.IOException e) {
            throw new java.io.UncheckedIOException(e);
        }
    }

    public void close() {
        try {
            out.close();
        } catch (java.io.IOException e) {
            throw new java.io.UncheckedIOException(e);
        }
    }

    private static int len(int x) {
        return
            x >= 1000000000 ? 10 :
            x >= 100000000  ?  9 :
            x >= 10000000   ?  8 :
            x >= 1000000    ?  7 :
            x >= 100000     ?  6 :
            x >= 10000      ?  5 :
            x >= 1000       ?  4 :
            x >= 100        ?  3 :
            x >= 10         ?  2 : 1;
    }

    private static int len(long x) {
        return
            x >= 1000000000000000000l ? 19 :
            x >= 100000000000000000l  ? 18 :
            x >= 10000000000000000l   ? 17 :
            x >= 1000000000000000l    ? 16 :
            x >= 100000000000000l     ? 15 :
            x >= 10000000000000l      ? 14 :
            x >= 1000000000000l       ? 13 :
            x >= 100000000000l        ? 12 :
            x >= 10000000000l         ? 11 : 10;
    }
}

class IntSegmentTree {
    final int MAX;

    final int N;
    final java.util.function.IntBinaryOperator op;
    final int E;

    final int[] data;

    public IntSegmentTree(int n, java.util.function.IntBinaryOperator op, int e) {
        this.MAX = n;
        int k = 1;
        while (k < n) k <<= 1;
        this.N = k;
        this.E = e;
        this.op = op;
        this.data = new int[N << 1];
        java.util.Arrays.fill(data, E);
    }

    public IntSegmentTree(int[] dat, java.util.function.IntBinaryOperator op, int e) {
        this(dat.length, op, e);
        build(dat);
    }

    private void build(int[] dat) {
        int l = dat.length;
        System.arraycopy(dat, 0, data, N, l);
        for (int i = N - 1; i > 0; i--) {
            data[i] = op.applyAsInt(data[i << 1 | 0], data[i << 1 | 1]);
        }
    }

    public void set(int p, int x) {
        exclusiveRangeCheck(p);
        data[p += N] = x;
        p >>= 1;
        while (p > 0) {
            data[p] = op.applyAsInt(data[p << 1 | 0], data[p << 1 | 1]);
            p >>= 1;
        }
    }

    public int get(int p) {
        exclusiveRangeCheck(p);
        return data[p + N];
    }

    public int prod(int l, int r) {
        if (l > r) {
            throw new IllegalArgumentException(
                String.format("Invalid range: [%d, %d)", l, r)
            );
        }
        inclusiveRangeCheck(l);
        inclusiveRangeCheck(r);
        int sumLeft = E;
        int sumRight = E;
        l += N; r += N;
        while (l < r) {
            if ((l & 1) == 1) sumLeft = op.applyAsInt(sumLeft, data[l++]);
            if ((r & 1) == 1) sumRight = op.applyAsInt(data[--r], sumRight);
            l >>= 1; r >>= 1;
        }
        return op.applyAsInt(sumLeft, sumRight);
    }

    public int allProd() {
        return data[1];
    }

    public int maxRight(int l, java.util.function.IntPredicate f) {
        inclusiveRangeCheck(l);
        if (!f.test(E)) {
            throw new IllegalArgumentException("Identity element must satisfy the condition.");
        }
        if (l == MAX) return MAX;
        l += N;
        int sum = E;
        do {
            l >>= Integer.numberOfTrailingZeros(l);
            if (!f.test(op.applyAsInt(sum, data[l]))) {
                while (l < N) {
                    l = l << 1;
                    if (f.test(op.applyAsInt(sum, data[l]))) {
                        sum = op.applyAsInt(sum, data[l]);
                        l++;
                    }
                }
                return l - N;
            }
            sum = op.applyAsInt(sum, data[l]);
            l++;
        } while ((l & -l) != l);
        return MAX;
    }

    public int minLeft(int r, java.util.function.IntPredicate f) {
        inclusiveRangeCheck(r);
        if (!f.test(E)) {
            throw new IllegalArgumentException("Identity element must satisfy the condition.");
        }
        if (r == 0) return 0;
        r += N;
        int sum = E;
        do {
            r--;
            while (r > 1 && (r & 1) == 1) r >>= 1;
            if (!f.test(op.applyAsInt(data[r], sum))) {
                while (r < N) {
                    r = r << 1 | 1;
                    if (f.test(op.applyAsInt(data[r], sum))) {
                        sum = op.applyAsInt(data[r], sum);
                        r--;
                    }
                }
                return r + 1 - N;
            }
            sum = op.applyAsInt(data[r], sum);
        } while ((r & -r) != r);
        return 0;
    }

    private void exclusiveRangeCheck(int p) {
        if (p < 0 || p >= MAX) {
            throw new IndexOutOfBoundsException(
                String.format("Index %d out of bounds for the range [%d, %d).", p, 0, MAX)
            );
        }
    }

    private void inclusiveRangeCheck(int p) {
        if (p < 0 || p > MAX) {
            throw new IndexOutOfBoundsException(
                String.format("Index %d out of bounds for the range [%d, %d].", p, 0, MAX)
            );
        }
    }

    // **************** DEBUG **************** //

    private int indent = 6;

    public void setIndent(int newIndent) {
        this.indent = newIndent;
    }

    @Override
    public String toString() {
        return toSimpleString();
    }

    public String toDetailedString() {
        return toDetailedString(1, 0);
    }

    private String toDetailedString(int k, int sp) {
        if (k >= N) return indent(sp) + data[k];
        String s = "";
        s += toDetailedString(k << 1 | 1, sp + indent);
        s += "\n";
        s += indent(sp) + data[k];
        s += "\n";
        s += toDetailedString(k << 1 | 0, sp + indent);
        return s;
    }

    private static String indent(int n) {
        StringBuilder sb = new StringBuilder();
        while (n --> 0) sb.append(' ');
        return sb.toString();
    }

    public String toSimpleString() {
        StringBuilder sb = new StringBuilder();
        sb.append('[');
        for (int i = 0; i < N; i++) {
            sb.append(data[i + N]);
            if (i < N - 1) sb.append(',').append(' ');
        }
        sb.append(']');
        return sb.toString();
    }
}

class IntOrderedMultiSet extends IntRandomizedBinarySearchTree<Object> {
    Node root;
    int kthElement(Node t, int k) {
        int c = size(t.l);
        if (k < c) return kthElement(t.l, k);
        if (k == c) return t.key;
        return kthElement(t.r, k - c - 1);
    }
    public int kthElement(int k) {
        if (k < 0 || k >= size()) throw new IndexOutOfBoundsException();
        return kthElement(root, k);
    }
    Node insertKey(Node t, int key) {
        return insert(t, leqCount(t, key), key, null);
    }
    public void insertKey(int key) {
        root = insertKey(root, key);
    }
    Node eraseKey(Node t, int key) {
        if (count(t, key) == 0) return t;
        return super.erase(t, leqCount(t, key) - 1);
    }
    public void eraseKey(int key) {
        root = eraseKey(root, key);
    }
    int count(Node t, int key) {
        return leqCount(t, key) - ltCount(t, key);
    }
    public int count(int key) {
        return count(root, key);
    }
    int leqCount(Node t, int key) {
        if (t == null) return 0;
        if (key < t.key) return leqCount(t.l, key);
        return leqCount(t.r, key) + size(t.l) + 1;
    }
    public int leqCount(int key) {
        return leqCount(root, key);
    }
    int ltCount(Node t, int key) {
        if (t == null) return 0;
        if (key <= t.key) return ltCount(t.l, key);
        return ltCount(t.r, key) + size(t.l) + 1;
    }
    public int ltCount(int key) {
        return ltCount(root, key);
    }
    public int size() {
        return size(root);
    }
    public void clear() {
        this.root = null;
    }
}


class LongCompress {
    private final int n;
    private final java.util.HashMap<Long, Integer> compressMap;
    private final int[] compressed;
    private final long[] sorted;
    public LongCompress(long[] a, boolean oneIndexed) {
        this.n = a.length;
        this.sorted = sort(LongArrayFactory.unique(a));
        this.compressMap = new java.util.HashMap<>();
        for (int i = 0; i < sorted.length; i++) {
            compressMap.put(sorted[i], oneIndexed ? i + 1 : i);
        }
        this.compressed = IntArrayFactory.setAll(n, i -> compressMap.get(a[i]));
    }
    public LongCompress(java.util.Collection<Long> collection, boolean oneIndexed) {
        this(LongArrayFactory.toArray(collection), oneIndexed);
    }
    public int compressedSize() {
        return sorted.length;
    }
    public int[] compressed() {
        return compressed;
    }
    public long restore(int i) {
        return sorted[i];
    }
    public int compress(long x) {
        return compressMap.get(x);
    }
    private static long[] sort(long[] a) {
        long[]b=new long[a.length];
        int[]c0=new int[0x101],c1=new int[0x101],c2=new int[0x101],c3=new int[0x101];
        int[]c4=new int[0x101],c5=new int[0x101],c6=new int[0x101],c7=new int[0x101];
        for(long v:a){
            c0[(int)(v&0xff)+1]++;c1[(int)(v>>>8&0xff)+1]++;c2[(int)(v>>>16&0xff)+1]++;c3[(int)(v>>>24&0xff)+1]++;
            c4[(int)(v>>>32&0xff)+1]++;c5[(int)(v>>>40&0xff)+1]++;c6[(int)(v>>>48&0xff)+1]++;c7[(int)(v>>>56^0x80)+1]++;
        }
        for(int i=0;i<0x100;i++){
            c0[i+1]+=c0[i];c1[i+1]+=c1[i];c2[i+1]+=c2[i];c3[i+1]+=c3[i];
            c4[i+1]+=c4[i];c5[i+1]+=c5[i];c6[i+1]+=c6[i];c7[i+1]+=c7[i];
        }
        for(long v:a)b[c0[(int)(v&0xff)]++]=v;for(long v:b)a[c1[(int)(v>>>8&0xff)]++]=v;
        for(long v:a)b[c2[(int)(v>>>16&0xff)]++]=v;for(long v:b)a[c3[(int)(v>>>24&0xff)]++]=v;
        for(long v:a)b[c4[(int)(v>>>32&0xff)]++]=v;for(long v:b)a[c5[(int)(v>>>40&0xff)]++]=v;
        for(long v:a)b[c6[(int)(v>>>48&0xff)]++]=v;for(long v:b)a[c7[(int)(v>>>56^0x80)]++]=v;
        return a;
    }
}

/**
 * @author https://atcoder.jp/users/suisen
 */
class FastScanner implements AutoCloseable {
    private final ByteBuffer tokenBuf = new ByteBuffer();
    private final java.io.InputStream in;
    private final byte[] rawBuf = new byte[1 << 14];
    private int ptr = 0;
    private int buflen = 0;

    public FastScanner(java.io.InputStream in) {
        this.in = in;
    }

    public FastScanner() {
        this(new java.io.FileInputStream(java.io.FileDescriptor.in));
    }

    private final int readByte() {
        if (ptr < buflen) return rawBuf[ptr++];
        ptr = 0;
        try {
            buflen = in.read(rawBuf);
            if (buflen > 0) {
                return rawBuf[ptr++];
            } else {
                throw new java.io.EOFException();
            }
        } catch (java.io.IOException e) {
            throw new java.io.UncheckedIOException(e);
        }
    }

    private final int readByteUnsafe() {
        if (ptr < buflen) return rawBuf[ptr++];
        ptr = 0;
        try {
            buflen = in.read(rawBuf);
            if (buflen > 0) {
                return rawBuf[ptr++];
            } else {
                return -1;
            }
        } catch (java.io.IOException e) {
            throw new java.io.UncheckedIOException(e);
        }
    }

    private final int skipUnprintableChars() {
        int b = readByte();
        while (b <= 32 || b >= 127) b = readByte();
        return b;
    }

    private final void loadToken() {
        tokenBuf.clear();
        for (int b = skipUnprintableChars(); 32 < b && b < 127; b = readByteUnsafe()) {
            tokenBuf.append(b);
        }
    }

    public final boolean hasNext() {
        for (int b = readByteUnsafe(); b <= 32 || b >= 127; b = readByteUnsafe()) {
            if (b == -1) return false;
        }
        --ptr;
        return true;
    }

    public final String next() {
        loadToken();
        return new String(tokenBuf.getRawBuf(), 0, tokenBuf.size());
    }

    public final String nextLine() {
        tokenBuf.clear();
        for (int b = readByte(); b != '\n'; b = readByteUnsafe()) {
            if (b == -1) break;
            tokenBuf.append(b);
        }
        return new String(tokenBuf.getRawBuf(), 0, tokenBuf.size());
    }

    public final char nextChar() {
        return (char) skipUnprintableChars();
    }

    public final char[] nextChars() {
        loadToken();
        return tokenBuf.toCharArray();
    }

    public final long nextLong() {
        long n = 0;
        boolean isNegative = false;
        int b = skipUnprintableChars();
        if (b == '-') {
            isNegative = true;
            b = readByteUnsafe();
        }
        if (b < '0' || '9' < b) throw new NumberFormatException();
        while ('0' <= b && b <= '9') {
            // -9223372036854775808 - 9223372036854775807
            if (n >= 922337203685477580l) {
                if (n > 922337203685477580l) {
                    throw new ArithmeticException("long overflow");
                }
                if (isNegative) {
                    if (b >= '9') {
                        throw new ArithmeticException("long overflow");
                    }
                    n = -n - (b + '0');
                    b = readByteUnsafe();
                    if ('0' <= b && b <= '9') {
                        throw new ArithmeticException("long overflow");
                    } else if (b <= 32 || b >= 127) {
                        return n;
                    } else {
                        throw new NumberFormatException();
                    }
                } else {
                    if (b >= '8') {
                        throw new ArithmeticException("long overflow");
                    }
                    n = n * 10 + b - '0';
                    b = readByteUnsafe();
                    if ('0' <= b && b <= '9') {
                        throw new ArithmeticException("long overflow");
                    } else if (b <= 32 || b >= 127) {
                        return n;
                    } else {
                        throw new NumberFormatException();
                    }
                }
            }
            n = n * 10 + b - '0';
            b = readByteUnsafe();
        }
        if (b <= 32 || b >= 127) return isNegative ? -n : n;
        throw new NumberFormatException();
    }
    public final int nextInt() {
        long value = nextLong();
        if ((int) value != value) {
            throw new ArithmeticException("int overflow");
        }
        return (int) value;
    }
    public final double nextDouble() {
        return Double.parseDouble(next());
    }
    public final void close() {
        try {
            in.close();
        } catch (java.io.IOException e) {
            throw new java.io.UncheckedIOException(e);
        }
    }

    private static final class ByteBuffer {
        private static final int DEFAULT_BUF_SIZE = 1 << 12;
        private byte[] buf;
        private int ptr = 0;
        private ByteBuffer(int capacity) {
            this.buf = new byte[capacity];
        }
        private ByteBuffer() {
            this(DEFAULT_BUF_SIZE);
        }
        private ByteBuffer append(int b) {
            if (ptr == buf.length) {
                int newLength = buf.length << 1;
                byte[] newBuf = new byte[newLength];
                System.arraycopy(buf, 0, newBuf, 0, buf.length);
                buf = newBuf;
            }
            buf[ptr++] = (byte) b;
            return this;
        }
        private char[] toCharArray() {
            char[] chs = new char[ptr];
            for (int i = 0; i < ptr; i++) {
                chs[i] = (char) buf[i];
            }
            return chs;
        }
        private byte[] getRawBuf() {
            return buf;
        }
        private int size() {
            return ptr;
        }
        private void clear() {
            ptr = 0;
        }
    }
}


final class ModArithmetic998244353 extends ModArithmetic {
    public static final ModArithmetic INSTANCE = new ModArithmetic998244353();
    private ModArithmetic998244353(){}
    public static final long MOD = Const.MOD998244353;
    public long getMod() {return MOD;}
    public long mod(long a) {return (a %= MOD) < 0 ? a + MOD : a;}
    public long add(long a, long b) {return (a += b) >= MOD ? a - MOD : a;}
    public long sub(long a, long b) {return (a -= b) < 0 ? a + MOD : a;}
    public long mul(long a, long b) {return (a * b) % MOD;}
}

class IntEntry<V> {
    public int key;
    public V val;
    public IntEntry(int key, V val) {
        this.key = key;
        this.val = val;
    }
    public int getKey() {return key;}
    public V getValue() {return val;}
    public V setValue(V val) {
        V oldValue = this.val;
        this.val = val;
        return oldValue;
    }
    public boolean equals(Object o) {
        if (!(o instanceof IntEntry)) return false;
        IntEntry<?> e = (IntEntry<?>) o;
        return key == e.getKey() && (val == null ? e.val == null : val.equals(e.val));
    }
    public int hashCode() {
        int keyHash = key;
        int valueHash = (val == null ? 0 : val.hashCode());
        return keyHash ^ valueHash;
    }
    public String toString() {return key + "=" + val;}
}

final class Random {
    private static final double DOUBLE_UNIT = 0x1.0p-53;
    private int x = 123456789;
    private int y = 362436069;
    private int z = 521288629;
    private int w = 88675123;
    public int nextInt() {
        int t = x ^ (x << 11);
        x = y; y = z; z = w;
        return w = (w ^ (w >> 19)) ^ (t ^ (t >> 8));
    }
    public long nextLong() {
        return ((long) (nextInt()) << 32) + nextInt();
    }
    public int nextInt(int bound) {
        return nextInt() % bound;
    }
    public boolean nextBoolean() {
        return (nextInt() & 1) == 0;
    }
    public double nextDouble() {
        return (((long) (next(26)) << 27) + next(27)) * DOUBLE_UNIT;
    }
    private int next(int bits) {
        int mask = bits == 32 ? -1 : (1 << bits) - 1;
        return nextInt() & mask;
    }
}



/**
 * @author https://atcoder.jp/users/suisen
 * 
 * (NON-DESTRUCTIVE) methods that returns long array.
 */
final class LongArrayFactory {
    private LongArrayFactory() {}
    public static long[] filled(int n, long init) {
        long[] ret = new long[n];
        Arrays.fill(ret, init);
        return ret;
    }
    public static long[][] filled(int n, int m, long init) {
        long[][] ret = new long[n][m];
        for (int i = 0; i < n; i++) Arrays.fill(ret[i], init);
        return ret;
    }
    public static long[] setAll(int n, IntToLongFunction f) {
        long[] a = new long[n];
        Arrays.setAll(a, f);
        return a;
    }
    public static long[][] setAll(int n, int m, IntToLongBiFunction f) {
        long[][] a = new long[n][m];
        for (int i = 0; i < n; i++) {
            IntToLongFunction g = f.curry(i);
            a[i] = setAll(m, j -> g.applyAsLong(j));
        }
        return a;
    }
    public static long[] toArray(Collection<? extends Number> collection) {
        int n = collection.size();
        long[] ret = new long[n];
        java.util.Iterator<? extends Number> it = collection.iterator();
        int i = 0;
        while (it.hasNext()) ret[i++] = ((Number) it.next()).longValue();
        return ret;
    }
    public static long[] unique(long[] a) {
        HashSet<Long> set = new HashSet<>();
        for (long e : a) set.add(e);
        long[] b = new long[set.size()];
        int j = 0;
        for (long e : a) if (set.contains(e)) set.remove(b[j++] = e);
        return b;
    }
    public static long[][] transpose(long[][] a) {
        return setAll(a[0].length, a.length, (i, j) -> a[j][i]);
    }
    public static long[] map(long[] a, LongUnaryOperator f) {
        return setAll(a.length, i -> f.applyAsLong(a[i]));
    }
    public static long[] filter(long[] a, LongPredicate p) {
        int m = 0;
        for (long e : a) if (p.test(e)) m++;
        long[] res = new long[m];
        int j = 0;
        for (long e : a) if (p.test(e)) res[j++] = e;
        return res;
    }
}


/**
 * @author https://atcoder.jp/users/suisen
 */
final class IntArrayFactory {
    private IntArrayFactory() {}
    public static int[] filled(int n, int init) {
        int[] ret = new int[n];
        Arrays.fill(ret, init);
        return ret;
    }
    public static int[][] filled(int n, int m, int init) {
        int[][] ret = new int[n][m];
        for (int i = 0; i < n; i++) Arrays.fill(ret[i], init);
        return ret;
    }
    public static int[] setAll(int n, IntUnaryOperator f) {
        int[] a = new int[n];
        Arrays.setAll(a, f);
        return a;
    }
    public static int[][] setAll(int n, int m, IntBinaryOperator f) {
        int[][] a = new int[n][m];
        for (int i = 0; i < n; i++) {
            int _i = i; a[i] = setAll(m, j -> f.applyAsInt(_i, j));
        }
        return a;
    }
    public static int[] toArray(Collection<? extends Number> collection) {
        int n = collection.size();
        int[] ret = new int[n];
        java.util.Iterator<? extends Number> it = collection.iterator();
        int i = 0;
        while (it.hasNext()) ret[i++] = ((Number) it.next()).intValue();
        return ret;
    }
    public static int[] unique(int[] a) {
        HashSet<Integer> set = new HashSet<>();
        for (int e : a) set.add(e);
        int[] b = new int[set.size()];
        int j = 0;
        for (int e : a) if (set.contains(e)) set.remove(b[j++] = e);
        return b;
    }
    public static int[][] transpose(int[][] a) {
        return setAll(a[0].length, a.length, (i, j) -> a[j][i]);
    }
    public static int[] map(int[] a, IntUnaryOperator f) {
        return setAll(a.length, i -> f.applyAsInt(a[i]));
    }
    public static int[] histogram(int[] a, int max) {
        int[] ret = new int[max + 1];
        for (int e : a) ret[e]++;
        return ret;
    }
    public static int[] filter(int[] a, IntPredicate p) {
        int m = 0;
        for (int e : a) if (p.test(e)) m++;
        int[] res = new int[m];
        int j = 0;
        for (int e : a) if (p.test(e)) res[j++] = e;
        return res;
    }
    public static int[] filterOfRange(int l, int r, IntPredicate p) {
        int m = 0;
        for (int i = l; i < r; i++) if (p.test(i)) m++;
        int[] res = new int[m];
        for (int i = l, j = 0; i < r; i++) if (p.test(i)) res[j++] = i;
        return res;
    }
    public static int[] filterOfRange(int n, IntPredicate p) {
        return filterOfRange(0, n, p);
    }
}

class Const {
    public static final long   LINF   = 1l << 59;
    public static final int    IINF   = (1  << 30) - 1;
    public static final double DINF   = 1e150;
    
    public static final double SMALL  = 1e-12;
    public static final double MEDIUM = 1e-9;
    public static final double LARGE  = 1e-6;

    public static final long MOD1000000007 = 1000000007;
    public static final long MOD998244353  = 998244353 ;
    public static final long MOD754974721  = 754974721 ;
    public static final long MOD167772161  = 167772161 ;
    public static final long MOD469762049  = 469762049 ;

    public static final int[] dx8 = {1, 0, -1, 0, 1, -1, -1, 1};
    public static final int[] dy8 = {0, 1, 0, -1, 1, 1, -1, -1};
    public static final int[] dx4 = {1, 0, -1, 0};
    public static final int[] dy4 = {0, 1, 0, -1};
}


class IntRandomizedBinarySearchTree<V> {
    private Node splitLeft, splitRight;
    private Random rnd = new Random();
    Node merge(Node l, Node r) {
        if (l == null) return r;
        if (r == null) return l;
        if (rnd.nextInt(l.size + r.size) < l.size) {
            l.r = merge(l.r, r);
            return l.update();
        } else {
            r.l = merge(l, r.l);
            return r.update();
        }
    }
    void split(Node x, int k) {
        if (k < 0 || k > size(x)) {
            throw new IndexOutOfBoundsException(
                String.format("index %d is out of bounds for the length of %d", k, size(x))
            );
        }
        if (x == null) {
            splitLeft = null;
            splitRight = null;
        } else if (k <= size(x.l)) {
            split(x.l, k);
            x.l = splitRight;
            splitRight = x.update();
        } else {
            split(x.r, k - size(x.l) - 1);
            x.r = splitLeft;
            splitLeft = x.update();
        }
    }
    Node insert(Node t, int k, int key, V val) {
        split(t, k);
        return merge(merge(splitLeft, new Node(key, val)), splitRight);
    }
    Node erase(Node t, int k) {
        split(t, k);
        Node l = splitLeft;
        split(splitRight, 1);
        return merge(l, splitRight);
    }
    int size(Node nd) {return nd == null ? 0 : nd.size;}
    class Node extends IntEntry<V> {
        Node l, r;
        int size;
        private Node(int key, V val) {super(key, val); this.size = 1;}
        private Node update() {
            size = size(l) + size(r) + 1;
            return this;
        }
    }
}

/**
 * @author https://atcoder.jp/users/suisen
 */
abstract class ModArithmetic {
    public abstract long getMod();
    public abstract long mod(long a);
    public abstract long add(long a, long b);
    public abstract long sub(long a, long b);
    public abstract long mul(long a, long b);
    public final long div(long a, long b) {return mul(a, inv(b));}
    public final long inv(long a) {
        a = mod(a);
        long b = getMod();
        long u = 1, v = 0;
        while (b >= 1) {
            long t = a / b;
            a -= t * b;
            long tmp1 = a; a = b; b = tmp1;
            u -= t * v;
            long tmp2 = u; u = v; v = tmp2;
        }
        if (a != 1) throw new ArithmeticException("divide by zero");
        return mod(u);
    }
    public final long fma(long a, long b, long c) {return add(mul(a, b), c);}
    public final long pow(long a, long b) {
        long pow = 1;
        for (a = mod(a); b > 0; b >>= 1, a = mul(a, a)) {
            if ((b & 1) == 1) pow = mul(pow, a);
        }
        return pow;
    }

    public final long add(long a, long b, long c) {return mod(a + b + c);}
    public final long add(long a, long b, long c, long d) {return mod(a + b + c + d);}
    public final long add(long a, long b, long c, long d, long e) {return mod(a + b + c + d + e);}
    public final long add(long a, long b, long c, long d, long e, long f) {return mod(a + b + c + d + e + f);}
    public final long add(long a, long b, long c, long d, long e, long f, long g) {return mod(a + b + c + d + e + f + g);}
    public final long add(long a, long b, long c, long d, long e, long f, long g, long h) {return mod(a + b + c + d + e + f + g + h);}
    public final long add(long... xs) {long s = 0; for (long x : xs) s += x; return mod(s);}
    public final long mul(long a, long b, long c) {return mul(a, mul(b, c));}
    public final long mul(long a, long b, long c, long d) {return mul(a, mul(b, mul(c, d)));}
    public final long mul(long a, long b, long c, long d, long e) {return mul(a, mul(b, mul(c, mul(d, e))));}
    public final long mul(long a, long b, long c, long d, long e, long f) {return mul(a, mul(b, mul(c, mul(d, mul(e, f)))));}
    public final long mul(long a, long b, long c, long d, long e, long f, long g) {return mul(a, mul(b, mul(c, mul(d, mul(e, mul(f, g))))));}
    public final long mul(long a, long b, long c, long d, long e, long f, long g, long h) {return mul(a, mul(b, mul(c, mul(d, mul(e, mul(f, mul(g, h)))))));}
    public final long mul(long... xs) {long s = 1; for (long x : xs) s = mul(s, x); return s;}
    public final java.util.OptionalLong sqrt(long a) {
        a = mod(a);
        if (a == 0) return java.util.OptionalLong.of(0);
        if (a == 1) return java.util.OptionalLong.of(1);
        long p = getMod();
        if (pow(a, (p - 1) >> 1) != 1) {
            return java.util.OptionalLong.empty();
        }
        if ((p & 3) == 3) {
            return java.util.OptionalLong.of(pow(a, (p + 1) >> 2));
        }
        if ((p & 7) == 5) {
            if (pow(a, (p - 1) >> 2) == 1) {
                return java.util.OptionalLong.of(pow(a, (p + 3) >> 3));
            } else {
                return java.util.OptionalLong.of(mul(pow(2, (p - 1) >> 2), pow(a, (p + 3) >> 3)));
            }
        }
        long S = 0, Q = p - 1;
        while ((Q & 1) == 0) {
            ++S;
            Q >>= 1;
        }
        long z = 1;
        while (pow(z, (p - 1) >> 1) != p - 1) ++z;
        long c = pow(z, Q), R = pow(a, (Q + 1) / 2), t = pow(a, Q), M = S;
        while (t != 1) {
            long cur = t;
            int i;
            for (i = 1; i < M; i++) {
                cur = mul(cur, cur);
                if (cur == 1) break;
            }
            long b = pow(c, 1l << (M - i - 1));
            c = mul(b, b); R = mul(R, b); t = mul(t, b, b); M = i;
        }
        return java.util.OptionalLong.of(R);
    }

    /** array operations */

    public final long[] rangeInv(int n) {
        final long MOD = getMod();
        long[] invs = new long[n + 1];
        if (n == 0) return invs;
        invs[1] = 1;
        for (int i = 2; i <= n; i++) {
            invs[i] = mul(MOD - MOD / i, invs[(int) (MOD % i)]);
        }
        return invs;
    }
    public final long[] arrayInv(long[] a) {
        int n = a.length;
        long[] l = new long[n + 1];
        long[] r = new long[n + 1];
        l[0] = r[n] = 1;
        for (int i = 0; i < n; i++) l[i + 1] = mul(l[i], a[i    ]);
        for (int i = n; i > 0; i--) r[i - 1] = mul(r[i], a[i - 1]);
        long invAll = inv(l[n]);
        long[] invs = new long[n];
        for (int i = 0; i < n; i++) {
            invs[i] = mul(l[i], r[i + 1], invAll);
        }
        return invs;
    }
    public final long[] factorial(int n) {
        long[] ret = new long[n + 1];
        ret[0] = 1;
        for (int i = 1; i <= n; i++) ret[i] = mul(ret[i - 1], i);
        return ret;
    }
    public final long[] factorialInv(int n) {
        long facN = 1;
        for (int i = 2; i <= n; i++) facN = mul(facN, i);
        long[] invs = new long[n + 1];
        invs[n] = inv(facN);
        for (int i = n; i > 0; i--) invs[i - 1] = mul(invs[i], i);
        return invs;
    }
    public final long[] rangePower(long a, int n) {
        a = mod(a);
        long[] pows = new long[n + 1];
        pows[0] = 1;
        for (int i = 1; i <= n; i++) pows[i] = mul(pows[i - 1], a);
        return pows;
    }
    public final long[] rangePowerInv(long a, int n) {
        a = mod(a);
        long[] invs = new long[n + 1];
        invs[n] = inv(pow(a, n));
        for (int i = n; i > 0; i--) invs[i - 1] = mul(invs[i], a);
        return invs;
    }

    /** combinatric operations */

    public final long[][] combTable(int n) {
        long[][] comb = new long[n + 1][];
        for (int i = 0; i <= n; i++) {
            comb[i] = new long[i + 1];
            comb[i][0] = comb[i][i] = 1;
            for (int j = 1; j < i; j++) {
                comb[i][j] = add(comb[i - 1][j - 1], comb[i - 1][j]);
            }
        }
        return comb;
    }
    public final long comb(int n, int r, long[] fac, long[] facInv) {
        return r < 0 || r > n ? 0 : mul(fac[n], facInv[r], facInv[n - r]);
    }
    public final long naiveComb(long n, long r) {
        if (r < 0 || r > n) return 0;
        r = Math.min(r, n - r);
        long num = 1, den = 1;
        for (int i = 0; i < r; i++) {
            num = mul(num, n - i);
            den = mul(den, i + 1);
        }
        return div(num, den);
    }
    public final long perm(int n, int r, long[] fac, long[] facInv) {
        return r < 0 || r > n ? 0 : mul(fac[n], facInv[n - r]);
    }
    public final long naivePerm(long n, long r) {
        if (r < 0 || r > n) return 0;
        long res = 1;
        for (long i = 0; i < r; i++) res = mul(res, n - i);
        return res;
    }
}
0