結果

問題 No.3182 recurrence relation’s intersection sum
コンテスト
ユーザー Today03
提出日時 2025-11-07 08:09:49
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 17,357 bytes
コンパイル時間 4,222 ms
コンパイル使用メモリ 309,316 KB
実行使用メモリ 7,848 KB
最終ジャッジ日時 2025-11-07 08:09:56
合計ジャッジ時間 6,081 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 3
other WA * 40
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std;
#define rep(i,n) for(int i=0; i<n; i++)
#define all(x) (x).begin(), (x).end()
bool chmax(auto& a, auto b) { return a<b ? a=b, true : false; }
bool chmin(auto& a, auto b) { return a>b ? a=b, true : false; }
using ll=long long; const int INF=1e9+10; const ll INFL=4e18;

#ifdef DEBUG
#include "./debug.hpp"
#else
#define debug(...)
#define print_line
#endif



/// @brief ModInt
template<ll MOD>
struct ModInt {
    ModInt(ll x=0){ value=(x>=0?x%MOD:MOD-(-x)%MOD); }

    ModInt operator-() const { return ModInt(-value); }
    ModInt operator+() const { return ModInt(*this); }
    ModInt& operator+=(const ModInt& other) {
        value+=other.value;
        if(value>=MOD) value-=MOD;
        return*this;
    }
    ModInt& operator-=(const ModInt& other) {
        value+=MOD-other.value;
        if(value>=MOD) value-=MOD;
        return*this;
    }
    ModInt& operator*=(const ModInt other) {
        value=value*other.value%MOD;
        return*this;
    }
    ModInt& operator/=(ModInt other) {
        (*this)*=other.inv();
        return*this;
    }
    ModInt operator+(const ModInt& other) const { return ModInt(*this)+=other; }
    ModInt operator-(const ModInt& other) const { return ModInt(*this)-=other; }
    ModInt operator*(const ModInt& other) const { return ModInt(*this)*=other; }
    ModInt operator/(const ModInt& other) const { return ModInt(*this)/=other; }
    bool operator==(const ModInt& other) const { return value==other.value; }
    bool operator!=(const ModInt& other) const { return value!=other.value; }
    friend ostream& operator<<(ostream& os, const ModInt& x) { return os<<x.value; }
    friend istream& operator>>(istream& is, ModInt& x) {
        ll v;
        is>>v;
        x=ModInt<MOD>(v);
        return is;
    }

    ModInt pow(ll x) const { 
        ModInt ret(1),mul(value);
        while(x) {
            if(x&1) ret*=mul;
            mul*=mul;
            x>>=1;
        }
        return ret;
    }
    ModInt inv() const { return pow(MOD-2); }
    ll val() {return value; }
    static constexpr ll mod() { return MOD; }

private:
    ll value;
};

using Mod998=ModInt<998244353>;
using Mod107=ModInt<1000000007>;



/// @brief 数論変換
/// @note O(N log(N))
/// @details f(x) = Σ a[i]x^i, w^N = 1 とすると、F(t) = Σ f(w^i)t^i の各係数 mod 998244353 に変換する
void NTT998(vector<Mod998>& a, bool inv=false) {
    static bool flag=false;
    static int divide_max;
    static vector<Mod998> roots, inv_roots, tmp;
    if(!flag) {
        flag=true;
        divide_max=1;
        ll n=998244353-1;
        while(n%2==0) n>>=1,divide_max++;
        roots=vector<Mod998>(divide_max+1);
        inv_roots=vector<Mod998>(divide_max+1);
        roots[0]=inv_roots[0]=1;
        for(int i=1; i<=divide_max; i++) {
            roots[i]=Mod998(3).pow((998244353-1)/(1<<i));
            inv_roots[i]=roots[i].inv();
        }
    }

    tmp=vector<Mod998>(a.size());
    int n=a.size(), mask=n-1, p=0;
    for(int i=n>>1; i>=1; i>>=1) {
        auto& cur=((p&1) ? tmp : a);
        auto& nxt=((p&1) ? a : tmp);
        Mod998 e=(inv ? roots[p+1] : inv_roots[p+1]);
        Mod998 w=1;
        for(int j=0; j<n; j+=i) {
            rep(k,i) nxt[j+k]=cur[((j<<1)&mask)+k]+w*cur[(((j<<1)+i)&mask)+k];
            w*=e;
        }
        p++;
    }
    if(p&1) swap(a,tmp);
    if(inv) {
        Mod998 inv_sz=Mod998(n).inv();
        for(int i=0; i<n; i++) a[i]*=inv_sz;
    }
}

