結果

問題 No.1470 Mex Sum
ユーザー CaliPotaCaliPota
提出日時 2021-04-09 21:30:22
言語 Java21
(openjdk 21)
結果
AC  
実行時間 121 ms / 2,000 ms
コード長 18,939 bytes
コンパイル時間 3,873 ms
コンパイル使用メモリ 90,920 KB
実行使用メモリ 52,260 KB
最終ジャッジ日時 2024-06-25 04:15:25
合計ジャッジ時間 9,958 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 54 ms
50,084 KB
testcase_01 AC 55 ms
50,372 KB
testcase_02 AC 78 ms
50,308 KB
testcase_03 AC 79 ms
50,572 KB
testcase_04 AC 79 ms
50,504 KB
testcase_05 AC 67 ms
50,416 KB
testcase_06 AC 79 ms
50,564 KB
testcase_07 AC 81 ms
50,276 KB
testcase_08 AC 77 ms
50,516 KB
testcase_09 AC 69 ms
50,520 KB
testcase_10 AC 79 ms
50,640 KB
testcase_11 AC 78 ms
50,592 KB
testcase_12 AC 121 ms
52,260 KB
testcase_13 AC 114 ms
51,992 KB
testcase_14 AC 115 ms
52,044 KB
testcase_15 AC 119 ms
52,256 KB
testcase_16 AC 117 ms
51,928 KB
testcase_17 AC 118 ms
52,240 KB
testcase_18 AC 117 ms
51,864 KB
testcase_19 AC 117 ms
52,172 KB
testcase_20 AC 111 ms
52,024 KB
testcase_21 AC 110 ms
52,124 KB
testcase_22 AC 113 ms
52,200 KB
testcase_23 AC 114 ms
52,112 KB
testcase_24 AC 117 ms
52,164 KB
testcase_25 AC 116 ms
51,860 KB
testcase_26 AC 114 ms
52,032 KB
testcase_27 AC 112 ms
52,080 KB
testcase_28 AC 114 ms
52,260 KB
testcase_29 AC 117 ms
52,224 KB
testcase_30 AC 113 ms
51,908 KB
testcase_31 AC 115 ms
52,180 KB
testcase_32 AC 115 ms
51,904 KB
testcase_33 AC 112 ms
51,788 KB
testcase_34 AC 114 ms
51,752 KB
testcase_35 AC 118 ms
52,184 KB
testcase_36 AC 118 ms
52,120 KB
testcase_37 AC 105 ms
51,964 KB
testcase_38 AC 106 ms
52,016 KB
testcase_39 AC 105 ms
51,976 KB
testcase_40 AC 109 ms
51,660 KB
testcase_41 AC 113 ms
52,044 KB
testcase_42 AC 109 ms
51,996 KB
testcase_43 AC 112 ms
52,028 KB
testcase_44 AC 110 ms
51,764 KB
testcase_45 AC 109 ms
52,048 KB
testcase_46 AC 112 ms
52,016 KB
testcase_47 AC 112 ms
52,116 KB
testcase_48 AC 112 ms
52,024 KB
testcase_49 AC 111 ms
51,936 KB
testcase_50 AC 110 ms
52,092 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.NoSuchElementException;


//@Japanese
/* PriorityQueueは拡張for文で出すとsortされてない順番で出てくる
 * longのbit演算は1L<<posに注意
 * JOIはMLEが厳しい。最悪shortを使う。
 * ArrayListはオートボクシングが遅いから、最悪自作listを使う
 */

/*略語
 * 2-to, 4-for, 8-from
 * -L -long(型), -P -素数Pを法とした余り
 * a,b -任意の引数, n,m -自然数, p -素数
 * pos -postition
 * abs -絶対値
 * min -minimum, max -maximum, ave -average
 * div -divide
 * pow -power(累乗)
 * ceil -ceiling(天井関数)
 * dt -data
 * ln -length
 * sc -scanner
 * INF -INFINITY
 * e97 -10E9+7=1000000007(prime often used)
 */

