結果

問題 No.430 文字列検索
ユーザー kk2kk2
提出日時 2024-10-13 04:51:25
言語 C++23(gcc13)
(gcc 13.2.0 + boost 1.83.0)
結果
AC  
実行時間 9 ms / 2,000 ms
コード長 20,325 bytes
コンパイル時間 3,370 ms
コンパイル使用メモリ 237,956 KB
実行使用メモリ 8,160 KB
最終ジャッジ日時 2024-11-10 01:13:32
合計ジャッジ時間 3,384 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 9 ms
8,160 KB
testcase_02 AC 4 ms
5,248 KB
testcase_03 AC 4 ms
5,248 KB
testcase_04 AC 2 ms
5,248 KB
testcase_05 AC 1 ms
5,248 KB
testcase_06 AC 2 ms
5,248 KB
testcase_07 AC 2 ms
5,248 KB
testcase_08 AC 2 ms
5,248 KB
testcase_09 AC 2 ms
5,248 KB
testcase_10 AC 2 ms
5,248 KB
testcase_11 AC 7 ms
5,840 KB
testcase_12 AC 7 ms
6,128 KB
testcase_13 AC 7 ms
6,260 KB
testcase_14 AC 6 ms
5,972 KB
testcase_15 AC 5 ms
5,908 KB
testcase_16 AC 6 ms
5,972 KB
testcase_17 AC 5 ms
5,840 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#define PROBLEM "https://yukicoder.me/problems/no/430" 

#ifndef DATA_STRUCTURE_TRIE_HPP
#define DATA_STRUCTURE_TRIE_HPP 1

#include <cassert>
#include <cstring>
#include <functional>
#include <string>
#include <vector>

namespace kk2 {

template <int char_size> struct TrieNode {
    int nxt[char_size];
    int exist;
    std::vector<int> accept;

    TrieNode() : exist(0) { memset(nxt, -1, sizeof(nxt)); }
};

template <int char_size, int margin> struct Trie {
    using Node = TrieNode<char_size>;

    std::vector<Node> nodes;
    constexpr static int root = 0;

    Trie() { nodes.emplace_back(); }

    int push_node() {
        nodes.emplace_back();
        return (int)nodes.size() - 1;
    }

    void update_direct(int node, int id) { nodes[node].accept.push_back(id); }

    void update_child(int node) { ++nodes[node].exist; }

    void add(const std::string &str) {
        assert(!str.empty());
        const int id = nodes[root].exist;
        auto rec = [&](auto self, int now, int idx) -> void {
            if (idx == (int)str.size()) {
                update_direct(now, id);
                return;
            }
            const int d = str[idx] - margin;
            if (nodes[now].nxt[d] == -1) nodes[now].nxt[d] = push_node();
            self(self, nodes[now].nxt[d], idx + 1);
            update_child(now);
        };
        rec(rec, root, 0);
    }

    template <void (*f)(int)> void query(const std::string &str) {
        query(str, [](int idx) { f(idx); });
    }

    template <class F> void query(const std::string &str, const F &f) {
        int now = root;
        for (char c : str) {
            for (int &idx : nodes[now].accept) f(idx);
            const int d = c - margin;
            now = nodes[now].nxt[d];
            if (now == -1) return;
        }
        for (int idx : nodes[now].accept) f(idx);
    }

    int count() const { return (int)nodes[0].exist; }

    int size() const { return (int)nodes.size(); }

    // return the number of strings which have the prefix
    // corresponding to the node_id
    int size(int node_idx) const {
        return (int)nodes[node_idx].accept.size() + nodes[node_idx].exist;
    }
};

} // namespace kk2

#endif // DATA_STRUCTURE_TRIE_HPP

#ifndef STRING_AHO_CORASICK_HPP
#define STRING_AHO_CORASICK_HPP 1

#include <algorithm>
#include <queue>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>

// #include "../data_structure/trie.hpp"

