結果

問題 No.1595 The Final Digit
ユーザー ansainansain
提出日時 2021-07-09 21:36:27
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
AC  
実行時間 548 ms / 2,000 ms
コード長 2,931 bytes
コンパイル時間 135 ms
コンパイル使用メモリ 12,800 KB
実行使用メモリ 44,468 KB
最終ジャッジ日時 2024-07-01 15:27:25
合計ジャッジ時間 13,756 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 527 ms
44,220 KB
testcase_01 AC 523 ms
44,212 KB
testcase_02 AC 525 ms
44,468 KB
testcase_03 AC 532 ms
43,956 KB
testcase_04 AC 522 ms
44,464 KB
testcase_05 AC 545 ms
44,460 KB
testcase_06 AC 532 ms
44,212 KB
testcase_07 AC 525 ms
44,464 KB
testcase_08 AC 530 ms
43,824 KB
testcase_09 AC 526 ms
44,336 KB
testcase_10 AC 537 ms
44,088 KB
testcase_11 AC 526 ms
44,220 KB
testcase_12 AC 533 ms
44,464 KB
testcase_13 AC 548 ms
44,464 KB
testcase_14 AC 543 ms
44,332 KB
testcase_15 AC 534 ms
44,468 KB
testcase_16 AC 531 ms
43,828 KB
testcase_17 AC 541 ms
44,332 KB
testcase_18 AC 543 ms
44,460 KB
testcase_19 AC 539 ms
44,204 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import numpy as np
import sys
from collections import defaultdict, Counter, deque
from itertools import permutations, combinations, product, combinations_with_replacement, groupby, accumulate
import operator
from math import sqrt, gcd, factorial
# from math import isqrt, prod,comb  # python3.8用(notpypy)
#from bisect import bisect_left,bisect_right
#from functools import lru_cache,reduce
#from heapq import heappush,heappop,heapify,heappushpop,heapreplace
#import numpy as np
#import networkx as nx
#from networkx.utils import UnionFind
#from numba import njit, b1, i1, i4, i8, f8
#from scipy.sparse import csr_matrix
#from scipy.sparse.csgraph import shortest_path, floyd_warshall, dijkstra, bellman_ford, johnson, NegativeCycleError
# numba例 @njit(i1(i4[:], i8[:, :]),cache=True) 引数i4配列、i8 2次元配列,戻り値i1
def input(): return sys.stdin.readline().rstrip()
def divceil(n, k): return 1+(n-1)//k  # n/kの切り上げを返す
def yn(hantei, yes='Yes', no='No'): print(yes if hantei else no)


def dot(a, b, n, mod=10**9+7):
    """
    正方行列積a@bの10**9+7modをオーバーフロー回避するために分割して掛け算する
    """
    ans = np.zeros_like(a, dtype=np.int64)
    newn = ((n//8+1)*8)
    a2 = np.zeros((newn, newn), dtype=np.int64)
    a2[0:n, 0:n] += a
    b2 = np.zeros((newn, newn), dtype=np.int64)
    b2[0:n, 0:n] += b
    ans = np.zeros_like(a2, dtype=np.int64)
    for k in range(newn//8):
        for i in range(newn//8):
            for j in range(newn//8):
                ans[i*8:i*8+8, j*8:j*8+8] += a2[i*8:i*8+8,
                                                k*8:k*8+8] @ b2[k*8:k*8+8, j*8:j*8+8]
                ans[i*8:i*8+8, j*8:j*8+8] %= mod
    return ans[0:n, 0:n]


def dot2(mat1, mat2, MOD):
    """
    行列積a@bの10**9+7modをオーバーフロー回避するために上下15bitで分割して掛け算する 
    https://ikatakos.com/pot/programming_algorithm/python_tips/avoid_overflow
    """
    mask = (1 << 15) - 1
    mat1h = mat1 >> 15
    mat1l = mat1 & mask
    mat2h = mat2 >> 15
    mat2l = mat2 & mask
    mathh = mat1h @ mat2h % MOD
    matll = mat1l @ mat2l % MOD
    mathl = (mathh + matll - (mat1h - mat1l) @ (mat2h - mat2l)) % MOD
    res = (mathh << 30) + (mathl << 15) + matll
    res %= MOD
    return res


def matrix_multiplication(a, n, k, mod=10**9+7):  # n次正方行列a^k
    ans = np.eye(n, dtype=np.int64)
    while k:
        k, i = divmod(k, 2)
        if i:
            ans = dot2(ans, a, mod)
        a = dot2(a, a, mod)
    return ans


def main():
    mod = 10**9+7
    mod2 = 998244353
    p, q, r, k = map(int, input().split())
    mat = np.zeros((3, 3), dtype=np.int64)
    mat[1][0] = 1
    mat[2][1] = 1
    mat[0][2] = 1
    mat[1][2] = 1
    mat[2][2] = 1
    mat=matrix_multiplication(mat,3,k-1,10)
    print(dot2(np.array([p,q,r],dtype=np.int64),mat,10)[0])


if __name__ == '__main__':
    main()
0