/// @brief AとBの畳み込み C[i] = Σ A[j]B[i-j] mod 998244353 を返す
/// @note O(N log(N))
/// @attention |a|+|b| <= 2^23 が必要
vector<Mod998> Convolve998(vector<Mod998> a,vector<Mod998> b) {
    int n=1;
    while(n+1<a.size()+b.size()) n<<=1;

    vector<Mod998> fa(n), fb(n);
    for(int i=0; i<a.size(); i++) fa[i]=a[i];
    for(int i=0; i<b.size(); i++) fb[i]=b[i];

    NTT998(fa); NTT998(fb);
    for(int i=0; i<n; i++) fa[i]*=fb[i];
    NTT998(fa,true);
    
    while(fa.size()+1>a.size()+b.size()) fa.pop_back();
    return fa;
}


using Fps=vector<Mod998>; //< 多項式
using FpsSparse=vector<pair<int,Mod998>>; //< 疎な多項式((次数, 係数)の配列)

/// @brief 形式的冪級数
/// @ref https://potato167.github.io/po167_library
namespace FPS {
    /// @brief 多項式 f, g の和を返す
    Fps Add(const Fps& a, const Fps& b) {
        Fps res(max(a.size(),b.size()));
        for(int i=0; i<res.size(); i++) {
            if(i<a.size()) res[i]+=a[i];
            if(i<b.size()) res[i]+=b[i];
        }
        return res;
    }

    /// @brief 多項式 f, g の差を返す
    Fps Sub(const Fps& f, const Fps& g, int len=-1) {
        if(len==-1) len=max(f.size(),g.size());
        Fps res(len);
        for(int i=0; i<len; i++) {
            if(i<f.size()) res[i]+=f[i];
            if(i<g.size()) res[i]-=g[i];
        }
        return res;
    }

    /// @brief 多項式 f, g の積を返す
    Fps Mul(const Fps& f, const Fps& g, int len=-1) {
        auto fg=Convolve998(f,g);
        if(len!=-1) fg.resize(len+1);
        return fg;
    }

    /// @brief 多項式 f に対し、f*g = 1 なる g を返す
    /// @note O(N log(N))
    /// @ref https://judge.yosupo.jp/problem/inv_of_formal_power_series
    Fps Inv(Fps f, int len=-1) {
        /**
         * 方程式 h(g) = 1/g - f = 0 の解を求める問題に変換する
         * ニュートン法より
         *  g_{n+1} = g_n - h(g_n)/h'(g_n)
         * h'(g) = - 1/g^2 より
         *  g_{n+1} = g_n - (1/g_n - f) / (-1/g_n^2)
         *          = g_n + g_n - f*g_n^2
         *          = g_n * (2 - f*g_n)
         * ここで、f*g_n を上位次数の項と下位次数の項 fg_low, fg_high に分けると、
         *  fg_low + fg_high = f*g_n, fg_low = 1 (mod x^d)
         * よって、
         *  g_{n+1} = g_n * (2 - fg_low - fg_high)
         *          = g_n * (1 - fg_high)
         *          = g_n - g_n*fg_high
         */

        if(len==-1) len=f.size();
        assert(f[0]!=0);
        Fps g={f[0].inv()};
        int s=1;
        while(s<len) {
            Fps nxtg(s*2,0), res(s*2,0);
            g.resize(s*2);
            for(int i=0; i<s*2; i++) {
                if(int(f.size())>i) res[i]=f[i];
                nxtg[i]=g[i];
            }

            //fg_high を計算
            NTT998(g); NTT998(res);
            for(int i=0; i<s*2; i++) res[i]*=g[i];
            NTT998(res,true); for(int i=0; i<s; i++) res[i]=0; //fg_high

            //fg_high * g_n を計算
            NTT998(res); for(int i=0; i<s*2; i++) res[i]*=g[i]; NTT998(res,true);

            for(int i=s; i<s*2; i++) nxtg[i]-=res[i];
            swap(nxtg,g);
            s*=2;
        }
        g.resize(len);
        return g;
    }

