結果

問題 No.2413 Multiple of 99
ユーザー hiromi_ayasehiromi_ayase
提出日時 2023-08-12 04:50:36
言語 Java21
(openjdk 21)
結果
TLE  
実行時間 -
コード長 14,409 bytes
コンパイル時間 2,611 ms
コンパイル使用メモリ 91,084 KB
実行使用メモリ 171,332 KB
最終ジャッジ日時 2024-04-29 19:22:35
合計ジャッジ時間 20,797 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 68 ms
38,388 KB
testcase_01 AC 68 ms
38,456 KB
testcase_02 AC 102 ms
39,960 KB
testcase_03 AC 67 ms
38,636 KB
testcase_04 AC 1,565 ms
50,836 KB
testcase_05 AC 3,130 ms
56,632 KB
testcase_06 AC 3,118 ms
56,352 KB
testcase_07 TLE -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.util.*;
import java.util.function.BiFunction;
import java.io.*;

@SuppressWarnings("unused")
public class Main {

  private static class Convolution {
    /**
     * Find a primitive root.
     *
     * @param m A prime number.
     * @return Primitive root.
     */
    private static int primitiveRoot(int m) {
      if (m == 2)
        return 1;
      if (m == 167772161)
        return 3;
      if (m == 469762049)
        return 3;
      if (m == 754974721)
        return 11;
      if (m == 998244353)
        return 3;

      int[] divs = new int[20];
      divs[0] = 2;
      int cnt = 1;
      int x = (m - 1) / 2;
      while (x % 2 == 0)
        x /= 2;
      for (int i = 3; (long) (i) * i <= x; i += 2) {
        if (x % i == 0) {
          divs[cnt++] = i;
          while (x % i == 0) {
            x /= i;
          }
        }
      }
      if (x > 1) {
        divs[cnt++] = x;
      }
      for (int g = 2;; g++) {
        boolean ok = true;
        for (int i = 0; i < cnt; i++) {
          if (pow(g, (m - 1) / divs[i], m) == 1) {
            ok = false;
            break;
          }
        }
        if (ok)
          return g;
      }
    }

    /**
     * Power.
     *
     * @param x Parameter x.
     * @param n Parameter n.
     * @param m Mod.
     * @return n-th power of x mod m.
     */
    private static long pow(long x, long n, int m) {
      if (m == 1)
        return 0;
      long r = 1;
      long y = x % m;
      while (n > 0) {
        if ((n & 1) != 0)
          r = (r * y) % m;
        y = (y * y) % m;
        n >>= 1;
      }
      return r;
    }

    /**
     * Ceil of power 2.
     *
     * @param n Value.
     * @return Ceil of power 2.
     */
    private static int ceilPow2(int n) {
      int x = 0;
      while ((1L << x) < n)
        x++;
      return x;
    }

    /**
     * Garner's algorithm.
     *
     * @param c    Mod convolution results.
     * @param mods Mods.
     * @return Result.
     */
    private static long garner(long[] c, int[] mods) {
      int n = c.length + 1;
      long[] cnst = new long[n];
      long[] coef = new long[n];
      java.util.Arrays.fill(coef, 1);
      for (int i = 0; i < n - 1; i++) {
        int m1 = mods[i];
        long v = (c[i] - cnst[i] + m1) % m1;
        v = v * pow(coef[i], m1 - 2, m1) % m1;

        for (int j = i + 1; j < n; j++) {
          long m2 = mods[j];
          cnst[j] = (cnst[j] + coef[j] * v) % m2;
          coef[j] = (coef[j] * m1) % m2;
        }
      }
      return cnst[n - 1];
    }

    /**
     * Pre-calculation for NTT.
     *
     * @param mod NTT Prime.
     * @param g   Primitive root of mod.
     * @return Pre-calculation table.
     */
    private static long[] sumE(int mod, int g) {
      long[] sum_e = new long[30];
      long[] es = new long[30];
      long[] ies = new long[30];
      int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
      long e = pow(g, (mod - 1) >> cnt2, mod);
      long ie = pow(e, mod - 2, mod);
      for (int i = cnt2; i >= 2; i--) {
        es[i - 2] = e;
        ies[i - 2] = ie;
        e = e * e % mod;
        ie = ie * ie % mod;
      }
      long now = 1;
      for (int i = 0; i < cnt2 - 2; i++) {
        sum_e[i] = es[i] * now % mod;
        now = now * ies[i] % mod;
      }
      return sum_e;
    }

