import java.util.*; public class Main { static int N; static ArrayList<ArrayList<Integer>> edge; static long[] dp; static long[] cnt; static int mod = 1_000_000_007; public static void solve(int i) { if (dp[i] != -1) return; long val = 0; long val_cnt = 1; for (Integer tmp : edge.get(i)) { solve(tmp); val = (val+dp[tmp]+cnt[tmp])%mod; val_cnt += cnt[tmp]; } dp[i] = val; cnt[i] = val_cnt; } public static void main(String[] args) { Scanner sc = new Scanner(System.in); N = sc.nextInt(); edge = new ArrayList<ArrayList<Integer>>(); for (int i=0;i<N;i++) { ArrayList<Integer> add = new ArrayList<Integer>(); edge.add(add); } boolean[] not_root = new boolean[N]; for (int i=0;i<N-1;i++) { int A = sc.nextInt()-1; int B = sc.nextInt()-1; edge.get(A).add(B); not_root[B] = true; } int root = -1; for (int i=0;i<N;i++) { if (!not_root[i]) root = i; } dp = new long[N]; cnt = new long[N]; Arrays.fill(dp, -1); solve(root); long ans = 0L; for (int i=0;i<N;i++) { ans = (ans+dp[i])%mod; } System.out.println(ans); } }