結果

問題 No.1552 Simple Dice Game
ユーザー yuruhiyayuruhiya
提出日時 2021-06-24 22:06:51
言語 Crystal
(1.11.2)
結果
AC  
実行時間 1,348 ms / 2,500 ms
コード長 7,960 bytes
コンパイル時間 13,502 ms
コンパイル使用メモリ 295,516 KB
実行使用メモリ 5,376 KB
最終ジャッジ日時 2024-06-25 07:34:12
合計ジャッジ時間 29,216 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
5,248 KB
testcase_01 AC 2 ms
5,376 KB
testcase_02 AC 2 ms
5,376 KB
testcase_03 AC 1,320 ms
5,376 KB
testcase_04 AC 1 ms
5,376 KB
testcase_05 AC 2 ms
5,376 KB
testcase_06 AC 2 ms
5,376 KB
testcase_07 AC 1 ms
5,376 KB
testcase_08 AC 2 ms
5,376 KB
testcase_09 AC 1,030 ms
5,376 KB
testcase_10 AC 1,198 ms
5,376 KB
testcase_11 AC 1,206 ms
5,376 KB
testcase_12 AC 339 ms
5,376 KB
testcase_13 AC 1,014 ms
5,376 KB
testcase_14 AC 588 ms
5,376 KB
testcase_15 AC 742 ms
5,376 KB
testcase_16 AC 102 ms
5,376 KB
testcase_17 AC 165 ms
5,376 KB
testcase_18 AC 923 ms
5,376 KB
testcase_19 AC 1,335 ms
5,376 KB
testcase_20 AC 1,328 ms
5,376 KB
testcase_21 AC 1,335 ms
5,376 KB
testcase_22 AC 1,324 ms
5,376 KB
testcase_23 AC 1,348 ms
5,376 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

# require "/math/Mint"
# require "../atcoder/src/Math"
# ac-library.cr by hakatashi https://github.com/google/ac-library.cr
#
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