    /**
     * Pre-calculation for inverse NTT.
     *
     * @param mod Mod.
     * @param g   Primitive root of mod.
     * @return Pre-calculation table.
     */
    private static long[] sumIE(int mod, int g) {
      long[] sum_ie = new long[30];
      long[] es = new long[30];
      long[] ies = new long[30];

      int cnt2 = Integer.numberOfTrailingZeros(mod - 1);
      long e = pow(g, (mod - 1) >> cnt2, mod);
      long ie = pow(e, mod - 2, mod);
      for (int i = cnt2; i >= 2; i--) {
        es[i - 2] = e;
        ies[i - 2] = ie;
        e = e * e % mod;
        ie = ie * ie % mod;
      }
      long now = 1;
      for (int i = 0; i < cnt2 - 2; i++) {
        sum_ie[i] = ies[i] * now % mod;
        now = now * es[i] % mod;
      }
      return sum_ie;
    }

    /**
     * Inverse NTT.
     *
     * @param a     Target array.
     * @param sumIE Pre-calculation table.
     * @param mod   NTT Prime.
     */
    private static void butterflyInv(long[] a, long[] sumIE, int mod) {
      int n = a.length;
      int h = ceilPow2(n);

      for (int ph = h; ph >= 1; ph--) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        long inow = 1;
        for (int s = 0; s < w; s++) {
          int offset = s << (h - ph + 1);
          for (int i = 0; i < p; i++) {
            long l = a[i + offset];
            long r = a[i + offset + p];
            a[i + offset] = (l + r) % mod;
            a[i + offset + p] = (mod + l - r) * inow % mod;
          }
          int x = Integer.numberOfTrailingZeros(~s);
          inow = inow * sumIE[x] % mod;
        }
      }
    }

    /**
     * Inverse NTT.
     *
     * @param a    Target array.
     * @param sumE Pre-calculation table.
     * @param mod  NTT Prime.
     */
    private static void butterfly(long[] a, long[] sumE, int mod) {
      int n = a.length;
      int h = ceilPow2(n);

      for (int ph = 1; ph <= h; ph++) {
        int w = 1 << (ph - 1), p = 1 << (h - ph);
        long now = 1;
        for (int s = 0; s < w; s++) {
          int offset = s << (h - ph + 1);
          for (int i = 0; i < p; i++) {
            long l = a[i + offset];
            long r = a[i + offset + p] * now % mod;
            a[i + offset] = (l + r) % mod;
            a[i + offset + p] = (l - r + mod) % mod;
          }
          int x = Integer.numberOfTrailingZeros(~s);
          now = now * sumE[x] % mod;
        }
      }
    }

    /**
     * Convolution.
     *
     * @param a   Target array 1.
     * @param b   Target array 2.
     * @param mod NTT Prime.
     * @return Answer.
     */
    public static long[] convolution(long[] a, long[] b, int mod) {
      int n = a.length;
      int m = b.length;
      if (n == 0 || m == 0)
        return new long[0];

      int z = 1 << ceilPow2(n + m - 1);
      {
        long[] na = new long[z];
        long[] nb = new long[z];
        System.arraycopy(a, 0, na, 0, n);
        System.arraycopy(b, 0, nb, 0, m);
        a = na;
        b = nb;
      }

      int g = primitiveRoot(mod);
      long[] sume = sumE(mod, g);
      long[] sumie = sumIE(mod, g);

      butterfly(a, sume, mod);
      butterfly(b, sume, mod);
      for (int i = 0; i < z; i++) {
        a[i] = a[i] * b[i] % mod;
      }
      butterflyInv(a, sumie, mod);
      a = java.util.Arrays.copyOf(a, n + m - 1);

      long iz = pow(z, mod - 2, mod);
      for (int i = 0; i < n + m - 1; i++)
        a[i] = a[i] * iz % mod;
      return a;
    }

