結果

問題 No.3349 AtCoder Janken Train
コンテスト
ユーザー りすりす/TwoSquirrels
提出日時 2025-11-01 03:25:09
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 3,780 bytes
コンパイル時間 2,882 ms
コンパイル使用メモリ 286,712 KB
実行使用メモリ 13,500 KB
最終ジャッジ日時 2025-11-13 21:11:00
合計ジャッジ時間 4,903 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2 WA * 1
other AC * 13 WA * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

// Generated by GPT-5 mini

#include <bits/stdc++.h>
using namespace std;

const int MOD = 998244353;
using ll = long long;

ll modpow(ll a, ll e=MOD-2) {
    ll r=1;
    while(e){
        if(e&1) r=r*a%MOD;
        a=a*a%MOD;
        e>>=1;
    }
    return r;
}

// NTT implementation (iterative)
void ntt(vector<int>& a, bool invert) {
    int n = a.size();
    static vector<int> rev;
    static vector<int> roots{0,1};
    if ((int)rev.size() != n) {
        int k = __builtin_ctz(n);
        rev.assign(n,0);
        for (int i = 0; i < n; ++i)
            rev[i] = (rev[i>>1]>>1) | ((i&1) << (k-1));
    }
    for (int i = 0; i < n; ++i)
        if (i < rev[i])
            swap(a[i], a[rev[i]]);

    if ((int)roots.size() < n) {
        int k = __builtin_ctz(roots.size());
        roots.resize(n);
        while ((1<<k) < n) {
            // z is 2^(k+1)-th primitive root
            ll z = modpow(3, (MOD-1) >> (k+1));
            for (int i = 1<<(k-1); i < (1<<k); ++i)
                roots[2*i] = roots[i];
            for (int i = 1<<(k-1); i < (1<<k); ++i)
                roots[2*i+1] = (ll)roots[2*i] * z % MOD;
            ++k;
        }
    }

    for (int len = 1; len < n; len <<= 1) {
        for (int i = 0; i < n; i += 2*len) {
            for (int j = 0; j < len; ++j) {
                int u = a[i+j];
                int v = (ll)a[i+j+len] * roots[len+j] % MOD;
                a[i+j] = u+v < MOD ? u+v : u+v-MOD;
                a[i+j+len] = u-v >= 0 ? u-v : u-v+MOD;
            }
        }
    }

    if (invert) {
        reverse(a.begin()+1, a.end());
        ll inv_n = modpow(n);
        for (int i = 0; i < n; ++i)
            a[i] = (ll)a[i] * inv_n % MOD;
    }
}

vector<int> convolution(const vector<int>& a, const vector<int>& b) {
    if (a.empty() || b.empty()) return {};
    int n = a.size(), m = b.size();
    int need = n + m - 1;
    int sz = 1;
    while (sz < need) sz <<= 1;
    vector<int> fa(sz), fb(sz);
    for (int i = 0; i < n; ++i) fa[i] = a[i];
    for (int i = 0; i < m; ++i) fb[i] = b[i];
    ntt(fa, false);
    ntt(fb, false);
    for (int i = 0; i < sz; ++i) fa[i] = (ll)fa[i] * fb[i] % MOD;
    ntt(fa, true);
    fa.resize(need);
    return fa;
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int N;
    long long M;
    if (!(cin >> N >> M)) return 0;
    int L = 1<<N;
    if (M < 0 || M > L) {
        cout << 0 << "\n";
        return 0;
    }

    // g[k][w] for k from 0..N, but we keep only previous level
    vector<int> g_prev(2); // k=0, sizes upto 1
    g_prev[0] = 1; // 0 W -> "C"
    g_prev[1] = 1; // 1 W -> "W"

    for (int k = 1; k <= N; ++k) {
        // convolution h = g_prev * g_prev
        vector<int> h = convolution(g_prev, g_prev); // size up to 2*2^{k-1}+1
        int maxw = (1<<k);
        vector<int> g_cur(maxw+1);
        // g_cur[0] = h[0] (since allow(0,0)=1)
        g_cur[0] = h.size() > 0 ? h[0] % MOD : 0;
        // for w>0: g_cur[w] = h[w] - g_prev[0] * g_prev[w] (g_prev[w] zero if out of range)
        ll g0 = g_prev.size() > 0 ? g_prev[0] : 0;
        for (int w = 1; w <= maxw; ++w) {
            ll val = 0;
            if (w < (int)h.size()) val = h[w];
            ll sub = 0;
            if (w < (int)g_prev.size()) sub = (ll)g0 * g_prev[w] % MOD;
            val = (val - sub) % MOD;
            if (val < 0) val += MOD;
            g_cur[w] = (int)val;
        }
        g_prev.swap(g_cur);
    }

    // g_prev is g[N]
    int patterns = g_prev[(int)M];

    // multiply by M! * (L-M)! mod
    vector<ll> fact(L+1);
    fact[0] = 1;
    for (int i = 1; i <= L; ++i) fact[i] = fact[i-1] * i % MOD;
    ll ans = patterns;
    ans = ans * fact[M] % MOD;
    ans = ans * fact[L - M] % MOD;
    cout << ans % MOD << "\n";
    return 0;
}

0