結果

問題 No.430 文字列検索
ユーザー HIcoderHIcoder
提出日時 2024-08-25 12:24:25
言語 C++17
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 161 ms / 2,000 ms
コード長 10,103 bytes
コンパイル時間 1,370 ms
コンパイル使用メモリ 139,068 KB
実行使用メモリ 27,264 KB
最終ジャッジ日時 2024-11-10 01:13:00
合計ジャッジ時間 2,450 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 1 ms
5,248 KB
testcase_01 AC 161 ms
27,264 KB
testcase_02 AC 21 ms
5,248 KB
testcase_03 AC 21 ms
5,248 KB
testcase_04 AC 2 ms
5,248 KB
testcase_05 AC 1 ms
5,248 KB
testcase_06 AC 2 ms
5,248 KB
testcase_07 AC 1 ms
5,248 KB
testcase_08 AC 153 ms
26,880 KB
testcase_09 AC 2 ms
5,248 KB
testcase_10 AC 12 ms
6,016 KB
testcase_11 AC 63 ms
8,192 KB
testcase_12 AC 62 ms
8,192 KB
testcase_13 AC 65 ms
8,320 KB
testcase_14 AC 52 ms
6,656 KB
testcase_15 AC 43 ms
5,888 KB
testcase_16 AC 25 ms
5,248 KB
testcase_17 AC 22 ms
5,248 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<iostream>
#include<string>
#include<queue>
#include<vector>
#include<cassert>
#include<random>
#include<set>
#include<map>
#include<cassert>
#include<unordered_map>
#include<bitset>
#include<numeric>
#include<algorithm>
using namespace std;
typedef long long ll;
const int inf=1<<30;
const ll INF=1LL<<62;
typedef pair<int,ll> P;
typedef pair<int,P> PP; 
const ll MOD=998244353;

