結果

問題 No.1502 Many Simple Additions
ユーザー lam6er
提出日時 2025-04-09 20:57:44
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 385 ms / 2,000 ms
コード長 4,494 bytes
コンパイル時間 236 ms
コンパイル使用メモリ 82,780 KB
実行使用メモリ 149,644 KB
最終ジャッジ日時 2025-04-09 20:59:42
合計ジャッジ時間 6,551 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 5
other AC * 39
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque
MOD = 10**9 + 7

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr]); ptr +=1
    M = int(input[ptr]); ptr +=1
    K = int(input[ptr]); ptr +=1

    adj = [[] for _ in range(N+1)]
    for _ in range(M):
        X = int(input[ptr]); ptr +=1
        Y = int(input[ptr]); ptr +=1
        Z = int(input[ptr]); ptr +=1
        adj[X].append( (Y, Z) )
        adj[Y].append( (X, Z) )

    visited = [False] * (N+1)
    components = []

    for u in range(1, N+1):
        if not visited[u]:
            q = deque()
            a = {u: 1}
            b = {u: 0}
            q.append(u)
            visited[u] = True
            is_invalid = False
            fixed_t = None
            component = [u]

            while q and not is_invalid:
                current = q.popleft()
                for (neighbor, Z) in adj[current]:
                    if neighbor not in a:
                        a[neighbor] = -a[current]
                        b[neighbor] = Z - b[current]
                        visited[neighbor] = True
                        component.append(neighbor)
                        q.append(neighbor)
                    else:
                        coeff = a[current] + a[neighbor]
                        const = b[current] + b[neighbor] - Z
                        if coeff != 0:
                            if (Z - b[current] - b[neighbor]) % coeff != 0:
                                is_invalid = True
                                break
                            t_val = (Z - b[current] - b[neighbor]) // coeff
                            if fixed_t is not None and fixed_t != t_val:
                                is_invalid = True
                                break
                            fixed_t = t_val
                        else:
                            if const != 0:
                                is_invalid = True
                                break

            if is_invalid:
                print(0)
                return

            if fixed_t is not None:
                fixed_values = {}
                valid = True
                for node in component:
                    val = a[node] * fixed_t + b[node]
                    if val < 1:
                        valid = False
                    fixed_values[node] = val
                if not valid:
                    print(0)
                    return
                components.append( ('fixed', fixed_values) )
            else:
                var_a = []
                var_b = []
                for node in component:
                    var_a.append( a[node] )
                    var_b.append( b[node] )
                components.append( ('variable', var_a, var_b) )

    def compute_count(X):
        total = 1
        for comp in components:
            if comp[0] == 'fixed':
                values = comp[1]
                for v in values.values():
                    if not (1 <= v <= X):
                        return 0
            else:
                a_list = comp[1]
                b_list = comp[2]
                min_t = -float('inf')
                max_t = float('inf')

                for a_i, b_i in zip(a_list, b_list):
                    if a_i == 1:
                        t_min = max( 1 - b_i, -float('inf') )
                        t_max = X - b_i
                    else:
                        t_min = b_i - X
                        t_max = b_i - 1

                    if a_i == 1:
                        new_min = max(t_min, 1 - b_i)
                        new_max = min(t_max, X - b_i)
                    else:
                        new_min = max(t_min, b_i - X)
                        new_max = min(t_max, b_i - 1)

                    current_min = new_min
                    current_max = new_max

                    if current_min > current_max:
                        return 0

                    if current_min > min_t:
                        min_t = current_min
                    if current_max < max_t:
                        max_t = current_max

                if min_t > max_t:
                    return 0

                cnt = max_t - min_t + 1
                if cnt < 0:
                    return 0
                total = (total * cnt) % MOD
        return total

    count_K = compute_count(K)
    count_Km1 = compute_count(K-1)

    ans = (count_K - count_Km1) % MOD
    print(ans)

if __name__ == '__main__':
    main()
0