namespace kk2 {

template <int char_size, int margin> struct AhoCorasick : Trie<char_size + 1, margin> {
    using Trie<char_size + 1, margin>::Trie;
    using Trie<char_size + 1, margin>::count;

    constexpr static int FAIL = char_size;
    std::vector<int> correct, perm;

    void build() {
        correct.resize(this->size());
        int now = 0;
        perm.resize(this->size());
        perm[now++] = this->root;
        for (int i = 0; i < (int)this->size(); ++i) {
            correct[i] = (int)this->nodes[i].accept.size();
        }
        std::queue<int> que;
        for (int i = 0; i <= char_size; ++i) {
            if (this->nodes[this->root].nxt[i] == -1) {
                this->nodes[this->root].nxt[i] = this->root;
            } else {
                this->nodes[this->nodes[this->root].nxt[i]].nxt[FAIL] = this->root;
                que.emplace(this->nodes[this->root].nxt[i]);
            }
        }
        while (!que.empty()) {
            perm[now++] = que.front();
            auto &now = this->nodes[que.front()];
            int fail = now.nxt[FAIL];
            correct[que.front()] += correct[fail];
            que.pop();
            for (int i = 0; i < char_size; ++i) {
                if (now.nxt[i] == -1) {
                    now.nxt[i] = this->nodes[fail].nxt[i];
                } else {
                    this->nodes[now.nxt[i]].nxt[FAIL] = this->nodes[fail].nxt[i];
                    que.emplace(now.nxt[i]);
                }
            }
        }
    }

    long long all_match(const std::string &str, int now_ = 0) {
        std::unordered_map<int, int> visit_cnt;
        for (char c : str) {
            now_ = this->nodes[now_].nxt[c - margin];
            visit_cnt[now_]++;
        }
        long long res{};
        for (auto &&[now, cnt] : visit_cnt) { res += (long long)correct[now] * cnt; }
        return res;
    }

    std::vector<long long> each_match(const std::string &str, int now_ = 0) {
        std::vector<int> visit_cnt(this->size());
        for (char c : str) {
            now_ = this->nodes[now_].nxt[c - margin];
            visit_cnt[now_]++;
        }
        std::vector<long long> res(this->count());
        for (int i = this->size() - 1; i > 0; --i) {
            int now = perm[i];
            visit_cnt[this->nodes[now].nxt[FAIL]] += visit_cnt[now];
            for (int idx : this->nodes[now].accept) { res[idx] += visit_cnt[now]; }
        }
        return res;
    }

    int move(int now, char c) { return this->nodes[now].nxt[c - margin]; }

    int count(int node) const { return correct[node]; }
};

} // namespace kk2

#endif // STRING_AHO_CORASICK_HPP

// #include "../../string/aho_corasick.hpp"

#ifndef TEMPLATE_FASTIO_HPP
#define TEMPLATE_FASTIO_HPP 1

#include <cctype>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <string>

namespace kk2 {

namespace fastio {

#define INPUT_FILE "in.txt"
#define OUTPUT_FILE "out.txt"

struct Scanner {
  private:
    static constexpr size_t INPUT_BUF = 1 << 17;
    size_t pos = INPUT_BUF;
    static char buf[INPUT_BUF];
    FILE *fp;

  public:
    Scanner() : fp(stdin) {}

    Scanner(const char *file) : fp(fopen(file, "r")) {}

    ~Scanner() {
        if (fp != stdin) fclose(fp);
    }

    char now() {
        if (pos == INPUT_BUF) {
            size_t len = fread(buf, 1, INPUT_BUF, fp);
            if (len != INPUT_BUF) buf[len] = '\0';
            pos = 0;
        }
        return buf[pos];
    }

    void skip_space() {
        while (isspace(now())) ++pos;
    }

    uint32_t next_u32() {
        skip_space();
        uint32_t res = 0;
        while (isdigit(now())) {
            res = res * 10 + (now() - '0');
            ++pos;
        }
        return res;
    }

    int32_t next_i32() {
        skip_space();
        if (now() == '-') {
            ++pos;
            return (int32_t)(-next_u32());
        } else return (int32_t)next_u32();
    }

    uint64_t next_u64() {
        skip_space();
        uint64_t res = 0;
        while (isdigit(now())) {
            res = res * 10 + (now() - '0');
            ++pos;
        }
        return res;
    }

    int64_t next_i64() {
        skip_space();
        if (now() == '-') {
            ++pos;
            return (int64_t)(-next_u64());
        } else return (int64_t)next_u64();
    }

    __uint128_t next_u128() {
        skip_space();
        __uint128_t res = 0;
        while (isdigit(now())) {
            res = res * 10 + (now() - '0');
            ++pos;
        }
        return res;
    }

    __int128_t next_i128() {
        skip_space();
        if (now() == '-') {
            ++pos;
            return (__int128_t)(-next_u128());
        } else return (__int128_t)next_u128();
    }

