結果

問題 No.3408 1215 Segments
コンテスト
ユーザー kencho
提出日時 2025-11-25 06:13:30
言語 Java
(openjdk 23)
結果
AC  
実行時間 304 ms / 2,500 ms
コード長 17,131 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 4,825 ms
コンパイル使用メモリ 89,884 KB
実行使用メモリ 50,324 KB
最終ジャッジ日時 2025-12-14 23:30:15
合計ジャッジ時間 14,436 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 46
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import java.util.Arrays;

public class WriterCode {

    public static final int MOD998 = 998244353;
    public static final int MOD100 = 1000000007;

    public static void main(String[] args) throws Exception {
        ContestScanner sc = new ContestScanner();
        ContestPrinter cp = new ContestPrinter();
        long n = sc.nextLong();
        String s = String.valueOf(n);
        int len = s.length();

        String ans = null;

        // Reusable arrays to avoid allocation
        int[] counts = new int[10];

        // Iterate lengths
        for (int l = len; l <= 20; l++) {
            minForLen = null;
            dfs(0, l, counts, s, l);
            if (minForLen != null) {
                ans = minForLen;
                break;
            }
        }

        cp.println(ans);
        cp.close();
    }

    static String minForLen = null;

    // DFS to generate all digit distributions summing to 'total'
    static void dfs(int digit, int remaining, int[] counts, String targetStr, int L) {
        if (digit == 9) {
            counts[9] = remaining;
            checkAndUpdate(counts, targetStr, L);
            counts[9] = 0; // backtrack
            return;
        }

        for (int i = 0; i <= remaining; i++) {
            counts[digit] = i;
            dfs(digit + 1, remaining - i, counts, targetStr, L);
        }
        counts[digit] = 0; // backtrack
    }

    static void checkAndUpdate(int[] cnt, String targetStr, int L) {
        // Fast validity check

        // segcnt[2] + segcnt[4] parity check
        // segcnt[2] = L - cnt[2]
        // segcnt[4] = cnt[0] + cnt[2] + cnt[6] + cnt[8]
        // Sum = L + cnt[0] + cnt[6] + cnt[8]
        if ((L + cnt[0] + cnt[6] + cnt[8]) % 2 != 0)
            return;

        // Calculate segcnts needed for logic
        // segcnt[0] (top): 0, 2, 3, 5, 6, 7, 8, 9
        int s0 = cnt[0] + cnt[2] + cnt[3] + cnt[5] + cnt[6] + cnt[7] + cnt[8] + cnt[9];

        // segcnt[1] (top-right): 0, 1, 2, 3, 4, 7, 8, 9
        int s1 = cnt[0] + cnt[1] + cnt[2] + cnt[3] + cnt[4] + cnt[7] + cnt[8] + cnt[9];

        // segcnt[2] (bottom-right): All except 2
        int s2 = L - cnt[2];

        // segcnt[3] (bottom): 0, 2, 3, 5, 6, 8, 9
        int s3 = cnt[0] + cnt[2] + cnt[3] + cnt[5] + cnt[6] + cnt[8] + cnt[9];

        // segcnt[4] (bottom-left): 0, 2, 6, 8
        int s4 = cnt[0] + cnt[2] + cnt[6] + cnt[8];

        // segcnt[5] (top-left): 0, 4, 5, 6, 8, 9
        int s5 = cnt[0] + cnt[4] + cnt[5] + cnt[6] + cnt[8] + cnt[9];

        // segcnt[6] (middle): 2, 3, 4, 5, 6, 8, 9
        int s6 = cnt[2] + cnt[3] + cnt[4] + cnt[5] + cnt[6] + cnt[8] + cnt[9];

        int base = (s2 + s4) / 2;
        int nine = s3 - base;
        int one = base - s4;
        int zero = (s1 - one) - base;
        int seven = base - (s5 - zero + one);
        int six = (s0 - zero) - base;

        if (one < 0 || zero < 0 || seven < 0 || six < 0 || nine < 0)
            return;
        if (one > cnt[1])
            return;
        if (zero > cnt[0])
            return;
        if (seven > cnt[7])
            return;
        if (six > cnt[6])
            return;
        if (nine > cnt[9])
            return;

        if (s0 - zero - six != base)
            return;
        if (s1 - one - zero != base)
            return;
        if (s2 - one != base)
            return;
        if (s3 - nine != base)
            return;
        if (s4 + one != base)
            return;
        if (s5 - zero + one + seven != base)
            return;
        if (s6 + zero != base)
            return;

        // If valid, find smallest
        String candidate = findSmallest(cnt, targetStr);
        if (candidate != null) {
            if (minForLen == null) {
                minForLen = candidate;
            } else {
                if (candidate.length() < minForLen.length()
                        || (candidate.length() == minForLen.length() && candidate.compareTo(minForLen) < 0)) {
                    minForLen = candidate;
                }
            }
        }
    }

