結果

問題 No.2990 Interval XOR
ユーザー Vladimir Novikov
提出日時 2025-08-16 23:18:30
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 1,165 ms / 2,000 ms
コード長 9,269 bytes
コンパイル時間 5,334 ms
コンパイル使用メモリ 212,356 KB
実行使用メモリ 57,884 KB
最終ジャッジ日時 2025-08-16 23:18:55
合計ジャッジ時間 16,438 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 37
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘std::vector<static_modular_int<998244353> > forward(std::vector<std::pair<int, int> >, int)’:
main.cpp:246:25: warning: range-based ‘for’ loops with initializer only available with ‘-std=c++20’ or ‘-std=gnu++20’ [-Wc++20-extensions]
  246 |         for (int i = 0; auto c : segs) {
      |                         ^~~~
main.cpp: In function ‘void solve()’:
main.cpp:329:21: warning: range-based ‘for’ loops with initializer only available with ‘-std=c++20’ or ‘-std=gnu++20’ [-Wc++20-extensions]
  329 |     for (int i = 0; auto c : res) {
      |                     ^~~~

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using ll = long long;

#define all(a) begin(a), end(a)
#define rall(a) rbegin(a), rend(a)
#define len(a) (int)((a).size())

/*
 ! WARNING: MOD must be prime if you use division or .inv().
 ! WARNING: 2 * (MOD - 1) must be smaller than INT_MAX
 * Use .value to get the stored value.
 */
template<typename T>
int normalize(T value, int mod) {
    if (value < -mod || value >= 2 * mod) value %= mod;
    if (value < 0) value += mod;
    if (value >= mod) value -= mod;
    return value;
}

template<int mod>
struct static_modular_int {
    static_assert(mod - 2 <= std::numeric_limits<int>::max() - mod, "2(mod - 1) <= INT_MAX");
    using mint = static_modular_int<mod>;

    int value;

    static_modular_int() : value(0) {}
    static_modular_int(const mint &x) : value(x.value) {}

    template<typename T, typename U = std::enable_if_t<std::is_integral<T>::value>>
    static_modular_int(T value) : value(normalize(value, mod)) {}

    static constexpr int get_mod() {
		return mod;
	}

    template<typename T>
    mint power(T degree) const {
        mint prod = 1, a = *this;
        for (; degree > 0; degree >>= 1, a *= a)
            if (degree & 1)
                prod *= a;

        return prod;
    }

    mint inv() const {
        return power(mod - 2);
    }

    mint& operator=(const mint &x) {
        value = x.value;
        return *this;
    }

    mint& operator+=(const mint &x) {
        value += x.value;
        if (value >= mod) value -= mod;
        return *this;
    }

    mint& operator-=(const mint &x) {
        value -= x.value;
        if (value < 0) value += mod;
        return *this;
    }

    mint& operator*=(const mint &x) {
        value = int64_t(value) * x.value % mod;
        return *this;
    }

    mint& operator/=(const mint &x) {
        return *this *= x.inv();
    }

    friend mint operator+(const mint &x, const mint &y) {
        return mint(x) += y;
    }

    friend mint operator-(const mint &x, const mint &y) {
        return mint(x) -= y;
    }

    friend mint operator*(const mint &x, const mint &y) {
        return mint(x) *= y;
    }

    friend mint operator/(const mint &x, const mint &y) {
        return mint(x) /= y;
    }

    mint& operator++() {
        ++value;
        if (value == mod) value = 0;
        return *this;
    }

    mint& operator--() {
        --value;
        if (value == -1) value = mod - 1;
        return *this;
    }

    mint operator++(int) {
        mint prev = *this;
        value++;
        if (value == mod) value = 0;
        return prev;
    }

    mint operator--(int) {
        mint prev = *this;
        value--;
        if (value == -1) value = mod - 1;
        return prev;
    }

    mint operator-() const {
        return mint(0) - *this;
    }

    bool operator==(const mint &x) const {
        return value == x.value;
    }

    bool operator!=(const mint &x) const {
        return value != x.value;
    }

    bool operator<(const mint &x) const {
        return value < x.value;
    }

    template<typename T>
    explicit operator T() {
        return value;
    }

    friend std::istream& operator>>(std::istream &in, mint &x) {
        std::string s;
        in >> s;
        x = 0;
        bool neg = s[0] == '-';
        for (const auto c : s)
            if (c != '-')
                x = x * 10 + (c - '0');

        if (neg)
            x *= -1;

        return in;
    }

    friend std::ostream& operator<<(std::ostream &out, const mint &x) {
        return out << x.value;
    }

    static int primitive_root() {
        if constexpr (mod == 1'000'000'007)
            return 5;
        if constexpr (mod == 998'244'353)
            return 3;
        if constexpr (mod == 786433)
            return 10;

        static int root = -1;
        if (root != -1)
            return root;

        std::vector<int> primes;
        int value = mod - 1;
        for (int i = 2; i * i <= value; i++)
            if (value % i == 0) {
                primes.push_back(i);
                while (value % i == 0)
                    value /= i;
            }

        if (value != 1)
            primes.push_back(value);

        for (int r = 2;; r++) {
            bool ok = true;
            for (auto p : primes)
                if ((mint(r).power((mod - 1) / p)).value == 1) {
                    ok = false;
                    break;
                }

            if (ok)
                return root = r;
        }
    }
};

// constexpr int MOD = 1'000'000'007;
constexpr int MOD = 998'244'353;
using mint = static_modular_int<MOD>;

vector<mint> adamar(vector<mint> vec) {
    for (ll bit = 1; bit < len(vec); bit *= 2) {
        for (int i = 0; i < len(vec); ++i) {
            if (i & bit) {
                mint old1 = vec[i], old0 = vec[i - bit];
                vec[i] = old0 - old1;
                vec[i - bit] = old0 + old1;
            }
        }
    }
    return vec;
}

vector<mint> forward(vector<pair<int, int>> segs, int m) {
    struct tup {
        array<ll, 4> p, c;

        tup() { p = {0, 0, 0, 0}; c = {0, 0, 0, 0}; }
    };
    struct pr {
        array<ll, 2> p, c;

        pr() { p = {0, 0}; c = {0, 0}; }
    };
    vector<mint> vec(1 << m);
    for (ll bit = 1; bit < len(vec); bit *= 2) {
        ll mask = (bit * 2 - 1), high = bit;
        auto push = [&](ll l, ll r, ll ind, tup& p) {
            if ((l & high) == (r & high)) {
                p.p[ind] = ((l | (high - 1)) ^ (high - 1));
                p.c[ind] = (r - l + 1);
                return;
            }
            ll mid = ((r | (high - 1)) ^ (high - 1));
            p.p[ind + 1] = mid;
            p.c[ind + 1] = r - mid + 1;

            p.p[ind] = ((l | (high - 1)) ^ (high - 1));
            p.c[ind] = mid - l;
        };

        vector<tup> tups(len(segs));

        for (int i = 0; auto c : segs) {
            if (c.second - (c.second & mask) == c.first - (c.first & mask)) {
                push(c.first, c.second, 0, tups[i]);
            } else {
                if ((c.first & mask) != 0) push(c.first, (c.first | mask), 0, tups[i]);
                if (((c.second + 1) & mask) != 0) push((c.second | mask) ^ mask, c.second, 2, tups[i]);
            }
            ++i;
        }

        vector<pr> pairs(len(segs));
        for (int i = 0; i < len(segs); ++i) {
            pairs[i].p[0] = tups[i].p[0];
            pairs[i].c[0] = tups[i].c[0] - tups[i].c[1];
            pairs[i].p[1] = tups[i].p[2];
            pairs[i].c[1] = tups[i].c[2] - tups[i].c[3];
        }

        auto do_calc = [](const vector<pr>& pairs, ll bt, ll m) {
            vector<mint> res(1 << m);
            vector<array<mint, 2>> sum(1 << m, {1, 1}), diff(1 << m, {1, 1}), par(1 << m, {0, 0});
            for (auto c : pairs) {
                sum[c.p[0] ^ c.p[1]][0] *= (c.c[0] + c.c[1]);
                diff[c.p[0] ^ c.p[1]][0] *= (c.c[0] - c.c[1]);
                par[c.p[0]][0] += 1;
            }
            for (ll bit = 1; bit < len(res); bit *= 2) {
                for (int i = 0; i < len(res); i += bt) {
                    if (i & bit) {
                        array<mint, 2> old_sum1 = sum[i], old_sum0 = sum[i - bit], old_diff1 = diff[i], old_diff0 = diff[i - bit], old_par1 = par[i], old_par0 = par[i - bit];
                        sum[i][0] = old_sum1[1] * old_sum0[0];
                        sum[i][1] = old_sum1[0] * old_sum0[1];
                        sum[i - bit][0] = old_sum1[0] * old_sum0[0];
                        sum[i - bit][1] = old_sum1[1] * old_sum0[1];

                        diff[i][0] = old_diff1[1] * old_diff0[0];
                        diff[i][1] = old_diff1[0] * old_diff0[1];
                        diff[i - bit][0] = old_diff1[0] * old_diff0[0];
                        diff[i - bit][1] = old_diff1[1] * old_diff0[1];

                        par[i][0] = old_par1[1] + old_par0[0];
                        par[i][1] = old_par1[0] + old_par0[1];
                        par[i - bit][0] = old_par1[0] + old_par0[0];
                        par[i - bit][1] = old_par1[1] + old_par0[1];
                    }
                }
            }

            for (int i = 0; i < len(res); i += bt) {
                res[i] = sum[i][0] * diff[i][1];
                if (par[i][1].value % 2) res[i] *= mint(-1);
            }

            return res;
        };

        vector<mint> res = do_calc(pairs, bit, m);
        for (int i = 0; i < len(vec); i += bit) {
            vec[i] = res[i];
        }
    }

    vec[0] = 1;
    for (auto c : segs) vec[0] *= (c.second - c.first + 1);

    return vec;
}

void solve() {
    int n, m;
    cin >> n >> m;
    swap(n, m);
    vector<pair<int, int>> segs(n);
    for (auto &c : segs) cin >> c.first >> c.second;

    vector<mint> to = forward(segs, m);
    vector<mint> res = adamar(to);
    for (auto& c : res) c /= len(res);

    // for (auto c : res) cerr << c << " ";
    // cerr << endl;

    ll ans = 0;
    for (int i = 0; auto c : res) {
        cout << c << "\n";
        ans ^= ll(mint(2).power(i) * c);
        ++i;
    }

    //cout << ans << endl;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int t = 1;
    //cin >> t;
    while(t--) solve();

    return 0;
} 
0