結果
| 問題 | No.2406 Difference of Coordinate Squared | 
| コンテスト | |
| ユーザー |  | 
| 提出日時 | 2023-08-05 01:54:20 | 
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) | 
| 結果 | 
                                AC
                                 
                             | 
| 実行時間 | 452 ms / 2,000 ms | 
| コード長 | 20,501 bytes | 
| コンパイル時間 | 1,868 ms | 
| コンパイル使用メモリ | 161,544 KB | 
| 最終ジャッジ日時 | 2025-02-15 23:24:40 | 
| ジャッジサーバーID (参考情報) | judge1 / judge2 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 2 | 
| other | AC * 55 | 
コンパイルメッセージ
main.cpp:106: warning: "debug" redefined
  106 | #define debug(...) void(0)
      | 
main.cpp:101: note: this is the location of the previous definition
  101 | #define debug(x) void(0)
      | 
            
            ソースコード
//#pragma GCC target ("avx")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#ifndef ONLINE_JUDGE
#define _GLIBCXX_DEBUG
#endif
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <iomanip>
#include <string>
#include <bitset>
#include <functional>
#include <list>
#include <deque>
#include <utility>
#include <numeric>
#include <complex>
#include <cctype>
#include <climits>
#include <cassert>
#include <unordered_map>
#include <unordered_set>
#include <ctime>
#include <random>
#include <fstream>
#include <chrono>
//#include <regex>
//#include <cstdio>
//#include <atcoder/all>
//using namespace atcoder;
using namespace std;
//#define int long long
using ll = long long;
using vll = vector<ll>;
using vvll = vector<vll>;
using vvvll = vector<vvll>;
using vi = vector<int>;
using vvi = vector<vi>;
using vvvi = vector<vvi>;
using ld = long double; // long double で誤差を誤魔化せる場合もある
using vld = vector<ld>;
using vd = vector<double>;
using vvd = vector<vd>;
using vc = vector<char>;
using vvc = vector<vc>;
using vs = vector<string>;
using vb = vector<bool>;
using vvb = vector<vb>;
using pii = pair<int, int>;
using pcc = pair<char, char>;
using pll = pair<ll, ll>;
using pdd = pair<double, double>;
using pldld = pair<ld,ld>;
using vpii = vector<pii>;
using vpll = vector<pll>;
 
template<class T>bool chmax(T& a, const T& b) { if (a < b) { a = b; return 1; } return 0; }
template<class T>bool chmin(T& a, const T& b) { if (b < a) { a = b; return 1; } return 0; }
 