// SA-IS (O(N))
template<class Str> struct SuffixArray {
    // data
    Str str;
    //sa[0]は空文字になっていることに注意
    vector<int> sa, rsa, lcp;//勝手に代入される
    int& operator [] (int i) {
        return sa[i];
    }

    // constructor
    SuffixArray(const Str& str_) : str(str_) {
        build_sa();
    }
    void init(const Str& str_) {
        str = str_;
        build_sa();
    }
    void build_sa() {
        int N = (int)str.size();
        vector<int> s;
        for (int i = 0; i < N; ++i) s.push_back(str[i] + 1);
        s.push_back(0);
        sa = sa_is(s);
        rsa.assign(N + 1, 0);
        for (int i = 0; i <= N; ++i) rsa[sa[i]] = i;
        calc_lcp(s);
    }

    // SA-IS
    // upper: # of characters 
    vector<int> sa_is(vector<int> &s, int upper = 256) {
        int N = (int)s.size();
        if (N == 0) return {};
        else if (N == 1) return {0};
        else if (N == 2) {
            if (s[0] < s[1]) return {0, 1};
            else return {1, 0};
        }

        vector<int> isa(N);
        vector<bool> ls(N, false);
        for (int i = N - 2; i >= 0; --i) {
            ls[i] = (s[i] == s[i + 1]) ? ls[i + 1] : (s[i] < s[i + 1]);
        }
        vector<int> sum_l(upper + 1, 0), sum_s(upper + 1, 0);
        for (int i = 0; i < N; ++i) {
            if (!ls[i]) ++sum_s[s[i]];
            else ++sum_l[s[i] + 1];
        }
        for (int i = 0; i <= upper; ++i) {
            sum_s[i] += sum_l[i];
            if (i < upper) sum_l[i + 1] += sum_s[i];
        }

        auto induce = [&](const vector<int> &lms) -> void {
            fill(isa.begin(), isa.end(), -1);
            vector<int> buf(upper + 1);
            copy(sum_s.begin(), sum_s.end(), buf.begin());
            for (auto d: lms) {
                if (d == N) continue;
                isa[buf[s[d]]++] = d;
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            isa[buf[s[N - 1]]++] = N - 1;
            for (int i = 0; i < N; ++i) {
                int v = isa[i];
                if (v >= 1 && !ls[v - 1]) {
                    isa[buf[s[v - 1]]++] = v - 1;
                }
            }
            copy(sum_l.begin(), sum_l.end(), buf.begin());
            for (int i = N - 1; i >= 0; --i) {
                int v = isa[i];
                if (v >= 1 && ls[v - 1]) {
                    isa[--buf[s[v - 1] + 1]] = v - 1;
                }
            }
        };
            
        vector<int> lms, lms_map(N + 1, -1);
        int M = 0;
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms_map[i] = M++;
            }
        }
        lms.reserve(M);
        for (int i = 1; i < N; ++i) {
            if (!ls[i - 1] && ls[i]) {
                lms.push_back(i);
            }
        }
        induce(lms);

        if (M) {
            vector<int> lms2;
            lms2.reserve(isa.size());
            for (auto v: isa) {
                if (lms_map[v] != -1) lms2.push_back(v);
            }
            int rec_upper = 0;
            vector<int> rec_s(M);
            rec_s[lms_map[lms2[0]]] = 0;
            for (int i = 1; i < M; ++i) {
                int l = lms2[i - 1], r = lms2[i];
                int nl = (lms_map[l] + 1 < M) ? lms[lms_map[l] + 1] : N;
                int nr = (lms_map[r] + 1 < M) ? lms[lms_map[r] + 1] : N;
                bool same = true;
                if (nl - l != nr - r) same = false;
                else {
                    while (l < nl) {
                        if (s[l] != s[r]) break;
                        ++l, ++r;
                    }
                    if (l == N || s[l] != s[r]) same = false;
                }
                if (!same) ++rec_upper;
                rec_s[lms_map[lms2[i]]] = rec_upper;
            }
            auto rec_sa = sa_is(rec_s, rec_upper);

            vector<int> sorted_lms(M);
            for (int i = 0; i < M; ++i) {
                sorted_lms[i] = lms[rec_sa[i]];
            }
            induce(sorted_lms);
        }
        return isa;
    }

    // prepair lcp
    vector<int> calc_lcp(const vector<int> &s) {
        int N = (int)s.size();
        lcp.assign(N - 1, 0);
        int h = 0;
        for (int i = 0; i < N - 1; ++i) {
            int pi = sa[rsa[i] - 1];
            if (h > 0) --h;
            for (; pi + h < N && i + h < N; ++h) {
                if (s[pi + h] != s[i + h]) break;
            }
            lcp[rsa[i] - 1] = h;
        }
        return lcp;
    }
};


struct z_algorithm{

    std::string s;
    int len;
    std::vector<int> z_array;
    bool is_build;
    z_algorithm(const std::string& s_):s(s_),len(s_.size()),is_build(false){
        z_array=std::vector<int>(len);
        build();
    }

    ~z_algorithm(){
        std::vector<int>().swap(z_array);
    }

    
    void build(){
        is_build = true;
        z_array[0]=len;
        int i=1;
        int j=0;
        while(i<len){
            while(i+j<len && s[i+j]==s[j])j++;

            //一致しなくなる or それ以上伸ばせない場合
            z_array[i]=j;

            if(j==0){
                i++;
                continue;
            }

            //j>0

            int k=1;
            while(k<j && k+z_array[k]<j){
                z_array[i+k]=z_array[k];
                k++;
            }
            i+=k;
            j-=k;
            
        }
    }   
    
    int operator[](int idx){
        assert(0<=idx && idx<len);
        if(!is_build) build();

        return z_array[idx];
    }


    void print(){
        for(int i=0;i<len;i++){
            std::cout<<"z_array["<<i<<"]="<<z_array[i]<<std::endl;
        }
        
    }
    
};



struct Trie{
    struct Node{
        std::map<char,int> to;
        int cnt;
        bool is_end;
        Node():cnt(0),is_end(false){}
    };

    vector<Node> nodes;
    Trie():nodes(1){}

