結果

問題 No.911 ラッキーソート
ユーザー 👑 obakyan
提出日時 2020-04-18 17:26:07
言語 Lua
(LuaJit 2.1.1734355927)
結果
AC  
実行時間 1,597 ms / 2,000 ms
コード長 3,922 bytes
コンパイル時間 167 ms
コンパイル使用メモリ 6,944 KB
実行使用メモリ 146,236 KB
最終ジャッジ日時 2024-10-04 00:00:30
合計ジャッジ時間 29,333 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 46
権限があれば一括ダウンロードができます

ソースコード

diff #

local bls, brs = bit.lshift, bit.rshift
local mfl, mce = math.floor, math.ceil
local band = bit.band
local function lltonumber(str)
  local ret = 0LL
  for i = 1, #str do
    ret = ret * 10LL + str:sub(i, i):byte() - 48
  end
  return ret
end
local n, l, r = io.read():match("(%d+) (%d+) (%d+)")
n = tonumber(n)
l = lltonumber(l)
r = lltonumber(r)
local t = {}
local countlim = 64
-- 1 or 3: can use "0", 2 or 3: can use "1"
for i = 1, countlim do
  t[i] = 3
end
local astr_all = io.read()
local allbits = {}
for astr in astr_all:gmatch("%d+") do
  allbits[#allbits + 1] = {}
  local i = #allbits
  if #astr < 11 then
    local a = tonumber(astr)
    for j = 1, countlim do
      allbits[i][j] = a % 2
      a = brs(a, 1)
    end
  else
    local top = astr:sub(1, 9)
    top = tonumber(top)
    local bottom = astr:sub(10, #astr)
    bottom = tonumber(bottom)
    for j = 1, countlim do
      allbits[i][j] = bottom % 2
      if top % 2 == 1 then
        top = brs(top, 1)
        bottom = bottom + (10^(#astr - 9))
        bottom = mfl(bottom / 2)
      else
        top = brs(top, 1)
        bottom = mfl(bottom / 2)
      end
    end
  end
  -- local a = lltonumber(astr)
  -- allbits[#allbits + 1] = {}
  -- local i = #allbits
  -- for j = 1, countlim do
  --   allbits[i][j] = a % 2LL == 1LL and 1 or 0
  --   a = a / 2LL
  -- end
end
local valid = true
local t1, t2, t3 = {countlim}, {1}, {n}
while 0 < #t1 do
  local pos, left, right = t1[#t1], t2[#t2], t3[#t3]
  table.remove(t1) table.remove(t2) table.remove(t3)
  if left == right then
    if 1 < pos then
      table.insert(t1, pos - 1) table.insert(t2, left) table.insert(t3, right)
    end
  else
    local lval = allbits[left][pos]
    local rval = allbits[right][pos]
    if lval == rval then
      for i = left + 1, right - 1 do
        if allbits[i][pos] ~= lval then
          valid = false break
        end
      end
      if not valid then break end
      if 1 < pos then
        table.insert(t1, pos - 1) table.insert(t2, left) table.insert(t3, right)
      end
    else
      local turnpos = nil
      for i = left + 1, right do
        if allbits[i][pos] ~= allbits[i - 1][pos] then
          if turnpos then valid = false break
          else
            turnpos = i
          end
        end
      end
      if not valid then break end
      if lval == 0 then
        t[pos] = band(t[pos], 1)
      else
        t[pos] = band(t[pos], 2)
      end
      if t[pos] == 0 then valid = false break end
      if pos ~= 1 then
        table.insert(t1, pos - 1) table.insert(t2, left) table.insert(t3, turnpos - 1)
        table.insert(t1, pos - 1) table.insert(t2, turnpos) table.insert(t3, right)
      end
    end
  end
end
-- print(valid)
-- print(table.concat(t, " "))
local function getCount(max)
  local maxbits = {}
  for i = 1, countlim do
    maxbits[i] = max % 2LL == 1LL and 1 or 0
    max = max / 2LL
  end
  local dpmax, dpnormal = {}, {}
  for i = 1, countlim do
    dpmax[i] = 0
    dpnormal[i] = 0
  end
  dpmax[countlim] = 1LL
  dpnormal[countlim] = 0LL
  for i = countlim - 1, 1, -1 do
    if t[i] == 3 then
      dpmax[i] = dpmax[i + 1]
      dpnormal[i] = dpnormal[i + 1] * 2LL
      if maxbits[i] == 1 then
        dpnormal[i] = dpnormal[i] + dpmax[i + 1]
      end
    elseif t[i] == 2 then
      -- only "1" is available
      if maxbits[i] == 0 then
        dpmax[i] = 0
      else
        dpmax[i] = dpmax[i + 1]
      end
      dpnormal[i] = dpnormal[i + 1]
    else -- t[i] == 1
      -- only "0" is available
      if maxbits[i] == 0 then
        dpmax[i] = dpmax[i + 1]
        dpnormal[i] = dpnormal[i + 1]
      else
        dpmax[i] = 0
        dpnormal[i] = dpnormal[i + 1] + dpmax[i + 1]
      end
    end
  end
  return dpnormal[1] + dpmax[1]
end
if valid then
  local cnt = getCount(r)
  if 0 < l then
    cnt = cnt - getCount(l - 1)
  end
  local str = tostring(cnt):gsub("LL", "")
  print(str)
else
  print(0)
end
0