module AtCoder
  # Implements [ACL's Math library](https://atcoder.github.io/ac-library/master/document_en/math.html)
  module Math
    def self.extended_gcd(a, b)
      last_remainder, remainder = a.abs, b.abs
      x, last_x, y, last_y = 0_i64, 1_i64, 1_i64, 0_i64
      while remainder != 0
        new_last_remainder = remainder
        quotient, remainder = last_remainder.divmod(remainder)
        last_remainder = new_last_remainder
        x, last_x = last_x - quotient * x, x
        y, last_y = last_y - quotient * y, y
      end

      return last_remainder, last_x * (a < 0 ? -1 : 1)
    end

    # Implements atcoder::inv_mod(value, modulo).
    def self.inv_mod(value, modulo)
      gcd, inv = extended_gcd(value, modulo)
      if gcd != 1
        raise ArgumentError.new("#{value} and #{modulo} are not coprime")
      end
      inv % modulo
    end

    # Simplified AtCoder::Math.pow_mod with support of Int64
    def self.pow_mod(base, exponent, modulo)
      if exponent == 0
        return base.class.zero + 1
      end
      if base == 0
        return base
      end
      b = exponent > 0 ? base : inv_mod(base, modulo)
      e = exponent.abs
      ret = 1_i64
      while e > 0
        if e % 2 == 1
          ret = mul_mod(ret, b, modulo)
        end
        b = mul_mod(b, b, modulo)
        e //= 2
      end
      ret
    end

    # Caluculates a * b % mod without overflow detection
    @[AlwaysInline]
    def self.mul_mod(a : Int64, b : Int64, mod : Int64)
      if mod < Int32::MAX
        return a * b % mod
      end

      # 31-bit width
      a_high = (a >> 32).to_u64
      # 32-bit width
      a_low = (a & 0xFFFFFFFF).to_u64
      # 31-bit width
      b_high = (b >> 32).to_u64
      # 32-bit width
      b_low = (b & 0xFFFFFFFF).to_u64

      # 31-bit + 32-bit + 1-bit = 64-bit
      c = a_high * b_low + b_high * a_low
      c_high = c >> 32
      c_low = c & 0xFFFFFFFF

      # 31-bit + 31-bit
      res_high = a_high * b_high + c_high
      # 32-bit + 32-bit
      res_low = a_low * b_low
      res_low_high = res_low >> 32
      res_low_low = res_low & 0xFFFFFFFF

      # Overflow
      if res_low_high + c_low >= 0x100000000
        res_high += 1
      end

      res_low = (((res_low_high + c_low) & 0xFFFFFFFF) << 32) | res_low_low

      (((res_high.to_i128 << 64) | res_low) % mod).to_i64
    end

    @[AlwaysInline]
    def self.mul_mod(a, b, mod)
      typeof(mod).new(a.to_i64 * b % mod)
    end

    # Implements atcoder::crt(remainders, modulos).
    def self.crt(remainders, modulos)
      raise ArgumentError.new unless remainders.size == modulos.size

      total_modulo = 1_i64
      answer = 0_i64

      remainders.zip(modulos).each do |(remainder, modulo)|
        gcd, p = extended_gcd(total_modulo, modulo)
        if (remainder - answer) % gcd != 0
          return 0_i64, 0_i64
        end
        tmp = (remainder - answer) // gcd * p % (modulo // gcd)
        answer += total_modulo * tmp
        total_modulo *= modulo // gcd
      end

      return answer % total_modulo, total_modulo
    end

    # Implements atcoder::floor_sum(n, m, a, b).
    def self.floor_sum(n, m, a, b)
      n, m, a, b = n.to_i64, m.to_i64, a.to_i64, b.to_i64
      res = 0_i64

      if a < 0
        a2 = a % m
        res -= n * (n - 1) // 2 * ((a2 - a) // m)
        a = a2
      end

      if b < 0
        b2 = b % m
        res -= n * ((b2 - b) // m)
        b = b2
      end

      res + floor_sum_unsigned(n, m, a, b)
    end

    private def self.floor_sum_unsigned(n, m, a, b)
      res = 0_i64

      loop do
        if a >= m
          res += n * (n - 1) // 2 * (a // m)
          a = a % m
        end

        if b >= m
          res += n * (b // m)
          b = b % m
        end

        y_max = a * n + b
        break if y_max < m

        n = y_max // m
        b = y_max % m
        m, a = a, m
      end

      res
    end
  end
end

macro static_modint(name, mod)
  struct {{name}}
    MOD = Int64.new({{mod}})

    def self.zero
      new
    end

    def self.raw(value : Int64)
      result = new
      result.value = value
      result
    end

    getter value : Int64

    def initialize
      @value = 0i64
    end

    def initialize(value)
      @value = value.to_i64 % MOD
    end

    def initialize(m : self)
      @value = m.value
    end

    protected def value=(value : Int64)
      @value = value
    end

    def ==(m : self)
      value == m.value
    end

    def ==(m)
      value == m
    end

    def + : self
      self
    end

    def - : self
      self.class.raw(value != 0 ? MOD &- value : 0i64)
    end

    def +(v)
      self + self.class.new(v)
    end

    def +(m : self)
      x = value &+ m.value
      x &-= MOD if x >= MOD
      self.class.raw(x)
    end

    def -(v)
      self - self.class.new(v)
    end

    def -(m : self)
      x = value &- m.value
      x &+= MOD if x < 0
      self.class.raw(x)
    end

    def *(v)
      self * self.class.new(v)
    end

    def *(m : self)
      self.class.new(value &* m.value)
    end

    def /(v)
      self / self.class.new(v)
    end

    def /(m : self)
      raise DivisionByZeroError.new if m.value == 0
      a, b, u, v = m.to_i64, MOD, 1i64, 0i64
      while b != 0
        t = a // b
        a &-= t &* b
        a, b = b, a
        u &-= t &* v
        u, v = v, u
      end
      self.class.new(value &* u)
    end

    def //(v)
      self / v
    end

    def **(exponent : Int)
      t, res = self, self.class.raw(1i64)
      while exponent > 0
        res *= t if exponent & 1 == 1
        t *= t
        exponent >>= 1
      end
      res
    end

    {% for op in %w[< <= > >=] %}
      def {{op.id}}(other)
        raise NotImplementedError.new({{op}})
      end
    {% end %}

    def inv
      self.class.raw AtCoder::Math.inv_mod(value, MOD)
    end

    def succ
      self.class.raw(value != MOD &- 1 ? value &+ 1 : 0i64)
    end

    def pred
      self.class.raw(value != 0 ? value &- 1 : MOD &- 1)
    end

    def abs
      self
    end

    def to_i64 : Int64
      value
    end

    delegate to_s, to: @value
    delegate inspect, to: @value
  end

  {% to = ("to_" + name.stringify.downcase.gsub(/mint|modint/, "m")).id %}

  struct Int
    {% for op in %w[+ - * / //] %}
      def {{op.id}}(value : {{name}})
        {{to}} {{op.id}} value
      end
    {% end %}

    {% for op in %w[< <= > >=] %}
      def {{op.id}}(m : {{name}})
        raise NotImplementedError.new({{op}})
      end
    {% end %}

    def {{to}} : {{name}}
      {{name}}.new(self)
    end
  end

  class String
    def {{to}} : {{name}}
      {{name}}.new(self)
    end
  end
end

static_modint(Mint, 10**9 + 7)
static_modint(Mint2, 998244353)

n, m = read_line.split.map(&.to_i64)
sum1 = (1..m).sum do |x|
  xx = Mint2.new(x)
  cnt = (xx**(n - 1)) * n
  if x == 1
    cnt * 1 * 1
  else
    cnt2 = ((xx**n - xx.pred**n) * n - cnt) // (x - 1)
    xx * (xx * xx.pred // 2) * cnt2 + xx * xx * cnt
  end
end
sum2 = (1..m).sum do |x|
  y = m + 1 - x
  xx = Mint2.new(x)
  yy = Mint2.new(y)
  cnt = (yy**(n - 1)) * n
  if x == m
    cnt * m * m
  else
    cnt2 = ((yy**n - yy.pred**n) * n - cnt) // (y - 1)
    xx * (x.to_i64 + 1..m.to_i64).sum * cnt2 + xx * xx * cnt
  end
end
puts sum1 - sum2
0