結果
| 問題 | No.2219 Re:010 | 
| コンテスト | |
| ユーザー |  | 
| 提出日時 | 2023-02-18 14:51:02 | 
| 言語 | C++14 (gcc 13.3.0 + boost 1.87.0) | 
| 結果 | 
                                AC
                                 
                             | 
| 実行時間 | 21 ms / 2,000 ms | 
| コード長 | 3,007 bytes | 
| コンパイル時間 | 1,590 ms | 
| コンパイル使用メモリ | 169,896 KB | 
| 実行使用メモリ | 16,128 KB | 
| 最終ジャッジ日時 | 2024-07-20 01:22:08 | 
| 合計ジャッジ時間 | 2,928 ms | 
| ジャッジサーバーID (参考情報) | judge3 / judge4 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 21 | 
ソースコード
// Problem: No.2219 Re:010
// Contest: yukicoder
// URL: https://yukicoder.me/problems/no/2219
// Memory Limit: 512 MB
// Time Limit: 2000 ms
#include <bits/stdc++.h>
#define fastio ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
#define dbg(x) cout << #x << " = " << (x) << "\n";
#define popcount(x) __builtin_popcountll((x))
#define all(v) (v).begin(), (v).end()
#define pb emplace_back
#define x first
#define y second
using namespace std;
typedef long long ll;
typedef pair<ll, ll> pll;
const int N = 2e5 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
namespace COM {
    ll ksm(ll x, ll n) {
        x %= mod;
        ll res = 1;
        while (n) {
            if (n & 1) res = (res * x) % mod;
            x = x * x % mod;
            n >>= 1;
        }
        return res;
    }
    ll inv(ll x) { return ksm(x, mod - 2); }
    ll fac[N], invfac[N];
    void init() {
        fac[0] = 1;
        for (int i = 1; i < N; ++i) fac[i] = (fac[i - 1] * i) % mod;
        invfac[N - 1] = inv(fac[N - 1]);
        for (int i = N - 2; i >= 0; --i) invfac[i] = (invfac[i + 1] * (i + 1)) % mod;
    }
    ll C(int n, int m) {
        if (n < m || m < 0) return 0;
        return (fac[n] * invfac[m] % mod) * invfac[n - m] % mod;
    }
}  // namespace COM
char str[N];
ll l0[N], lq[N], r0[N], rq[N], l1[N], r1[N];
int n;
void solve() {
    COM::init();
    cin >> str + 1;
    n = strlen(str + 1);
    for (int i = 1; i <= n; i++) {
        l1[i] = l1[i - 1] + (str[i] == '1');
        l0[i] = l0[i - 1] + (str[i] == '0');
        lq[i] = lq[i - 1] + (str[i] == '?');
    }
    for (int i = n; i >= 1; i--) {
        r1[i] = r1[i + 1] + (str[i] == '1');
        r0[i] = r0[i + 1] + (str[i] == '0');
        rq[i] = rq[i + 1] + (str[i] == '?');
    }
    ll ans = 0, pre = 0, q1 = 0;
    for (int i = 1; i <= n; i++) {
        if (str[i] == '?') q1 = (q1 + l0[i - 1] * r0[i + 1] % mod + pre) % mod;
        if (str[i] == '1') pre += l0[i], pre %= mod;
    }
    pre = 0;
    for (int i = n; i >= 1; i--) {
        if (str[i] == '?') q1 = (q1 + pre) % mod;
        if (str[i] == '1') pre += r0[i], pre %= mod;
    }
    if (lq[n] >= 1) ans = q1 * COM::ksm(2, lq[n] - 1) % mod;
    ll q0 = 0;
    for (int i = n; i >= 1; i--) {
        if (str[i] == '1') q0 = (q0 + (l0[i] * r0[i]) % mod) % mod;
    }
    ans = (ans + q0 * COM::ksm(2, lq[n]) % mod) % mod;
    ans %= mod;
    ll q2 = 0;
    for (int i = 1; i <= n; i++) {
        if (str[i] == '1') {
            q2 = (q2 + lq[i] * rq[i] % mod) % mod;
        }
        if (str[i] == '0') {
            q2 = (q2 + COM::C(lq[i], 2)) % mod;
            q2 = (q2 + COM::C(rq[i], 2)) % mod;
        }
    }
    if (lq[n] >= 2) {
        ans = ans + q2 * COM::ksm(2ll, lq[n] - 2) % mod;
    }
    ans %= mod;
    if (lq[n] >= 3) ans = ans + COM::C(lq[n], 3) * COM::ksm(2, lq[n] - 3) % mod;
    ans %= mod;
    cout << ans << "\n";
}
int main() {
    fastio;
    int t = 1;
    // cin >> t;
    while (t--) {
        solve();
    }
    return 0;
}
            
            
            
        