結果
問題 |
No.1263 ご注文は数学ですか?
|
ユーザー |
![]() |
提出日時 | 2025-05-14 13:18:16 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 39 ms / 2,000 ms |
コード長 | 6,161 bytes |
コンパイル時間 | 305 ms |
コンパイル使用メモリ | 82,448 KB |
実行使用メモリ | 52,480 KB |
最終ジャッジ日時 | 2025-05-14 13:19:06 |
合計ジャッジ時間 | 1,169 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 7 |
ソースコード
import sys # Set higher recursion depth if needed, although likely unnecessary for x<=8 # sys.setrecursionlimit(2000) def solve(): # Read input integer x x = int(sys.stdin.readline()) # Define the modulus MOD = 1000000007 # Precompute factorials modulo MOD up to x # fact[i] will store i! mod MOD fact = [1] * (x + 1) for i in range(2, x + 1): fact[i] = (fact[i-1] * i) % MOD # Modular exponentiation function: computes (base^exp) % mod # Uses binary exponentiation (also known as exponentiation by squaring) def mod_pow(base, exp, mod): res = 1 base %= mod while exp > 0: # If exponent is odd, multiply result with base if exp % 2 == 1: res = (res * base) % mod # Square the base and halve the exponent base = (base * base) % mod exp //= 2 return res # Modular inverse function using Fermat's Little Theorem: computes n^(-1) % mod # This requires mod to be prime and n not divisible by mod. # The problem constraints ensure MOD is prime and calculations guarantee n is not divisible by MOD. def mod_inverse(n, mod): # Check if n is congruent to 0 mod MOD. If so, inverse doesn't exist. # Based on problem constraints (x>=2) and properties of z_lambda, n should not be 0. if n % mod == 0: raise ValueError("Modular inverse does not exist for zero") # Fermat's Little Theorem states n^(mod-2) is congruent to n^(-1) mod mod for prime mod return mod_pow(n, mod - 2, mod) # List to store all partitions of x partitions = [] # Temporary list to build a partition during recursion current_partition_list = [] # Recursive function to generate partitions of 'target' integer # Partitions are generated in non-increasing order of parts. # 'max_val' restricts the maximum value of parts that can be added. def generate_partitions(target, max_val): # Base case: If target becomes 0, we have found a valid partition. if target == 0: # Add a copy of the current partition state to the list of partitions. partitions.append(list(current_partition_list)) return # Recursive step: Try adding parts from min(target, max_val) down to 1. # This ensures parts are added in non-increasing order. for i in range(min(target, max_val), 0, -1): # Add part 'i' to the current partition current_partition_list.append(i) # Recursively call to find partitions for the remaining value 'target - i' # The new max_val is 'i' to maintain non-increasing order. generate_partitions(target - i, i) # Backtrack: remove the last added part to explore other possibilities. current_partition_list.pop() # Start the partition generation process for the input integer x generate_partitions(x, x) # Get the total number of partitions, p(x) partition_count = len(partitions) # Initialize accumulator for the exponent of the overall sign factor (-1) total_sign_exponent = 0 # Initialize accumulator for the product of all z_lambda values, modulo MOD # Start with 1 because it's the identity element for multiplication. product_z_lambda = 1 # Iterate through each generated partition lambda for p in partitions: # Calculate the length (number of parts) of the partition lambda, denoted as ell(lambda) ell = len(p) # The sign component for this partition's corresponding f_lambda is (-1)^(x - ell(lambda)). # Sum the exponents (x - ell(lambda)) to find the total exponent for the overall sign. total_sign_exponent += (x - ell) # Calculate z_lambda for the current partition p. # First, count occurrences of each part size j. Store in dictionary 'counts' where counts[j] = m_j. counts = {} for part in p: counts[part] = counts.get(part, 0) + 1 # Calculate z_lambda = product over j >= 1 of (j^(m_j) * m_j!) mod MOD current_z_lambda = 1 for j, m_j in counts.items(): # Calculate j^(m_j) mod MOD term_pow = mod_pow(j, m_j, MOD) # Get m_j! mod MOD from precomputed factorials term_fact = fact[m_j] # Calculate the term (j^m_j * m_j!) mod MOD term = (term_pow * term_fact) % MOD # Accumulate the product for z_lambda current_z_lambda = (current_z_lambda * term) % MOD # Accumulate the product of all z_lambda values across all partitions product_z_lambda = (product_z_lambda * current_z_lambda) % MOD # Determine the overall sign factor based on the parity of total_sign_exponent sign = 1 if total_sign_exponent % 2 != 0: # If the exponent is odd, the sign is -1. Represent -1 as MOD-1 in modular arithmetic. sign = MOD - 1 # Calculate x! mod MOD using the precomputed value xf = fact[x] # Calculate (x! ^ p(x)) mod MOD using modular exponentiation xf_pow_p = mod_pow(xf, partition_count, MOD) # Calculate the modular inverse of (product of z_lambda) mod MOD inv_prod_z = mod_inverse(product_z_lambda, MOD) # Calculate the final value: sign * (x!^p(x)) * (product_z_lambda)^(-1) mod MOD # Perform calculations step-by-step modulo MOD to prevent overflow final_value = (sign * xf_pow_p) % MOD final_value = (final_value * inv_prod_z) % MOD # Ensure the final result is non-negative. Python's % operator handles negative numbers # in a way that might require adjustment for standard modular arithmetic representation (0 to MOD-1). # Specifically, if `final_value` is negative, add MOD to bring it into the range [0, MOD-1]. # This check is robust even if intermediate results became negative somehow, or if language spec differs. if final_value < 0: final_value += MOD # Print the final computed value print(final_value) # Execute the main function solve()