#define rep(i, n) for (ll i = 0; i < ll(n); i++)
#define repback(i, n) for (ll i = n-1; i >= 0; i--)
#define REP(i, a, b) for (ll i = a; i < ll(b); i++)
#define REPBACK(i, a, b) for (ll i = a-1; i >= ll(b); i--)
#define all(x) (x).begin(), (x).end()
#define UNIQUE(A) A.erase(unique(all(A)), A.end()) // sortしてから使う
#define include(y, x, H, W) (0 <= x && x < W && 0 <= y && y < H)
#define square(x) (x) * (x)
#define pb push_back
#define eb emplace_back
#define EPS (1e-10)
#define equals(a,b) (fabs((a) - (b)) < EPS)
#ifndef ONLINE_JUDGE
#define debug1(x) cout << "debug:" << (x) << endl
#define debug2(x, y) cout << "debug:" << (x) << " " << (y) << endl
#define debug3(x, y, z) cout << "debug:" << (x) << " " << (y) << " " << (z) << endl
#define debug4(x, y, z, w) cout << "debug:" << (x) << " " << (y) << " " << (z) << " " << (w) << endl
#define debug5(x, y, z, w, v) cout << "debug:" << (x) << " " << (y) << " " << (z) << " " << (w) << " " << (v) << endl
#define overload5(a, b, c, d, e, f, ...) f
#define debug(...) overload5(__VA_ARGS__, debug5, debug4, debug3, debug2, debug1)(__VA_ARGS__)
#define debug2C(x, y) cout << "debug:" << (x) << " : " << (y) << endl
#define debug3P(x, y, z) cout << "debug:" << (x) << ", " << (y) << " : " << (z) << endl
#define debug3C(x, y, z) cout << "debug:" << (x) << " : " << (y) << ", " << (z) << endl
#define debuga cerr << "a" << endl
#define TIMER_START TIME_START = clock()
#define TIMER_END TIME_END = clock()
#define TIMECHECK cerr << 1000.0 * static_cast<double>(clock() - TIME_START) / CLOCKS_PER_SEC << "ms" << endl
#else
#define debug(x) void(0)
#define debug2(x, y) void(0)
#define debug3(x, y, z) void(0)
#define debug4(x, y, z, w) void(0)
#define debug5(x, y, z, w, v) void(0)
#define debug(...) void(0)
#define debug2C(x, y) void(0)
#define debug3P(x, y, z) void(0)
#define debug3C(x, y, z) void(0)
#define debuga void(0)
#define TIMER_START void(0)
#define TIMER_END void(0)
#define TIMECHECK void(0)
#endif
#define YESNO(bool) if(bool){cout<<"YES"<<'\n';}else{cout<<"NO"<<'\n';}
#define yesno(bool) if(bool){cout<<"yes"<<'\n';}else{cout<<"no"<<'\n';}
#define YesNo(bool) if(bool){cout<<"Yes"<<'\n';}else{cout<<"No"<<'\n';}
#define POSIMPOS(bool) if(bool){cout<<"POSSIBLE"<<'\n';}else{cout<<"IMPOSSIBLE"<<'\n';}
#define PosImpos(bool) if(bool){cout<<"Possible"<<'\n';}else{cout<<"Impossible"<<'\n';}
#define posimpos(bool) if(bool){cout<<"possible"<<'\n';}else{cout<<"impossible"<<'\n';}
#define popcount __builtin_popcountll // ll は 64bit対応!
mt19937_64 rng(chrono::steady_clock::now().time_since_epoch().count());
clock_t TIME_START, TIME_END;
static const double pi = acos(-1.0);
const long long INFL = pow(10,18);
const long long INFLMAX = 9223372036854775807;
const int INF = pow(10,9);
const int INFMAX = 2147483647;
const int mod1 = 1000000007;
const int mod2 = 998244353;
const vi dx1 = {1,0,-1,0};
const vi dy1 = {0,1,0,-1};
const vi dx2 = {0, 1, 1, 1, 0, -1, -1, -1, 0};
const vi dy2 = {1, 1, 0, -1, -1, -1, 0, 1, 1};
// vector出力
template<typename T>
ostream& operator << (ostream& os, vector<T>& vec) {
	os << "[";
	for (int i = 0; i<vec.size(); i++) {
		os << vec[i] << (i + 1 == vec.size() ? "" : ", ");
	}
	os << "]";
	return os;
}
// pair出力
template<typename T, typename U>
ostream& operator << (ostream& os, pair<T, U>& pair_var) {
	os << "(" << pair_var.first << ", " << pair_var.second << ")";
	return os;
}
// map出力
template<typename T, typename U>
ostream& operator << (ostream& os, map<T, U>& map_var) {
	os << "{";
	for (auto itr = map_var.begin(); itr != map_var.end(); itr++) {
		os << "(" << itr->first << ", " << itr->second << ")";
		itr++;
		if(itr != map_var.end()) os << ", ";
		itr--;
	}
	os << "}";
	return os;
}
// set 出力
template<typename T>
ostream& operator << (ostream& os, set<T>& set_var) {
	os << "{";
	for (auto itr = set_var.begin(); itr != set_var.end(); itr++) {
		os << *itr;
		++itr;
		if(itr != set_var.end()) os << ", ";
		itr--;
	}
	os << "}";
	return os;
}
int GetTime(){
    return 1000.0*static_cast<double>(clock() - TIME_START) / CLOCKS_PER_SEC;
}
ll myRand(ll B) {return (unsigned long long)rng() % B;}
struct Edge{
    int to;
    ll cost;
    Edge(int to = 0, ll cost = 0):to(to), cost(cost){}
};
using Graph = vector<vector<Edge> >;
 