    // Find smallest permutation >= targetStr
    static String findSmallest(int[] counts, String targetStr) {
        int totalDigits = 0;
        for (int c : counts)
            totalDigits += c;

        if (totalDigits > targetStr.length()) {
            // Find smallest non-zero digit for the first position
            for (int d = 1; d <= 9; d++) {
                if (counts[d] > 0) {
                    StringBuilder sb = new StringBuilder();
                    sb.append(d);
                    counts[d]--;
                    // Append remaining digits in increasing order
                    for (int k = 0; k <= 9; k++) {
                        for (int j = 0; j < counts[k]; j++) {
                            sb.append(k);
                        }
                    }
                    counts[d]++; // backtrack
                    return sb.toString();
                }
            }
            return null;
        } else {
            // Same length, need >= targetStr
            return solveRecursive(0, true, counts, targetStr);
        }
    }

    static String solveRecursive(int idx, boolean tight, int[] counts, String targetStr) {
        if (idx == targetStr.length()) {
            return "";
        }

        int start = tight ? (targetStr.charAt(idx) - '0') : 0;
        for (int d = start; d <= 9; d++) {
            if (counts[d] > 0) {
                counts[d]--;
                boolean nextTight = tight && (d == start);

                // Optimization: if not tight, we can just fill with smallest digits
                if (!nextTight) {
                    StringBuilder sb = new StringBuilder();
                    sb.append(d);
                    for (int k = 0; k <= 9; k++) {
                        for (int j = 0; j < counts[k]; j++) {
                            sb.append(k);
                        }
                    }
                    counts[d]++; // backtrack
                    return sb.toString();
                }

                String res = solveRecursive(idx + 1, nextTight, counts, targetStr);
                if (res != null) {
                    counts[d]++; // backtrack
                    return d + res;
                }
                counts[d]++; // backtrack
            }
        }
        return null;
    }

    static class ContestScanner {
        private final java.io.InputStream in;
        private final byte[] buffer = new byte[1024];
        private int ptr = 0;
        private int buflen = 0;

        private static final long LONG_MAX_TENTHS = 922337203685477580L;
        private static final int LONG_MAX_LAST_DIGIT = 7;
        private static final int LONG_MIN_LAST_DIGIT = 8;

        public ContestScanner(java.io.InputStream in) {
            this.in = in;
        }

        public ContestScanner() {
            this(System.in);
        }

        private boolean hasNextByte() {
            if (ptr < buflen) {
                return true;
            } else {
                ptr = 0;
                try {
                    buflen = in.read(buffer);
                } catch (java.io.IOException e) {
                    e.printStackTrace();
                }
                if (buflen <= 0) {
                    return false;
                }
            }
            return true;
        }

        private int readByte() {
            if (hasNextByte())
                return buffer[ptr++];
            else
                return -1;
        }

        private static boolean isPrintableChar(int c) {
            return 33 <= c && c <= 126;
        }

        public boolean hasNext() {
            while (hasNextByte() && !isPrintableChar(buffer[ptr]))
                ptr++;
            return hasNextByte();
        }

        public String next() {
            if (!hasNext())
                throw new java.util.NoSuchElementException();
            StringBuilder sb = new StringBuilder();
            int b = readByte();
            while (isPrintableChar(b)) {
                sb.appendCodePoint(b);
                b = readByte();
            }
            return sb.toString();
        }

        public long nextLong() {
            if (!hasNext())
                throw new java.util.NoSuchElementException();
            long n = 0;
            boolean minus = false;
            int b = readByte();
            if (b == '-') {
                minus = true;
                b = readByte();
            }
            if (b < '0' || '9' < b) {
                throw new NumberFormatException();
            }
            while (true) {
                if ('0' <= b && b <= '9') {
                    int digit = b - '0';
                    if (n >= LONG_MAX_TENTHS) {
                        if (n == LONG_MAX_TENTHS) {
                            if (minus) {
                                if (digit <= LONG_MIN_LAST_DIGIT) {
                                    n = -n * 10 - digit;
                                    b = readByte();
                                    if (!isPrintableChar(b)) {
                                        return n;
                                    } else if (b < '0' || '9' < b) {
                                        throw new NumberFormatException(
                                                String.format("%d%s... is not number", n, Character.toString(b)));
                                    }
                                }
                            } else {
                                if (digit <= LONG_MAX_LAST_DIGIT) {
                                    n = n * 10 + digit;
                                    b = readByte();
                                    if (!isPrintableChar(b)) {
                                        return n;
                                    } else if (b < '0' || '9' < b) {
                                        throw new NumberFormatException(
                                                String.format("%d%s... is not number", n, Character.toString(b)));
                                    }
                                }
                            }
                        }
                        throw new ArithmeticException(
                                String.format("%s%d%d... overflows long.", minus ? "-" : "", n, digit));
                    }
                    n = n * 10 + digit;
                } else if (b == -1 || !isPrintableChar(b)) {
                    return minus ? -n : n;
                } else {
                    throw new NumberFormatException();
                }
                b = readByte();
            }
        }

        public int nextInt() {
            long nl = nextLong();
            if (nl < Integer.MIN_VALUE || nl > Integer.MAX_VALUE)
                throw new NumberFormatException();
            return (int) nl;
        }

        public double nextDouble() {
            return Double.parseDouble(next());
        }

