結果

問題 No.1079 まお
ユーザー sasa8uyauya
提出日時 2025-02-25 13:06:00
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,697 ms / 2,000 ms
コード長 1,703 bytes
コンパイル時間 2,742 ms
コンパイル使用メモリ 81,660 KB
実行使用メモリ 310,904 KB
最終ジャッジ日時 2025-02-25 13:06:32
合計ジャッジ時間 29,944 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

n,K=map(int,input().split())
a=list(map(int,input().split()))
z=sorted(set(a))
d={v:i for i,v in enumerate(z)}
g=0

class BIT:
  def __init__(self,n):
    self.n=n
    self.q=[0]*(n+1)
  
  def add(self,p,x):
    p+=1
    while p<=self.n:
      self.q[p]+=x
      p+=p&(-p)
  
  def sum_(self,p):
    a=0
    while p>0:
      a+=self.q[p]
      p-=p&(-p)
    return a
  
  def sum(self,l,r):
    return self.sum_(r)-self.sum_(l)

sts=BIT(len(d)+1)
stc=BIT(len(d)+1)

def solve(l,r):
  global g
  if l==r:
    v=a[l]
    if v+v==K:
      g+=1
    return {v:[(1,v,1)]},{v:[(1,v,1)]}
  w=(r-l+1)//2
  ql,_=solve(l,l+w-1)
  _,qr=solve(l+w,r)
  for lv in ql:
    if K-lv in qr:
      for x,v,c in qr[K-lv]:
        y=d[v]
        sts.add(y,x)
        stc.add(y,1)
      for x,v,c in ql[lv]:
        if c==1:
          y=d[v]
          g+=x*stc.sum(y+1,len(d))+sts.sum(y+1,len(d))
      for x,v,c in qr[K-lv]:
        y=d[v]
        sts.add(y,-x)
        stc.add(y,-1)
  for rv in qr:
    if K-rv in ql:
      for x,v,c in ql[K-rv]:
        y=d[v]
        sts.add(y,x)
        stc.add(y,1)
      for x,v,c in qr[rv]:
        if c==1:
          y=d[v]
          g+=x*stc.sum(y+1,len(d))+sts.sum(y+1,len(d))
      for x,v,c in ql[K-rv]:
        y=d[v]
        sts.add(y,-x)
        stc.add(y,-1)
  ql={}
  mv=10**10
  c=1
  for i in reversed(range(l,r+1)):
    v=a[i]
    if v<mv:
      mv=v
      c=1
    elif v==mv:
      c+=1
    if v not in ql:
      ql[v]=[]
    ql[v]+=[(r-i+1,mv,c)]
  qr={}
  mv=10**10
  c=1
  for i in range(l,r+1):
    v=a[i]
    if v<mv:
      mv=v
      c=1
    elif v==mv:
      c+=1
    if v not in qr:
      qr[v]=[]
    qr[v]+=[(i-l+1,mv,c)]
  return ql,qr

solve(0,n-1)
print(g)
0