    /**
     * Convolution.
     *
     * @param a   Target array 1.
     * @param b   Target array 2.
     * @param mod Any mod.
     * @return Answer.
     */
    public static long[] convolutionLL(long[] a, long[] b, int mod) {
      int n = a.length;
      int m = b.length;
      if (n == 0 || m == 0)
        return new long[0];

      int mod1 = 754974721;
      int mod2 = 167772161;
      int mod3 = 469762049;

      long[] c1 = convolution(a, b, mod1);
      long[] c2 = convolution(a, b, mod2);
      long[] c3 = convolution(a, b, mod3);

      int retSize = c1.length;
      long[] ret = new long[retSize];
      int[] mods = { mod1, mod2, mod3, mod };
      for (int i = 0; i < retSize; ++i) {
        ret[i] = garner(new long[] { c1[i], c2[i], c3[i] }, mods);
      }
      return ret;
    }

    /**
     * Naive convolution. (Complexity is O(N^2)!!)
     *
     * @param a   Target array 1.
     * @param b   Target array 2.
     * @param mod Mod.
     * @return Answer.
     */
    public static long[] convolutionNaive(long[] a, long[] b, int mod) {
      int n = a.length;
      int m = b.length;
      int k = n + m - 1;
      long[] ret = new long[k];
      for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
          ret[i + j] += a[i] * b[j] % mod;
          ret[i + j] %= mod;
        }
      }
      return ret;
    }
  }

  static class FPS {
    private final int mod;
    private final BiFunction<long[], long[], long[]> conv;

    public FPS(int mod, BiFunction<long[], long[], long[]> conv) {
      this.mod = mod;
      this.conv = conv;
    }

    private long inv(long a) {
      long b = mod;
      long p = 1, q = 0;
      while (b > 0) {
        long c = a / b;
        long d;
        d = a;
        a = b;
        b = d % b;
        d = p;
        p = q;
        q = d - c * q;
      }
      return p < 0 ? p + mod : p;
    }

    public long[] mul(long[] f, long[] g) {
      return Arrays.copyOf(conv.apply(f, g), f.length);
    }

    public long[] add(long[] f, long[] g) {
      int k = f.length;
      long[] ret = new long[k];
      for (int i = 0; i < k; i++) {
        ret[i] = (f[i] + g[i]) % mod;
      }
      return ret;
    }

    public long[] sub(long[] f, long[] g) {
      int k = f.length;
      long[] ret = new long[k];
      for (int i = 0; i < k; i++) {
        ret[i] = (f[i] - g[i] + mod) % mod;
      }
      return ret;
    }

    private long[] limit(long[] f, long g0, BiFunction<long[], long[], long[]> rec) {
      int k = f.length;
      long[] g = { g0 };

      for (int m = 0; (1 << m) <= k; m++) {
        int n = 1 << (m + 1);
        long[] fn = new long[n];
        long[] gn = new long[n];

        for (int j = 0; j < n; j++) {
          fn[j] = j < f.length ? f[j] : 0;
          gn[j] = j < g.length ? g[j] : 0;
        }
        g = Arrays.copyOf(rec.apply(fn, gn), n);
      }
      return Arrays.copyOf(g, k);
    }

    public long[] inv(long[] f) {
      BiFunction<long[], long[], long[]> rec = (fn, gn) -> {
        long[] h = mul(fn, gn);
        for (int i = 0; i < fn.length; i++) {
          h[i] = ((i == 0 ? 2 : 0) + mod - h[i]) % mod;
        }
        return mul(gn, h);
      };
      return limit(f, inv(f[0]), rec);
    }

    public long[] exp(long[] f) {
      assert (f[0] == 0);
      BiFunction<long[], long[], long[]> rec = (fn, gn) -> {
        int n = fn.length;
        long[] h = log(gn);
        for (int i = 0; i < n; i++) {
          h[i] = (fn[i] + (i == 0 ? 1 : 0) + mod - h[i]) % mod;
        }
        return mul(gn, Arrays.copyOf(h, gn.length));
      };
      long g0 = 1;
      return limit(f, g0, rec);
    }

    public long[] integral(long[] f) {
      int k = f.length;
      long[] ret = new long[k];
      for (int i = 0; i < k - 1; i++) {
        ret[i + 1] = (i < f.length ? f[i] : 0) * inv(i + 1) % mod;
      }
      return ret;
    }

    public long[] differential(long[] f) {
      int k = f.length;
      long[] ret = new long[k];
      for (int i = 1; i < k; i++) {
        ret[i - 1] = (i < f.length ? f[i] : 0) * i % mod;
      }
      return ret;
    }

    public long[] log(long[] f) {
      assert (f[0] == 1);
      return integral(mul(differential(f), inv(f)));
    }

    public long[] pow(long[] f, long n) {
      long[] log = log(f);
      for (int i = 0; i < f.length; i++) {
        log[i] = log[i] * n % mod;
      }
      return exp(log);
    }

    public long[] powNotLog(long[] f, long n) {
      long ans[] = new long[f.length];
      ans[0] = 1;
      while (n > 0) {
        if (n % 2 == 1)
          ans = mul(ans, f);
        ans = mul(ans, ans);
        n /= 2;
      }
      return ans;
    }
  }

  private static void solve() {
    int n = ni();
    int k = ni();

    int max = n * 9 + 1;
    long[] o = f((n + 1) / 2, max);
    long[] e = f(n / 2, max);


    long[] ret = new long[max];
    for (int i = 0; i < 11; i++) {
      long[] co = Arrays.copyOf(o, max);
      long[] ce = Arrays.copyOf(e, max);

      for (int j = 0; j < max; j++) {
        if (j % 11 != i) {
          co[j] = 0;
          ce[j] = 0;
        }
      }
      long[] cur = Convolution.convolution(co, ce, mod);
      for (int j = 0; j < max; j ++) {
        ret[j] = (ret[j] + cur[j]) % mod;
      }
    }

    long ans = 0;
    for (int i = 0; i <= n * 9; i += 9) {
      ans += pow(i, k, mod) * ret[i];
      ans %= mod;
    }
    System.out.println(ans);
  }

  public static long pow(long x, long n, long m){
      assert(n >= 0 && m >= 1);
      long ans = 1;
      while(n > 0){
          if(n%2==1) ans = (ans * x) % m;
          x = (x*x) % m;
          n /= 2;
      }
      return ans;
  }

  private static final int mod = 998244353;

  private static long[] f(int n, int max) {
    FPS fps = new FPS(mod, (o1, o2) -> Convolution.convolution(o1, o2, mod));

    long[] f = new long[max];
    for (int i = 0; i < 10; i++) {
      f[i] = 1;
    }
    return fps.pow(f, n);
  }

  public static void main(String[] args) {
    new Thread(null, new Runnable() {
      @Override
      public void run() {
        solve();
        out.flush();
      }
    }, "", 64000000).start();
  }

  private static PrintWriter out = new PrintWriter(System.out);
  private static StringTokenizer tokenizer = null;
  private static BufferedReader reader = new BufferedReader(new InputStreamReader(System.in), 32768);

  public static String next() {
    while (tokenizer == null || !tokenizer.hasMoreTokens()) {
      try {
        tokenizer = new java.util.StringTokenizer(reader.readLine());
      } catch (Exception e) {
        throw new RuntimeException(e);
      }
    }
    return tokenizer.nextToken();
  }

  private static double nd() {
    return Double.parseDouble(next());
  }

  private static long nl() {
    return Long.parseLong(next());
  }

  private static int[] na(int n) {
    int[] a = new int[n];
    for (int i = 0; i < n; i++)
      a[i] = ni();
    return a;
  }

  private static char[] ns() {
    return next().toCharArray();
  }

  private static long[] nal(int n) {
    long[] a = new long[n];
    for (int i = 0; i < n; i++)
      a[i] = nl();
    return a;
  }

  private static int[][] ntable(int n, int m) {
    int[][] table = new int[n][m];
    for (int i = 0; i < n; i++) {
      for (int j = 0; j < m; j++) {
        table[i][j] = ni();
      }
    }
    return table;
  }

  private static int[][] nlist(int n, int m) {
    int[][] table = new int[m][n];
    for (int i = 0; i < n; i++) {
      for (int j = 0; j < m; j++) {
        table[j][i] = ni();
      }
    }
    return table;
  }

  private static int ni() {
    return Integer.parseInt(next());
  }
}
0