#include #include #include #include using namespace std; using namespace atcoder; using ll = long long; using mint = modint998244353; int main() { ll n, m; cin >> n >> m; mint ans = 0; vector f(n + 1, 1); for (ll i = 0; i < n; i++) f[i + 1] = f[i] * (i + 1); for (ll y = 0; y < n; y++) { ll x2 = y * y + m; ll x = sqrt(x2) - 2; while (x * x - y * y < m) x++; if (x * x - y * y != m) continue; if ((n - x - y) % 2) continue; if (n - x - y < 0) break; ll a = abs(n - x + y) / 2; ll b = abs(n - x - y) / 2; mint tmp = f[n] * f[n] / (f[a] * f[n - a] * f[b] * f[n - b]); if (y) tmp *= 4; else tmp *= 2; ans += tmp; } ans /= mint(4).pow(n); cout << ans.val() << "\n"; return 0; }