結果

問題 No.2876 Infection
ユーザー kainadekainade
提出日時 2024-09-12 08:58:23
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,198 ms / 2,000 ms
コード長 29,713 bytes
コンパイル時間 334 ms
コンパイル使用メモリ 82,528 KB
実行使用メモリ 97,776 KB
最終ジャッジ日時 2024-09-12 08:58:37
合計ジャッジ時間 13,827 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 83 ms
77,180 KB
testcase_01 AC 81 ms
77,640 KB
testcase_02 AC 83 ms
77,616 KB
testcase_03 AC 81 ms
77,236 KB
testcase_04 AC 82 ms
77,128 KB
testcase_05 AC 82 ms
77,548 KB
testcase_06 AC 110 ms
77,180 KB
testcase_07 AC 1,198 ms
97,592 KB
testcase_08 AC 1,194 ms
97,732 KB
testcase_09 AC 1,176 ms
97,776 KB
testcase_10 AC 323 ms
82,100 KB
testcase_11 AC 304 ms
81,872 KB
testcase_12 AC 280 ms
81,884 KB
testcase_13 AC 407 ms
83,980 KB
testcase_14 AC 1,071 ms
94,908 KB
testcase_15 AC 80 ms
77,608 KB
testcase_16 AC 164 ms
79,488 KB
testcase_17 AC 86 ms
78,168 KB
testcase_18 AC 81 ms
77,088 KB
testcase_19 AC 128 ms
79,504 KB
testcase_20 AC 767 ms
85,696 KB
testcase_21 AC 128 ms
79,172 KB
testcase_22 AC 801 ms
87,276 KB
testcase_23 AC 413 ms
83,764 KB
testcase_24 AC 327 ms
82,180 KB
testcase_25 AC 806 ms
86,592 KB
testcase_26 AC 124 ms
79,244 KB
testcase_27 AC 346 ms
83,028 KB
testcase_28 AC 133 ms
79,748 KB
testcase_29 AC 1,035 ms
92,732 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    N,x=MI()

    if x==0:return print(1)
    if x==100:return print(N)


    nume=pow(100,-1,MOD)
    p,q=x*nume%MOD,(100-x)*nume%MOD

    C=Comb(N+100,MOD)
    comb,fac,finv=C.comb,C.fac,C.finv
    fft=FFT(MOD)

    powp=[1]
    for _ in range(N+99):
        powp.append(powp[-1]*p%MOD)

    dp=[[0]*(N+1)for _ in range(N+1)]
    dp[1][N-1]=1
    subdp=[[0]*(N+1)for _ in range(N+1)]
    subdp[1][N-1]=powp[N-1]*fac[N-1]%MOD
    
    blue=[pow(q*pow(p,-1,MOD)%MOD,k,MOD)*finv[k]%MOD for k in range(N+10)]

    kf=finv[:N+1][::-1]
    for i in range(1,N):
        #dp[i+1]を埋める.
        res=fft.convolution(subdp[i],kf)
        for k in range(N-i):
            dp[i+1][k]=blue[k]*res[N+k]%MOD
        for k in range(N-i):
            subdp[i+1][k]=dp[i+1][k]*powp[k]%MOD*fac[k]%MOD
    print(sum_table(dp)%MOD)