    void insert(const string &s){
        int now=0;
        for(char c:s){
            if(!nodes[now].to.count(c)){
                nodes[now].to[c]=nodes.size();
                nodes.emplace_back(Node());
            }
            now=nodes[now].to[c];
        }
        nodes[now].cnt++;
        nodes[now].is_end=true;
    }
    ll ans=0;
    int dfs(int v){
        int res=nodes[v].cnt;
        for(auto [c,nv]:nodes[v].to){
            res+=dfs(nv);
        }

        if(v>0)ans+=res*(res-1)/2;
        return res;
    }

};





class RollingHash{

    long long  gcd(long long  x,long long  y){
        return y?gcd(y,x%y):x;
    }

    public:
        
        std::vector<long long> h; 
        std::vector<long long> powBase;
        int len;
        std::string str;
        long long mod;
        long long base;
        //RollingHash(const std::string& str,long long  base_=1000000+7,long long  mod_=1000000000+7):base(base_),mod(mod_){
        RollingHash(const std::string& str,long long  base_=1000000+7,long long  mod_=998244353):base(base_),mod(mod_){

            len=str.size();
            h = std::vector<long long>(len+1,0);
            powBase = std::vector<long long>(len+1,0);


            assert(gcd(base_,mod_)==1);
            h[0]=0;
            for(int i=1;i<=len;i++){
                h[i]=(h[i-1]*base%mod + (int)str[i-1])%mod; 
            }

            powBase[0]=1;
            for(int i=1;i<=len;i++){
                powBase[i]=powBase[i-1]*base%mod;
            }


        }


        ~RollingHash(){
            std::vector<long long>().swap(h);
            std::vector<long long>().swap(powBase);

        }


        //[0,idx]のハッシュ値を取得
        long long  get_hash(int idx){
            assert(0<=idx && idx<len);
            return h[idx+1];
        }

        long long pow(int i){
            assert(0<=i && i<=len);
            return powBase[i];
        }

        //0-indexed [l,r)のハッシュを取り出す
        long long  get_hash(int l,int r){
            assert(0<=l && r<=len);
            assert(l<r);

            int len=r-l;//[l,r)

            //h[r]-h[l]
            return (h[r] - (powBase[len]*h[l])%mod + mod)%mod;
        }

        long long getBase()const{
            return base;
        }

        long long getMod()const{
            return mod;
        }

        std::string getString()const{
            return str;
        }

        //[l,r)にしたい
        std::string getSubString(int l,int r)const{
            int len=r-l;
            return str.substr(l,len);
        }

        //[lからlen文字
        std::string getSubString(int l,unsigned int len)const{
            return str.substr(l,len);
        }



};




int main(){
    string S;
    cin>>S;

    // map<string,int> mp;
    // for(int i=0;i<S.size();i++){
    //     for(int j=1;j+i<=min(static_cast<int>(S.size()),10);j++){
    //         mp[S.substr(i,j)]++;
    //     }
    // }   

    int M;
    cin>>M;
    vector<string> C(M);
    for(int i=0;i<M;i++){
        cin>>C[i];
    }
    
    ll ans=0;
    /*
    //zアルゴリズムは厳しい
    for(int i=0;i<M;i++){
        string tmp=C[i]+'$'+S;

        z_algorithm za(tmp);
        int len=C[i].size();

        int cnt=0;
        for(int j=len+1;j+len<=tmp.size();j++){
            if(za[j]>=len){
                cnt++;
            }
        }
        ans+=cnt;
    }
    cout<<ans<<endl;
    */
    RollingHash rh(S);
    vector<map<ll,ll>> mp(11);

    for(int x=1;x<=10;x++){
        for(int i=0;i+x<=S.size();i++){
            ll hash=rh.get_hash(i,i+x);
            mp[x][hash]++;
        }
    }
    
    for(int i=0;i<M;i++){
        RollingHash rh2(C[i]);
        ans+=mp[C[i].size()][rh2.get_hash(0,C[i].size())];
    }
    cout<<ans<<endl;
    
    
}
0