增加sql的waf功能
This commit is contained in:
102
init.lua
102
init.lua
@@ -1,42 +1,45 @@
|
||||
require 'config'
|
||||
require 'b64'
|
||||
require 'aes'
|
||||
require 'log'
|
||||
require '403'
|
||||
require 'tableXstring'
|
||||
require 'fileio'
|
||||
require 'randomStr'
|
||||
require 'whiteList'
|
||||
require "config"
|
||||
require "b64"
|
||||
require "aes"
|
||||
require "log"
|
||||
require "403"
|
||||
require "tableXstring"
|
||||
require "fileio"
|
||||
require "randomStr"
|
||||
require "whiteList"
|
||||
require "tools"
|
||||
require "waf/waf"
|
||||
|
||||
local optionIsOn = function (options) return options == "on" and true or false end
|
||||
local optionIsOn = function(options)
|
||||
return options == "on" and true or false
|
||||
end
|
||||
ToolsProtect = optionIsOn(toolsProtect)
|
||||
ShiroProtect = optionIsOn(shiroProtect)
|
||||
JsProtect = optionIsOn(jsProtect)
|
||||
JsConfuse = false
|
||||
SensitiveProtect = optionIsOn(sensitiveProtect)
|
||||
|
||||
|
||||
-- cookie加密
|
||||
function reqCookieParse()
|
||||
if ShiroProtect then
|
||||
local userCookieX9 = ngx.var.cookie_x9i7RDYX23
|
||||
if not userCookieX9 then -- 没有cookie
|
||||
log('0-cookie 无cookie', '')
|
||||
ngx.req.set_header('Cookie', '') -- 移除其他cookie
|
||||
elseif #userCookieX9 < 32 then -- 判断cookie长度
|
||||
log('1-cookie 不符合要求', userCookieX9)
|
||||
ngx.say('4')
|
||||
if not userCookieX9 then -- 没有cookie
|
||||
log("0-cookie 无cookie", "")
|
||||
ngx.req.set_header("Cookie", "") -- 移除其他cookie
|
||||
elseif #userCookieX9 < 32 then -- 判断cookie长度
|
||||
log("1-cookie 不符合要求", userCookieX9)
|
||||
ngx.say("4")
|
||||
say_html()
|
||||
else --有cookie
|
||||
else --有cookie
|
||||
local result = xpcall(dencrypT, emptyPrint, userCookieX9, aesKey)
|
||||
if not result then --解密失败
|
||||
log('2-cookie 无法解密', userCookieX9)
|
||||
ngx.say('5')
|
||||
log("2-cookie 无法解密", userCookieX9)
|
||||
ngx.say("5")
|
||||
say_html()
|
||||
else --解密成功
|
||||
else --解密成功
|
||||
local originCookie = StrToTable(dencrypT(userCookieX9, aesKey))
|
||||
ngx.req.set_header('Cookie', transTable(originCookie))
|
||||
log('3-cookie 解密成功', userCookieX9)
|
||||
ngx.req.set_header("Cookie", transTable(originCookie))
|
||||
log("3-cookie 解密成功", userCookieX9)
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -46,9 +49,9 @@ function respCookieEncrypt()
|
||||
if ShiroProtect then
|
||||
local value = ngx.resp.get_headers()["Set-Cookie"]
|
||||
if value then
|
||||
local encryptedCookie = cookieD.."="..encrypT(TableToStr(value), aesKey)
|
||||
local encryptedCookie = cookieD .. "=" .. encrypT(TableToStr(value), aesKey)
|
||||
ngx.header["Set-Cookie"] = encryptedCookie
|
||||
log('4-cookie 加密成功',encryptedCookie)
|
||||
log("4-cookie 加密成功", encryptedCookie)
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -58,30 +61,30 @@ function toolsInfoSpider()
|
||||
if ToolsProtect and not whiteExtCheck() then
|
||||
local clientCookieA = ngx.var.cookie_h0yGbdRv
|
||||
local clientCookieB = ngx.var.cookie_kQpFHdoh
|
||||
if not (clientCookieA and clientCookieB) then --没有cookieA进入reload,302至html生成cookie后再请求原地址
|
||||
local ip = 'xxx'
|
||||
local finalPath = 'http://'..ip..'/'..jsPath..'?origin='..encodeBase64(ngx.var.request_uri)
|
||||
log('1-tools 无cookieA/B', '')
|
||||
if not (clientCookieA and clientCookieB) then --没有cookieA进入reload,302至html生成cookie后再请求原地址
|
||||
local ip = "xxx"
|
||||
local finalPath = "http://" .. ip .. "/" .. jsPath .. "?origin=" .. encodeBase64(ngx.var.request_uri)
|
||||
log("1-tools 无cookieA/B", "")
|
||||
ngx.redirect(finalPath, 302)
|
||||
else
|
||||
local result = xpcall(dencrypT, emptyPrint, clientCookieB, clientCookieA)
|
||||
if not result then
|
||||
log('2-tools 解密失败', clientCookieA..', '..clientCookieB)
|
||||
ngx.say('1')
|
||||
log("2-tools 解密失败", clientCookieA .. ", " .. clientCookieB)
|
||||
ngx.say("1")
|
||||
say_html() -- 解密失败
|
||||
else-- 可以解密,提取数据
|
||||
else -- 可以解密,提取数据
|
||||
local result2 = dencrypT(clientCookieB, clientCookieA)
|
||||
if #result2 < 1 then
|
||||
log('3-tools 解密失败', result2)
|
||||
log("3-tools 解密失败", result2)
|
||||
else
|
||||
local srs = split(result2, ',')
|
||||
local _,e = string.find(srs[1], '0')
|
||||
local srs = split(result2, ",")
|
||||
local _, e = string.find(srs[1], "0")
|
||||
if e ~= nil then
|
||||
log('4-tools 工具请求', result2)
|
||||
ngx.say('2')
|
||||
log("4-tools 工具请求", result2)
|
||||
ngx.say("2")
|
||||
say_html()
|
||||
else
|
||||
log('0-tools 工具验证通过, 记录浏览器指纹', '', srs[2])
|
||||
log("0-tools 工具验证通过, 记录浏览器指纹", "", srs[2])
|
||||
end
|
||||
end
|
||||
end
|
||||
@@ -93,7 +96,7 @@ end
|
||||
function jsExtDetect()
|
||||
if JsProtect then
|
||||
local ext = string.match(ngx.var.uri, ".+%.(%w+)$")
|
||||
if ext == 'js' then -- 加入检查,js文件是否存在
|
||||
if ext == "js" then -- 加入检查,js文件是否存在
|
||||
JsConfuse = true
|
||||
end
|
||||
end
|
||||
@@ -102,14 +105,14 @@ end
|
||||
function jsConfuse()
|
||||
if JsConfuse then
|
||||
local originBody = ngx.arg[1]
|
||||
if #originBody > 200 then -- 筛选空js
|
||||
if #originBody > 200 then -- 筛选空js
|
||||
local s = getRandom(8)
|
||||
local path = '/tmp/'..s
|
||||
writefile(path, originBody, 'w+')
|
||||
local t = io.popen('export NODE_PATH=/usr/lib/node_modules && node /gate/node/js_confuse.js '..path)
|
||||
local path = "/tmp/" .. s
|
||||
writefile(path, originBody, "w+")
|
||||
local t = io.popen("export NODE_PATH=/usr/lib/node_modules && node /gate/node/js_confuse.js " .. path)
|
||||
local a = t:read("*all")
|
||||
ngx.arg[1] = a
|
||||
os.execute('rm -f '..path)
|
||||
os.execute("rm -f " .. path)
|
||||
end
|
||||
JsConfuse = false
|
||||
end
|
||||
@@ -122,14 +125,3 @@ function dateReplace()
|
||||
ngx.arg[1] = replaceTelephone
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
56
log.lua
56
log.lua
@@ -1,21 +1,24 @@
|
||||
require 'config'
|
||||
require "config"
|
||||
|
||||
|
||||
local optionIsOn = function (options) return options == "on" and true or false end
|
||||
local optionIsOn = function(options)
|
||||
return options == "on" and true or false
|
||||
end
|
||||
local Attacklog = optionIsOn(attacklog)
|
||||
local logpath = logdir
|
||||
|
||||
local function getClientIp()
|
||||
IP = ngx.var.remote_addr
|
||||
IP = ngx.var.remote_addr
|
||||
if IP == nil then
|
||||
IP = "unknown"
|
||||
IP = "unknown"
|
||||
end
|
||||
return IP
|
||||
end
|
||||
|
||||
local function write(logfile,msg)
|
||||
local fd = io.open(logfile,"ab")
|
||||
if fd == nil then return end
|
||||
local function write(logfile, msg)
|
||||
local fd = io.open(logfile, "ab")
|
||||
if fd == nil then
|
||||
return
|
||||
end
|
||||
fd:write(msg)
|
||||
fd:flush()
|
||||
fd:close()
|
||||
@@ -23,19 +26,38 @@ end
|
||||
|
||||
function log(data, ruletag, fp)
|
||||
if Attacklog then
|
||||
local fingerprint = fp or ''
|
||||
local fingerprint = fp or ""
|
||||
local realIp = getClientIp()
|
||||
local method = ngx.var.request_method
|
||||
local ua = ngx.var.http_user_agent
|
||||
local servername=ngx.var.server_name
|
||||
local servername = ngx.var.server_name
|
||||
local url = ngx.var.request_uri
|
||||
local time=ngx.localtime()
|
||||
if ua then
|
||||
line = realIp.." ["..time.."] \""..method.." "..servername..url.."\" \""..ruletag.."\" \""..ua.."\" \""..data.."\" \""..fingerprint.."\"\n"
|
||||
local time = ngx.localtime()
|
||||
if ua then
|
||||
line =
|
||||
realIp ..
|
||||
" [" ..
|
||||
time ..
|
||||
'] "' ..
|
||||
method ..
|
||||
" " ..
|
||||
servername ..
|
||||
url ..
|
||||
'" "' ..
|
||||
ruletag ..
|
||||
'" "' .. ua .. '" "' .. data .. '" "' .. fingerprint .. '"\n'
|
||||
else
|
||||
line = realIp.." ["..time.."] \""..method.." "..servername..url.."\" \""..ruletag.."\" - \""..data.."\" \""..fingerprint.."\"\n"
|
||||
line =
|
||||
realIp ..
|
||||
" [" ..
|
||||
time ..
|
||||
'] "' ..
|
||||
method ..
|
||||
" " ..
|
||||
servername ..
|
||||
url .. '" "' .. ruletag .. '" - "' .. data .. '" "' .. fingerprint .. '"\n'
|
||||
end
|
||||
local filename = logpath..'/'..servername.."_"..ngx.today().."_sec.log"
|
||||
write(filename,line)
|
||||
local filename = logpath .. "/" .. servername .. "_" .. ngx.today() .. "_sec.log"
|
||||
write(filename, line)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
303
resty/aes.lua
Normal file
303
resty/aes.lua
Normal file
@@ -0,0 +1,303 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
--local asn1 = require "resty.asn1"
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_gc = ffi.gc
|
||||
local ffi_str = ffi.string
|
||||
local ffi_copy = ffi.copy
|
||||
local C = ffi.C
|
||||
local setmetatable = setmetatable
|
||||
--local error = error
|
||||
local type = type
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
local EVP_CTRL_AEAD_SET_IVLEN = 0x09
|
||||
local EVP_CTRL_AEAD_GET_TAG = 0x10
|
||||
local EVP_CTRL_AEAD_SET_TAG = 0x11
|
||||
|
||||
ffi.cdef[[
|
||||
typedef struct engine_st ENGINE;
|
||||
|
||||
typedef struct evp_cipher_st EVP_CIPHER;
|
||||
typedef struct evp_cipher_ctx_st EVP_CIPHER_CTX;
|
||||
|
||||
typedef struct env_md_ctx_st EVP_MD_CTX;
|
||||
typedef struct env_md_st EVP_MD;
|
||||
|
||||
const EVP_MD *EVP_md5(void);
|
||||
const EVP_MD *EVP_sha(void);
|
||||
const EVP_MD *EVP_sha1(void);
|
||||
const EVP_MD *EVP_sha224(void);
|
||||
const EVP_MD *EVP_sha256(void);
|
||||
const EVP_MD *EVP_sha384(void);
|
||||
const EVP_MD *EVP_sha512(void);
|
||||
|
||||
const EVP_CIPHER *EVP_aes_128_ecb(void);
|
||||
const EVP_CIPHER *EVP_aes_128_cbc(void);
|
||||
const EVP_CIPHER *EVP_aes_128_cfb1(void);
|
||||
const EVP_CIPHER *EVP_aes_128_cfb8(void);
|
||||
const EVP_CIPHER *EVP_aes_128_cfb128(void);
|
||||
const EVP_CIPHER *EVP_aes_128_ofb(void);
|
||||
const EVP_CIPHER *EVP_aes_128_ctr(void);
|
||||
const EVP_CIPHER *EVP_aes_192_ecb(void);
|
||||
const EVP_CIPHER *EVP_aes_192_cbc(void);
|
||||
const EVP_CIPHER *EVP_aes_192_cfb1(void);
|
||||
const EVP_CIPHER *EVP_aes_192_cfb8(void);
|
||||
const EVP_CIPHER *EVP_aes_192_cfb128(void);
|
||||
const EVP_CIPHER *EVP_aes_192_ofb(void);
|
||||
const EVP_CIPHER *EVP_aes_192_ctr(void);
|
||||
const EVP_CIPHER *EVP_aes_256_ecb(void);
|
||||
const EVP_CIPHER *EVP_aes_256_cbc(void);
|
||||
const EVP_CIPHER *EVP_aes_256_cfb1(void);
|
||||
const EVP_CIPHER *EVP_aes_256_cfb8(void);
|
||||
const EVP_CIPHER *EVP_aes_256_cfb128(void);
|
||||
const EVP_CIPHER *EVP_aes_256_ofb(void);
|
||||
const EVP_CIPHER *EVP_aes_128_gcm(void);
|
||||
const EVP_CIPHER *EVP_aes_192_gcm(void);
|
||||
const EVP_CIPHER *EVP_aes_256_gcm(void);
|
||||
|
||||
EVP_CIPHER_CTX *EVP_CIPHER_CTX_new();
|
||||
void EVP_CIPHER_CTX_free(EVP_CIPHER_CTX *a);
|
||||
int EVP_CIPHER_CTX_block_size(const EVP_CIPHER_CTX *ctx);
|
||||
|
||||
int EVP_EncryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher,
|
||||
ENGINE *impl, unsigned char *key, const unsigned char *iv);
|
||||
|
||||
int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
|
||||
const unsigned char *in, int inl);
|
||||
|
||||
int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl);
|
||||
|
||||
int EVP_DecryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher,
|
||||
ENGINE *impl, unsigned char *key, const unsigned char *iv);
|
||||
|
||||
int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
|
||||
const unsigned char *in, int inl);
|
||||
|
||||
int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *outm, int *outl);
|
||||
|
||||
int EVP_BytesToKey(const EVP_CIPHER *type,const EVP_MD *md,
|
||||
const unsigned char *salt, const unsigned char *data, int datal,
|
||||
int count, unsigned char *key,unsigned char *iv);
|
||||
|
||||
int EVP_CIPHER_CTX_ctrl(EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr);
|
||||
]]
|
||||
|
||||
local hash
|
||||
hash = {
|
||||
md5 = C.EVP_md5(),
|
||||
sha1 = C.EVP_sha1(),
|
||||
sha224 = C.EVP_sha224(),
|
||||
sha256 = C.EVP_sha256(),
|
||||
sha384 = C.EVP_sha384(),
|
||||
sha512 = C.EVP_sha512()
|
||||
}
|
||||
_M.hash = hash
|
||||
|
||||
local EVP_MAX_BLOCK_LENGTH = 32
|
||||
|
||||
local cipher
|
||||
cipher = function (size, _cipher)
|
||||
local _size = size or 128
|
||||
local _cipher = _cipher or "cbc"
|
||||
local func = "EVP_aes_" .. _size .. "_" .. _cipher
|
||||
if C[func] then
|
||||
return { size=_size, cipher=_cipher, method=C[func]()}
|
||||
else
|
||||
return nil
|
||||
end
|
||||
end
|
||||
_M.cipher = cipher
|
||||
|
||||
function _M.new(self, key, salt, _cipher, _hash, hash_rounds, iv_len)
|
||||
local encrypt_ctx = C.EVP_CIPHER_CTX_new()
|
||||
if encrypt_ctx == nil then
|
||||
return nil, "no memory"
|
||||
end
|
||||
|
||||
ffi_gc(encrypt_ctx, C.EVP_CIPHER_CTX_free)
|
||||
|
||||
local decrypt_ctx = C.EVP_CIPHER_CTX_new()
|
||||
if decrypt_ctx == nil then
|
||||
return nil, "no memory"
|
||||
end
|
||||
|
||||
ffi_gc(decrypt_ctx, C.EVP_CIPHER_CTX_free)
|
||||
|
||||
local _cipher = _cipher or cipher()
|
||||
local _hash = _hash or hash.md5
|
||||
local hash_rounds = hash_rounds or 1
|
||||
local _cipherLength = _cipher.size/8
|
||||
local gen_key = ffi_new("unsigned char[?]",_cipherLength)
|
||||
local gen_iv = ffi_new("unsigned char[?]",_cipherLength)
|
||||
iv_len = iv_len or _cipherLength
|
||||
|
||||
if type(_hash) == "table" then
|
||||
if not _hash.iv then
|
||||
return nil, "iv is needed"
|
||||
end
|
||||
|
||||
--[[ Depending on the encryption algorithm, the length of iv will be
|
||||
different. For detailed, please refer to
|
||||
https://www.openssl.org/docs/man1.1.0/man3/EVP_CIPHER_CTX_ctrl.html
|
||||
]]
|
||||
iv_len = #_hash.iv
|
||||
if iv_len > _cipherLength then
|
||||
return nil, "bad iv length"
|
||||
end
|
||||
|
||||
if _hash.method then
|
||||
local tmp_key = _hash.method(key)
|
||||
|
||||
if #tmp_key ~= _cipherLength then
|
||||
return nil, "bad key length"
|
||||
end
|
||||
|
||||
ffi_copy(gen_key, tmp_key, _cipherLength)
|
||||
|
||||
elseif #key ~= _cipherLength then
|
||||
return nil, "bad key length"
|
||||
|
||||
else
|
||||
ffi_copy(gen_key, key, _cipherLength)
|
||||
end
|
||||
|
||||
ffi_copy(gen_iv, _hash.iv, iv_len)
|
||||
|
||||
else
|
||||
if salt and #salt ~= 8 then
|
||||
return nil, "salt must be 8 characters or nil"
|
||||
end
|
||||
|
||||
if C.EVP_BytesToKey(_cipher.method, _hash, salt, key, #key,
|
||||
hash_rounds, gen_key, gen_iv)
|
||||
~= _cipherLength
|
||||
then
|
||||
return nil, "failed to generate key and iv"
|
||||
end
|
||||
end
|
||||
|
||||
if C.EVP_EncryptInit_ex(encrypt_ctx, _cipher.method, nil,
|
||||
nil, nil) == 0 or
|
||||
C.EVP_DecryptInit_ex(decrypt_ctx, _cipher.method, nil,
|
||||
nil, nil) == 0 then
|
||||
return nil, "failed to init ctx"
|
||||
end
|
||||
|
||||
local cipher_name = _cipher.cipher
|
||||
if cipher_name == "gcm"
|
||||
or cipher_name == "ccm"
|
||||
or cipher_name == "ocb" then
|
||||
if C.EVP_CIPHER_CTX_ctrl(encrypt_ctx, EVP_CTRL_AEAD_SET_IVLEN,
|
||||
iv_len, nil) == 0 or
|
||||
C.EVP_CIPHER_CTX_ctrl(decrypt_ctx, EVP_CTRL_AEAD_SET_IVLEN,
|
||||
iv_len, nil) == 0 then
|
||||
return nil, "failed to set IV length"
|
||||
end
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
_encrypt_ctx = encrypt_ctx,
|
||||
_decrypt_ctx = decrypt_ctx,
|
||||
_cipher = _cipher.cipher,
|
||||
_key = gen_key,
|
||||
_iv = gen_iv
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.encrypt(self, s)
|
||||
local typ = type(self)
|
||||
if typ ~= "table" then
|
||||
error("bad argument #1 self: table expected, got " .. typ, 2)
|
||||
end
|
||||
|
||||
local s_len = #s
|
||||
local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH
|
||||
local buf = ffi_new("unsigned char[?]", max_len)
|
||||
local out_len = ffi_new("int[1]")
|
||||
local tmp_len = ffi_new("int[1]")
|
||||
local ctx = self._encrypt_ctx
|
||||
|
||||
if C.EVP_EncryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then
|
||||
return nil, "EVP_EncryptInit_ex failed"
|
||||
end
|
||||
|
||||
if C.EVP_EncryptUpdate(ctx, buf, out_len, s, s_len) == 0 then
|
||||
return nil, "EVP_EncryptUpdate failed"
|
||||
end
|
||||
|
||||
if self._cipher == "gcm" then
|
||||
local encrypt_data = ffi_str(buf, out_len[0])
|
||||
if C.EVP_EncryptFinal_ex(ctx, buf, out_len) == 0 then
|
||||
return nil, "EVP_DecryptFinal_ex failed"
|
||||
end
|
||||
|
||||
-- FIXME: For OCB mode the taglen must either be 16
|
||||
-- or the value previously set via EVP_CTRL_OCB_SET_TAGLEN.
|
||||
-- so we should extend this api in the future
|
||||
C.EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, 16, buf);
|
||||
local tag = ffi_str(buf, 16)
|
||||
return {encrypt_data, tag}
|
||||
end
|
||||
|
||||
if C.EVP_EncryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then
|
||||
return nil, "EVP_EncryptFinal_ex failed"
|
||||
end
|
||||
|
||||
return ffi_str(buf, out_len[0] + tmp_len[0])
|
||||
end
|
||||
|
||||
|
||||
function _M.decrypt(self, s, tag)
|
||||
local typ = type(self)
|
||||
if typ ~= "table" then
|
||||
error("bad argument #1 self: table expected, got " .. typ, 2)
|
||||
end
|
||||
|
||||
local s_len = #s
|
||||
local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH
|
||||
local buf = ffi_new("unsigned char[?]", max_len)
|
||||
local out_len = ffi_new("int[1]")
|
||||
local tmp_len = ffi_new("int[1]")
|
||||
local ctx = self._decrypt_ctx
|
||||
|
||||
if C.EVP_DecryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then
|
||||
return nil, "EVP_DecryptInit_ex failed"
|
||||
end
|
||||
|
||||
if C.EVP_DecryptUpdate(ctx, buf, out_len, s, s_len) == 0 then
|
||||
return nil, "EVP_DecryptUpdate failed"
|
||||
end
|
||||
|
||||
if self._cipher == "gcm" then
|
||||
local plain_txt = ffi_str(buf, out_len[0])
|
||||
if tag ~= nil then
|
||||
local tag_buf = ffi_new("unsigned char[?]", 16)
|
||||
ffi.copy(tag_buf, tag, 16)
|
||||
C.EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, 16, tag_buf);
|
||||
end
|
||||
|
||||
if C.EVP_DecryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then
|
||||
return nil, "EVP_DecryptFinal_ex failed"
|
||||
end
|
||||
|
||||
return plain_txt
|
||||
end
|
||||
|
||||
if C.EVP_DecryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then
|
||||
return nil, "EVP_DecryptFinal_ex failed"
|
||||
end
|
||||
|
||||
return ffi_str(buf, out_len[0] + tmp_len[0])
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
|
||||
35
resty/core.lua
Normal file
35
resty/core.lua
Normal file
@@ -0,0 +1,35 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
require "resty.core.var"
|
||||
require "resty.core.worker"
|
||||
require "resty.core.regex"
|
||||
require "resty.core.shdict"
|
||||
require "resty.core.time"
|
||||
require "resty.core.hash"
|
||||
require "resty.core.uri"
|
||||
require "resty.core.exit"
|
||||
require "resty.core.base64"
|
||||
require "resty.core.request"
|
||||
|
||||
|
||||
if subsystem == 'http' then
|
||||
require "resty.core.response"
|
||||
require "resty.core.phase"
|
||||
require "resty.core.ndk"
|
||||
require "resty.core.socket"
|
||||
end
|
||||
|
||||
|
||||
require "resty.core.misc"
|
||||
require "resty.core.ctx"
|
||||
|
||||
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
return {
|
||||
version = base.version
|
||||
}
|
||||
259
resty/core/base.lua
Normal file
259
resty/core/base.lua
Normal file
@@ -0,0 +1,259 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require 'ffi'
|
||||
local ffi_new = ffi.new
|
||||
local error = error
|
||||
local select = select
|
||||
local ceil = math.ceil
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local str_buf_size = 4096
|
||||
local str_buf
|
||||
local size_ptr
|
||||
local FREE_LIST_REF = 0
|
||||
|
||||
|
||||
if subsystem == 'http' then
|
||||
local ngx_lua_v = ngx.config.ngx_lua_version
|
||||
if not ngx.config
|
||||
or not ngx.config.ngx_lua_version
|
||||
or (ngx_lua_v ~= 10019 and ngx_lua_v ~= 10020)
|
||||
then
|
||||
error("ngx_http_lua_module 0.10.19 or 0.10.20 required")
|
||||
end
|
||||
|
||||
elseif subsystem == 'stream' then
|
||||
if not ngx.config
|
||||
or not ngx.config.ngx_lua_version
|
||||
or ngx.config.ngx_lua_version ~= 10
|
||||
then
|
||||
error("ngx_stream_lua_module 0.0.10 required")
|
||||
end
|
||||
|
||||
else
|
||||
error("ngx_http_lua_module 0.10.20 or "
|
||||
.. "ngx_stream_lua_module 0.0.10 required")
|
||||
end
|
||||
|
||||
|
||||
if string.find(jit.version, " 2.0", 1, true) then
|
||||
ngx.log(ngx.ALERT, "use of lua-resty-core with LuaJIT 2.0 is ",
|
||||
"not recommended; use LuaJIT 2.1+ instead")
|
||||
end
|
||||
|
||||
|
||||
local ok, new_tab = pcall(require, "table.new")
|
||||
if not ok then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
|
||||
local clear_tab
|
||||
ok, clear_tab = pcall(require, "table.clear")
|
||||
if not ok then
|
||||
local pairs = pairs
|
||||
clear_tab = function (tab)
|
||||
for k, _ in pairs(tab) do
|
||||
tab[k] = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
-- XXX for now LuaJIT 2.1 cannot compile require()
|
||||
-- so we make the fast code path Lua only in our own
|
||||
-- wrapper so that most of the require() calls in hot
|
||||
-- Lua code paths can be JIT compiled.
|
||||
do
|
||||
local orig_require = require
|
||||
local pkg_loaded = package.loaded
|
||||
local function my_require(name)
|
||||
local mod = pkg_loaded[name]
|
||||
if mod then
|
||||
return mod
|
||||
end
|
||||
return orig_require(name)
|
||||
end
|
||||
getfenv(0).require = my_require
|
||||
end
|
||||
|
||||
|
||||
if not pcall(ffi.typeof, "ngx_str_t") then
|
||||
ffi.cdef[[
|
||||
typedef struct {
|
||||
size_t len;
|
||||
const unsigned char *data;
|
||||
} ngx_str_t;
|
||||
]]
|
||||
end
|
||||
|
||||
|
||||
if subsystem == 'http' then
|
||||
if not pcall(ffi.typeof, "ngx_http_request_t") then
|
||||
ffi.cdef[[
|
||||
typedef struct ngx_http_request_s ngx_http_request_t;
|
||||
]]
|
||||
end
|
||||
|
||||
if not pcall(ffi.typeof, "ngx_http_lua_ffi_str_t") then
|
||||
ffi.cdef[[
|
||||
typedef struct {
|
||||
int len;
|
||||
const unsigned char *data;
|
||||
} ngx_http_lua_ffi_str_t;
|
||||
]]
|
||||
end
|
||||
|
||||
elseif subsystem == 'stream' then
|
||||
if not pcall(ffi.typeof, "ngx_stream_lua_request_t") then
|
||||
ffi.cdef[[
|
||||
typedef struct ngx_stream_lua_request_s ngx_stream_lua_request_t;
|
||||
]]
|
||||
end
|
||||
|
||||
if not pcall(ffi.typeof, "ngx_stream_lua_ffi_str_t") then
|
||||
ffi.cdef[[
|
||||
typedef struct {
|
||||
int len;
|
||||
const unsigned char *data;
|
||||
} ngx_stream_lua_ffi_str_t;
|
||||
]]
|
||||
end
|
||||
|
||||
else
|
||||
error("unknown subsystem: " .. subsystem)
|
||||
end
|
||||
|
||||
|
||||
local c_buf_type = ffi.typeof("char[?]")
|
||||
|
||||
|
||||
local _M = new_tab(0, 18)
|
||||
|
||||
|
||||
_M.version = "0.1.22"
|
||||
_M.new_tab = new_tab
|
||||
_M.clear_tab = clear_tab
|
||||
|
||||
|
||||
local errmsg
|
||||
|
||||
|
||||
function _M.get_errmsg_ptr()
|
||||
if not errmsg then
|
||||
errmsg = ffi_new("char *[1]")
|
||||
end
|
||||
return errmsg
|
||||
end
|
||||
|
||||
|
||||
if not ngx then
|
||||
error("no existing ngx. table found")
|
||||
end
|
||||
|
||||
|
||||
function _M.set_string_buf_size(size)
|
||||
if size <= 0 then
|
||||
return
|
||||
end
|
||||
if str_buf then
|
||||
str_buf = nil
|
||||
end
|
||||
str_buf_size = ceil(size)
|
||||
end
|
||||
|
||||
|
||||
function _M.get_string_buf_size()
|
||||
return str_buf_size
|
||||
end
|
||||
|
||||
|
||||
function _M.get_size_ptr()
|
||||
if not size_ptr then
|
||||
size_ptr = ffi_new("size_t[1]")
|
||||
end
|
||||
|
||||
return size_ptr
|
||||
end
|
||||
|
||||
|
||||
function _M.get_string_buf(size, must_alloc)
|
||||
-- ngx.log(ngx.ERR, "str buf size: ", str_buf_size)
|
||||
if size > str_buf_size or must_alloc then
|
||||
return ffi_new(c_buf_type, size)
|
||||
end
|
||||
|
||||
if not str_buf then
|
||||
str_buf = ffi_new(c_buf_type, str_buf_size)
|
||||
end
|
||||
|
||||
return str_buf
|
||||
end
|
||||
|
||||
|
||||
function _M.ref_in_table(tb, key)
|
||||
if key == nil then
|
||||
return -1
|
||||
end
|
||||
local ref = tb[FREE_LIST_REF]
|
||||
if ref and ref ~= 0 then
|
||||
tb[FREE_LIST_REF] = tb[ref]
|
||||
|
||||
else
|
||||
ref = #tb + 1
|
||||
end
|
||||
tb[ref] = key
|
||||
|
||||
-- print("ref key_id returned ", ref)
|
||||
return ref
|
||||
end
|
||||
|
||||
|
||||
function _M.allows_subsystem(...)
|
||||
local total = select("#", ...)
|
||||
|
||||
for i = 1, total do
|
||||
if select(i, ...) == subsystem then
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
error("unsupported subsystem: " .. subsystem, 2)
|
||||
end
|
||||
|
||||
|
||||
_M.FFI_OK = 0
|
||||
_M.FFI_NO_REQ_CTX = -100
|
||||
_M.FFI_BAD_CONTEXT = -101
|
||||
_M.FFI_ERROR = -1
|
||||
_M.FFI_AGAIN = -2
|
||||
_M.FFI_BUSY = -3
|
||||
_M.FFI_DONE = -4
|
||||
_M.FFI_DECLINED = -5
|
||||
|
||||
|
||||
do
|
||||
local exdata
|
||||
|
||||
ok, exdata = pcall(require, "thread.exdata")
|
||||
if ok and exdata then
|
||||
function _M.get_request()
|
||||
local r = exdata()
|
||||
if r ~= nil then
|
||||
return r
|
||||
end
|
||||
end
|
||||
|
||||
else
|
||||
local getfenv = getfenv
|
||||
|
||||
function _M.get_request()
|
||||
return getfenv(0).__ngx_req
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
115
resty/core/base64.lua
Normal file
115
resty/core/base64.lua
Normal file
@@ -0,0 +1,115 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_string = ffi.string
|
||||
local ngx = ngx
|
||||
local type = type
|
||||
local error = error
|
||||
local floor = math.floor
|
||||
local tostring = tostring
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_size_ptr = base.get_size_ptr
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_encode_base64
|
||||
local ngx_lua_ffi_decode_base64
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
size_t ngx_http_lua_ffi_encode_base64(const unsigned char *src,
|
||||
size_t len, unsigned char *dst,
|
||||
int no_padding);
|
||||
|
||||
int ngx_http_lua_ffi_decode_base64(const unsigned char *src,
|
||||
size_t len, unsigned char *dst,
|
||||
size_t *dlen);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_encode_base64 = C.ngx_http_lua_ffi_encode_base64
|
||||
ngx_lua_ffi_decode_base64 = C.ngx_http_lua_ffi_decode_base64
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
size_t ngx_stream_lua_ffi_encode_base64(const unsigned char *src,
|
||||
size_t len, unsigned char *dst,
|
||||
int no_padding);
|
||||
|
||||
int ngx_stream_lua_ffi_decode_base64(const unsigned char *src,
|
||||
size_t len, unsigned char *dst,
|
||||
size_t *dlen);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_encode_base64 = C.ngx_stream_lua_ffi_encode_base64
|
||||
ngx_lua_ffi_decode_base64 = C.ngx_stream_lua_ffi_decode_base64
|
||||
end
|
||||
|
||||
|
||||
local function base64_encoded_length(len, no_padding)
|
||||
return no_padding and floor((len * 8 + 5) / 6) or
|
||||
floor((len + 2) / 3) * 4
|
||||
end
|
||||
|
||||
|
||||
ngx.encode_base64 = function (s, no_padding)
|
||||
if type(s) ~= 'string' then
|
||||
if not s then
|
||||
s = ''
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
|
||||
local slen = #s
|
||||
local no_padding_bool = false;
|
||||
local no_padding_int = 0;
|
||||
|
||||
if no_padding then
|
||||
if no_padding ~= true then
|
||||
local typ = type(no_padding)
|
||||
error("bad no_padding: boolean expected, got " .. typ, 2)
|
||||
end
|
||||
|
||||
no_padding_bool = true
|
||||
no_padding_int = 1;
|
||||
end
|
||||
|
||||
local dlen = base64_encoded_length(slen, no_padding_bool)
|
||||
local dst = get_string_buf(dlen)
|
||||
local r_dlen = ngx_lua_ffi_encode_base64(s, slen, dst, no_padding_int)
|
||||
-- if dlen ~= r_dlen then error("discrepancy in len") end
|
||||
return ffi_string(dst, r_dlen)
|
||||
end
|
||||
|
||||
|
||||
local function base64_decoded_length(len)
|
||||
return floor((len + 3) / 4) * 3
|
||||
end
|
||||
|
||||
|
||||
ngx.decode_base64 = function (s)
|
||||
if type(s) ~= 'string' then
|
||||
error("string argument only", 2)
|
||||
end
|
||||
local slen = #s
|
||||
local dlen = base64_decoded_length(slen)
|
||||
-- print("dlen: ", tonumber(dlen))
|
||||
local dst = get_string_buf(dlen)
|
||||
local pdlen = get_size_ptr()
|
||||
local ok = ngx_lua_ffi_decode_base64(s, slen, dst, pdlen)
|
||||
if ok == 0 then
|
||||
return nil
|
||||
end
|
||||
return ffi_string(dst, pdlen[0])
|
||||
end
|
||||
|
||||
|
||||
return {
|
||||
version = base.version
|
||||
}
|
||||
143
resty/core/ctx.lua
Normal file
143
resty/core/ctx.lua
Normal file
@@ -0,0 +1,143 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local debug = require "debug"
|
||||
local base = require "resty.core.base"
|
||||
local misc = require "resty.core.misc"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local register_getter = misc.register_ngx_magic_key_getter
|
||||
local register_setter = misc.register_ngx_magic_key_setter
|
||||
local registry = debug.getregistry()
|
||||
local new_tab = base.new_tab
|
||||
local ref_in_table = base.ref_in_table
|
||||
local get_request = base.get_request
|
||||
local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX
|
||||
local FFI_OK = base.FFI_OK
|
||||
local error = error
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_get_ctx_ref
|
||||
local ngx_lua_ffi_set_ctx_ref
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_get_ctx_ref(ngx_http_request_t *r, int *in_ssl_phase,
|
||||
int *ssl_ctx_ref);
|
||||
int ngx_http_lua_ffi_set_ctx_ref(ngx_http_request_t *r, int ref);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_get_ctx_ref = C.ngx_http_lua_ffi_get_ctx_ref
|
||||
ngx_lua_ffi_set_ctx_ref = C.ngx_http_lua_ffi_set_ctx_ref
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
int ngx_stream_lua_ffi_get_ctx_ref(ngx_stream_lua_request_t *r,
|
||||
int *in_ssl_phase, int *ssl_ctx_ref);
|
||||
int ngx_stream_lua_ffi_set_ctx_ref(ngx_stream_lua_request_t *r, int ref);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_get_ctx_ref = C.ngx_stream_lua_ffi_get_ctx_ref
|
||||
ngx_lua_ffi_set_ctx_ref = C.ngx_stream_lua_ffi_set_ctx_ref
|
||||
end
|
||||
|
||||
|
||||
local _M = {
|
||||
_VERSION = base.version
|
||||
}
|
||||
|
||||
|
||||
local get_ctx_table
|
||||
do
|
||||
local in_ssl_phase = ffi.new("int[1]")
|
||||
local ssl_ctx_ref = ffi.new("int[1]")
|
||||
|
||||
function get_ctx_table(ctx)
|
||||
local r = get_request()
|
||||
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
local ctx_ref = ngx_lua_ffi_get_ctx_ref(r, in_ssl_phase, ssl_ctx_ref)
|
||||
if ctx_ref == FFI_NO_REQ_CTX then
|
||||
error("no request ctx found")
|
||||
end
|
||||
|
||||
local ctxs = registry.ngx_lua_ctx_tables
|
||||
if ctx_ref < 0 then
|
||||
ctx_ref = ssl_ctx_ref[0]
|
||||
if ctx_ref > 0 and ctxs[ctx_ref] then
|
||||
if in_ssl_phase[0] ~= 0 then
|
||||
return ctxs[ctx_ref]
|
||||
end
|
||||
|
||||
if not ctx then
|
||||
ctx = new_tab(0, 4)
|
||||
end
|
||||
|
||||
ctx = setmetatable(ctx, ctxs[ctx_ref])
|
||||
|
||||
else
|
||||
if in_ssl_phase[0] ~= 0 then
|
||||
if not ctx then
|
||||
ctx = new_tab(1, 4)
|
||||
end
|
||||
|
||||
-- to avoid creating another table, we assume the users
|
||||
-- won't overwrite the `__index` key
|
||||
ctx.__index = ctx
|
||||
|
||||
elseif not ctx then
|
||||
ctx = new_tab(0, 4)
|
||||
end
|
||||
end
|
||||
|
||||
ctx_ref = ref_in_table(ctxs, ctx)
|
||||
if ngx_lua_ffi_set_ctx_ref(r, ctx_ref) ~= FFI_OK then
|
||||
return nil
|
||||
end
|
||||
return ctx
|
||||
end
|
||||
return ctxs[ctx_ref]
|
||||
end
|
||||
end
|
||||
register_getter("ctx", get_ctx_table)
|
||||
_M.get_ctx_table = get_ctx_table
|
||||
|
||||
|
||||
local function set_ctx_table(ctx)
|
||||
local ctx_type = type(ctx)
|
||||
if ctx_type ~= "table" then
|
||||
error("ctx should be a table while getting a " .. ctx_type)
|
||||
end
|
||||
|
||||
local r = get_request()
|
||||
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
local ctx_ref = ngx_lua_ffi_get_ctx_ref(r, nil, nil)
|
||||
if ctx_ref == FFI_NO_REQ_CTX then
|
||||
error("no request ctx found")
|
||||
end
|
||||
|
||||
local ctxs = registry.ngx_lua_ctx_tables
|
||||
if ctx_ref < 0 then
|
||||
ctx_ref = ref_in_table(ctxs, ctx)
|
||||
ngx_lua_ffi_set_ctx_ref(r, ctx_ref)
|
||||
return
|
||||
end
|
||||
ctxs[ctx_ref] = ctx
|
||||
end
|
||||
register_setter("ctx", set_ctx_table)
|
||||
|
||||
|
||||
return _M
|
||||
66
resty/core/exit.lua
Normal file
66
resty/core/exit.lua
Normal file
@@ -0,0 +1,66 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_string = ffi.string
|
||||
local ngx = ngx
|
||||
local error = error
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_size_ptr = base.get_size_ptr
|
||||
local get_request = base.get_request
|
||||
local co_yield = coroutine._yield
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_exit
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_exit(ngx_http_request_t *r, int status,
|
||||
unsigned char *err, size_t *errlen);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_exit = C.ngx_http_lua_ffi_exit
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
int ngx_stream_lua_ffi_exit(ngx_stream_lua_request_t *r, int status,
|
||||
unsigned char *err, size_t *errlen);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_exit = C.ngx_stream_lua_ffi_exit
|
||||
end
|
||||
|
||||
|
||||
local ERR_BUF_SIZE = 128
|
||||
local FFI_DONE = base.FFI_DONE
|
||||
|
||||
|
||||
ngx.exit = function (rc)
|
||||
local err = get_string_buf(ERR_BUF_SIZE)
|
||||
local errlen = get_size_ptr()
|
||||
local r = get_request()
|
||||
if r == nil then
|
||||
error("no request found")
|
||||
end
|
||||
errlen[0] = ERR_BUF_SIZE
|
||||
rc = ngx_lua_ffi_exit(r, rc, err, errlen)
|
||||
if rc == 0 then
|
||||
-- print("yielding...")
|
||||
return co_yield()
|
||||
end
|
||||
if rc == FFI_DONE then
|
||||
return
|
||||
end
|
||||
error(ffi_string(err, errlen[0]), 2)
|
||||
end
|
||||
|
||||
|
||||
return {
|
||||
version = base.version
|
||||
}
|
||||
154
resty/core/hash.lua
Normal file
154
resty/core/hash.lua
Normal file
@@ -0,0 +1,154 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_new = ffi.new
|
||||
local ffi_string = ffi.string
|
||||
local ngx = ngx
|
||||
local type = type
|
||||
local error = error
|
||||
local tostring = tostring
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_md5
|
||||
local ngx_lua_ffi_md5_bin
|
||||
local ngx_lua_ffi_sha1_bin
|
||||
local ngx_lua_ffi_crc32_long
|
||||
local ngx_lua_ffi_crc32_short
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
void ngx_http_lua_ffi_md5_bin(const unsigned char *src, size_t len,
|
||||
unsigned char *dst);
|
||||
|
||||
void ngx_http_lua_ffi_md5(const unsigned char *src, size_t len,
|
||||
unsigned char *dst);
|
||||
|
||||
int ngx_http_lua_ffi_sha1_bin(const unsigned char *src, size_t len,
|
||||
unsigned char *dst);
|
||||
|
||||
unsigned int ngx_http_lua_ffi_crc32_long(const unsigned char *src,
|
||||
size_t len);
|
||||
|
||||
unsigned int ngx_http_lua_ffi_crc32_short(const unsigned char *src,
|
||||
size_t len);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_md5 = C.ngx_http_lua_ffi_md5
|
||||
ngx_lua_ffi_md5_bin = C.ngx_http_lua_ffi_md5_bin
|
||||
ngx_lua_ffi_sha1_bin = C.ngx_http_lua_ffi_sha1_bin
|
||||
ngx_lua_ffi_crc32_short = C.ngx_http_lua_ffi_crc32_short
|
||||
ngx_lua_ffi_crc32_long = C.ngx_http_lua_ffi_crc32_long
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
void ngx_stream_lua_ffi_md5_bin(const unsigned char *src, size_t len,
|
||||
unsigned char *dst);
|
||||
|
||||
void ngx_stream_lua_ffi_md5(const unsigned char *src, size_t len,
|
||||
unsigned char *dst);
|
||||
|
||||
int ngx_stream_lua_ffi_sha1_bin(const unsigned char *src, size_t len,
|
||||
unsigned char *dst);
|
||||
|
||||
unsigned int ngx_stream_lua_ffi_crc32_long(const unsigned char *src,
|
||||
size_t len);
|
||||
|
||||
unsigned int ngx_stream_lua_ffi_crc32_short(const unsigned char *src,
|
||||
size_t len);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_md5 = C.ngx_stream_lua_ffi_md5
|
||||
ngx_lua_ffi_md5_bin = C.ngx_stream_lua_ffi_md5_bin
|
||||
ngx_lua_ffi_sha1_bin = C.ngx_stream_lua_ffi_sha1_bin
|
||||
ngx_lua_ffi_crc32_short = C.ngx_stream_lua_ffi_crc32_short
|
||||
ngx_lua_ffi_crc32_long = C.ngx_stream_lua_ffi_crc32_long
|
||||
end
|
||||
|
||||
|
||||
local MD5_DIGEST_LEN = 16
|
||||
local md5_buf = ffi_new("unsigned char[?]", MD5_DIGEST_LEN)
|
||||
|
||||
ngx.md5_bin = function (s)
|
||||
if type(s) ~= 'string' then
|
||||
if not s then
|
||||
s = ''
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
ngx_lua_ffi_md5_bin(s, #s, md5_buf)
|
||||
return ffi_string(md5_buf, MD5_DIGEST_LEN)
|
||||
end
|
||||
|
||||
|
||||
local MD5_HEX_DIGEST_LEN = MD5_DIGEST_LEN * 2
|
||||
local md5_hex_buf = ffi_new("unsigned char[?]", MD5_HEX_DIGEST_LEN)
|
||||
|
||||
ngx.md5 = function (s)
|
||||
if type(s) ~= 'string' then
|
||||
if not s then
|
||||
s = ''
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
ngx_lua_ffi_md5(s, #s, md5_hex_buf)
|
||||
return ffi_string(md5_hex_buf, MD5_HEX_DIGEST_LEN)
|
||||
end
|
||||
|
||||
|
||||
local SHA_DIGEST_LEN = 20
|
||||
local sha_buf = ffi_new("unsigned char[?]", SHA_DIGEST_LEN)
|
||||
|
||||
ngx.sha1_bin = function (s)
|
||||
if type(s) ~= 'string' then
|
||||
if not s then
|
||||
s = ''
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
local ok = ngx_lua_ffi_sha1_bin(s, #s, sha_buf)
|
||||
if ok == 0 then
|
||||
error("SHA-1 support missing in Nginx")
|
||||
end
|
||||
return ffi_string(sha_buf, SHA_DIGEST_LEN)
|
||||
end
|
||||
|
||||
|
||||
ngx.crc32_short = function (s)
|
||||
if type(s) ~= "string" then
|
||||
if not s then
|
||||
s = ""
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
|
||||
return ngx_lua_ffi_crc32_short(s, #s)
|
||||
end
|
||||
|
||||
|
||||
ngx.crc32_long = function (s)
|
||||
if type(s) ~= "string" then
|
||||
if not s then
|
||||
s = ""
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
|
||||
return ngx_lua_ffi_crc32_long(s, #s)
|
||||
end
|
||||
|
||||
|
||||
return {
|
||||
version = base.version
|
||||
}
|
||||
240
resty/core/misc.lua
Normal file
240
resty/core/misc.lua
Normal file
@@ -0,0 +1,240 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local base = require "resty.core.base"
|
||||
local ffi = require "ffi"
|
||||
local os = require "os"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local ngx = ngx
|
||||
local type = type
|
||||
local error = error
|
||||
local rawget = rawget
|
||||
local rawset = rawset
|
||||
local tonumber = tonumber
|
||||
local setmetatable = setmetatable
|
||||
local FFI_OK = base.FFI_OK
|
||||
local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX
|
||||
local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT
|
||||
local new_tab = base.new_tab
|
||||
local get_request = base.get_request
|
||||
local get_size_ptr = base.get_size_ptr
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_string_buf_size = base.get_string_buf_size
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_get_resp_status
|
||||
local ngx_lua_ffi_get_conf_env
|
||||
local ngx_magic_key_getters
|
||||
local ngx_magic_key_setters
|
||||
|
||||
|
||||
local _M = new_tab(0, 3)
|
||||
local ngx_mt = new_tab(0, 2)
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ngx_magic_key_getters = new_tab(0, 4)
|
||||
ngx_magic_key_setters = new_tab(0, 2)
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ngx_magic_key_getters = new_tab(0, 2)
|
||||
ngx_magic_key_setters = new_tab(0, 1)
|
||||
end
|
||||
|
||||
|
||||
local function register_getter(key, func)
|
||||
ngx_magic_key_getters[key] = func
|
||||
end
|
||||
_M.register_ngx_magic_key_getter = register_getter
|
||||
|
||||
|
||||
local function register_setter(key, func)
|
||||
ngx_magic_key_setters[key] = func
|
||||
end
|
||||
_M.register_ngx_magic_key_setter = register_setter
|
||||
|
||||
|
||||
ngx_mt.__index = function (tb, key)
|
||||
local f = ngx_magic_key_getters[key]
|
||||
if f then
|
||||
return f()
|
||||
end
|
||||
return rawget(tb, key)
|
||||
end
|
||||
|
||||
|
||||
ngx_mt.__newindex = function (tb, key, ctx)
|
||||
local f = ngx_magic_key_setters[key]
|
||||
if f then
|
||||
return f(ctx)
|
||||
end
|
||||
return rawset(tb, key, ctx)
|
||||
end
|
||||
|
||||
|
||||
setmetatable(ngx, ngx_mt)
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_get_resp_status(ngx_http_request_t *r);
|
||||
int ngx_http_lua_ffi_set_resp_status(ngx_http_request_t *r, int r);
|
||||
int ngx_http_lua_ffi_is_subrequest(ngx_http_request_t *r);
|
||||
int ngx_http_lua_ffi_headers_sent(ngx_http_request_t *r);
|
||||
int ngx_http_lua_ffi_get_conf_env(const unsigned char *name,
|
||||
unsigned char **env_buf,
|
||||
size_t *name_len);
|
||||
]]
|
||||
|
||||
|
||||
ngx_lua_ffi_get_resp_status = C.ngx_http_lua_ffi_get_resp_status
|
||||
ngx_lua_ffi_get_conf_env = C.ngx_http_lua_ffi_get_conf_env
|
||||
|
||||
|
||||
-- ngx.status
|
||||
|
||||
|
||||
local function set_status(status)
|
||||
local r = get_request()
|
||||
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if type(status) ~= 'number' then
|
||||
status = tonumber(status)
|
||||
end
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_set_resp_status(r, status)
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
return
|
||||
end
|
||||
register_setter("status", set_status)
|
||||
|
||||
|
||||
-- ngx.is_subrequest
|
||||
|
||||
|
||||
local function is_subreq()
|
||||
local r = get_request()
|
||||
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_is_subrequest(r)
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
return rc == 1
|
||||
end
|
||||
register_getter("is_subrequest", is_subreq)
|
||||
|
||||
|
||||
-- ngx.headers_sent
|
||||
|
||||
|
||||
local function headers_sent()
|
||||
local r = get_request()
|
||||
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_headers_sent(r)
|
||||
|
||||
if rc == FFI_NO_REQ_CTX then
|
||||
error("no request ctx found")
|
||||
end
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
return rc == 1
|
||||
end
|
||||
register_getter("headers_sent", headers_sent)
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
int ngx_stream_lua_ffi_get_resp_status(ngx_stream_lua_request_t *r);
|
||||
int ngx_stream_lua_ffi_get_conf_env(const unsigned char *name,
|
||||
unsigned char **env_buf,
|
||||
size_t *name_len);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_get_resp_status = C.ngx_stream_lua_ffi_get_resp_status
|
||||
ngx_lua_ffi_get_conf_env = C.ngx_stream_lua_ffi_get_conf_env
|
||||
end
|
||||
|
||||
|
||||
-- ngx.status
|
||||
|
||||
|
||||
local function get_status()
|
||||
local r = get_request()
|
||||
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
local rc = ngx_lua_ffi_get_resp_status(r)
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
return rc
|
||||
end
|
||||
register_getter("status", get_status)
|
||||
|
||||
|
||||
do
|
||||
local _getenv = os.getenv
|
||||
local env_ptr = ffi_new("unsigned char *[1]")
|
||||
|
||||
os.getenv = function (name)
|
||||
local r = get_request()
|
||||
if r then
|
||||
-- past init_by_lua* phase now
|
||||
os.getenv = _getenv
|
||||
env_ptr = nil
|
||||
return os.getenv(name)
|
||||
end
|
||||
|
||||
local size = get_string_buf_size()
|
||||
env_ptr[0] = get_string_buf(size)
|
||||
local name_len_ptr = get_size_ptr()
|
||||
|
||||
local rc = ngx_lua_ffi_get_conf_env(name, env_ptr, name_len_ptr)
|
||||
if rc == FFI_OK then
|
||||
return ffi_str(env_ptr[0] + name_len_ptr[0] + 1)
|
||||
end
|
||||
|
||||
-- FFI_DECLINED
|
||||
|
||||
local value = _getenv(name)
|
||||
if value ~= nil then
|
||||
return value
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
_M._VERSION = base.version
|
||||
|
||||
|
||||
return _M
|
||||
92
resty/core/ndk.lua
Normal file
92
resty/core/ndk.lua
Normal file
@@ -0,0 +1,92 @@
|
||||
-- Copyright (C) by OpenResty Inc.
|
||||
|
||||
|
||||
local ffi = require 'ffi'
|
||||
local base = require "resty.core.base"
|
||||
base.allows_subsystem('http')
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_cast = ffi.cast
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local FFI_OK = base.FFI_OK
|
||||
local new_tab = base.new_tab
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_request = base.get_request
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local tostring = tostring
|
||||
local error = error
|
||||
|
||||
|
||||
local _M = {
|
||||
version = base.version
|
||||
}
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef void * ndk_set_var_value_pt;
|
||||
|
||||
int ngx_http_lua_ffi_ndk_lookup_directive(const unsigned char *var_data,
|
||||
size_t var_len, ndk_set_var_value_pt *func);
|
||||
int ngx_http_lua_ffi_ndk_set_var_get(ngx_http_request_t *r,
|
||||
ndk_set_var_value_pt func, const unsigned char *arg_data, size_t arg_len,
|
||||
ngx_http_lua_ffi_str_t *value);
|
||||
]]
|
||||
|
||||
|
||||
local func_p = ffi_new("void*[1]")
|
||||
local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t")
|
||||
local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*")
|
||||
|
||||
|
||||
local function ndk_set_var_get(self, var)
|
||||
if type(var) ~= "string" then
|
||||
var = tostring(var)
|
||||
end
|
||||
|
||||
if C.ngx_http_lua_ffi_ndk_lookup_directive(var, #var, func_p) ~= FFI_OK then
|
||||
error('ndk.set_var: directive "' .. var
|
||||
.. '" not found or does not use ndk_set_var_value', 2)
|
||||
end
|
||||
|
||||
local func = func_p[0]
|
||||
|
||||
return function (arg)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if type(arg) ~= "string" then
|
||||
arg = tostring(arg)
|
||||
end
|
||||
|
||||
local buf = get_string_buf(ffi_str_size)
|
||||
local value = ffi_cast(ffi_str_type, buf)
|
||||
local rc = C.ngx_http_lua_ffi_ndk_set_var_get(r, func, arg, #arg, value)
|
||||
if rc ~= FFI_OK then
|
||||
error("calling directive " .. var .. " failed with code " .. rc, 2)
|
||||
end
|
||||
|
||||
return ffi_str(value.data, value.len)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function ndk_set_var_set()
|
||||
error("not allowed", 2)
|
||||
end
|
||||
|
||||
|
||||
if ndk then
|
||||
local mt = new_tab(0, 2)
|
||||
mt.__newindex = ndk_set_var_set
|
||||
mt.__index = ndk_set_var_get
|
||||
|
||||
ndk.set_var = setmetatable(new_tab(0, 0), mt)
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
59
resty/core/phase.lua
Normal file
59
resty/core/phase.lua
Normal file
@@ -0,0 +1,59 @@
|
||||
local ffi = require 'ffi'
|
||||
local base = require "resty.core.base"
|
||||
|
||||
local C = ffi.C
|
||||
local FFI_ERROR = base.FFI_ERROR
|
||||
local get_request = base.get_request
|
||||
local error = error
|
||||
local tostring = tostring
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_get_phase(ngx_http_request_t *r, char **err)
|
||||
]]
|
||||
|
||||
|
||||
local errmsg = base.get_errmsg_ptr()
|
||||
local context_names = {
|
||||
[0x0001] = "set",
|
||||
[0x0002] = "rewrite",
|
||||
[0x0004] = "access",
|
||||
[0x0008] = "content",
|
||||
[0x0010] = "log",
|
||||
[0x0020] = "header_filter",
|
||||
[0x0040] = "body_filter",
|
||||
[0x0080] = "timer",
|
||||
[0x0100] = "init_worker",
|
||||
[0x0200] = "balancer",
|
||||
[0x0400] = "ssl_cert",
|
||||
[0x0800] = "ssl_session_store",
|
||||
[0x1000] = "ssl_session_fetch",
|
||||
[0x2000] = "exit_worker",
|
||||
}
|
||||
|
||||
|
||||
function ngx.get_phase()
|
||||
local r = get_request()
|
||||
|
||||
-- if we have no request object, assume we are called from the "init" phase
|
||||
if not r then
|
||||
return "init"
|
||||
end
|
||||
|
||||
local context = C.ngx_http_lua_ffi_get_phase(r, errmsg)
|
||||
if context == FFI_ERROR then -- NGX_ERROR
|
||||
error(errmsg, 2)
|
||||
end
|
||||
|
||||
local phase = context_names[context]
|
||||
if not phase then
|
||||
error("unknown phase: " .. tostring(context))
|
||||
end
|
||||
|
||||
return phase
|
||||
end
|
||||
|
||||
|
||||
return {
|
||||
version = base.version
|
||||
}
|
||||
1213
resty/core/regex.lua
Normal file
1213
resty/core/regex.lua
Normal file
File diff suppressed because it is too large
Load Diff
451
resty/core/request.lua
Normal file
451
resty/core/request.lua
Normal file
@@ -0,0 +1,451 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require 'ffi'
|
||||
local base = require "resty.core.base"
|
||||
local utils = require "resty.core.utils"
|
||||
|
||||
|
||||
local subsystem = ngx.config.subsystem
|
||||
local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT
|
||||
local FFI_DECLINED = base.FFI_DECLINED
|
||||
local FFI_OK = base.FFI_OK
|
||||
local new_tab = base.new_tab
|
||||
local C = ffi.C
|
||||
local ffi_cast = ffi.cast
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_size_ptr = base.get_size_ptr
|
||||
local setmetatable = setmetatable
|
||||
local lower = string.lower
|
||||
local rawget = rawget
|
||||
local ngx = ngx
|
||||
local get_request = base.get_request
|
||||
local type = type
|
||||
local error = error
|
||||
local tostring = tostring
|
||||
local tonumber = tonumber
|
||||
local str_replace_char = utils.str_replace_char
|
||||
|
||||
|
||||
local _M = {
|
||||
version = base.version
|
||||
}
|
||||
|
||||
|
||||
local ngx_lua_ffi_req_start_time
|
||||
|
||||
|
||||
if subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
double ngx_stream_lua_ffi_req_start_time(ngx_stream_lua_request_t *r);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_req_start_time = C.ngx_stream_lua_ffi_req_start_time
|
||||
|
||||
elseif subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
double ngx_http_lua_ffi_req_start_time(ngx_http_request_t *r);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_req_start_time = C.ngx_http_lua_ffi_req_start_time
|
||||
end
|
||||
|
||||
|
||||
function ngx.req.start_time()
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
return tonumber(ngx_lua_ffi_req_start_time(r))
|
||||
end
|
||||
|
||||
|
||||
if subsystem == "stream" then
|
||||
return _M
|
||||
end
|
||||
|
||||
|
||||
local errmsg = base.get_errmsg_ptr()
|
||||
local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*")
|
||||
local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t")
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef struct {
|
||||
ngx_http_lua_ffi_str_t key;
|
||||
ngx_http_lua_ffi_str_t value;
|
||||
} ngx_http_lua_ffi_table_elt_t;
|
||||
|
||||
int ngx_http_lua_ffi_req_get_headers_count(ngx_http_request_t *r,
|
||||
int max, int *truncated);
|
||||
|
||||
int ngx_http_lua_ffi_req_get_headers(ngx_http_request_t *r,
|
||||
ngx_http_lua_ffi_table_elt_t *out, int count, int raw);
|
||||
|
||||
int ngx_http_lua_ffi_req_get_uri_args_count(ngx_http_request_t *r,
|
||||
int max, int *truncated);
|
||||
|
||||
size_t ngx_http_lua_ffi_req_get_querystring_len(ngx_http_request_t *r);
|
||||
|
||||
int ngx_http_lua_ffi_req_get_uri_args(ngx_http_request_t *r,
|
||||
unsigned char *buf, ngx_http_lua_ffi_table_elt_t *out, int count);
|
||||
|
||||
int ngx_http_lua_ffi_req_get_method(ngx_http_request_t *r);
|
||||
|
||||
int ngx_http_lua_ffi_req_get_method_name(ngx_http_request_t *r,
|
||||
unsigned char **name, size_t *len);
|
||||
|
||||
int ngx_http_lua_ffi_req_set_method(ngx_http_request_t *r, int method);
|
||||
|
||||
int ngx_http_lua_ffi_req_set_header(ngx_http_request_t *r,
|
||||
const unsigned char *key, size_t key_len, const unsigned char *value,
|
||||
size_t value_len, ngx_http_lua_ffi_str_t *mvals, size_t mvals_len,
|
||||
int override, char **errmsg);
|
||||
]]
|
||||
|
||||
|
||||
local table_elt_type = ffi.typeof("ngx_http_lua_ffi_table_elt_t*")
|
||||
local table_elt_size = ffi.sizeof("ngx_http_lua_ffi_table_elt_t")
|
||||
local truncated = ffi.new("int[1]")
|
||||
|
||||
local req_headers_mt = {
|
||||
__index = function (tb, key)
|
||||
return rawget(tb, (str_replace_char(lower(key), '_', '-')))
|
||||
end
|
||||
}
|
||||
|
||||
|
||||
function ngx.req.get_headers(max_headers, raw)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if not max_headers then
|
||||
max_headers = -1
|
||||
end
|
||||
|
||||
if not raw then
|
||||
raw = 0
|
||||
else
|
||||
raw = 1
|
||||
end
|
||||
|
||||
local n = C.ngx_http_lua_ffi_req_get_headers_count(r, max_headers,
|
||||
truncated)
|
||||
if n == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
if n == 0 then
|
||||
local headers = {}
|
||||
if raw == 0 then
|
||||
headers = setmetatable(headers, req_headers_mt)
|
||||
end
|
||||
|
||||
return headers
|
||||
end
|
||||
|
||||
local raw_buf = get_string_buf(n * table_elt_size)
|
||||
local buf = ffi_cast(table_elt_type, raw_buf)
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_req_get_headers(r, buf, n, raw)
|
||||
if rc == 0 then
|
||||
local headers = new_tab(0, n)
|
||||
for i = 0, n - 1 do
|
||||
local h = buf[i]
|
||||
|
||||
local key = h.key
|
||||
key = ffi_str(key.data, key.len)
|
||||
|
||||
local value = h.value
|
||||
value = ffi_str(value.data, value.len)
|
||||
|
||||
local existing = headers[key]
|
||||
if existing then
|
||||
if type(existing) == "table" then
|
||||
existing[#existing + 1] = value
|
||||
else
|
||||
headers[key] = {existing, value}
|
||||
end
|
||||
|
||||
else
|
||||
headers[key] = value
|
||||
end
|
||||
end
|
||||
|
||||
if raw == 0 then
|
||||
headers = setmetatable(headers, req_headers_mt)
|
||||
end
|
||||
|
||||
if truncated[0] ~= 0 then
|
||||
return headers, "truncated"
|
||||
end
|
||||
|
||||
return headers
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
|
||||
function ngx.req.get_uri_args(max_args)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if not max_args then
|
||||
max_args = -1
|
||||
end
|
||||
|
||||
local n = C.ngx_http_lua_ffi_req_get_uri_args_count(r, max_args, truncated)
|
||||
if n == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
if n == 0 then
|
||||
return {}
|
||||
end
|
||||
|
||||
local args_len = C.ngx_http_lua_ffi_req_get_querystring_len(r)
|
||||
|
||||
local strbuf = get_string_buf(args_len + n * table_elt_size)
|
||||
local kvbuf = ffi_cast(table_elt_type, strbuf + args_len)
|
||||
|
||||
local nargs = C.ngx_http_lua_ffi_req_get_uri_args(r, strbuf, kvbuf, n)
|
||||
|
||||
local args = new_tab(0, nargs)
|
||||
for i = 0, nargs - 1 do
|
||||
local arg = kvbuf[i]
|
||||
|
||||
local key = arg.key
|
||||
key = ffi_str(key.data, key.len)
|
||||
|
||||
local value = arg.value
|
||||
local len = value.len
|
||||
if len == -1 then
|
||||
value = true
|
||||
else
|
||||
value = ffi_str(value.data, len)
|
||||
end
|
||||
|
||||
local existing = args[key]
|
||||
if existing then
|
||||
if type(existing) == "table" then
|
||||
existing[#existing + 1] = value
|
||||
else
|
||||
args[key] = {existing, value}
|
||||
end
|
||||
|
||||
else
|
||||
args[key] = value
|
||||
end
|
||||
end
|
||||
|
||||
if truncated[0] ~= 0 then
|
||||
return args, "truncated"
|
||||
end
|
||||
|
||||
return args
|
||||
end
|
||||
|
||||
|
||||
do
|
||||
local methods = {
|
||||
[0x0002] = "GET",
|
||||
[0x0004] = "HEAD",
|
||||
[0x0008] = "POST",
|
||||
[0x0010] = "PUT",
|
||||
[0x0020] = "DELETE",
|
||||
[0x0040] = "MKCOL",
|
||||
[0x0080] = "COPY",
|
||||
[0x0100] = "MOVE",
|
||||
[0x0200] = "OPTIONS",
|
||||
[0x0400] = "PROPFIND",
|
||||
[0x0800] = "PROPPATCH",
|
||||
[0x1000] = "LOCK",
|
||||
[0x2000] = "UNLOCK",
|
||||
[0x4000] = "PATCH",
|
||||
[0x8000] = "TRACE",
|
||||
}
|
||||
|
||||
local namep = ffi_new("unsigned char *[1]")
|
||||
|
||||
function ngx.req.get_method()
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
do
|
||||
local id = C.ngx_http_lua_ffi_req_get_method(r)
|
||||
if id == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
local method = methods[id]
|
||||
if method then
|
||||
return method
|
||||
end
|
||||
end
|
||||
|
||||
local sizep = get_size_ptr()
|
||||
local rc = C.ngx_http_lua_ffi_req_get_method_name(r, namep, sizep)
|
||||
if rc ~= 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return ffi_str(namep[0], sizep[0])
|
||||
end
|
||||
end -- do
|
||||
|
||||
|
||||
function ngx.req.set_method(method)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if type(method) ~= "number" then
|
||||
error("bad method number", 2)
|
||||
end
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_req_set_method(r, method)
|
||||
if rc == FFI_OK then
|
||||
return
|
||||
end
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
if rc == FFI_DECLINED then
|
||||
error("unsupported HTTP method: " .. method, 2)
|
||||
end
|
||||
|
||||
error("unknown error: " .. rc)
|
||||
end
|
||||
|
||||
|
||||
do
|
||||
local function set_req_header(name, value, override)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found", 3)
|
||||
end
|
||||
|
||||
if name == nil then
|
||||
error("bad 'name' argument: string expected, got nil", 3)
|
||||
end
|
||||
|
||||
if type(name) ~= "string" then
|
||||
name = tostring(name)
|
||||
end
|
||||
|
||||
local rc
|
||||
|
||||
if value == nil then
|
||||
if not override then
|
||||
error("bad 'value' argument: string or table expected, got nil",
|
||||
3)
|
||||
end
|
||||
|
||||
rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, nil, 0, nil,
|
||||
0, 1, errmsg)
|
||||
|
||||
else
|
||||
local sval, sval_len, mvals, mvals_len, buf
|
||||
local value_type = type(value)
|
||||
|
||||
if value_type == "table" then
|
||||
mvals_len = #value
|
||||
if mvals_len == 0 and not override then
|
||||
error("bad 'value' argument: non-empty table expected", 3)
|
||||
end
|
||||
|
||||
buf = get_string_buf(ffi_str_size * mvals_len)
|
||||
mvals = ffi_cast(ffi_str_type, buf)
|
||||
|
||||
for i = 1, mvals_len do
|
||||
local s = value[i]
|
||||
if type(s) ~= "string" then
|
||||
s = tostring(s)
|
||||
value[i] = s
|
||||
end
|
||||
|
||||
local str = mvals[i - 1]
|
||||
str.data = s
|
||||
str.len = #s
|
||||
end
|
||||
|
||||
sval_len = 0
|
||||
|
||||
else
|
||||
if value_type ~= "string" then
|
||||
sval = tostring(value)
|
||||
else
|
||||
sval = value
|
||||
end
|
||||
|
||||
sval_len = #sval
|
||||
mvals_len = 0
|
||||
end
|
||||
|
||||
rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, sval,
|
||||
sval_len, mvals, mvals_len,
|
||||
override and 1 or 0, errmsg)
|
||||
end
|
||||
|
||||
if rc == FFI_OK or rc == FFI_DECLINED then
|
||||
return
|
||||
end
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 3)
|
||||
end
|
||||
|
||||
-- rc == FFI_ERROR
|
||||
error(ffi_str(errmsg[0]))
|
||||
end
|
||||
|
||||
|
||||
_M.set_req_header = set_req_header
|
||||
|
||||
|
||||
function ngx.req.set_header(name, value)
|
||||
set_req_header(name, value, true) -- override
|
||||
end
|
||||
end -- do
|
||||
|
||||
|
||||
function ngx.req.clear_header(name)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if type(name) ~= "string" then
|
||||
name = tostring(name)
|
||||
end
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, nil, 0, nil, 0,
|
||||
1, errmsg)
|
||||
|
||||
if rc == FFI_OK or rc == FFI_DECLINED then
|
||||
return
|
||||
end
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
-- rc == FFI_ERROR
|
||||
error(ffi_str(errmsg[0]))
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
183
resty/core/response.lua
Normal file
183
resty/core/response.lua
Normal file
@@ -0,0 +1,183 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require 'ffi'
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_cast = ffi.cast
|
||||
local ffi_str = ffi.string
|
||||
local new_tab = base.new_tab
|
||||
local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT
|
||||
local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX
|
||||
local FFI_DECLINED = base.FFI_DECLINED
|
||||
local get_string_buf = base.get_string_buf
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local tostring = tostring
|
||||
local get_request = base.get_request
|
||||
local error = error
|
||||
local ngx = ngx
|
||||
|
||||
|
||||
local _M = {
|
||||
version = base.version
|
||||
}
|
||||
|
||||
|
||||
local MAX_HEADER_VALUES = 100
|
||||
local errmsg = base.get_errmsg_ptr()
|
||||
local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*")
|
||||
local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t")
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_set_resp_header(ngx_http_request_t *r,
|
||||
const char *key_data, size_t key_len, int is_nil,
|
||||
const char *sval, size_t sval_len, ngx_http_lua_ffi_str_t *mvals,
|
||||
size_t mvals_len, int override, char **errmsg);
|
||||
|
||||
int ngx_http_lua_ffi_get_resp_header(ngx_http_request_t *r,
|
||||
const unsigned char *key, size_t key_len,
|
||||
unsigned char *key_buf, ngx_http_lua_ffi_str_t *values,
|
||||
int max_nvalues, char **errmsg);
|
||||
]]
|
||||
|
||||
|
||||
local function set_resp_header(tb, key, value, no_override)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local rc
|
||||
if value == nil then
|
||||
if no_override then
|
||||
error("invalid header value", 3)
|
||||
end
|
||||
|
||||
rc = C.ngx_http_lua_ffi_set_resp_header(r, key, #key, true, nil, 0, nil,
|
||||
0, 1, errmsg)
|
||||
else
|
||||
local sval, sval_len, mvals, mvals_len, buf
|
||||
|
||||
if type(value) == "table" then
|
||||
mvals_len = #value
|
||||
if mvals_len == 0 and no_override then
|
||||
return
|
||||
end
|
||||
|
||||
buf = get_string_buf(ffi_str_size * mvals_len)
|
||||
mvals = ffi_cast(ffi_str_type, buf)
|
||||
for i = 1, mvals_len do
|
||||
local s = value[i]
|
||||
if type(s) ~= "string" then
|
||||
s = tostring(s)
|
||||
value[i] = s
|
||||
end
|
||||
local str = mvals[i - 1]
|
||||
str.data = s
|
||||
str.len = #s
|
||||
end
|
||||
|
||||
sval_len = 0
|
||||
|
||||
else
|
||||
if type(value) ~= "string" then
|
||||
sval = tostring(value)
|
||||
else
|
||||
sval = value
|
||||
end
|
||||
sval_len = #sval
|
||||
|
||||
mvals_len = 0
|
||||
end
|
||||
|
||||
local override_int = no_override and 0 or 1
|
||||
rc = C.ngx_http_lua_ffi_set_resp_header(r, key, #key, false, sval,
|
||||
sval_len, mvals, mvals_len,
|
||||
override_int, errmsg)
|
||||
end
|
||||
|
||||
if rc == 0 or rc == FFI_DECLINED then
|
||||
return
|
||||
end
|
||||
|
||||
if rc == FFI_NO_REQ_CTX then
|
||||
error("no request ctx found")
|
||||
end
|
||||
|
||||
if rc == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
-- rc == FFI_ERROR
|
||||
error(ffi_str(errmsg[0]), 2)
|
||||
end
|
||||
|
||||
|
||||
_M.set_resp_header = set_resp_header
|
||||
|
||||
|
||||
local function get_resp_header(tb, key)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local key_len = #key
|
||||
|
||||
local key_buf = get_string_buf(key_len + ffi_str_size * MAX_HEADER_VALUES)
|
||||
local values = ffi_cast(ffi_str_type, key_buf + key_len)
|
||||
local n = C.ngx_http_lua_ffi_get_resp_header(r, key, key_len, key_buf,
|
||||
values, MAX_HEADER_VALUES,
|
||||
errmsg)
|
||||
|
||||
-- print("retval: ", n)
|
||||
|
||||
if n == FFI_BAD_CONTEXT then
|
||||
error("API disabled in the current context", 2)
|
||||
end
|
||||
|
||||
if n == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
if n == 1 then
|
||||
local v = values[0]
|
||||
return ffi_str(v.data, v.len)
|
||||
end
|
||||
|
||||
if n > 0 then
|
||||
local ret = new_tab(n, 0)
|
||||
for i = 1, n do
|
||||
local v = values[i - 1]
|
||||
ret[i] = ffi_str(v.data, v.len)
|
||||
end
|
||||
return ret
|
||||
end
|
||||
|
||||
-- n == FFI_ERROR
|
||||
error(ffi_str(errmsg[0]), 2)
|
||||
end
|
||||
|
||||
|
||||
do
|
||||
local mt = new_tab(0, 2)
|
||||
mt.__newindex = set_resp_header
|
||||
mt.__index = get_resp_header
|
||||
|
||||
ngx.header = setmetatable(new_tab(0, 0), mt)
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
638
resty/core/shdict.lua
Normal file
638
resty/core/shdict.lua
Normal file
@@ -0,0 +1,638 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require 'ffi'
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local _M = {
|
||||
version = base.version
|
||||
}
|
||||
|
||||
local ngx_shared = ngx.shared
|
||||
if not ngx_shared then
|
||||
return _M
|
||||
end
|
||||
|
||||
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_string_buf_size = base.get_string_buf_size
|
||||
local get_size_ptr = base.get_size_ptr
|
||||
local tonumber = tonumber
|
||||
local tostring = tostring
|
||||
local next = next
|
||||
local type = type
|
||||
local error = error
|
||||
local getmetatable = getmetatable
|
||||
local FFI_DECLINED = base.FFI_DECLINED
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_shdict_get
|
||||
local ngx_lua_ffi_shdict_incr
|
||||
local ngx_lua_ffi_shdict_store
|
||||
local ngx_lua_ffi_shdict_flush_all
|
||||
local ngx_lua_ffi_shdict_get_ttl
|
||||
local ngx_lua_ffi_shdict_set_expire
|
||||
local ngx_lua_ffi_shdict_capacity
|
||||
local ngx_lua_ffi_shdict_free_space
|
||||
local ngx_lua_ffi_shdict_udata_to_zone
|
||||
|
||||
|
||||
if subsystem == 'http' then
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_shdict_get(void *zone, const unsigned char *key,
|
||||
size_t key_len, int *value_type, unsigned char **str_value_buf,
|
||||
size_t *str_value_len, double *num_value, int *user_flags,
|
||||
int get_stale, int *is_stale, char **errmsg);
|
||||
|
||||
int ngx_http_lua_ffi_shdict_incr(void *zone, const unsigned char *key,
|
||||
size_t key_len, double *value, char **err, int has_init,
|
||||
double init, long init_ttl, int *forcible);
|
||||
|
||||
int ngx_http_lua_ffi_shdict_store(void *zone, int op,
|
||||
const unsigned char *key, size_t key_len, int value_type,
|
||||
const unsigned char *str_value_buf, size_t str_value_len,
|
||||
double num_value, long exptime, int user_flags, char **errmsg,
|
||||
int *forcible);
|
||||
|
||||
int ngx_http_lua_ffi_shdict_flush_all(void *zone);
|
||||
|
||||
long ngx_http_lua_ffi_shdict_get_ttl(void *zone,
|
||||
const unsigned char *key, size_t key_len);
|
||||
|
||||
int ngx_http_lua_ffi_shdict_set_expire(void *zone,
|
||||
const unsigned char *key, size_t key_len, long exptime);
|
||||
|
||||
size_t ngx_http_lua_ffi_shdict_capacity(void *zone);
|
||||
|
||||
void *ngx_http_lua_ffi_shdict_udata_to_zone(void *zone_udata);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_shdict_get = C.ngx_http_lua_ffi_shdict_get
|
||||
ngx_lua_ffi_shdict_incr = C.ngx_http_lua_ffi_shdict_incr
|
||||
ngx_lua_ffi_shdict_store = C.ngx_http_lua_ffi_shdict_store
|
||||
ngx_lua_ffi_shdict_flush_all = C.ngx_http_lua_ffi_shdict_flush_all
|
||||
ngx_lua_ffi_shdict_get_ttl = C.ngx_http_lua_ffi_shdict_get_ttl
|
||||
ngx_lua_ffi_shdict_set_expire = C.ngx_http_lua_ffi_shdict_set_expire
|
||||
ngx_lua_ffi_shdict_capacity = C.ngx_http_lua_ffi_shdict_capacity
|
||||
ngx_lua_ffi_shdict_udata_to_zone =
|
||||
C.ngx_http_lua_ffi_shdict_udata_to_zone
|
||||
|
||||
if not pcall(function ()
|
||||
return C.ngx_http_lua_ffi_shdict_free_space
|
||||
end)
|
||||
then
|
||||
ffi.cdef[[
|
||||
size_t ngx_http_lua_ffi_shdict_free_space(void *zone);
|
||||
]]
|
||||
end
|
||||
|
||||
pcall(function ()
|
||||
ngx_lua_ffi_shdict_free_space = C.ngx_http_lua_ffi_shdict_free_space
|
||||
end)
|
||||
|
||||
elseif subsystem == 'stream' then
|
||||
|
||||
ffi.cdef[[
|
||||
int ngx_stream_lua_ffi_shdict_get(void *zone, const unsigned char *key,
|
||||
size_t key_len, int *value_type, unsigned char **str_value_buf,
|
||||
size_t *str_value_len, double *num_value, int *user_flags,
|
||||
int get_stale, int *is_stale, char **errmsg);
|
||||
|
||||
int ngx_stream_lua_ffi_shdict_incr(void *zone, const unsigned char *key,
|
||||
size_t key_len, double *value, char **err, int has_init,
|
||||
double init, long init_ttl, int *forcible);
|
||||
|
||||
int ngx_stream_lua_ffi_shdict_store(void *zone, int op,
|
||||
const unsigned char *key, size_t key_len, int value_type,
|
||||
const unsigned char *str_value_buf, size_t str_value_len,
|
||||
double num_value, long exptime, int user_flags, char **errmsg,
|
||||
int *forcible);
|
||||
|
||||
int ngx_stream_lua_ffi_shdict_flush_all(void *zone);
|
||||
|
||||
long ngx_stream_lua_ffi_shdict_get_ttl(void *zone,
|
||||
const unsigned char *key, size_t key_len);
|
||||
|
||||
int ngx_stream_lua_ffi_shdict_set_expire(void *zone,
|
||||
const unsigned char *key, size_t key_len, long exptime);
|
||||
|
||||
size_t ngx_stream_lua_ffi_shdict_capacity(void *zone);
|
||||
|
||||
void *ngx_stream_lua_ffi_shdict_udata_to_zone(void *zone_udata);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_shdict_get = C.ngx_stream_lua_ffi_shdict_get
|
||||
ngx_lua_ffi_shdict_incr = C.ngx_stream_lua_ffi_shdict_incr
|
||||
ngx_lua_ffi_shdict_store = C.ngx_stream_lua_ffi_shdict_store
|
||||
ngx_lua_ffi_shdict_flush_all = C.ngx_stream_lua_ffi_shdict_flush_all
|
||||
ngx_lua_ffi_shdict_get_ttl = C.ngx_stream_lua_ffi_shdict_get_ttl
|
||||
ngx_lua_ffi_shdict_set_expire = C.ngx_stream_lua_ffi_shdict_set_expire
|
||||
ngx_lua_ffi_shdict_capacity = C.ngx_stream_lua_ffi_shdict_capacity
|
||||
ngx_lua_ffi_shdict_udata_to_zone =
|
||||
C.ngx_stream_lua_ffi_shdict_udata_to_zone
|
||||
|
||||
if not pcall(function ()
|
||||
return C.ngx_stream_lua_ffi_shdict_free_space
|
||||
end)
|
||||
then
|
||||
ffi.cdef[[
|
||||
size_t ngx_stream_lua_ffi_shdict_free_space(void *zone);
|
||||
]]
|
||||
end
|
||||
|
||||
-- ngx_stream_lua is only compatible with NGINX >= 1.13.6, meaning it
|
||||
-- cannot lack support for ngx_stream_lua_ffi_shdict_free_space.
|
||||
ngx_lua_ffi_shdict_free_space = C.ngx_stream_lua_ffi_shdict_free_space
|
||||
|
||||
else
|
||||
error("unknown subsystem: " .. subsystem)
|
||||
end
|
||||
|
||||
if not pcall(function () return C.free end) then
|
||||
ffi.cdef[[
|
||||
void free(void *ptr);
|
||||
]]
|
||||
end
|
||||
|
||||
|
||||
local value_type = ffi_new("int[1]")
|
||||
local user_flags = ffi_new("int[1]")
|
||||
local num_value = ffi_new("double[1]")
|
||||
local is_stale = ffi_new("int[1]")
|
||||
local forcible = ffi_new("int[1]")
|
||||
local str_value_buf = ffi_new("unsigned char *[1]")
|
||||
local errmsg = base.get_errmsg_ptr()
|
||||
|
||||
|
||||
local function check_zone(zone)
|
||||
if not zone or type(zone) ~= "table" then
|
||||
error("bad \"zone\" argument", 3)
|
||||
end
|
||||
|
||||
zone = zone[1]
|
||||
if type(zone) ~= "userdata" then
|
||||
error("bad \"zone\" argument", 3)
|
||||
end
|
||||
|
||||
zone = ngx_lua_ffi_shdict_udata_to_zone(zone)
|
||||
if zone == nil then
|
||||
error("bad \"zone\" argument", 3)
|
||||
end
|
||||
|
||||
return zone
|
||||
end
|
||||
|
||||
|
||||
local function shdict_store(zone, op, key, value, exptime, flags)
|
||||
zone = check_zone(zone)
|
||||
|
||||
if not exptime then
|
||||
exptime = 0
|
||||
elseif exptime < 0 then
|
||||
error('bad "exptime" argument', 2)
|
||||
end
|
||||
|
||||
if not flags then
|
||||
flags = 0
|
||||
end
|
||||
|
||||
if key == nil then
|
||||
return nil, "nil key"
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local key_len = #key
|
||||
if key_len == 0 then
|
||||
return nil, "empty key"
|
||||
end
|
||||
if key_len > 65535 then
|
||||
return nil, "key too long"
|
||||
end
|
||||
|
||||
local str_val_buf
|
||||
local str_val_len = 0
|
||||
local num_val = 0
|
||||
local valtyp = type(value)
|
||||
|
||||
-- print("value type: ", valtyp)
|
||||
-- print("exptime: ", exptime)
|
||||
|
||||
if valtyp == "string" then
|
||||
valtyp = 4 -- LUA_TSTRING
|
||||
str_val_buf = value
|
||||
str_val_len = #value
|
||||
|
||||
elseif valtyp == "number" then
|
||||
valtyp = 3 -- LUA_TNUMBER
|
||||
num_val = value
|
||||
|
||||
elseif value == nil then
|
||||
valtyp = 0 -- LUA_TNIL
|
||||
|
||||
elseif valtyp == "boolean" then
|
||||
valtyp = 1 -- LUA_TBOOLEAN
|
||||
num_val = value and 1 or 0
|
||||
|
||||
else
|
||||
return nil, "bad value type"
|
||||
end
|
||||
|
||||
local rc = ngx_lua_ffi_shdict_store(zone, op, key, key_len,
|
||||
valtyp, str_val_buf,
|
||||
str_val_len, num_val,
|
||||
exptime * 1000, flags, errmsg,
|
||||
forcible)
|
||||
|
||||
-- print("rc == ", rc)
|
||||
|
||||
if rc == 0 then -- NGX_OK
|
||||
return true, nil, forcible[0] == 1
|
||||
end
|
||||
|
||||
-- NGX_DECLINED or NGX_ERROR
|
||||
return false, ffi_str(errmsg[0]), forcible[0] == 1
|
||||
end
|
||||
|
||||
|
||||
local function shdict_set(zone, key, value, exptime, flags)
|
||||
return shdict_store(zone, 0, key, value, exptime, flags)
|
||||
end
|
||||
|
||||
|
||||
local function shdict_safe_set(zone, key, value, exptime, flags)
|
||||
return shdict_store(zone, 0x0004, key, value, exptime, flags)
|
||||
end
|
||||
|
||||
|
||||
local function shdict_add(zone, key, value, exptime, flags)
|
||||
return shdict_store(zone, 0x0001, key, value, exptime, flags)
|
||||
end
|
||||
|
||||
|
||||
local function shdict_safe_add(zone, key, value, exptime, flags)
|
||||
return shdict_store(zone, 0x0005, key, value, exptime, flags)
|
||||
end
|
||||
|
||||
|
||||
local function shdict_replace(zone, key, value, exptime, flags)
|
||||
return shdict_store(zone, 0x0002, key, value, exptime, flags)
|
||||
end
|
||||
|
||||
|
||||
local function shdict_delete(zone, key)
|
||||
return shdict_set(zone, key, nil)
|
||||
end
|
||||
|
||||
|
||||
local function shdict_get(zone, key)
|
||||
zone = check_zone(zone)
|
||||
|
||||
if key == nil then
|
||||
return nil, "nil key"
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local key_len = #key
|
||||
if key_len == 0 then
|
||||
return nil, "empty key"
|
||||
end
|
||||
if key_len > 65535 then
|
||||
return nil, "key too long"
|
||||
end
|
||||
|
||||
local size = get_string_buf_size()
|
||||
local buf = get_string_buf(size)
|
||||
str_value_buf[0] = buf
|
||||
local value_len = get_size_ptr()
|
||||
value_len[0] = size
|
||||
|
||||
local rc = ngx_lua_ffi_shdict_get(zone, key, key_len, value_type,
|
||||
str_value_buf, value_len,
|
||||
num_value, user_flags, 0,
|
||||
is_stale, errmsg)
|
||||
if rc ~= 0 then
|
||||
if errmsg[0] ~= nil then
|
||||
return nil, ffi_str(errmsg[0])
|
||||
end
|
||||
|
||||
error("failed to get the key")
|
||||
end
|
||||
|
||||
local typ = value_type[0]
|
||||
|
||||
if typ == 0 then -- LUA_TNIL
|
||||
return nil
|
||||
end
|
||||
|
||||
local flags = tonumber(user_flags[0])
|
||||
|
||||
local val
|
||||
|
||||
if typ == 4 then -- LUA_TSTRING
|
||||
if str_value_buf[0] ~= buf then
|
||||
-- ngx.say("len: ", tonumber(value_len[0]))
|
||||
buf = str_value_buf[0]
|
||||
val = ffi_str(buf, value_len[0])
|
||||
C.free(buf)
|
||||
else
|
||||
val = ffi_str(buf, value_len[0])
|
||||
end
|
||||
|
||||
elseif typ == 3 then -- LUA_TNUMBER
|
||||
val = tonumber(num_value[0])
|
||||
|
||||
elseif typ == 1 then -- LUA_TBOOLEAN
|
||||
val = (tonumber(buf[0]) ~= 0)
|
||||
|
||||
else
|
||||
error("unknown value type: " .. typ)
|
||||
end
|
||||
|
||||
if flags ~= 0 then
|
||||
return val, flags
|
||||
end
|
||||
|
||||
return val
|
||||
end
|
||||
|
||||
|
||||
local function shdict_get_stale(zone, key)
|
||||
zone = check_zone(zone)
|
||||
|
||||
if key == nil then
|
||||
return nil, "nil key"
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local key_len = #key
|
||||
if key_len == 0 then
|
||||
return nil, "empty key"
|
||||
end
|
||||
if key_len > 65535 then
|
||||
return nil, "key too long"
|
||||
end
|
||||
|
||||
local size = get_string_buf_size()
|
||||
local buf = get_string_buf(size)
|
||||
str_value_buf[0] = buf
|
||||
local value_len = get_size_ptr()
|
||||
value_len[0] = size
|
||||
|
||||
local rc = ngx_lua_ffi_shdict_get(zone, key, key_len, value_type,
|
||||
str_value_buf, value_len,
|
||||
num_value, user_flags, 1,
|
||||
is_stale, errmsg)
|
||||
if rc ~= 0 then
|
||||
if errmsg[0] ~= nil then
|
||||
return nil, ffi_str(errmsg[0])
|
||||
end
|
||||
|
||||
error("failed to get the key")
|
||||
end
|
||||
|
||||
local typ = value_type[0]
|
||||
|
||||
if typ == 0 then -- LUA_TNIL
|
||||
return nil
|
||||
end
|
||||
|
||||
local flags = tonumber(user_flags[0])
|
||||
local val
|
||||
|
||||
if typ == 4 then -- LUA_TSTRING
|
||||
if str_value_buf[0] ~= buf then
|
||||
-- ngx.say("len: ", tonumber(value_len[0]))
|
||||
buf = str_value_buf[0]
|
||||
val = ffi_str(buf, value_len[0])
|
||||
C.free(buf)
|
||||
else
|
||||
val = ffi_str(buf, value_len[0])
|
||||
end
|
||||
|
||||
elseif typ == 3 then -- LUA_TNUMBER
|
||||
val = tonumber(num_value[0])
|
||||
|
||||
elseif typ == 1 then -- LUA_TBOOLEAN
|
||||
val = (tonumber(buf[0]) ~= 0)
|
||||
|
||||
else
|
||||
error("unknown value type: " .. typ)
|
||||
end
|
||||
|
||||
if flags ~= 0 then
|
||||
return val, flags, is_stale[0] == 1
|
||||
end
|
||||
|
||||
return val, nil, is_stale[0] == 1
|
||||
end
|
||||
|
||||
|
||||
local function shdict_incr(zone, key, value, init, init_ttl)
|
||||
zone = check_zone(zone)
|
||||
|
||||
if key == nil then
|
||||
return nil, "nil key"
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local key_len = #key
|
||||
if key_len == 0 then
|
||||
return nil, "empty key"
|
||||
end
|
||||
if key_len > 65535 then
|
||||
return nil, "key too long"
|
||||
end
|
||||
|
||||
if type(value) ~= "number" then
|
||||
value = tonumber(value)
|
||||
end
|
||||
num_value[0] = value
|
||||
|
||||
if init then
|
||||
local typ = type(init)
|
||||
if typ ~= "number" then
|
||||
init = tonumber(init)
|
||||
|
||||
if not init then
|
||||
error("bad init arg: number expected, got " .. typ, 2)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if init_ttl ~= nil then
|
||||
local typ = type(init_ttl)
|
||||
if typ ~= "number" then
|
||||
init_ttl = tonumber(init_ttl)
|
||||
|
||||
if not init_ttl then
|
||||
error("bad init_ttl arg: number expected, got " .. typ, 2)
|
||||
end
|
||||
end
|
||||
|
||||
if init_ttl < 0 then
|
||||
error('bad "init_ttl" argument', 2)
|
||||
end
|
||||
|
||||
if not init then
|
||||
error('must provide "init" when providing "init_ttl"', 2)
|
||||
end
|
||||
|
||||
else
|
||||
init_ttl = 0
|
||||
end
|
||||
|
||||
local rc = ngx_lua_ffi_shdict_incr(zone, key, key_len, num_value,
|
||||
errmsg, init and 1 or 0,
|
||||
init or 0, init_ttl * 1000,
|
||||
forcible)
|
||||
if rc ~= 0 then -- ~= NGX_OK
|
||||
return nil, ffi_str(errmsg[0])
|
||||
end
|
||||
|
||||
if not init then
|
||||
return tonumber(num_value[0])
|
||||
end
|
||||
|
||||
return tonumber(num_value[0]), nil, forcible[0] == 1
|
||||
end
|
||||
|
||||
|
||||
local function shdict_flush_all(zone)
|
||||
zone = check_zone(zone)
|
||||
|
||||
ngx_lua_ffi_shdict_flush_all(zone)
|
||||
end
|
||||
|
||||
|
||||
local function shdict_ttl(zone, key)
|
||||
zone = check_zone(zone)
|
||||
|
||||
if key == nil then
|
||||
return nil, "nil key"
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local key_len = #key
|
||||
if key_len == 0 then
|
||||
return nil, "empty key"
|
||||
end
|
||||
|
||||
if key_len > 65535 then
|
||||
return nil, "key too long"
|
||||
end
|
||||
|
||||
local rc = ngx_lua_ffi_shdict_get_ttl(zone, key, key_len)
|
||||
|
||||
if rc == FFI_DECLINED then
|
||||
return nil, "not found"
|
||||
end
|
||||
|
||||
return tonumber(rc) / 1000
|
||||
end
|
||||
|
||||
|
||||
local function shdict_expire(zone, key, exptime)
|
||||
zone = check_zone(zone)
|
||||
|
||||
if not exptime then
|
||||
error('bad "exptime" argument', 2)
|
||||
end
|
||||
|
||||
if key == nil then
|
||||
return nil, "nil key"
|
||||
end
|
||||
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local key_len = #key
|
||||
if key_len == 0 then
|
||||
return nil, "empty key"
|
||||
end
|
||||
|
||||
if key_len > 65535 then
|
||||
return nil, "key too long"
|
||||
end
|
||||
|
||||
local rc = ngx_lua_ffi_shdict_set_expire(zone, key, key_len,
|
||||
exptime * 1000)
|
||||
|
||||
if rc == FFI_DECLINED then
|
||||
return nil, "not found"
|
||||
end
|
||||
|
||||
-- NGINX_OK/FFI_OK
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
|
||||
local function shdict_capacity(zone)
|
||||
zone = check_zone(zone)
|
||||
|
||||
return tonumber(ngx_lua_ffi_shdict_capacity(zone))
|
||||
end
|
||||
|
||||
|
||||
local shdict_free_space
|
||||
if ngx_lua_ffi_shdict_free_space then
|
||||
shdict_free_space = function (zone)
|
||||
zone = check_zone(zone)
|
||||
|
||||
return tonumber(ngx_lua_ffi_shdict_free_space(zone))
|
||||
end
|
||||
|
||||
else
|
||||
shdict_free_space = function ()
|
||||
error("'shm:free_space()' not supported in NGINX < 1.11.7", 2)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local _, dict = next(ngx_shared, nil)
|
||||
if dict then
|
||||
local mt = getmetatable(dict)
|
||||
if mt then
|
||||
mt = mt.__index
|
||||
if mt then
|
||||
mt.get = shdict_get
|
||||
mt.get_stale = shdict_get_stale
|
||||
mt.incr = shdict_incr
|
||||
mt.set = shdict_set
|
||||
mt.safe_set = shdict_safe_set
|
||||
mt.add = shdict_add
|
||||
mt.safe_add = shdict_safe_add
|
||||
mt.replace = shdict_replace
|
||||
mt.delete = shdict_delete
|
||||
mt.flush_all = shdict_flush_all
|
||||
mt.ttl = shdict_ttl
|
||||
mt.expire = shdict_expire
|
||||
mt.capacity = shdict_capacity
|
||||
mt.free_space = shdict_free_space
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
124
resty/core/socket.lua
Normal file
124
resty/core/socket.lua
Normal file
@@ -0,0 +1,124 @@
|
||||
local base = require "resty.core.base"
|
||||
base.allows_subsystem('http')
|
||||
local debug = require 'debug'
|
||||
local ffi = require 'ffi'
|
||||
|
||||
|
||||
local error = error
|
||||
local tonumber = tonumber
|
||||
local registry = debug.getregistry()
|
||||
local ffi_new = ffi.new
|
||||
local ffi_string = ffi.string
|
||||
local C = ffi.C
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_size_ptr = base.get_size_ptr
|
||||
local tostring = tostring
|
||||
|
||||
|
||||
local option_index = {
|
||||
["keepalive"] = 1,
|
||||
["reuseaddr"] = 2,
|
||||
["tcp-nodelay"] = 3,
|
||||
["sndbuf"] = 4,
|
||||
["rcvbuf"] = 5,
|
||||
}
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef struct ngx_http_lua_socket_tcp_upstream_s
|
||||
ngx_http_lua_socket_tcp_upstream_t;
|
||||
|
||||
int
|
||||
ngx_http_lua_ffi_socket_tcp_getoption(ngx_http_lua_socket_tcp_upstream_t *u,
|
||||
int opt, int *val, unsigned char *err, size_t *errlen);
|
||||
|
||||
int
|
||||
ngx_http_lua_ffi_socket_tcp_setoption(ngx_http_lua_socket_tcp_upstream_t *u,
|
||||
int opt, int val, unsigned char *err, size_t *errlen);
|
||||
]]
|
||||
|
||||
|
||||
local output_value_buf = ffi_new("int[1]")
|
||||
local FFI_OK = base.FFI_OK
|
||||
local SOCKET_CTX_INDEX = 1
|
||||
local ERR_BUF_SIZE = 4096
|
||||
|
||||
|
||||
local function get_tcp_socket(cosocket)
|
||||
local tcp_socket = cosocket[SOCKET_CTX_INDEX]
|
||||
if not tcp_socket then
|
||||
error("socket is never created nor connected")
|
||||
end
|
||||
|
||||
return tcp_socket
|
||||
end
|
||||
|
||||
|
||||
local function getoption(cosocket, option)
|
||||
local tcp_socket = get_tcp_socket(cosocket)
|
||||
|
||||
if option == nil then
|
||||
return nil, 'missing the "option" argument'
|
||||
end
|
||||
|
||||
if option_index[option] == nil then
|
||||
return nil, "unsupported option " .. tostring(option)
|
||||
end
|
||||
|
||||
local err = get_string_buf(ERR_BUF_SIZE)
|
||||
local errlen = get_size_ptr()
|
||||
errlen[0] = ERR_BUF_SIZE
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_socket_tcp_getoption(tcp_socket,
|
||||
option_index[option],
|
||||
output_value_buf,
|
||||
err,
|
||||
errlen)
|
||||
if rc ~= FFI_OK then
|
||||
return nil, ffi_string(err, errlen[0])
|
||||
end
|
||||
|
||||
return tonumber(output_value_buf[0])
|
||||
end
|
||||
|
||||
|
||||
local function setoption(cosocket, option, value)
|
||||
local tcp_socket = get_tcp_socket(cosocket)
|
||||
|
||||
if option == nil then
|
||||
return nil, 'missing the "option" argument'
|
||||
end
|
||||
|
||||
if value == nil then
|
||||
return nil, 'missing the "value" argument'
|
||||
end
|
||||
|
||||
if option_index[option] == nil then
|
||||
return nil, "unsupported option " .. tostring(option)
|
||||
end
|
||||
|
||||
local err = get_string_buf(ERR_BUF_SIZE)
|
||||
local errlen = get_size_ptr()
|
||||
errlen[0] = ERR_BUF_SIZE
|
||||
|
||||
local rc = C.ngx_http_lua_ffi_socket_tcp_setoption(tcp_socket,
|
||||
option_index[option],
|
||||
value,
|
||||
err,
|
||||
errlen)
|
||||
if rc ~= FFI_OK then
|
||||
return nil, ffi_string(err, errlen[0])
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
|
||||
do
|
||||
local method_table = registry.__tcp_cosocket_mt
|
||||
method_table.getoption = getoption
|
||||
method_table.setoption = setoption
|
||||
end
|
||||
|
||||
|
||||
return { version = base.version }
|
||||
159
resty/core/time.lua
Normal file
159
resty/core/time.lua
Normal file
@@ -0,0 +1,159 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require 'ffi'
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local error = error
|
||||
local tonumber = tonumber
|
||||
local type = type
|
||||
local C = ffi.C
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local time_val = ffi_new("long[1]")
|
||||
local get_string_buf = base.get_string_buf
|
||||
local ngx = ngx
|
||||
local FFI_ERROR = base.FFI_ERROR
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_now
|
||||
local ngx_lua_ffi_time
|
||||
local ngx_lua_ffi_today
|
||||
local ngx_lua_ffi_localtime
|
||||
local ngx_lua_ffi_utctime
|
||||
local ngx_lua_ffi_update_time
|
||||
|
||||
|
||||
if subsystem == 'http' then
|
||||
ffi.cdef[[
|
||||
double ngx_http_lua_ffi_now(void);
|
||||
long ngx_http_lua_ffi_time(void);
|
||||
void ngx_http_lua_ffi_today(unsigned char *buf);
|
||||
void ngx_http_lua_ffi_localtime(unsigned char *buf);
|
||||
void ngx_http_lua_ffi_utctime(unsigned char *buf);
|
||||
void ngx_http_lua_ffi_update_time(void);
|
||||
int ngx_http_lua_ffi_cookie_time(unsigned char *buf, long t);
|
||||
void ngx_http_lua_ffi_http_time(unsigned char *buf, long t);
|
||||
void ngx_http_lua_ffi_parse_http_time(const unsigned char *str, size_t len,
|
||||
long *time);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_now = C.ngx_http_lua_ffi_now
|
||||
ngx_lua_ffi_time = C.ngx_http_lua_ffi_time
|
||||
ngx_lua_ffi_today = C.ngx_http_lua_ffi_today
|
||||
ngx_lua_ffi_localtime = C.ngx_http_lua_ffi_localtime
|
||||
ngx_lua_ffi_utctime = C.ngx_http_lua_ffi_utctime
|
||||
ngx_lua_ffi_update_time = C.ngx_http_lua_ffi_update_time
|
||||
|
||||
elseif subsystem == 'stream' then
|
||||
ffi.cdef[[
|
||||
double ngx_stream_lua_ffi_now(void);
|
||||
long ngx_stream_lua_ffi_time(void);
|
||||
void ngx_stream_lua_ffi_today(unsigned char *buf);
|
||||
void ngx_stream_lua_ffi_localtime(unsigned char *buf);
|
||||
void ngx_stream_lua_ffi_utctime(unsigned char *buf);
|
||||
void ngx_stream_lua_ffi_update_time(void);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_now = C.ngx_stream_lua_ffi_now
|
||||
ngx_lua_ffi_time = C.ngx_stream_lua_ffi_time
|
||||
ngx_lua_ffi_today = C.ngx_stream_lua_ffi_today
|
||||
ngx_lua_ffi_localtime = C.ngx_stream_lua_ffi_localtime
|
||||
ngx_lua_ffi_utctime = C.ngx_stream_lua_ffi_utctime
|
||||
ngx_lua_ffi_update_time = C.ngx_stream_lua_ffi_update_time
|
||||
end
|
||||
|
||||
|
||||
function ngx.now()
|
||||
return tonumber(ngx_lua_ffi_now())
|
||||
end
|
||||
|
||||
|
||||
function ngx.time()
|
||||
return tonumber(ngx_lua_ffi_time())
|
||||
end
|
||||
|
||||
|
||||
function ngx.update_time()
|
||||
ngx_lua_ffi_update_time()
|
||||
end
|
||||
|
||||
|
||||
function ngx.today()
|
||||
-- the format of today is 2010-11-19
|
||||
local today_buf_size = 10
|
||||
local buf = get_string_buf(today_buf_size)
|
||||
ngx_lua_ffi_today(buf)
|
||||
return ffi_str(buf, today_buf_size)
|
||||
end
|
||||
|
||||
|
||||
function ngx.localtime()
|
||||
-- the format of localtime is 2010-11-19 20:56:31
|
||||
local localtime_buf_size = 19
|
||||
local buf = get_string_buf(localtime_buf_size)
|
||||
ngx_lua_ffi_localtime(buf)
|
||||
return ffi_str(buf, localtime_buf_size)
|
||||
end
|
||||
|
||||
|
||||
function ngx.utctime()
|
||||
-- the format of utctime is 2010-11-19 20:56:31
|
||||
local utctime_buf_size = 19
|
||||
local buf = get_string_buf(utctime_buf_size)
|
||||
ngx_lua_ffi_utctime(buf)
|
||||
return ffi_str(buf, utctime_buf_size)
|
||||
end
|
||||
|
||||
|
||||
if subsystem == 'http' then
|
||||
|
||||
function ngx.cookie_time(sec)
|
||||
if type(sec) ~= "number" then
|
||||
error("number argument only", 2)
|
||||
end
|
||||
|
||||
-- the format of cookie time is Mon, 28-Sep-2038 06:00:00 GMT
|
||||
-- or Mon, 28-Sep-18 06:00:00 GMT
|
||||
local cookie_time_buf_size = 29
|
||||
local buf = get_string_buf(cookie_time_buf_size)
|
||||
local used_size = C.ngx_http_lua_ffi_cookie_time(buf, sec)
|
||||
return ffi_str(buf, used_size)
|
||||
end
|
||||
|
||||
|
||||
function ngx.http_time(sec)
|
||||
if type(sec) ~= "number" then
|
||||
error("number argument only", 2)
|
||||
end
|
||||
|
||||
-- the format of http time is Mon, 28 Sep 1970 06:00:00 GMT
|
||||
local http_time_buf_size = 29
|
||||
local buf = get_string_buf(http_time_buf_size)
|
||||
C.ngx_http_lua_ffi_http_time(buf, sec)
|
||||
return ffi_str(buf, http_time_buf_size)
|
||||
end
|
||||
|
||||
|
||||
function ngx.parse_http_time(time_str)
|
||||
if type(time_str) ~= "string" then
|
||||
error("string argument only", 2)
|
||||
end
|
||||
|
||||
C.ngx_http_lua_ffi_parse_http_time(time_str, #time_str, time_val)
|
||||
|
||||
local res = time_val[0]
|
||||
if res == FFI_ERROR then
|
||||
return nil
|
||||
end
|
||||
|
||||
return tonumber(res)
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
return {
|
||||
version = base.version
|
||||
}
|
||||
115
resty/core/uri.lua
Normal file
115
resty/core/uri.lua
Normal file
@@ -0,0 +1,115 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_string = ffi.string
|
||||
local ngx = ngx
|
||||
local type = type
|
||||
local error = error
|
||||
local tostring = tostring
|
||||
local get_string_buf = base.get_string_buf
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_escape_uri
|
||||
local ngx_lua_ffi_unescape_uri
|
||||
local ngx_lua_ffi_uri_escaped_length
|
||||
|
||||
local NGX_ESCAPE_URI = 0
|
||||
local NGX_ESCAPE_URI_COMPONENT = 2
|
||||
local NGX_ESCAPE_MAIL_AUTH = 6
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
size_t ngx_http_lua_ffi_uri_escaped_length(const unsigned char *src,
|
||||
size_t len, int type);
|
||||
|
||||
void ngx_http_lua_ffi_escape_uri(const unsigned char *src, size_t len,
|
||||
unsigned char *dst, int type);
|
||||
|
||||
size_t ngx_http_lua_ffi_unescape_uri(const unsigned char *src,
|
||||
size_t len, unsigned char *dst);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_escape_uri = C.ngx_http_lua_ffi_escape_uri
|
||||
ngx_lua_ffi_unescape_uri = C.ngx_http_lua_ffi_unescape_uri
|
||||
ngx_lua_ffi_uri_escaped_length = C.ngx_http_lua_ffi_uri_escaped_length
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
size_t ngx_stream_lua_ffi_uri_escaped_length(const unsigned char *src,
|
||||
size_t len, int type);
|
||||
|
||||
void ngx_stream_lua_ffi_escape_uri(const unsigned char *src, size_t len,
|
||||
unsigned char *dst, int type);
|
||||
|
||||
size_t ngx_stream_lua_ffi_unescape_uri(const unsigned char *src,
|
||||
size_t len, unsigned char *dst);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_escape_uri = C.ngx_stream_lua_ffi_escape_uri
|
||||
ngx_lua_ffi_unescape_uri = C.ngx_stream_lua_ffi_unescape_uri
|
||||
ngx_lua_ffi_uri_escaped_length = C.ngx_stream_lua_ffi_uri_escaped_length
|
||||
end
|
||||
|
||||
|
||||
ngx.escape_uri = function (s, esc_type)
|
||||
if type(s) ~= 'string' then
|
||||
if not s then
|
||||
s = ''
|
||||
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
|
||||
if esc_type == nil then
|
||||
esc_type = NGX_ESCAPE_URI_COMPONENT
|
||||
|
||||
else
|
||||
if type(esc_type) ~= 'number' then
|
||||
error("\"type\" is not a number", 3)
|
||||
end
|
||||
|
||||
if esc_type < NGX_ESCAPE_URI or esc_type > NGX_ESCAPE_MAIL_AUTH then
|
||||
error("\"type\" " .. esc_type .. " out of range", 3)
|
||||
end
|
||||
end
|
||||
|
||||
local slen = #s
|
||||
local dlen = ngx_lua_ffi_uri_escaped_length(s, slen, esc_type)
|
||||
|
||||
-- print("dlen: ", tonumber(dlen))
|
||||
if dlen == slen then
|
||||
return s
|
||||
end
|
||||
local dst = get_string_buf(dlen)
|
||||
ngx_lua_ffi_escape_uri(s, slen, dst, esc_type)
|
||||
return ffi_string(dst, dlen)
|
||||
end
|
||||
|
||||
|
||||
ngx.unescape_uri = function (s)
|
||||
if type(s) ~= 'string' then
|
||||
if not s then
|
||||
s = ''
|
||||
else
|
||||
s = tostring(s)
|
||||
end
|
||||
end
|
||||
local slen = #s
|
||||
local dlen = slen
|
||||
local dst = get_string_buf(dlen)
|
||||
dlen = ngx_lua_ffi_unescape_uri(s, slen, dst)
|
||||
return ffi_string(dst, dlen)
|
||||
end
|
||||
|
||||
|
||||
return {
|
||||
version = base.version,
|
||||
}
|
||||
46
resty/core/utils.lua
Normal file
46
resty/core/utils.lua
Normal file
@@ -0,0 +1,46 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_str = ffi.string
|
||||
local ffi_copy = ffi.copy
|
||||
local byte = string.byte
|
||||
local str_find = string.find
|
||||
local get_string_buf = base.get_string_buf
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local _M = {
|
||||
version = base.version
|
||||
}
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
void ngx_http_lua_ffi_str_replace_char(unsigned char *buf, size_t len,
|
||||
const unsigned char find, const unsigned char replace);
|
||||
]]
|
||||
|
||||
|
||||
function _M.str_replace_char(str, find, replace)
|
||||
if not str_find(str, find, nil, true) then
|
||||
return str
|
||||
end
|
||||
|
||||
local len = #str
|
||||
local buf = get_string_buf(len)
|
||||
ffi_copy(buf, str, len)
|
||||
|
||||
C.ngx_http_lua_ffi_str_replace_char(buf, len, byte(find),
|
||||
byte(replace))
|
||||
|
||||
return ffi_str(buf, len)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
160
resty/core/var.lua
Normal file
160
resty/core/var.lua
Normal file
@@ -0,0 +1,160 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local type = type
|
||||
local error = error
|
||||
local tostring = tostring
|
||||
local setmetatable = setmetatable
|
||||
local get_request = base.get_request
|
||||
local get_string_buf = base.get_string_buf
|
||||
local get_size_ptr = base.get_size_ptr
|
||||
local new_tab = base.new_tab
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_var_get
|
||||
local ngx_lua_ffi_var_set
|
||||
|
||||
|
||||
local ERR_BUF_SIZE = 256
|
||||
|
||||
|
||||
ngx.var = new_tab(0, 0)
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_var_get(ngx_http_request_t *r,
|
||||
const char *name_data, size_t name_len, char *lowcase_buf,
|
||||
int capture_id, char **value, size_t *value_len, char **err);
|
||||
|
||||
int ngx_http_lua_ffi_var_set(ngx_http_request_t *r,
|
||||
const unsigned char *name_data, size_t name_len,
|
||||
unsigned char *lowcase_buf, const unsigned char *value,
|
||||
size_t value_len, unsigned char *errbuf, size_t *errlen);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_var_get = C.ngx_http_lua_ffi_var_get
|
||||
ngx_lua_ffi_var_set = C.ngx_http_lua_ffi_var_set
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
int ngx_stream_lua_ffi_var_get(ngx_stream_lua_request_t *r,
|
||||
const char *name_data, size_t name_len, char *lowcase_buf,
|
||||
int capture_id, char **value, size_t *value_len, char **err);
|
||||
|
||||
int ngx_stream_lua_ffi_var_set(ngx_stream_lua_request_t *r,
|
||||
const unsigned char *name_data, size_t name_len,
|
||||
unsigned char *lowcase_buf, const unsigned char *value,
|
||||
size_t value_len, unsigned char *errbuf, size_t *errlen);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_var_get = C.ngx_stream_lua_ffi_var_get
|
||||
ngx_lua_ffi_var_set = C.ngx_stream_lua_ffi_var_set
|
||||
end
|
||||
|
||||
|
||||
local value_ptr = ffi_new("unsigned char *[1]")
|
||||
local errmsg = base.get_errmsg_ptr()
|
||||
|
||||
|
||||
local function var_get(self, name)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
local value_len = get_size_ptr()
|
||||
local rc
|
||||
if type(name) == "number" then
|
||||
rc = ngx_lua_ffi_var_get(r, nil, 0, nil, name, value_ptr, value_len,
|
||||
errmsg)
|
||||
|
||||
else
|
||||
if type(name) ~= "string" then
|
||||
error("bad variable name", 2)
|
||||
end
|
||||
|
||||
local name_len = #name
|
||||
local lowcase_buf = get_string_buf(name_len)
|
||||
|
||||
rc = ngx_lua_ffi_var_get(r, name, name_len, lowcase_buf, 0, value_ptr,
|
||||
value_len, errmsg)
|
||||
end
|
||||
|
||||
-- ngx.log(ngx.WARN, "rc = ", rc)
|
||||
|
||||
if rc == 0 then -- NGX_OK
|
||||
return ffi_str(value_ptr[0], value_len[0])
|
||||
end
|
||||
|
||||
if rc == -5 then -- NGX_DECLINED
|
||||
return nil
|
||||
end
|
||||
|
||||
if rc == -1 then -- NGX_ERROR
|
||||
error(ffi_str(errmsg[0]), 2)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function var_set(self, name, value)
|
||||
local r = get_request()
|
||||
if not r then
|
||||
error("no request found")
|
||||
end
|
||||
|
||||
if type(name) ~= "string" then
|
||||
error("bad variable name", 2)
|
||||
end
|
||||
local name_len = #name
|
||||
|
||||
local errlen = get_size_ptr()
|
||||
errlen[0] = ERR_BUF_SIZE
|
||||
local lowcase_buf = get_string_buf(name_len + ERR_BUF_SIZE)
|
||||
|
||||
local value_len
|
||||
if value == nil then
|
||||
value_len = 0
|
||||
else
|
||||
if type(value) ~= 'string' then
|
||||
value = tostring(value)
|
||||
end
|
||||
value_len = #value
|
||||
end
|
||||
|
||||
local errbuf = lowcase_buf + name_len
|
||||
local rc = ngx_lua_ffi_var_set(r, name, name_len, lowcase_buf, value,
|
||||
value_len, errbuf, errlen)
|
||||
|
||||
-- ngx.log(ngx.WARN, "rc = ", rc)
|
||||
|
||||
if rc == 0 then -- NGX_OK
|
||||
return
|
||||
end
|
||||
|
||||
if rc == -1 then -- NGX_ERROR
|
||||
error(ffi_str(errbuf, errlen[0]), 2)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
do
|
||||
local mt = new_tab(0, 2)
|
||||
mt.__index = var_get
|
||||
mt.__newindex = var_set
|
||||
|
||||
setmetatable(ngx.var, mt)
|
||||
end
|
||||
|
||||
|
||||
return {
|
||||
version = base.version
|
||||
}
|
||||
77
resty/core/worker.lua
Normal file
77
resty/core/worker.lua
Normal file
@@ -0,0 +1,77 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local base = require "resty.core.base"
|
||||
|
||||
|
||||
local C = ffi.C
|
||||
local new_tab = base.new_tab
|
||||
local subsystem = ngx.config.subsystem
|
||||
|
||||
|
||||
local ngx_lua_ffi_worker_id
|
||||
local ngx_lua_ffi_worker_pid
|
||||
local ngx_lua_ffi_worker_count
|
||||
local ngx_lua_ffi_worker_exiting
|
||||
|
||||
|
||||
ngx.worker = new_tab(0, 4)
|
||||
|
||||
|
||||
if subsystem == "http" then
|
||||
ffi.cdef[[
|
||||
int ngx_http_lua_ffi_worker_id(void);
|
||||
int ngx_http_lua_ffi_worker_pid(void);
|
||||
int ngx_http_lua_ffi_worker_count(void);
|
||||
int ngx_http_lua_ffi_worker_exiting(void);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_worker_id = C.ngx_http_lua_ffi_worker_id
|
||||
ngx_lua_ffi_worker_pid = C.ngx_http_lua_ffi_worker_pid
|
||||
ngx_lua_ffi_worker_count = C.ngx_http_lua_ffi_worker_count
|
||||
ngx_lua_ffi_worker_exiting = C.ngx_http_lua_ffi_worker_exiting
|
||||
|
||||
elseif subsystem == "stream" then
|
||||
ffi.cdef[[
|
||||
int ngx_stream_lua_ffi_worker_id(void);
|
||||
int ngx_stream_lua_ffi_worker_pid(void);
|
||||
int ngx_stream_lua_ffi_worker_count(void);
|
||||
int ngx_stream_lua_ffi_worker_exiting(void);
|
||||
]]
|
||||
|
||||
ngx_lua_ffi_worker_id = C.ngx_stream_lua_ffi_worker_id
|
||||
ngx_lua_ffi_worker_pid = C.ngx_stream_lua_ffi_worker_pid
|
||||
ngx_lua_ffi_worker_count = C.ngx_stream_lua_ffi_worker_count
|
||||
ngx_lua_ffi_worker_exiting = C.ngx_stream_lua_ffi_worker_exiting
|
||||
end
|
||||
|
||||
|
||||
function ngx.worker.exiting()
|
||||
return ngx_lua_ffi_worker_exiting() ~= 0
|
||||
end
|
||||
|
||||
|
||||
function ngx.worker.pid()
|
||||
return ngx_lua_ffi_worker_pid()
|
||||
end
|
||||
|
||||
|
||||
function ngx.worker.id()
|
||||
local id = ngx_lua_ffi_worker_id()
|
||||
if id < 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return id
|
||||
end
|
||||
|
||||
|
||||
function ngx.worker.count()
|
||||
return ngx_lua_ffi_worker_count()
|
||||
end
|
||||
|
||||
|
||||
return {
|
||||
_VERSION = base.version
|
||||
}
|
||||
982
resty/dns/resolver.lua
Normal file
982
resty/dns/resolver.lua
Normal file
@@ -0,0 +1,982 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
-- local socket = require "socket"
|
||||
local bit = require "bit"
|
||||
local udp = ngx.socket.udp
|
||||
local rand = math.random
|
||||
local char = string.char
|
||||
local byte = string.byte
|
||||
local find = string.find
|
||||
local gsub = string.gsub
|
||||
local sub = string.sub
|
||||
local rep = string.rep
|
||||
local format = string.format
|
||||
local band = bit.band
|
||||
local rshift = bit.rshift
|
||||
local lshift = bit.lshift
|
||||
local insert = table.insert
|
||||
local concat = table.concat
|
||||
local re_sub = ngx.re.sub
|
||||
local tcp = ngx.socket.tcp
|
||||
local log = ngx.log
|
||||
local DEBUG = ngx.DEBUG
|
||||
local unpack = unpack
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local ipairs = ipairs
|
||||
|
||||
|
||||
local ok, new_tab = pcall(require, "table.new")
|
||||
if not ok then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
|
||||
local DOT_CHAR = byte(".")
|
||||
local ZERO_CHAR = byte("0")
|
||||
local COLON_CHAR = byte(":")
|
||||
|
||||
local IP6_ARPA = "ip6.arpa"
|
||||
|
||||
local TYPE_A = 1
|
||||
local TYPE_NS = 2
|
||||
local TYPE_CNAME = 5
|
||||
local TYPE_SOA = 6
|
||||
local TYPE_PTR = 12
|
||||
local TYPE_MX = 15
|
||||
local TYPE_TXT = 16
|
||||
local TYPE_AAAA = 28
|
||||
local TYPE_SRV = 33
|
||||
local TYPE_SPF = 99
|
||||
|
||||
local CLASS_IN = 1
|
||||
|
||||
local SECTION_AN = 1
|
||||
local SECTION_NS = 2
|
||||
local SECTION_AR = 3
|
||||
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.22',
|
||||
TYPE_A = TYPE_A,
|
||||
TYPE_NS = TYPE_NS,
|
||||
TYPE_CNAME = TYPE_CNAME,
|
||||
TYPE_SOA = TYPE_SOA,
|
||||
TYPE_PTR = TYPE_PTR,
|
||||
TYPE_MX = TYPE_MX,
|
||||
TYPE_TXT = TYPE_TXT,
|
||||
TYPE_AAAA = TYPE_AAAA,
|
||||
TYPE_SRV = TYPE_SRV,
|
||||
TYPE_SPF = TYPE_SPF,
|
||||
CLASS_IN = CLASS_IN,
|
||||
SECTION_AN = SECTION_AN,
|
||||
SECTION_NS = SECTION_NS,
|
||||
SECTION_AR = SECTION_AR
|
||||
}
|
||||
|
||||
|
||||
local resolver_errstrs = {
|
||||
"format error", -- 1
|
||||
"server failure", -- 2
|
||||
"name error", -- 3
|
||||
"not implemented", -- 4
|
||||
"refused", -- 5
|
||||
}
|
||||
|
||||
local soa_int32_fields = { "serial", "refresh", "retry", "expire", "minimum" }
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
local arpa_tmpl = new_tab(72, 0)
|
||||
|
||||
for i = 1, #IP6_ARPA do
|
||||
arpa_tmpl[64 + i] = byte(IP6_ARPA, i)
|
||||
end
|
||||
|
||||
for i = 2, 64, 2 do
|
||||
arpa_tmpl[i] = DOT_CHAR
|
||||
end
|
||||
|
||||
|
||||
function _M.new(class, opts)
|
||||
if not opts then
|
||||
return nil, "no options table specified"
|
||||
end
|
||||
|
||||
local servers = opts.nameservers
|
||||
if not servers or #servers == 0 then
|
||||
return nil, "no nameservers specified"
|
||||
end
|
||||
|
||||
local timeout = opts.timeout or 2000 -- default 2 sec
|
||||
|
||||
local n = #servers
|
||||
|
||||
local socks = {}
|
||||
|
||||
for i = 1, n do
|
||||
local server = servers[i]
|
||||
local sock, err = udp()
|
||||
if not sock then
|
||||
return nil, "failed to create udp socket: " .. err
|
||||
end
|
||||
|
||||
local host, port
|
||||
if type(server) == 'table' then
|
||||
host = server[1]
|
||||
port = server[2] or 53
|
||||
|
||||
else
|
||||
host = server
|
||||
port = 53
|
||||
servers[i] = {host, port}
|
||||
end
|
||||
|
||||
local ok, err = sock:setpeername(host, port)
|
||||
if not ok then
|
||||
return nil, "failed to set peer name: " .. err
|
||||
end
|
||||
|
||||
sock:settimeout(timeout)
|
||||
|
||||
insert(socks, sock)
|
||||
end
|
||||
|
||||
local tcp_sock, err = tcp()
|
||||
if not tcp_sock then
|
||||
return nil, "failed to create tcp socket: " .. err
|
||||
end
|
||||
|
||||
tcp_sock:settimeout(timeout)
|
||||
|
||||
return setmetatable(
|
||||
{ cur = opts.no_random and 1 or rand(1, n),
|
||||
socks = socks,
|
||||
tcp_sock = tcp_sock,
|
||||
servers = servers,
|
||||
retrans = opts.retrans or 5,
|
||||
no_recurse = opts.no_recurse,
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
local function pick_sock(self, socks)
|
||||
local cur = self.cur
|
||||
|
||||
if cur == #socks then
|
||||
self.cur = 1
|
||||
else
|
||||
self.cur = cur + 1
|
||||
end
|
||||
|
||||
return socks[cur]
|
||||
end
|
||||
|
||||
|
||||
local function _get_cur_server(self)
|
||||
local cur = self.cur
|
||||
|
||||
local servers = self.servers
|
||||
|
||||
if cur == 1 then
|
||||
return servers[#servers]
|
||||
end
|
||||
|
||||
return servers[cur - 1]
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, timeout)
|
||||
local socks = self.socks
|
||||
if not socks then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
for i = 1, #socks do
|
||||
local sock = socks[i]
|
||||
sock:settimeout(timeout)
|
||||
end
|
||||
|
||||
local tcp_sock = self.tcp_sock
|
||||
if not tcp_sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
tcp_sock:settimeout(timeout)
|
||||
end
|
||||
|
||||
|
||||
local function _encode_name(s)
|
||||
return char(#s) .. s
|
||||
end
|
||||
|
||||
|
||||
local function _decode_name(buf, pos)
|
||||
local labels = {}
|
||||
local nptrs = 0
|
||||
local p = pos
|
||||
while nptrs < 128 do
|
||||
local fst = byte(buf, p)
|
||||
|
||||
if not fst then
|
||||
return nil, 'truncated';
|
||||
end
|
||||
|
||||
-- print("fst at ", p, ": ", fst)
|
||||
|
||||
if fst == 0 then
|
||||
if nptrs == 0 then
|
||||
pos = pos + 1
|
||||
end
|
||||
break
|
||||
end
|
||||
|
||||
if band(fst, 0xc0) ~= 0 then
|
||||
-- being a pointer
|
||||
if nptrs == 0 then
|
||||
pos = pos + 2
|
||||
end
|
||||
|
||||
nptrs = nptrs + 1
|
||||
|
||||
local snd = byte(buf, p + 1)
|
||||
if not snd then
|
||||
return nil, 'truncated'
|
||||
end
|
||||
|
||||
p = lshift(band(fst, 0x3f), 8) + snd + 1
|
||||
|
||||
-- print("resolving ptr ", p, ": ", byte(buf, p))
|
||||
|
||||
else
|
||||
-- being a label
|
||||
local label = sub(buf, p + 1, p + fst)
|
||||
insert(labels, label)
|
||||
|
||||
-- print("resolved label ", label)
|
||||
|
||||
p = p + fst + 1
|
||||
|
||||
if nptrs == 0 then
|
||||
pos = p
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return concat(labels, "."), pos
|
||||
end
|
||||
|
||||
|
||||
local function _build_request(qname, id, no_recurse, opts)
|
||||
local qtype
|
||||
|
||||
if opts then
|
||||
qtype = opts.qtype
|
||||
end
|
||||
|
||||
if not qtype then
|
||||
qtype = 1 -- A record
|
||||
end
|
||||
|
||||
local ident_hi = char(rshift(id, 8))
|
||||
local ident_lo = char(band(id, 0xff))
|
||||
|
||||
local flags
|
||||
if no_recurse then
|
||||
-- print("found no recurse")
|
||||
flags = "\0\0"
|
||||
else
|
||||
flags = "\1\0"
|
||||
end
|
||||
|
||||
local nqs = "\0\1"
|
||||
local nan = "\0\0"
|
||||
local nns = "\0\0"
|
||||
local nar = "\0\0"
|
||||
local typ = char(rshift(qtype, 8), band(qtype, 0xff))
|
||||
local class = "\0\1" -- the Internet class
|
||||
|
||||
if byte(qname, 1) == DOT_CHAR then
|
||||
return nil, "bad name"
|
||||
end
|
||||
|
||||
local name = gsub(qname, "([^.]+)%.?", _encode_name) .. '\0'
|
||||
|
||||
return {
|
||||
ident_hi, ident_lo, flags, nqs, nan, nns, nar,
|
||||
name, typ, class
|
||||
}
|
||||
end
|
||||
|
||||
|
||||
local function parse_section(answers, section, buf, start_pos, size,
|
||||
should_skip)
|
||||
local pos = start_pos
|
||||
|
||||
for _ = 1, size do
|
||||
-- print(format("ans %d: qtype:%d qclass:%d", i, qtype, qclass))
|
||||
local ans = {}
|
||||
|
||||
if not should_skip then
|
||||
insert(answers, ans)
|
||||
end
|
||||
|
||||
ans.section = section
|
||||
|
||||
local name
|
||||
name, pos = _decode_name(buf, pos)
|
||||
if not name then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
ans.name = name
|
||||
|
||||
-- print("name: ", name)
|
||||
|
||||
local type_hi = byte(buf, pos)
|
||||
local type_lo = byte(buf, pos + 1)
|
||||
local typ = lshift(type_hi, 8) + type_lo
|
||||
|
||||
ans.type = typ
|
||||
|
||||
-- print("type: ", typ)
|
||||
|
||||
local class_hi = byte(buf, pos + 2)
|
||||
local class_lo = byte(buf, pos + 3)
|
||||
local class = lshift(class_hi, 8) + class_lo
|
||||
|
||||
ans.class = class
|
||||
|
||||
-- print("class: ", class)
|
||||
|
||||
local byte_1, byte_2, byte_3, byte_4 = byte(buf, pos + 4, pos + 7)
|
||||
|
||||
local ttl = lshift(byte_1, 24) + lshift(byte_2, 16)
|
||||
+ lshift(byte_3, 8) + byte_4
|
||||
|
||||
-- print("ttl: ", ttl)
|
||||
|
||||
ans.ttl = ttl
|
||||
|
||||
local len_hi = byte(buf, pos + 8)
|
||||
local len_lo = byte(buf, pos + 9)
|
||||
local len = lshift(len_hi, 8) + len_lo
|
||||
|
||||
-- print("record len: ", len)
|
||||
|
||||
pos = pos + 10
|
||||
|
||||
if typ == TYPE_A then
|
||||
|
||||
if len ~= 4 then
|
||||
return nil, "bad A record value length: " .. len
|
||||
end
|
||||
|
||||
local addr_bytes = { byte(buf, pos, pos + 3) }
|
||||
local addr = concat(addr_bytes, ".")
|
||||
-- print("ipv4 address: ", addr)
|
||||
|
||||
ans.address = addr
|
||||
|
||||
pos = pos + 4
|
||||
|
||||
elseif typ == TYPE_CNAME then
|
||||
|
||||
local cname, p = _decode_name(buf, pos)
|
||||
if not cname then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
if p - pos ~= len then
|
||||
return nil, format("bad cname record length: %d ~= %d",
|
||||
p - pos, len)
|
||||
end
|
||||
|
||||
pos = p
|
||||
|
||||
-- print("cname: ", cname)
|
||||
|
||||
ans.cname = cname
|
||||
|
||||
elseif typ == TYPE_AAAA then
|
||||
|
||||
if len ~= 16 then
|
||||
return nil, "bad AAAA record value length: " .. len
|
||||
end
|
||||
|
||||
local addr_bytes = { byte(buf, pos, pos + 15) }
|
||||
local flds = {}
|
||||
for i = 1, 16, 2 do
|
||||
local a = addr_bytes[i]
|
||||
local b = addr_bytes[i + 1]
|
||||
if a == 0 then
|
||||
insert(flds, format("%x", b))
|
||||
|
||||
else
|
||||
insert(flds, format("%x%02x", a, b))
|
||||
end
|
||||
end
|
||||
|
||||
-- we do not compress the IPv6 addresses by default
|
||||
-- due to performance considerations
|
||||
|
||||
ans.address = concat(flds, ":")
|
||||
|
||||
pos = pos + 16
|
||||
|
||||
elseif typ == TYPE_MX then
|
||||
|
||||
-- print("len = ", len)
|
||||
|
||||
if len < 3 then
|
||||
return nil, "bad MX record value length: " .. len
|
||||
end
|
||||
|
||||
local pref_hi = byte(buf, pos)
|
||||
local pref_lo = byte(buf, pos + 1)
|
||||
|
||||
ans.preference = lshift(pref_hi, 8) + pref_lo
|
||||
|
||||
local host, p = _decode_name(buf, pos + 2)
|
||||
if not host then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
if p - pos ~= len then
|
||||
return nil, format("bad cname record length: %d ~= %d",
|
||||
p - pos, len)
|
||||
end
|
||||
|
||||
ans.exchange = host
|
||||
|
||||
pos = p
|
||||
|
||||
elseif typ == TYPE_SRV then
|
||||
if len < 7 then
|
||||
return nil, "bad SRV record value length: " .. len
|
||||
end
|
||||
|
||||
local prio_hi = byte(buf, pos)
|
||||
local prio_lo = byte(buf, pos + 1)
|
||||
ans.priority = lshift(prio_hi, 8) + prio_lo
|
||||
|
||||
local weight_hi = byte(buf, pos + 2)
|
||||
local weight_lo = byte(buf, pos + 3)
|
||||
ans.weight = lshift(weight_hi, 8) + weight_lo
|
||||
|
||||
local port_hi = byte(buf, pos + 4)
|
||||
local port_lo = byte(buf, pos + 5)
|
||||
ans.port = lshift(port_hi, 8) + port_lo
|
||||
|
||||
local name, p = _decode_name(buf, pos + 6)
|
||||
if not name then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
if p - pos ~= len then
|
||||
return nil, format("bad srv record length: %d ~= %d",
|
||||
p - pos, len)
|
||||
end
|
||||
|
||||
ans.target = name
|
||||
|
||||
pos = p
|
||||
|
||||
elseif typ == TYPE_NS then
|
||||
|
||||
local name, p = _decode_name(buf, pos)
|
||||
if not name then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
if p - pos ~= len then
|
||||
return nil, format("bad cname record length: %d ~= %d",
|
||||
p - pos, len)
|
||||
end
|
||||
|
||||
pos = p
|
||||
|
||||
-- print("name: ", name)
|
||||
|
||||
ans.nsdname = name
|
||||
|
||||
elseif typ == TYPE_TXT or typ == TYPE_SPF then
|
||||
|
||||
local key = (typ == TYPE_TXT) and "txt" or "spf"
|
||||
|
||||
local slen = byte(buf, pos)
|
||||
if slen + 1 > len then
|
||||
-- truncate the over-run TXT record data
|
||||
slen = len
|
||||
end
|
||||
|
||||
-- print("slen: ", len)
|
||||
|
||||
local val = sub(buf, pos + 1, pos + slen)
|
||||
local last = pos + len
|
||||
pos = pos + slen + 1
|
||||
|
||||
if pos < last then
|
||||
-- more strings to be processed
|
||||
-- this code path is usually cold, so we do not
|
||||
-- merge the following loop on this code path
|
||||
-- with the processing logic above.
|
||||
|
||||
val = {val}
|
||||
local idx = 2
|
||||
repeat
|
||||
local slen = byte(buf, pos)
|
||||
if pos + slen + 1 > last then
|
||||
-- truncate the over-run TXT record data
|
||||
slen = last - pos - 1
|
||||
end
|
||||
|
||||
val[idx] = sub(buf, pos + 1, pos + slen)
|
||||
idx = idx + 1
|
||||
pos = pos + slen + 1
|
||||
|
||||
until pos >= last
|
||||
end
|
||||
|
||||
ans[key] = val
|
||||
|
||||
elseif typ == TYPE_PTR then
|
||||
|
||||
local name, p = _decode_name(buf, pos)
|
||||
if not name then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
if p - pos ~= len then
|
||||
return nil, format("bad cname record length: %d ~= %d",
|
||||
p - pos, len)
|
||||
end
|
||||
|
||||
pos = p
|
||||
|
||||
-- print("name: ", name)
|
||||
|
||||
ans.ptrdname = name
|
||||
|
||||
elseif typ == TYPE_SOA then
|
||||
local name, p = _decode_name(buf, pos)
|
||||
if not name then
|
||||
return nil, pos
|
||||
end
|
||||
ans.mname = name
|
||||
|
||||
pos = p
|
||||
name, p = _decode_name(buf, pos)
|
||||
if not name then
|
||||
return nil, pos
|
||||
end
|
||||
ans.rname = name
|
||||
|
||||
for _, field in ipairs(soa_int32_fields) do
|
||||
local byte_1, byte_2, byte_3, byte_4 = byte(buf, p, p + 3)
|
||||
ans[field] = lshift(byte_1, 24) + lshift(byte_2, 16)
|
||||
+ lshift(byte_3, 8) + byte_4
|
||||
p = p + 4
|
||||
end
|
||||
|
||||
pos = p
|
||||
|
||||
else
|
||||
-- for unknown types, just forward the raw value
|
||||
|
||||
ans.rdata = sub(buf, pos, pos + len - 1)
|
||||
pos = pos + len
|
||||
end
|
||||
end
|
||||
|
||||
return pos
|
||||
end
|
||||
|
||||
|
||||
local function parse_response(buf, id, opts)
|
||||
local n = #buf
|
||||
if n < 12 then
|
||||
return nil, 'truncated';
|
||||
end
|
||||
|
||||
-- header layout: ident flags nqs nan nns nar
|
||||
|
||||
local ident_hi = byte(buf, 1)
|
||||
local ident_lo = byte(buf, 2)
|
||||
local ans_id = lshift(ident_hi, 8) + ident_lo
|
||||
|
||||
-- print("id: ", id, ", ans id: ", ans_id)
|
||||
|
||||
if ans_id ~= id then
|
||||
-- identifier mismatch and throw it away
|
||||
log(DEBUG, "id mismatch in the DNS reply: ", ans_id, " ~= ", id)
|
||||
return nil, "id mismatch"
|
||||
end
|
||||
|
||||
local flags_hi = byte(buf, 3)
|
||||
local flags_lo = byte(buf, 4)
|
||||
local flags = lshift(flags_hi, 8) + flags_lo
|
||||
|
||||
-- print(format("flags: 0x%x", flags))
|
||||
|
||||
if band(flags, 0x8000) == 0 then
|
||||
return nil, format("bad QR flag in the DNS response")
|
||||
end
|
||||
|
||||
if band(flags, 0x200) ~= 0 then
|
||||
return nil, "truncated"
|
||||
end
|
||||
|
||||
local code = band(flags, 0xf)
|
||||
|
||||
-- print(format("code: %d", code))
|
||||
|
||||
local nqs_hi = byte(buf, 5)
|
||||
local nqs_lo = byte(buf, 6)
|
||||
local nqs = lshift(nqs_hi, 8) + nqs_lo
|
||||
|
||||
-- print("nqs: ", nqs)
|
||||
|
||||
if nqs ~= 1 then
|
||||
return nil, format("bad number of questions in DNS response: %d", nqs)
|
||||
end
|
||||
|
||||
local nan_hi = byte(buf, 7)
|
||||
local nan_lo = byte(buf, 8)
|
||||
local nan = lshift(nan_hi, 8) + nan_lo
|
||||
|
||||
-- print("nan: ", nan)
|
||||
|
||||
local nns_hi = byte(buf, 9)
|
||||
local nns_lo = byte(buf, 10)
|
||||
local nns = lshift(nns_hi, 8) + nns_lo
|
||||
|
||||
local nar_hi = byte(buf, 11)
|
||||
local nar_lo = byte(buf, 12)
|
||||
local nar = lshift(nar_hi, 8) + nar_lo
|
||||
|
||||
-- skip the question part
|
||||
|
||||
local ans_qname, pos = _decode_name(buf, 13)
|
||||
if not ans_qname then
|
||||
return nil, pos
|
||||
end
|
||||
|
||||
-- print("qname in reply: ", ans_qname)
|
||||
|
||||
-- print("question: ", sub(buf, 13, pos))
|
||||
|
||||
if pos + 3 + nan * 12 > n then
|
||||
-- print(format("%d > %d", pos + 3 + nan * 12, n))
|
||||
return nil, 'truncated';
|
||||
end
|
||||
|
||||
-- question section layout: qname qtype(2) qclass(2)
|
||||
|
||||
--[[
|
||||
local type_hi = byte(buf, pos)
|
||||
local type_lo = byte(buf, pos + 1)
|
||||
local ans_type = lshift(type_hi, 8) + type_lo
|
||||
]]
|
||||
|
||||
-- print("ans qtype: ", ans_type)
|
||||
|
||||
local class_hi = byte(buf, pos + 2)
|
||||
local class_lo = byte(buf, pos + 3)
|
||||
local qclass = lshift(class_hi, 8) + class_lo
|
||||
|
||||
-- print("ans qclass: ", qclass)
|
||||
|
||||
if qclass ~= 1 then
|
||||
return nil, format("unknown query class %d in DNS response", qclass)
|
||||
end
|
||||
|
||||
pos = pos + 4
|
||||
|
||||
local answers = {}
|
||||
|
||||
if code ~= 0 then
|
||||
answers.errcode = code
|
||||
answers.errstr = resolver_errstrs[code] or "unknown"
|
||||
end
|
||||
|
||||
local authority_section, additional_section
|
||||
|
||||
if opts then
|
||||
authority_section = opts.authority_section
|
||||
additional_section = opts.additional_section
|
||||
if opts.qtype == TYPE_SOA then
|
||||
authority_section = true
|
||||
end
|
||||
end
|
||||
|
||||
local err
|
||||
|
||||
pos, err = parse_section(answers, SECTION_AN, buf, pos, nan)
|
||||
|
||||
if not pos then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if not authority_section and not additional_section then
|
||||
return answers
|
||||
end
|
||||
|
||||
pos, err = parse_section(answers, SECTION_NS, buf, pos, nns,
|
||||
not authority_section)
|
||||
|
||||
if not pos then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if not additional_section then
|
||||
return answers
|
||||
end
|
||||
|
||||
pos, err = parse_section(answers, SECTION_AR, buf, pos, nar)
|
||||
|
||||
if not pos then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return answers
|
||||
end
|
||||
|
||||
|
||||
local function _gen_id(self)
|
||||
local id = self._id -- for regression testing
|
||||
if id then
|
||||
return id
|
||||
end
|
||||
return rand(0, 65535) -- two bytes
|
||||
end
|
||||
|
||||
|
||||
local function _tcp_query(self, query, id, opts)
|
||||
local sock = self.tcp_sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
log(DEBUG, "query the TCP server due to reply truncation")
|
||||
|
||||
local server = _get_cur_server(self)
|
||||
|
||||
local ok, err = sock:connect(server[1], server[2])
|
||||
if not ok then
|
||||
return nil, "failed to connect to TCP server "
|
||||
.. concat(server, ":") .. ": " .. err
|
||||
end
|
||||
|
||||
query = concat(query, "")
|
||||
local len = #query
|
||||
|
||||
local len_hi = char(rshift(len, 8))
|
||||
local len_lo = char(band(len, 0xff))
|
||||
|
||||
local bytes, err = sock:send({len_hi, len_lo, query})
|
||||
if not bytes then
|
||||
return nil, "failed to send query to TCP server "
|
||||
.. concat(server, ":") .. ": " .. err
|
||||
end
|
||||
|
||||
local buf, err = sock:receive(2)
|
||||
if not buf then
|
||||
return nil, "failed to receive the reply length field from TCP server "
|
||||
.. concat(server, ":") .. ": " .. err
|
||||
end
|
||||
|
||||
len_hi = byte(buf, 1)
|
||||
len_lo = byte(buf, 2)
|
||||
len = lshift(len_hi, 8) + len_lo
|
||||
|
||||
-- print("tcp message len: ", len)
|
||||
|
||||
buf, err = sock:receive(len)
|
||||
if not buf then
|
||||
return nil, "failed to receive the reply message body from TCP server "
|
||||
.. concat(server, ":") .. ": " .. err
|
||||
end
|
||||
|
||||
local answers, err = parse_response(buf, id, opts)
|
||||
if not answers then
|
||||
return nil, "failed to parse the reply from the TCP server "
|
||||
.. concat(server, ":") .. ": " .. err
|
||||
end
|
||||
|
||||
sock:close()
|
||||
|
||||
return answers
|
||||
end
|
||||
|
||||
|
||||
function _M.tcp_query(self, qname, opts)
|
||||
local socks = self.socks
|
||||
if not socks then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
pick_sock(self, socks)
|
||||
|
||||
local id = _gen_id(self)
|
||||
|
||||
local query, err = _build_request(qname, id, self.no_recurse, opts)
|
||||
if not query then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return _tcp_query(self, query, id, opts)
|
||||
end
|
||||
|
||||
|
||||
function _M.query(self, qname, opts, tries)
|
||||
local socks = self.socks
|
||||
if not socks then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local id = _gen_id(self)
|
||||
|
||||
local query, err = _build_request(qname, id, self.no_recurse, opts)
|
||||
if not query then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
-- local cjson = require "cjson"
|
||||
-- print("query: ", cjson.encode(concat(query, "")))
|
||||
|
||||
local retrans = self.retrans
|
||||
if tries then
|
||||
tries[1] = nil
|
||||
end
|
||||
|
||||
-- print("retrans: ", retrans)
|
||||
|
||||
for i = 1, retrans do
|
||||
local sock = pick_sock(self, socks)
|
||||
|
||||
local ok
|
||||
ok, err = sock:send(query)
|
||||
if not ok then
|
||||
local server = _get_cur_server(self)
|
||||
err = "failed to send request to UDP server "
|
||||
.. concat(server, ":") .. ": " .. err
|
||||
|
||||
else
|
||||
local buf
|
||||
|
||||
for _ = 1, 128 do
|
||||
buf, err = sock:receive(4096)
|
||||
if err then
|
||||
local server = _get_cur_server(self)
|
||||
err = "failed to receive reply from UDP server "
|
||||
.. concat(server, ":") .. ": " .. err
|
||||
break
|
||||
end
|
||||
|
||||
if buf then
|
||||
local answers
|
||||
answers, err = parse_response(buf, id, opts)
|
||||
if err == "truncated" then
|
||||
answers, err = _tcp_query(self, query, id, opts)
|
||||
end
|
||||
|
||||
if err and err ~= "id mismatch" then
|
||||
break
|
||||
end
|
||||
|
||||
if answers then
|
||||
return answers, nil, tries
|
||||
end
|
||||
end
|
||||
-- only here in case of an "id mismatch"
|
||||
end
|
||||
end
|
||||
|
||||
if tries then
|
||||
tries[i] = err
|
||||
tries[i + 1] = nil -- ensure termination for user supplied table
|
||||
end
|
||||
end
|
||||
|
||||
return nil, err, tries
|
||||
end
|
||||
|
||||
|
||||
function _M.compress_ipv6_addr(addr)
|
||||
local addr = re_sub(addr, "^(0:)+|(:0)+$|:(0:)+", "::", "jo")
|
||||
if addr == "::0" then
|
||||
addr = "::"
|
||||
end
|
||||
|
||||
return addr
|
||||
end
|
||||
|
||||
|
||||
local function _expand_ipv6_addr(addr)
|
||||
if find(addr, "::", 1, true) then
|
||||
local ncol, addrlen = 8, #addr
|
||||
|
||||
for i = 1, addrlen do
|
||||
if byte(addr, i) == COLON_CHAR then
|
||||
ncol = ncol - 1
|
||||
end
|
||||
end
|
||||
|
||||
if byte(addr, 1) == COLON_CHAR then
|
||||
addr = "0" .. addr
|
||||
end
|
||||
|
||||
if byte(addr, -1) == COLON_CHAR then
|
||||
addr = addr .. "0"
|
||||
end
|
||||
|
||||
addr = re_sub(addr, "::", ":" .. rep("0:", ncol), "jo")
|
||||
end
|
||||
|
||||
return addr
|
||||
end
|
||||
|
||||
|
||||
_M.expand_ipv6_addr = _expand_ipv6_addr
|
||||
|
||||
|
||||
function _M.arpa_str(addr)
|
||||
if find(addr, ":", 1, true) then
|
||||
addr = _expand_ipv6_addr(addr)
|
||||
local idx, hidx, addrlen = 1, 1, #addr
|
||||
|
||||
for i = addrlen, 0, -1 do
|
||||
local s = byte(addr, i)
|
||||
if s == COLON_CHAR or not s then
|
||||
for _ = hidx, 4 do
|
||||
arpa_tmpl[idx] = ZERO_CHAR
|
||||
idx = idx + 2
|
||||
end
|
||||
hidx = 1
|
||||
else
|
||||
arpa_tmpl[idx] = s
|
||||
idx = idx + 2
|
||||
hidx = hidx + 1
|
||||
end
|
||||
end
|
||||
|
||||
addr = char(unpack(arpa_tmpl))
|
||||
else
|
||||
addr = re_sub(addr, [[(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})]],
|
||||
"$4.$3.$2.$1.in-addr.arpa", "ajo")
|
||||
end
|
||||
|
||||
return addr
|
||||
end
|
||||
|
||||
|
||||
function _M.reverse_query(self, addr)
|
||||
return self.query(self, self.arpa_str(addr),
|
||||
{qtype = self.TYPE_PTR})
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
125
resty/limit/conn.lua
Normal file
125
resty/limit/conn.lua
Normal file
@@ -0,0 +1,125 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
--
|
||||
-- This library is an enhanced Lua port of the standard ngx_limit_conn
|
||||
-- module.
|
||||
|
||||
|
||||
local math = require "math"
|
||||
|
||||
|
||||
local setmetatable = setmetatable
|
||||
local floor = math.floor
|
||||
local ngx_shared = ngx.shared
|
||||
local assert = assert
|
||||
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.07'
|
||||
}
|
||||
|
||||
|
||||
local mt = {
|
||||
__index = _M
|
||||
}
|
||||
|
||||
|
||||
function _M.new(dict_name, max, burst, default_conn_delay)
|
||||
local dict = ngx_shared[dict_name]
|
||||
if not dict then
|
||||
return nil, "shared dict not found"
|
||||
end
|
||||
|
||||
assert(max > 0 and burst >= 0 and default_conn_delay > 0)
|
||||
|
||||
local self = {
|
||||
dict = dict,
|
||||
max = max + 0, -- just to ensure the param is good
|
||||
burst = burst,
|
||||
unit_delay = default_conn_delay,
|
||||
}
|
||||
|
||||
return setmetatable(self, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.incoming(self, key, commit)
|
||||
local dict = self.dict
|
||||
local max = self.max
|
||||
|
||||
self.committed = false
|
||||
|
||||
local conn, err
|
||||
if commit then
|
||||
conn, err = dict:incr(key, 1, 0)
|
||||
if not conn then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if conn > max + self.burst then
|
||||
conn, err = dict:incr(key, -1)
|
||||
if not conn then
|
||||
return nil, err
|
||||
end
|
||||
return nil, "rejected"
|
||||
end
|
||||
self.committed = true
|
||||
|
||||
else
|
||||
conn = (dict:get(key) or 0) + 1
|
||||
if conn > max + self.burst then
|
||||
return nil, "rejected"
|
||||
end
|
||||
end
|
||||
|
||||
if conn > max then
|
||||
-- make the exessive connections wait
|
||||
return self.unit_delay * floor((conn - 1) / max), conn
|
||||
end
|
||||
|
||||
-- we return a 0 delay by default
|
||||
return 0, conn
|
||||
end
|
||||
|
||||
|
||||
function _M.is_committed(self)
|
||||
return self.committed
|
||||
end
|
||||
|
||||
|
||||
function _M.leaving(self, key, req_latency)
|
||||
assert(key)
|
||||
local dict = self.dict
|
||||
|
||||
local conn, err = dict:incr(key, -1)
|
||||
if not conn then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if req_latency then
|
||||
local unit_delay = self.unit_delay
|
||||
self.unit_delay = (req_latency + unit_delay) / 2
|
||||
end
|
||||
|
||||
return conn
|
||||
end
|
||||
|
||||
|
||||
function _M.uncommit(self, key)
|
||||
assert(key)
|
||||
local dict = self.dict
|
||||
|
||||
return dict:incr(key, -1)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_conn(self, conn)
|
||||
self.max = conn
|
||||
end
|
||||
|
||||
|
||||
function _M.set_burst(self, burst)
|
||||
self.burst = burst
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
103
resty/limit/count.lua
Normal file
103
resty/limit/count.lua
Normal file
@@ -0,0 +1,103 @@
|
||||
-- implement GitHub request rate limiting:
|
||||
-- https://developer.github.com/v3/#rate-limiting
|
||||
|
||||
local ngx_shared = ngx.shared
|
||||
local setmetatable = setmetatable
|
||||
local assert = assert
|
||||
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.07'
|
||||
}
|
||||
|
||||
|
||||
local mt = {
|
||||
__index = _M
|
||||
}
|
||||
|
||||
|
||||
-- the "limit" argument controls number of request allowed in a time window.
|
||||
-- time "window" argument controls the time window in seconds.
|
||||
function _M.new(dict_name, limit, window)
|
||||
local dict = ngx_shared[dict_name]
|
||||
if not dict then
|
||||
return nil, "shared dict not found"
|
||||
end
|
||||
|
||||
assert(limit > 0 and window > 0)
|
||||
|
||||
local self = {
|
||||
dict = dict,
|
||||
limit = limit,
|
||||
window = window,
|
||||
}
|
||||
|
||||
return setmetatable(self, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.incoming(self, key, commit)
|
||||
local dict = self.dict
|
||||
local limit = self.limit
|
||||
local window = self.window
|
||||
|
||||
local remaining, ok, err
|
||||
|
||||
if commit then
|
||||
remaining, err = dict:incr(key, -1, limit)
|
||||
if not remaining then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if remaining == limit - 1 then
|
||||
ok, err = dict:expire(key, window)
|
||||
if not ok then
|
||||
if err == "not found" then
|
||||
remaining, err = dict:incr(key, -1, limit)
|
||||
if not remaining then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
ok, err = dict:expire(key, window)
|
||||
if not ok then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
else
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
else
|
||||
remaining = (dict:get(key) or limit) - 1
|
||||
end
|
||||
|
||||
if remaining < 0 then
|
||||
return nil, "rejected"
|
||||
end
|
||||
|
||||
return 0, remaining
|
||||
end
|
||||
|
||||
|
||||
-- uncommit remaining and return remaining value
|
||||
function _M.uncommit(self, key)
|
||||
assert(key)
|
||||
local dict = self.dict
|
||||
local limit = self.limit
|
||||
|
||||
local remaining, err = dict:incr(key, 1)
|
||||
if not remaining then
|
||||
if err == "not found" then
|
||||
remaining = limit
|
||||
else
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
|
||||
return remaining
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
153
resty/limit/req.lua
Normal file
153
resty/limit/req.lua
Normal file
@@ -0,0 +1,153 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
--
|
||||
-- This library is an approximate Lua port of the standard ngx_limit_req
|
||||
-- module.
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local math = require "math"
|
||||
|
||||
|
||||
local ngx_shared = ngx.shared
|
||||
local ngx_now = ngx.now
|
||||
local setmetatable = setmetatable
|
||||
local ffi_cast = ffi.cast
|
||||
local ffi_str = ffi.string
|
||||
local abs = math.abs
|
||||
local tonumber = tonumber
|
||||
local type = type
|
||||
local assert = assert
|
||||
local max = math.max
|
||||
|
||||
|
||||
-- TODO: we could avoid the tricky FFI cdata when lua_shared_dict supports
|
||||
-- hash-typed values as in redis.
|
||||
ffi.cdef[[
|
||||
struct lua_resty_limit_req_rec {
|
||||
unsigned long excess;
|
||||
uint64_t last; /* time in milliseconds */
|
||||
/* integer value, 1 corresponds to 0.001 r/s */
|
||||
};
|
||||
]]
|
||||
local const_rec_ptr_type = ffi.typeof("const struct lua_resty_limit_req_rec*")
|
||||
local rec_size = ffi.sizeof("struct lua_resty_limit_req_rec")
|
||||
|
||||
-- we can share the cdata here since we only need it temporarily for
|
||||
-- serialization inside the shared dict:
|
||||
local rec_cdata = ffi.new("struct lua_resty_limit_req_rec")
|
||||
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.07'
|
||||
}
|
||||
|
||||
|
||||
local mt = {
|
||||
__index = _M
|
||||
}
|
||||
|
||||
|
||||
function _M.new(dict_name, rate, burst)
|
||||
local dict = ngx_shared[dict_name]
|
||||
if not dict then
|
||||
return nil, "shared dict not found"
|
||||
end
|
||||
|
||||
assert(rate > 0 and burst >= 0)
|
||||
|
||||
local self = {
|
||||
dict = dict,
|
||||
rate = rate * 1000,
|
||||
burst = burst * 1000,
|
||||
}
|
||||
|
||||
return setmetatable(self, mt)
|
||||
end
|
||||
|
||||
|
||||
-- sees an new incoming event
|
||||
-- the "commit" argument controls whether should we record the event in shm.
|
||||
-- FIXME we have a (small) race-condition window between dict:get() and
|
||||
-- dict:set() across multiple nginx worker processes. The size of the
|
||||
-- window is proportional to the number of workers.
|
||||
function _M.incoming(self, key, commit)
|
||||
local dict = self.dict
|
||||
local rate = self.rate
|
||||
local now = ngx_now() * 1000
|
||||
|
||||
local excess
|
||||
|
||||
-- it's important to anchor the string value for the read-only pointer
|
||||
-- cdata:
|
||||
local v = dict:get(key)
|
||||
if v then
|
||||
if type(v) ~= "string" or #v ~= rec_size then
|
||||
return nil, "shdict abused by other users"
|
||||
end
|
||||
local rec = ffi_cast(const_rec_ptr_type, v)
|
||||
local elapsed = now - tonumber(rec.last)
|
||||
|
||||
-- print("elapsed: ", elapsed, "ms")
|
||||
|
||||
-- we do not handle changing rate values specifically. the excess value
|
||||
-- can get automatically adjusted by the following formula with new rate
|
||||
-- values rather quickly anyway.
|
||||
excess = max(tonumber(rec.excess) - rate * abs(elapsed) / 1000 + 1000,
|
||||
0)
|
||||
|
||||
-- print("excess: ", excess)
|
||||
|
||||
if excess > self.burst then
|
||||
return nil, "rejected"
|
||||
end
|
||||
|
||||
else
|
||||
excess = 0
|
||||
end
|
||||
|
||||
if commit then
|
||||
rec_cdata.excess = excess
|
||||
rec_cdata.last = now
|
||||
dict:set(key, ffi_str(rec_cdata, rec_size))
|
||||
end
|
||||
|
||||
-- return the delay in seconds, as well as excess
|
||||
return excess / rate, excess / 1000
|
||||
end
|
||||
|
||||
|
||||
function _M.uncommit(self, key)
|
||||
assert(key)
|
||||
local dict = self.dict
|
||||
|
||||
local v = dict:get(key)
|
||||
if not v then
|
||||
return nil, "not found"
|
||||
end
|
||||
|
||||
if type(v) ~= "string" or #v ~= rec_size then
|
||||
return nil, "shdict abused by other users"
|
||||
end
|
||||
|
||||
local rec = ffi_cast(const_rec_ptr_type, v)
|
||||
|
||||
local excess = max(tonumber(rec.excess) - 1000, 0)
|
||||
|
||||
rec_cdata.excess = excess
|
||||
rec_cdata.last = rec.last
|
||||
dict:set(key, ffi_str(rec_cdata, rec_size))
|
||||
return true
|
||||
end
|
||||
|
||||
|
||||
function _M.set_rate(self, rate)
|
||||
self.rate = rate * 1000
|
||||
end
|
||||
|
||||
|
||||
function _M.set_burst(self, burst)
|
||||
self.burst = burst * 1000
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
58
resty/limit/traffic.lua
Normal file
58
resty/limit/traffic.lua
Normal file
@@ -0,0 +1,58 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
--
|
||||
-- This is an aggregator for various concrete traffic limiter instances
|
||||
-- (like instances of the resty.limit.req, resty.limit.count and
|
||||
-- resty.limit.conn classes).
|
||||
|
||||
|
||||
local max = math.max
|
||||
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.07'
|
||||
}
|
||||
|
||||
|
||||
-- the states table is user supplied. each element stores the 2nd return value
|
||||
-- of each limiter if there is no error returned. for resty.limit.req, the state
|
||||
-- is the "excess" value (i.e., the number of excessive requests each second),
|
||||
-- and for resty.limit.conn, the state is the current concurrency level
|
||||
-- (including the current new connection).
|
||||
function _M.combine(limiters, keys, states)
|
||||
local n = #limiters
|
||||
local max_delay = 0
|
||||
for i = 1, n do
|
||||
local lim = limiters[i]
|
||||
local delay, err = lim:incoming(keys[i], i == n)
|
||||
if not delay then
|
||||
return nil, err
|
||||
end
|
||||
if i == n then
|
||||
if states then
|
||||
states[i] = err
|
||||
end
|
||||
max_delay = delay
|
||||
end
|
||||
end
|
||||
for i = 1, n - 1 do
|
||||
local lim = limiters[i]
|
||||
local delay, err = lim:incoming(keys[i], true)
|
||||
if not delay then
|
||||
for j = 1, i - 1 do
|
||||
-- we intentionally ignore any errors returned below.
|
||||
limiters[j]:uncommit(keys[j])
|
||||
end
|
||||
limiters[n]:uncommit(keys[n])
|
||||
return nil, err
|
||||
end
|
||||
if states then
|
||||
states[i] = err
|
||||
end
|
||||
|
||||
max_delay = max(max_delay, delay)
|
||||
end
|
||||
return max_delay
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
221
resty/lock.lua
Normal file
221
resty/lock.lua
Normal file
@@ -0,0 +1,221 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
require "resty.core.shdict" -- enforce this to avoid dead locks
|
||||
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local shared = ngx.shared
|
||||
local sleep = ngx.sleep
|
||||
local log = ngx.log
|
||||
local max = math.max
|
||||
local min = math.min
|
||||
local debug = ngx.config.debug
|
||||
local setmetatable = setmetatable
|
||||
local tonumber = tonumber
|
||||
|
||||
local _M = { _VERSION = '0.08' }
|
||||
local mt = { __index = _M }
|
||||
|
||||
local ERR = ngx.ERR
|
||||
local FREE_LIST_REF = 0
|
||||
|
||||
-- FIXME: we don't need this when we have __gc metamethod support on Lua
|
||||
-- tables.
|
||||
local memo = {}
|
||||
if debug then _M.memo = memo end
|
||||
|
||||
|
||||
local function ref_obj(key)
|
||||
if key == nil then
|
||||
return -1
|
||||
end
|
||||
local ref = memo[FREE_LIST_REF]
|
||||
if ref and ref ~= 0 then
|
||||
memo[FREE_LIST_REF] = memo[ref]
|
||||
|
||||
else
|
||||
ref = #memo + 1
|
||||
end
|
||||
memo[ref] = key
|
||||
|
||||
-- print("ref key_id returned ", ref)
|
||||
return ref
|
||||
end
|
||||
if debug then _M.ref_obj = ref_obj end
|
||||
|
||||
|
||||
local function unref_obj(ref)
|
||||
if ref >= 0 then
|
||||
memo[ref] = memo[FREE_LIST_REF]
|
||||
memo[FREE_LIST_REF] = ref
|
||||
end
|
||||
end
|
||||
if debug then _M.unref_obj = unref_obj end
|
||||
|
||||
|
||||
local function gc_lock(cdata)
|
||||
local dict_id = tonumber(cdata.dict_id)
|
||||
local key_id = tonumber(cdata.key_id)
|
||||
|
||||
-- print("key_id: ", key_id, ", key: ", memo[key_id], "dict: ",
|
||||
-- type(memo[cdata.dict_id]))
|
||||
if key_id > 0 then
|
||||
local key = memo[key_id]
|
||||
unref_obj(key_id)
|
||||
local dict = memo[dict_id]
|
||||
-- print("dict.delete type: ", type(dict.delete))
|
||||
local ok, err = dict:delete(key)
|
||||
if not ok then
|
||||
log(ERR, 'failed to delete key "', key, '": ', err)
|
||||
end
|
||||
cdata.key_id = 0
|
||||
end
|
||||
|
||||
unref_obj(dict_id)
|
||||
end
|
||||
|
||||
|
||||
local ctype = ffi.metatype("struct { int key_id; int dict_id; }",
|
||||
{ __gc = gc_lock })
|
||||
|
||||
|
||||
function _M.new(_, dict_name, opts)
|
||||
local dict = shared[dict_name]
|
||||
if not dict then
|
||||
return nil, "dictionary not found"
|
||||
end
|
||||
local cdata = ffi_new(ctype)
|
||||
cdata.key_id = 0
|
||||
cdata.dict_id = ref_obj(dict)
|
||||
|
||||
local timeout, exptime, step, ratio, max_step
|
||||
if opts then
|
||||
timeout = opts.timeout
|
||||
exptime = opts.exptime
|
||||
step = opts.step
|
||||
ratio = opts.ratio
|
||||
max_step = opts.max_step
|
||||
end
|
||||
|
||||
if not exptime then
|
||||
exptime = 30
|
||||
end
|
||||
|
||||
if timeout then
|
||||
timeout = min(timeout, exptime)
|
||||
|
||||
if step then
|
||||
step = min(step, timeout)
|
||||
end
|
||||
end
|
||||
|
||||
local self = {
|
||||
cdata = cdata,
|
||||
dict = dict,
|
||||
timeout = timeout or 5,
|
||||
exptime = exptime,
|
||||
step = step or 0.001,
|
||||
ratio = ratio or 2,
|
||||
max_step = max_step or 0.5,
|
||||
}
|
||||
return setmetatable(self, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.lock(self, key)
|
||||
if not key then
|
||||
return nil, "nil key"
|
||||
end
|
||||
|
||||
local dict = self.dict
|
||||
local cdata = self.cdata
|
||||
if cdata.key_id > 0 then
|
||||
return nil, "locked"
|
||||
end
|
||||
local exptime = self.exptime
|
||||
local ok, err = dict:add(key, true, exptime)
|
||||
if ok then
|
||||
cdata.key_id = ref_obj(key)
|
||||
self.key = key
|
||||
return 0
|
||||
end
|
||||
if err ~= "exists" then
|
||||
return nil, err
|
||||
end
|
||||
-- lock held by others
|
||||
local step = self.step
|
||||
local ratio = self.ratio
|
||||
local timeout = self.timeout
|
||||
local max_step = self.max_step
|
||||
local elapsed = 0
|
||||
while timeout > 0 do
|
||||
sleep(step)
|
||||
elapsed = elapsed + step
|
||||
timeout = timeout - step
|
||||
|
||||
local ok, err = dict:add(key, true, exptime)
|
||||
if ok then
|
||||
cdata.key_id = ref_obj(key)
|
||||
self.key = key
|
||||
return elapsed
|
||||
end
|
||||
|
||||
if err ~= "exists" then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if timeout <= 0 then
|
||||
break
|
||||
end
|
||||
|
||||
step = min(max(0.001, step * ratio), timeout, max_step)
|
||||
end
|
||||
|
||||
return nil, "timeout"
|
||||
end
|
||||
|
||||
|
||||
function _M.unlock(self)
|
||||
local dict = self.dict
|
||||
local cdata = self.cdata
|
||||
local key_id = tonumber(cdata.key_id)
|
||||
if key_id <= 0 then
|
||||
return nil, "unlocked"
|
||||
end
|
||||
|
||||
local key = memo[key_id]
|
||||
unref_obj(key_id)
|
||||
|
||||
local ok, err = dict:delete(key)
|
||||
if not ok then
|
||||
return nil, err
|
||||
end
|
||||
cdata.key_id = 0
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
function _M.expire(self, time)
|
||||
local dict = self.dict
|
||||
local cdata = self.cdata
|
||||
local key_id = tonumber(cdata.key_id)
|
||||
if key_id <= 0 then
|
||||
return nil, "unlocked"
|
||||
end
|
||||
|
||||
if not time then
|
||||
time = self.exptime
|
||||
end
|
||||
|
||||
local ok, err = dict:replace(self.key, true, time)
|
||||
if not ok then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
340
resty/lrucache.lua
Normal file
340
resty/lrucache.lua
Normal file
@@ -0,0 +1,340 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_sizeof = ffi.sizeof
|
||||
local ffi_cast = ffi.cast
|
||||
local ffi_fill = ffi.fill
|
||||
local ngx_now = ngx.now
|
||||
local uintptr_t = ffi.typeof("uintptr_t")
|
||||
local setmetatable = setmetatable
|
||||
local tonumber = tonumber
|
||||
local type = type
|
||||
local new_tab
|
||||
do
|
||||
local ok
|
||||
ok, new_tab = pcall(require, "table.new")
|
||||
if not ok then
|
||||
new_tab = function(narr, nrec) return {} end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
if string.find(jit.version, " 2.0", 1, true) then
|
||||
ngx.log(ngx.ALERT, "use of lua-resty-lrucache with LuaJIT 2.0 is ",
|
||||
"not recommended; use LuaJIT 2.1+ instead")
|
||||
end
|
||||
|
||||
|
||||
local ok, tb_clear = pcall(require, "table.clear")
|
||||
if not ok then
|
||||
local pairs = pairs
|
||||
tb_clear = function (tab)
|
||||
for k, _ in pairs(tab) do
|
||||
tab[k] = nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
-- queue data types
|
||||
--
|
||||
-- this queue is a double-ended queue and the first node
|
||||
-- is reserved for the queue itself.
|
||||
-- the implementation is mostly borrowed from nginx's ngx_queue_t data
|
||||
-- structure.
|
||||
|
||||
ffi.cdef[[
|
||||
typedef struct lrucache_queue_s lrucache_queue_t;
|
||||
struct lrucache_queue_s {
|
||||
double expire; /* in seconds */
|
||||
lrucache_queue_t *prev;
|
||||
lrucache_queue_t *next;
|
||||
uint32_t user_flags;
|
||||
};
|
||||
]]
|
||||
|
||||
local queue_arr_type = ffi.typeof("lrucache_queue_t[?]")
|
||||
local queue_type = ffi.typeof("lrucache_queue_t")
|
||||
local NULL = ffi.null
|
||||
|
||||
|
||||
-- queue utility functions
|
||||
|
||||
local function queue_insert_tail(h, x)
|
||||
local last = h[0].prev
|
||||
x.prev = last
|
||||
last.next = x
|
||||
x.next = h
|
||||
h[0].prev = x
|
||||
end
|
||||
|
||||
|
||||
local function queue_init(size)
|
||||
if not size then
|
||||
size = 0
|
||||
end
|
||||
local q = ffi_new(queue_arr_type, size + 1)
|
||||
ffi_fill(q, ffi_sizeof(queue_type, size + 1), 0)
|
||||
|
||||
if size == 0 then
|
||||
q[0].prev = q
|
||||
q[0].next = q
|
||||
|
||||
else
|
||||
local prev = q[0]
|
||||
for i = 1, size do
|
||||
local e = q + i
|
||||
e.user_flags = 0
|
||||
prev.next = e
|
||||
e.prev = prev
|
||||
prev = e
|
||||
end
|
||||
|
||||
local last = q[size]
|
||||
last.next = q
|
||||
q[0].prev = last
|
||||
end
|
||||
|
||||
return q
|
||||
end
|
||||
|
||||
|
||||
local function queue_is_empty(q)
|
||||
-- print("q: ", tostring(q), "q.prev: ", tostring(q), ": ", q == q.prev)
|
||||
return q == q[0].prev
|
||||
end
|
||||
|
||||
|
||||
local function queue_remove(x)
|
||||
local prev = x.prev
|
||||
local next = x.next
|
||||
|
||||
next.prev = prev
|
||||
prev.next = next
|
||||
|
||||
-- for debugging purpose only:
|
||||
x.prev = NULL
|
||||
x.next = NULL
|
||||
end
|
||||
|
||||
|
||||
local function queue_insert_head(h, x)
|
||||
x.next = h[0].next
|
||||
x.next.prev = x
|
||||
x.prev = h
|
||||
h[0].next = x
|
||||
end
|
||||
|
||||
|
||||
local function queue_last(h)
|
||||
return h[0].prev
|
||||
end
|
||||
|
||||
|
||||
local function queue_head(h)
|
||||
return h[0].next
|
||||
end
|
||||
|
||||
|
||||
-- true module stuffs
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.11'
|
||||
}
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
local function ptr2num(ptr)
|
||||
return tonumber(ffi_cast(uintptr_t, ptr))
|
||||
end
|
||||
|
||||
|
||||
function _M.new(size)
|
||||
if size < 1 then
|
||||
return nil, "size too small"
|
||||
end
|
||||
|
||||
local self = {
|
||||
hasht = {},
|
||||
free_queue = queue_init(size),
|
||||
cache_queue = queue_init(),
|
||||
key2node = {},
|
||||
node2key = {},
|
||||
num_items = 0,
|
||||
max_items = size,
|
||||
}
|
||||
return setmetatable(self, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.count(self)
|
||||
return self.num_items
|
||||
end
|
||||
|
||||
|
||||
function _M.capacity(self)
|
||||
return self.max_items
|
||||
end
|
||||
|
||||
|
||||
function _M.get(self, key)
|
||||
local hasht = self.hasht
|
||||
local val = hasht[key]
|
||||
if val == nil then
|
||||
return nil
|
||||
end
|
||||
|
||||
local node = self.key2node[key]
|
||||
|
||||
-- print(key, ": moving node ", tostring(node), " to cache queue head")
|
||||
local cache_queue = self.cache_queue
|
||||
queue_remove(node)
|
||||
queue_insert_head(cache_queue, node)
|
||||
|
||||
if node.expire >= 0 and node.expire < ngx_now() then
|
||||
-- print("expired: ", node.expire, " > ", ngx_now())
|
||||
return nil, val, node.user_flags
|
||||
end
|
||||
|
||||
return val, nil, node.user_flags
|
||||
end
|
||||
|
||||
|
||||
function _M.delete(self, key)
|
||||
self.hasht[key] = nil
|
||||
|
||||
local key2node = self.key2node
|
||||
local node = key2node[key]
|
||||
|
||||
if not node then
|
||||
return false
|
||||
end
|
||||
|
||||
key2node[key] = nil
|
||||
self.node2key[ptr2num(node)] = nil
|
||||
|
||||
queue_remove(node)
|
||||
queue_insert_tail(self.free_queue, node)
|
||||
self.num_items = self.num_items - 1
|
||||
return true
|
||||
end
|
||||
|
||||
|
||||
function _M.set(self, key, value, ttl, flags)
|
||||
local hasht = self.hasht
|
||||
hasht[key] = value
|
||||
|
||||
local key2node = self.key2node
|
||||
local node = key2node[key]
|
||||
if not node then
|
||||
local free_queue = self.free_queue
|
||||
local node2key = self.node2key
|
||||
|
||||
if queue_is_empty(free_queue) then
|
||||
-- evict the least recently used key
|
||||
-- assert(not queue_is_empty(self.cache_queue))
|
||||
node = queue_last(self.cache_queue)
|
||||
|
||||
local oldkey = node2key[ptr2num(node)]
|
||||
-- print(key, ": evicting oldkey: ", oldkey, ", oldnode: ",
|
||||
-- tostring(node))
|
||||
if oldkey then
|
||||
hasht[oldkey] = nil
|
||||
key2node[oldkey] = nil
|
||||
end
|
||||
|
||||
else
|
||||
-- take a free queue node
|
||||
node = queue_head(free_queue)
|
||||
-- only add count if we are not evicting
|
||||
self.num_items = self.num_items + 1
|
||||
-- print(key, ": get a new free node: ", tostring(node))
|
||||
end
|
||||
|
||||
node2key[ptr2num(node)] = key
|
||||
key2node[key] = node
|
||||
end
|
||||
|
||||
queue_remove(node)
|
||||
queue_insert_head(self.cache_queue, node)
|
||||
|
||||
if ttl then
|
||||
node.expire = ngx_now() + ttl
|
||||
else
|
||||
node.expire = -1
|
||||
end
|
||||
|
||||
if type(flags) == "number" and flags >= 0 then
|
||||
node.user_flags = flags
|
||||
|
||||
else
|
||||
node.user_flags = 0
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function _M.get_keys(self, max_count, res)
|
||||
if not max_count or max_count == 0 then
|
||||
max_count = self.num_items
|
||||
end
|
||||
|
||||
if not res then
|
||||
res = new_tab(max_count + 1, 0) -- + 1 for trailing hole
|
||||
end
|
||||
|
||||
local cache_queue = self.cache_queue
|
||||
local node2key = self.node2key
|
||||
|
||||
local i = 0
|
||||
local node = queue_head(cache_queue)
|
||||
|
||||
while node ~= cache_queue do
|
||||
if i >= max_count then
|
||||
break
|
||||
end
|
||||
|
||||
i = i + 1
|
||||
res[i] = node2key[ptr2num(node)]
|
||||
node = node.next
|
||||
end
|
||||
|
||||
res[i + 1] = nil
|
||||
|
||||
return res
|
||||
end
|
||||
|
||||
|
||||
function _M.flush_all(self)
|
||||
tb_clear(self.hasht)
|
||||
tb_clear(self.node2key)
|
||||
tb_clear(self.key2node)
|
||||
|
||||
self.num_items = 0
|
||||
|
||||
local cache_queue = self.cache_queue
|
||||
local free_queue = self.free_queue
|
||||
|
||||
-- splice the cache_queue into free_queue
|
||||
if not queue_is_empty(cache_queue) then
|
||||
local free_head = free_queue[0]
|
||||
local free_last = free_head.prev
|
||||
|
||||
local cache_head = cache_queue[0]
|
||||
local cache_first = cache_head.next
|
||||
local cache_last = cache_head.prev
|
||||
|
||||
free_last.next = cache_first
|
||||
cache_first.prev = free_last
|
||||
|
||||
cache_last.next = free_head
|
||||
free_head.prev = cache_last
|
||||
|
||||
cache_head.next = cache_queue
|
||||
cache_head.prev = cache_queue
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
606
resty/lrucache/pureffi.lua
Normal file
606
resty/lrucache/pureffi.lua
Normal file
@@ -0,0 +1,606 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
-- Copyright (C) Shuxin Yang
|
||||
|
||||
--[[
|
||||
This module implements a key/value cache store. We adopt LRU as our
|
||||
replace/evict policy. Each key/value pair is tagged with a Time-to-Live (TTL);
|
||||
from user's perspective, stale pairs are automatically removed from the cache.
|
||||
|
||||
Why FFI
|
||||
-------
|
||||
In Lua, expression "table[key] = nil" does not *PHYSICALLY* remove the value
|
||||
associated with the key; it just set the value to be nil! So the table will
|
||||
keep growing with large number of the key/nil pairs which will be purged until
|
||||
resize() operator is called.
|
||||
|
||||
This "feature" is terribly ill-suited to what we need. Therefore we have to
|
||||
rely on FFI to build a hash-table where any entry can be physically deleted
|
||||
immediately.
|
||||
|
||||
Under the hood:
|
||||
--------------
|
||||
In concept, we introduce three data structures to implement the cache store:
|
||||
1. key/value vector for storing keys and values.
|
||||
2. a queue to mimic the LRU.
|
||||
3. hash-table for looking up the value for a given key.
|
||||
|
||||
Unfortunately, efficiency and clarity usually come at each other cost. The
|
||||
data strucutres we are using are slightly more complicated than what we
|
||||
described above.
|
||||
|
||||
o. Lua does not have efficient way to store a vector of pair. So, we use
|
||||
two vectors for key/value pair: one for keys and the other for values
|
||||
(_M.key_v and _M.val_v, respectively), and i-th key corresponds to
|
||||
i-th value.
|
||||
|
||||
A key/value pair is identified by the "id" field in a "node" (we shall
|
||||
discuss node later)
|
||||
|
||||
o. The queue is nothing more than a doubly-linked list of "node" linked via
|
||||
lrucache_pureffi_queue_s::{next|prev} fields.
|
||||
|
||||
o. The hash-table has two parts:
|
||||
- the _M.bucket_v[] a vector of bucket, indiced by hash-value, and
|
||||
- a bucket is a singly-linked list of "node" via the
|
||||
lrucache_pureffi_queue_s::conflict field.
|
||||
|
||||
A key must be a string, and the hash value of a key is evaluated by:
|
||||
crc32(key-cast-to-pointer) % size(_M.bucket_v).
|
||||
We mandate size(_M.bucket_v) being a power-of-two in order to avoid
|
||||
expensive modulo operation.
|
||||
|
||||
At the heart of the module is an array of "node" (of type
|
||||
lrucache_pureffi_queue_s). A node:
|
||||
- keeps the meta-data of its corresponding key/value pair
|
||||
(embodied by the "id", and "expire" field);
|
||||
- is a part of LRU queue (embodied by "prev" and "next" fields);
|
||||
- is a part of hash-table (embodied by the "conflict" field).
|
||||
]]
|
||||
|
||||
local ffi = require "ffi"
|
||||
local bit = require "bit"
|
||||
|
||||
|
||||
local ffi_new = ffi.new
|
||||
local ffi_sizeof = ffi.sizeof
|
||||
local ffi_cast = ffi.cast
|
||||
local ffi_fill = ffi.fill
|
||||
local ngx_now = ngx.now
|
||||
local uintptr_t = ffi.typeof("uintptr_t")
|
||||
local c_str_t = ffi.typeof("const char*")
|
||||
local int_t = ffi.typeof("int")
|
||||
local int_array_t = ffi.typeof("int[?]")
|
||||
|
||||
|
||||
local crc_tab = ffi.new("const unsigned int[256]", {
|
||||
0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F,
|
||||
0xE963A535, 0x9E6495A3, 0x0EDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988,
|
||||
0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, 0x1DB71064, 0x6AB020F2,
|
||||
0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,
|
||||
0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9,
|
||||
0xFA0F3D63, 0x8D080DF5, 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172,
|
||||
0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, 0x35B5A8FA, 0x42B2986C,
|
||||
0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,
|
||||
0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423,
|
||||
0xCFBA9599, 0xB8BDA50F, 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924,
|
||||
0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, 0x76DC4190, 0x01DB7106,
|
||||
0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,
|
||||
0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D,
|
||||
0x91646C97, 0xE6635C01, 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E,
|
||||
0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, 0x65B0D9C6, 0x12B7E950,
|
||||
0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,
|
||||
0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7,
|
||||
0xA4D1C46D, 0xD3D6F4FB, 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0,
|
||||
0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, 0x5005713C, 0x270241AA,
|
||||
0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,
|
||||
0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81,
|
||||
0xB7BD5C3B, 0xC0BA6CAD, 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A,
|
||||
0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, 0xE3630B12, 0x94643B84,
|
||||
0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,
|
||||
0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB,
|
||||
0x196C3671, 0x6E6B06E7, 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC,
|
||||
0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, 0xD6D6A3E8, 0xA1D1937E,
|
||||
0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,
|
||||
0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55,
|
||||
0x316E8EEF, 0x4669BE79, 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236,
|
||||
0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, 0xC5BA3BBE, 0xB2BD0B28,
|
||||
0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,
|
||||
0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F,
|
||||
0x72076785, 0x05005713, 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38,
|
||||
0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, 0x86D3D2D4, 0xF1D4E242,
|
||||
0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,
|
||||
0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69,
|
||||
0x616BFFD3, 0x166CCF45, 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2,
|
||||
0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, 0xAED16A4A, 0xD9D65ADC,
|
||||
0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,
|
||||
0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693,
|
||||
0x54DE5729, 0x23D967BF, 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94,
|
||||
0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D });
|
||||
|
||||
local setmetatable = setmetatable
|
||||
local tonumber = tonumber
|
||||
local tostring = tostring
|
||||
local type = type
|
||||
|
||||
local brshift = bit.rshift
|
||||
local bxor = bit.bxor
|
||||
local band = bit.band
|
||||
|
||||
local new_tab
|
||||
do
|
||||
local ok
|
||||
ok, new_tab = pcall(require, "table.new")
|
||||
if not ok then
|
||||
new_tab = function(narr, nrec) return {} end
|
||||
end
|
||||
end
|
||||
|
||||
-- queue data types
|
||||
--
|
||||
-- this queue is a double-ended queue and the first node
|
||||
-- is reserved for the queue itself.
|
||||
-- the implementation is mostly borrowed from nginx's ngx_queue_t data
|
||||
-- structure.
|
||||
|
||||
ffi.cdef[[
|
||||
/* A lrucache_pureffi_queue_s node hook together three data structures:
|
||||
* o. the key/value store as embodied by the "id" (which is in essence the
|
||||
* indentifier of key/pair pair) and the "expire" (which is a metadata
|
||||
* of the corresponding key/pair pair).
|
||||
* o. The LRU queue via the prev/next fields.
|
||||
* o. The hash-tabble as embodied by the "conflict" field.
|
||||
*/
|
||||
typedef struct lrucache_pureffi_queue_s lrucache_pureffi_queue_t;
|
||||
struct lrucache_pureffi_queue_s {
|
||||
/* Each node is assigned a unique ID at construction time, and the
|
||||
* ID remain immutatble, regardless the node is in active-list or
|
||||
* free-list. The queue header is assigned ID 0. Since queue-header
|
||||
* is a sentinel node, 0 denodes "invalid ID".
|
||||
*
|
||||
* Intuitively, we can view the "id" as the identifier of key/value
|
||||
* pair.
|
||||
*/
|
||||
int id;
|
||||
|
||||
/* The bucket of the hash-table is implemented as a singly-linked list.
|
||||
* The "conflict" refers to the ID of the next node in the bucket.
|
||||
*/
|
||||
int conflict;
|
||||
|
||||
uint32_t user_flags;
|
||||
|
||||
double expire; /* in seconds */
|
||||
|
||||
lrucache_pureffi_queue_t *prev;
|
||||
lrucache_pureffi_queue_t *next;
|
||||
};
|
||||
]]
|
||||
|
||||
local queue_arr_type = ffi.typeof("lrucache_pureffi_queue_t[?]")
|
||||
--local queue_ptr_type = ffi.typeof("lrucache_pureffi_queue_t*")
|
||||
local queue_type = ffi.typeof("lrucache_pureffi_queue_t")
|
||||
local NULL = ffi.null
|
||||
|
||||
|
||||
--========================================================================
|
||||
--
|
||||
-- Queue utility functions
|
||||
--
|
||||
--========================================================================
|
||||
|
||||
-- Append the element "x" to the given queue "h".
|
||||
local function queue_insert_tail(h, x)
|
||||
local last = h[0].prev
|
||||
x.prev = last
|
||||
last.next = x
|
||||
x.next = h
|
||||
h[0].prev = x
|
||||
end
|
||||
|
||||
|
||||
--[[
|
||||
Allocate a queue with size + 1 elements. Elements are linked together in a
|
||||
circular way, i.e. the last element's "next" points to the first element,
|
||||
while the first element's "prev" element points to the last element.
|
||||
]]
|
||||
local function queue_init(size)
|
||||
if not size then
|
||||
size = 0
|
||||
end
|
||||
local q = ffi_new(queue_arr_type, size + 1)
|
||||
ffi_fill(q, ffi_sizeof(queue_type, size + 1), 0)
|
||||
|
||||
if size == 0 then
|
||||
q[0].prev = q
|
||||
q[0].next = q
|
||||
|
||||
else
|
||||
local prev = q[0]
|
||||
for i = 1, size do
|
||||
local e = q[i]
|
||||
e.id = i
|
||||
e.user_flags = 0
|
||||
prev.next = e
|
||||
e.prev = prev
|
||||
prev = e
|
||||
end
|
||||
|
||||
local last = q[size]
|
||||
last.next = q
|
||||
q[0].prev = last
|
||||
end
|
||||
|
||||
return q
|
||||
end
|
||||
|
||||
|
||||
local function queue_is_empty(q)
|
||||
-- print("q: ", tostring(q), "q.prev: ", tostring(q), ": ", q == q.prev)
|
||||
return q == q[0].prev
|
||||
end
|
||||
|
||||
|
||||
local function queue_remove(x)
|
||||
local prev = x.prev
|
||||
local next = x.next
|
||||
|
||||
next.prev = prev
|
||||
prev.next = next
|
||||
|
||||
-- for debugging purpose only:
|
||||
x.prev = NULL
|
||||
x.next = NULL
|
||||
end
|
||||
|
||||
|
||||
-- Insert the element "x" the to the given queue "h"
|
||||
local function queue_insert_head(h, x)
|
||||
x.next = h[0].next
|
||||
x.next.prev = x
|
||||
x.prev = h
|
||||
h[0].next = x
|
||||
end
|
||||
|
||||
|
||||
local function queue_last(h)
|
||||
return h[0].prev
|
||||
end
|
||||
|
||||
|
||||
local function queue_head(h)
|
||||
return h[0].next
|
||||
end
|
||||
|
||||
|
||||
--========================================================================
|
||||
--
|
||||
-- Miscellaneous Utility Functions
|
||||
--
|
||||
--========================================================================
|
||||
|
||||
local function ptr2num(ptr)
|
||||
return tonumber(ffi_cast(uintptr_t, ptr))
|
||||
end
|
||||
|
||||
|
||||
local function crc32_ptr(ptr)
|
||||
local p = brshift(ptr2num(ptr), 3)
|
||||
local b = band(p, 255)
|
||||
local crc32 = crc_tab[b]
|
||||
|
||||
b = band(brshift(p, 8), 255)
|
||||
crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)])
|
||||
|
||||
b = band(brshift(p, 16), 255)
|
||||
crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)])
|
||||
|
||||
--b = band(brshift(p, 24), 255)
|
||||
--crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)])
|
||||
return crc32
|
||||
end
|
||||
|
||||
|
||||
--========================================================================
|
||||
--
|
||||
-- Implementation of "export" functions
|
||||
--
|
||||
--========================================================================
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.11'
|
||||
}
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
-- "size" specifies the maximum number of entries in the LRU queue, and the
|
||||
-- "load_factor" designates the 'load factor' of the hash-table we are using
|
||||
-- internally. The default value of load-factor is 0.5 (i.e. 50%); if the
|
||||
-- load-factor is specified, it will be clamped to the range of [0.1, 1](i.e.
|
||||
-- if load-factor is greater than 1, it will be saturated to 1, likewise,
|
||||
-- if load-factor is smaller than 0.1, it will be clamped to 0.1).
|
||||
function _M.new(size, load_factor)
|
||||
if size < 1 then
|
||||
return nil, "size too small"
|
||||
end
|
||||
|
||||
-- Determine bucket size, which must be power of two.
|
||||
local load_f = load_factor
|
||||
if not load_factor then
|
||||
load_f = 0.5
|
||||
elseif load_factor > 1 then
|
||||
load_f = 1
|
||||
elseif load_factor < 0.1 then
|
||||
load_f = 0.1
|
||||
end
|
||||
|
||||
local bs_min = size / load_f
|
||||
-- The bucket_sz *MUST* be a power-of-two. See the hash_string().
|
||||
local bucket_sz = 1
|
||||
repeat
|
||||
bucket_sz = bucket_sz * 2
|
||||
until bucket_sz >= bs_min
|
||||
|
||||
local self = {
|
||||
size = size,
|
||||
bucket_sz = bucket_sz,
|
||||
free_queue = queue_init(size),
|
||||
cache_queue = queue_init(0),
|
||||
node_v = nil,
|
||||
key_v = new_tab(size, 0),
|
||||
val_v = new_tab(size, 0),
|
||||
bucket_v = ffi_new(int_array_t, bucket_sz),
|
||||
num_items = 0,
|
||||
}
|
||||
-- "node_v" is an array of all the nodes used in the LRU queue. Exprpession
|
||||
-- node_v[i] evaluates to the element of ID "i".
|
||||
self.node_v = self.free_queue
|
||||
|
||||
-- Allocate the array-part of the key_v, val_v, bucket_v.
|
||||
--local key_v = self.key_v
|
||||
--local val_v = self.val_v
|
||||
--local bucket_v = self.bucket_v
|
||||
ffi_fill(self.bucket_v, ffi_sizeof(int_t, bucket_sz), 0)
|
||||
|
||||
return setmetatable(self, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.count(self)
|
||||
return self.num_items
|
||||
end
|
||||
|
||||
|
||||
function _M.capacity(self)
|
||||
return self.size
|
||||
end
|
||||
|
||||
|
||||
local function hash_string(self, str)
|
||||
local c_str = ffi_cast(c_str_t, str)
|
||||
|
||||
local hv = crc32_ptr(c_str)
|
||||
hv = band(hv, self.bucket_sz - 1)
|
||||
-- Hint: bucket is 0-based
|
||||
return hv
|
||||
end
|
||||
|
||||
|
||||
-- Search the node associated with the key in the bucket, if found returns
|
||||
-- the the id of the node, and the id of its previous node in the conflict list.
|
||||
-- The "bucket_hdr_id" is the ID of the first node in the bucket
|
||||
local function _find_node_in_bucket(key, key_v, node_v, bucket_hdr_id)
|
||||
if bucket_hdr_id ~= 0 then
|
||||
local prev = 0
|
||||
local cur = bucket_hdr_id
|
||||
|
||||
while cur ~= 0 and key_v[cur] ~= key do
|
||||
prev = cur
|
||||
cur = node_v[cur].conflict
|
||||
end
|
||||
|
||||
if cur ~= 0 then
|
||||
return cur, prev
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
-- Return the node corresponding to the key/val.
|
||||
local function find_key(self, key)
|
||||
local key_hash = hash_string(self, key)
|
||||
return _find_node_in_bucket(key, self.key_v, self.node_v,
|
||||
self.bucket_v[key_hash])
|
||||
end
|
||||
|
||||
|
||||
--[[ This function tries to
|
||||
1. Remove the given key and the associated value from the key/value store,
|
||||
2. Remove the entry associated with the key from the hash-table.
|
||||
|
||||
NOTE: all queues remain intact.
|
||||
|
||||
If there was a node bound to the key/val, return that node; otherwise,
|
||||
nil is returned.
|
||||
]]
|
||||
local function remove_key(self, key)
|
||||
local key_v = self.key_v
|
||||
local val_v = self.val_v
|
||||
local node_v = self.node_v
|
||||
local bucket_v = self.bucket_v
|
||||
|
||||
local key_hash = hash_string(self, key)
|
||||
local cur, prev =
|
||||
_find_node_in_bucket(key, key_v, node_v, bucket_v[key_hash])
|
||||
|
||||
if cur then
|
||||
-- In an attempt to make key and val dead.
|
||||
key_v[cur] = nil
|
||||
val_v[cur] = nil
|
||||
self.num_items = self.num_items - 1
|
||||
|
||||
-- Remove the node from the hash table
|
||||
local next_node = node_v[cur].conflict
|
||||
if prev ~= 0 then
|
||||
node_v[prev].conflict = next_node
|
||||
else
|
||||
bucket_v[key_hash] = next_node
|
||||
end
|
||||
node_v[cur].conflict = 0
|
||||
|
||||
return cur
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
--[[ Bind the key/val with the given node, and insert the node into the Hashtab.
|
||||
NOTE: this function does not touch any queue
|
||||
]]
|
||||
local function insert_key(self, key, val, node)
|
||||
-- Bind the key/val with the node
|
||||
local node_id = node.id
|
||||
self.key_v[node_id] = key
|
||||
self.val_v[node_id] = val
|
||||
|
||||
-- Insert the node into the hash-table
|
||||
local key_hash = hash_string(self, key)
|
||||
local bucket_v = self.bucket_v
|
||||
node.conflict = bucket_v[key_hash]
|
||||
bucket_v[key_hash] = node_id
|
||||
self.num_items = self.num_items + 1
|
||||
end
|
||||
|
||||
|
||||
function _M.get(self, key)
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local node_id = find_key(self, key)
|
||||
if not node_id then
|
||||
return nil
|
||||
end
|
||||
|
||||
-- print(key, ": moving node ", tostring(node), " to cache queue head")
|
||||
local cache_queue = self.cache_queue
|
||||
local node = self.node_v + node_id
|
||||
queue_remove(node)
|
||||
queue_insert_head(cache_queue, node)
|
||||
|
||||
local expire = node.expire
|
||||
if expire >= 0 and expire < ngx_now() then
|
||||
-- print("expired: ", node.expire, " > ", ngx_now())
|
||||
return nil, self.val_v[node_id], node.user_flags
|
||||
end
|
||||
|
||||
return self.val_v[node_id], nil, node.user_flags
|
||||
end
|
||||
|
||||
|
||||
function _M.delete(self, key)
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local node_id = remove_key(self, key);
|
||||
if not node_id then
|
||||
return false
|
||||
end
|
||||
|
||||
local node = self.node_v + node_id
|
||||
queue_remove(node)
|
||||
queue_insert_tail(self.free_queue, node)
|
||||
return true
|
||||
end
|
||||
|
||||
|
||||
function _M.set(self, key, value, ttl, flags)
|
||||
if type(key) ~= "string" then
|
||||
key = tostring(key)
|
||||
end
|
||||
|
||||
local node_id = find_key(self, key)
|
||||
local node
|
||||
if not node_id then
|
||||
local free_queue = self.free_queue
|
||||
if queue_is_empty(free_queue) then
|
||||
-- evict the least recently used key
|
||||
-- assert(not queue_is_empty(self.cache_queue))
|
||||
node = queue_last(self.cache_queue)
|
||||
remove_key(self, self.key_v[node.id])
|
||||
else
|
||||
-- take a free queue node
|
||||
node = queue_head(free_queue)
|
||||
-- print(key, ": get a new free node: ", tostring(node))
|
||||
end
|
||||
|
||||
-- insert the key
|
||||
insert_key(self, key, value, node)
|
||||
else
|
||||
node = self.node_v + node_id
|
||||
self.val_v[node_id] = value
|
||||
end
|
||||
|
||||
queue_remove(node)
|
||||
queue_insert_head(self.cache_queue, node)
|
||||
|
||||
if ttl then
|
||||
node.expire = ngx_now() + ttl
|
||||
else
|
||||
node.expire = -1
|
||||
end
|
||||
|
||||
if type(flags) == "number" and flags >= 0 then
|
||||
node.user_flags = flags
|
||||
|
||||
else
|
||||
node.user_flags = 0
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function _M.get_keys(self, max_count, res)
|
||||
if not max_count or max_count == 0 then
|
||||
max_count = self.num_items
|
||||
end
|
||||
|
||||
if not res then
|
||||
res = new_tab(max_count + 1, 0) -- + 1 for trailing hole
|
||||
end
|
||||
|
||||
local cache_queue = self.cache_queue
|
||||
local key_v = self.key_v
|
||||
|
||||
local i = 0
|
||||
local node = queue_head(cache_queue)
|
||||
|
||||
while node ~= cache_queue do
|
||||
if i >= max_count then
|
||||
break
|
||||
end
|
||||
|
||||
i = i + 1
|
||||
res[i] = key_v[node.id]
|
||||
node = node.next
|
||||
end
|
||||
|
||||
res[i + 1] = nil
|
||||
|
||||
return res
|
||||
end
|
||||
|
||||
|
||||
function _M.flush_all(self)
|
||||
local cache_queue = self.cache_queue
|
||||
local key_v = self.key_v
|
||||
|
||||
local node = queue_head(cache_queue)
|
||||
|
||||
while node ~= cache_queue do
|
||||
local key = key_v[node.id]
|
||||
node = node.next
|
||||
_M.delete(self, key)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
72
resty/md5.lua
Normal file
72
resty/md5.lua
Normal file
@@ -0,0 +1,72 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
local setmetatable = setmetatable
|
||||
--local error = error
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef unsigned long MD5_LONG ;
|
||||
|
||||
enum {
|
||||
MD5_CBLOCK = 64,
|
||||
MD5_LBLOCK = MD5_CBLOCK/4
|
||||
};
|
||||
|
||||
typedef struct MD5state_st
|
||||
{
|
||||
MD5_LONG A,B,C,D;
|
||||
MD5_LONG Nl,Nh;
|
||||
MD5_LONG data[MD5_LBLOCK];
|
||||
unsigned int num;
|
||||
} MD5_CTX;
|
||||
|
||||
int MD5_Init(MD5_CTX *c);
|
||||
int MD5_Update(MD5_CTX *c, const void *data, size_t len);
|
||||
int MD5_Final(unsigned char *md, MD5_CTX *c);
|
||||
]]
|
||||
|
||||
local buf = ffi_new("char[16]")
|
||||
local ctx_ptr_type = ffi.typeof("MD5_CTX[1]")
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local ctx = ffi_new(ctx_ptr_type)
|
||||
if C.MD5_Init(ctx) == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return setmetatable({ _ctx = ctx }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.update(self, s)
|
||||
return C.MD5_Update(self._ctx, s, #s) == 1
|
||||
end
|
||||
|
||||
|
||||
function _M.final(self)
|
||||
if C.MD5_Final(buf, self._ctx) == 1 then
|
||||
return ffi_str(buf, 16)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
|
||||
function _M.reset(self)
|
||||
return C.MD5_Init(self._ctx) == 1
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
|
||||
744
resty/memcached.lua
Normal file
744
resty/memcached.lua
Normal file
@@ -0,0 +1,744 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh), CloudFlare Inc.
|
||||
|
||||
|
||||
local escape_uri = ngx.escape_uri
|
||||
local unescape_uri = ngx.unescape_uri
|
||||
local match = string.match
|
||||
local tcp = ngx.socket.tcp
|
||||
local strlen = string.len
|
||||
local concat = table.concat
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.16'
|
||||
}
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
function _M.new(self, opts)
|
||||
local sock, err = tcp()
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local escape_key = escape_uri
|
||||
local unescape_key = unescape_uri
|
||||
|
||||
if opts then
|
||||
local key_transform = opts.key_transform
|
||||
|
||||
if key_transform then
|
||||
escape_key = key_transform[1]
|
||||
unescape_key = key_transform[2]
|
||||
if not escape_key or not unescape_key then
|
||||
return nil, "expecting key_transform = { escape, unescape } table"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
sock = sock,
|
||||
escape_key = escape_key,
|
||||
unescape_key = unescape_key,
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
local function set_timeouts(self, connect, send, read)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
sock:settimeouts(connect, send, read)
|
||||
return 1
|
||||
end
|
||||
_M.set_timeouts = set_timeouts
|
||||
|
||||
function _M.set_timeout(self, timeout)
|
||||
return set_timeouts(self, timeout, timeout, timeout)
|
||||
end
|
||||
|
||||
|
||||
function _M.connect(self, ...)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:connect(...)
|
||||
end
|
||||
|
||||
|
||||
local function _multi_get(self, keys)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local nkeys = #keys
|
||||
|
||||
if nkeys == 0 then
|
||||
return {}, nil
|
||||
end
|
||||
|
||||
local escape_key = self.escape_key
|
||||
local cmd = {"get"}
|
||||
local n = 1
|
||||
|
||||
for i = 1, nkeys do
|
||||
cmd[n + 1] = " "
|
||||
cmd[n + 2] = escape_key(keys[i])
|
||||
n = n + 2
|
||||
end
|
||||
cmd[n + 1] = "\r\n"
|
||||
|
||||
-- print("multi get cmd: ", cmd)
|
||||
|
||||
local bytes, err = sock:send(cmd)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local unescape_key = self.unescape_key
|
||||
local results = {}
|
||||
|
||||
while true do
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if line == 'END' then
|
||||
break
|
||||
end
|
||||
|
||||
local key, flags, len = match(line, '^VALUE (%S+) (%d+) (%d+)$')
|
||||
-- print("key: ", key, "len: ", len, ", flags: ", flags)
|
||||
|
||||
if not key then
|
||||
return nil, line
|
||||
end
|
||||
|
||||
local data, err = sock:receive(len)
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
results[unescape_key(key)] = {data, flags}
|
||||
|
||||
data, err = sock:receive(2) -- discard the trailing CRLF
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
|
||||
function _M.get(self, key)
|
||||
if type(key) == "table" then
|
||||
return _multi_get(self, key)
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send("get " .. self.escape_key(key) .. "\r\n")
|
||||
if not bytes then
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
if line == 'END' then
|
||||
return nil, nil, nil
|
||||
end
|
||||
|
||||
local flags, len = match(line, '^VALUE %S+ (%d+) (%d+)$')
|
||||
if not flags then
|
||||
return nil, nil, line
|
||||
end
|
||||
|
||||
-- print("len: ", len, ", flags: ", flags)
|
||||
|
||||
local data, err = sock:receive(len)
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
line, err = sock:receive(7) -- discard the trailing "\r\nEND\r\n"
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
return data, flags
|
||||
end
|
||||
|
||||
|
||||
local function _multi_gets(self, keys)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local nkeys = #keys
|
||||
|
||||
if nkeys == 0 then
|
||||
return {}, nil
|
||||
end
|
||||
|
||||
local escape_key = self.escape_key
|
||||
local cmd = {"gets"}
|
||||
local n = 1
|
||||
for i = 1, nkeys do
|
||||
cmd[n + 1] = " "
|
||||
cmd[n + 2] = escape_key(keys[i])
|
||||
n = n + 2
|
||||
end
|
||||
cmd[n + 1] = "\r\n"
|
||||
|
||||
-- print("multi get cmd: ", cmd)
|
||||
|
||||
local bytes, err = sock:send(cmd)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local unescape_key = self.unescape_key
|
||||
local results = {}
|
||||
|
||||
while true do
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if line == 'END' then
|
||||
break
|
||||
end
|
||||
|
||||
local key, flags, len, cas_uniq =
|
||||
match(line, '^VALUE (%S+) (%d+) (%d+) (%d+)$')
|
||||
|
||||
-- print("key: ", key, "len: ", len, ", flags: ", flags)
|
||||
|
||||
if not key then
|
||||
return nil, line
|
||||
end
|
||||
|
||||
local data, err = sock:receive(len)
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
results[unescape_key(key)] = {data, flags, cas_uniq}
|
||||
|
||||
data, err = sock:receive(2) -- discard the trailing CRLF
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
|
||||
function _M.gets(self, key)
|
||||
if type(key) == "table" then
|
||||
return _multi_gets(self, key)
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, nil, "not initialized"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send("gets " .. self.escape_key(key) .. "\r\n")
|
||||
if not bytes then
|
||||
return nil, nil, nil, err
|
||||
end
|
||||
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, nil, nil, err
|
||||
end
|
||||
|
||||
if line == 'END' then
|
||||
return nil, nil, nil, nil
|
||||
end
|
||||
|
||||
local flags, len, cas_uniq = match(line, '^VALUE %S+ (%d+) (%d+) (%d+)$')
|
||||
if not flags then
|
||||
return nil, nil, nil, line
|
||||
end
|
||||
|
||||
-- print("len: ", len, ", flags: ", flags)
|
||||
|
||||
local data, err = sock:receive(len)
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, nil, nil, err
|
||||
end
|
||||
|
||||
line, err = sock:receive(7) -- discard the trailing "\r\nEND\r\n"
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, nil, nil, err
|
||||
end
|
||||
|
||||
return data, flags, cas_uniq
|
||||
end
|
||||
|
||||
|
||||
local function _expand_table(value)
|
||||
local segs = {}
|
||||
local nelems = #value
|
||||
local nsegs = 0
|
||||
for i = 1, nelems do
|
||||
local seg = value[i]
|
||||
nsegs = nsegs + 1
|
||||
if type(seg) == "table" then
|
||||
segs[nsegs] = _expand_table(seg)
|
||||
else
|
||||
segs[nsegs] = seg
|
||||
end
|
||||
end
|
||||
return concat(segs)
|
||||
end
|
||||
|
||||
|
||||
local function _store(self, cmd, key, value, exptime, flags)
|
||||
if not exptime then
|
||||
exptime = 0
|
||||
end
|
||||
|
||||
if not flags then
|
||||
flags = 0
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
if type(value) == "table" then
|
||||
value = _expand_table(value)
|
||||
end
|
||||
|
||||
local req = cmd .. " " .. self.escape_key(key) .. " " .. flags .. " "
|
||||
.. exptime .. " " .. strlen(value) .. "\r\n" .. value
|
||||
.. "\r\n"
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local data, err = sock:receive()
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if data == "STORED" then
|
||||
return 1
|
||||
end
|
||||
|
||||
return nil, data
|
||||
end
|
||||
|
||||
|
||||
function _M.set(self, ...)
|
||||
return _store(self, "set", ...)
|
||||
end
|
||||
|
||||
|
||||
function _M.add(self, ...)
|
||||
return _store(self, "add", ...)
|
||||
end
|
||||
|
||||
|
||||
function _M.replace(self, ...)
|
||||
return _store(self, "replace", ...)
|
||||
end
|
||||
|
||||
|
||||
function _M.append(self, ...)
|
||||
return _store(self, "append", ...)
|
||||
end
|
||||
|
||||
|
||||
function _M.prepend(self, ...)
|
||||
return _store(self, "prepend", ...)
|
||||
end
|
||||
|
||||
|
||||
function _M.cas(self, key, value, cas_uniq, exptime, flags)
|
||||
if not exptime then
|
||||
exptime = 0
|
||||
end
|
||||
|
||||
if not flags then
|
||||
flags = 0
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local req = "cas " .. self.escape_key(key) .. " " .. flags .. " "
|
||||
.. exptime .. " " .. strlen(value) .. " " .. cas_uniq
|
||||
.. "\r\n" .. value .. "\r\n"
|
||||
|
||||
-- local cjson = require "cjson"
|
||||
-- print("request: ", cjson.encode(req))
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
-- print("response: [", line, "]")
|
||||
|
||||
if line == "STORED" then
|
||||
return 1
|
||||
end
|
||||
|
||||
return nil, line
|
||||
end
|
||||
|
||||
|
||||
function _M.delete(self, key)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
key = self.escape_key(key)
|
||||
|
||||
local req = "delete " .. key .. "\r\n"
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local res, err = sock:receive()
|
||||
if not res then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if res ~= 'DELETED' then
|
||||
return nil, res
|
||||
end
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
function _M.set_keepalive(self, ...)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:setkeepalive(...)
|
||||
end
|
||||
|
||||
|
||||
function _M.get_reused_times(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:getreusedtimes()
|
||||
end
|
||||
|
||||
|
||||
function _M.flush_all(self, time)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local req
|
||||
if time then
|
||||
req = "flush_all " .. time .. "\r\n"
|
||||
else
|
||||
req = "flush_all\r\n"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local res, err = sock:receive()
|
||||
if not res then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if res ~= 'OK' then
|
||||
return nil, res
|
||||
end
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
local function _incr_decr(self, cmd, key, value)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local req = cmd .. " " .. self.escape_key(key) .. " " .. value .. "\r\n"
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if not match(line, '^%d+$') then
|
||||
return nil, line
|
||||
end
|
||||
|
||||
return line
|
||||
end
|
||||
|
||||
|
||||
function _M.incr(self, key, value)
|
||||
return _incr_decr(self, "incr", key, value)
|
||||
end
|
||||
|
||||
|
||||
function _M.decr(self, key, value)
|
||||
return _incr_decr(self, "decr", key, value)
|
||||
end
|
||||
|
||||
|
||||
function _M.stats(self, args)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local req
|
||||
if args then
|
||||
req = "stats " .. args .. "\r\n"
|
||||
else
|
||||
req = "stats\r\n"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local lines = {}
|
||||
local n = 0
|
||||
while true do
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if line == 'END' then
|
||||
return lines, nil
|
||||
end
|
||||
|
||||
if not match(line, "ERROR") then
|
||||
n = n + 1
|
||||
lines[n] = line
|
||||
else
|
||||
return nil, line
|
||||
end
|
||||
end
|
||||
|
||||
-- cannot reach here...
|
||||
return lines
|
||||
end
|
||||
|
||||
|
||||
function _M.version(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send("version\r\n")
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local ver = match(line, "^VERSION (.+)$")
|
||||
if not ver then
|
||||
return nil, ver
|
||||
end
|
||||
|
||||
return ver
|
||||
end
|
||||
|
||||
|
||||
function _M.quit(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send("quit\r\n")
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
function _M.verbosity(self, level)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send("verbosity " .. level .. "\r\n")
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if line ~= 'OK' then
|
||||
return nil, line
|
||||
end
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
function _M.touch(self, key, exptime)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send("touch " .. self.escape_key(key) .. " "
|
||||
.. exptime .. "\r\n")
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
-- moxi server from couchbase returned stored after touching
|
||||
if line == "TOUCHED" or line =="STORED" then
|
||||
return 1
|
||||
end
|
||||
return nil, line
|
||||
end
|
||||
|
||||
|
||||
function _M.close(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:close()
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
1410
resty/mysql.lua
Normal file
1410
resty/mysql.lua
Normal file
File diff suppressed because it is too large
Load Diff
36
resty/random.lua
Normal file
36
resty/random.lua
Normal file
@@ -0,0 +1,36 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
--local setmetatable = setmetatable
|
||||
--local error = error
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
int RAND_bytes(unsigned char *buf, int num);
|
||||
int RAND_pseudo_bytes(unsigned char *buf, int num);
|
||||
]]
|
||||
|
||||
|
||||
function _M.bytes(len, strong)
|
||||
local buf = ffi_new("char[?]", len)
|
||||
if strong then
|
||||
if C.RAND_bytes(buf, len) == 0 then
|
||||
return nil
|
||||
end
|
||||
else
|
||||
C.RAND_pseudo_bytes(buf,len)
|
||||
end
|
||||
|
||||
return ffi_str(buf, len)
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
|
||||
676
resty/redis.lua
Normal file
676
resty/redis.lua
Normal file
@@ -0,0 +1,676 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local sub = string.sub
|
||||
local byte = string.byte
|
||||
local tab_insert = table.insert
|
||||
local tab_remove = table.remove
|
||||
local tcp = ngx.socket.tcp
|
||||
local null = ngx.null
|
||||
local ipairs = ipairs
|
||||
local type = type
|
||||
local pairs = pairs
|
||||
local unpack = unpack
|
||||
local setmetatable = setmetatable
|
||||
local tonumber = tonumber
|
||||
local tostring = tostring
|
||||
local rawget = rawget
|
||||
local select = select
|
||||
--local error = error
|
||||
|
||||
|
||||
local ok, new_tab = pcall(require, "table.new")
|
||||
if not ok or type(new_tab) ~= "function" then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
|
||||
local _M = new_tab(0, 55)
|
||||
|
||||
_M._VERSION = '0.29'
|
||||
|
||||
|
||||
local common_cmds = {
|
||||
"get", "set", "mget", "mset",
|
||||
"del", "incr", "decr", -- Strings
|
||||
"llen", "lindex", "lpop", "lpush",
|
||||
"lrange", "linsert", -- Lists
|
||||
"hexists", "hget", "hset", "hmget",
|
||||
--[[ "hmset", ]] "hdel", -- Hashes
|
||||
"smembers", "sismember", "sadd", "srem",
|
||||
"sdiff", "sinter", "sunion", -- Sets
|
||||
"zrange", "zrangebyscore", "zrank", "zadd",
|
||||
"zrem", "zincrby", -- Sorted Sets
|
||||
"auth", "eval", "expire", "script",
|
||||
"sort" -- Others
|
||||
}
|
||||
|
||||
|
||||
local sub_commands = {
|
||||
"subscribe", "psubscribe"
|
||||
}
|
||||
|
||||
|
||||
local unsub_commands = {
|
||||
"unsubscribe", "punsubscribe"
|
||||
}
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local sock, err = tcp()
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
return setmetatable({ _sock = sock,
|
||||
_subscribed = false,
|
||||
_n_channel = {
|
||||
unsubscribe = 0,
|
||||
punsubscribe = 0,
|
||||
},
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, timeout)
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
error("not initialized", 2)
|
||||
return
|
||||
end
|
||||
|
||||
sock:settimeout(timeout)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeouts(self, connect_timeout, send_timeout, read_timeout)
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
error("not initialized", 2)
|
||||
return
|
||||
end
|
||||
|
||||
sock:settimeouts(connect_timeout, send_timeout, read_timeout)
|
||||
end
|
||||
|
||||
|
||||
function _M.connect(self, host, port_or_opts, opts)
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local unix
|
||||
|
||||
do
|
||||
local typ = type(host)
|
||||
if typ ~= "string" then
|
||||
error("bad argument #1 host: string expected, got " .. typ, 2)
|
||||
end
|
||||
|
||||
if sub(host, 1, 5) == "unix:" then
|
||||
unix = true
|
||||
end
|
||||
|
||||
if unix then
|
||||
typ = type(port_or_opts)
|
||||
if port_or_opts ~= nil and typ ~= "table" then
|
||||
error("bad argument #2 opts: nil or table expected, got " ..
|
||||
typ, 2)
|
||||
end
|
||||
|
||||
else
|
||||
typ = type(port_or_opts)
|
||||
if typ ~= "number" then
|
||||
port_or_opts = tonumber(port_or_opts)
|
||||
if port_or_opts == nil then
|
||||
error("bad argument #2 port: number expected, got " ..
|
||||
typ, 2)
|
||||
end
|
||||
end
|
||||
|
||||
if opts ~= nil then
|
||||
typ = type(opts)
|
||||
if typ ~= "table" then
|
||||
error("bad argument #3 opts: nil or table expected, got " ..
|
||||
typ, 2)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
self._subscribed = false
|
||||
|
||||
local ok, err
|
||||
|
||||
if unix then
|
||||
-- second argument of sock:connect() cannot be nil
|
||||
if port_or_opts ~= nil then
|
||||
ok, err = sock:connect(host, port_or_opts)
|
||||
opts = port_or_opts
|
||||
else
|
||||
ok, err = sock:connect(host)
|
||||
end
|
||||
else
|
||||
ok, err = sock:connect(host, port_or_opts, opts)
|
||||
end
|
||||
|
||||
if not ok then
|
||||
return ok, err
|
||||
end
|
||||
|
||||
if opts and opts.ssl then
|
||||
ok, err = sock:sslhandshake(false, opts.server_name, opts.ssl_verify)
|
||||
if not ok then
|
||||
return ok, "failed to do ssl handshake: " .. err
|
||||
end
|
||||
end
|
||||
|
||||
return ok, err
|
||||
end
|
||||
|
||||
|
||||
function _M.set_keepalive(self, ...)
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
if rawget(self, "_subscribed") then
|
||||
return nil, "subscribed state"
|
||||
end
|
||||
|
||||
return sock:setkeepalive(...)
|
||||
end
|
||||
|
||||
|
||||
function _M.get_reused_times(self)
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:getreusedtimes()
|
||||
end
|
||||
|
||||
|
||||
local function close(self)
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:close()
|
||||
end
|
||||
_M.close = close
|
||||
|
||||
|
||||
local function _read_reply(self, sock)
|
||||
local line, err = sock:receive()
|
||||
if not line then
|
||||
if err == "timeout" and not rawget(self, "_subscribed") then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local prefix = byte(line)
|
||||
|
||||
if prefix == 36 then -- char '$'
|
||||
-- print("bulk reply")
|
||||
|
||||
local size = tonumber(sub(line, 2))
|
||||
if size < 0 then
|
||||
return null
|
||||
end
|
||||
|
||||
local data, err = sock:receive(size)
|
||||
if not data then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local dummy, err = sock:receive(2) -- ignore CRLF
|
||||
if not dummy then
|
||||
if err == "timeout" then
|
||||
sock:close()
|
||||
end
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return data
|
||||
|
||||
elseif prefix == 43 then -- char '+'
|
||||
-- print("status reply")
|
||||
|
||||
return sub(line, 2)
|
||||
|
||||
elseif prefix == 42 then -- char '*'
|
||||
local n = tonumber(sub(line, 2))
|
||||
|
||||
-- print("multi-bulk reply: ", n)
|
||||
if n < 0 then
|
||||
return null
|
||||
end
|
||||
|
||||
local vals = new_tab(n, 0)
|
||||
local nvals = 0
|
||||
for i = 1, n do
|
||||
local res, err = _read_reply(self, sock)
|
||||
if res then
|
||||
nvals = nvals + 1
|
||||
vals[nvals] = res
|
||||
|
||||
elseif res == nil then
|
||||
return nil, err
|
||||
|
||||
else
|
||||
-- be a valid redis error value
|
||||
nvals = nvals + 1
|
||||
vals[nvals] = {false, err}
|
||||
end
|
||||
end
|
||||
|
||||
return vals
|
||||
|
||||
elseif prefix == 58 then -- char ':'
|
||||
-- print("integer reply")
|
||||
return tonumber(sub(line, 2))
|
||||
|
||||
elseif prefix == 45 then -- char '-'
|
||||
-- print("error reply: ", n)
|
||||
|
||||
return false, sub(line, 2)
|
||||
|
||||
else
|
||||
-- when `line` is an empty string, `prefix` will be equal to nil.
|
||||
return nil, "unknown prefix: \"" .. tostring(prefix) .. "\""
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function _gen_req(args)
|
||||
local nargs = #args
|
||||
|
||||
local req = new_tab(nargs * 5 + 1, 0)
|
||||
req[1] = "*" .. nargs .. "\r\n"
|
||||
local nbits = 2
|
||||
|
||||
for i = 1, nargs do
|
||||
local arg = args[i]
|
||||
if type(arg) ~= "string" then
|
||||
arg = tostring(arg)
|
||||
end
|
||||
|
||||
req[nbits] = "$"
|
||||
req[nbits + 1] = #arg
|
||||
req[nbits + 2] = "\r\n"
|
||||
req[nbits + 3] = arg
|
||||
req[nbits + 4] = "\r\n"
|
||||
|
||||
nbits = nbits + 5
|
||||
end
|
||||
|
||||
-- it is much faster to do string concatenation on the C land
|
||||
-- in real world (large number of strings in the Lua VM)
|
||||
return req
|
||||
end
|
||||
|
||||
|
||||
local function _check_msg(self, res)
|
||||
return rawget(self, "_subscribed") and
|
||||
type(res) == "table" and (res[1] == "message" or res[1] == "pmessage")
|
||||
end
|
||||
|
||||
|
||||
local function _do_cmd(self, ...)
|
||||
local args = {...}
|
||||
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local req = _gen_req(args)
|
||||
|
||||
local reqs = rawget(self, "_reqs")
|
||||
if reqs then
|
||||
reqs[#reqs + 1] = req
|
||||
return
|
||||
end
|
||||
|
||||
-- print("request: ", table.concat(req))
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local res, err = _read_reply(self, sock)
|
||||
while _check_msg(self, res) do
|
||||
if rawget(self, "_buffered_msg") == nil then
|
||||
self._buffered_msg = new_tab(1, 0)
|
||||
end
|
||||
|
||||
tab_insert(self._buffered_msg, res)
|
||||
res, err = _read_reply(self, sock)
|
||||
end
|
||||
|
||||
return res, err
|
||||
end
|
||||
|
||||
|
||||
local function _check_unsubscribed(self, res)
|
||||
if type(res) == "table"
|
||||
and (res[1] == "unsubscribe" or res[1] == "punsubscribe")
|
||||
then
|
||||
self._n_channel[res[1]] = self._n_channel[res[1]] - 1
|
||||
|
||||
local buffered_msg = rawget(self, "_buffered_msg")
|
||||
if buffered_msg then
|
||||
-- remove messages of unsubscribed channel
|
||||
local msg_type =
|
||||
(res[1] == "punsubscribe") and "pmessage" or "message"
|
||||
local j = 1
|
||||
for _, msg in ipairs(buffered_msg) do
|
||||
if msg[1] == msg_type and msg[2] ~= res[2] then
|
||||
-- move messages to overwrite the removed ones
|
||||
buffered_msg[j] = msg
|
||||
j = j + 1
|
||||
end
|
||||
end
|
||||
|
||||
-- clear remain messages
|
||||
for i = j, #buffered_msg do
|
||||
buffered_msg[i] = nil
|
||||
end
|
||||
|
||||
if #buffered_msg == 0 then
|
||||
self._buffered_msg = nil
|
||||
end
|
||||
end
|
||||
|
||||
if res[3] == 0 then
|
||||
-- all channels are unsubscribed
|
||||
self._subscribed = false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function _check_subscribed(self, res)
|
||||
if type(res) == "table"
|
||||
and (res[1] == "subscribe" or res[1] == "psubscribe")
|
||||
then
|
||||
if res[1] == "subscribe" then
|
||||
self._n_channel.unsubscribe = self._n_channel.unsubscribe + 1
|
||||
|
||||
elseif res[1] == "psubscribe" then
|
||||
self._n_channel.punsubscribe = self._n_channel.punsubscribe + 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function _M.read_reply(self)
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
if not rawget(self, "_subscribed") then
|
||||
return nil, "not subscribed"
|
||||
end
|
||||
|
||||
local buffered_msg = rawget(self, "_buffered_msg")
|
||||
if buffered_msg then
|
||||
local msg = buffered_msg[1]
|
||||
tab_remove(buffered_msg, 1)
|
||||
|
||||
if #buffered_msg == 0 then
|
||||
self._buffered_msg = nil
|
||||
end
|
||||
|
||||
return msg
|
||||
end
|
||||
|
||||
local res, err = _read_reply(self, sock)
|
||||
_check_unsubscribed(self, res)
|
||||
|
||||
return res, err
|
||||
end
|
||||
|
||||
|
||||
for i = 1, #common_cmds do
|
||||
local cmd = common_cmds[i]
|
||||
|
||||
_M[cmd] =
|
||||
function (self, ...)
|
||||
return _do_cmd(self, cmd, ...)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function handle_subscribe_result(self, cmd, nargs, res)
|
||||
local err
|
||||
_check_subscribed(self, res)
|
||||
|
||||
if nargs <= 1 then
|
||||
return res
|
||||
end
|
||||
|
||||
local results = new_tab(nargs, 0)
|
||||
results[1] = res
|
||||
local sock = rawget(self, "_sock")
|
||||
|
||||
for i = 2, nargs do
|
||||
res, err = _read_reply(self, sock)
|
||||
if not res then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
_check_subscribed(self, res)
|
||||
results[i] = res
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
for i = 1, #sub_commands do
|
||||
local cmd = sub_commands[i]
|
||||
|
||||
_M[cmd] =
|
||||
function (self, ...)
|
||||
if not rawget(self, "_subscribed") then
|
||||
self._subscribed = true
|
||||
end
|
||||
|
||||
local nargs = select("#", ...)
|
||||
|
||||
local res, err = _do_cmd(self, cmd, ...)
|
||||
if not res then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return handle_subscribe_result(self, cmd, nargs, res)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function handle_unsubscribe_result(self, cmd, nargs, res)
|
||||
local err
|
||||
_check_unsubscribed(self, res)
|
||||
|
||||
if self._n_channel[cmd] == 0 or nargs == 1 then
|
||||
return res
|
||||
end
|
||||
|
||||
local results = new_tab(nargs, 0)
|
||||
results[1] = res
|
||||
local sock = rawget(self, "_sock")
|
||||
local i = 2
|
||||
|
||||
while nargs == 0 or i <= nargs do
|
||||
res, err = _read_reply(self, sock)
|
||||
if not res then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
results[i] = res
|
||||
i = i + 1
|
||||
|
||||
_check_unsubscribed(self, res)
|
||||
if self._n_channel[cmd] == 0 then
|
||||
-- exit the loop for unsubscribe() call
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
return results
|
||||
end
|
||||
|
||||
for i = 1, #unsub_commands do
|
||||
local cmd = unsub_commands[i]
|
||||
|
||||
_M[cmd] =
|
||||
function (self, ...)
|
||||
-- assume all channels are unsubscribed by only one time
|
||||
if not rawget(self, "_subscribed") then
|
||||
return nil, "not subscribed"
|
||||
end
|
||||
|
||||
local nargs = select("#", ...)
|
||||
|
||||
local res, err = _do_cmd(self, cmd, ...)
|
||||
if not res then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return handle_unsubscribe_result(self, cmd, nargs, res)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
function _M.hmset(self, hashname, ...)
|
||||
if select('#', ...) == 1 then
|
||||
local t = select(1, ...)
|
||||
|
||||
local n = 0
|
||||
for k, v in pairs(t) do
|
||||
n = n + 2
|
||||
end
|
||||
|
||||
local array = new_tab(n, 0)
|
||||
|
||||
local i = 0
|
||||
for k, v in pairs(t) do
|
||||
array[i + 1] = k
|
||||
array[i + 2] = v
|
||||
i = i + 2
|
||||
end
|
||||
-- print("key", hashname)
|
||||
return _do_cmd(self, "hmset", hashname, unpack(array))
|
||||
end
|
||||
|
||||
-- backwards compatibility
|
||||
return _do_cmd(self, "hmset", hashname, ...)
|
||||
end
|
||||
|
||||
|
||||
function _M.init_pipeline(self, n)
|
||||
self._reqs = new_tab(n or 4, 0)
|
||||
end
|
||||
|
||||
|
||||
function _M.cancel_pipeline(self)
|
||||
self._reqs = nil
|
||||
end
|
||||
|
||||
|
||||
function _M.commit_pipeline(self)
|
||||
local reqs = rawget(self, "_reqs")
|
||||
if not reqs then
|
||||
return nil, "no pipeline"
|
||||
end
|
||||
|
||||
self._reqs = nil
|
||||
|
||||
local sock = rawget(self, "_sock")
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local bytes, err = sock:send(reqs)
|
||||
if not bytes then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local nvals = 0
|
||||
local nreqs = #reqs
|
||||
local vals = new_tab(nreqs, 0)
|
||||
for i = 1, nreqs do
|
||||
local res, err = _read_reply(self, sock)
|
||||
if res then
|
||||
nvals = nvals + 1
|
||||
vals[nvals] = res
|
||||
|
||||
elseif res == nil then
|
||||
if err == "timeout" then
|
||||
close(self)
|
||||
end
|
||||
return nil, err
|
||||
|
||||
else
|
||||
-- be a valid redis error value
|
||||
nvals = nvals + 1
|
||||
vals[nvals] = {false, err}
|
||||
end
|
||||
end
|
||||
|
||||
return vals
|
||||
end
|
||||
|
||||
|
||||
function _M.array_to_hash(self, t)
|
||||
local n = #t
|
||||
-- print("n = ", n)
|
||||
local h = new_tab(0, n / 2)
|
||||
for i = 1, n, 2 do
|
||||
h[t[i]] = t[i + 1]
|
||||
end
|
||||
return h
|
||||
end
|
||||
|
||||
|
||||
-- this method is deperate since we already do lazy method generation.
|
||||
function _M.add_commands(...)
|
||||
local cmds = {...}
|
||||
for i = 1, #cmds do
|
||||
local cmd = cmds[i]
|
||||
_M[cmd] =
|
||||
function (self, ...)
|
||||
return _do_cmd(self, cmd, ...)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
setmetatable(_M, {__index = function(self, cmd)
|
||||
local method =
|
||||
function (self, ...)
|
||||
return _do_cmd(self, cmd, ...)
|
||||
end
|
||||
|
||||
-- cache the lazily generated method in our
|
||||
-- module table
|
||||
_M[cmd] = method
|
||||
return method
|
||||
end})
|
||||
|
||||
|
||||
return _M
|
||||
19
resty/sha.lua
Normal file
19
resty/sha.lua
Normal file
@@ -0,0 +1,19 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef unsigned long SHA_LONG;
|
||||
typedef unsigned long long SHA_LONG64;
|
||||
|
||||
enum {
|
||||
SHA_LBLOCK = 16
|
||||
};
|
||||
]];
|
||||
|
||||
return _M
|
||||
69
resty/sha1.lua
Normal file
69
resty/sha1.lua
Normal file
@@ -0,0 +1,69 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
require "resty.sha"
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
local setmetatable = setmetatable
|
||||
--local error = error
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef struct SHAstate_st
|
||||
{
|
||||
SHA_LONG h0,h1,h2,h3,h4;
|
||||
SHA_LONG Nl,Nh;
|
||||
SHA_LONG data[SHA_LBLOCK];
|
||||
unsigned int num;
|
||||
} SHA_CTX;
|
||||
|
||||
int SHA1_Init(SHA_CTX *c);
|
||||
int SHA1_Update(SHA_CTX *c, const void *data, size_t len);
|
||||
int SHA1_Final(unsigned char *md, SHA_CTX *c);
|
||||
]]
|
||||
|
||||
local digest_len = 20
|
||||
|
||||
local buf = ffi_new("char[?]", digest_len)
|
||||
local ctx_ptr_type = ffi.typeof("SHA_CTX[1]")
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local ctx = ffi_new(ctx_ptr_type)
|
||||
if C.SHA1_Init(ctx) == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return setmetatable({ _ctx = ctx }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.update(self, s)
|
||||
return C.SHA1_Update(self._ctx, s, #s) == 1
|
||||
end
|
||||
|
||||
|
||||
function _M.final(self)
|
||||
if C.SHA1_Final(buf, self._ctx) == 1 then
|
||||
return ffi_str(buf, digest_len)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
|
||||
function _M.reset(self)
|
||||
return C.SHA1_Init(self._ctx) == 1
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
|
||||
60
resty/sha224.lua
Normal file
60
resty/sha224.lua
Normal file
@@ -0,0 +1,60 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
require "resty.sha256"
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
local setmetatable = setmetatable
|
||||
--local error = error
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
int SHA224_Init(SHA256_CTX *c);
|
||||
int SHA224_Update(SHA256_CTX *c, const void *data, size_t len);
|
||||
int SHA224_Final(unsigned char *md, SHA256_CTX *c);
|
||||
]]
|
||||
|
||||
local digest_len = 28
|
||||
|
||||
local buf = ffi_new("char[?]", digest_len)
|
||||
local ctx_ptr_type = ffi.typeof("SHA256_CTX[1]")
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local ctx = ffi_new(ctx_ptr_type)
|
||||
if C.SHA224_Init(ctx) == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return setmetatable({ _ctx = ctx }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.update(self, s)
|
||||
return C.SHA224_Update(self._ctx, s, #s) == 1
|
||||
end
|
||||
|
||||
|
||||
function _M.final(self)
|
||||
if C.SHA224_Final(buf, self._ctx) == 1 then
|
||||
return ffi_str(buf, digest_len)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
|
||||
function _M.reset(self)
|
||||
return C.SHA224_Init(self._ctx) == 1
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
69
resty/sha256.lua
Normal file
69
resty/sha256.lua
Normal file
@@ -0,0 +1,69 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
require "resty.sha"
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
local setmetatable = setmetatable
|
||||
--local error = error
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef struct SHA256state_st
|
||||
{
|
||||
SHA_LONG h[8];
|
||||
SHA_LONG Nl,Nh;
|
||||
SHA_LONG data[SHA_LBLOCK];
|
||||
unsigned int num,md_len;
|
||||
} SHA256_CTX;
|
||||
|
||||
int SHA256_Init(SHA256_CTX *c);
|
||||
int SHA256_Update(SHA256_CTX *c, const void *data, size_t len);
|
||||
int SHA256_Final(unsigned char *md, SHA256_CTX *c);
|
||||
]]
|
||||
|
||||
local digest_len = 32
|
||||
|
||||
local buf = ffi_new("char[?]", digest_len)
|
||||
local ctx_ptr_type = ffi.typeof("SHA256_CTX[1]")
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local ctx = ffi_new(ctx_ptr_type)
|
||||
if C.SHA256_Init(ctx) == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return setmetatable({ _ctx = ctx }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.update(self, s)
|
||||
return C.SHA256_Update(self._ctx, s, #s) == 1
|
||||
end
|
||||
|
||||
|
||||
function _M.final(self)
|
||||
if C.SHA256_Final(buf, self._ctx) == 1 then
|
||||
return ffi_str(buf, digest_len)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
|
||||
function _M.reset(self)
|
||||
return C.SHA256_Init(self._ctx) == 1
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
|
||||
60
resty/sha384.lua
Normal file
60
resty/sha384.lua
Normal file
@@ -0,0 +1,60 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
require "resty.sha512"
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
local setmetatable = setmetatable
|
||||
--local error = error
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
int SHA384_Init(SHA512_CTX *c);
|
||||
int SHA384_Update(SHA512_CTX *c, const void *data, size_t len);
|
||||
int SHA384_Final(unsigned char *md, SHA512_CTX *c);
|
||||
]]
|
||||
|
||||
local digest_len = 48
|
||||
|
||||
local buf = ffi_new("char[?]", digest_len)
|
||||
local ctx_ptr_type = ffi.typeof("SHA512_CTX[1]")
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local ctx = ffi_new(ctx_ptr_type)
|
||||
if C.SHA384_Init(ctx) == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return setmetatable({ _ctx = ctx }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.update(self, s)
|
||||
return C.SHA384_Update(self._ctx, s, #s) == 1
|
||||
end
|
||||
|
||||
|
||||
function _M.final(self)
|
||||
if C.SHA384_Final(buf, self._ctx) == 1 then
|
||||
return ffi_str(buf, digest_len)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
|
||||
function _M.reset(self)
|
||||
return C.SHA384_Init(self._ctx) == 1
|
||||
end
|
||||
|
||||
return _M
|
||||
|
||||
75
resty/sha512.lua
Normal file
75
resty/sha512.lua
Normal file
@@ -0,0 +1,75 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
require "resty.sha"
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
local setmetatable = setmetatable
|
||||
--local error = error
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
enum {
|
||||
SHA512_CBLOCK = SHA_LBLOCK*8
|
||||
};
|
||||
|
||||
typedef struct SHA512state_st
|
||||
{
|
||||
SHA_LONG64 h[8];
|
||||
SHA_LONG64 Nl,Nh;
|
||||
union {
|
||||
SHA_LONG64 d[SHA_LBLOCK];
|
||||
unsigned char p[SHA512_CBLOCK];
|
||||
} u;
|
||||
unsigned int num,md_len;
|
||||
} SHA512_CTX;
|
||||
|
||||
int SHA512_Init(SHA512_CTX *c);
|
||||
int SHA512_Update(SHA512_CTX *c, const void *data, size_t len);
|
||||
int SHA512_Final(unsigned char *md, SHA512_CTX *c);
|
||||
]]
|
||||
|
||||
local digest_len = 64
|
||||
|
||||
local buf = ffi_new("char[?]", digest_len)
|
||||
local ctx_ptr_type = ffi.typeof("SHA512_CTX[1]")
|
||||
|
||||
|
||||
function _M.new(self)
|
||||
local ctx = ffi_new(ctx_ptr_type)
|
||||
if C.SHA512_Init(ctx) == 0 then
|
||||
return nil
|
||||
end
|
||||
|
||||
return setmetatable({ _ctx = ctx }, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.update(self, s)
|
||||
return C.SHA512_Update(self._ctx, s, #s) == 1
|
||||
end
|
||||
|
||||
|
||||
function _M.final(self)
|
||||
if C.SHA512_Final(buf, self._ctx) == 1 then
|
||||
return ffi_str(buf, digest_len)
|
||||
end
|
||||
|
||||
return nil
|
||||
end
|
||||
|
||||
|
||||
function _M.reset(self)
|
||||
return C.SHA512_Init(self._ctx) == 1
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
40
resty/string.lua
Normal file
40
resty/string.lua
Normal file
@@ -0,0 +1,40 @@
|
||||
-- Copyright (C) by Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local ffi = require "ffi"
|
||||
local ffi_new = ffi.new
|
||||
local ffi_str = ffi.string
|
||||
local C = ffi.C
|
||||
--local setmetatable = setmetatable
|
||||
--local error = error
|
||||
local tonumber = tonumber
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.14' }
|
||||
|
||||
|
||||
ffi.cdef[[
|
||||
typedef unsigned char u_char;
|
||||
|
||||
u_char * ngx_hex_dump(u_char *dst, const u_char *src, size_t len);
|
||||
|
||||
intptr_t ngx_atoi(const unsigned char *line, size_t n);
|
||||
]]
|
||||
|
||||
local str_type = ffi.typeof("uint8_t[?]")
|
||||
|
||||
|
||||
function _M.to_hex(s)
|
||||
local len = #s
|
||||
local buf_len = len * 2
|
||||
local buf = ffi_new(str_type, buf_len)
|
||||
C.ngx_hex_dump(buf, s, len)
|
||||
return ffi_str(buf, buf_len)
|
||||
end
|
||||
|
||||
function _M.atoi(s)
|
||||
return tonumber(C.ngx_atoi(s, #s))
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
267
resty/upload.lua
Normal file
267
resty/upload.lua
Normal file
@@ -0,0 +1,267 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
-- local sub = string.sub
|
||||
local req_socket = ngx.req.socket
|
||||
local match = string.match
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local ngx_var = ngx.var
|
||||
-- local print = print
|
||||
|
||||
|
||||
local _M = { _VERSION = '0.10' }
|
||||
|
||||
|
||||
local CHUNK_SIZE = 4096
|
||||
local MAX_LINE_SIZE = 512
|
||||
|
||||
local STATE_BEGIN = 1
|
||||
local STATE_READING_HEADER = 2
|
||||
local STATE_READING_BODY = 3
|
||||
local STATE_EOF = 4
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
local state_handlers
|
||||
|
||||
|
||||
local function get_boundary()
|
||||
local header = ngx_var.content_type
|
||||
if not header then
|
||||
return nil
|
||||
end
|
||||
|
||||
if type(header) == "table" then
|
||||
header = header[1]
|
||||
end
|
||||
|
||||
local m = match(header, ";%s*boundary=\"([^\"]+)\"")
|
||||
if m then
|
||||
return m
|
||||
end
|
||||
|
||||
return match(header, ";%s*boundary=([^\",;]+)")
|
||||
end
|
||||
|
||||
|
||||
function _M.new(self, chunk_size, max_line_size)
|
||||
local boundary = get_boundary()
|
||||
|
||||
-- print("boundary: ", boundary)
|
||||
|
||||
if not boundary then
|
||||
return nil, "no boundary defined in Content-Type"
|
||||
end
|
||||
|
||||
-- print('boundary: "', boundary, '"')
|
||||
|
||||
local sock, err = req_socket()
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local read2boundary, err = sock:receiveuntil("--" .. boundary)
|
||||
if not read2boundary then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local read_line, err = sock:receiveuntil("\r\n")
|
||||
if not read_line then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
sock = sock,
|
||||
size = chunk_size or CHUNK_SIZE,
|
||||
line_size = max_line_size or MAX_LINE_SIZE,
|
||||
read2boundary = read2boundary,
|
||||
read_line = read_line,
|
||||
boundary = boundary,
|
||||
state = STATE_BEGIN
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, timeout)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:settimeout(timeout)
|
||||
end
|
||||
|
||||
|
||||
local function discard_line(self)
|
||||
local read_line = self.read_line
|
||||
|
||||
local line, err = read_line(self.line_size)
|
||||
if not line then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local dummy, err = read_line(1)
|
||||
if dummy then
|
||||
return nil, "line too long: " .. line .. dummy .. "..."
|
||||
end
|
||||
|
||||
if err then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
local function discard_rest(self)
|
||||
local sock = self.sock
|
||||
local size = self.size
|
||||
|
||||
while true do
|
||||
local dummy, err = sock:receive(size)
|
||||
if err and err ~= 'closed' then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
if not dummy then
|
||||
return 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
local function read_body_part(self)
|
||||
local read2boundary = self.read2boundary
|
||||
|
||||
local chunk, err = read2boundary(self.size)
|
||||
if err then
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
if not chunk then
|
||||
local sock = self.sock
|
||||
|
||||
local data = sock:receive(2)
|
||||
if data == "--" then
|
||||
local ok, err = discard_rest(self)
|
||||
if not ok then
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
self.state = STATE_EOF
|
||||
return "part_end"
|
||||
end
|
||||
|
||||
if data ~= "\r\n" then
|
||||
local ok, err = discard_line(self)
|
||||
if not ok then
|
||||
return nil, nil, err
|
||||
end
|
||||
end
|
||||
|
||||
self.state = STATE_READING_HEADER
|
||||
return "part_end"
|
||||
end
|
||||
|
||||
return "body", chunk
|
||||
end
|
||||
|
||||
|
||||
local function read_header(self)
|
||||
local read_line = self.read_line
|
||||
|
||||
local line, err = read_line(self.line_size)
|
||||
if err then
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
local dummy, err = read_line(1)
|
||||
if dummy then
|
||||
return nil, nil, "line too long: " .. line .. dummy .. "..."
|
||||
end
|
||||
|
||||
if err then
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
-- print("read line: ", line)
|
||||
|
||||
if line == "" then
|
||||
-- after the last header
|
||||
self.state = STATE_READING_BODY
|
||||
return read_body_part(self)
|
||||
end
|
||||
|
||||
local key, value = match(line, "([^: \t]+)%s*:%s*(.+)")
|
||||
if not key then
|
||||
return 'header', line
|
||||
end
|
||||
|
||||
return 'header', {key, value, line}
|
||||
end
|
||||
|
||||
|
||||
local function eof()
|
||||
return "eof", nil
|
||||
end
|
||||
|
||||
|
||||
function _M.read(self)
|
||||
-- local size = self.size
|
||||
|
||||
local handler = state_handlers[self.state]
|
||||
if handler then
|
||||
return handler(self)
|
||||
end
|
||||
|
||||
return nil, nil, "bad state: " .. self.state
|
||||
end
|
||||
|
||||
|
||||
local function read_preamble(self)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized"
|
||||
end
|
||||
|
||||
local size = self.size
|
||||
local read2boundary = self.read2boundary
|
||||
|
||||
while true do
|
||||
local preamble = read2boundary(size)
|
||||
if not preamble then
|
||||
break
|
||||
end
|
||||
|
||||
-- discard the preamble data chunk
|
||||
-- print("read preamble: ", preamble)
|
||||
end
|
||||
|
||||
local ok, err = discard_line(self)
|
||||
if not ok then
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
local read2boundary, err = sock:receiveuntil("\r\n--" .. self.boundary)
|
||||
if not read2boundary then
|
||||
return nil, nil, err
|
||||
end
|
||||
|
||||
self.read2boundary = read2boundary
|
||||
|
||||
self.state = STATE_READING_HEADER
|
||||
return read_header(self)
|
||||
end
|
||||
|
||||
|
||||
state_handlers = {
|
||||
read_preamble,
|
||||
read_header,
|
||||
read_body_part,
|
||||
eof
|
||||
}
|
||||
|
||||
|
||||
return _M
|
||||
707
resty/upstream/healthcheck.lua
Normal file
707
resty/upstream/healthcheck.lua
Normal file
@@ -0,0 +1,707 @@
|
||||
local stream_sock = ngx.socket.tcp
|
||||
local log = ngx.log
|
||||
local ERR = ngx.ERR
|
||||
local WARN = ngx.WARN
|
||||
local DEBUG = ngx.DEBUG
|
||||
local sub = string.sub
|
||||
local re_find = ngx.re.find
|
||||
local new_timer = ngx.timer.at
|
||||
local shared = ngx.shared
|
||||
local debug_mode = ngx.config.debug
|
||||
local concat = table.concat
|
||||
local tonumber = tonumber
|
||||
local tostring = tostring
|
||||
local ipairs = ipairs
|
||||
local ceil = math.ceil
|
||||
local spawn = ngx.thread.spawn
|
||||
local wait = ngx.thread.wait
|
||||
local pcall = pcall
|
||||
|
||||
local _M = {
|
||||
_VERSION = '0.05'
|
||||
}
|
||||
|
||||
if not ngx.config
|
||||
or not ngx.config.ngx_lua_version
|
||||
or ngx.config.ngx_lua_version < 9005
|
||||
then
|
||||
error("ngx_lua 0.9.5+ required")
|
||||
end
|
||||
|
||||
local ok, upstream = pcall(require, "ngx.upstream")
|
||||
if not ok then
|
||||
error("ngx_upstream_lua module required")
|
||||
end
|
||||
|
||||
local ok, new_tab = pcall(require, "table.new")
|
||||
if not ok or type(new_tab) ~= "function" then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
local set_peer_down = upstream.set_peer_down
|
||||
local get_primary_peers = upstream.get_primary_peers
|
||||
local get_backup_peers = upstream.get_backup_peers
|
||||
local get_upstreams = upstream.get_upstreams
|
||||
|
||||
local upstream_checker_statuses = {}
|
||||
|
||||
local function warn(...)
|
||||
log(WARN, "healthcheck: ", ...)
|
||||
end
|
||||
|
||||
local function errlog(...)
|
||||
log(ERR, "healthcheck: ", ...)
|
||||
end
|
||||
|
||||
local function debug(...)
|
||||
-- print("debug mode: ", debug_mode)
|
||||
if debug_mode then
|
||||
log(DEBUG, "healthcheck: ", ...)
|
||||
end
|
||||
end
|
||||
|
||||
local function gen_peer_key(prefix, u, is_backup, id)
|
||||
if is_backup then
|
||||
return prefix .. u .. ":b" .. id
|
||||
end
|
||||
return prefix .. u .. ":p" .. id
|
||||
end
|
||||
|
||||
local function set_peer_down_globally(ctx, is_backup, id, value)
|
||||
local u = ctx.upstream
|
||||
local dict = ctx.dict
|
||||
local ok, err = set_peer_down(u, is_backup, id, value)
|
||||
if not ok then
|
||||
errlog("failed to set peer down: ", err)
|
||||
end
|
||||
|
||||
if not ctx.new_version then
|
||||
ctx.new_version = true
|
||||
end
|
||||
|
||||
local key = gen_peer_key("d:", u, is_backup, id)
|
||||
local ok, err = dict:set(key, value)
|
||||
if not ok then
|
||||
errlog("failed to set peer down state: ", err)
|
||||
end
|
||||
end
|
||||
|
||||
local function peer_fail(ctx, is_backup, id, peer)
|
||||
debug("peer ", peer.name, " was checked to be not ok")
|
||||
|
||||
local u = ctx.upstream
|
||||
local dict = ctx.dict
|
||||
|
||||
local key = gen_peer_key("nok:", u, is_backup, id)
|
||||
local fails, err = dict:get(key)
|
||||
if not fails then
|
||||
if err then
|
||||
errlog("failed to get peer nok key: ", err)
|
||||
return
|
||||
end
|
||||
fails = 1
|
||||
|
||||
-- below may have a race condition, but it is fine for our
|
||||
-- purpose here.
|
||||
local ok, err = dict:set(key, 1)
|
||||
if not ok then
|
||||
errlog("failed to set peer nok key: ", err)
|
||||
end
|
||||
else
|
||||
fails = fails + 1
|
||||
local ok, err = dict:incr(key, 1)
|
||||
if not ok then
|
||||
errlog("failed to incr peer nok key: ", err)
|
||||
end
|
||||
end
|
||||
|
||||
if fails == 1 then
|
||||
key = gen_peer_key("ok:", u, is_backup, id)
|
||||
local succ, err = dict:get(key)
|
||||
if not succ or succ == 0 then
|
||||
if err then
|
||||
errlog("failed to get peer ok key: ", err)
|
||||
return
|
||||
end
|
||||
else
|
||||
local ok, err = dict:set(key, 0)
|
||||
if not ok then
|
||||
errlog("failed to set peer ok key: ", err)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
-- print("ctx fall: ", ctx.fall, ", peer down: ", peer.down,
|
||||
-- ", fails: ", fails)
|
||||
|
||||
if not peer.down and fails >= ctx.fall then
|
||||
warn("peer ", peer.name, " is turned down after ", fails,
|
||||
" failure(s)")
|
||||
peer.down = true
|
||||
set_peer_down_globally(ctx, is_backup, id, true)
|
||||
end
|
||||
end
|
||||
|
||||
local function peer_ok(ctx, is_backup, id, peer)
|
||||
debug("peer ", peer.name, " was checked to be ok")
|
||||
|
||||
local u = ctx.upstream
|
||||
local dict = ctx.dict
|
||||
|
||||
local key = gen_peer_key("ok:", u, is_backup, id)
|
||||
local succ, err = dict:get(key)
|
||||
if not succ then
|
||||
if err then
|
||||
errlog("failed to get peer ok key: ", err)
|
||||
return
|
||||
end
|
||||
succ = 1
|
||||
|
||||
-- below may have a race condition, but it is fine for our
|
||||
-- purpose here.
|
||||
local ok, err = dict:set(key, 1)
|
||||
if not ok then
|
||||
errlog("failed to set peer ok key: ", err)
|
||||
end
|
||||
else
|
||||
succ = succ + 1
|
||||
local ok, err = dict:incr(key, 1)
|
||||
if not ok then
|
||||
errlog("failed to incr peer ok key: ", err)
|
||||
end
|
||||
end
|
||||
|
||||
if succ == 1 then
|
||||
key = gen_peer_key("nok:", u, is_backup, id)
|
||||
local fails, err = dict:get(key)
|
||||
if not fails or fails == 0 then
|
||||
if err then
|
||||
errlog("failed to get peer nok key: ", err)
|
||||
return
|
||||
end
|
||||
else
|
||||
local ok, err = dict:set(key, 0)
|
||||
if not ok then
|
||||
errlog("failed to set peer nok key: ", err)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
if peer.down and succ >= ctx.rise then
|
||||
warn("peer ", peer.name, " is turned up after ", succ,
|
||||
" success(es)")
|
||||
peer.down = nil
|
||||
set_peer_down_globally(ctx, is_backup, id, nil)
|
||||
end
|
||||
end
|
||||
|
||||
-- shortcut error function for check_peer()
|
||||
local function peer_error(ctx, is_backup, id, peer, ...)
|
||||
if not peer.down then
|
||||
errlog(...)
|
||||
end
|
||||
peer_fail(ctx, is_backup, id, peer)
|
||||
end
|
||||
|
||||
local function check_peer(ctx, id, peer, is_backup)
|
||||
local ok
|
||||
local name = peer.name
|
||||
local statuses = ctx.statuses
|
||||
local req = ctx.http_req
|
||||
|
||||
local sock, err = stream_sock()
|
||||
if not sock then
|
||||
errlog("failed to create stream socket: ", err)
|
||||
return
|
||||
end
|
||||
|
||||
sock:settimeout(ctx.timeout)
|
||||
|
||||
if peer.host then
|
||||
-- print("peer port: ", peer.port)
|
||||
ok, err = sock:connect(peer.host, peer.port)
|
||||
else
|
||||
ok, err = sock:connect(name)
|
||||
end
|
||||
if not ok then
|
||||
if not peer.down then
|
||||
errlog("failed to connect to ", name, ": ", err)
|
||||
end
|
||||
return peer_fail(ctx, is_backup, id, peer)
|
||||
end
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return peer_error(ctx, is_backup, id, peer,
|
||||
"failed to send request to ", name, ": ", err)
|
||||
end
|
||||
|
||||
local status_line, err = sock:receive()
|
||||
if not status_line then
|
||||
peer_error(ctx, is_backup, id, peer,
|
||||
"failed to receive status line from ", name, ": ", err)
|
||||
if err == "timeout" then
|
||||
sock:close() -- timeout errors do not close the socket.
|
||||
end
|
||||
return
|
||||
end
|
||||
|
||||
if statuses then
|
||||
local from, to, err = re_find(status_line,
|
||||
[[^HTTP/\d+\.\d+\s+(\d+)]],
|
||||
"joi", nil, 1)
|
||||
if err then
|
||||
errlog("failed to parse status line: ", err)
|
||||
end
|
||||
|
||||
if not from then
|
||||
peer_error(ctx, is_backup, id, peer,
|
||||
"bad status line from ", name, ": ",
|
||||
status_line)
|
||||
sock:close()
|
||||
return
|
||||
end
|
||||
|
||||
local status = tonumber(sub(status_line, from, to))
|
||||
if not statuses[status] then
|
||||
peer_error(ctx, is_backup, id, peer, "bad status code from ",
|
||||
name, ": ", status)
|
||||
sock:close()
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
peer_ok(ctx, is_backup, id, peer)
|
||||
sock:close()
|
||||
end
|
||||
|
||||
local function check_peer_range(ctx, from, to, peers, is_backup)
|
||||
for i = from, to do
|
||||
check_peer(ctx, i - 1, peers[i], is_backup)
|
||||
end
|
||||
end
|
||||
|
||||
local function check_peers(ctx, peers, is_backup)
|
||||
local n = #peers
|
||||
if n == 0 then
|
||||
return
|
||||
end
|
||||
|
||||
local concur = ctx.concurrency
|
||||
if concur <= 1 then
|
||||
for i = 1, n do
|
||||
check_peer(ctx, i - 1, peers[i], is_backup)
|
||||
end
|
||||
else
|
||||
local threads
|
||||
local nthr
|
||||
|
||||
if n <= concur then
|
||||
nthr = n - 1
|
||||
threads = new_tab(nthr, 0)
|
||||
for i = 1, nthr do
|
||||
|
||||
if debug_mode then
|
||||
debug("spawn a thread checking ",
|
||||
is_backup and "backup" or "primary", " peer ", i - 1)
|
||||
end
|
||||
|
||||
threads[i] = spawn(check_peer, ctx, i - 1, peers[i], is_backup)
|
||||
end
|
||||
-- use the current "light thread" to run the last task
|
||||
if debug_mode then
|
||||
debug("check ", is_backup and "backup" or "primary", " peer ",
|
||||
n - 1)
|
||||
end
|
||||
check_peer(ctx, n - 1, peers[n], is_backup)
|
||||
|
||||
else
|
||||
local group_size = ceil(n / concur)
|
||||
nthr = ceil(n / group_size) - 1
|
||||
|
||||
threads = new_tab(nthr, 0)
|
||||
local from = 1
|
||||
local rest = n
|
||||
for i = 1, nthr do
|
||||
local to
|
||||
if rest >= group_size then
|
||||
rest = rest - group_size
|
||||
to = from + group_size - 1
|
||||
else
|
||||
rest = 0
|
||||
to = from + rest - 1
|
||||
end
|
||||
|
||||
if debug_mode then
|
||||
debug("spawn a thread checking ",
|
||||
is_backup and "backup" or "primary", " peers ",
|
||||
from - 1, " to ", to - 1)
|
||||
end
|
||||
|
||||
threads[i] = spawn(check_peer_range, ctx, from, to, peers,
|
||||
is_backup)
|
||||
from = from + group_size
|
||||
if rest == 0 then
|
||||
break
|
||||
end
|
||||
end
|
||||
if rest > 0 then
|
||||
local to = from + rest - 1
|
||||
|
||||
if debug_mode then
|
||||
debug("check ", is_backup and "backup" or "primary",
|
||||
" peers ", from - 1, " to ", to - 1)
|
||||
end
|
||||
|
||||
check_peer_range(ctx, from, to, peers, is_backup)
|
||||
end
|
||||
end
|
||||
|
||||
if nthr and nthr > 0 then
|
||||
for i = 1, nthr do
|
||||
local t = threads[i]
|
||||
if t then
|
||||
wait(t)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function upgrade_peers_version(ctx, peers, is_backup)
|
||||
local dict = ctx.dict
|
||||
local u = ctx.upstream
|
||||
local n = #peers
|
||||
for i = 1, n do
|
||||
local peer = peers[i]
|
||||
local id = i - 1
|
||||
local key = gen_peer_key("d:", u, is_backup, id)
|
||||
local down = false
|
||||
local res, err = dict:get(key)
|
||||
if not res then
|
||||
if err then
|
||||
errlog("failed to get peer down state: ", err)
|
||||
end
|
||||
else
|
||||
down = true
|
||||
end
|
||||
if (peer.down and not down) or (not peer.down and down) then
|
||||
local ok, err = set_peer_down(u, is_backup, id, down)
|
||||
if not ok then
|
||||
errlog("failed to set peer down: ", err)
|
||||
else
|
||||
-- update our cache too
|
||||
peer.down = down
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local function check_peers_updates(ctx)
|
||||
local dict = ctx.dict
|
||||
local u = ctx.upstream
|
||||
local key = "v:" .. u
|
||||
local ver, err = dict:get(key)
|
||||
if not ver then
|
||||
if err then
|
||||
errlog("failed to get peers version: ", err)
|
||||
return
|
||||
end
|
||||
|
||||
if ctx.version > 0 then
|
||||
ctx.new_version = true
|
||||
end
|
||||
|
||||
elseif ctx.version < ver then
|
||||
debug("upgrading peers version to ", ver)
|
||||
upgrade_peers_version(ctx, ctx.primary_peers, false);
|
||||
upgrade_peers_version(ctx, ctx.backup_peers, true);
|
||||
ctx.version = ver
|
||||
end
|
||||
end
|
||||
|
||||
local function get_lock(ctx)
|
||||
local dict = ctx.dict
|
||||
local key = "l:" .. ctx.upstream
|
||||
|
||||
-- the lock is held for the whole interval to prevent multiple
|
||||
-- worker processes from sending the test request simultaneously.
|
||||
-- here we substract the lock expiration time by 1ms to prevent
|
||||
-- a race condition with the next timer event.
|
||||
local ok, err = dict:add(key, true, ctx.interval - 0.001)
|
||||
if not ok then
|
||||
if err == "exists" then
|
||||
return nil
|
||||
end
|
||||
errlog("failed to add key \"", key, "\": ", err)
|
||||
return nil
|
||||
end
|
||||
return true
|
||||
end
|
||||
|
||||
local function do_check(ctx)
|
||||
debug("healthcheck: run a check cycle")
|
||||
|
||||
check_peers_updates(ctx)
|
||||
|
||||
if get_lock(ctx) then
|
||||
check_peers(ctx, ctx.primary_peers, false)
|
||||
check_peers(ctx, ctx.backup_peers, true)
|
||||
end
|
||||
|
||||
if ctx.new_version then
|
||||
local key = "v:" .. ctx.upstream
|
||||
local dict = ctx.dict
|
||||
|
||||
if debug_mode then
|
||||
debug("publishing peers version ", ctx.version + 1)
|
||||
end
|
||||
|
||||
dict:add(key, 0)
|
||||
local new_ver, err = dict:incr(key, 1)
|
||||
if not new_ver then
|
||||
errlog("failed to publish new peers version: ", err)
|
||||
end
|
||||
|
||||
ctx.version = new_ver
|
||||
ctx.new_version = nil
|
||||
end
|
||||
end
|
||||
|
||||
local function update_upstream_checker_status(upstream, success)
|
||||
local cnt = upstream_checker_statuses[upstream]
|
||||
if not cnt then
|
||||
cnt = 0
|
||||
end
|
||||
|
||||
if success then
|
||||
cnt = cnt + 1
|
||||
else
|
||||
cnt = cnt - 1
|
||||
end
|
||||
|
||||
upstream_checker_statuses[upstream] = cnt
|
||||
end
|
||||
|
||||
local check
|
||||
check = function (premature, ctx)
|
||||
if premature then
|
||||
return
|
||||
end
|
||||
|
||||
local ok, err = pcall(do_check, ctx)
|
||||
if not ok then
|
||||
errlog("failed to run healthcheck cycle: ", err)
|
||||
end
|
||||
|
||||
local ok, err = new_timer(ctx.interval, check, ctx)
|
||||
if not ok then
|
||||
if err ~= "process exiting" then
|
||||
errlog("failed to create timer: ", err)
|
||||
end
|
||||
|
||||
update_upstream_checker_status(ctx.upstream, false)
|
||||
return
|
||||
end
|
||||
end
|
||||
|
||||
local function preprocess_peers(peers)
|
||||
local n = #peers
|
||||
for i = 1, n do
|
||||
local p = peers[i]
|
||||
local name = p.name
|
||||
|
||||
if name then
|
||||
local from, to, err = re_find(name, [[^(.*):\d+$]], "jo", nil, 1)
|
||||
if from then
|
||||
p.host = sub(name, 1, to)
|
||||
p.port = tonumber(sub(name, to + 2))
|
||||
end
|
||||
end
|
||||
end
|
||||
return peers
|
||||
end
|
||||
|
||||
function _M.spawn_checker(opts)
|
||||
local typ = opts.type
|
||||
if not typ then
|
||||
return nil, "\"type\" option required"
|
||||
end
|
||||
|
||||
if typ ~= "http" then
|
||||
return nil, "only \"http\" type is supported right now"
|
||||
end
|
||||
|
||||
local http_req = opts.http_req
|
||||
if not http_req then
|
||||
return nil, "\"http_req\" option required"
|
||||
end
|
||||
|
||||
local timeout = opts.timeout
|
||||
if not timeout then
|
||||
timeout = 1000
|
||||
end
|
||||
|
||||
local interval = opts.interval
|
||||
if not interval then
|
||||
interval = 1
|
||||
|
||||
else
|
||||
interval = interval / 1000
|
||||
if interval < 0.002 then -- minimum 2ms
|
||||
interval = 0.002
|
||||
end
|
||||
end
|
||||
|
||||
local valid_statuses = opts.valid_statuses
|
||||
local statuses
|
||||
if valid_statuses then
|
||||
statuses = new_tab(0, #valid_statuses)
|
||||
for _, status in ipairs(valid_statuses) do
|
||||
-- print("found good status ", status)
|
||||
statuses[status] = true
|
||||
end
|
||||
end
|
||||
|
||||
-- debug("interval: ", interval)
|
||||
|
||||
local concur = opts.concurrency
|
||||
if not concur then
|
||||
concur = 1
|
||||
end
|
||||
|
||||
local fall = opts.fall
|
||||
if not fall then
|
||||
fall = 5
|
||||
end
|
||||
|
||||
local rise = opts.rise
|
||||
if not rise then
|
||||
rise = 2
|
||||
end
|
||||
|
||||
local shm = opts.shm
|
||||
if not shm then
|
||||
return nil, "\"shm\" option required"
|
||||
end
|
||||
|
||||
local dict = shared[shm]
|
||||
if not dict then
|
||||
return nil, "shm \"" .. tostring(shm) .. "\" not found"
|
||||
end
|
||||
|
||||
local u = opts.upstream
|
||||
if not u then
|
||||
return nil, "no upstream specified"
|
||||
end
|
||||
|
||||
local ppeers, err = get_primary_peers(u)
|
||||
if not ppeers then
|
||||
return nil, "failed to get primary peers: " .. err
|
||||
end
|
||||
|
||||
local bpeers, err = get_backup_peers(u)
|
||||
if not bpeers then
|
||||
return nil, "failed to get backup peers: " .. err
|
||||
end
|
||||
|
||||
local ctx = {
|
||||
upstream = u,
|
||||
primary_peers = preprocess_peers(ppeers),
|
||||
backup_peers = preprocess_peers(bpeers),
|
||||
http_req = http_req,
|
||||
timeout = timeout,
|
||||
interval = interval,
|
||||
dict = dict,
|
||||
fall = fall,
|
||||
rise = rise,
|
||||
statuses = statuses,
|
||||
version = 0,
|
||||
concurrency = concur,
|
||||
}
|
||||
|
||||
if debug_mode and opts.no_timer then
|
||||
check(nil, ctx)
|
||||
|
||||
else
|
||||
local ok, err = new_timer(0, check, ctx)
|
||||
if not ok then
|
||||
return nil, "failed to create timer: " .. err
|
||||
end
|
||||
end
|
||||
|
||||
update_upstream_checker_status(u, true)
|
||||
|
||||
return true
|
||||
end
|
||||
|
||||
local function gen_peers_status_info(peers, bits, idx)
|
||||
local npeers = #peers
|
||||
for i = 1, npeers do
|
||||
local peer = peers[i]
|
||||
bits[idx] = " "
|
||||
bits[idx + 1] = peer.name
|
||||
if peer.down then
|
||||
bits[idx + 2] = " DOWN\n"
|
||||
else
|
||||
bits[idx + 2] = " up\n"
|
||||
end
|
||||
idx = idx + 3
|
||||
end
|
||||
return idx
|
||||
end
|
||||
|
||||
function _M.status_page()
|
||||
-- generate an HTML page
|
||||
local us, err = get_upstreams()
|
||||
if not us then
|
||||
return "failed to get upstream names: " .. err
|
||||
end
|
||||
|
||||
local n = #us
|
||||
local bits = new_tab(n * 20, 0)
|
||||
local idx = 1
|
||||
for i = 1, n do
|
||||
if i > 1 then
|
||||
bits[idx] = "\n"
|
||||
idx = idx + 1
|
||||
end
|
||||
|
||||
local u = us[i]
|
||||
|
||||
bits[idx] = "Upstream "
|
||||
bits[idx + 1] = u
|
||||
idx = idx + 2
|
||||
|
||||
local ncheckers = upstream_checker_statuses[u]
|
||||
if not ncheckers or ncheckers == 0 then
|
||||
bits[idx] = " (NO checkers)"
|
||||
idx = idx + 1
|
||||
end
|
||||
|
||||
bits[idx] = "\n Primary Peers\n"
|
||||
idx = idx + 1
|
||||
|
||||
local peers, err = get_primary_peers(u)
|
||||
if not peers then
|
||||
return "failed to get primary peers in upstream " .. u .. ": "
|
||||
.. err
|
||||
end
|
||||
|
||||
idx = gen_peers_status_info(peers, bits, idx)
|
||||
|
||||
bits[idx] = " Backup Peers\n"
|
||||
idx = idx + 1
|
||||
|
||||
peers, err = get_backup_peers(u)
|
||||
if not peers then
|
||||
return "failed to get backup peers in upstream " .. u .. ": "
|
||||
.. err
|
||||
end
|
||||
|
||||
idx = gen_peers_status_info(peers, bits, idx)
|
||||
end
|
||||
return concat(bits)
|
||||
end
|
||||
|
||||
return _M
|
||||
353
resty/websocket/client.lua
Normal file
353
resty/websocket/client.lua
Normal file
@@ -0,0 +1,353 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
-- FIXME: this library is very rough and is currently just for testing
|
||||
-- the websocket server.
|
||||
|
||||
|
||||
local wbproto = require "resty.websocket.protocol"
|
||||
local bit = require "bit"
|
||||
|
||||
|
||||
local _recv_frame = wbproto.recv_frame
|
||||
local _send_frame = wbproto.send_frame
|
||||
local new_tab = wbproto.new_tab
|
||||
local tcp = ngx.socket.tcp
|
||||
local re_match = ngx.re.match
|
||||
local encode_base64 = ngx.encode_base64
|
||||
local concat = table.concat
|
||||
local char = string.char
|
||||
local str_find = string.find
|
||||
local rand = math.random
|
||||
local rshift = bit.rshift
|
||||
local band = bit.band
|
||||
local setmetatable = setmetatable
|
||||
local type = type
|
||||
local debug = ngx.config.debug
|
||||
local ngx_log = ngx.log
|
||||
local ngx_DEBUG = ngx.DEBUG
|
||||
local ssl_support = true
|
||||
|
||||
if not ngx.config
|
||||
or not ngx.config.ngx_lua_version
|
||||
or ngx.config.ngx_lua_version < 9011
|
||||
then
|
||||
ssl_support = false
|
||||
end
|
||||
|
||||
local _M = new_tab(0, 13)
|
||||
_M._VERSION = '0.08'
|
||||
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
function _M.new(self, opts)
|
||||
local sock, err = tcp()
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local max_payload_len, send_unmasked, timeout
|
||||
if opts then
|
||||
max_payload_len = opts.max_payload_len
|
||||
send_unmasked = opts.send_unmasked
|
||||
timeout = opts.timeout
|
||||
|
||||
if timeout then
|
||||
sock:settimeout(timeout)
|
||||
end
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
sock = sock,
|
||||
max_payload_len = max_payload_len or 65535,
|
||||
send_unmasked = send_unmasked,
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.connect(self, uri, opts)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
local m, err = re_match(uri, [[^(wss?)://([^:/]+)(?::(\d+))?(.*)]], "jo")
|
||||
if not m then
|
||||
if err then
|
||||
return nil, "failed to match the uri: " .. err
|
||||
end
|
||||
|
||||
return nil, "bad websocket uri"
|
||||
end
|
||||
|
||||
local scheme = m[1]
|
||||
local host = m[2]
|
||||
local port = m[3]
|
||||
local path = m[4]
|
||||
|
||||
-- ngx.say("host: ", host)
|
||||
-- ngx.say("port: ", port)
|
||||
|
||||
if not port then
|
||||
port = 80
|
||||
end
|
||||
|
||||
if path == "" then
|
||||
path = "/"
|
||||
end
|
||||
|
||||
local ssl_verify, headers, proto_header, origin_header, sock_opts = false
|
||||
|
||||
if opts then
|
||||
local protos = opts.protocols
|
||||
if protos then
|
||||
if type(protos) == "table" then
|
||||
proto_header = "\r\nSec-WebSocket-Protocol: "
|
||||
.. concat(protos, ",")
|
||||
|
||||
else
|
||||
proto_header = "\r\nSec-WebSocket-Protocol: " .. protos
|
||||
end
|
||||
end
|
||||
|
||||
local origin = opts.origin
|
||||
if origin then
|
||||
origin_header = "\r\nOrigin: " .. origin
|
||||
end
|
||||
|
||||
local pool = opts.pool
|
||||
if pool then
|
||||
sock_opts = { pool = pool }
|
||||
end
|
||||
|
||||
if opts.ssl_verify then
|
||||
if not ssl_support then
|
||||
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
|
||||
end
|
||||
ssl_verify = true
|
||||
end
|
||||
|
||||
if opts.headers then
|
||||
headers = opts.headers
|
||||
if type(headers) ~= "table" then
|
||||
return nil, "custom headers must be a table"
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
local ok, err
|
||||
if sock_opts then
|
||||
ok, err = sock:connect(host, port, sock_opts)
|
||||
else
|
||||
ok, err = sock:connect(host, port)
|
||||
end
|
||||
if not ok then
|
||||
return nil, "failed to connect: " .. err
|
||||
end
|
||||
|
||||
if scheme == "wss" then
|
||||
if not ssl_support then
|
||||
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
|
||||
end
|
||||
ok, err = sock:sslhandshake(false, host, ssl_verify)
|
||||
if not ok then
|
||||
return nil, "ssl handshake failed: " .. err
|
||||
end
|
||||
end
|
||||
|
||||
-- check for connections from pool:
|
||||
|
||||
local count, err = sock:getreusedtimes()
|
||||
if not count then
|
||||
return nil, "failed to get reused times: " .. err
|
||||
end
|
||||
if count > 0 then
|
||||
-- being a reused connection (must have done handshake)
|
||||
return 1
|
||||
end
|
||||
|
||||
local custom_headers
|
||||
if headers then
|
||||
custom_headers = concat(headers, "\r\n")
|
||||
custom_headers = "\r\n" .. custom_headers
|
||||
end
|
||||
|
||||
-- do the websocket handshake:
|
||||
|
||||
local bytes = char(rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1, rand(256) - 1, rand(256) - 1,
|
||||
rand(256) - 1)
|
||||
|
||||
local key = encode_base64(bytes)
|
||||
local req = "GET " .. path .. " HTTP/1.1\r\nUpgrade: websocket\r\nHost: "
|
||||
.. host .. ":" .. port
|
||||
.. "\r\nSec-WebSocket-Key: " .. key
|
||||
.. (proto_header or "")
|
||||
.. "\r\nSec-WebSocket-Version: 13"
|
||||
.. (origin_header or "")
|
||||
.. "\r\nConnection: Upgrade"
|
||||
.. (custom_headers or "")
|
||||
.. "\r\n\r\n"
|
||||
|
||||
local bytes, err = sock:send(req)
|
||||
if not bytes then
|
||||
return nil, "failed to send the handshake request: " .. err
|
||||
end
|
||||
|
||||
local header_reader = sock:receiveuntil("\r\n\r\n")
|
||||
-- FIXME: check for too big response headers
|
||||
local header, err, partial = header_reader()
|
||||
if not header then
|
||||
return nil, "failed to receive response header: " .. err
|
||||
end
|
||||
|
||||
-- error("header: " .. header)
|
||||
|
||||
-- FIXME: verify the response headers
|
||||
|
||||
m, err = re_match(header, [[^\s*HTTP/1\.1\s+]], "jo")
|
||||
if not m then
|
||||
return nil, "bad HTTP response status line: " .. header
|
||||
end
|
||||
|
||||
return 1
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, time)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
return sock:settimeout(time)
|
||||
end
|
||||
|
||||
|
||||
function _M.recv_frame(self)
|
||||
if self.fatal then
|
||||
return nil, nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local data, typ, err = _recv_frame(sock, self.max_payload_len, false)
|
||||
if not data and not str_find(err, ": timeout", 1, true) then
|
||||
self.fatal = true
|
||||
end
|
||||
return data, typ, err
|
||||
end
|
||||
|
||||
|
||||
local function send_frame(self, fin, opcode, payload)
|
||||
if self.fatal then
|
||||
return nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
if self.closed then
|
||||
return nil, "already closed"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local bytes, err = _send_frame(sock, fin, opcode, payload,
|
||||
self.max_payload_len,
|
||||
not self.send_unmasked)
|
||||
if not bytes then
|
||||
self.fatal = true
|
||||
end
|
||||
return bytes, err
|
||||
end
|
||||
_M.send_frame = send_frame
|
||||
|
||||
|
||||
function _M.send_text(self, data)
|
||||
return send_frame(self, true, 0x1, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_binary(self, data)
|
||||
return send_frame(self, true, 0x2, data)
|
||||
end
|
||||
|
||||
|
||||
local function send_close(self, code, msg)
|
||||
local payload
|
||||
if code then
|
||||
if type(code) ~= "number" or code > 0x7fff then
|
||||
return nil, "bad status code"
|
||||
end
|
||||
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
|
||||
.. (msg or "")
|
||||
end
|
||||
|
||||
if debug then
|
||||
ngx_log(ngx_DEBUG, "sending the close frame")
|
||||
end
|
||||
|
||||
local bytes, err = send_frame(self, true, 0x8, payload)
|
||||
|
||||
if not bytes then
|
||||
self.fatal = true
|
||||
end
|
||||
|
||||
self.closed = true
|
||||
|
||||
return bytes, err
|
||||
end
|
||||
_M.send_close = send_close
|
||||
|
||||
|
||||
function _M.send_ping(self, data)
|
||||
return send_frame(self, true, 0x9, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_pong(self, data)
|
||||
return send_frame(self, true, 0xa, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.close(self)
|
||||
if self.fatal then
|
||||
return nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
if not self.closed then
|
||||
local bytes, err = send_close(self)
|
||||
if not bytes then
|
||||
return nil, "failed to send close frame: " .. err
|
||||
end
|
||||
end
|
||||
|
||||
return sock:close()
|
||||
end
|
||||
|
||||
|
||||
function _M.set_keepalive(self, ...)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized"
|
||||
end
|
||||
|
||||
return sock:setkeepalive(...)
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
345
resty/websocket/protocol.lua
Normal file
345
resty/websocket/protocol.lua
Normal file
@@ -0,0 +1,345 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local bit = require "bit"
|
||||
local ffi = require "ffi"
|
||||
|
||||
|
||||
local byte = string.byte
|
||||
local char = string.char
|
||||
local sub = string.sub
|
||||
local band = bit.band
|
||||
local bor = bit.bor
|
||||
local bxor = bit.bxor
|
||||
local lshift = bit.lshift
|
||||
local rshift = bit.rshift
|
||||
--local tohex = bit.tohex
|
||||
local tostring = tostring
|
||||
local concat = table.concat
|
||||
local rand = math.random
|
||||
local type = type
|
||||
local debug = ngx.config.debug
|
||||
local ngx_log = ngx.log
|
||||
local ngx_DEBUG = ngx.DEBUG
|
||||
local ffi_new = ffi.new
|
||||
local ffi_string = ffi.string
|
||||
|
||||
|
||||
local ok, new_tab = pcall(require, "table.new")
|
||||
if not ok then
|
||||
new_tab = function (narr, nrec) return {} end
|
||||
end
|
||||
|
||||
|
||||
local _M = new_tab(0, 5)
|
||||
|
||||
_M.new_tab = new_tab
|
||||
_M._VERSION = '0.08'
|
||||
|
||||
|
||||
local types = {
|
||||
[0x0] = "continuation",
|
||||
[0x1] = "text",
|
||||
[0x2] = "binary",
|
||||
[0x8] = "close",
|
||||
[0x9] = "ping",
|
||||
[0xa] = "pong",
|
||||
}
|
||||
|
||||
local str_buf_size = 4096
|
||||
local str_buf
|
||||
local c_buf_type = ffi.typeof("char[?]")
|
||||
|
||||
|
||||
local function get_string_buf(size)
|
||||
if size > str_buf_size then
|
||||
return ffi_new(c_buf_type, size)
|
||||
end
|
||||
if not str_buf then
|
||||
str_buf = ffi_new(c_buf_type, str_buf_size)
|
||||
end
|
||||
|
||||
return str_buf
|
||||
end
|
||||
|
||||
|
||||
function _M.recv_frame(sock, max_payload_len, force_masking)
|
||||
local data, err = sock:receive(2)
|
||||
if not data then
|
||||
return nil, nil, "failed to receive the first 2 bytes: " .. err
|
||||
end
|
||||
|
||||
local fst, snd = byte(data, 1, 2)
|
||||
|
||||
local fin = band(fst, 0x80) ~= 0
|
||||
-- print("fin: ", fin)
|
||||
|
||||
if band(fst, 0x70) ~= 0 then
|
||||
return nil, nil, "bad RSV1, RSV2, or RSV3 bits"
|
||||
end
|
||||
|
||||
local opcode = band(fst, 0x0f)
|
||||
-- print("opcode: ", tohex(opcode))
|
||||
|
||||
if opcode >= 0x3 and opcode <= 0x7 then
|
||||
return nil, nil, "reserved non-control frames"
|
||||
end
|
||||
|
||||
if opcode >= 0xb and opcode <= 0xf then
|
||||
return nil, nil, "reserved control frames"
|
||||
end
|
||||
|
||||
local mask = band(snd, 0x80) ~= 0
|
||||
|
||||
if debug then
|
||||
ngx_log(ngx_DEBUG, "recv_frame: mask bit: ", mask and 1 or 0)
|
||||
end
|
||||
|
||||
if force_masking and not mask then
|
||||
return nil, nil, "frame unmasked"
|
||||
end
|
||||
|
||||
local payload_len = band(snd, 0x7f)
|
||||
-- print("payload len: ", payload_len)
|
||||
|
||||
if payload_len == 126 then
|
||||
local data, err = sock:receive(2)
|
||||
if not data then
|
||||
return nil, nil, "failed to receive the 2 byte payload length: "
|
||||
.. (err or "unknown")
|
||||
end
|
||||
|
||||
payload_len = bor(lshift(byte(data, 1), 8), byte(data, 2))
|
||||
|
||||
elseif payload_len == 127 then
|
||||
local data, err = sock:receive(8)
|
||||
if not data then
|
||||
return nil, nil, "failed to receive the 8 byte payload length: "
|
||||
.. (err or "unknown")
|
||||
end
|
||||
|
||||
if byte(data, 1) ~= 0
|
||||
or byte(data, 2) ~= 0
|
||||
or byte(data, 3) ~= 0
|
||||
or byte(data, 4) ~= 0
|
||||
then
|
||||
return nil, nil, "payload len too large"
|
||||
end
|
||||
|
||||
local fifth = byte(data, 5)
|
||||
if band(fifth, 0x80) ~= 0 then
|
||||
return nil, nil, "payload len too large"
|
||||
end
|
||||
|
||||
payload_len = bor(lshift(fifth, 24),
|
||||
lshift(byte(data, 6), 16),
|
||||
lshift(byte(data, 7), 8),
|
||||
byte(data, 8))
|
||||
end
|
||||
|
||||
if band(opcode, 0x8) ~= 0 then
|
||||
-- being a control frame
|
||||
if payload_len > 125 then
|
||||
return nil, nil, "too long payload for control frame"
|
||||
end
|
||||
|
||||
if not fin then
|
||||
return nil, nil, "fragmented control frame"
|
||||
end
|
||||
end
|
||||
|
||||
-- print("payload len: ", payload_len, ", max payload len: ",
|
||||
-- max_payload_len)
|
||||
|
||||
if payload_len > max_payload_len then
|
||||
return nil, nil, "exceeding max payload len"
|
||||
end
|
||||
|
||||
local rest
|
||||
if mask then
|
||||
rest = payload_len + 4
|
||||
|
||||
else
|
||||
rest = payload_len
|
||||
end
|
||||
-- print("rest: ", rest)
|
||||
|
||||
local data
|
||||
if rest > 0 then
|
||||
data, err = sock:receive(rest)
|
||||
if not data then
|
||||
return nil, nil, "failed to read masking-len and payload: "
|
||||
.. (err or "unknown")
|
||||
end
|
||||
else
|
||||
data = ""
|
||||
end
|
||||
|
||||
-- print("received rest")
|
||||
|
||||
if opcode == 0x8 then
|
||||
-- being a close frame
|
||||
if payload_len > 0 then
|
||||
if payload_len < 2 then
|
||||
return nil, nil, "close frame with a body must carry a 2-byte"
|
||||
.. " status code"
|
||||
end
|
||||
|
||||
local msg, code
|
||||
if mask then
|
||||
local fst = bxor(byte(data, 4 + 1), byte(data, 1))
|
||||
local snd = bxor(byte(data, 4 + 2), byte(data, 2))
|
||||
code = bor(lshift(fst, 8), snd)
|
||||
|
||||
if payload_len > 2 then
|
||||
-- TODO string.buffer optimizations
|
||||
local bytes = get_string_buf(payload_len - 2)
|
||||
for i = 3, payload_len do
|
||||
bytes[i - 3] = bxor(byte(data, 4 + i),
|
||||
byte(data, (i - 1) % 4 + 1))
|
||||
end
|
||||
msg = ffi_string(bytes, payload_len - 2)
|
||||
|
||||
else
|
||||
msg = ""
|
||||
end
|
||||
|
||||
else
|
||||
local fst = byte(data, 1)
|
||||
local snd = byte(data, 2)
|
||||
code = bor(lshift(fst, 8), snd)
|
||||
|
||||
-- print("parsing unmasked close frame payload: ", payload_len)
|
||||
|
||||
if payload_len > 2 then
|
||||
msg = sub(data, 3)
|
||||
|
||||
else
|
||||
msg = ""
|
||||
end
|
||||
end
|
||||
|
||||
return msg, "close", code
|
||||
end
|
||||
|
||||
return "", "close", nil
|
||||
end
|
||||
|
||||
local msg
|
||||
if mask then
|
||||
-- TODO string.buffer optimizations
|
||||
local bytes = get_string_buf(payload_len)
|
||||
for i = 1, payload_len do
|
||||
bytes[i - 1] = bxor(byte(data, 4 + i),
|
||||
byte(data, (i - 1) % 4 + 1))
|
||||
end
|
||||
msg = ffi_string(bytes, payload_len)
|
||||
|
||||
else
|
||||
msg = data
|
||||
end
|
||||
|
||||
return msg, types[opcode], not fin and "again" or nil
|
||||
end
|
||||
|
||||
|
||||
local function build_frame(fin, opcode, payload_len, payload, masking)
|
||||
-- XXX optimize this when we have string.buffer in LuaJIT 2.1
|
||||
local fst
|
||||
if fin then
|
||||
fst = bor(0x80, opcode)
|
||||
else
|
||||
fst = opcode
|
||||
end
|
||||
|
||||
local snd, extra_len_bytes
|
||||
if payload_len <= 125 then
|
||||
snd = payload_len
|
||||
extra_len_bytes = ""
|
||||
|
||||
elseif payload_len <= 65535 then
|
||||
snd = 126
|
||||
extra_len_bytes = char(band(rshift(payload_len, 8), 0xff),
|
||||
band(payload_len, 0xff))
|
||||
|
||||
else
|
||||
if band(payload_len, 0x7fffffff) < payload_len then
|
||||
return nil, "payload too big"
|
||||
end
|
||||
|
||||
snd = 127
|
||||
-- XXX we only support 31-bit length here
|
||||
extra_len_bytes = char(0, 0, 0, 0, band(rshift(payload_len, 24), 0xff),
|
||||
band(rshift(payload_len, 16), 0xff),
|
||||
band(rshift(payload_len, 8), 0xff),
|
||||
band(payload_len, 0xff))
|
||||
end
|
||||
|
||||
local masking_key
|
||||
if masking then
|
||||
-- set the mask bit
|
||||
snd = bor(snd, 0x80)
|
||||
local key = rand(0xffffffff)
|
||||
masking_key = char(band(rshift(key, 24), 0xff),
|
||||
band(rshift(key, 16), 0xff),
|
||||
band(rshift(key, 8), 0xff),
|
||||
band(key, 0xff))
|
||||
|
||||
-- TODO string.buffer optimizations
|
||||
local bytes = get_string_buf(payload_len)
|
||||
for i = 1, payload_len do
|
||||
bytes[i - 1] = bxor(byte(payload, i),
|
||||
byte(masking_key, (i - 1) % 4 + 1))
|
||||
end
|
||||
payload = ffi_string(bytes, payload_len)
|
||||
|
||||
else
|
||||
masking_key = ""
|
||||
end
|
||||
|
||||
return char(fst, snd) .. extra_len_bytes .. masking_key .. payload
|
||||
end
|
||||
_M.build_frame = build_frame
|
||||
|
||||
|
||||
function _M.send_frame(sock, fin, opcode, payload, max_payload_len, masking)
|
||||
-- ngx.log(ngx.WARN, ngx.var.uri, ": masking: ", masking)
|
||||
|
||||
if not payload then
|
||||
payload = ""
|
||||
|
||||
elseif type(payload) ~= "string" then
|
||||
payload = tostring(payload)
|
||||
end
|
||||
|
||||
local payload_len = #payload
|
||||
|
||||
if payload_len > max_payload_len then
|
||||
return nil, "payload too big"
|
||||
end
|
||||
|
||||
if band(opcode, 0x8) ~= 0 then
|
||||
-- being a control frame
|
||||
if payload_len > 125 then
|
||||
return nil, "too much payload for control frame"
|
||||
end
|
||||
if not fin then
|
||||
return nil, "fragmented control frame"
|
||||
end
|
||||
end
|
||||
|
||||
local frame, err = build_frame(fin, opcode, payload_len, payload,
|
||||
masking)
|
||||
if not frame then
|
||||
return nil, "failed to build frame: " .. err
|
||||
end
|
||||
|
||||
local bytes, err = sock:send(frame)
|
||||
if not bytes then
|
||||
return nil, "failed to send frame: " .. err
|
||||
end
|
||||
return bytes
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
210
resty/websocket/server.lua
Normal file
210
resty/websocket/server.lua
Normal file
@@ -0,0 +1,210 @@
|
||||
-- Copyright (C) Yichun Zhang (agentzh)
|
||||
|
||||
|
||||
local bit = require "bit"
|
||||
local wbproto = require "resty.websocket.protocol"
|
||||
|
||||
local new_tab = wbproto.new_tab
|
||||
local _recv_frame = wbproto.recv_frame
|
||||
local _send_frame = wbproto.send_frame
|
||||
local http_ver = ngx.req.http_version
|
||||
local req_sock = ngx.req.socket
|
||||
local ngx_header = ngx.header
|
||||
local req_headers = ngx.req.get_headers
|
||||
local str_lower = string.lower
|
||||
local char = string.char
|
||||
local str_find = string.find
|
||||
local sha1_bin = ngx.sha1_bin
|
||||
local base64 = ngx.encode_base64
|
||||
local ngx = ngx
|
||||
local read_body = ngx.req.read_body
|
||||
local band = bit.band
|
||||
local rshift = bit.rshift
|
||||
local type = type
|
||||
local setmetatable = setmetatable
|
||||
local tostring = tostring
|
||||
-- local print = print
|
||||
|
||||
|
||||
local _M = new_tab(0, 10)
|
||||
_M._VERSION = '0.08'
|
||||
|
||||
local mt = { __index = _M }
|
||||
|
||||
|
||||
function _M.new(self, opts)
|
||||
if ngx.headers_sent then
|
||||
return nil, "response header already sent"
|
||||
end
|
||||
|
||||
read_body()
|
||||
|
||||
if http_ver() ~= 1.1 then
|
||||
return nil, "bad http version"
|
||||
end
|
||||
|
||||
local headers = req_headers()
|
||||
|
||||
local val = headers.upgrade
|
||||
if type(val) == "table" then
|
||||
val = val[1]
|
||||
end
|
||||
if not val or str_lower(val) ~= "websocket" then
|
||||
return nil, "bad \"upgrade\" request header: " .. tostring(val)
|
||||
end
|
||||
|
||||
val = headers.connection
|
||||
if type(val) == "table" then
|
||||
val = val[1]
|
||||
end
|
||||
if not val or not str_find(str_lower(val), "upgrade", 1, true) then
|
||||
return nil, "bad \"connection\" request header"
|
||||
end
|
||||
|
||||
local key = headers["sec-websocket-key"]
|
||||
if type(key) == "table" then
|
||||
key = key[1]
|
||||
end
|
||||
if not key then
|
||||
return nil, "bad \"sec-websocket-key\" request header"
|
||||
end
|
||||
|
||||
local ver = headers["sec-websocket-version"]
|
||||
if type(ver) == "table" then
|
||||
ver = ver[1]
|
||||
end
|
||||
if not ver or ver ~= "13" then
|
||||
return nil, "bad \"sec-websocket-version\" request header"
|
||||
end
|
||||
|
||||
local protocols = headers["sec-websocket-protocol"]
|
||||
if type(protocols) == "table" then
|
||||
protocols = protocols[1]
|
||||
end
|
||||
|
||||
if protocols then
|
||||
ngx_header["Sec-WebSocket-Protocol"] = protocols
|
||||
end
|
||||
ngx_header["Upgrade"] = "websocket"
|
||||
|
||||
local sha1 = sha1_bin(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
|
||||
ngx_header["Sec-WebSocket-Accept"] = base64(sha1)
|
||||
|
||||
ngx_header["Content-Type"] = nil
|
||||
|
||||
ngx.status = 101
|
||||
local ok, err = ngx.send_headers()
|
||||
if not ok then
|
||||
return nil, "failed to send response header: " .. (err or "unknonw")
|
||||
end
|
||||
ok, err = ngx.flush(true)
|
||||
if not ok then
|
||||
return nil, "failed to flush response header: " .. (err or "unknown")
|
||||
end
|
||||
|
||||
local sock
|
||||
sock, err = req_sock(true)
|
||||
if not sock then
|
||||
return nil, err
|
||||
end
|
||||
|
||||
local max_payload_len, send_masked, timeout
|
||||
if opts then
|
||||
max_payload_len = opts.max_payload_len
|
||||
send_masked = opts.send_masked
|
||||
timeout = opts.timeout
|
||||
|
||||
if timeout then
|
||||
sock:settimeout(timeout)
|
||||
end
|
||||
end
|
||||
|
||||
return setmetatable({
|
||||
sock = sock,
|
||||
max_payload_len = max_payload_len or 65535,
|
||||
send_masked = send_masked,
|
||||
}, mt)
|
||||
end
|
||||
|
||||
|
||||
function _M.set_timeout(self, time)
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
return sock:settimeout(time)
|
||||
end
|
||||
|
||||
|
||||
function _M.recv_frame(self)
|
||||
if self.fatal then
|
||||
return nil, nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local data, typ, err = _recv_frame(sock, self.max_payload_len, true)
|
||||
if not data and not str_find(err, ": timeout", 1, true) then
|
||||
self.fatal = true
|
||||
end
|
||||
return data, typ, err
|
||||
end
|
||||
|
||||
|
||||
local function send_frame(self, fin, opcode, payload)
|
||||
if self.fatal then
|
||||
return nil, "fatal error already happened"
|
||||
end
|
||||
|
||||
local sock = self.sock
|
||||
if not sock then
|
||||
return nil, "not initialized yet"
|
||||
end
|
||||
|
||||
local bytes, err = _send_frame(sock, fin, opcode, payload,
|
||||
self.max_payload_len, self.send_masked)
|
||||
if not bytes then
|
||||
self.fatal = true
|
||||
end
|
||||
return bytes, err
|
||||
end
|
||||
_M.send_frame = send_frame
|
||||
|
||||
|
||||
function _M.send_text(self, data)
|
||||
return send_frame(self, true, 0x1, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_binary(self, data)
|
||||
return send_frame(self, true, 0x2, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_close(self, code, msg)
|
||||
local payload
|
||||
if code then
|
||||
if type(code) ~= "number" or code > 0x7fff then
|
||||
end
|
||||
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
|
||||
.. (msg or "")
|
||||
end
|
||||
return send_frame(self, true, 0x8, payload)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_ping(self, data)
|
||||
return send_frame(self, true, 0x9, data)
|
||||
end
|
||||
|
||||
|
||||
function _M.send_pong(self, data)
|
||||
return send_frame(self, true, 0xa, data)
|
||||
end
|
||||
|
||||
|
||||
return _M
|
||||
@@ -1,2 +1,2 @@
|
||||
jsConfuse()
|
||||
dateReplace()
|
||||
dateReplace()
|
||||
|
||||
33
tools.lua
Normal file
33
tools.lua
Normal file
@@ -0,0 +1,33 @@
|
||||
function i_get_cookie(s_cookie)
|
||||
local cookie = {}
|
||||
|
||||
-- string.gfind is renamed to string.gmatch
|
||||
for item in string.gmatch(s_cookie, "[^;]+") do
|
||||
local _, _, k, v = string.find(item, "^%s*(%S+)%s*=%s*(%S+)%s*")
|
||||
if k ~= nil and v ~= nil then
|
||||
cookie[k] = v
|
||||
end
|
||||
end
|
||||
|
||||
return cookie
|
||||
end
|
||||
|
||||
function get_cookie_table()
|
||||
local raw_cookie = ngx.req.get_headers()["Cookie"]
|
||||
if raw_cookie ~= nil then
|
||||
return i_get_cookie(raw_cookie)
|
||||
end
|
||||
return nil
|
||||
end
|
||||
|
||||
function get_cookie_raw()
|
||||
return ngx.req.get_headers()["Cookie"]
|
||||
end
|
||||
|
||||
function match_string(input_str, rule)
|
||||
if input_str == nil then
|
||||
return false
|
||||
end
|
||||
local from, to, err = ngx.re.find(input_str, rule, "jo")
|
||||
return from ~= nil
|
||||
end
|
||||
11
waf/report.lua
Normal file
11
waf/report.lua
Normal file
@@ -0,0 +1,11 @@
|
||||
local _M = {}
|
||||
function _M.violation(result)
|
||||
full_violation_text = "violation: \n"
|
||||
if result == g_violation_sql_detect then
|
||||
g_result_sql_detect = full_violation_text .. "SQL Injection detected \n"
|
||||
end
|
||||
full_violation_text = full_violation_text .. ngx.var.request_uri .. " \n"
|
||||
log(violation)
|
||||
say_html()
|
||||
end
|
||||
return _M
|
||||
6
waf/rule.lua
Normal file
6
waf/rule.lua
Normal file
@@ -0,0 +1,6 @@
|
||||
local _M = {
|
||||
sql_get = "'|\\b(and|or)\\b.+?(>|<|=|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)",
|
||||
sql_post = "\\b(and|or)\\b.{1,6}?(=|>|<|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)",
|
||||
sql_cookie = "\\b(and|or)\\b.{1,6}?(=|>|<|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)"
|
||||
}
|
||||
return _M
|
||||
37
waf/sql.lua
Normal file
37
waf/sql.lua
Normal file
@@ -0,0 +1,37 @@
|
||||
local _M = {}
|
||||
local ngx_base = require "resty.core.base"
|
||||
local waf_rule = require "waf/rule"
|
||||
local waf_voilation = require "waf/violation_list"
|
||||
local waf_report = require "waf/report"
|
||||
|
||||
function _M.waf_sql_filter_params(arg_tables, filter_rule)
|
||||
if arg_tables then
|
||||
for key, val in pairs(arg_tables) do
|
||||
if match_string(val, filter_rule) then
|
||||
--ngx.say(key .. ":" .. val)
|
||||
waf_report.violation(waf_voilation.sql_detect)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function _M.waf_sql_filter()
|
||||
--local get_arags = ngx.req.get_uri_args
|
||||
--for k, v in ipairs(get_arags) do
|
||||
-- print(v)
|
||||
--end
|
||||
if ngx_base.get_request() ~= nil then
|
||||
_M.waf_sql_filter_params(ngx.req.get_uri_args(), waf_rule.sql_get)
|
||||
ngx.say(get_cookie_raw())
|
||||
if match_string(get_cookie_raw(), waf_rule.sql_cookie) then
|
||||
waf_report.violation(waf_voilation.sql_detect)
|
||||
end
|
||||
|
||||
local is_post_method = ngx.req.get_method() == "POST"
|
||||
if is_post_method then
|
||||
ngx.req.read_body()
|
||||
_M.waf_sql_filter_params(ngx.req.get_post_args(), waf_rule.sql_post)
|
||||
end
|
||||
end
|
||||
end
|
||||
return _M
|
||||
3
waf/violation_list.lua
Normal file
3
waf/violation_list.lua
Normal file
@@ -0,0 +1,3 @@
|
||||
return {
|
||||
sql_detect = 1
|
||||
}
|
||||
8
waf/waf.lua
Normal file
8
waf/waf.lua
Normal file
@@ -0,0 +1,8 @@
|
||||
local waf_sql = require "waf/sql"
|
||||
|
||||
--以后加规则配置、插件这些.现在不加
|
||||
function waf_dispatch()
|
||||
waf_sql.waf_sql_filter()
|
||||
end
|
||||
|
||||
waf_dispatch()
|
||||
Reference in New Issue
Block a user