ll gcd(ll a, ll b){
    if(a < b) swap(a, b);
    if(b == 0) return a;
    if(a%b == 0) return b;
    else return gcd(b, a%b);
}
// 有理数型(x/y)
struct Fraction{
    long long x,y;
    // 約分
    void reduc(){
        ll minus = 1;
        if(x < 0) minus *= -1;
        if(y < 0) minus *= -1;
        ll g = gcd(abs(x), abs(y));
        x = minus * abs(x) / g;
        y = abs(y) / g;
    }
    
    Fraction(ll x = 0, ll y = 1): x(x), y(y) {reduc();};
    bool operator<(const Fraction& right) const {return x*right.y < y*right.x;}
    bool operator<=(const Fraction& right) const {return x*right.y <= y*right.x;}
    bool operator>(const Fraction& right) const {return x*right.y > y*right.x;}
    bool operator>=(const Fraction& right) const {return x*right.y >= y*right.x;}
    bool operator==(const Fraction& right) const {return x == right.x && y == right.y;}
    Fraction operator-() const {return Fraction(-x, y);}
    Fraction& operator+=(const Fraction& v){
        x = x*v.y + y*v.x;
        y *= v.y;
        reduc();
        return *this;
    }
    Fraction operator+(const Fraction& v) const {return Fraction(*this) += v;}
    Fraction& operator-=(const Fraction& v){
        x = x*v.y - y*v.x;
        y *= v.y;
        reduc();
        return *this;
    }
    Fraction operator-(const Fraction& v) const {return Fraction(*this) -= v;}
    Fraction& operator*=(const Fraction& v){
        x *= v.x;
        y *= v.y;
        reduc();
        return *this;
    }
    Fraction operator*(const Fraction& v) const {return Fraction(*this) *= v;}
    Fraction& operator/=(const Fraction& v){
        x *= v.y;
        y *= v.x;
        reduc();
        return *this;
    }
    Fraction operator/(const Fraction& v) const {return Fraction(*this) /= v;}
    Fraction inv() const {return Fraction(y,x);}
    Fraction pow(ll t) const {
        if(t < 0) return inv().pow(-t);
        Fraction a(1, 1), d = *this;
        while(t){
            d *= d;
            if(t & 1) a *= d;
            t >>= 1;
        }
        return a;
    }
    friend ostream& operator << (ostream& os, const Fraction& v){ return os << v.x << '/' << v.y;}
};
 
class Point{
    public:
        double x, y;
 
        Point(double x = 0, double y = 0): x(x), y(y) {}
 
        
        Point operator + (Point p) {return Point(x + p.x, y + p.y); }
        Point operator - (Point p) {return Point(x - p.x, y - p.y); }
        Point operator * (double a) {return Point(a * x, a * y); }
        Point operator / (double a) {return Point(x / a, y / a); }
        
 
        double norm() {return x * x + y * y; }
        double abs() {return sqrt(norm()); }
 
 
        bool operator == (const Point &p) const{
            return fabs(x - p.x) < EPS && fabs(y - p.y) < EPS;
        }
 
        bool operator < (const Point &p) const{
            return x != p.x ? x < p.x : y < p.y;
        }
};
double dot(Point a, Point b){return a.x * b.x + a.y * b.y;}
double cross(Point a, Point b){return a.x * b.y - a.y * b.x;}
 
// x % P を非負整数に直す
ll MOD(ll &x, const ll P){
    ll ret = x%P;
    if(ret < 0) ret += P;
    return x = ret;
}
 
// x^n % mod を計算
ll mpow(ll x, ll n, ll mod){
    x %= mod;
    ll ret = 1;
    while(n > 0){
        if(n & 1) ret = ret * x % mod;
        x = x * x % mod;
        n >>= 1;
    }
    return ret;
}
// x^nを計算
ll lpow(ll x, ll n){
    ll ret = 1;
    while(n > 0){
        if(n & 1) ret = ret * x;
        x = x * x;
        n >>= 1;
    }
    return ret;
}
int ceil_pow2(int n) {
    int x = 0;
    while ((1U << x) < (unsigned int)(n)) x++;
    return x;
}
 
