#include #include #include #include using namespace std; using namespace atcoder; using ll = long long; using mint = modint998244353; int main() { ll n, m; cin >> n >> m; mint ans = 0; vector f(n + 1, 1); for (ll i = 0; i < n; i++) f[i + 1] = f[i] * (i + 1); for (ll y = 0; y < n; y++) { ll x2 = y * y + m; ll x = sqrt(x2) - 2; while (x * x - y * y < m) x++; if (x * x - y * y != m) continue; if ((n - x - y) % 2) continue; ll pm = (n - x - y) / 2; for (ll pmx = 0; pmx <= pm; pmx++) { ll pmy = pm - pmx; vector c = { x + pmx,pmx,y + pmy,pmy }; ll nc = n; mint tmp = 1; for (ll i = 0; i < 4; i++) { tmp *= f[nc] / (f[c[i]] * f[nc - c[i]]); nc -= c[i]; } if (y) ans += tmp * 4; else ans += tmp * 2; } } ans /= mint(4).pow(n); cout << ans.val() << "\n"; return 0; }