import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); char[] nArr = sc.next().toCharArray(); int modN = nArr[nArr.length - 1] - '0'; char[] mArr = sc.next().toCharArray(); if (mArr.length == 0 && mArr[0] == '0') { System.out.println(1); return; } if (nArr.length == 0 && nArr[0] == '0') { System.out.println(0); return; } int modM = 0; for (char c : mArr) { modM *= 10; modM += c - '0'; modM %= 8; } modM %= 4; modM += 4; long ans = 1; for (int i = 1; i <= modM; i++) { ans *= modN; } ans %= 10; System.out.println(ans); } }