import java.util.Arrays; import java.util.Scanner; class Main { long N; long M; class S { long a, b; public S(long a_, long b_) { a = a_ % M; b = b_ % M; } S mul(S s) { return new S(a * s.a % M + 3 * b % M * s.b % M, a * s.b % M + b * s.a % M); } S add(S s) { return new S(a + s.a, b + s.b); } S pow(long n) { S ret = new S(1, 0); S mul = new S(a, b); for (; n > 0; n >>= 1, mul = mul.mul(mul)) { if (n % 2 == 1) { ret = ret.mul(mul); } } return ret; } } void run() { Scanner sc = new Scanner(System.in); N = sc.nextLong(); M = sc.nextLong(); S s1 = new S(2, 1); S s2 = new S(2, -1); S ret; if (f(3) == 1) { ret = s1.pow(pow(2, N, M - 1)).add(s2.pow(pow(2, N, M - 1))); } else { ret = s1.pow(pow(2, N, M + 1)).add(s2.pow(pow(2, N, M + 1))); } System.out.println((ret.a - 2 + M) % M); } long f(long a) { return pow(a, (M - 1) / 2, M); } long pow(long a, long n, long mod) { long ret = 1; for (; n > 0; n >>= 1, a = a * a % mod) { if (n % 2 == 1) { ret = ret * a % mod; } } return ret; } void tr(Object... objects) { System.out.println(Arrays.deepToString(objects)); } public static void main(String[] args) { new Main().run(); } }