#include #include #include #include using namespace std; using namespace atcoder; using ll = long long; using mint = modint998244353; int main() { ll n, m; cin >> n >> m; m = abs(m); mint ans = 0; vector f(n + 1, 1); for (ll i = 0; i < n; i++) f[i + 1] = f[i] * (i + 1); for (ll y = 0; y < n; y++) { ll x2 = y * y + m; ll x = max(ll(sqrt(x2) - 2), 0LL); while (x * x - y * y < m) x++; if (x * x - y * y != m) continue; if ((n - x - y) % 2) continue; if (n - x - y < 0) break; ll a = abs(n - x + y) / 2; ll b = abs(n - x - y) / 2; mint tmp = f[n] * f[n] / (f[a] * f[n - a] * f[b] * f[n - b]); if (y) tmp *= 2; if (x) tmp *= 2; ans += tmp; } ans /= mint(4).pow(n); cout << ans.val() << "\n"; return 0; }