結果
| 問題 |
No.2891 Mint
|
| コンテスト | |
| ユーザー |
tonegawa
|
| 提出日時 | 2024-09-13 22:42:17 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 50 ms / 2,000 ms |
| コード長 | 6,588 bytes |
| コンパイル時間 | 759 ms |
| コンパイル使用メモリ | 92,788 KB |
| 最終ジャッジ日時 | 2025-02-24 08:00:15 |
|
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 54 |
ソースコード
#include <iostream>
#include <cmath>
#include <cassert>
#include <algorithm>
#include <tuple>
using namespace std;
#include <type_traits>
// @param m `1 <= m`
constexpr long long safe_mod(long long x, long long m){
x %= m;
if (x < 0) x += m;
return x;
}
// x^n mod m
// @param n `0 <= n`
// @param m `1 <= m`
constexpr long long pow_mod_constexpr(long long x, long long n, int m) {
if (m == 1) return 0;
unsigned int _m = (unsigned int)(m);
unsigned long long r = 1;
unsigned long long y = safe_mod(x, m);
while (n) {
if (n & 1) r = (r * y) % _m;
y = (y * y) % _m;
n >>= 1;
}
return r;
}
constexpr bool miller_rabin32_constexpr(int n) {
if (n <= 1) return false;
if (n == 2 || n == 7 || n == 61) return true;
if (n % 2 == 0) return false;
long long d = n - 1;
while (d % 2 == 0) d /= 2;
constexpr long long bases[3] = {2, 7, 61};
for (long long a : bases) {
long long t = d;
long long y = pow_mod_constexpr(a, t, n);
while (t != n - 1 && y != 1 && y != n - 1) {
y = y * y % n;
t <<= 1;
}
if (y != n - 1 && t % 2 == 0) {
return false;
}
}
return true;
}
template<int n>
constexpr bool miller_rabin32 = miller_rabin32_constexpr(n);
// -10^18 <= _a, _b <= 10^18
long long gcd(long long _a, long long _b) {
long long a = abs(_a), b = abs(_b);
if (a == 0) return b;
if (b == 0) return a;
int shift = __builtin_ctzll(a | b);
a >>= __builtin_ctzll(a);
do{
b >>= __builtin_ctzll(b);
if(a > b) std::swap(a, b);
b -= a;
} while (b);
return a << shift;
}
// 最大でa*b
// -10^18 <= a, b <= 10^18
// a, bは負でもいいが非負の値を返す
__int128_t lcm(long long a, long long b) {
a = abs(a), b = abs(b);
long long g = gcd(a, b);
if (!g) return 0;
return __int128_t(a) * b / g;
}
// {x, y, gcd(a, b)} s.t. ax + by = gcd(a, b)
// g >= 0
std::tuple<long long, long long, long long> extgcd(long long a, long long b) {
long long x, y;
for (long long u = y = 1, v = x = 0; a;) {
long long q = b / a;
std::swap(x -= q * u, u);
std::swap(y -= q * v, v);
std::swap(b -= q * a, a);
}
// x + k * (b / g), y - k * (a / g) も条件を満たす(kは任意の整数)
return {x, y, b};
}
// @param b `1 <= b`
// @return pair(g, x) s.t. g = gcd(a, b), xa = g (mod b), 0 <= x < b/g
constexpr std::pair<long long, long long> inv_gcd(long long a, long long b) {
a = safe_mod(a, b);
if (a == 0) return {b, 0};
long long s = b, t = a;
long long m0 = 0, m1 = 1;
while (t) {
long long u = s / t;
s -= t * u;
m0 -= m1 * u;
auto tmp = s;
s = t;
t = tmp;
tmp = m0;
m0 = m1;
m1 = tmp;
}
if (m0 < 0) m0 += b / s;
return {s, m0};
}
template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct modint32_static {
using mint = modint32_static;
public:
static constexpr int mod() { return m; }
static mint raw(int v) {
mint x;
x._v = v;
return x;
}
modint32_static(): _v(0) {}
template <class T>
modint32_static(T v) {
long long x = v % (long long)umod();
if (x < 0) x += umod();
_v = x;
}
unsigned int val() const { return _v; }
mint& operator ++ () {
_v++;
if (_v == umod()) _v = 0;
return *this;
}
mint& operator -- () {
if (_v == 0) _v = umod();
_v--;
return *this;
}
mint operator ++ (int) {
mint result = *this;
++*this;
return result;
}
mint operator -- (int) {
mint result = *this;
--*this;
return result;
}
mint& operator += (const mint& rhs) {
_v += rhs._v;
if (_v >= umod()) _v -= umod();
return *this;
}
mint& operator -= (const mint& rhs) {
_v -= rhs._v;
if (_v >= umod()) _v += umod();
return *this;
}
mint& operator *= (const mint& rhs) {
unsigned long long z = _v;
z *= rhs._v;
_v = (unsigned int)(z % umod());
return *this;
}
mint& operator /= (const mint& rhs) { return *this = *this * rhs.inv(); }
mint operator + () const { return *this; }
mint operator-() const { return mint() - *this; }
mint pow(long long n) const {
assert(0 <= n);
mint x = *this, r = 1;
while (n) {
if (n & 1) r *= x;
x *= x;
n >>= 1;
}
return r;
}
mint inv() const {
if (prime) {
assert(_v);
return pow(umod() - 2);
} else {
auto eg = inv_gcd(_v, m);
assert(eg.first == 1);
return eg.second;
}
}
friend mint operator + (const mint& lhs, const mint& rhs) { return mint(lhs) += rhs; }
friend mint operator - (const mint& lhs, const mint& rhs) { return mint(lhs) -= rhs; }
friend mint operator * (const mint& lhs, const mint& rhs) { return mint(lhs) *= rhs; }
friend mint operator / (const mint& lhs, const mint& rhs) { return mint(lhs) /= rhs; }
friend bool operator == (const mint& lhs, const mint& rhs) { return lhs._v == rhs._v; }
friend bool operator != (const mint& lhs, const mint& rhs) { return lhs._v != rhs._v; }
private:
unsigned int _v;
static constexpr unsigned int umod() { return m; }
static constexpr bool prime = miller_rabin32<m>;
};
template<int m>
std::ostream &operator<<(std::ostream &dest, const modint32_static<m> &a) {
dest << a.val();
return dest;
}
using modint998244353 = modint32_static<998244353>;
using modint1000000007 = modint32_static<1000000007>;
// x%1 + x%2 ... x%n
// @param 0 <= x, n
__int128_t mod_sum(__int128_t x, __int128_t n){
assert(0 <= x && 0 <= n);
if(x == 0) return 0;
__int128_t ans = x * n;
if(n > x) n = x;
// floor(x / i)の値でグループ分け
__int128_t sq = sqrtl(x);
for(int i = 1; i <= sq; i++){
__int128_t l = x / (i + 1) + 1, r = std::min(n + 1, x / i + 1);
if(l < r){
__int128_t sum_lr = (r * (r - 1) - l * (l - 1)) / 2;
ans -= i * sum_lr;
}
}
if(x / sq == sq) sq--;
for(__int128_t i = std::min(n, sq); i >= 1; i--){
ans -= i * (x / i);
}
return ans;
}
int main(){
long long x, n;
std::cin >> x >> n;
__int128_t ans = mod_sum(n, x);
ans %= 998244353;
std::cout << (long long)ans << '\n';
}
tonegawa