    char next_char() {
        skip_space();
        auto res = now();
        ++pos;
        return res;
    }

    std::string next_string() {
        skip_space();
        std::string res;
        while (true) {
            char c = now();
            if (isspace(c) or c == '\0') break;
            res.push_back(now());
            ++pos;
        }
        return res;
    }

    Scanner &operator>>(int &x) {
        x = next_i32();
        return *this;
    }

    Scanner &operator>>(unsigned int &x) {
        x = next_u32();
        return *this;
    }

    Scanner &operator>>(long &x) {
        x = next_i64();
        return *this;
    }

    Scanner &operator>>(long long &x) {
        x = next_i64();
        return *this;
    }

    Scanner &operator>>(unsigned long &x) {
        x = next_u64();
        return *this;
    }

    Scanner &operator>>(unsigned long long &x) {
        x = next_u64();
        return *this;
    }

    Scanner &operator>>(__int128_t &x) {
        x = next_i128();
        return *this;
    }

    Scanner &operator>>(__uint128_t &x) {
        x = next_u128();
        return *this;
    }

    Scanner &operator>>(char &x) {
        x = next_char();
        return *this;
    }

    Scanner &operator>>(std::string &x) {
        x = next_string();
        return *this;
    }
};

struct Printer {
  private:
    static char helper[1000][4];
    static char leading_zero[1000][4];
    constexpr static size_t OUTPUT_BUF = 1 << 17;
    static char buf[OUTPUT_BUF];
    size_t pos = 0;
    FILE *fp;

    static constexpr uint32_t pow10_32(uint32_t n) { return n == 0 ? 1 : pow10_32(n - 1) * 10; }

    static constexpr uint64_t pow10_64(uint32_t n) { return n == 0 ? 1 : pow10_64(n - 1) * 10; }

    static constexpr __uint128_t pow10_128(uint32_t n) {
        return n == 0 ? 1 : pow10_128(n - 1) * 10;
    }

    template <class T, class U> static constexpr void div_mod(T &a, U &b, U mod) {
        a = b / mod;
        b -= a * mod;
    }

    static void init() {
        buf[0] = '\0';
        for (size_t i = 0; i < 1000; ++i) {
            leading_zero[i][0] = i / 100 + '0';
            leading_zero[i][1] = i / 10 % 10 + '0';
            leading_zero[i][2] = i % 10 + '0';
            leading_zero[i][3] = '\0';

            size_t j = 0;
            if (i >= 100) helper[i][j++] = i / 100 + '0';
            if (i >= 10) helper[i][j++] = i / 10 % 10 + '0';
            helper[i][j++] = i % 10 + '0';
            helper[i][j] = '\0';
        }
    }

  public:
    Printer() : fp(stdout) { init(); }

    Printer(const char *file) : fp(fopen(file, "w")) { init(); }

    ~Printer() {
        write();
        if (fp != stdout) fclose(fp);
    }

    void write() {
        fwrite(buf, 1, pos, fp);
        pos = 0;
    }

    void put_char(char c) {
        if (pos == OUTPUT_BUF) write();
        buf[pos++] = c;
    }

    void put_cstr(const char *s) {
        while (*s) put_char(*(s++));
    }

    void put_u32(uint32_t x) {
        uint32_t y;
        if (x >= pow10_32(9)) {
            div_mod(y, x, pow10_32(9));
            put_cstr(helper[y]);
            div_mod(y, x, pow10_32(6));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_32(3));
            put_cstr(leading_zero[y]);
            put_cstr(leading_zero[x]);
        } else if (x >= pow10_32(6)) {
            div_mod(y, x, pow10_32(6));
            put_cstr(helper[y]);
            div_mod(y, x, pow10_32(3));
            put_cstr(leading_zero[y]);
            put_cstr(leading_zero[x]);
        } else if (x >= pow10_32(3)) {
            div_mod(y, x, pow10_32(3));
            put_cstr(helper[y]);
            put_cstr(leading_zero[x]);
        } else put_cstr(helper[x]);
    }

    void put_i32(int32_t x) {
        if (x < 0) {
            put_char('-');
            put_u32(-x);
        } else put_u32(x);
    }