public class Main {
	public static void main(String[] args) throws Exception {

		FastScanner sc = new FastScanner();
		PrintWriter out = new PrintWriter(System.out);
		
		int n = sc.nexI();
		long[] as = new long[n];
		sc.al(as);
		
		long n1 = 0, n2=0;
		for(int i=0; i<n; i++) {
			if(as[i]==1L) n1++;
			else if(as[i]==2L) n2++;
		}
		
		long ans = 2L*nC2((long)n);
		
		ans -= nC2((long)(n-n1));
		ans += n1*n2;
		System.out.println(ans);
		
		out.flush();
		return;
	}



	private static final int INF = (int) 1e8;
	private static final long INFL = (long) 1e17;
	private static final long e97 = 1000000007L;
	private static final long e99 = 998244353L;
	private static final double PI = Math.PI;

	private static void assertion(boolean should_true) {
		// throw Error if should_true is not true @Japanese「断言」
		if (!should_true)
			throw new AssertionError();
	}

	private static int abs(int a) {
		return (a >= 0) ? a : -a;
	}

	private static long abs(long a) {
		return (a >= 0) ? a : -a;
	}

	private static double abs(double a) {
		return (a >= 0.0) ? a : -a;
	}

	private static int min(int a, int b) {
		return (a > b) ? b : a;
	}

	private static long min(long a, long b) {
		return (a > b) ? b : a;
	}

	private static double min(double a, double b) {
		return (a > b) ? b : a;
	}

	private static int max(int a, int b) {
		return (a > b) ? a : b;
	}

	private static long max(long a, long b) {
		return (a > b) ? a : b;
	}

	private static double max(double a, double b) {
		return (a > b) ? a : b;
	}

	private static int pow2(int num2pow) {
		return num2pow * num2pow;
	}

	private static long pow2(long num2pow) {
		return num2pow * num2pow;
	}

	private static int pow(int num_powered, int index) {
		int ans = 1;
		for (int i = 0; i < index; i++) {
			ans *= num_powered;
			assertion(ans >= 0);
		}
		return ans;
	}

	private static long pow(long num_powered, int index) {
		long ans = 1L;
		for (int i = 0; i < index; i++) {
			ans *= num_powered;
			assertion(ans >= 0L);
		}
		return ans;
	}

	private static long powP(long num_powered, long index, long p) {
		// @Japanese 繰り返し二乗法
		if (num_powered == 0L)
			return 0L;
		if (index == 0L)
			return 1L;
		if (index == 2L) {
			long ans = num_powered * num_powered;
			assertion(ans > 0);
			return ans % p;
		}

		int d = getDigit2(index);
		long[] num_done_by2 = new long[d + 1];
		num_done_by2[0] = num_powered;
		for (int i = 1; i <= d; i++) {
			num_done_by2[i] = num_done_by2[i - 1] * num_done_by2[i - 1];
			assertion(num_done_by2[i] > 0);
			num_done_by2[i] %= p;
		}
		long ans = 1L;
		for (int i = d; i >= 0; i--) {
			if (index >= (1L << (long) i)) {
				index -= (1L << (long) i);
				ans = ans * num_done_by2[i];
				assertion(ans > 0);
				ans %= p;
			}
		}
		return ans % p;
	}

	private static double hypod(double a, double b) {
		return Math.sqrt(a * a + b * b);
	}

	private static int getDigit2(long num2know) {
		long compare4 = 1L;
		int digit = 0;
		while (num2know > compare4) {
			digit++;
			compare4 = (1L << (long) digit);
		}
		return digit;
		// num <= 2^digit
	}

	private static int getDigit10(long num2know) {
		long compare4 = 1L;
		int digit = 0;
		while (num2know >= compare4) {
			digit++;
			compare4 *= 10L;
		}
		return digit; // @Japanese num は digit桁の数で、10^digit未満

	}

	private static int divceil(int numerator, int denominator) {
		return (numerator + denominator - 1) / denominator;
	}

	private static long divceil(long numerator, long denominator) {
		return (numerator + denominator - 1L) / denominator;
	}

	private static long factorial(int n) {
		long ans = 1L;
		for (long i = n; i > 0; i--) {
			ans *= i;
			assertion(ans >= 0L);
		}
		return ans;
	}

	private static long facP(int n, long p) {
		long ans = 1L;
		for (long i = n; i > 0; i--) {
			ans *= i;
			ans %= p;
		}
		return ans;
	}