# https://github.com/shakayami/ACL-for-python/blob/master/convolution.py
class FFT():
    def primitive_root_constexpr(self,m):
        if m==2:return 1
        if m==167772161:return 3
        if m==469762049:return 3
        if m==754974721:return 11
        if m==998244353:return 3
        divs=[0]*20
        divs[0]=2
        cnt=1
        x=(m-1)//2
        while(x%2==0):x//=2
        i=3
        while(i*i<=x):
            if (x%i==0):
                divs[cnt]=i
                cnt+=1
                while(x%i==0):
                    x//=i
            i+=2
        if x>1:
            divs[cnt]=x
            cnt+=1
        g=2
        while(1):
            ok=True
            for i in range(cnt):
                if pow(g,(m-1)//divs[i],m)==1:
                    ok=False
                    break
            if ok:
                return g
            g+=1
    def bsf(self,x):
        res=0
        while(x%2==0):
            res+=1
            x//=2
        return res
    rank2=0
    root=[]
    iroot=[]
    rate2=[]
    irate2=[]
    rate3=[]
    irate3=[]
    
    def __init__(self,MOD):
        self.mod=MOD
        self.g=self.primitive_root_constexpr(self.mod)
        self.rank2=self.bsf(self.mod-1)
        self.root=[0 for i in range(self.rank2+1)]
        self.iroot=[0 for i in range(self.rank2+1)]
        self.rate2=[0 for i in range(self.rank2)]
        self.irate2=[0 for i in range(self.rank2)]
        self.rate3=[0 for i in range(self.rank2-1)]
        self.irate3=[0 for i in range(self.rank2-1)]
        self.root[self.rank2]=pow(self.g,(self.mod-1)>>self.rank2,self.mod)
        self.iroot[self.rank2]=pow(self.root[self.rank2],self.mod-2,self.mod)
        for i in range(self.rank2-1,-1,-1):
            self.root[i]=(self.root[i+1]**2)%self.mod
            self.iroot[i]=(self.iroot[i+1]**2)%self.mod
        prod=1;iprod=1
        for i in range(self.rank2-1):
            self.rate2[i]=(self.root[i+2]*prod)%self.mod
            self.irate2[i]=(self.iroot[i+2]*iprod)%self.mod
            prod=(prod*self.iroot[i+2])%self.mod
            iprod=(iprod*self.root[i+2])%self.mod
        prod=1;iprod=1
        for i in range(self.rank2-2):
            self.rate3[i]=(self.root[i+3]*prod)%self.mod
            self.irate3[i]=(self.iroot[i+3]*iprod)%self.mod
            prod=(prod*self.iroot[i+3])%self.mod
            iprod=(iprod*self.root[i+3])%self.mod
    def butterfly(self,a):
        n=len(a)
        h=(n-1).bit_length()
        
        LEN=0
        while(LEN<h):
            if (h-LEN==1):
                p=1<<(h-LEN-1)
                rot=1
                for s in range(1<<LEN):
                    offset=s<<(h-LEN)
                    for i in range(p):
                        l=a[i+offset]
                        r=a[i+offset+p]*rot
                        a[i+offset]=(l+r)%self.mod
                        a[i+offset+p]=(l-r)%self.mod
                    rot*=self.rate2[(~s & -~s).bit_length()-1]
                    rot%=self.mod
                LEN+=1
            else:
                p=1<<(h-LEN-2)
                rot=1
                imag=self.root[2]
                for s in range(1<<LEN):
                    rot2=(rot*rot)%self.mod
                    rot3=(rot2*rot)%self.mod
                    offset=s<<(h-LEN)
                    for i in range(p):
                        a0=a[i+offset]
                        a1=a[i+offset+p]*rot
                        a2=a[i+offset+2*p]*rot2
                        a3=a[i+offset+3*p]*rot3
                        a1na3imag=(a1-a3)%self.mod*imag
                        a[i+offset]=(a0+a2+a1+a3)%self.mod
                        a[i+offset+p]=(a0+a2-a1-a3)%self.mod
                        a[i+offset+2*p]=(a0-a2+a1na3imag)%self.mod
                        a[i+offset+3*p]=(a0-a2-a1na3imag)%self.mod
                    rot*=self.rate3[(~s & -~s).bit_length()-1]
                    rot%=self.mod
                LEN+=2
                
    def butterfly_inv(self,a):
        n=len(a)
        h=(n-1).bit_length()
        LEN=h
        while(LEN):
            if (LEN==1):
                p=1<<(h-LEN)
                irot=1
                for s in range(1<<(LEN-1)):
                    offset=s<<(h-LEN+1)
                    for i in range(p):
                        l=a[i+offset]
                        r=a[i+offset+p]
                        a[i+offset]=(l+r)%self.mod
                        a[i+offset+p]=(l-r)*irot%self.mod
                    irot*=self.irate2[(~s & -~s).bit_length()-1]
                    irot%=self.mod
                LEN-=1
            else:
                p=1<<(h-LEN)
                irot=1
                iimag=self.iroot[2]
                for s in range(1<<(LEN-2)):
                    irot2=(irot*irot)%self.mod
                    irot3=(irot*irot2)%self.mod
                    offset=s<<(h-LEN+2)
                    for i in range(p):
                        a0=a[i+offset]
                        a1=a[i+offset+p]
                        a2=a[i+offset+2*p]
                        a3=a[i+offset+3*p]
                        a2na3iimag=(a2-a3)*iimag%self.mod
                        a[i+offset]=(a0+a1+a2+a3)%self.mod
                        a[i+offset+p]=(a0-a1+a2na3iimag)*irot%self.mod
                        a[i+offset+2*p]=(a0+a1-a2-a3)*irot2%self.mod
                        a[i+offset+3*p]=(a0-a1-a2na3iimag)*irot3%self.mod
                    irot*=self.irate3[(~s & -~s).bit_length()-1]
                    irot%=self.mod
                LEN-=2
    def convolution(self,a,b):
        n=len(a);m=len(b)
        if not(a) or not(b):
            return []
        if min(n,m)<=40:
            res=[0]*(n+m-1)
            for i in range(n):
                for j in range(m):
                    res[i+j]+=a[i]*b[j]
                    res[i+j]%=self.mod
            return res
        z=1<<((n+m-2).bit_length())
        a=a+[0]*(z-n)
        b=b+[0]*(z-m)
        self.butterfly(a)
        self.butterfly(b)
        c=[(a[i]*b[i])%self.mod for i in range(z)]
        self.butterfly_inv(c)
        iz=pow(z,self.mod-2,self.mod)
        for i in range(n+m-1):
            c[i]=(c[i]*iz)%self.mod
        return c[:n+m-1]

# user config
############
DEBUG_MODE=1
############


# import
import sys
import itertools
import bisect
import math
from collections import *
from functools import cache
from heapq import *

# alias
DD = defaultdict
BSL = bisect.bisect_left
BSR = bisect.bisect_right

# config
input = sys.stdin.readline
sys.setrecursionlimit(10**7)

# input
def II(): return int(input())
def IS(): return input()[:-1]
def MI(): return map(int,input().split())
def LM(): return list(MI())
def LL(n): return [LM() for _ in range(n)]
def MI_1(): return map(lambda x:int(x)-1,input().split())
def LM_1(): return list(MI_1())
def LL_1(n): return [LM_1() for _ in range(n)]
def ALPHABET_TO_NUM(string, upper=False): return list(map(lambda elm:ord(elm)-ord("A") if upper else ord(elm)-ord("a"), string))


# functions
def DB(*args,**kwargs):
    global DEBUG_MODE
    if not DEBUG_MODE:
        return
    if args:
        print(*args)
        return
    for name, value in kwargs.items():
        print(f"{name} : {value}")

def bit_count(num):
    length = num.bit_length()
    res = 0
    for i in range(length):
        if num >> i & 1:
            res += 1
    return res

def popcount64(n):
    # 63桁まで高速に動く.64桁まで正常に動く.
    c=(n&0x5555555555555555)+((n>>1)&0x5555555555555555)
    c=(c&0x3333333333333333)+((c>>2)&0x3333333333333333)
    c=(c&0x0f0f0f0f0f0f0f0f)+((c>>4)&0x0f0f0f0f0f0f0f0f)
    c=(c&0x00ff00ff00ff00ff)+((c>>8)&0x00ff00ff00ff00ff)
    c=(c&0x0000ffff0000ffff)+((c>>16)&0x0000ffff0000ffff)
    c=(c&0x00000000ffffffff)+((c>>32)&0x00000000ffffffff)
    return c

def argmax(*args):
    if len(args) == 1 and hasattr(args[0], '__iter__'):
        lst = args[0]
    else:
        lst = args
    return lst.index(max(lst))

def argmin(*args):
    if len(args) == 1 and hasattr(args[0], '__iter__'):
        lst = args[0]
    else:
        lst = args
    return lst.index(min(lst))

def prefix_op(lst, op=lambda x,y:x+y, e=0):
    N = len(lst)
    res = [e]*(N+1)
    for i in range(N):
        res[i+1] = op(res[i], lst[i])
    return res

def suffix_op(lst, op=lambda x,y:x+y, e=0):
    N = len(lst)
    res = [e]*(N+1)
    for i in range(N):
        res[N-1-i] = op(lst[N-1-i], res[N-i])
    return res

def sigma_LinearFunc(coeff1, coeff0, left, right, MOD=None):
    """
    coeff1*x + coeff0
    の x = [left, right] の和を求める.
    MODで計算したい場合、区間の引数をMOD取った上で代入しても良い.
    そのとき、left > right となってもよい. 
    """
    if MOD:
        # MODが素数でない場合にも対応するように、和公式を適応後に剰余を計算.
        return ((coeff0%MOD*((right-left+1)%MOD)%MOD) + (coeff1%MOD*((left+right)*(right-left+1)//2%MOD)%MOD))%MOD
    return coeff0*(right-left+1) + coeff1*(left+right)*(right-left+1)//2

def find_divisors(n):
    divs_small, divs_big = [], []
    i = 1
    while i * i <= n:
        if n % i == 0:
            divs_small.append(i)
            divs_big.append(n // i)
        i += 1
    if divs_big[-1] == divs_small[-1]:
        divs_big.pop()
    for e in reversed(divs_big):
        divs_small.append(e)
    return divs_small

def prime_factorization(n):
    n_intact = n
    ret = []
    i = 2
    while i * i <= n_intact:
        if n % i == 0:
            cnt = 0
            while n % i == 0:
                n //= i
                cnt += 1
            ret.append((i,cnt))
        i += 1
    if n != 1: ret.append((n,1))
    return ret



""" 矩形の二次元配列を扱う諸関数 """
def copy_table(table):
    H,W = len(table), len(table[0])
    res = []
    for i in range(H):
        res.append([])
        for j in range(W):
            res[-1].append(table[i][j])
    return res

def sum_table(table, MOD=None):
    H,W = len(table), len(table[0])
    res = 0
    for i in range(H):
        for j in range(W):
            res += table[i][j]
        if MOD:
            res %= MOD
    return res

def transpose_table(table):
    H,W = len(table), len(table[0])
    res = [[None]*H for _ in range(W)]
    for i in range(H):
        for j in range(W):
            res[j][i] = table[i][j]
    return res

def convert_table_to_bit(table, letter1="#", rev=False):
    H,W = len(table), len(table[0])
    res = []
    for h in range(H):
        rowBit = 0
        for w in range(W):
            if rev:
                if table[h][w] == letter1:
                    rowBit += 2**w
            else:
                if table[h][W-w-1] == letter1:
                    rowBit += 2**w
        res.append(rowBit)
    return res

def rotate_table_cc(S): return list(zip(*S))[::-1]
def rotate_table_c(S): return list(zip(*S[::-1]))

def mul_matrix(A, B, mod=None):
    N = len(A)
    K = len(A[0])
    if not hasattr(B[0], "__iter__"): 
        B = [[e] for e in B]
    M = len(B[0])
    res = [[0 for _ in range(M)] for _ in range(N)]
    
    if mod is None:
        for i in range(N):
            for j in range(M):
                for k in range(K):
                    res[i][j] += A[i][k] * B[k][j]
    else:
        for i in range(N):
            for j in range(M):
                for k in range(K):
                    res[i][j] += A[i][k] * B[k][j]
                    res[i][j] %= mod
    return res


def pow_matrix(mat, exp, mod):
    N = len(mat)
    res = [[1 if i == j else 0 for i in range(N)] for j in range(N)]
    while exp > 0 :
        if exp%2 == 1 :
            res = mul_matrix(res, mat, mod)
        mat = mul_matrix(mat, mat, mod)
        exp //= 2
    return res

def compress(lst):
    D = {e:i for i,e in enumerate(sorted(set(lst)))}
    return [D[e] for e in lst]

def highDimCompress(lst):
    #(x,y)の配列や,(x,y,z)の配列が与えられたとき,軸ごとに座圧する.
    return list(zip(*list(map(compress,list(zip(*lst))))))



#classes

# https://github.com/tatyam-prime/SortedSet/blob/main/SortedSet.py
from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, List, Tuple, TypeVar, Optional
T = TypeVar('T')
class SortedSet(Generic[T]):
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a: Iterable[T] = []) -> None:
        "Make a new SortedSet from iterable. / O(N) if sorted and unique / O(N log N)"
        a = list(a)
        n = self.size = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        if any(a[i] >= a[i + 1] for i in range(n - 1)):
            a, b = [], a
            for x in b:
                if not a or a[-1] != x:
                    a.append(x)
        bucket_size = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // bucket_size : n * (i + 1) // bucket_size] for i in range(bucket_size)]

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i: yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedSet" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x: T) -> Tuple[List[T], int, int]:
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x: T) -> bool:
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def add(self, x: T) -> bool:
        "Add an element and return True if added. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return True
        a, b, i = self._position(x)
        if i != len(a) and a[i] == x: return False
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
        return True
    
    def _pop(self, a: List[T], b: int, i: int) -> T:
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x: T) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True
    
    def lt(self, x: T) -> Optional[T]:
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Optional[T]:
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Optional[T]:
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Optional[T]:
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int) -> T:
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1) -> T:
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError
    
    def index(self, x: T) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans


    """
    (num, cnt)を要素としたSSを管理してMultiset化したいとき用.
    """
    def exist(self, x):
        ret = self.gt((x,0))
        if ret is None:
            return False
        elif ret[0] == x:
            return True
        else:
            return False

    def increment(self, x):
        if not self.exist(x):
            self.add((x,1))
        else:
            num, cnt = self.gt((x,0))
            self.discard((x,cnt))
            self.add((x,cnt+1))


    def decrement(self, x):
        if not self.exist(x):
            return
        num, cnt = self.gt((x,0))
        if cnt == 1:
            self.discard((x,cnt))
        else:
            self.discard((x,cnt))
            self.add((x,cnt-1))

    def multi_add(self, x, y):
        if not self.exist(x):
            self.add((x,y))
        else:
            num, cnt = self.gt((x,0))
            self.discard((x,cnt))
            self.add((x,cnt+y))

    def multi_sub(self, x, y):
        if not self.exist(x):
            return
        num, cnt = self.gt((x,0))
        if cnt <= y:
            self.discard((x,cnt))
        else:
            self.discard((x,cnt))
            self.add((x,cnt-y))


# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
T = TypeVar('T')
class SortedMultiset(Generic[T]):
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a: Iterable[T] = []) -> None:
        "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
        a = list(a)
        n = self.size = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]

    def __iter__(self) -> Iterator[T]:
        for i in self.a:
            for j in i: yield j

    def __reversed__(self) -> Iterator[T]:
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedMultiset" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x: T) -> Tuple[List[T], int, int]:
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x: T) -> bool:
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def count(self, x: T) -> int:
        "Count the number of x."
        return self.index_right(x) - self.index(x)

    def add(self, x: T) -> None:
        "Add an element. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return
        a, b, i = self._position(x)
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
    
    def _pop(self, a: List[T], b: int, i: int) -> T:
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x: T) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True

    def lt(self, x: T) -> Optional[T]:
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x: T) -> Optional[T]:
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x: T) -> Optional[T]:
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x: T) -> Optional[T]:
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int) -> T:
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1) -> T:
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError

    def index(self, x: T) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x: T) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans


class Comb:
    def __init__(self,table_len,mod):
        """
        modが素数の場合しか使用できない.
        table_len に指定した数まで法modでのコンビネーションの計算が可能になる.
        """
        self.fac = [1,1]
        self.inv = [1,1]
        self.finv = [1,1]
        self.mod = mod
        for i in range(2,table_len+1):
            self.fac.append(self.fac[i-1]*i%mod)
            self.inv.append(-self.inv[mod%i]*(mod//i)%mod)
            self.finv.append(self.finv[i-1]*self.inv[i]%mod)

    def comb(self,a,b):
        return self.fac[a]*self.finv[b]*self.finv[a-b]%self.mod


class GridBFS:

    def __init__(self, table):
        #二次元配列や文字列の配列を受け取る.
        self.table = table
        self.H = len(table)
        self.W = len(table[0])
        self.wall = "#"

    def find(self, c):
        #table から引数の文字を探しインデックスを返す. 無い時、None.
        for h in range(self.H):
            for w in range(self.W):
                if self.table[h][w] == c:
                    return (h,w)
        return None

    def set_wall_string(self, string): 
        #壁として扱う文字を "#!^" の様に文字列リテラルで格納. 初期値は、"#"
        self.wall = string

    def island(self, transition = [[-1,0],[0,1],[1,0],[0,-1]]):
        H, W = self.H, self.W
        self.island_id = [[-1]*W for _ in range(H)]
        self.island_size = [[-1]*W for _ in range(W)]

        crr_id = 0
        id2size = dict()
        for sh in range(H):
            for sw in range(W):
                if self.table[sh][sw] in self.wall:
                    continue
                if self.island_id[sh][sw] != -1:
                    continue
                deq = deque()
                deq.append((sh,sw))
                crr_size = 1
                self.island_id[sh][sw] = crr_id
                while deq:
                    h,w = deq.popleft()
                    for dh, dw in transition:
                        nh, nw = h+dh, w+dw
                        if (not 0<=nh<H) or (not 0<=nw<W):
                            continue
                        if self.table[nh][nw] in self.wall:
                            continue
                        if self.island_id[nh][nw] == -1:
                            self.island_id[nh][nw] = crr_id
                            deq.append((nh, nw))
                            crr_size += 1

                id2size[crr_id] = crr_size
                crr_id += 1

        for h in range(H):
            for w in range(W):
                if self.table[h][w] in self.wall:
                    continue
                self.island_size[h][w] = id2size[self.island_id[h][w]]

        return self.island_id, self.island_size


    def distance(self, start, goal=None, transition = [[-1,0],[0,1],[1,0],[0,-1]]):
        #goal指定したら、goalまでの最短距離を、指定しなければdist配列を返す. 到達不可能は -1.
        #二次元配列上での遷移方法を transition で指定できる. 初期値は上下左右.
        H, W, tab, wall = self.H, self.W, self.table, self.wall

        INF = 1<<60

        deq = deque()
        deq.append(start)
        dist = [[INF]*W for _ in range(H)]
        dist[start[0]][start[1]] = 0

        if start == goal:
            return 0

        while deq:
            h,w = deq.popleft()
            for dh, dw in transition:
                nh = h+dh
                nw = w+dw
                # gridの範囲外.
                if (not 0<=nh<H) or (not 0<=nw<W):
                    continue

                # wallに設定されている文字なら.
                if tab[nh][nw] in wall:
                    continue

                new_dist = dist[h][w] + 1

                #goalが引数で与えられていてgoalに達したら終了.
                if goal and (nh,nw)==goal:
                    return new_dist
                
                if dist[nh][nw] > new_dist:
                    dist[nh][nw] = new_dist
                    deq.append((nh,nw))

        # goalが設定されていていまだreturnされていないなら,
        # goalに到達できなかったということ.
        if goal:
            return -1

        return dist


class nth_root:
    def __init__(self):
        self.ngs = [-1, -1, 4294967296, 2642246, 65536, 7132, 1626, 566, 256, 139, 85, 57, 41, 31, 24, 20, 16, 14, 12, 11, 10, 9, 8, 7, 7, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
        """ pre-calculated by this code """
        # self.ngs = [-1,-1]
        # for n in range(2,64):
        #     ok,ng=2**33,0
        #     while abs(ok-ng)>1:
        #         mid=ok+ng>>1
        #         if mid**n>=2**64:
        #             ok=mid
        #         else:
        #             ng=mid
        #     self.ngs.append(ok)

    """ xが64bitに収まる時高速. そうでないときも動作するが速度の保証無し. """
    def calc(self, x, n, is_x_within_64bit=True):
        if x<=1 or n==1: return x
        if is_x_within_64bit:
            if n>=64: return 1
            ng = self.ngs[n]
        else:
            ng = x

        ok = 0
        while abs(ok-ng)>1:
            mid = ok+ng>>1
            if mid**n <= x:
                ok = mid
            else:
                ng = mid
        return ok

# well-used const
global DIRECTION_4, DIRECTION_8, DIRECTION_DIAGONAL, DIRECTION_URDL_TABLE, DIRECTION_URDL_COORD_PLANE, MOD, INF, LOWER_ALP, UPPER_ALP, ALL_ALP

# clockwise from top.
DIRECTION_4 = [[-1,0],[0,1],[1,0],[0,-1]] 
DIRECTION_8 = [[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1],[0,-1],[-1,-1]]
DIRECTION_DIAGONAL = [[-1,1],[1,1],[1,-1],[-1,-1]]
DIRECTION_URDL_TABLE = {'U':(-1,0), 'R':(0,1), 'D':(1,0), 'L':(0,-1)}
DIRECTION_URDL_COORD_PLANE = {'U':(0,1), 'R':(1,0), 'D':(0,-1), 'L':(-1,0)}

MOD = 998244353
INF = 1<<60
LOWER_ALP = "abcdefghijklmnopqrstuvwxyz"
UPPER_ALP = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
ALL_ALP = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"

main()
0