// 10進数(long long) → 2進数(string)への変換
string toBinary(ll n)
{
    if(n == 0) return "0";
    
    assert(n > 0);
    string ret;
    while (n != 0){
        ret += ( n & 1 == 1 ? '1' : '0' );
        n >>= 1;
    }
    reverse(ret.begin(), ret.end());
 
    return ret;
}
 
// 2進数(string) → 10進数(long long)への変換
ll toDecimal(string S){
    ll ret = 0;
    for(int i = 0; i < S.size(); i++){
        ret *= 2LL;
        if(S[i] == '1') ret += 1;
    }
 
    return ret;
}
 
ll lcm(ll a, ll b){
    assert(gcd(a,b) != 0);
    return a / gcd(a, b) * b;
}
// 拡張ユークリッドの互除法
// ax + by = gcd(a, b) を満たす (x, y) が格納される (返り値: a と b の最大公約数)
long long extGCD(long long a, long long b, long long &x, long long &y) {
    if (b == 0) {
        x = 1;
        y = 0;
        return a;
    }
    // a = a/b * b + a % b を上式に再帰的に適用
    long long d = extGCD(b, a%b, y, x);
    y -= a/b * x; 
    return d;
}
// 行列同士の積
template<typename T>
vector<vector<T> > matrix_prod(vector<vector<T> > &A, vector<vector<T> > &B){
    // Aは i * k , Bは k * j の行列 -> 積で i * j の行列を返す
    assert(A[0].size() == B.size());
    int r = A.size();
    int c = B[0].size();
    vector<vector<T> > ret(r, vector<T>(c,0));
    rep(i,r){
        rep(j,c){
            rep(k,B.size()){
                ret[i][j] += A[i][k] * B[k][j];
            }
        }
    }
    return ret;
}
template< class T >
struct CumulativeSum2D {
  vector< vector< T > > data;
  CumulativeSum2D(int W, int H) : data(W + 1, vector< T >(H + 1, 0)) {}
  void add(int x, int y, T z) {
    ++x, ++y;
    if(x >= data.size() || y >= data[0].size()) return;
    data[x][y] += z;
  }
  void build() {
    for(int i = 1; i < data.size(); i++) {
      for(int j = 1; j < data[i].size(); j++) {
        data[i][j] += data[i][j - 1] + data[i - 1][j] - data[i - 1][j - 1];
      }
    }
  }
  T query(int sx, int sy, int gx, int gy) const {
    return (data[gx][gy] - data[sx][gy] - data[gx][sy] + data[sx][sy]);
  }
};
// 二項係数
struct Combination{
    int MAX;
    int MOD;
    vll fac,finv,inv;
    Combination(int MAX, int MOD) : MAX(MAX + 1), MOD(MOD){
        fac.resize(MAX + 1);
        finv.resize(MAX + 1);
        inv.resize(MAX + 1);
        COMinit();
    }
    // テーブルを作る前処理
    void COMinit() {
        fac[0] = fac[1] = 1;
        finv[0] = finv[1] = 1;
        inv[1] = 1;
        for (int i = 2; i < MAX; i++){
            fac[i] = fac[i - 1] * i % MOD;
            inv[i] = MOD - inv[MOD%i] * (MOD / i) % MOD;
            finv[i] = finv[i - 1] * inv[i] % MOD;
        }
    }
    // 二項係数計算
    long long COM(int n, int k){
        if (n < k) return 0;
        if (n < 0 || k < 0) return 0;
        return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
    }
};
// segment tree (from ACL)
// 型S, 二項演算 S op(S a, S b), 単位元 S e() を定義する必要有、モノイドが対象
template <class S, S (*op)(S, S), S (*e)()> struct segtree {
  public:
    segtree() : segtree(0) {}
    explicit segtree(int n) : segtree(std::vector<S>(n, e())) {} // 引数に int n で長さnの数列a(初期値e())を作る
    explicit segtree(const std::vector<S>& v) : _n(int(v.size())) { // 引数に vector<S> v で長さn = v.size() の数列a(初期値はvに従う)を作る
        log = ceil_pow2(_n);
        size = 1 << log;
        d = std::vector<S>(2 * size, e());
        for (int i = 0; i < _n; i++) d[size + i] = v[i];
        for (int i = size - 1; i >= 1; i--) {
            update(i);
        }
    }
    // a[p]にxを代入(一点更新)
    void set(int p, S x) {
        assert(0 <= p && p < _n);
        p += size;
        d[p] = x;
        for (int i = 1; i <= log; i++) update(p >> i);
    }
    // a[p]を返す(一点取得)
    S get(int p) const {
        assert(0 <= p && p < _n);
        return d[p + size];
    }
    // op(a[l], ……, a[r-1]) を計算する(区間に対する演算)
    S prod(int l, int r) const {
        assert(0 <= l && l <= r && r <= _n);
        S sml = e(), smr = e();
        l += size;
        r += size;
        while (l < r) {
            if (l & 1) sml = op(sml, d[l++]);
            if (r & 1) smr = op(d[--r], smr);
            l >>= 1;
            r >>= 1;
        }
        return op(sml, smr);
    }
    // op(a[0], ……, a[n-1]) を計算する(全体に対する演算)
    S all_prod() const { return d[1]; }
    // segment tree 上での二分探索
    template <bool (*f)(S)> int max_right(int l) const {
        return max_right(l, [](S x) { return f(x); });
    }
    template <class F> int max_right(int l, F f) const {
        assert(0 <= l && l <= _n);
        assert(f(e()));
        if (l == _n) return _n;
        l += size;
        S sm = e();
        do {
            while (l % 2 == 0) l >>= 1;
            if (!f(op(sm, d[l]))) {
                while (l < size) {
                    l = (2 * l);
                    if (f(op(sm, d[l]))) {
                        sm = op(sm, d[l]);
                        l++;
                    }
                }
                return l - size;
            }
            sm = op(sm, d[l]);
            l++;
        } while ((l & -l) != l);
        return _n;
    }
    template <bool (*f)(S)> int min_left(int r) const {
        return min_left(r, [](S x) { return f(x); });
    }
    template <class F> int min_left(int r, F f) const {
        assert(0 <= r && r <= _n);
        assert(f(e()));
        if (r == 0) return 0;
        r += size;
        S sm = e();
        do {
            r--;
            while (r > 1 && (r % 2)) r >>= 1;
            if (!f(op(d[r], sm))) {
                while (r < size) {
                    r = (2 * r + 1);
                    if (f(op(d[r], sm))) {
                        sm = op(d[r], sm);
                        r--;
                    }
                }
                return r + 1 - size;
            }
            sm = op(d[r], sm);
        } while ((r & -r) != r);
        return 0;
    }
  private:
    int _n, size, log;
    std::vector<S> d;
    void update(int k) { d[k] = op(d[2 * k], d[2 * k + 1]); }
};
// Binary Indexed Tree
template <typename T>
struct BIT {
    int n;          // 配列の要素数(数列の要素数+1)
    vector<T> bit;  // データの格納先
    BIT(int n_) : n(n_ + 1), bit(n, 0) {} // 1-indexed
    // A_i += x
    void add(int idx, T x) {
        while(idx < n){ // n <- n+1 に予めしてるため等号を含まないことに注意
            bit[idx] += x;
            idx += (idx & -idx);
        }
    }
    // A_1 ~ A_i の和を計算
    T sum(int idx) {
        T ret(0);
        while(idx > 0){ 
            ret += bit[idx];
            idx -= (idx & -idx);
        }
        return ret;
    }
    // A_1 + A_2 + ... + A_x >= w となるような最小の x を求める (A_i >= 0)
    int lower_bound(T w) { 
    if (w <= 0) return 0;
    else {
        int x = 0, r = 1;
        while (r < n) r = r << 1;
        for (int len = r; len > 0; len = len >> 1) { 
            if (x + len < n && bit[x + len] < w) { 
                w -= bit[x + len];
                x += len;
                }
            }
            return x + 1;
        }
    }
};
// Union-Find
struct unionfind{
    vector<int> par, siz;
 
