n,q = map(int,input().split()) adj = [[] for _ in range(n)]; mod = 998244353 for _ in range(q): a,b,c = map(int,input().split()) adj[a-1].append((b-1,c)); adj[b-1].append((a-1,c)) ans = 1; f = [-1]*n for i in range(n): if f[i]!=-1: continue ans = ans*2%mod; f[i] = 0; st = [i] while st: p = st.pop() for v,c in adj[p]: if f[v]==-1: f[v] = f[p]^c; st.append(v) elif f[v]!=f[p]^c: ans = 0 print(ans)