	private static long lcm(long m, long n) {
		assertion((m > 0L) && (n > 0L));
		long ans = m / gcd(m, n);
		ans *= n;
		return ans;
	}

	private static long gcd(long m, long n) {
		// if(isINFL(-m)) return n;
		assertion((m >= 0L) && (n >= 0L));
		if (m < n)
			return gcd(n, m);
		if (n == 0)
			return m;
		return gcd(n, m % n);
	}

	private static boolean is_prime(long n2check) {
		if (n2check == 1L)
			return false;
		for (long i = 2L; i <= Math.sqrt(n2check); i++) {
			if (n2check % i == 0)
				return false;
		}
		return true;
	}

	private static long modinv(long n, long p) {
		assertion((n > 0L) && (p > 1L) && (gcd(n, p) == 1L));
		n %= p;

		// @Japanese
		// a⊥p, >1に注意

		// yn≡1(mod p)
		// <-> xp+yn=1; (n<p)
		// ...(sx+ty)a+(ux+vy)b=1 (|sv-tu|=1)
		// ...(sx+ty)a+(ux+vy)=1
		// <- sx+ty=0, ux+vy=1

		long a = p, b = n, s = 1, t = 0, u = 0, v = 1;
		while (b > 1) {
			long quo = a / b;
			long rem = a % b;
			a = b;
			b = rem;
			long s2 = s * quo + u, t2 = t * quo + v;
			u = s;
			v = t;
			s = s2;
			t = t2;
		}
		long det = s * v - t * u;
		assertion(abs(det) == 1);
		s /= det;
		s %= p;
		if (s < 0L)
			s += p;
		return s;
	}

	private static int minAll(int[] dt4min) {
		int min = INF;
		for (int element : dt4min) {
			if (element < min)
				min = element;
		}
		return min;
	}

	private static long minAll(long[] dt4min) {
		long min = INFL;
		for (long element : dt4min) {
			if (element < min)
				min = element;
		}
		return min;
	}

	private static int maxAll(int[] dt4max) {
		int max = -INF;
		for (int element : dt4max) {
			if (element < max)
				max = element;
		}
		return max;
	}

	private static long maxAll(long[] dt4max) {
		long max = -INFL;
		for (long element : dt4max) {
			if (element < max)
				max = element;
		}
		return max;
	}

	private static int sumAll(int[] dt4sum) {
		int sum_of_dt = 0;
		for (int element : dt4sum) {
			sum_of_dt += element;
		}
		return sum_of_dt;
	}

	private static long sumAll(long[] dt4sum) {
		long sum_of_dt = 0L;
		for (long element : dt4sum) {
			sum_of_dt += element;
		}
		return sum_of_dt;
	}

	private static int sumAll(ArrayList<Integer> dt4sum) {
		int sum_of_dt = 0;
		for (long element : dt4sum) {
			sum_of_dt += element;
		}
		return sum_of_dt;
	}
	/*
	 * private static long sumAll(ArrayList<Long> dt4sum) { long sum_of_dt = 0L;
	 * for(long element: dt4sum) { sum_of_dt += element; } return sum_of_dt; }
	 */

	private static int[] reverse(int[] as) {
		int ln = as.length;
		int[] bs = new int[ln];
		for (int i = 0; i < ln; i++) {
			bs[i] = as[ln - i - 1];
		}
		return bs;
	}

	private static void reverseSub(int[] as, int S_include, int Gnot_include) {
		// Caution: O(n)
		int ln = Gnot_include - S_include;

		int[] bs = new int[ln];
		for (int i = S_include; i < Gnot_include; i++) {
			bs[i - S_include] = as[i];
		}
		for (int i = 0; i < ln; i++) {
			as[i + S_include] = bs[ln - i - 1];
		}
	}

	private static boolean is_in_area(int y, int x, int height, int width) {
		if (y < 0)
			return false;
		if (x < 0)
			return false;
		if (y >= height)
			return false;
		if (x >= width)
			return false;
		return true;
	}

	private static boolean is_in_area(Vector v, int height, int width) {
		if (v.y < 0)
			return false;
		if (v.x < 0)
			return false;
		if (v.y >= height)
			return false;
		if (v.x >= width)
			return false;
		return true;
	}