    /// @brief 多項式 f の積分を返す
    Fps Integral(Fps f) {
        if(f.empty()) return f;
        Fps num_inv((int)f.size()+1);
        num_inv[0]=1; num_inv[1]=1;
        auto m=Mod998::mod();
        for(int i=2; i<=(int)f.size(); i++) num_inv[i]=(-num_inv[m%i])*(Mod998)(m/i);
        f.reserve((int)f.size()+1);
        f.push_back(0);
        for(int i=(int)f.size()-1; i>0; i--) f[i]=f[i-1]*num_inv[i];
        f[0]=0;
        return f;
    }

    /// @brief 多項式 f の微分を返す
    Fps Differential(Fps f) {
        if(f.empty()) return f;
        for(int i=0; i<(int)f.size()-1; i++) f[i]=f[i+1]*(Mod998)(i+1);
        f.pop_back();
        return f;
    }

    /// @brief 多項式 f, g について、`f = gq + r` なる q, r を返す
    pair<Fps,Fps> Div(Fps f, Fps g) {
        int n=f.size(),m=g.size();
        if(n<m) return{{},f};
        Fps r=f;
        reverse(all(f)); reverse(all(g));
        f.resize(n-m+1);
        Fps q=Mul(f,Inv(g,n-m+1));
        q.resize(n-m+1);
        reverse(all(q)); reverse(all(g));
        r=Sub(r,Mul(q,g));
        while(!q.empty()&&q.back()==0) q.pop_back();
        while(!r.empty()&&r.back()==0) r.pop_back();
        return {q,r};
    }

    /// @brief 多項式 f, g の積を返す(ただし、g は疎な多項式として与える)
    Fps MulSparse(Fps f, FpsSparse g) {
        auto itr=find_if(g.begin(),g.end(),[&](auto p) { return p.first==0; });
        Mod998 x0=0;
        if(itr!=g.end()) {
            x0=itr->second;
            g.erase(itr);
        }

        for(int i=(int)f.size()-1; i>=0; i--) {
            for(auto& [d,c]: g) {
                if(i+d>=f.size()) continue;
                f[i+d]+=f[i]*c;
            }
            f[i]*=x0;
        }

        return f;
    }

    /// @brief 多項式 f, g に対し、 f / g を返す(ただし、g は疎な多項式として与える)
    Fps DivSparse(Fps f, FpsSparse g) {
        auto itr=find_if(g.begin(),g.end(),[&](auto p) { return p.first==0; });
        assert(itr!=g.end());
        Mod998 x0_inv=itr->second.inv();
        g.erase(itr);

        for(int i=0; i<f.size(); i++) {
            f[i]*=x0_inv;
            for(auto& [d,c]: g) {
                if(i+d>=f.size()) continue;
                f[i+d]-=f[i]*c;
            }
        }

        return f;
    }

    namespace Internal {
        const int PRIMITIVE_ROOT=3;

        Fps CyclicConvolution(Fps f, Fps g) {
            NTT998(f); NTT998(g);
            for(int i=0; i<(int)f.size(); i++) f[i]*=g[i];
            NTT998(f,true);
            return f;
        }

        //in :DFT(v)(len(v)=z)
        //out:DFT(v)(len(v)=2*z)
        void Extend(Fps& v) {
            int z=v.size();
            Mod998 e=Mod998(PRIMITIVE_ROOT).pow(Mod998::mod()/(2*z));
            auto cp=v;
            NTT998(cp,true);
            Mod998 tmp=1;
            for(int i=0; i<z; i++) {
                cp[i]*=tmp;
                tmp*=e;
            }
            NTT998(cp);
            for(int i=0; i<z; i++) v.push_back(cp[i]);
        }

        //s.t|v|=2^s(no assert)
        void PickEvenOdd(Fps& v, int odd) {
            int z=v.size()/2;
            Mod998 half=(Mod998)(1)/(Mod998)(2);
            if(odd==0) {
                for(int i=0; i<z; i++) v[i]=(v[i*2]+v[i*2+1])*half;
                v.resize(z);
            }else{
                Mod998 e=Mod998(PRIMITIVE_ROOT).pow(Mod998::mod()/(2*z));
                Mod998 ie=Mod998(1)/e;
                Fps es={half};
                while((int)es.size()!=z) {
                    Fps n_es((int)es.size()*2);
                    for(int i=0; i<(int)es.size(); i++) {
                        n_es[i*2]=(es[i]);
                        n_es[i*2+1]=(es[i]*ie);
                    }
                    ie*=ie;
                    swap(n_es,es);
                }
                for(int i=0; i<z; i++) v[i]=(v[i*2]-v[i*2+1])*es[i];
                v.resize(z);
            }
        }
    }