    // 初期化
    unionfind(int n) : par(n, -1), siz(n, 1) {}
 
    // 根を求める
    int root(int x) {
        if (par[x] == -1) return x;
        else return par[x] = root(par[x]);
    }
 
    // xとyの根(グループ)が一致するかどうか
    bool issame(int x, int y){
        return root(x) == root(y);
    }
 
    // xとyのグループの併合
    bool unite(int x, int y){
        x = root(x); y = root(y);
 
        if (x == y) return false;
 
        if (siz[x] < siz[y]) swap(x,y);
 
        par[y] = x;
        siz[x] += siz[y];
        return true;
    }
 
    // xを含むグループのサイズ
    int size(int x){
        return siz[root(x)];
    }
};
// ############################
// #                          #
// #    C O D E  S T A R T    #
// #                          #
// ############################
void solve() {   
    // D - 大ジャンプ
    ll N; cin >> N;
    ll M; cin >> M;
    bool neg = false;
    if(M < 0) {
        neg = true;
        M = -M;
    }
    Combination C(2*N+1, mod2);
    if(M == 0) { // corner case
        ll ans = 0;
        // rep(x,N/2+1) {
        //     ll now = 0;
        //     if((N-2*x) % 2 == 1) continue;
        //     for(ll i = x, y = x; i <= N && N - i >= y; i+=2) {
        //         ll left = N - i;
        //         now += (C.COM(N,i) * C.COM(i,x + (i-x)/2) % mod2) * C.COM(left, y + (left-y)/2) % mod2;
        //         now %= mod2;
        //         //debug(i,now);
        //     } 
        //     if(x != 0) now *= 4;
        //     ans += now;
        //     ans %= mod2;  
        // }
        // cout << ans * mpow(all ,mod2-2, mod2) % mod2 << '\n';
        if(N % 2 == 0) { 
            ll g = C.COM(N,N/2) * mpow(mpow(2, N, mod2), mod2-2, mod2) % mod2; // (X+Y, X-Y)座標系(45度回転)でみたときに、片方が0となればよい(もう片方はこのとき任意) -> 下のコメントの通り、+/-はそれぞれ1/2ずつなので、1次元でのランダムウォークを考えれば良い
            ans = 2*g;
            ans -= g*g % mod2; // 両方とも0のときは、ダブルカウントしているので引く
            MOD(ans, mod2);
        }
        cout << ans << '\n';
        return;
    }
    vpll cands;
    REP(i,1,M+1) {
        if(i*i > M) break;
        if(M % i == 0) {
            cands.pb({i,M/i});
        }
    }
    ll all = mpow(4,N,mod2);
    ll ans = 0;
    // 1 : (X+Y)++, (X-Y)++;
    // 2 : (X+Y)--, (X-Y)--;
    // 3 : (X+Y)++, (X-Y)--, diff -= 2;
    // 4 : (X+Y)--, (X-Y)++, diff += 2;
    for(auto [x,y]:cands) {
        if(y > N) continue;
        if((y-x) % 2 == 1) continue;
        ll now = 0;
        ll diff = (y-x) / 2;
        for(ll i = diff; i <= N && N - i >= abs(y - diff); i+=2) {
            ll left = N - i;
            now += (C.COM(N,i) * C.COM(i,diff + (i-diff)/2) % mod2) * C.COM(left, abs(y-diff) + (left-abs(y-diff))/2) % mod2;
            now %= mod2;
            //debug(i,now);
        }
        if(x != y) {
            now *= 4;
            now %= mod2;
        }
        else {
            now *= 2;
            now %= mod2;
        }
        ans += now;
        ans %= mod2;
    }
    cout << ans * mpow(all, mod2-2, mod2) % mod2 << '\n';
}
signed main() {
    cin.tie(0);
    ios_base::sync_with_stdio(false);
    TIMER_START;
    //cout << fixed << setprecision(15);
    
    int tt;
    tt = 1;
    //cin >> tt;
    while(tt--){
        solve();
    }
    
    TIMER_END;
    TIMECHECK;
    return 0;
}
            
            
            
        