        public long[] nextLongArray(int length) {
            long[] array = new long[length];
            for (int i = 0; i < length; i++)
                array[i] = this.nextLong();
            return array;
        }

        public long[] nextLongArray(int length, java.util.function.LongUnaryOperator map) {
            long[] array = new long[length];
            for (int i = 0; i < length; i++)
                array[i] = map.applyAsLong(this.nextLong());
            return array;
        }

        public int[] nextIntArray(int length) {
            int[] array = new int[length];
            for (int i = 0; i < length; i++)
                array[i] = this.nextInt();
            return array;
        }

        public int[][] nextIntArrayMulti(int length, int width) {
            int[][] arrays = new int[width][length];
            for (int i = 0; i < length; i++) {
                for (int j = 0; j < width; j++)
                    arrays[j][i] = this.nextInt();
            }
            return arrays;
        }

        public int[] nextIntArray(int length, java.util.function.IntUnaryOperator map) {
            int[] array = new int[length];
            for (int i = 0; i < length; i++)
                array[i] = map.applyAsInt(this.nextInt());
            return array;
        }

        public double[] nextDoubleArray(int length) {
            double[] array = new double[length];
            for (int i = 0; i < length; i++)
                array[i] = this.nextDouble();
            return array;
        }

        public double[] nextDoubleArray(int length, java.util.function.DoubleUnaryOperator map) {
            double[] array = new double[length];
            for (int i = 0; i < length; i++)
                array[i] = map.applyAsDouble(this.nextDouble());
            return array;
        }

        public long[][] nextLongMatrix(int height, int width) {
            long[][] mat = new long[height][width];
            for (int h = 0; h < height; h++)
                for (int w = 0; w < width; w++) {
                    mat[h][w] = this.nextLong();
                }
            return mat;
        }

        public int[][] nextIntMatrix(int height, int width) {
            int[][] mat = new int[height][width];
            for (int h = 0; h < height; h++)
                for (int w = 0; w < width; w++) {
                    mat[h][w] = this.nextInt();
                }
            return mat;
        }

        public double[][] nextDoubleMatrix(int height, int width) {
            double[][] mat = new double[height][width];
            for (int h = 0; h < height; h++)
                for (int w = 0; w < width; w++) {
                    mat[h][w] = this.nextDouble();
                }
            return mat;
        }

        public char[][] nextCharMatrix(int height, int width) {
            char[][] mat = new char[height][width];
            for (int h = 0; h < height; h++) {
                String s = this.next();
                for (int w = 0; w < width; w++) {
                    mat[h][w] = s.charAt(w);
                }
            }
            return mat;
        }
    }

    static class ContestPrinter extends java.io.PrintWriter {
        public ContestPrinter(java.io.PrintStream stream) {
            super(stream);
        }

        public ContestPrinter() {
            super(System.out);
        }

        private static String dtos(double x, int n) {
            StringBuilder sb = new StringBuilder();
            if (x < 0) {
                sb.append('-');
                x = -x;
            }
            x += Math.pow(10, -n) / 2;
            sb.append((long) x);
            sb.append(".");
            x -= (long) x;
            for (int i = 0; i < n; i++) {
                x *= 10;
                sb.append((int) x);
                x -= (int) x;
            }
            return sb.toString();
        }

        @Override
        public void print(float f) {
            super.print(dtos(f, 20));
        }

        @Override
        public void println(float f) {
            super.println(dtos(f, 20));
        }

        @Override
        public void print(double d) {
            super.print(dtos(d, 20));
        }

        @Override
        public void println(double d) {
            super.println(dtos(d, 20));
        }

        public void printArray(int[] array, String separator) {
            int n = array.length;
            for (int i = 0; i < n - 1; i++) {
                super.print(array[i]);
                super.print(separator);
            }
            super.println(array[n - 1]);
        }

        public void printArray(int[] array) {
            this.printArray(array, " ");
        }

        public void printArray(int[] array, String separator, java.util.function.IntUnaryOperator map) {
            int n = array.length;
            for (int i = 0; i < n - 1; i++) {
                super.print(map.applyAsInt(array[i]));
                super.print(separator);
            }
            super.println(map.applyAsInt(array[n - 1]));
        }

        public void printArray(int[] array, java.util.function.IntUnaryOperator map) {
            this.printArray(array, " ", map);
        }

        public void printArray(long[] array, String separator) {
            int n = array.length;
            for (int i = 0; i < n - 1; i++) {
                super.print(array[i]);
                super.print(separator);
            }
            super.println(array[n - 1]);
        }

        public void printArray(long[] array) {
            this.printArray(array, " ");
        }

        public void printArray(long[] array, String separator, java.util.function.LongUnaryOperator map) {
            int n = array.length;
            for (int i = 0; i < n - 1; i++) {
                super.print(map.applyAsLong(array[i]));
                super.print(separator);
            }
            super.println(map.applyAsLong(array[n - 1]));
        }

        public void printArray(long[] array, java.util.function.LongUnaryOperator map) {
            this.printArray(array, " ", map);
        }

    }
}
0