結果

問題 No.3096 Snake Path
ユーザー autumn09
提出日時 2025-03-05 18:35:16
言語 C++17(gcc12)
(gcc 12.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 2,608 bytes
コンパイル時間 28,020 ms
コンパイル使用メモリ 359,308 KB
実行使用メモリ 17,336 KB
最終ジャッジ日時 2025-03-05 18:35:51
合計ジャッジ時間 31,228 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 4 WA * 7 TLE * 1 -- * 23
権限があれば一括ダウンロードができます

ソースコード

diff #

#pragma GCC target("avx2")
#pragma GCC optimize("O3")
#pragma GCC optimize("unroll-loops")

#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using namespace atcoder;
using ll = long long;
using pll = pair<ll, ll>;
using pii = pair<int, int>;
using mint = modint998244353;
constexpr ll mod = 998244353;
using MINT = modint1000000007;
constexpr ll MOD = 1000000007;
int dx[4] = {1, 0, -1, 0};
int dy[4] = {0, 1, 0, -1};

template <typename T>
void print(vector<T> A);

int main()
{
    ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int N, K;
    cin >> N >> K;
    vector<vector<vector<mint>>> dp(
        N+1, vector<vector<mint>>(3, vector<mint>(K + 1, 0)));
    dp[0][0][0] = 1;
    for (int i = 0; i < N; i++) {
        for (int j = 0; j < 3; j++) {
            for (int k = 0; k <= K; k++) {
                for (int nj = 0; nj < 3; nj++) {
                    int c = 1;
                    if (nj == 2 || i==N-1) {
                        c--;
                    }
                    if (k + c > K) {
                        continue;
                    }
                    dp[i + 1][nj][k + c] += dp[i][j][k];

                    if (j == 0 && nj == 2) {
                        for (int l = i + 2; l <= N; l++) {
                            if (k + (l - i - 1) * 2 > K) {
                                break;
                            }
                            dp[l][nj][k+(l-i-1)*2]+=dp[i][j][k];
                        }
                    }
                    if (j == 2 && nj == 0) {
                        for (int l = i + 2; l <= N; l++) {
                            if (l == N) {
                                if (k + (l - i - 1) * 2 > K) {
                                    break;
                                }
                                dp[l][nj][k + (l - i - 1) * 2] += dp[i][j][k];
                                continue;
                            }
                            if (k + (l - i - 1) * 2+1 > K) {
                                break;
                            }
                            dp[l][nj][k+(l-i-1)*2+1]+=dp[i][j][k];
                        }
                    }
                    
                }
            }
        }
    }

    mint ans = 0;
    int j = 0;
    if (N % 2) {
        j = 2;
    }
    for (int k = 0; k <= K; k++) {
        ans+=dp[N][j][k];
    }
    cout<<ans.val()<<endl;
    
}

template <typename T>
void print(vector<T> A)
{
    for (int i = 0; i < A.size() - 1; i++)
    {
        cout << A[i] << ' ';
    }
    cout << A[A.size() - 1] << endl;
    return;
}
0