import java.io.*; import java.lang.*; import java.util.*; import java.math.*; public class Main { static Scanner cin = new Scanner(System.in); public static void main(String[] args) { BigInteger n = new BigInteger(cin.next()); BigInteger i = BigInteger.ONE; i=i.add(BigInteger.ONE); BigInteger f = BigInteger.ONE; if(n.equals(BigInteger.ONE)) f=BigInteger.ZERO; for(;; i=i.add(BigInteger.ONE)) { if(i.equals(n)||f.equals(BigInteger.ZERO)) break; if(n.mod(i)==BigInteger.ZERO) f=BigInteger.ZERO; } if(f.equals(BigInteger.ONE)) System.out.println("YES"); else System.out.println("NO"); } }