/* このコード、と~おれ! Be accepted! ∧_∧  (。・ω・。)つ━☆・*。 ⊂   ノ    ・゜+.  しーJ   °。+ *´¨)          .· ´¸.·*´¨) ¸.·*¨)           (¸.·´ (¸.·'* ☆ */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include /*多倍長整数/cpp_intで宣言 #include using namespace boost::multiprecision; */ //#pragma GCC target ("avx2") //#pragma GCC optimization ("O3") //#pragma GCC optimization ("unroll-loops") #define rep(i, n) for(int i = 0; i < (n); ++i) #define rep1(i, n) for(int i = 1; i <= (n); ++i) #define rep2(i, n) for(int i = 2; i < (n); ++i) #define repr(i, n) for(int i = n; i >= 0; --i) #define reprm(i, n) for(int i = n - 1; i >= 0; --i) #define printynl(a) printf(a ? "yes\n" : "no\n") #define printyn(a) printf(a ? "Yes\n" : "No\n") #define printYN(a) printf(a ? "YES\n" : "NO\n") #define printin(a) printf(a ? "possible\n" : "imposible\n") #define printdb(a) printf("%.50lf\n", a)//少数出力 #define printdbd(a) printf("%.16lf\n", a)//少数出力(桁少なめ) #define prints(s) printf("%s\n", s.c_str())//string出力 #define all(x) (x).begin(), (x).end() #define allsum(a, b, c) ((a + b) * c / 2)//等差数列の和、初項,末項,項数 #define pb push_back #define priq priority_queue #define rpriq priq, greater> #define deg_to_rad(deg) (((deg)/360.0)*2.0*PI) #define rad_to_deg(rad) (((rad)/2.0/PI)*360.0) #define Please return #define AC 0 #define addf(T) [](T a, T b){return (a + b);} #define minf(T) [](T a, T b){return min(a, b);} #define maxf(T) [](T a, T b){return max(a, b);} using ll = long long; constexpr int INF = 1073741823; constexpr int MINF = -1073741823; constexpr ll LINF = ll(4661686018427387903); constexpr ll MOD = 1000000007; const long double PI = acos(-1.0L); using namespace std; void scans(string& str) { char c; str = ""; scanf("%c", &c); if (c == '\n')scanf("%c", &c);//最初の改行対策 while (c != '\n' && c != -1) { str += c; scanf("%c", &c); } } void scanc(char& str) { char c; scanf("%c", &c); if (c == -1)return; while (c == '\n') { scanf("%c", &c); } str = c; } double acot(double x) { return PI / 2 - atan(x); } ll gcd(ll a, ll b) { if (b == 0) return a; return gcd(b, a % b); } ll lcm(ll number1, ll number2) {//lcmを求める return number1 / gcd(number1, number2) * number2; } ll LSB(ll n) { return (n & (-n)); } /*-----------------------------------------ここからコード-----------------------------------------*/ template< int mod > struct ModInt { int x; ModInt() : x(0) {} ModInt(ll y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {} ModInt& operator+=(const ModInt& p) { if ((x += p.x) >= mod) x -= mod; return *this; } ModInt& operator-=(const ModInt& p) { if ((x += mod - p.x) >= mod) x -= mod; return *this; } ModInt& operator*=(const ModInt& p) { x = (int)(1LL * x * p.x % mod); return *this; } ModInt& operator/=(const ModInt& p) { *this *= p.inverse(); return *this; } ModInt operator-() const { return ModInt(-x); } ModInt operator+(const ModInt& p) const { return ModInt(*this) += p; } ModInt operator-(const ModInt& p) const { return ModInt(*this) -= p; } ModInt operator*(const ModInt& p) const { return ModInt(*this) *= p; } ModInt operator/(const ModInt& p) const { return ModInt(*this) /= p; } bool operator==(const ModInt& p) const { return x == p.x; } bool operator!=(const ModInt& p) const { return x != p.x; } bool operator>(const ModInt& p) const { return x > p.x; } bool operator<(const ModInt& p) const { return x < p.x; } bool operator<=(const ModInt& p) const { return x <= p.x; } bool operator>=(const ModInt& p) const { return x >= p.x; } ModInt inverse() const { int a = x, b = mod, u = 1, v = 0, t; while (b > 0) { t = a / b; swap(a -= t * b, b); swap(u -= t * v, v); } return ModInt(u); } ModInt pow(int64_t n) const { ModInt ret(1), mul(x); while (n > 0) { if (n & 1) ret *= mul; mul *= mul; n >>= 1; } return ret; } friend ostream& operator<<(ostream& os, const ModInt& p) { return os << p.x; } friend istream& operator>>(istream& is, ModInt& a) { ll t; is >> t; a = ModInt< mod >(t); return (is); } static int get_mod() { return mod; } }; using modint = ModInt< MOD >;//MOD=1e9 + 7 using Modint = ModInt< 998244353 >; int main() { modint m, p; int n; cin >> n >> p; m = n; modint ans = 0; vector a(n), b(n); a[1] = 1; rep(i, n) { if (i < 2) { rep(j, i)b[j] += a[j]; continue; } a[i] = p * a[i - 1] + a[i - 2]; rep(j, i)b[j] = b[j] + a[j]; } rep(i, n) { ans += b[i] * a[i]; } cout << ans << '\n'; Please AC; }