結果
問題 | No.823 Many Shifts Easy |
ユーザー | kusomushi |
提出日時 | 2019-07-09 07:28:21 |
言語 | Java21 (openjdk 21) |
結果 |
AC
|
実行時間 | 202 ms / 2,000 ms |
コード長 | 6,927 bytes |
コンパイル時間 | 2,393 ms |
コンパイル使用メモリ | 81,180 KB |
実行使用メモリ | 38,828 KB |
最終ジャッジ日時 | 2024-10-11 02:14:45 |
合計ジャッジ時間 | 4,101 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 50 ms
37,088 KB |
testcase_01 | AC | 51 ms
36,956 KB |
testcase_02 | AC | 56 ms
37,220 KB |
testcase_03 | AC | 202 ms
38,828 KB |
testcase_04 | AC | 51 ms
37,016 KB |
testcase_05 | AC | 161 ms
38,372 KB |
testcase_06 | AC | 190 ms
38,568 KB |
testcase_07 | AC | 51 ms
37,272 KB |
testcase_08 | AC | 141 ms
38,680 KB |
testcase_09 | AC | 75 ms
37,656 KB |
ソースコード
import java.io.*; import java.util.Arrays; import java.util.StringJoiner; import java.util.StringTokenizer; import java.util.function.Function; public class Main { static int N, K; public static void main(String[] args) { FastScanner sc = new FastScanner(System.in); N = sc.nextInt(); K = sc.nextInt(); System.out.println(solve()); } static int solve() { FermatCombination fc = new FermatCombination(N+1); int arithN = div(mul(N, N+1), 2); int arithN1 = div(mul(N-1, N), 2); // とりあえず全部足す int ans = mul(fc.perm(N, K), arithN); // debug(ans); // とりあえず整数列Aの全組み合わせ分を引く { // 各数値がAに選ばれる確率 int p = div(K, N); ans = sub(ans, mul(p, mul(fc.perm(N, K), arithN))); } // debug(ans); // 右側にa+1が存在する場合、+a for (int i = 0; i < K; i++) { int right = K-i-1; if( right <= 0 ) continue; // 右側にa+1が入る確率 int p = div(right, N-1); // iがaである場合の数 perm(N-1, K-1) int b = fc.perm(N-1, K-1); // Nは右側にN+1が来たりしないので除外する ans += mul(mul(arithN1, b), p); ans %= MOD; } return ans; } static int MOD = 1_000_000_007; static int pow(int base, long exp) { if (exp == 0) return 1; int ans = 1; base %= MOD; while (exp > 0) { if ((exp & 1) == 1) { ans = mul(ans, base); } base = mul(base, base); exp = exp >> 1; } return ans; } static int sub(int a, int b) { int c = a - b; if (c < 0) c += MOD; return c; } static int div(int a, int b) { return mul(a, pow(b, MOD - 2)); } static int add(int a, int b) { int c = a + b; if (c >= MOD) c %= MOD; return c; } static int mul(int a, int b) { long c = (long) a * b; if (c >= MOD) c %= MOD; return (int) c; } static class FermatCombination { private final int size; private final int[] factorial; // n -> factorial(n) private final int[] inverse; // n -> inverse(factorial(n)) FermatCombination(int size) { this.size = size; factorial = new int[size + 1]; inverse = new int[size + 1]; init(); } private void init() { factorial[0] = 1; factorial[1] = 1; inverse[0] = 1; inverse[1] = 1; for (int i = 2; i <= size; i++) { factorial[i] = mul(factorial[i - 1], i); inverse[i] = pow(factorial[i], MOD - 2); } } int perm(int n, int k) { return mul(factorial[n], inverse[n-k]); } int comb(int n, int k) { if (n > size) throw new RuntimeException("wtf : size=" + size + " n=" + n); return mul(mul(factorial[n], inverse[k]), inverse[n - k]); } // 重複組み合わせ // k種類からn個選ぶ場合の数 int hcomb(int k, int n) { return comb(k + n - 1, n); } int group(int n, int g, int k) { // C(n, g) * C(n-g, g)... / k! // n! / (n-gk)! / g! ^ k / k! int ret = factorial[n]; // n! ret = mul(ret, inverse[n - g * k]); // 1 / (n-gk)! ret = mul(ret, pow(inverse[g], k)); // 1 / (g! ^ k) ret = mul(ret, inverse[k]); // 1 / k! return ret; } } @SuppressWarnings("unused") static class FastScanner { private BufferedReader reader; private StringTokenizer tokenizer; FastScanner(InputStream in) { reader = new BufferedReader(new InputStreamReader(in)); tokenizer = null; } String next() { if (tokenizer == null || !tokenizer.hasMoreTokens()) { try { tokenizer = new StringTokenizer(reader.readLine()); } catch (IOException e) { throw new RuntimeException(e); } } return tokenizer.nextToken(); } String nextLine() { if (tokenizer == null || !tokenizer.hasMoreTokens()) { try { return reader.readLine(); } catch (IOException e) { throw new RuntimeException(e); } } return tokenizer.nextToken("\n"); } long nextLong() { return Long.parseLong(next()); } int nextInt() { return Integer.parseInt(next()); } int[] nextIntArray(int n) { int[] a = new int[n]; for (int i = 0; i < n; i++) a[i] = nextInt(); return a; } int[] nextIntArray(int n, int delta) { int[] a = new int[n]; for (int i = 0; i < n; i++) a[i] = nextInt() + delta; return a; } long[] nextLongArray(int n) { long[] a = new long[n]; for (int i = 0; i < n; i++) a[i] = nextLong(); return a; } } static <A> void writeLines(A[] as, Function<A, String> f) { PrintWriter pw = new PrintWriter(System.out); for (A a : as) { pw.println(f.apply(a)); } pw.flush(); } static void writeLines(int[] as) { PrintWriter pw = new PrintWriter(System.out); for (int a : as) pw.println(a); pw.flush(); } static void writeLines(long[] as) { PrintWriter pw = new PrintWriter(System.out); for (long a : as) pw.println(a); pw.flush(); } static int max(int... as) { int max = Integer.MIN_VALUE; for (int a : as) max = Math.max(a, max); return max; } static int min(int... as) { int min = Integer.MAX_VALUE; for (int a : as) min = Math.min(a, min); return min; } static void debug(Object... args) { StringJoiner j = new StringJoiner(" "); for (Object arg : args) { if (arg instanceof int[]) j.add(Arrays.toString((int[]) arg)); else if (arg instanceof long[]) j.add(Arrays.toString((long[]) arg)); else if (arg instanceof double[]) j.add(Arrays.toString((double[]) arg)); else if (arg instanceof Object[]) j.add(Arrays.toString((Object[]) arg)); else j.add(arg.toString()); } System.err.println(j.toString()); } }