結果

問題 No.214 素数サイコロと合成数サイコロ (3-Medium)
ユーザー akakimidoriakakimidori
提出日時 2021-12-09 04:59:30
言語 Rust
(1.77.0)
結果
AC  
実行時間 275 ms / 3,000 ms
コード長 3,905 bytes
コンパイル時間 5,081 ms
コンパイル使用メモリ 156,288 KB
実行使用メモリ 4,380 KB
最終ジャッジ日時 2023-09-23 15:55:42
合計ジャッジ時間 3,445 ms
ジャッジサーバーID
(参考情報)
judge13 / judge11
このコードへのチャレンジ(β)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 268 ms
4,376 KB
testcase_01 AC 252 ms
4,380 KB
testcase_02 AC 275 ms
4,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

pub trait SemiRing: Clone {
    fn zero() -> Self;
    fn one() -> Self;
    fn add(&self, rhs: &Self) -> Self;
    fn mul(&self, rhs: &Self) -> Self;
}

pub struct Kitamasa<T> {
    c: Vec<T>,
    tmp: std::cell::RefCell<Vec<T>>,
}

impl<T: SemiRing> Kitamasa<T> {
    pub fn new(c: Vec<T>) -> Self {
        assert!(!c.is_empty());
        Self {
            c: c,
            tmp: std::cell::RefCell::new(vec![]),
        }
    }
    pub fn normalize(&self, d: &mut Vec<T>) {
        let n = self.c.len();
        for i in (n..d.len()).rev() {
            let v = d.pop().unwrap();
            for (d, c) in d[(i - n)..].iter_mut().zip(&self.c) {
                *d = d.add(&v.mul(c));
            }
        }
    }
    pub fn next(&self, d: &mut Vec<T>) {
        d.insert(0, T::zero());
        self.normalize(d);
    }
    pub fn twice(&self, d: &mut Vec<T>) {
        assert!(!d.is_empty());
        let mut tmp = self.tmp.borrow_mut();
        tmp.clear();
        tmp.resize(d.len() * 2 - 1, T::zero());
        for (i, a) in d.iter().enumerate() {
            for (tmp, b) in tmp[i..].iter_mut().zip(d.iter()) {
                *tmp = a.mul(b).add(tmp);
            }
        }
        std::mem::swap(&mut *tmp, d);
        drop(tmp);
        self.normalize(d);
    }
    pub fn kth_coefficient(&self, k: usize) -> Vec<T> {
        let mut t = vec![T::one()];
        if k > 0 {
            let p = (k + 1).next_power_of_two().trailing_zeros() - 1;
            for i in (0..=p).rev() {
                self.twice(&mut t);
                if k >> i & 1 == 1 {
                    self.next(&mut t);
                }
            }
        }
        t.resize(self.c.len(), T::zero());
        t
    }
}

const MOD: u32 = 1_000_000_007;

impl SemiRing for u32 {
    fn zero() -> Self {
        0
    }
    fn one() -> Self {
        1
    }
    fn add(&self, rhs: &Self) -> Self {
        let mut v = *self + *rhs;
        if v >= MOD {
            v -= MOD;
        }
        v
    }
    fn mul(&self, rhs: &Self) -> Self {
        (*self as u64 * *rhs as u64 % MOD as u64) as u32
    }
}

fn read() -> (usize, usize, usize) {
    let mut s = String::new();
    std::io::stdin().read_line(&mut s).unwrap();
    let a = s.trim().split_whitespace().flat_map(|s| s.parse()).collect::<Vec<_>>();
    (a[0], a[1], a[2])
}

fn main() {
    let (n, p, c) = read();
    let prime = [2, 3, 5, 7, 11, 13];
    let composite = [4, 6, 8, 9, 10, 12];
    let len = p * 13 + 12 * c;
    let mut dp = vec![vec![0u32; len + 1]; p + 1];
    dp[0][0] = 1;
    for &a in prime.iter() {
        for i in 0..p {
            let s = std::mem::take(&mut dp[i]);
            for (dp, s) in dp[i + 1][a..].iter_mut().zip(&s) {
                *dp = dp.add(s);
            }
            dp[i] = s;
        }
    }
    let res = dp.pop().unwrap();
    let mut dp = vec![vec![0u32; len + 1]; c + 1];
    dp[0] = res;
    for &a in composite.iter() {
        for i in 0..c {
            let s = std::mem::take(&mut dp[i]);
            for (dp, s) in dp[i + 1][a..].iter_mut().zip(&s) {
                *dp = dp.add(s);
            }
            dp[i] = s;
        }
    }
    let mut c = dp.pop().unwrap();
    c.remove(0);
    let trans = c.clone();
    c.reverse();
    let solver = Kitamasa::new(c);
    let mut ini = vec![0; len];
    ini[0] = 1;
    for i in 0..len {
        let v = ini[i];
        for (ini, c) in ini[(i + 1)..].iter_mut().zip(&trans) {
            *ini = ini.add(&v.mul(c));
        }
    }
    let p = n.saturating_sub(len);
    let mut a = solver.kth_coefficient(p);
    let mut ans = 0;
    for i in p..n {
        let v = a.iter().zip(&ini).fold(0, |s, a| a.0.mul(a.1).add(&s));
        for (j, trans) in trans.iter().enumerate() {
            if i + j + 1 >= n {
                ans = trans.mul(&v).add(&ans);
            }
        }
        solver.next(&mut a);
    }
    println!("{}", ans);
}
0