	private static int nC2(int n) {
		return ((n * (n - 1)) / 2);
	}

	private static long nC2(long n) {
		return ((n * (n - 1L)) / 2L);
	}

	private static int iflag(int pos) {
		assertion(pos < 31);
		return (1 << pos);
	}

	private static long flag(int pos) {
		assertion(pos < 63);
		return (1L << (long) pos);
	}

	private static boolean isFlaged(int bit, int pos) {
		if ((bit & (1 << pos)) > 0)
			return true;
		else
			return false;
	}

	private static boolean isFlaged(long bit, int pos) {
		if ((bit & (1L << (long) pos)) > 0L)
			return true;
		else
			return false;
	}

	private static int deflag(int bit, int pos) {
		return bit & ~(1 << pos);
	}

	private static int countFlaged(int bit) {
		int ans = 0;
		for (int i = 0; i < 31; i++) {
			if ((bit & (1 << i)) > 0)
				ans++;
		}
		return ans;
	}

	private static int countFlaged(long bit) {
		int ans = 0;
		for (long i = 0; i < 63L; i++) {
			if ((bit & (1L << i)) > 0)
				ans++;
		}
		return ans;
	}

	private static int[] Xdir4 = { 1, 0, 0, -1 };
	private static int[] Ydir4 = { 0, 1, -1, 0 };
	private static int[] Xdir8 = { 1, 1, 1, 0, 0, -1, -1, -1 };
	private static int[] Ydir8 = { 1, 0, -1, 1, -1, 1, 0, -1 };

	public static int biSearch(int[] dt, int target) {
		int left = 0, right = dt.length - 1;
		int mid = -1;
		while (left <= right) {
			mid = (right + left) / 2;
			if (dt[mid] == target)
				return mid;
			if (dt[mid] < target)
				left = mid + 1;
			else
				right = mid - 1;
		}
		return -1;
	}

	public static int biSearchMax(long[] dt, long target) {
		int left = -1, right = dt.length, mid = -1;
		while ((right - left) > 1) {
			mid = (right + left) / 2;
			if (dt[mid] <= target)
				left = mid;
			else
				right = mid;
		}
		return left;
		// @Japanese target以下の最大のaddress
	}

	public static int biSearchMin(long[] dt, long target) {
		int left = -1, right = dt.length, mid = -1;
		while ((right - left) > 1) {
			mid = (right + left) / 2;
			if (dt[mid] <= target)
				left = mid;
			else
				right = mid;
		}
		return right;
		// @Japanese targetより大きい最小のaddress
	}

	private static void fill(boolean[] target, boolean reset) {
		for (int i = 0; i < target.length; i++)
			target[i] = reset;
	}

	private static void fill(int[] target, int reset) {
		for (int i = 0; i < target.length; i++)
			target[i] = reset;
	}

	private static void fill(long[] target, long reset) {
		for (int i = 0; i < target.length; i++)
			target[i] = reset;
	}

	private static void fill(char[] target, char reset) {
		for (int i = 0; i < target.length; i++)
			target[i] = reset;
	}

	private static void fill(double[] target, double reset) {
		for (int i = 0; i < target.length; i++)
			target[i] = reset;
	}

	private static void fill(boolean[][] target, boolean reset) {
		for (int i = 0; i < target.length; i++) {
			for (int j = 0; j < target[i].length; j++) {
				target[i][j] = reset;
			}
		}
	}

	private static void fill(int[][] target, int reset) {
		for (int i = 0; i < target.length; i++) {
			for (int j = 0; j < target[i].length; j++) {
				target[i][j] = reset;
			}
		}
	}

	private static void fill(long[][] target, long reset) {
		for (int i = 0; i < target.length; i++) {
			for (int j = 0; j < target[i].length; j++) {
				target[i][j] = reset;
			}
		}
	}

	private static void fill(char[][] target, char reset) {
		for (int i = 0; i < target.length; i++) {
			for (int j = 0; j < target[i].length; j++) {
				target[i][j] = reset;
			}
		}
	}