    /// @brief 多項式 f について、e^f = Σ[k=0~len-1](f(x)^k/k!) を返す
    Fps Exp(Fps f, int len=-1) {
        /**
         * 方程式 h(g) = log(g) - f = 0 の解を求める問題に変換する。
         * ニュートン法より
         *  g_{n+1} = g_n - h(g_n)/h'(g_n)
         * h(g) = log(g) - f, h'(g) = 1/g  より
         *  g_{n+1} = g_n - (log(g_n) - f) / (1/g_n)
         *          = g_n - g_n(log(g_n) - f)
         *          = g_n * (1 - log(g_n) + f)
         * ここで、log(g_n) を上位項と下位項 lg_low, lg_high に分けると、
         *  log(g_n) = lg_low + lg_high, lg_low = f_low
         * よって、
         *  g_{n+1} = g_n(1 - f_low - lg_high + f)
         *          = g_n(1 + (f - log(g))_high)
         *          = g_n + g_n * (f - log(g))_high
         */

        if(len==-1) len=f.size();
        if(len==0) return{};
        if(len==1) return{Mod998(1)};
        assert(!f.empty() && f[0]==0);

        int s=1;
        Fps g={Mod998(1)};

        while(s<len) {
            Fps A=g,B=g;
            A=Differential(A); //A = g'
            B=Inv(B,2*s); //B = 1/g
            A.resize(2*s);
            A=Internal::CyclicConvolution(A,B); //A = g'/g
            A.pop_back();
            A=Integral(A); //A = ∫(g'/g)dx = log(g)

            //A = (f-log(g_n))_high
            for(int i=0; i<s; i++) A[i]=0;
            for(int i=s; i<s*2; i++) A[i]=(i<(int)f.size() ? f[i] : 0)-A[i];

            //g_{n+1} = g_n + g_n * (f-log(g))_high
            g.resize(2*s);
            B=Internal::CyclicConvolution(A,g);
            for(int i=s; i<s*2; i++) g[i]=B[i];
            s*=2;
        }
        g.resize(len);
        return g;
    }

    /// @brief 多項式 f について、log(f) を返す
    Fps Log(Fps f, int len=-1) {
        if(len==-1) len=f.size();
        if(len==0) return{};
        if(len==1) return{Mod998(0)};
        assert(!f.empty()&&f[0]==1);
        Fps res=Convolve998(Differential(f),Inv(f,len));
        res.resize(len-1);
        return Integral(res);
    }

    /// @brief 多項式 f^M を返す
    Fps Pow(Fps f, ll M, int len=-1) {
        if(len==-1) len=f.size();
        Fps res(len,0);
        if(M==0) {
            res[0]=1;
            return res;
        }
        for(int i=0; i<(int)f.size(); i++) {
            if(f[i]==0) continue;
            if(i>(len-1)/M) break;
            Fps g((int)f.size()-i);
            Mod998 v=(Mod998)(1)/(Mod998)(f[i]);
            for(int j=i; j<(int)f.size(); j++) g[j-i]=f[j]*v;
            ll zero=i*M;
            if(i) len-=i*M;
            g=Log(g,len);
            for(Mod998& x:g) x*=M;
            g=Exp(g,len);
            v=(Mod998)(1)/v;
            Mod998 c=1;
            while(M) {
                if(M&1) c=c*v;
                v=v*v;
                M>>=1;
            }
            for(int i=0; i<len; i++) res[i+zero]=g[i]*c;
            return res;
        }
        return res;
    }

    /// @brief `[x^k](P/Q)` を返す
    Mod998 BostanMori(ll k, Fps P, Fps Q) {
        assert(!Q.empty()&&Q[0]!=0);
        int z=1;
        while(z<(int)max(P.size(),Q.size())) z*=2;
        P.resize(z*2,0); Q.resize(z*2,0);
        NTT998(P); NTT998(Q);
        //fast
        while(k) {
            //Q(-x)
            Fps Q_n(z*2);
            for(int i=0; i<z; i++) {
                Q_n[i*2]=Q[i*2+1];
                Q_n[i*2+1]=Q[i*2];
            }
            for(int i=0; i<z*2; i++) {
                P[i]*=Q_n[i];
                Q[i]*=Q_n[i];
            }
            Internal::PickEvenOdd(P,k&1);
            Internal::PickEvenOdd(Q,0);
            k/=2;
            if(k==0) break;
            Internal::Extend(P);
            Internal::Extend(Q);
        }
        Mod998 SP=0,SQ=0;
        for(int i=0; i<z; i++) SP+=P[i],SQ+=Q[i];
        return SP/SQ;
    }