    void put_u64(uint64_t x) {
        uint64_t y;
        if (x >= pow10_64(18)) {
            div_mod(y, x, pow10_64(18));
            put_cstr(helper[y]);
            div_mod(y, x, pow10_64(15));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_64(12));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_64(9));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_64(6));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_64(3));
            put_cstr(leading_zero[y]);
            put_cstr(leading_zero[x]);
        } else if (x >= pow10_64(9)) {
            div_mod(y, x, pow10_64(9));
            put_u32(uint32_t(y));
            div_mod(y, x, pow10_64(6));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_64(3));
            put_cstr(leading_zero[y]);
            put_cstr(leading_zero[x]);
        } else put_u32(uint32_t(x));
    }

    void put_i64(int64_t x) {
        if (x < 0) {
            put_char('-');
            put_u64(-x);
        } else put_u64(x);
    }

    void put_u128(__uint128_t x) {
        __uint128_t y;
        if (x >= pow10_128(36)) {
            div_mod(y, x, pow10_128(36));
            put_cstr(helper[y]);
            div_mod(y, x, pow10_128(33));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(30));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(27));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(24));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(21));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(18));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(15));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(12));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(9));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(6));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(3));
            put_cstr(leading_zero[y]);
            put_cstr(leading_zero[x]);
        } else if (x >= pow10_128(18)) {
            div_mod(y, x, pow10_128(18));
            put_u64(uint64_t(y));
            div_mod(y, x, pow10_128(15));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(12));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(9));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(6));
            put_cstr(leading_zero[y]);
            div_mod(y, x, pow10_128(3));
            put_cstr(leading_zero[y]);
            put_cstr(leading_zero[x]);
        } else put_u64(uint64_t(x));
    }

    void put_i128(__int128_t x) {
        if (x < 0) {
            put_char('-');
            put_u128(-x);
        } else put_u128(x);
    }

    Printer &operator<<(int x) {
        put_i32(x);
        return *this;
    }

    Printer &operator<<(unsigned int x) {
        put_u32(x);
        return *this;
    }

    Printer &operator<<(long x) {
        put_i64(x);
        return *this;
    }

    Printer &operator<<(long long x) {
        put_i64(x);
        return *this;
    }

    Printer &operator<<(unsigned long x) {
        put_u64(x);
        return *this;
    }

    Printer &operator<<(unsigned long long x) {
        put_u64(x);
        return *this;
    }

    Printer &operator<<(__int128_t x) {
        put_i128(x);
        return *this;
    }

    Printer &operator<<(__uint128_t x) {
        put_u128(x);
        return *this;
    }

    Printer &operator<<(char x) {
        put_char(x);
        return *this;
    }

    Printer &operator<<(const std::string &x) {
        for (char c : x) put_char(c);
        return *this;
    }

    Printer &operator<<(const char *x) {
        put_cstr(x);
        return *this;
    }
};

char Scanner::buf[Scanner::INPUT_BUF];
char Printer::buf[Printer::OUTPUT_BUF];
char Printer::helper[1000][4];
char Printer::leading_zero[1000][4];

} // namespace fastio

} // namespace kk2

#if defined(INTERACTIVE) || defined(USE_STDIO)
struct IoSetUp {
    IoSetUp() {
        std::cin.tie(nullptr);
        std::ios::sync_with_stdio(false);
    }
} iosetup;
#define kin std::cin
#define kout std::cout
#elif defined(KK2)
kk2::fastio::Scanner kin(INPUT_FILE);
kk2::fastio::Printer kout(OUTPUT_FILE);
#define endl '\n'
#else
kk2::fastio::Scanner kin;
kk2::fastio::Printer kout;
#define endl '\n'
#endif

#endif // TEMPLATE_FASTIO_HPP

#ifndef TEMPLATE
#define TEMPLATE 1

#pragma GCC optimize("O3,unroll-loops")

// #include <bits/stdc++.h>
#include <algorithm>
#include <array>
#include <bitset>
#include <cassert>
#include <chrono>
#include <cmath>
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <limits>
#include <map>
#include <numeric>
#include <optional>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

using u32 = unsigned int;
using i64 = long long;
using u64 = unsigned long long;
using i128 = __int128_t;
using u128 = __uint128_t;

using pi = std::pair<int, int>;
using pl = std::pair<i64, i64>;
using pil = std::pair<int, i64>;
using pli = std::pair<i64, int>;

template <class T> using vc = std::vector<T>;
template <class T> using vvc = std::vector<vc<T>>;
template <class T> using vvvc = std::vector<vvc<T>>;
template <class T> using vvvvc = std::vector<vvvc<T>>;