	private static void fill(double[][] target, double reset) {
		for (int i = 0; i < target.length; i++) {
			for (int j = 0; j < target[i].length; j++) {
				target[i][j] = reset;
			}
		}
	}

	private static void fill(int[][][] target, int reset) {
		for (int i = 0; i < target.length; i++) {
			for (int j = 0; j < target[i].length; j++) {
				for (int k = 0; k < target[i][j].length; k++) {
					target[i][j][k] = reset;
				}
			}
		}
	}

	private static void fill(long[][][] target, long reset) {
		for (int i = 0; i < target.length; i++) {
			for (int j = 0; j < target[i].length; j++) {
				for (int k = 0; k < target[i][j].length; k++) {
					target[i][j][k] = reset;
				}
			}
		}
	}

	private static void showBit(int bit) {
		for (int i = 0; i < getDigit2(bit); i++) {
			if (isFlaged(bit, i))
				System.out.print("O");
			else
				System.out.print(".");
		}
		System.out.println();
	}

	private static void showBit(long bit) {
		for (int i = 0; i < getDigit2(bit); i++) {
			if (isFlaged(bit, i))
				System.out.print("O");
			else
				System.out.print(".");
		}
		System.out.println();
	}

	static void show2(boolean[][] dt, String cmnt) {
		for (int i = 0; i < dt.length; i++) {
			for (int j = 0; j < dt[i].length; j++) {
				if (dt[i][j])
					System.out.print("O");
				else
					System.out.print(".");
			}
			if (!cmnt.equals(""))
				System.out.print("<-" + cmnt);
			System.out.println(" :" + i);
		}
	}

	static void show2(int[][] dt, String cmnt) {
		for (int i = 0; i < dt.length; i++) {
			for (int j = 0; j < dt[i].length; j++) {
				System.out.print(dt[i][j] + ",");
			}
			if (!cmnt.equals(""))
				System.out.print("<-" + cmnt);
			System.out.println(" :" + i);
		}
	}

	static void show2(long[][] dt, String cmnt) {
		for (int i = 0; i < dt.length; i++) {
			for (int j = 0; j < dt[i].length; j++) {
				System.out.print(dt[i][j] + ",");
			}
			if (!cmnt.equals(""))
				System.out.print("<-" + cmnt);
			System.out.println(" :" + i);
		}
	}

	static void show2(ArrayDeque<Long> dt) {
		long element = 0;
		while (dt.size() > 0) {
			element = dt.removeFirst();
			System.out.print(element);
		}
		System.out.println("\n");
	}

	static void show2(List<Object> dt) {
		for (int i = 0; i < dt.size(); i++) {
			System.out.print(dt.get(i) + ",");
		}
		System.out.println("\n");
	}

	private static void prtlnas(int[] array) {
		PrintWriter out = new PrintWriter(System.out);
		for (int i = 0; i < array.length; i++) {
			out.println(array[i]);
		}
		out.flush();
	}

	private static void prtlnas(long[] array) {
		PrintWriter out = new PrintWriter(System.out);
		for (int i = 0; i < array.length; i++) {
			out.println(array[i]);
		}
		out.flush();
	}

	private static void prtlnas(ArrayList<Object> array) {
		PrintWriter out = new PrintWriter(System.out);
		for (int i = 0; i < array.size(); i++) {
			out.println(array.get(i));
		}
		out.flush();
	}

	private static void prtspas(int[] array) {
		PrintWriter out = new PrintWriter(System.out);
		out.print(array[0]);
		for (int i = 1; i < array.length; i++) {
			out.print(" " + array[i]);
		}
		out.println();
		out.flush();
	}

	private static void prtspas(long[] array) {
		PrintWriter out = new PrintWriter(System.out);
		out.print(array[0]);
		for (int i = 1; i < array.length; i++) {
			out.print(" " + array[i]);
		}
		out.println();
		out.flush();
	}

	private static void prtspas(double[] array) {
		PrintWriter out = new PrintWriter(System.out);
		out.print(array[0]);
		for (int i = 1; i < array.length; i++) {
			out.print(" " + array[i]);
		}
		out.println();
		out.flush();
	}