    //0=a[i]*c[0]+a[i-1]*c[1]+a[i-2]*c[2]+...+a[i-d]*c[d]
    //a.size()+1==c.size()
    //c[0]=-1?
    //return a[k]
    Mod998 KthLinear(ll k, Fps a, Fps c) {
        int d=a.size();
        assert(d+1==int(c.size()));
        Fps P=Convolve998(a,c);
        P.resize(d);
        return BostanMori(k,P,c);
    }
}


vector<Mod998> BerlekampMassey(const vector<Mod998>& s) {
    int n=s.size();
    vector<Mod998> c={1}, b={1};
    int l=0, m=1;
    Mod998 bb=1;
    for(int i=0; i<n; i++) {
        Mod998 d=0;
        for(int j=0; j<=l; j++) d+=c[j]*s[i-j];
        if(d==Mod998(0)) m++;
        else if(2*l<=i) {
            auto t=c; auto coef=d*bb.inv();
            c.resize(max((int)b.size()+m,(int)c.size()),0);
            for(int j=0; j<b.size(); j++) c[j+m]-=coef*b[j];
            l=i+1-l; b=t; bb=d; m=1;
        } else {
            auto coef=d*bb.inv();
            c.resize(max((int)b.size()+m,(int)c.size()),0);
            for(int j=0; j<b.size(); j++) c[j+m]-=coef*b[j];
            m++;
        }
    }
    c.erase(c.begin());
    for(auto& x: c) x=-x;
    return c;

    // const int N = (int)f.size(); vector<Mod998> b, c;
    // b.reserve(N + 1); c.reserve(N + 1); b.push_back(Mod998(1)); c.push_back(Mod998(1));
    // Mod998 y = Mod998(1);
    // for (int ed = 1; ed <= N; ed++) {
    //     int l = int(c.size()), m = int(b.size()); Mod998 x = 0;
    //     for (int i = 0; i < l; i++) x += c[i] * f[ed - l + i];
    //     b.emplace_back(Mod998(0)); m++;
    //     if (x == Mod998(0)) continue;
    //     Mod998 freq = x / y;
    //     if (l < m) {
    //         auto tmp = c; c.insert(begin(c), m - l, Mod998(0));
    //         for (int i = 0; i < m; i++) c[m - 1 - i] -= freq * b[m - 1 - i];
    //         b = tmp; y = x;
    //     } else {
    //         for (int i = 0; i < m; i++) c[l - 1 - i] -= freq * b[m - 1 - i];
    //     }
    // }
    // reverse(begin(c), end(c));

    // return c;
}

Mod998 LinearRecurrence(vector<Mod998> a, vector<Mod998> c, ll k) {
    int n=c.size();
    if(n==0) return 0;
    vector<Mod998> dnm(n+1,1);
    rep(i,n) dnm[i+1]=-c[i];
    a.resize(n);
    auto num=Convolve998(dnm,a); num.resize(n);
    return FPS::BostanMori(k,num,dnm);
}

Mod998 BMBM(const vector<Mod998> &s, ll n) {
  auto bm=BerlekampMassey(s);
  return LinearRecurrence(s,bm,n);
}

//----------------------------------------------------------

void solve() {
    ll K,L,R; cin>>K>>L>>R;
    ll N=500;
    vector<Mod998> A(N+1);
    A[0]=1; rep(i,N) A[i+1]=A[i]*K+Mod998(i).pow(K)+Mod998(K).pow(i);
    A.insert(A.begin(),0); rep(i,N+1) A[i+1]+=A[i];
    cout<<BMBM(A,R+1)-BMBM(A,L)<<endl;
}

int main() {
    ios::sync_with_stdio(false); cin.tie(nullptr);
    //cout<<fixed<<setprecision(15);
    int T=1; //cin>>T;
    while(T--) solve();
}
0