template <class T> using pq = std::priority_queue<T>;
template <class T> using pqi = std::priority_queue<T, std::vector<T>, std::greater<T>>;

template <class T> constexpr T infty = 0;
template <> constexpr int infty<int> = (1 << 30) - 123;
template <> constexpr i64 infty<i64> = (1ll << 62) - (1ll << 31);
template <> constexpr i128 infty<i128> = (i128(1) << 126) - (i128(1) << 63);
template <> constexpr u32 infty<u32> = infty<int>;
template <> constexpr u64 infty<u64> = infty<i64>;
template <> constexpr u128 infty<u128> = infty<i128>;
template <> constexpr double infty<double> = infty<i64>;
template <> constexpr long double infty<long double> = infty<i64>;

constexpr int mod = 998244353;
constexpr int modu = 1e9 + 7;
constexpr long double PI = 3.14159265358979323846;

namespace kk2 {

template <class T, class... Sizes> auto make_vector(int first, Sizes... sizes) {
    if constexpr (sizeof...(sizes) == 0) {
        return std::vector<T>(first);
    } else {
        return std::vector<decltype(make_vector(sizes...))>(first, make_vector(sizes...));
    }
}

template <class T, class U> void fill_all(std::vector<T> &v, const U &x) {
    std::fill(std::begin(v), std::end(v), T(x));
}

template <class T, class U> void fill_all(std::vector<std::vector<T>> &v, const U &x) {
    for (auto &u : v) fill_all(u, x);
}

} // namespace kk2

template <class T, class S> inline bool chmax(T &a, const S &b) {
    return (a < b ? a = b, 1 : 0);
}

template <class T, class S> inline bool chmin(T &a, const S &b) {
    return (a > b ? a = b, 1 : 0);
}

#define rep1(a) for (i64 _ = 0; _ < (i64)(a); ++_)
#define rep2(i, a) for (i64 i = 0; i < (i64)(a); ++i)
#define rep3(i, a, b) for (i64 i = (a); i < (i64)(b); ++i)
#define repi2(i, a) for (i64 i = (a) - 1; i >= 0; --i)
#define repi3(i, a, b) for (i64 i = (a) - 1; i >= (i64)(b); --i)
#define overload3(a, b, c, d, ...) d
#define rep(...) overload3(__VA_ARGS__, rep3, rep2, rep1)(__VA_ARGS__)
#define repi(...) overload3(__VA_ARGS__, repi3, repi2, rep1)(__VA_ARGS__)

#define fi first
#define se second
#define all(p) std::begin(p), std::end(p)

// #include "fastio.hpp"

template <class OStream, class T, class U>
OStream &operator<<(OStream &os, const std::pair<T, U> &p) {
    os << p.first << ' ' << p.second;
    return os;
}

template <class IStream, class T, class U> IStream &operator>>(IStream &is, std::pair<T, U> &p) {
    is >> p.first >> p.second;
    return is;
}

template <class OStream, class T> OStream &operator<<(OStream &os, const std::vector<T> &v) {
    for (int i = 0; i < (int)v.size(); i++) { os << v[i] << (i + 1 == (int)v.size() ? "" : " "); }
    return os;
}

template <class IStream, class T> IStream &operator>>(IStream &is, std::vector<T> &v) {
    for (auto &x : v) is >> x;
    return is;
}

void Yes(bool b = 1) {
    kout << (b ? "Yes\n" : "No\n");
}

void No(bool b = 1) {
    kout << (b ? "No\n" : "Yes\n");
}

void YES(bool b = 1) {
    kout << (b ? "YES\n" : "NO\n");
}

void NO(bool b = 1) {
    kout << (b ? "NO\n" : "YES\n");
}

void yes(bool b = 1) {
    kout << (b ? "yes\n" : "no\n");
}

void no(bool b = 1) {
    kout << (b ? "no\n" : "yes\n");
}

#endif // TEMPLATE

// #include "../../template/template.hpp"

using namespace std;

int main() {
    string s;
    kin >> s;
    int m;
    kin >> m;
    vc<string> c(m);
    kin >> c;

    kk2::AhoCorasick<26, 'A'> ac;
    for (auto &x : c) ac.add(x);
    ac.build();
    auto each = ac.each_match(s);
    i64 res = accumulate(all(each), 0LL);
    kout << res << endl;

    return 0;
}

// converted!!
0