	private static void prtspas(ArrayList<Integer> array) {
		PrintWriter out = new PrintWriter(System.out);
		out.print(array.get(0));
		for (int i = 1; i < array.size(); i++) {
			out.print(" " + array.get(i));
		}
		out.println();
		out.flush();
	}

	static class Vector {
		int x, y;

		public Vector(int sx, int sy) {
			this.x = sx;
			this.y = sy;
		}

		public boolean equals(Vector v) {
			return (this.x==v.x && this.y == v.y);
		}
		public void show2() {
			System.out.println(this.x+", "+this.y);
		}
        public static int dist2(Vector a, Vector b){
        	int dx = abs(a.x-b.x);
        	int dy = abs(a.y-b.y);
        	if(dx>0) assertion((INF/dx) >= dx);
        	if(dy>0) assertion((INF/dy) >= dy);
            return (dx*dx + dy*dy);
        }
	}

	static class CompVector implements Comparator<Vector>{
		public int compare(Vector a, Vector b) {
			if(a.x==b.x) return a.y-b.y;
			else return a.x-b.x;
		}
	}


	static class FastScanner {
		private final InputStream in = System.in;
		private final byte[] buffer = new byte[1024];
		private int ptr = 0;
		private int buflen = 0;

		private boolean hasNextByte() {
			if (ptr < buflen) {
				return true;
			} else {
				ptr = 0;
				try {
					buflen = in.read(buffer);
				} catch (IOException e) {
					e.printStackTrace();
				}
				if (buflen <= 0) {
					return false;
				}
			}
			return true;
		}

		private int readByte() {
			if (hasNextByte())
				return buffer[ptr++];
			else
				return -1;
		}

		private static boolean isPrintableChar(int c) {
			return (33 <= c) && (c <= 126);
		}

		public boolean hasNext() {
			while (hasNextByte() && !isPrintableChar(buffer[ptr]))
				ptr++;
			return hasNextByte();
		}

		public String next() {
			if (!hasNext())
				throw new NoSuchElementException();
			StringBuilder sb = new StringBuilder();
			int b = readByte();
			while (isPrintableChar(b)) {
				sb.appendCodePoint(b);
				b = readByte();
			}
			return sb.toString();
		}

		public long nexL() {
			if (!hasNext())
				throw new NoSuchElementException();
			long n = 0;
			boolean minus = false;
			int b = readByte();
			if (b == '-') {
				minus = true;
				b = readByte();
			}
			if (b < '0' || '9' < b) {
				throw new NumberFormatException();
			}
			while (true) {
				if ('0' <= b && b <= '9') {
					n *= 10;
					n += b - '0';
				} else if (b == -1 || !isPrintableChar(b) || b == ':') {
					return minus ? -n : n;
				} else {
					throw new NumberFormatException();
				}
				b = readByte();
			}
		}

		public int nexI() {
			long nl = nexL();
			if (nl < Integer.MIN_VALUE || nl > Integer.MAX_VALUE) {
				throw new NumberFormatException();
			}
			return (int) nl;
		}

		public double nexD() {
			return Double.parseDouble(next());
		}

		// a means array
		public void ai(int[]... array) {
			for (int i = 0; i < array[0].length; i++) {
				for (int j = 0; j < array.length; j++) {
					array[j][i] = nexI();
				}
			}
			return;
		}

		public void al(long[]... array) {
			for (int i = 0; i < array[0].length; i++) {
				for (int j = 0; j < array.length; j++) {
					array[j][i] = nexL();
				}
			}
			return;
		}

		public void aimin1(int[] array) {
			for (int i = 0; i < array.length; i++) {
				array[i] = nexI() - 1;
			}
			return;
		}

		public void aD(double[] array) {
			for (int i = 0; i < array.length; i++) {
				array[i] = nexD();
			}
			return;
		}

		public void ai2d(int[][] array) {
			for (int i = 0; i < array.length; i++) {
				for (int j = 0; j < array[0].length; j++) {
					array[i][j] = nexI();
				}
			}
			return;
		}

		public void al2d(long[][] array) {
			for (int i = 0; i < array.length; i++) {
				for (int j = 0; j < array[0].length; j++) {
					array[i][j] = nexL();
				}
			}
			return;
		}
	}
}
// END OF THE CODE
0