Compare commits

1 Commits

Author SHA1 Message Date
琴心
fcace799df 增加sql的waf功能 2022-03-03 16:16:53 +08:00
54 changed files with 12617 additions and 73 deletions

102
init.lua
View File

@@ -1,42 +1,45 @@
require 'config'
require 'b64'
require 'aes'
require 'log'
require '403'
require 'tableXstring'
require 'fileio'
require 'randomStr'
require 'whiteList'
require "config"
require "b64"
require "aes"
require "log"
require "403"
require "tableXstring"
require "fileio"
require "randomStr"
require "whiteList"
require "tools"
require "waf/waf"
local optionIsOn = function (options) return options == "on" and true or false end
local optionIsOn = function(options)
return options == "on" and true or false
end
ToolsProtect = optionIsOn(toolsProtect)
ShiroProtect = optionIsOn(shiroProtect)
JsProtect = optionIsOn(jsProtect)
JsConfuse = false
SensitiveProtect = optionIsOn(sensitiveProtect)
-- cookie加密
function reqCookieParse()
if ShiroProtect then
local userCookieX9 = ngx.var.cookie_x9i7RDYX23
if not userCookieX9 then -- 没有cookie
log('0-cookie 无cookie', '')
ngx.req.set_header('Cookie', '') -- 移除其他cookie
elseif #userCookieX9 < 32 then -- 判断cookie长度
log('1-cookie 不符合要求', userCookieX9)
ngx.say('4')
if not userCookieX9 then -- 没有cookie
log("0-cookie 无cookie", "")
ngx.req.set_header("Cookie", "") -- 移除其他cookie
elseif #userCookieX9 < 32 then -- 判断cookie长度
log("1-cookie 不符合要求", userCookieX9)
ngx.say("4")
say_html()
else --有cookie
else --有cookie
local result = xpcall(dencrypT, emptyPrint, userCookieX9, aesKey)
if not result then --解密失败
log('2-cookie 无法解密', userCookieX9)
ngx.say('5')
log("2-cookie 无法解密", userCookieX9)
ngx.say("5")
say_html()
else --解密成功
else --解密成功
local originCookie = StrToTable(dencrypT(userCookieX9, aesKey))
ngx.req.set_header('Cookie', transTable(originCookie))
log('3-cookie 解密成功', userCookieX9)
ngx.req.set_header("Cookie", transTable(originCookie))
log("3-cookie 解密成功", userCookieX9)
end
end
end
@@ -46,9 +49,9 @@ function respCookieEncrypt()
if ShiroProtect then
local value = ngx.resp.get_headers()["Set-Cookie"]
if value then
local encryptedCookie = cookieD.."="..encrypT(TableToStr(value), aesKey)
local encryptedCookie = cookieD .. "=" .. encrypT(TableToStr(value), aesKey)
ngx.header["Set-Cookie"] = encryptedCookie
log('4-cookie 加密成功',encryptedCookie)
log("4-cookie 加密成功", encryptedCookie)
end
end
end
@@ -58,30 +61,30 @@ function toolsInfoSpider()
if ToolsProtect and not whiteExtCheck() then
local clientCookieA = ngx.var.cookie_h0yGbdRv
local clientCookieB = ngx.var.cookie_kQpFHdoh
if not (clientCookieA and clientCookieB) then --没有cookieA进入reload302至html生成cookie后再请求原地址
local ip = 'xxx'
local finalPath = 'http://'..ip..'/'..jsPath..'?origin='..encodeBase64(ngx.var.request_uri)
log('1-tools 无cookieA/B', '')
if not (clientCookieA and clientCookieB) then --没有cookieA进入reload302至html生成cookie后再请求原地址
local ip = "xxx"
local finalPath = "http://" .. ip .. "/" .. jsPath .. "?origin=" .. encodeBase64(ngx.var.request_uri)
log("1-tools 无cookieA/B", "")
ngx.redirect(finalPath, 302)
else
local result = xpcall(dencrypT, emptyPrint, clientCookieB, clientCookieA)
if not result then
log('2-tools 解密失败', clientCookieA..', '..clientCookieB)
ngx.say('1')
log("2-tools 解密失败", clientCookieA .. ", " .. clientCookieB)
ngx.say("1")
say_html() -- 解密失败
else-- 可以解密,提取数据
else -- 可以解密,提取数据
local result2 = dencrypT(clientCookieB, clientCookieA)
if #result2 < 1 then
log('3-tools 解密失败', result2)
log("3-tools 解密失败", result2)
else
local srs = split(result2, ',')
local _,e = string.find(srs[1], '0')
local srs = split(result2, ",")
local _, e = string.find(srs[1], "0")
if e ~= nil then
log('4-tools 工具请求', result2)
ngx.say('2')
log("4-tools 工具请求", result2)
ngx.say("2")
say_html()
else
log('0-tools 工具验证通过, 记录浏览器指纹', '', srs[2])
log("0-tools 工具验证通过, 记录浏览器指纹", "", srs[2])
end
end
end
@@ -93,7 +96,7 @@ end
function jsExtDetect()
if JsProtect then
local ext = string.match(ngx.var.uri, ".+%.(%w+)$")
if ext == 'js' then -- 加入检查js文件是否存在
if ext == "js" then -- 加入检查js文件是否存在
JsConfuse = true
end
end
@@ -102,14 +105,14 @@ end
function jsConfuse()
if JsConfuse then
local originBody = ngx.arg[1]
if #originBody > 200 then -- 筛选空js
if #originBody > 200 then -- 筛选空js
local s = getRandom(8)
local path = '/tmp/'..s
writefile(path, originBody, 'w+')
local t = io.popen('export NODE_PATH=/usr/lib/node_modules && node /gate/node/js_confuse.js '..path)
local path = "/tmp/" .. s
writefile(path, originBody, "w+")
local t = io.popen("export NODE_PATH=/usr/lib/node_modules && node /gate/node/js_confuse.js " .. path)
local a = t:read("*all")
ngx.arg[1] = a
os.execute('rm -f '..path)
os.execute("rm -f " .. path)
end
JsConfuse = false
end
@@ -122,14 +125,3 @@ function dateReplace()
ngx.arg[1] = replaceTelephone
end
end

56
log.lua
View File

@@ -1,21 +1,24 @@
require 'config'
require "config"
local optionIsOn = function (options) return options == "on" and true or false end
local optionIsOn = function(options)
return options == "on" and true or false
end
local Attacklog = optionIsOn(attacklog)
local logpath = logdir
local function getClientIp()
IP = ngx.var.remote_addr
IP = ngx.var.remote_addr
if IP == nil then
IP = "unknown"
IP = "unknown"
end
return IP
end
local function write(logfile,msg)
local fd = io.open(logfile,"ab")
if fd == nil then return end
local function write(logfile, msg)
local fd = io.open(logfile, "ab")
if fd == nil then
return
end
fd:write(msg)
fd:flush()
fd:close()
@@ -23,19 +26,38 @@ end
function log(data, ruletag, fp)
if Attacklog then
local fingerprint = fp or ''
local fingerprint = fp or ""
local realIp = getClientIp()
local method = ngx.var.request_method
local ua = ngx.var.http_user_agent
local servername=ngx.var.server_name
local servername = ngx.var.server_name
local url = ngx.var.request_uri
local time=ngx.localtime()
if ua then
line = realIp.." ["..time.."] \""..method.." "..servername..url.."\" \""..ruletag.."\" \""..ua.."\" \""..data.."\" \""..fingerprint.."\"\n"
local time = ngx.localtime()
if ua then
line =
realIp ..
" [" ..
time ..
'] "' ..
method ..
" " ..
servername ..
url ..
'" "' ..
ruletag ..
'" "' .. ua .. '" "' .. data .. '" "' .. fingerprint .. '"\n'
else
line = realIp.." ["..time.."] \""..method.." "..servername..url.."\" \""..ruletag.."\" - \""..data.."\" \""..fingerprint.."\"\n"
line =
realIp ..
" [" ..
time ..
'] "' ..
method ..
" " ..
servername ..
url .. '" "' .. ruletag .. '" - "' .. data .. '" "' .. fingerprint .. '"\n'
end
local filename = logpath..'/'..servername.."_"..ngx.today().."_sec.log"
write(filename,line)
local filename = logpath .. "/" .. servername .. "_" .. ngx.today() .. "_sec.log"
write(filename, line)
end
end
end

303
resty/aes.lua Normal file
View File

@@ -0,0 +1,303 @@
-- Copyright (C) by Yichun Zhang (agentzh)
--local asn1 = require "resty.asn1"
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_gc = ffi.gc
local ffi_str = ffi.string
local ffi_copy = ffi.copy
local C = ffi.C
local setmetatable = setmetatable
--local error = error
local type = type
local _M = { _VERSION = '0.14' }
local mt = { __index = _M }
local EVP_CTRL_AEAD_SET_IVLEN = 0x09
local EVP_CTRL_AEAD_GET_TAG = 0x10
local EVP_CTRL_AEAD_SET_TAG = 0x11
ffi.cdef[[
typedef struct engine_st ENGINE;
typedef struct evp_cipher_st EVP_CIPHER;
typedef struct evp_cipher_ctx_st EVP_CIPHER_CTX;
typedef struct env_md_ctx_st EVP_MD_CTX;
typedef struct env_md_st EVP_MD;
const EVP_MD *EVP_md5(void);
const EVP_MD *EVP_sha(void);
const EVP_MD *EVP_sha1(void);
const EVP_MD *EVP_sha224(void);
const EVP_MD *EVP_sha256(void);
const EVP_MD *EVP_sha384(void);
const EVP_MD *EVP_sha512(void);
const EVP_CIPHER *EVP_aes_128_ecb(void);
const EVP_CIPHER *EVP_aes_128_cbc(void);
const EVP_CIPHER *EVP_aes_128_cfb1(void);
const EVP_CIPHER *EVP_aes_128_cfb8(void);
const EVP_CIPHER *EVP_aes_128_cfb128(void);
const EVP_CIPHER *EVP_aes_128_ofb(void);
const EVP_CIPHER *EVP_aes_128_ctr(void);
const EVP_CIPHER *EVP_aes_192_ecb(void);
const EVP_CIPHER *EVP_aes_192_cbc(void);
const EVP_CIPHER *EVP_aes_192_cfb1(void);
const EVP_CIPHER *EVP_aes_192_cfb8(void);
const EVP_CIPHER *EVP_aes_192_cfb128(void);
const EVP_CIPHER *EVP_aes_192_ofb(void);
const EVP_CIPHER *EVP_aes_192_ctr(void);
const EVP_CIPHER *EVP_aes_256_ecb(void);
const EVP_CIPHER *EVP_aes_256_cbc(void);
const EVP_CIPHER *EVP_aes_256_cfb1(void);
const EVP_CIPHER *EVP_aes_256_cfb8(void);
const EVP_CIPHER *EVP_aes_256_cfb128(void);
const EVP_CIPHER *EVP_aes_256_ofb(void);
const EVP_CIPHER *EVP_aes_128_gcm(void);
const EVP_CIPHER *EVP_aes_192_gcm(void);
const EVP_CIPHER *EVP_aes_256_gcm(void);
EVP_CIPHER_CTX *EVP_CIPHER_CTX_new();
void EVP_CIPHER_CTX_free(EVP_CIPHER_CTX *a);
int EVP_CIPHER_CTX_block_size(const EVP_CIPHER_CTX *ctx);
int EVP_EncryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher,
ENGINE *impl, unsigned char *key, const unsigned char *iv);
int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
const unsigned char *in, int inl);
int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl);
int EVP_DecryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher,
ENGINE *impl, unsigned char *key, const unsigned char *iv);
int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl,
const unsigned char *in, int inl);
int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *outm, int *outl);
int EVP_BytesToKey(const EVP_CIPHER *type,const EVP_MD *md,
const unsigned char *salt, const unsigned char *data, int datal,
int count, unsigned char *key,unsigned char *iv);
int EVP_CIPHER_CTX_ctrl(EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr);
]]
local hash
hash = {
md5 = C.EVP_md5(),
sha1 = C.EVP_sha1(),
sha224 = C.EVP_sha224(),
sha256 = C.EVP_sha256(),
sha384 = C.EVP_sha384(),
sha512 = C.EVP_sha512()
}
_M.hash = hash
local EVP_MAX_BLOCK_LENGTH = 32
local cipher
cipher = function (size, _cipher)
local _size = size or 128
local _cipher = _cipher or "cbc"
local func = "EVP_aes_" .. _size .. "_" .. _cipher
if C[func] then
return { size=_size, cipher=_cipher, method=C[func]()}
else
return nil
end
end
_M.cipher = cipher
function _M.new(self, key, salt, _cipher, _hash, hash_rounds, iv_len)
local encrypt_ctx = C.EVP_CIPHER_CTX_new()
if encrypt_ctx == nil then
return nil, "no memory"
end
ffi_gc(encrypt_ctx, C.EVP_CIPHER_CTX_free)
local decrypt_ctx = C.EVP_CIPHER_CTX_new()
if decrypt_ctx == nil then
return nil, "no memory"
end
ffi_gc(decrypt_ctx, C.EVP_CIPHER_CTX_free)
local _cipher = _cipher or cipher()
local _hash = _hash or hash.md5
local hash_rounds = hash_rounds or 1
local _cipherLength = _cipher.size/8
local gen_key = ffi_new("unsigned char[?]",_cipherLength)
local gen_iv = ffi_new("unsigned char[?]",_cipherLength)
iv_len = iv_len or _cipherLength
if type(_hash) == "table" then
if not _hash.iv then
return nil, "iv is needed"
end
--[[ Depending on the encryption algorithm, the length of iv will be
different. For detailed, please refer to
https://www.openssl.org/docs/man1.1.0/man3/EVP_CIPHER_CTX_ctrl.html
]]
iv_len = #_hash.iv
if iv_len > _cipherLength then
return nil, "bad iv length"
end
if _hash.method then
local tmp_key = _hash.method(key)
if #tmp_key ~= _cipherLength then
return nil, "bad key length"
end
ffi_copy(gen_key, tmp_key, _cipherLength)
elseif #key ~= _cipherLength then
return nil, "bad key length"
else
ffi_copy(gen_key, key, _cipherLength)
end
ffi_copy(gen_iv, _hash.iv, iv_len)
else
if salt and #salt ~= 8 then
return nil, "salt must be 8 characters or nil"
end
if C.EVP_BytesToKey(_cipher.method, _hash, salt, key, #key,
hash_rounds, gen_key, gen_iv)
~= _cipherLength
then
return nil, "failed to generate key and iv"
end
end
if C.EVP_EncryptInit_ex(encrypt_ctx, _cipher.method, nil,
nil, nil) == 0 or
C.EVP_DecryptInit_ex(decrypt_ctx, _cipher.method, nil,
nil, nil) == 0 then
return nil, "failed to init ctx"
end
local cipher_name = _cipher.cipher
if cipher_name == "gcm"
or cipher_name == "ccm"
or cipher_name == "ocb" then
if C.EVP_CIPHER_CTX_ctrl(encrypt_ctx, EVP_CTRL_AEAD_SET_IVLEN,
iv_len, nil) == 0 or
C.EVP_CIPHER_CTX_ctrl(decrypt_ctx, EVP_CTRL_AEAD_SET_IVLEN,
iv_len, nil) == 0 then
return nil, "failed to set IV length"
end
end
return setmetatable({
_encrypt_ctx = encrypt_ctx,
_decrypt_ctx = decrypt_ctx,
_cipher = _cipher.cipher,
_key = gen_key,
_iv = gen_iv
}, mt)
end
function _M.encrypt(self, s)
local typ = type(self)
if typ ~= "table" then
error("bad argument #1 self: table expected, got " .. typ, 2)
end
local s_len = #s
local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH
local buf = ffi_new("unsigned char[?]", max_len)
local out_len = ffi_new("int[1]")
local tmp_len = ffi_new("int[1]")
local ctx = self._encrypt_ctx
if C.EVP_EncryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then
return nil, "EVP_EncryptInit_ex failed"
end
if C.EVP_EncryptUpdate(ctx, buf, out_len, s, s_len) == 0 then
return nil, "EVP_EncryptUpdate failed"
end
if self._cipher == "gcm" then
local encrypt_data = ffi_str(buf, out_len[0])
if C.EVP_EncryptFinal_ex(ctx, buf, out_len) == 0 then
return nil, "EVP_DecryptFinal_ex failed"
end
-- FIXME: For OCB mode the taglen must either be 16
-- or the value previously set via EVP_CTRL_OCB_SET_TAGLEN.
-- so we should extend this api in the future
C.EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, 16, buf);
local tag = ffi_str(buf, 16)
return {encrypt_data, tag}
end
if C.EVP_EncryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then
return nil, "EVP_EncryptFinal_ex failed"
end
return ffi_str(buf, out_len[0] + tmp_len[0])
end
function _M.decrypt(self, s, tag)
local typ = type(self)
if typ ~= "table" then
error("bad argument #1 self: table expected, got " .. typ, 2)
end
local s_len = #s
local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH
local buf = ffi_new("unsigned char[?]", max_len)
local out_len = ffi_new("int[1]")
local tmp_len = ffi_new("int[1]")
local ctx = self._decrypt_ctx
if C.EVP_DecryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then
return nil, "EVP_DecryptInit_ex failed"
end
if C.EVP_DecryptUpdate(ctx, buf, out_len, s, s_len) == 0 then
return nil, "EVP_DecryptUpdate failed"
end
if self._cipher == "gcm" then
local plain_txt = ffi_str(buf, out_len[0])
if tag ~= nil then
local tag_buf = ffi_new("unsigned char[?]", 16)
ffi.copy(tag_buf, tag, 16)
C.EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, 16, tag_buf);
end
if C.EVP_DecryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then
return nil, "EVP_DecryptFinal_ex failed"
end
return plain_txt
end
if C.EVP_DecryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then
return nil, "EVP_DecryptFinal_ex failed"
end
return ffi_str(buf, out_len[0] + tmp_len[0])
end
return _M

35
resty/core.lua Normal file
View File

@@ -0,0 +1,35 @@
-- Copyright (C) Yichun Zhang (agentzh)
local subsystem = ngx.config.subsystem
require "resty.core.var"
require "resty.core.worker"
require "resty.core.regex"
require "resty.core.shdict"
require "resty.core.time"
require "resty.core.hash"
require "resty.core.uri"
require "resty.core.exit"
require "resty.core.base64"
require "resty.core.request"
if subsystem == 'http' then
require "resty.core.response"
require "resty.core.phase"
require "resty.core.ndk"
require "resty.core.socket"
end
require "resty.core.misc"
require "resty.core.ctx"
local base = require "resty.core.base"
return {
version = base.version
}

259
resty/core/base.lua Normal file
View File

@@ -0,0 +1,259 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require 'ffi'
local ffi_new = ffi.new
local error = error
local select = select
local ceil = math.ceil
local subsystem = ngx.config.subsystem
local str_buf_size = 4096
local str_buf
local size_ptr
local FREE_LIST_REF = 0
if subsystem == 'http' then
local ngx_lua_v = ngx.config.ngx_lua_version
if not ngx.config
or not ngx.config.ngx_lua_version
or (ngx_lua_v ~= 10019 and ngx_lua_v ~= 10020)
then
error("ngx_http_lua_module 0.10.19 or 0.10.20 required")
end
elseif subsystem == 'stream' then
if not ngx.config
or not ngx.config.ngx_lua_version
or ngx.config.ngx_lua_version ~= 10
then
error("ngx_stream_lua_module 0.0.10 required")
end
else
error("ngx_http_lua_module 0.10.20 or "
.. "ngx_stream_lua_module 0.0.10 required")
end
if string.find(jit.version, " 2.0", 1, true) then
ngx.log(ngx.ALERT, "use of lua-resty-core with LuaJIT 2.0 is ",
"not recommended; use LuaJIT 2.1+ instead")
end
local ok, new_tab = pcall(require, "table.new")
if not ok then
new_tab = function (narr, nrec) return {} end
end
local clear_tab
ok, clear_tab = pcall(require, "table.clear")
if not ok then
local pairs = pairs
clear_tab = function (tab)
for k, _ in pairs(tab) do
tab[k] = nil
end
end
end
-- XXX for now LuaJIT 2.1 cannot compile require()
-- so we make the fast code path Lua only in our own
-- wrapper so that most of the require() calls in hot
-- Lua code paths can be JIT compiled.
do
local orig_require = require
local pkg_loaded = package.loaded
local function my_require(name)
local mod = pkg_loaded[name]
if mod then
return mod
end
return orig_require(name)
end
getfenv(0).require = my_require
end
if not pcall(ffi.typeof, "ngx_str_t") then
ffi.cdef[[
typedef struct {
size_t len;
const unsigned char *data;
} ngx_str_t;
]]
end
if subsystem == 'http' then
if not pcall(ffi.typeof, "ngx_http_request_t") then
ffi.cdef[[
typedef struct ngx_http_request_s ngx_http_request_t;
]]
end
if not pcall(ffi.typeof, "ngx_http_lua_ffi_str_t") then
ffi.cdef[[
typedef struct {
int len;
const unsigned char *data;
} ngx_http_lua_ffi_str_t;
]]
end
elseif subsystem == 'stream' then
if not pcall(ffi.typeof, "ngx_stream_lua_request_t") then
ffi.cdef[[
typedef struct ngx_stream_lua_request_s ngx_stream_lua_request_t;
]]
end
if not pcall(ffi.typeof, "ngx_stream_lua_ffi_str_t") then
ffi.cdef[[
typedef struct {
int len;
const unsigned char *data;
} ngx_stream_lua_ffi_str_t;
]]
end
else
error("unknown subsystem: " .. subsystem)
end
local c_buf_type = ffi.typeof("char[?]")
local _M = new_tab(0, 18)
_M.version = "0.1.22"
_M.new_tab = new_tab
_M.clear_tab = clear_tab
local errmsg
function _M.get_errmsg_ptr()
if not errmsg then
errmsg = ffi_new("char *[1]")
end
return errmsg
end
if not ngx then
error("no existing ngx. table found")
end
function _M.set_string_buf_size(size)
if size <= 0 then
return
end
if str_buf then
str_buf = nil
end
str_buf_size = ceil(size)
end
function _M.get_string_buf_size()
return str_buf_size
end
function _M.get_size_ptr()
if not size_ptr then
size_ptr = ffi_new("size_t[1]")
end
return size_ptr
end
function _M.get_string_buf(size, must_alloc)
-- ngx.log(ngx.ERR, "str buf size: ", str_buf_size)
if size > str_buf_size or must_alloc then
return ffi_new(c_buf_type, size)
end
if not str_buf then
str_buf = ffi_new(c_buf_type, str_buf_size)
end
return str_buf
end
function _M.ref_in_table(tb, key)
if key == nil then
return -1
end
local ref = tb[FREE_LIST_REF]
if ref and ref ~= 0 then
tb[FREE_LIST_REF] = tb[ref]
else
ref = #tb + 1
end
tb[ref] = key
-- print("ref key_id returned ", ref)
return ref
end
function _M.allows_subsystem(...)
local total = select("#", ...)
for i = 1, total do
if select(i, ...) == subsystem then
return
end
end
error("unsupported subsystem: " .. subsystem, 2)
end
_M.FFI_OK = 0
_M.FFI_NO_REQ_CTX = -100
_M.FFI_BAD_CONTEXT = -101
_M.FFI_ERROR = -1
_M.FFI_AGAIN = -2
_M.FFI_BUSY = -3
_M.FFI_DONE = -4
_M.FFI_DECLINED = -5
do
local exdata
ok, exdata = pcall(require, "thread.exdata")
if ok and exdata then
function _M.get_request()
local r = exdata()
if r ~= nil then
return r
end
end
else
local getfenv = getfenv
function _M.get_request()
return getfenv(0).__ngx_req
end
end
end
return _M

115
resty/core/base64.lua Normal file
View File

@@ -0,0 +1,115 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local base = require "resty.core.base"
local C = ffi.C
local ffi_string = ffi.string
local ngx = ngx
local type = type
local error = error
local floor = math.floor
local tostring = tostring
local get_string_buf = base.get_string_buf
local get_size_ptr = base.get_size_ptr
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_encode_base64
local ngx_lua_ffi_decode_base64
if subsystem == "http" then
ffi.cdef[[
size_t ngx_http_lua_ffi_encode_base64(const unsigned char *src,
size_t len, unsigned char *dst,
int no_padding);
int ngx_http_lua_ffi_decode_base64(const unsigned char *src,
size_t len, unsigned char *dst,
size_t *dlen);
]]
ngx_lua_ffi_encode_base64 = C.ngx_http_lua_ffi_encode_base64
ngx_lua_ffi_decode_base64 = C.ngx_http_lua_ffi_decode_base64
elseif subsystem == "stream" then
ffi.cdef[[
size_t ngx_stream_lua_ffi_encode_base64(const unsigned char *src,
size_t len, unsigned char *dst,
int no_padding);
int ngx_stream_lua_ffi_decode_base64(const unsigned char *src,
size_t len, unsigned char *dst,
size_t *dlen);
]]
ngx_lua_ffi_encode_base64 = C.ngx_stream_lua_ffi_encode_base64
ngx_lua_ffi_decode_base64 = C.ngx_stream_lua_ffi_decode_base64
end
local function base64_encoded_length(len, no_padding)
return no_padding and floor((len * 8 + 5) / 6) or
floor((len + 2) / 3) * 4
end
ngx.encode_base64 = function (s, no_padding)
if type(s) ~= 'string' then
if not s then
s = ''
else
s = tostring(s)
end
end
local slen = #s
local no_padding_bool = false;
local no_padding_int = 0;
if no_padding then
if no_padding ~= true then
local typ = type(no_padding)
error("bad no_padding: boolean expected, got " .. typ, 2)
end
no_padding_bool = true
no_padding_int = 1;
end
local dlen = base64_encoded_length(slen, no_padding_bool)
local dst = get_string_buf(dlen)
local r_dlen = ngx_lua_ffi_encode_base64(s, slen, dst, no_padding_int)
-- if dlen ~= r_dlen then error("discrepancy in len") end
return ffi_string(dst, r_dlen)
end
local function base64_decoded_length(len)
return floor((len + 3) / 4) * 3
end
ngx.decode_base64 = function (s)
if type(s) ~= 'string' then
error("string argument only", 2)
end
local slen = #s
local dlen = base64_decoded_length(slen)
-- print("dlen: ", tonumber(dlen))
local dst = get_string_buf(dlen)
local pdlen = get_size_ptr()
local ok = ngx_lua_ffi_decode_base64(s, slen, dst, pdlen)
if ok == 0 then
return nil
end
return ffi_string(dst, pdlen[0])
end
return {
version = base.version
}

143
resty/core/ctx.lua Normal file
View File

@@ -0,0 +1,143 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local debug = require "debug"
local base = require "resty.core.base"
local misc = require "resty.core.misc"
local C = ffi.C
local register_getter = misc.register_ngx_magic_key_getter
local register_setter = misc.register_ngx_magic_key_setter
local registry = debug.getregistry()
local new_tab = base.new_tab
local ref_in_table = base.ref_in_table
local get_request = base.get_request
local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX
local FFI_OK = base.FFI_OK
local error = error
local setmetatable = setmetatable
local type = type
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_get_ctx_ref
local ngx_lua_ffi_set_ctx_ref
if subsystem == "http" then
ffi.cdef[[
int ngx_http_lua_ffi_get_ctx_ref(ngx_http_request_t *r, int *in_ssl_phase,
int *ssl_ctx_ref);
int ngx_http_lua_ffi_set_ctx_ref(ngx_http_request_t *r, int ref);
]]
ngx_lua_ffi_get_ctx_ref = C.ngx_http_lua_ffi_get_ctx_ref
ngx_lua_ffi_set_ctx_ref = C.ngx_http_lua_ffi_set_ctx_ref
elseif subsystem == "stream" then
ffi.cdef[[
int ngx_stream_lua_ffi_get_ctx_ref(ngx_stream_lua_request_t *r,
int *in_ssl_phase, int *ssl_ctx_ref);
int ngx_stream_lua_ffi_set_ctx_ref(ngx_stream_lua_request_t *r, int ref);
]]
ngx_lua_ffi_get_ctx_ref = C.ngx_stream_lua_ffi_get_ctx_ref
ngx_lua_ffi_set_ctx_ref = C.ngx_stream_lua_ffi_set_ctx_ref
end
local _M = {
_VERSION = base.version
}
local get_ctx_table
do
local in_ssl_phase = ffi.new("int[1]")
local ssl_ctx_ref = ffi.new("int[1]")
function get_ctx_table(ctx)
local r = get_request()
if not r then
error("no request found")
end
local ctx_ref = ngx_lua_ffi_get_ctx_ref(r, in_ssl_phase, ssl_ctx_ref)
if ctx_ref == FFI_NO_REQ_CTX then
error("no request ctx found")
end
local ctxs = registry.ngx_lua_ctx_tables
if ctx_ref < 0 then
ctx_ref = ssl_ctx_ref[0]
if ctx_ref > 0 and ctxs[ctx_ref] then
if in_ssl_phase[0] ~= 0 then
return ctxs[ctx_ref]
end
if not ctx then
ctx = new_tab(0, 4)
end
ctx = setmetatable(ctx, ctxs[ctx_ref])
else
if in_ssl_phase[0] ~= 0 then
if not ctx then
ctx = new_tab(1, 4)
end
-- to avoid creating another table, we assume the users
-- won't overwrite the `__index` key
ctx.__index = ctx
elseif not ctx then
ctx = new_tab(0, 4)
end
end
ctx_ref = ref_in_table(ctxs, ctx)
if ngx_lua_ffi_set_ctx_ref(r, ctx_ref) ~= FFI_OK then
return nil
end
return ctx
end
return ctxs[ctx_ref]
end
end
register_getter("ctx", get_ctx_table)
_M.get_ctx_table = get_ctx_table
local function set_ctx_table(ctx)
local ctx_type = type(ctx)
if ctx_type ~= "table" then
error("ctx should be a table while getting a " .. ctx_type)
end
local r = get_request()
if not r then
error("no request found")
end
local ctx_ref = ngx_lua_ffi_get_ctx_ref(r, nil, nil)
if ctx_ref == FFI_NO_REQ_CTX then
error("no request ctx found")
end
local ctxs = registry.ngx_lua_ctx_tables
if ctx_ref < 0 then
ctx_ref = ref_in_table(ctxs, ctx)
ngx_lua_ffi_set_ctx_ref(r, ctx_ref)
return
end
ctxs[ctx_ref] = ctx
end
register_setter("ctx", set_ctx_table)
return _M

66
resty/core/exit.lua Normal file
View File

@@ -0,0 +1,66 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local base = require "resty.core.base"
local C = ffi.C
local ffi_string = ffi.string
local ngx = ngx
local error = error
local get_string_buf = base.get_string_buf
local get_size_ptr = base.get_size_ptr
local get_request = base.get_request
local co_yield = coroutine._yield
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_exit
if subsystem == "http" then
ffi.cdef[[
int ngx_http_lua_ffi_exit(ngx_http_request_t *r, int status,
unsigned char *err, size_t *errlen);
]]
ngx_lua_ffi_exit = C.ngx_http_lua_ffi_exit
elseif subsystem == "stream" then
ffi.cdef[[
int ngx_stream_lua_ffi_exit(ngx_stream_lua_request_t *r, int status,
unsigned char *err, size_t *errlen);
]]
ngx_lua_ffi_exit = C.ngx_stream_lua_ffi_exit
end
local ERR_BUF_SIZE = 128
local FFI_DONE = base.FFI_DONE
ngx.exit = function (rc)
local err = get_string_buf(ERR_BUF_SIZE)
local errlen = get_size_ptr()
local r = get_request()
if r == nil then
error("no request found")
end
errlen[0] = ERR_BUF_SIZE
rc = ngx_lua_ffi_exit(r, rc, err, errlen)
if rc == 0 then
-- print("yielding...")
return co_yield()
end
if rc == FFI_DONE then
return
end
error(ffi_string(err, errlen[0]), 2)
end
return {
version = base.version
}

154
resty/core/hash.lua Normal file
View File

@@ -0,0 +1,154 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local base = require "resty.core.base"
local C = ffi.C
local ffi_new = ffi.new
local ffi_string = ffi.string
local ngx = ngx
local type = type
local error = error
local tostring = tostring
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_md5
local ngx_lua_ffi_md5_bin
local ngx_lua_ffi_sha1_bin
local ngx_lua_ffi_crc32_long
local ngx_lua_ffi_crc32_short
if subsystem == "http" then
ffi.cdef[[
void ngx_http_lua_ffi_md5_bin(const unsigned char *src, size_t len,
unsigned char *dst);
void ngx_http_lua_ffi_md5(const unsigned char *src, size_t len,
unsigned char *dst);
int ngx_http_lua_ffi_sha1_bin(const unsigned char *src, size_t len,
unsigned char *dst);
unsigned int ngx_http_lua_ffi_crc32_long(const unsigned char *src,
size_t len);
unsigned int ngx_http_lua_ffi_crc32_short(const unsigned char *src,
size_t len);
]]
ngx_lua_ffi_md5 = C.ngx_http_lua_ffi_md5
ngx_lua_ffi_md5_bin = C.ngx_http_lua_ffi_md5_bin
ngx_lua_ffi_sha1_bin = C.ngx_http_lua_ffi_sha1_bin
ngx_lua_ffi_crc32_short = C.ngx_http_lua_ffi_crc32_short
ngx_lua_ffi_crc32_long = C.ngx_http_lua_ffi_crc32_long
elseif subsystem == "stream" then
ffi.cdef[[
void ngx_stream_lua_ffi_md5_bin(const unsigned char *src, size_t len,
unsigned char *dst);
void ngx_stream_lua_ffi_md5(const unsigned char *src, size_t len,
unsigned char *dst);
int ngx_stream_lua_ffi_sha1_bin(const unsigned char *src, size_t len,
unsigned char *dst);
unsigned int ngx_stream_lua_ffi_crc32_long(const unsigned char *src,
size_t len);
unsigned int ngx_stream_lua_ffi_crc32_short(const unsigned char *src,
size_t len);
]]
ngx_lua_ffi_md5 = C.ngx_stream_lua_ffi_md5
ngx_lua_ffi_md5_bin = C.ngx_stream_lua_ffi_md5_bin
ngx_lua_ffi_sha1_bin = C.ngx_stream_lua_ffi_sha1_bin
ngx_lua_ffi_crc32_short = C.ngx_stream_lua_ffi_crc32_short
ngx_lua_ffi_crc32_long = C.ngx_stream_lua_ffi_crc32_long
end
local MD5_DIGEST_LEN = 16
local md5_buf = ffi_new("unsigned char[?]", MD5_DIGEST_LEN)
ngx.md5_bin = function (s)
if type(s) ~= 'string' then
if not s then
s = ''
else
s = tostring(s)
end
end
ngx_lua_ffi_md5_bin(s, #s, md5_buf)
return ffi_string(md5_buf, MD5_DIGEST_LEN)
end
local MD5_HEX_DIGEST_LEN = MD5_DIGEST_LEN * 2
local md5_hex_buf = ffi_new("unsigned char[?]", MD5_HEX_DIGEST_LEN)
ngx.md5 = function (s)
if type(s) ~= 'string' then
if not s then
s = ''
else
s = tostring(s)
end
end
ngx_lua_ffi_md5(s, #s, md5_hex_buf)
return ffi_string(md5_hex_buf, MD5_HEX_DIGEST_LEN)
end
local SHA_DIGEST_LEN = 20
local sha_buf = ffi_new("unsigned char[?]", SHA_DIGEST_LEN)
ngx.sha1_bin = function (s)
if type(s) ~= 'string' then
if not s then
s = ''
else
s = tostring(s)
end
end
local ok = ngx_lua_ffi_sha1_bin(s, #s, sha_buf)
if ok == 0 then
error("SHA-1 support missing in Nginx")
end
return ffi_string(sha_buf, SHA_DIGEST_LEN)
end
ngx.crc32_short = function (s)
if type(s) ~= "string" then
if not s then
s = ""
else
s = tostring(s)
end
end
return ngx_lua_ffi_crc32_short(s, #s)
end
ngx.crc32_long = function (s)
if type(s) ~= "string" then
if not s then
s = ""
else
s = tostring(s)
end
end
return ngx_lua_ffi_crc32_long(s, #s)
end
return {
version = base.version
}

240
resty/core/misc.lua Normal file
View File

@@ -0,0 +1,240 @@
-- Copyright (C) Yichun Zhang (agentzh)
local base = require "resty.core.base"
local ffi = require "ffi"
local os = require "os"
local C = ffi.C
local ffi_new = ffi.new
local ffi_str = ffi.string
local ngx = ngx
local type = type
local error = error
local rawget = rawget
local rawset = rawset
local tonumber = tonumber
local setmetatable = setmetatable
local FFI_OK = base.FFI_OK
local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX
local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT
local new_tab = base.new_tab
local get_request = base.get_request
local get_size_ptr = base.get_size_ptr
local get_string_buf = base.get_string_buf
local get_string_buf_size = base.get_string_buf_size
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_get_resp_status
local ngx_lua_ffi_get_conf_env
local ngx_magic_key_getters
local ngx_magic_key_setters
local _M = new_tab(0, 3)
local ngx_mt = new_tab(0, 2)
if subsystem == "http" then
ngx_magic_key_getters = new_tab(0, 4)
ngx_magic_key_setters = new_tab(0, 2)
elseif subsystem == "stream" then
ngx_magic_key_getters = new_tab(0, 2)
ngx_magic_key_setters = new_tab(0, 1)
end
local function register_getter(key, func)
ngx_magic_key_getters[key] = func
end
_M.register_ngx_magic_key_getter = register_getter
local function register_setter(key, func)
ngx_magic_key_setters[key] = func
end
_M.register_ngx_magic_key_setter = register_setter
ngx_mt.__index = function (tb, key)
local f = ngx_magic_key_getters[key]
if f then
return f()
end
return rawget(tb, key)
end
ngx_mt.__newindex = function (tb, key, ctx)
local f = ngx_magic_key_setters[key]
if f then
return f(ctx)
end
return rawset(tb, key, ctx)
end
setmetatable(ngx, ngx_mt)
if subsystem == "http" then
ffi.cdef[[
int ngx_http_lua_ffi_get_resp_status(ngx_http_request_t *r);
int ngx_http_lua_ffi_set_resp_status(ngx_http_request_t *r, int r);
int ngx_http_lua_ffi_is_subrequest(ngx_http_request_t *r);
int ngx_http_lua_ffi_headers_sent(ngx_http_request_t *r);
int ngx_http_lua_ffi_get_conf_env(const unsigned char *name,
unsigned char **env_buf,
size_t *name_len);
]]
ngx_lua_ffi_get_resp_status = C.ngx_http_lua_ffi_get_resp_status
ngx_lua_ffi_get_conf_env = C.ngx_http_lua_ffi_get_conf_env
-- ngx.status
local function set_status(status)
local r = get_request()
if not r then
error("no request found")
end
if type(status) ~= 'number' then
status = tonumber(status)
end
local rc = C.ngx_http_lua_ffi_set_resp_status(r, status)
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
return
end
register_setter("status", set_status)
-- ngx.is_subrequest
local function is_subreq()
local r = get_request()
if not r then
error("no request found")
end
local rc = C.ngx_http_lua_ffi_is_subrequest(r)
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
return rc == 1
end
register_getter("is_subrequest", is_subreq)
-- ngx.headers_sent
local function headers_sent()
local r = get_request()
if not r then
error("no request found")
end
local rc = C.ngx_http_lua_ffi_headers_sent(r)
if rc == FFI_NO_REQ_CTX then
error("no request ctx found")
end
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
return rc == 1
end
register_getter("headers_sent", headers_sent)
elseif subsystem == "stream" then
ffi.cdef[[
int ngx_stream_lua_ffi_get_resp_status(ngx_stream_lua_request_t *r);
int ngx_stream_lua_ffi_get_conf_env(const unsigned char *name,
unsigned char **env_buf,
size_t *name_len);
]]
ngx_lua_ffi_get_resp_status = C.ngx_stream_lua_ffi_get_resp_status
ngx_lua_ffi_get_conf_env = C.ngx_stream_lua_ffi_get_conf_env
end
-- ngx.status
local function get_status()
local r = get_request()
if not r then
error("no request found")
end
local rc = ngx_lua_ffi_get_resp_status(r)
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
return rc
end
register_getter("status", get_status)
do
local _getenv = os.getenv
local env_ptr = ffi_new("unsigned char *[1]")
os.getenv = function (name)
local r = get_request()
if r then
-- past init_by_lua* phase now
os.getenv = _getenv
env_ptr = nil
return os.getenv(name)
end
local size = get_string_buf_size()
env_ptr[0] = get_string_buf(size)
local name_len_ptr = get_size_ptr()
local rc = ngx_lua_ffi_get_conf_env(name, env_ptr, name_len_ptr)
if rc == FFI_OK then
return ffi_str(env_ptr[0] + name_len_ptr[0] + 1)
end
-- FFI_DECLINED
local value = _getenv(name)
if value ~= nil then
return value
end
return nil
end
end
_M._VERSION = base.version
return _M

92
resty/core/ndk.lua Normal file
View File

@@ -0,0 +1,92 @@
-- Copyright (C) by OpenResty Inc.
local ffi = require 'ffi'
local base = require "resty.core.base"
base.allows_subsystem('http')
local C = ffi.C
local ffi_cast = ffi.cast
local ffi_new = ffi.new
local ffi_str = ffi.string
local FFI_OK = base.FFI_OK
local new_tab = base.new_tab
local get_string_buf = base.get_string_buf
local get_request = base.get_request
local setmetatable = setmetatable
local type = type
local tostring = tostring
local error = error
local _M = {
version = base.version
}
ffi.cdef[[
typedef void * ndk_set_var_value_pt;
int ngx_http_lua_ffi_ndk_lookup_directive(const unsigned char *var_data,
size_t var_len, ndk_set_var_value_pt *func);
int ngx_http_lua_ffi_ndk_set_var_get(ngx_http_request_t *r,
ndk_set_var_value_pt func, const unsigned char *arg_data, size_t arg_len,
ngx_http_lua_ffi_str_t *value);
]]
local func_p = ffi_new("void*[1]")
local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t")
local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*")
local function ndk_set_var_get(self, var)
if type(var) ~= "string" then
var = tostring(var)
end
if C.ngx_http_lua_ffi_ndk_lookup_directive(var, #var, func_p) ~= FFI_OK then
error('ndk.set_var: directive "' .. var
.. '" not found or does not use ndk_set_var_value', 2)
end
local func = func_p[0]
return function (arg)
local r = get_request()
if not r then
error("no request found")
end
if type(arg) ~= "string" then
arg = tostring(arg)
end
local buf = get_string_buf(ffi_str_size)
local value = ffi_cast(ffi_str_type, buf)
local rc = C.ngx_http_lua_ffi_ndk_set_var_get(r, func, arg, #arg, value)
if rc ~= FFI_OK then
error("calling directive " .. var .. " failed with code " .. rc, 2)
end
return ffi_str(value.data, value.len)
end
end
local function ndk_set_var_set()
error("not allowed", 2)
end
if ndk then
local mt = new_tab(0, 2)
mt.__newindex = ndk_set_var_set
mt.__index = ndk_set_var_get
ndk.set_var = setmetatable(new_tab(0, 0), mt)
end
return _M

59
resty/core/phase.lua Normal file
View File

@@ -0,0 +1,59 @@
local ffi = require 'ffi'
local base = require "resty.core.base"
local C = ffi.C
local FFI_ERROR = base.FFI_ERROR
local get_request = base.get_request
local error = error
local tostring = tostring
ffi.cdef[[
int ngx_http_lua_ffi_get_phase(ngx_http_request_t *r, char **err)
]]
local errmsg = base.get_errmsg_ptr()
local context_names = {
[0x0001] = "set",
[0x0002] = "rewrite",
[0x0004] = "access",
[0x0008] = "content",
[0x0010] = "log",
[0x0020] = "header_filter",
[0x0040] = "body_filter",
[0x0080] = "timer",
[0x0100] = "init_worker",
[0x0200] = "balancer",
[0x0400] = "ssl_cert",
[0x0800] = "ssl_session_store",
[0x1000] = "ssl_session_fetch",
[0x2000] = "exit_worker",
}
function ngx.get_phase()
local r = get_request()
-- if we have no request object, assume we are called from the "init" phase
if not r then
return "init"
end
local context = C.ngx_http_lua_ffi_get_phase(r, errmsg)
if context == FFI_ERROR then -- NGX_ERROR
error(errmsg, 2)
end
local phase = context_names[context]
if not phase then
error("unknown phase: " .. tostring(context))
end
return phase
end
return {
version = base.version
}

1213
resty/core/regex.lua Normal file

File diff suppressed because it is too large Load Diff

451
resty/core/request.lua Normal file
View File

@@ -0,0 +1,451 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require 'ffi'
local base = require "resty.core.base"
local utils = require "resty.core.utils"
local subsystem = ngx.config.subsystem
local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT
local FFI_DECLINED = base.FFI_DECLINED
local FFI_OK = base.FFI_OK
local new_tab = base.new_tab
local C = ffi.C
local ffi_cast = ffi.cast
local ffi_new = ffi.new
local ffi_str = ffi.string
local get_string_buf = base.get_string_buf
local get_size_ptr = base.get_size_ptr
local setmetatable = setmetatable
local lower = string.lower
local rawget = rawget
local ngx = ngx
local get_request = base.get_request
local type = type
local error = error
local tostring = tostring
local tonumber = tonumber
local str_replace_char = utils.str_replace_char
local _M = {
version = base.version
}
local ngx_lua_ffi_req_start_time
if subsystem == "stream" then
ffi.cdef[[
double ngx_stream_lua_ffi_req_start_time(ngx_stream_lua_request_t *r);
]]
ngx_lua_ffi_req_start_time = C.ngx_stream_lua_ffi_req_start_time
elseif subsystem == "http" then
ffi.cdef[[
double ngx_http_lua_ffi_req_start_time(ngx_http_request_t *r);
]]
ngx_lua_ffi_req_start_time = C.ngx_http_lua_ffi_req_start_time
end
function ngx.req.start_time()
local r = get_request()
if not r then
error("no request found")
end
return tonumber(ngx_lua_ffi_req_start_time(r))
end
if subsystem == "stream" then
return _M
end
local errmsg = base.get_errmsg_ptr()
local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*")
local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t")
ffi.cdef[[
typedef struct {
ngx_http_lua_ffi_str_t key;
ngx_http_lua_ffi_str_t value;
} ngx_http_lua_ffi_table_elt_t;
int ngx_http_lua_ffi_req_get_headers_count(ngx_http_request_t *r,
int max, int *truncated);
int ngx_http_lua_ffi_req_get_headers(ngx_http_request_t *r,
ngx_http_lua_ffi_table_elt_t *out, int count, int raw);
int ngx_http_lua_ffi_req_get_uri_args_count(ngx_http_request_t *r,
int max, int *truncated);
size_t ngx_http_lua_ffi_req_get_querystring_len(ngx_http_request_t *r);
int ngx_http_lua_ffi_req_get_uri_args(ngx_http_request_t *r,
unsigned char *buf, ngx_http_lua_ffi_table_elt_t *out, int count);
int ngx_http_lua_ffi_req_get_method(ngx_http_request_t *r);
int ngx_http_lua_ffi_req_get_method_name(ngx_http_request_t *r,
unsigned char **name, size_t *len);
int ngx_http_lua_ffi_req_set_method(ngx_http_request_t *r, int method);
int ngx_http_lua_ffi_req_set_header(ngx_http_request_t *r,
const unsigned char *key, size_t key_len, const unsigned char *value,
size_t value_len, ngx_http_lua_ffi_str_t *mvals, size_t mvals_len,
int override, char **errmsg);
]]
local table_elt_type = ffi.typeof("ngx_http_lua_ffi_table_elt_t*")
local table_elt_size = ffi.sizeof("ngx_http_lua_ffi_table_elt_t")
local truncated = ffi.new("int[1]")
local req_headers_mt = {
__index = function (tb, key)
return rawget(tb, (str_replace_char(lower(key), '_', '-')))
end
}
function ngx.req.get_headers(max_headers, raw)
local r = get_request()
if not r then
error("no request found")
end
if not max_headers then
max_headers = -1
end
if not raw then
raw = 0
else
raw = 1
end
local n = C.ngx_http_lua_ffi_req_get_headers_count(r, max_headers,
truncated)
if n == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
if n == 0 then
local headers = {}
if raw == 0 then
headers = setmetatable(headers, req_headers_mt)
end
return headers
end
local raw_buf = get_string_buf(n * table_elt_size)
local buf = ffi_cast(table_elt_type, raw_buf)
local rc = C.ngx_http_lua_ffi_req_get_headers(r, buf, n, raw)
if rc == 0 then
local headers = new_tab(0, n)
for i = 0, n - 1 do
local h = buf[i]
local key = h.key
key = ffi_str(key.data, key.len)
local value = h.value
value = ffi_str(value.data, value.len)
local existing = headers[key]
if existing then
if type(existing) == "table" then
existing[#existing + 1] = value
else
headers[key] = {existing, value}
end
else
headers[key] = value
end
end
if raw == 0 then
headers = setmetatable(headers, req_headers_mt)
end
if truncated[0] ~= 0 then
return headers, "truncated"
end
return headers
end
return nil
end
function ngx.req.get_uri_args(max_args)
local r = get_request()
if not r then
error("no request found")
end
if not max_args then
max_args = -1
end
local n = C.ngx_http_lua_ffi_req_get_uri_args_count(r, max_args, truncated)
if n == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
if n == 0 then
return {}
end
local args_len = C.ngx_http_lua_ffi_req_get_querystring_len(r)
local strbuf = get_string_buf(args_len + n * table_elt_size)
local kvbuf = ffi_cast(table_elt_type, strbuf + args_len)
local nargs = C.ngx_http_lua_ffi_req_get_uri_args(r, strbuf, kvbuf, n)
local args = new_tab(0, nargs)
for i = 0, nargs - 1 do
local arg = kvbuf[i]
local key = arg.key
key = ffi_str(key.data, key.len)
local value = arg.value
local len = value.len
if len == -1 then
value = true
else
value = ffi_str(value.data, len)
end
local existing = args[key]
if existing then
if type(existing) == "table" then
existing[#existing + 1] = value
else
args[key] = {existing, value}
end
else
args[key] = value
end
end
if truncated[0] ~= 0 then
return args, "truncated"
end
return args
end
do
local methods = {
[0x0002] = "GET",
[0x0004] = "HEAD",
[0x0008] = "POST",
[0x0010] = "PUT",
[0x0020] = "DELETE",
[0x0040] = "MKCOL",
[0x0080] = "COPY",
[0x0100] = "MOVE",
[0x0200] = "OPTIONS",
[0x0400] = "PROPFIND",
[0x0800] = "PROPPATCH",
[0x1000] = "LOCK",
[0x2000] = "UNLOCK",
[0x4000] = "PATCH",
[0x8000] = "TRACE",
}
local namep = ffi_new("unsigned char *[1]")
function ngx.req.get_method()
local r = get_request()
if not r then
error("no request found")
end
do
local id = C.ngx_http_lua_ffi_req_get_method(r)
if id == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
local method = methods[id]
if method then
return method
end
end
local sizep = get_size_ptr()
local rc = C.ngx_http_lua_ffi_req_get_method_name(r, namep, sizep)
if rc ~= 0 then
return nil
end
return ffi_str(namep[0], sizep[0])
end
end -- do
function ngx.req.set_method(method)
local r = get_request()
if not r then
error("no request found")
end
if type(method) ~= "number" then
error("bad method number", 2)
end
local rc = C.ngx_http_lua_ffi_req_set_method(r, method)
if rc == FFI_OK then
return
end
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
if rc == FFI_DECLINED then
error("unsupported HTTP method: " .. method, 2)
end
error("unknown error: " .. rc)
end
do
local function set_req_header(name, value, override)
local r = get_request()
if not r then
error("no request found", 3)
end
if name == nil then
error("bad 'name' argument: string expected, got nil", 3)
end
if type(name) ~= "string" then
name = tostring(name)
end
local rc
if value == nil then
if not override then
error("bad 'value' argument: string or table expected, got nil",
3)
end
rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, nil, 0, nil,
0, 1, errmsg)
else
local sval, sval_len, mvals, mvals_len, buf
local value_type = type(value)
if value_type == "table" then
mvals_len = #value
if mvals_len == 0 and not override then
error("bad 'value' argument: non-empty table expected", 3)
end
buf = get_string_buf(ffi_str_size * mvals_len)
mvals = ffi_cast(ffi_str_type, buf)
for i = 1, mvals_len do
local s = value[i]
if type(s) ~= "string" then
s = tostring(s)
value[i] = s
end
local str = mvals[i - 1]
str.data = s
str.len = #s
end
sval_len = 0
else
if value_type ~= "string" then
sval = tostring(value)
else
sval = value
end
sval_len = #sval
mvals_len = 0
end
rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, sval,
sval_len, mvals, mvals_len,
override and 1 or 0, errmsg)
end
if rc == FFI_OK or rc == FFI_DECLINED then
return
end
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 3)
end
-- rc == FFI_ERROR
error(ffi_str(errmsg[0]))
end
_M.set_req_header = set_req_header
function ngx.req.set_header(name, value)
set_req_header(name, value, true) -- override
end
end -- do
function ngx.req.clear_header(name)
local r = get_request()
if not r then
error("no request found")
end
if type(name) ~= "string" then
name = tostring(name)
end
local rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, nil, 0, nil, 0,
1, errmsg)
if rc == FFI_OK or rc == FFI_DECLINED then
return
end
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
-- rc == FFI_ERROR
error(ffi_str(errmsg[0]))
end
return _M

183
resty/core/response.lua Normal file
View File

@@ -0,0 +1,183 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require 'ffi'
local base = require "resty.core.base"
local C = ffi.C
local ffi_cast = ffi.cast
local ffi_str = ffi.string
local new_tab = base.new_tab
local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT
local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX
local FFI_DECLINED = base.FFI_DECLINED
local get_string_buf = base.get_string_buf
local setmetatable = setmetatable
local type = type
local tostring = tostring
local get_request = base.get_request
local error = error
local ngx = ngx
local _M = {
version = base.version
}
local MAX_HEADER_VALUES = 100
local errmsg = base.get_errmsg_ptr()
local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*")
local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t")
ffi.cdef[[
int ngx_http_lua_ffi_set_resp_header(ngx_http_request_t *r,
const char *key_data, size_t key_len, int is_nil,
const char *sval, size_t sval_len, ngx_http_lua_ffi_str_t *mvals,
size_t mvals_len, int override, char **errmsg);
int ngx_http_lua_ffi_get_resp_header(ngx_http_request_t *r,
const unsigned char *key, size_t key_len,
unsigned char *key_buf, ngx_http_lua_ffi_str_t *values,
int max_nvalues, char **errmsg);
]]
local function set_resp_header(tb, key, value, no_override)
local r = get_request()
if not r then
error("no request found")
end
if type(key) ~= "string" then
key = tostring(key)
end
local rc
if value == nil then
if no_override then
error("invalid header value", 3)
end
rc = C.ngx_http_lua_ffi_set_resp_header(r, key, #key, true, nil, 0, nil,
0, 1, errmsg)
else
local sval, sval_len, mvals, mvals_len, buf
if type(value) == "table" then
mvals_len = #value
if mvals_len == 0 and no_override then
return
end
buf = get_string_buf(ffi_str_size * mvals_len)
mvals = ffi_cast(ffi_str_type, buf)
for i = 1, mvals_len do
local s = value[i]
if type(s) ~= "string" then
s = tostring(s)
value[i] = s
end
local str = mvals[i - 1]
str.data = s
str.len = #s
end
sval_len = 0
else
if type(value) ~= "string" then
sval = tostring(value)
else
sval = value
end
sval_len = #sval
mvals_len = 0
end
local override_int = no_override and 0 or 1
rc = C.ngx_http_lua_ffi_set_resp_header(r, key, #key, false, sval,
sval_len, mvals, mvals_len,
override_int, errmsg)
end
if rc == 0 or rc == FFI_DECLINED then
return
end
if rc == FFI_NO_REQ_CTX then
error("no request ctx found")
end
if rc == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
-- rc == FFI_ERROR
error(ffi_str(errmsg[0]), 2)
end
_M.set_resp_header = set_resp_header
local function get_resp_header(tb, key)
local r = get_request()
if not r then
error("no request found")
end
if type(key) ~= "string" then
key = tostring(key)
end
local key_len = #key
local key_buf = get_string_buf(key_len + ffi_str_size * MAX_HEADER_VALUES)
local values = ffi_cast(ffi_str_type, key_buf + key_len)
local n = C.ngx_http_lua_ffi_get_resp_header(r, key, key_len, key_buf,
values, MAX_HEADER_VALUES,
errmsg)
-- print("retval: ", n)
if n == FFI_BAD_CONTEXT then
error("API disabled in the current context", 2)
end
if n == 0 then
return nil
end
if n == 1 then
local v = values[0]
return ffi_str(v.data, v.len)
end
if n > 0 then
local ret = new_tab(n, 0)
for i = 1, n do
local v = values[i - 1]
ret[i] = ffi_str(v.data, v.len)
end
return ret
end
-- n == FFI_ERROR
error(ffi_str(errmsg[0]), 2)
end
do
local mt = new_tab(0, 2)
mt.__newindex = set_resp_header
mt.__index = get_resp_header
ngx.header = setmetatable(new_tab(0, 0), mt)
end
return _M

638
resty/core/shdict.lua Normal file
View File

@@ -0,0 +1,638 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require 'ffi'
local base = require "resty.core.base"
local _M = {
version = base.version
}
local ngx_shared = ngx.shared
if not ngx_shared then
return _M
end
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
local get_string_buf = base.get_string_buf
local get_string_buf_size = base.get_string_buf_size
local get_size_ptr = base.get_size_ptr
local tonumber = tonumber
local tostring = tostring
local next = next
local type = type
local error = error
local getmetatable = getmetatable
local FFI_DECLINED = base.FFI_DECLINED
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_shdict_get
local ngx_lua_ffi_shdict_incr
local ngx_lua_ffi_shdict_store
local ngx_lua_ffi_shdict_flush_all
local ngx_lua_ffi_shdict_get_ttl
local ngx_lua_ffi_shdict_set_expire
local ngx_lua_ffi_shdict_capacity
local ngx_lua_ffi_shdict_free_space
local ngx_lua_ffi_shdict_udata_to_zone
if subsystem == 'http' then
ffi.cdef[[
int ngx_http_lua_ffi_shdict_get(void *zone, const unsigned char *key,
size_t key_len, int *value_type, unsigned char **str_value_buf,
size_t *str_value_len, double *num_value, int *user_flags,
int get_stale, int *is_stale, char **errmsg);
int ngx_http_lua_ffi_shdict_incr(void *zone, const unsigned char *key,
size_t key_len, double *value, char **err, int has_init,
double init, long init_ttl, int *forcible);
int ngx_http_lua_ffi_shdict_store(void *zone, int op,
const unsigned char *key, size_t key_len, int value_type,
const unsigned char *str_value_buf, size_t str_value_len,
double num_value, long exptime, int user_flags, char **errmsg,
int *forcible);
int ngx_http_lua_ffi_shdict_flush_all(void *zone);
long ngx_http_lua_ffi_shdict_get_ttl(void *zone,
const unsigned char *key, size_t key_len);
int ngx_http_lua_ffi_shdict_set_expire(void *zone,
const unsigned char *key, size_t key_len, long exptime);
size_t ngx_http_lua_ffi_shdict_capacity(void *zone);
void *ngx_http_lua_ffi_shdict_udata_to_zone(void *zone_udata);
]]
ngx_lua_ffi_shdict_get = C.ngx_http_lua_ffi_shdict_get
ngx_lua_ffi_shdict_incr = C.ngx_http_lua_ffi_shdict_incr
ngx_lua_ffi_shdict_store = C.ngx_http_lua_ffi_shdict_store
ngx_lua_ffi_shdict_flush_all = C.ngx_http_lua_ffi_shdict_flush_all
ngx_lua_ffi_shdict_get_ttl = C.ngx_http_lua_ffi_shdict_get_ttl
ngx_lua_ffi_shdict_set_expire = C.ngx_http_lua_ffi_shdict_set_expire
ngx_lua_ffi_shdict_capacity = C.ngx_http_lua_ffi_shdict_capacity
ngx_lua_ffi_shdict_udata_to_zone =
C.ngx_http_lua_ffi_shdict_udata_to_zone
if not pcall(function ()
return C.ngx_http_lua_ffi_shdict_free_space
end)
then
ffi.cdef[[
size_t ngx_http_lua_ffi_shdict_free_space(void *zone);
]]
end
pcall(function ()
ngx_lua_ffi_shdict_free_space = C.ngx_http_lua_ffi_shdict_free_space
end)
elseif subsystem == 'stream' then
ffi.cdef[[
int ngx_stream_lua_ffi_shdict_get(void *zone, const unsigned char *key,
size_t key_len, int *value_type, unsigned char **str_value_buf,
size_t *str_value_len, double *num_value, int *user_flags,
int get_stale, int *is_stale, char **errmsg);
int ngx_stream_lua_ffi_shdict_incr(void *zone, const unsigned char *key,
size_t key_len, double *value, char **err, int has_init,
double init, long init_ttl, int *forcible);
int ngx_stream_lua_ffi_shdict_store(void *zone, int op,
const unsigned char *key, size_t key_len, int value_type,
const unsigned char *str_value_buf, size_t str_value_len,
double num_value, long exptime, int user_flags, char **errmsg,
int *forcible);
int ngx_stream_lua_ffi_shdict_flush_all(void *zone);
long ngx_stream_lua_ffi_shdict_get_ttl(void *zone,
const unsigned char *key, size_t key_len);
int ngx_stream_lua_ffi_shdict_set_expire(void *zone,
const unsigned char *key, size_t key_len, long exptime);
size_t ngx_stream_lua_ffi_shdict_capacity(void *zone);
void *ngx_stream_lua_ffi_shdict_udata_to_zone(void *zone_udata);
]]
ngx_lua_ffi_shdict_get = C.ngx_stream_lua_ffi_shdict_get
ngx_lua_ffi_shdict_incr = C.ngx_stream_lua_ffi_shdict_incr
ngx_lua_ffi_shdict_store = C.ngx_stream_lua_ffi_shdict_store
ngx_lua_ffi_shdict_flush_all = C.ngx_stream_lua_ffi_shdict_flush_all
ngx_lua_ffi_shdict_get_ttl = C.ngx_stream_lua_ffi_shdict_get_ttl
ngx_lua_ffi_shdict_set_expire = C.ngx_stream_lua_ffi_shdict_set_expire
ngx_lua_ffi_shdict_capacity = C.ngx_stream_lua_ffi_shdict_capacity
ngx_lua_ffi_shdict_udata_to_zone =
C.ngx_stream_lua_ffi_shdict_udata_to_zone
if not pcall(function ()
return C.ngx_stream_lua_ffi_shdict_free_space
end)
then
ffi.cdef[[
size_t ngx_stream_lua_ffi_shdict_free_space(void *zone);
]]
end
-- ngx_stream_lua is only compatible with NGINX >= 1.13.6, meaning it
-- cannot lack support for ngx_stream_lua_ffi_shdict_free_space.
ngx_lua_ffi_shdict_free_space = C.ngx_stream_lua_ffi_shdict_free_space
else
error("unknown subsystem: " .. subsystem)
end
if not pcall(function () return C.free end) then
ffi.cdef[[
void free(void *ptr);
]]
end
local value_type = ffi_new("int[1]")
local user_flags = ffi_new("int[1]")
local num_value = ffi_new("double[1]")
local is_stale = ffi_new("int[1]")
local forcible = ffi_new("int[1]")
local str_value_buf = ffi_new("unsigned char *[1]")
local errmsg = base.get_errmsg_ptr()
local function check_zone(zone)
if not zone or type(zone) ~= "table" then
error("bad \"zone\" argument", 3)
end
zone = zone[1]
if type(zone) ~= "userdata" then
error("bad \"zone\" argument", 3)
end
zone = ngx_lua_ffi_shdict_udata_to_zone(zone)
if zone == nil then
error("bad \"zone\" argument", 3)
end
return zone
end
local function shdict_store(zone, op, key, value, exptime, flags)
zone = check_zone(zone)
if not exptime then
exptime = 0
elseif exptime < 0 then
error('bad "exptime" argument', 2)
end
if not flags then
flags = 0
end
if key == nil then
return nil, "nil key"
end
if type(key) ~= "string" then
key = tostring(key)
end
local key_len = #key
if key_len == 0 then
return nil, "empty key"
end
if key_len > 65535 then
return nil, "key too long"
end
local str_val_buf
local str_val_len = 0
local num_val = 0
local valtyp = type(value)
-- print("value type: ", valtyp)
-- print("exptime: ", exptime)
if valtyp == "string" then
valtyp = 4 -- LUA_TSTRING
str_val_buf = value
str_val_len = #value
elseif valtyp == "number" then
valtyp = 3 -- LUA_TNUMBER
num_val = value
elseif value == nil then
valtyp = 0 -- LUA_TNIL
elseif valtyp == "boolean" then
valtyp = 1 -- LUA_TBOOLEAN
num_val = value and 1 or 0
else
return nil, "bad value type"
end
local rc = ngx_lua_ffi_shdict_store(zone, op, key, key_len,
valtyp, str_val_buf,
str_val_len, num_val,
exptime * 1000, flags, errmsg,
forcible)
-- print("rc == ", rc)
if rc == 0 then -- NGX_OK
return true, nil, forcible[0] == 1
end
-- NGX_DECLINED or NGX_ERROR
return false, ffi_str(errmsg[0]), forcible[0] == 1
end
local function shdict_set(zone, key, value, exptime, flags)
return shdict_store(zone, 0, key, value, exptime, flags)
end
local function shdict_safe_set(zone, key, value, exptime, flags)
return shdict_store(zone, 0x0004, key, value, exptime, flags)
end
local function shdict_add(zone, key, value, exptime, flags)
return shdict_store(zone, 0x0001, key, value, exptime, flags)
end
local function shdict_safe_add(zone, key, value, exptime, flags)
return shdict_store(zone, 0x0005, key, value, exptime, flags)
end
local function shdict_replace(zone, key, value, exptime, flags)
return shdict_store(zone, 0x0002, key, value, exptime, flags)
end
local function shdict_delete(zone, key)
return shdict_set(zone, key, nil)
end
local function shdict_get(zone, key)
zone = check_zone(zone)
if key == nil then
return nil, "nil key"
end
if type(key) ~= "string" then
key = tostring(key)
end
local key_len = #key
if key_len == 0 then
return nil, "empty key"
end
if key_len > 65535 then
return nil, "key too long"
end
local size = get_string_buf_size()
local buf = get_string_buf(size)
str_value_buf[0] = buf
local value_len = get_size_ptr()
value_len[0] = size
local rc = ngx_lua_ffi_shdict_get(zone, key, key_len, value_type,
str_value_buf, value_len,
num_value, user_flags, 0,
is_stale, errmsg)
if rc ~= 0 then
if errmsg[0] ~= nil then
return nil, ffi_str(errmsg[0])
end
error("failed to get the key")
end
local typ = value_type[0]
if typ == 0 then -- LUA_TNIL
return nil
end
local flags = tonumber(user_flags[0])
local val
if typ == 4 then -- LUA_TSTRING
if str_value_buf[0] ~= buf then
-- ngx.say("len: ", tonumber(value_len[0]))
buf = str_value_buf[0]
val = ffi_str(buf, value_len[0])
C.free(buf)
else
val = ffi_str(buf, value_len[0])
end
elseif typ == 3 then -- LUA_TNUMBER
val = tonumber(num_value[0])
elseif typ == 1 then -- LUA_TBOOLEAN
val = (tonumber(buf[0]) ~= 0)
else
error("unknown value type: " .. typ)
end
if flags ~= 0 then
return val, flags
end
return val
end
local function shdict_get_stale(zone, key)
zone = check_zone(zone)
if key == nil then
return nil, "nil key"
end
if type(key) ~= "string" then
key = tostring(key)
end
local key_len = #key
if key_len == 0 then
return nil, "empty key"
end
if key_len > 65535 then
return nil, "key too long"
end
local size = get_string_buf_size()
local buf = get_string_buf(size)
str_value_buf[0] = buf
local value_len = get_size_ptr()
value_len[0] = size
local rc = ngx_lua_ffi_shdict_get(zone, key, key_len, value_type,
str_value_buf, value_len,
num_value, user_flags, 1,
is_stale, errmsg)
if rc ~= 0 then
if errmsg[0] ~= nil then
return nil, ffi_str(errmsg[0])
end
error("failed to get the key")
end
local typ = value_type[0]
if typ == 0 then -- LUA_TNIL
return nil
end
local flags = tonumber(user_flags[0])
local val
if typ == 4 then -- LUA_TSTRING
if str_value_buf[0] ~= buf then
-- ngx.say("len: ", tonumber(value_len[0]))
buf = str_value_buf[0]
val = ffi_str(buf, value_len[0])
C.free(buf)
else
val = ffi_str(buf, value_len[0])
end
elseif typ == 3 then -- LUA_TNUMBER
val = tonumber(num_value[0])
elseif typ == 1 then -- LUA_TBOOLEAN
val = (tonumber(buf[0]) ~= 0)
else
error("unknown value type: " .. typ)
end
if flags ~= 0 then
return val, flags, is_stale[0] == 1
end
return val, nil, is_stale[0] == 1
end
local function shdict_incr(zone, key, value, init, init_ttl)
zone = check_zone(zone)
if key == nil then
return nil, "nil key"
end
if type(key) ~= "string" then
key = tostring(key)
end
local key_len = #key
if key_len == 0 then
return nil, "empty key"
end
if key_len > 65535 then
return nil, "key too long"
end
if type(value) ~= "number" then
value = tonumber(value)
end
num_value[0] = value
if init then
local typ = type(init)
if typ ~= "number" then
init = tonumber(init)
if not init then
error("bad init arg: number expected, got " .. typ, 2)
end
end
end
if init_ttl ~= nil then
local typ = type(init_ttl)
if typ ~= "number" then
init_ttl = tonumber(init_ttl)
if not init_ttl then
error("bad init_ttl arg: number expected, got " .. typ, 2)
end
end
if init_ttl < 0 then
error('bad "init_ttl" argument', 2)
end
if not init then
error('must provide "init" when providing "init_ttl"', 2)
end
else
init_ttl = 0
end
local rc = ngx_lua_ffi_shdict_incr(zone, key, key_len, num_value,
errmsg, init and 1 or 0,
init or 0, init_ttl * 1000,
forcible)
if rc ~= 0 then -- ~= NGX_OK
return nil, ffi_str(errmsg[0])
end
if not init then
return tonumber(num_value[0])
end
return tonumber(num_value[0]), nil, forcible[0] == 1
end
local function shdict_flush_all(zone)
zone = check_zone(zone)
ngx_lua_ffi_shdict_flush_all(zone)
end
local function shdict_ttl(zone, key)
zone = check_zone(zone)
if key == nil then
return nil, "nil key"
end
if type(key) ~= "string" then
key = tostring(key)
end
local key_len = #key
if key_len == 0 then
return nil, "empty key"
end
if key_len > 65535 then
return nil, "key too long"
end
local rc = ngx_lua_ffi_shdict_get_ttl(zone, key, key_len)
if rc == FFI_DECLINED then
return nil, "not found"
end
return tonumber(rc) / 1000
end
local function shdict_expire(zone, key, exptime)
zone = check_zone(zone)
if not exptime then
error('bad "exptime" argument', 2)
end
if key == nil then
return nil, "nil key"
end
if type(key) ~= "string" then
key = tostring(key)
end
local key_len = #key
if key_len == 0 then
return nil, "empty key"
end
if key_len > 65535 then
return nil, "key too long"
end
local rc = ngx_lua_ffi_shdict_set_expire(zone, key, key_len,
exptime * 1000)
if rc == FFI_DECLINED then
return nil, "not found"
end
-- NGINX_OK/FFI_OK
return true
end
local function shdict_capacity(zone)
zone = check_zone(zone)
return tonumber(ngx_lua_ffi_shdict_capacity(zone))
end
local shdict_free_space
if ngx_lua_ffi_shdict_free_space then
shdict_free_space = function (zone)
zone = check_zone(zone)
return tonumber(ngx_lua_ffi_shdict_free_space(zone))
end
else
shdict_free_space = function ()
error("'shm:free_space()' not supported in NGINX < 1.11.7", 2)
end
end
local _, dict = next(ngx_shared, nil)
if dict then
local mt = getmetatable(dict)
if mt then
mt = mt.__index
if mt then
mt.get = shdict_get
mt.get_stale = shdict_get_stale
mt.incr = shdict_incr
mt.set = shdict_set
mt.safe_set = shdict_safe_set
mt.add = shdict_add
mt.safe_add = shdict_safe_add
mt.replace = shdict_replace
mt.delete = shdict_delete
mt.flush_all = shdict_flush_all
mt.ttl = shdict_ttl
mt.expire = shdict_expire
mt.capacity = shdict_capacity
mt.free_space = shdict_free_space
end
end
end
return _M

124
resty/core/socket.lua Normal file
View File

@@ -0,0 +1,124 @@
local base = require "resty.core.base"
base.allows_subsystem('http')
local debug = require 'debug'
local ffi = require 'ffi'
local error = error
local tonumber = tonumber
local registry = debug.getregistry()
local ffi_new = ffi.new
local ffi_string = ffi.string
local C = ffi.C
local get_string_buf = base.get_string_buf
local get_size_ptr = base.get_size_ptr
local tostring = tostring
local option_index = {
["keepalive"] = 1,
["reuseaddr"] = 2,
["tcp-nodelay"] = 3,
["sndbuf"] = 4,
["rcvbuf"] = 5,
}
ffi.cdef[[
typedef struct ngx_http_lua_socket_tcp_upstream_s
ngx_http_lua_socket_tcp_upstream_t;
int
ngx_http_lua_ffi_socket_tcp_getoption(ngx_http_lua_socket_tcp_upstream_t *u,
int opt, int *val, unsigned char *err, size_t *errlen);
int
ngx_http_lua_ffi_socket_tcp_setoption(ngx_http_lua_socket_tcp_upstream_t *u,
int opt, int val, unsigned char *err, size_t *errlen);
]]
local output_value_buf = ffi_new("int[1]")
local FFI_OK = base.FFI_OK
local SOCKET_CTX_INDEX = 1
local ERR_BUF_SIZE = 4096
local function get_tcp_socket(cosocket)
local tcp_socket = cosocket[SOCKET_CTX_INDEX]
if not tcp_socket then
error("socket is never created nor connected")
end
return tcp_socket
end
local function getoption(cosocket, option)
local tcp_socket = get_tcp_socket(cosocket)
if option == nil then
return nil, 'missing the "option" argument'
end
if option_index[option] == nil then
return nil, "unsupported option " .. tostring(option)
end
local err = get_string_buf(ERR_BUF_SIZE)
local errlen = get_size_ptr()
errlen[0] = ERR_BUF_SIZE
local rc = C.ngx_http_lua_ffi_socket_tcp_getoption(tcp_socket,
option_index[option],
output_value_buf,
err,
errlen)
if rc ~= FFI_OK then
return nil, ffi_string(err, errlen[0])
end
return tonumber(output_value_buf[0])
end
local function setoption(cosocket, option, value)
local tcp_socket = get_tcp_socket(cosocket)
if option == nil then
return nil, 'missing the "option" argument'
end
if value == nil then
return nil, 'missing the "value" argument'
end
if option_index[option] == nil then
return nil, "unsupported option " .. tostring(option)
end
local err = get_string_buf(ERR_BUF_SIZE)
local errlen = get_size_ptr()
errlen[0] = ERR_BUF_SIZE
local rc = C.ngx_http_lua_ffi_socket_tcp_setoption(tcp_socket,
option_index[option],
value,
err,
errlen)
if rc ~= FFI_OK then
return nil, ffi_string(err, errlen[0])
end
return true
end
do
local method_table = registry.__tcp_cosocket_mt
method_table.getoption = getoption
method_table.setoption = setoption
end
return { version = base.version }

159
resty/core/time.lua Normal file
View File

@@ -0,0 +1,159 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require 'ffi'
local base = require "resty.core.base"
local error = error
local tonumber = tonumber
local type = type
local C = ffi.C
local ffi_new = ffi.new
local ffi_str = ffi.string
local time_val = ffi_new("long[1]")
local get_string_buf = base.get_string_buf
local ngx = ngx
local FFI_ERROR = base.FFI_ERROR
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_now
local ngx_lua_ffi_time
local ngx_lua_ffi_today
local ngx_lua_ffi_localtime
local ngx_lua_ffi_utctime
local ngx_lua_ffi_update_time
if subsystem == 'http' then
ffi.cdef[[
double ngx_http_lua_ffi_now(void);
long ngx_http_lua_ffi_time(void);
void ngx_http_lua_ffi_today(unsigned char *buf);
void ngx_http_lua_ffi_localtime(unsigned char *buf);
void ngx_http_lua_ffi_utctime(unsigned char *buf);
void ngx_http_lua_ffi_update_time(void);
int ngx_http_lua_ffi_cookie_time(unsigned char *buf, long t);
void ngx_http_lua_ffi_http_time(unsigned char *buf, long t);
void ngx_http_lua_ffi_parse_http_time(const unsigned char *str, size_t len,
long *time);
]]
ngx_lua_ffi_now = C.ngx_http_lua_ffi_now
ngx_lua_ffi_time = C.ngx_http_lua_ffi_time
ngx_lua_ffi_today = C.ngx_http_lua_ffi_today
ngx_lua_ffi_localtime = C.ngx_http_lua_ffi_localtime
ngx_lua_ffi_utctime = C.ngx_http_lua_ffi_utctime
ngx_lua_ffi_update_time = C.ngx_http_lua_ffi_update_time
elseif subsystem == 'stream' then
ffi.cdef[[
double ngx_stream_lua_ffi_now(void);
long ngx_stream_lua_ffi_time(void);
void ngx_stream_lua_ffi_today(unsigned char *buf);
void ngx_stream_lua_ffi_localtime(unsigned char *buf);
void ngx_stream_lua_ffi_utctime(unsigned char *buf);
void ngx_stream_lua_ffi_update_time(void);
]]
ngx_lua_ffi_now = C.ngx_stream_lua_ffi_now
ngx_lua_ffi_time = C.ngx_stream_lua_ffi_time
ngx_lua_ffi_today = C.ngx_stream_lua_ffi_today
ngx_lua_ffi_localtime = C.ngx_stream_lua_ffi_localtime
ngx_lua_ffi_utctime = C.ngx_stream_lua_ffi_utctime
ngx_lua_ffi_update_time = C.ngx_stream_lua_ffi_update_time
end
function ngx.now()
return tonumber(ngx_lua_ffi_now())
end
function ngx.time()
return tonumber(ngx_lua_ffi_time())
end
function ngx.update_time()
ngx_lua_ffi_update_time()
end
function ngx.today()
-- the format of today is 2010-11-19
local today_buf_size = 10
local buf = get_string_buf(today_buf_size)
ngx_lua_ffi_today(buf)
return ffi_str(buf, today_buf_size)
end
function ngx.localtime()
-- the format of localtime is 2010-11-19 20:56:31
local localtime_buf_size = 19
local buf = get_string_buf(localtime_buf_size)
ngx_lua_ffi_localtime(buf)
return ffi_str(buf, localtime_buf_size)
end
function ngx.utctime()
-- the format of utctime is 2010-11-19 20:56:31
local utctime_buf_size = 19
local buf = get_string_buf(utctime_buf_size)
ngx_lua_ffi_utctime(buf)
return ffi_str(buf, utctime_buf_size)
end
if subsystem == 'http' then
function ngx.cookie_time(sec)
if type(sec) ~= "number" then
error("number argument only", 2)
end
-- the format of cookie time is Mon, 28-Sep-2038 06:00:00 GMT
-- or Mon, 28-Sep-18 06:00:00 GMT
local cookie_time_buf_size = 29
local buf = get_string_buf(cookie_time_buf_size)
local used_size = C.ngx_http_lua_ffi_cookie_time(buf, sec)
return ffi_str(buf, used_size)
end
function ngx.http_time(sec)
if type(sec) ~= "number" then
error("number argument only", 2)
end
-- the format of http time is Mon, 28 Sep 1970 06:00:00 GMT
local http_time_buf_size = 29
local buf = get_string_buf(http_time_buf_size)
C.ngx_http_lua_ffi_http_time(buf, sec)
return ffi_str(buf, http_time_buf_size)
end
function ngx.parse_http_time(time_str)
if type(time_str) ~= "string" then
error("string argument only", 2)
end
C.ngx_http_lua_ffi_parse_http_time(time_str, #time_str, time_val)
local res = time_val[0]
if res == FFI_ERROR then
return nil
end
return tonumber(res)
end
end
return {
version = base.version
}

115
resty/core/uri.lua Normal file
View File

@@ -0,0 +1,115 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local base = require "resty.core.base"
local C = ffi.C
local ffi_string = ffi.string
local ngx = ngx
local type = type
local error = error
local tostring = tostring
local get_string_buf = base.get_string_buf
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_escape_uri
local ngx_lua_ffi_unescape_uri
local ngx_lua_ffi_uri_escaped_length
local NGX_ESCAPE_URI = 0
local NGX_ESCAPE_URI_COMPONENT = 2
local NGX_ESCAPE_MAIL_AUTH = 6
if subsystem == "http" then
ffi.cdef[[
size_t ngx_http_lua_ffi_uri_escaped_length(const unsigned char *src,
size_t len, int type);
void ngx_http_lua_ffi_escape_uri(const unsigned char *src, size_t len,
unsigned char *dst, int type);
size_t ngx_http_lua_ffi_unescape_uri(const unsigned char *src,
size_t len, unsigned char *dst);
]]
ngx_lua_ffi_escape_uri = C.ngx_http_lua_ffi_escape_uri
ngx_lua_ffi_unescape_uri = C.ngx_http_lua_ffi_unescape_uri
ngx_lua_ffi_uri_escaped_length = C.ngx_http_lua_ffi_uri_escaped_length
elseif subsystem == "stream" then
ffi.cdef[[
size_t ngx_stream_lua_ffi_uri_escaped_length(const unsigned char *src,
size_t len, int type);
void ngx_stream_lua_ffi_escape_uri(const unsigned char *src, size_t len,
unsigned char *dst, int type);
size_t ngx_stream_lua_ffi_unescape_uri(const unsigned char *src,
size_t len, unsigned char *dst);
]]
ngx_lua_ffi_escape_uri = C.ngx_stream_lua_ffi_escape_uri
ngx_lua_ffi_unescape_uri = C.ngx_stream_lua_ffi_unescape_uri
ngx_lua_ffi_uri_escaped_length = C.ngx_stream_lua_ffi_uri_escaped_length
end
ngx.escape_uri = function (s, esc_type)
if type(s) ~= 'string' then
if not s then
s = ''
else
s = tostring(s)
end
end
if esc_type == nil then
esc_type = NGX_ESCAPE_URI_COMPONENT
else
if type(esc_type) ~= 'number' then
error("\"type\" is not a number", 3)
end
if esc_type < NGX_ESCAPE_URI or esc_type > NGX_ESCAPE_MAIL_AUTH then
error("\"type\" " .. esc_type .. " out of range", 3)
end
end
local slen = #s
local dlen = ngx_lua_ffi_uri_escaped_length(s, slen, esc_type)
-- print("dlen: ", tonumber(dlen))
if dlen == slen then
return s
end
local dst = get_string_buf(dlen)
ngx_lua_ffi_escape_uri(s, slen, dst, esc_type)
return ffi_string(dst, dlen)
end
ngx.unescape_uri = function (s)
if type(s) ~= 'string' then
if not s then
s = ''
else
s = tostring(s)
end
end
local slen = #s
local dlen = slen
local dst = get_string_buf(dlen)
dlen = ngx_lua_ffi_unescape_uri(s, slen, dst)
return ffi_string(dst, dlen)
end
return {
version = base.version,
}

46
resty/core/utils.lua Normal file
View File

@@ -0,0 +1,46 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local base = require "resty.core.base"
local C = ffi.C
local ffi_str = ffi.string
local ffi_copy = ffi.copy
local byte = string.byte
local str_find = string.find
local get_string_buf = base.get_string_buf
local subsystem = ngx.config.subsystem
local _M = {
version = base.version
}
if subsystem == "http" then
ffi.cdef[[
void ngx_http_lua_ffi_str_replace_char(unsigned char *buf, size_t len,
const unsigned char find, const unsigned char replace);
]]
function _M.str_replace_char(str, find, replace)
if not str_find(str, find, nil, true) then
return str
end
local len = #str
local buf = get_string_buf(len)
ffi_copy(buf, str, len)
C.ngx_http_lua_ffi_str_replace_char(buf, len, byte(find),
byte(replace))
return ffi_str(buf, len)
end
end
return _M

160
resty/core/var.lua Normal file
View File

@@ -0,0 +1,160 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local base = require "resty.core.base"
local C = ffi.C
local ffi_new = ffi.new
local ffi_str = ffi.string
local type = type
local error = error
local tostring = tostring
local setmetatable = setmetatable
local get_request = base.get_request
local get_string_buf = base.get_string_buf
local get_size_ptr = base.get_size_ptr
local new_tab = base.new_tab
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_var_get
local ngx_lua_ffi_var_set
local ERR_BUF_SIZE = 256
ngx.var = new_tab(0, 0)
if subsystem == "http" then
ffi.cdef[[
int ngx_http_lua_ffi_var_get(ngx_http_request_t *r,
const char *name_data, size_t name_len, char *lowcase_buf,
int capture_id, char **value, size_t *value_len, char **err);
int ngx_http_lua_ffi_var_set(ngx_http_request_t *r,
const unsigned char *name_data, size_t name_len,
unsigned char *lowcase_buf, const unsigned char *value,
size_t value_len, unsigned char *errbuf, size_t *errlen);
]]
ngx_lua_ffi_var_get = C.ngx_http_lua_ffi_var_get
ngx_lua_ffi_var_set = C.ngx_http_lua_ffi_var_set
elseif subsystem == "stream" then
ffi.cdef[[
int ngx_stream_lua_ffi_var_get(ngx_stream_lua_request_t *r,
const char *name_data, size_t name_len, char *lowcase_buf,
int capture_id, char **value, size_t *value_len, char **err);
int ngx_stream_lua_ffi_var_set(ngx_stream_lua_request_t *r,
const unsigned char *name_data, size_t name_len,
unsigned char *lowcase_buf, const unsigned char *value,
size_t value_len, unsigned char *errbuf, size_t *errlen);
]]
ngx_lua_ffi_var_get = C.ngx_stream_lua_ffi_var_get
ngx_lua_ffi_var_set = C.ngx_stream_lua_ffi_var_set
end
local value_ptr = ffi_new("unsigned char *[1]")
local errmsg = base.get_errmsg_ptr()
local function var_get(self, name)
local r = get_request()
if not r then
error("no request found")
end
local value_len = get_size_ptr()
local rc
if type(name) == "number" then
rc = ngx_lua_ffi_var_get(r, nil, 0, nil, name, value_ptr, value_len,
errmsg)
else
if type(name) ~= "string" then
error("bad variable name", 2)
end
local name_len = #name
local lowcase_buf = get_string_buf(name_len)
rc = ngx_lua_ffi_var_get(r, name, name_len, lowcase_buf, 0, value_ptr,
value_len, errmsg)
end
-- ngx.log(ngx.WARN, "rc = ", rc)
if rc == 0 then -- NGX_OK
return ffi_str(value_ptr[0], value_len[0])
end
if rc == -5 then -- NGX_DECLINED
return nil
end
if rc == -1 then -- NGX_ERROR
error(ffi_str(errmsg[0]), 2)
end
end
local function var_set(self, name, value)
local r = get_request()
if not r then
error("no request found")
end
if type(name) ~= "string" then
error("bad variable name", 2)
end
local name_len = #name
local errlen = get_size_ptr()
errlen[0] = ERR_BUF_SIZE
local lowcase_buf = get_string_buf(name_len + ERR_BUF_SIZE)
local value_len
if value == nil then
value_len = 0
else
if type(value) ~= 'string' then
value = tostring(value)
end
value_len = #value
end
local errbuf = lowcase_buf + name_len
local rc = ngx_lua_ffi_var_set(r, name, name_len, lowcase_buf, value,
value_len, errbuf, errlen)
-- ngx.log(ngx.WARN, "rc = ", rc)
if rc == 0 then -- NGX_OK
return
end
if rc == -1 then -- NGX_ERROR
error(ffi_str(errbuf, errlen[0]), 2)
end
end
do
local mt = new_tab(0, 2)
mt.__index = var_get
mt.__newindex = var_set
setmetatable(ngx.var, mt)
end
return {
version = base.version
}

77
resty/core/worker.lua Normal file
View File

@@ -0,0 +1,77 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local base = require "resty.core.base"
local C = ffi.C
local new_tab = base.new_tab
local subsystem = ngx.config.subsystem
local ngx_lua_ffi_worker_id
local ngx_lua_ffi_worker_pid
local ngx_lua_ffi_worker_count
local ngx_lua_ffi_worker_exiting
ngx.worker = new_tab(0, 4)
if subsystem == "http" then
ffi.cdef[[
int ngx_http_lua_ffi_worker_id(void);
int ngx_http_lua_ffi_worker_pid(void);
int ngx_http_lua_ffi_worker_count(void);
int ngx_http_lua_ffi_worker_exiting(void);
]]
ngx_lua_ffi_worker_id = C.ngx_http_lua_ffi_worker_id
ngx_lua_ffi_worker_pid = C.ngx_http_lua_ffi_worker_pid
ngx_lua_ffi_worker_count = C.ngx_http_lua_ffi_worker_count
ngx_lua_ffi_worker_exiting = C.ngx_http_lua_ffi_worker_exiting
elseif subsystem == "stream" then
ffi.cdef[[
int ngx_stream_lua_ffi_worker_id(void);
int ngx_stream_lua_ffi_worker_pid(void);
int ngx_stream_lua_ffi_worker_count(void);
int ngx_stream_lua_ffi_worker_exiting(void);
]]
ngx_lua_ffi_worker_id = C.ngx_stream_lua_ffi_worker_id
ngx_lua_ffi_worker_pid = C.ngx_stream_lua_ffi_worker_pid
ngx_lua_ffi_worker_count = C.ngx_stream_lua_ffi_worker_count
ngx_lua_ffi_worker_exiting = C.ngx_stream_lua_ffi_worker_exiting
end
function ngx.worker.exiting()
return ngx_lua_ffi_worker_exiting() ~= 0
end
function ngx.worker.pid()
return ngx_lua_ffi_worker_pid()
end
function ngx.worker.id()
local id = ngx_lua_ffi_worker_id()
if id < 0 then
return nil
end
return id
end
function ngx.worker.count()
return ngx_lua_ffi_worker_count()
end
return {
_VERSION = base.version
}

982
resty/dns/resolver.lua Normal file
View File

@@ -0,0 +1,982 @@
-- Copyright (C) Yichun Zhang (agentzh)
-- local socket = require "socket"
local bit = require "bit"
local udp = ngx.socket.udp
local rand = math.random
local char = string.char
local byte = string.byte
local find = string.find
local gsub = string.gsub
local sub = string.sub
local rep = string.rep
local format = string.format
local band = bit.band
local rshift = bit.rshift
local lshift = bit.lshift
local insert = table.insert
local concat = table.concat
local re_sub = ngx.re.sub
local tcp = ngx.socket.tcp
local log = ngx.log
local DEBUG = ngx.DEBUG
local unpack = unpack
local setmetatable = setmetatable
local type = type
local ipairs = ipairs
local ok, new_tab = pcall(require, "table.new")
if not ok then
new_tab = function (narr, nrec) return {} end
end
local DOT_CHAR = byte(".")
local ZERO_CHAR = byte("0")
local COLON_CHAR = byte(":")
local IP6_ARPA = "ip6.arpa"
local TYPE_A = 1
local TYPE_NS = 2
local TYPE_CNAME = 5
local TYPE_SOA = 6
local TYPE_PTR = 12
local TYPE_MX = 15
local TYPE_TXT = 16
local TYPE_AAAA = 28
local TYPE_SRV = 33
local TYPE_SPF = 99
local CLASS_IN = 1
local SECTION_AN = 1
local SECTION_NS = 2
local SECTION_AR = 3
local _M = {
_VERSION = '0.22',
TYPE_A = TYPE_A,
TYPE_NS = TYPE_NS,
TYPE_CNAME = TYPE_CNAME,
TYPE_SOA = TYPE_SOA,
TYPE_PTR = TYPE_PTR,
TYPE_MX = TYPE_MX,
TYPE_TXT = TYPE_TXT,
TYPE_AAAA = TYPE_AAAA,
TYPE_SRV = TYPE_SRV,
TYPE_SPF = TYPE_SPF,
CLASS_IN = CLASS_IN,
SECTION_AN = SECTION_AN,
SECTION_NS = SECTION_NS,
SECTION_AR = SECTION_AR
}
local resolver_errstrs = {
"format error", -- 1
"server failure", -- 2
"name error", -- 3
"not implemented", -- 4
"refused", -- 5
}
local soa_int32_fields = { "serial", "refresh", "retry", "expire", "minimum" }
local mt = { __index = _M }
local arpa_tmpl = new_tab(72, 0)
for i = 1, #IP6_ARPA do
arpa_tmpl[64 + i] = byte(IP6_ARPA, i)
end
for i = 2, 64, 2 do
arpa_tmpl[i] = DOT_CHAR
end
function _M.new(class, opts)
if not opts then
return nil, "no options table specified"
end
local servers = opts.nameservers
if not servers or #servers == 0 then
return nil, "no nameservers specified"
end
local timeout = opts.timeout or 2000 -- default 2 sec
local n = #servers
local socks = {}
for i = 1, n do
local server = servers[i]
local sock, err = udp()
if not sock then
return nil, "failed to create udp socket: " .. err
end
local host, port
if type(server) == 'table' then
host = server[1]
port = server[2] or 53
else
host = server
port = 53
servers[i] = {host, port}
end
local ok, err = sock:setpeername(host, port)
if not ok then
return nil, "failed to set peer name: " .. err
end
sock:settimeout(timeout)
insert(socks, sock)
end
local tcp_sock, err = tcp()
if not tcp_sock then
return nil, "failed to create tcp socket: " .. err
end
tcp_sock:settimeout(timeout)
return setmetatable(
{ cur = opts.no_random and 1 or rand(1, n),
socks = socks,
tcp_sock = tcp_sock,
servers = servers,
retrans = opts.retrans or 5,
no_recurse = opts.no_recurse,
}, mt)
end
local function pick_sock(self, socks)
local cur = self.cur
if cur == #socks then
self.cur = 1
else
self.cur = cur + 1
end
return socks[cur]
end
local function _get_cur_server(self)
local cur = self.cur
local servers = self.servers
if cur == 1 then
return servers[#servers]
end
return servers[cur - 1]
end
function _M.set_timeout(self, timeout)
local socks = self.socks
if not socks then
return nil, "not initialized"
end
for i = 1, #socks do
local sock = socks[i]
sock:settimeout(timeout)
end
local tcp_sock = self.tcp_sock
if not tcp_sock then
return nil, "not initialized"
end
tcp_sock:settimeout(timeout)
end
local function _encode_name(s)
return char(#s) .. s
end
local function _decode_name(buf, pos)
local labels = {}
local nptrs = 0
local p = pos
while nptrs < 128 do
local fst = byte(buf, p)
if not fst then
return nil, 'truncated';
end
-- print("fst at ", p, ": ", fst)
if fst == 0 then
if nptrs == 0 then
pos = pos + 1
end
break
end
if band(fst, 0xc0) ~= 0 then
-- being a pointer
if nptrs == 0 then
pos = pos + 2
end
nptrs = nptrs + 1
local snd = byte(buf, p + 1)
if not snd then
return nil, 'truncated'
end
p = lshift(band(fst, 0x3f), 8) + snd + 1
-- print("resolving ptr ", p, ": ", byte(buf, p))
else
-- being a label
local label = sub(buf, p + 1, p + fst)
insert(labels, label)
-- print("resolved label ", label)
p = p + fst + 1
if nptrs == 0 then
pos = p
end
end
end
return concat(labels, "."), pos
end
local function _build_request(qname, id, no_recurse, opts)
local qtype
if opts then
qtype = opts.qtype
end
if not qtype then
qtype = 1 -- A record
end
local ident_hi = char(rshift(id, 8))
local ident_lo = char(band(id, 0xff))
local flags
if no_recurse then
-- print("found no recurse")
flags = "\0\0"
else
flags = "\1\0"
end
local nqs = "\0\1"
local nan = "\0\0"
local nns = "\0\0"
local nar = "\0\0"
local typ = char(rshift(qtype, 8), band(qtype, 0xff))
local class = "\0\1" -- the Internet class
if byte(qname, 1) == DOT_CHAR then
return nil, "bad name"
end
local name = gsub(qname, "([^.]+)%.?", _encode_name) .. '\0'
return {
ident_hi, ident_lo, flags, nqs, nan, nns, nar,
name, typ, class
}
end
local function parse_section(answers, section, buf, start_pos, size,
should_skip)
local pos = start_pos
for _ = 1, size do
-- print(format("ans %d: qtype:%d qclass:%d", i, qtype, qclass))
local ans = {}
if not should_skip then
insert(answers, ans)
end
ans.section = section
local name
name, pos = _decode_name(buf, pos)
if not name then
return nil, pos
end
ans.name = name
-- print("name: ", name)
local type_hi = byte(buf, pos)
local type_lo = byte(buf, pos + 1)
local typ = lshift(type_hi, 8) + type_lo
ans.type = typ
-- print("type: ", typ)
local class_hi = byte(buf, pos + 2)
local class_lo = byte(buf, pos + 3)
local class = lshift(class_hi, 8) + class_lo
ans.class = class
-- print("class: ", class)
local byte_1, byte_2, byte_3, byte_4 = byte(buf, pos + 4, pos + 7)
local ttl = lshift(byte_1, 24) + lshift(byte_2, 16)
+ lshift(byte_3, 8) + byte_4
-- print("ttl: ", ttl)
ans.ttl = ttl
local len_hi = byte(buf, pos + 8)
local len_lo = byte(buf, pos + 9)
local len = lshift(len_hi, 8) + len_lo
-- print("record len: ", len)
pos = pos + 10
if typ == TYPE_A then
if len ~= 4 then
return nil, "bad A record value length: " .. len
end
local addr_bytes = { byte(buf, pos, pos + 3) }
local addr = concat(addr_bytes, ".")
-- print("ipv4 address: ", addr)
ans.address = addr
pos = pos + 4
elseif typ == TYPE_CNAME then
local cname, p = _decode_name(buf, pos)
if not cname then
return nil, pos
end
if p - pos ~= len then
return nil, format("bad cname record length: %d ~= %d",
p - pos, len)
end
pos = p
-- print("cname: ", cname)
ans.cname = cname
elseif typ == TYPE_AAAA then
if len ~= 16 then
return nil, "bad AAAA record value length: " .. len
end
local addr_bytes = { byte(buf, pos, pos + 15) }
local flds = {}
for i = 1, 16, 2 do
local a = addr_bytes[i]
local b = addr_bytes[i + 1]
if a == 0 then
insert(flds, format("%x", b))
else
insert(flds, format("%x%02x", a, b))
end
end
-- we do not compress the IPv6 addresses by default
-- due to performance considerations
ans.address = concat(flds, ":")
pos = pos + 16
elseif typ == TYPE_MX then
-- print("len = ", len)
if len < 3 then
return nil, "bad MX record value length: " .. len
end
local pref_hi = byte(buf, pos)
local pref_lo = byte(buf, pos + 1)
ans.preference = lshift(pref_hi, 8) + pref_lo
local host, p = _decode_name(buf, pos + 2)
if not host then
return nil, pos
end
if p - pos ~= len then
return nil, format("bad cname record length: %d ~= %d",
p - pos, len)
end
ans.exchange = host
pos = p
elseif typ == TYPE_SRV then
if len < 7 then
return nil, "bad SRV record value length: " .. len
end
local prio_hi = byte(buf, pos)
local prio_lo = byte(buf, pos + 1)
ans.priority = lshift(prio_hi, 8) + prio_lo
local weight_hi = byte(buf, pos + 2)
local weight_lo = byte(buf, pos + 3)
ans.weight = lshift(weight_hi, 8) + weight_lo
local port_hi = byte(buf, pos + 4)
local port_lo = byte(buf, pos + 5)
ans.port = lshift(port_hi, 8) + port_lo
local name, p = _decode_name(buf, pos + 6)
if not name then
return nil, pos
end
if p - pos ~= len then
return nil, format("bad srv record length: %d ~= %d",
p - pos, len)
end
ans.target = name
pos = p
elseif typ == TYPE_NS then
local name, p = _decode_name(buf, pos)
if not name then
return nil, pos
end
if p - pos ~= len then
return nil, format("bad cname record length: %d ~= %d",
p - pos, len)
end
pos = p
-- print("name: ", name)
ans.nsdname = name
elseif typ == TYPE_TXT or typ == TYPE_SPF then
local key = (typ == TYPE_TXT) and "txt" or "spf"
local slen = byte(buf, pos)
if slen + 1 > len then
-- truncate the over-run TXT record data
slen = len
end
-- print("slen: ", len)
local val = sub(buf, pos + 1, pos + slen)
local last = pos + len
pos = pos + slen + 1
if pos < last then
-- more strings to be processed
-- this code path is usually cold, so we do not
-- merge the following loop on this code path
-- with the processing logic above.
val = {val}
local idx = 2
repeat
local slen = byte(buf, pos)
if pos + slen + 1 > last then
-- truncate the over-run TXT record data
slen = last - pos - 1
end
val[idx] = sub(buf, pos + 1, pos + slen)
idx = idx + 1
pos = pos + slen + 1
until pos >= last
end
ans[key] = val
elseif typ == TYPE_PTR then
local name, p = _decode_name(buf, pos)
if not name then
return nil, pos
end
if p - pos ~= len then
return nil, format("bad cname record length: %d ~= %d",
p - pos, len)
end
pos = p
-- print("name: ", name)
ans.ptrdname = name
elseif typ == TYPE_SOA then
local name, p = _decode_name(buf, pos)
if not name then
return nil, pos
end
ans.mname = name
pos = p
name, p = _decode_name(buf, pos)
if not name then
return nil, pos
end
ans.rname = name
for _, field in ipairs(soa_int32_fields) do
local byte_1, byte_2, byte_3, byte_4 = byte(buf, p, p + 3)
ans[field] = lshift(byte_1, 24) + lshift(byte_2, 16)
+ lshift(byte_3, 8) + byte_4
p = p + 4
end
pos = p
else
-- for unknown types, just forward the raw value
ans.rdata = sub(buf, pos, pos + len - 1)
pos = pos + len
end
end
return pos
end
local function parse_response(buf, id, opts)
local n = #buf
if n < 12 then
return nil, 'truncated';
end
-- header layout: ident flags nqs nan nns nar
local ident_hi = byte(buf, 1)
local ident_lo = byte(buf, 2)
local ans_id = lshift(ident_hi, 8) + ident_lo
-- print("id: ", id, ", ans id: ", ans_id)
if ans_id ~= id then
-- identifier mismatch and throw it away
log(DEBUG, "id mismatch in the DNS reply: ", ans_id, " ~= ", id)
return nil, "id mismatch"
end
local flags_hi = byte(buf, 3)
local flags_lo = byte(buf, 4)
local flags = lshift(flags_hi, 8) + flags_lo
-- print(format("flags: 0x%x", flags))
if band(flags, 0x8000) == 0 then
return nil, format("bad QR flag in the DNS response")
end
if band(flags, 0x200) ~= 0 then
return nil, "truncated"
end
local code = band(flags, 0xf)
-- print(format("code: %d", code))
local nqs_hi = byte(buf, 5)
local nqs_lo = byte(buf, 6)
local nqs = lshift(nqs_hi, 8) + nqs_lo
-- print("nqs: ", nqs)
if nqs ~= 1 then
return nil, format("bad number of questions in DNS response: %d", nqs)
end
local nan_hi = byte(buf, 7)
local nan_lo = byte(buf, 8)
local nan = lshift(nan_hi, 8) + nan_lo
-- print("nan: ", nan)
local nns_hi = byte(buf, 9)
local nns_lo = byte(buf, 10)
local nns = lshift(nns_hi, 8) + nns_lo
local nar_hi = byte(buf, 11)
local nar_lo = byte(buf, 12)
local nar = lshift(nar_hi, 8) + nar_lo
-- skip the question part
local ans_qname, pos = _decode_name(buf, 13)
if not ans_qname then
return nil, pos
end
-- print("qname in reply: ", ans_qname)
-- print("question: ", sub(buf, 13, pos))
if pos + 3 + nan * 12 > n then
-- print(format("%d > %d", pos + 3 + nan * 12, n))
return nil, 'truncated';
end
-- question section layout: qname qtype(2) qclass(2)
--[[
local type_hi = byte(buf, pos)
local type_lo = byte(buf, pos + 1)
local ans_type = lshift(type_hi, 8) + type_lo
]]
-- print("ans qtype: ", ans_type)
local class_hi = byte(buf, pos + 2)
local class_lo = byte(buf, pos + 3)
local qclass = lshift(class_hi, 8) + class_lo
-- print("ans qclass: ", qclass)
if qclass ~= 1 then
return nil, format("unknown query class %d in DNS response", qclass)
end
pos = pos + 4
local answers = {}
if code ~= 0 then
answers.errcode = code
answers.errstr = resolver_errstrs[code] or "unknown"
end
local authority_section, additional_section
if opts then
authority_section = opts.authority_section
additional_section = opts.additional_section
if opts.qtype == TYPE_SOA then
authority_section = true
end
end
local err
pos, err = parse_section(answers, SECTION_AN, buf, pos, nan)
if not pos then
return nil, err
end
if not authority_section and not additional_section then
return answers
end
pos, err = parse_section(answers, SECTION_NS, buf, pos, nns,
not authority_section)
if not pos then
return nil, err
end
if not additional_section then
return answers
end
pos, err = parse_section(answers, SECTION_AR, buf, pos, nar)
if not pos then
return nil, err
end
return answers
end
local function _gen_id(self)
local id = self._id -- for regression testing
if id then
return id
end
return rand(0, 65535) -- two bytes
end
local function _tcp_query(self, query, id, opts)
local sock = self.tcp_sock
if not sock then
return nil, "not initialized"
end
log(DEBUG, "query the TCP server due to reply truncation")
local server = _get_cur_server(self)
local ok, err = sock:connect(server[1], server[2])
if not ok then
return nil, "failed to connect to TCP server "
.. concat(server, ":") .. ": " .. err
end
query = concat(query, "")
local len = #query
local len_hi = char(rshift(len, 8))
local len_lo = char(band(len, 0xff))
local bytes, err = sock:send({len_hi, len_lo, query})
if not bytes then
return nil, "failed to send query to TCP server "
.. concat(server, ":") .. ": " .. err
end
local buf, err = sock:receive(2)
if not buf then
return nil, "failed to receive the reply length field from TCP server "
.. concat(server, ":") .. ": " .. err
end
len_hi = byte(buf, 1)
len_lo = byte(buf, 2)
len = lshift(len_hi, 8) + len_lo
-- print("tcp message len: ", len)
buf, err = sock:receive(len)
if not buf then
return nil, "failed to receive the reply message body from TCP server "
.. concat(server, ":") .. ": " .. err
end
local answers, err = parse_response(buf, id, opts)
if not answers then
return nil, "failed to parse the reply from the TCP server "
.. concat(server, ":") .. ": " .. err
end
sock:close()
return answers
end
function _M.tcp_query(self, qname, opts)
local socks = self.socks
if not socks then
return nil, "not initialized"
end
pick_sock(self, socks)
local id = _gen_id(self)
local query, err = _build_request(qname, id, self.no_recurse, opts)
if not query then
return nil, err
end
return _tcp_query(self, query, id, opts)
end
function _M.query(self, qname, opts, tries)
local socks = self.socks
if not socks then
return nil, "not initialized"
end
local id = _gen_id(self)
local query, err = _build_request(qname, id, self.no_recurse, opts)
if not query then
return nil, err
end
-- local cjson = require "cjson"
-- print("query: ", cjson.encode(concat(query, "")))
local retrans = self.retrans
if tries then
tries[1] = nil
end
-- print("retrans: ", retrans)
for i = 1, retrans do
local sock = pick_sock(self, socks)
local ok
ok, err = sock:send(query)
if not ok then
local server = _get_cur_server(self)
err = "failed to send request to UDP server "
.. concat(server, ":") .. ": " .. err
else
local buf
for _ = 1, 128 do
buf, err = sock:receive(4096)
if err then
local server = _get_cur_server(self)
err = "failed to receive reply from UDP server "
.. concat(server, ":") .. ": " .. err
break
end
if buf then
local answers
answers, err = parse_response(buf, id, opts)
if err == "truncated" then
answers, err = _tcp_query(self, query, id, opts)
end
if err and err ~= "id mismatch" then
break
end
if answers then
return answers, nil, tries
end
end
-- only here in case of an "id mismatch"
end
end
if tries then
tries[i] = err
tries[i + 1] = nil -- ensure termination for user supplied table
end
end
return nil, err, tries
end
function _M.compress_ipv6_addr(addr)
local addr = re_sub(addr, "^(0:)+|(:0)+$|:(0:)+", "::", "jo")
if addr == "::0" then
addr = "::"
end
return addr
end
local function _expand_ipv6_addr(addr)
if find(addr, "::", 1, true) then
local ncol, addrlen = 8, #addr
for i = 1, addrlen do
if byte(addr, i) == COLON_CHAR then
ncol = ncol - 1
end
end
if byte(addr, 1) == COLON_CHAR then
addr = "0" .. addr
end
if byte(addr, -1) == COLON_CHAR then
addr = addr .. "0"
end
addr = re_sub(addr, "::", ":" .. rep("0:", ncol), "jo")
end
return addr
end
_M.expand_ipv6_addr = _expand_ipv6_addr
function _M.arpa_str(addr)
if find(addr, ":", 1, true) then
addr = _expand_ipv6_addr(addr)
local idx, hidx, addrlen = 1, 1, #addr
for i = addrlen, 0, -1 do
local s = byte(addr, i)
if s == COLON_CHAR or not s then
for _ = hidx, 4 do
arpa_tmpl[idx] = ZERO_CHAR
idx = idx + 2
end
hidx = 1
else
arpa_tmpl[idx] = s
idx = idx + 2
hidx = hidx + 1
end
end
addr = char(unpack(arpa_tmpl))
else
addr = re_sub(addr, [[(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})]],
"$4.$3.$2.$1.in-addr.arpa", "ajo")
end
return addr
end
function _M.reverse_query(self, addr)
return self.query(self, self.arpa_str(addr),
{qtype = self.TYPE_PTR})
end
return _M

125
resty/limit/conn.lua Normal file
View File

@@ -0,0 +1,125 @@
-- Copyright (C) Yichun Zhang (agentzh)
--
-- This library is an enhanced Lua port of the standard ngx_limit_conn
-- module.
local math = require "math"
local setmetatable = setmetatable
local floor = math.floor
local ngx_shared = ngx.shared
local assert = assert
local _M = {
_VERSION = '0.07'
}
local mt = {
__index = _M
}
function _M.new(dict_name, max, burst, default_conn_delay)
local dict = ngx_shared[dict_name]
if not dict then
return nil, "shared dict not found"
end
assert(max > 0 and burst >= 0 and default_conn_delay > 0)
local self = {
dict = dict,
max = max + 0, -- just to ensure the param is good
burst = burst,
unit_delay = default_conn_delay,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, commit)
local dict = self.dict
local max = self.max
self.committed = false
local conn, err
if commit then
conn, err = dict:incr(key, 1, 0)
if not conn then
return nil, err
end
if conn > max + self.burst then
conn, err = dict:incr(key, -1)
if not conn then
return nil, err
end
return nil, "rejected"
end
self.committed = true
else
conn = (dict:get(key) or 0) + 1
if conn > max + self.burst then
return nil, "rejected"
end
end
if conn > max then
-- make the exessive connections wait
return self.unit_delay * floor((conn - 1) / max), conn
end
-- we return a 0 delay by default
return 0, conn
end
function _M.is_committed(self)
return self.committed
end
function _M.leaving(self, key, req_latency)
assert(key)
local dict = self.dict
local conn, err = dict:incr(key, -1)
if not conn then
return nil, err
end
if req_latency then
local unit_delay = self.unit_delay
self.unit_delay = (req_latency + unit_delay) / 2
end
return conn
end
function _M.uncommit(self, key)
assert(key)
local dict = self.dict
return dict:incr(key, -1)
end
function _M.set_conn(self, conn)
self.max = conn
end
function _M.set_burst(self, burst)
self.burst = burst
end
return _M

103
resty/limit/count.lua Normal file
View File

@@ -0,0 +1,103 @@
-- implement GitHub request rate limiting:
-- https://developer.github.com/v3/#rate-limiting
local ngx_shared = ngx.shared
local setmetatable = setmetatable
local assert = assert
local _M = {
_VERSION = '0.07'
}
local mt = {
__index = _M
}
-- the "limit" argument controls number of request allowed in a time window.
-- time "window" argument controls the time window in seconds.
function _M.new(dict_name, limit, window)
local dict = ngx_shared[dict_name]
if not dict then
return nil, "shared dict not found"
end
assert(limit > 0 and window > 0)
local self = {
dict = dict,
limit = limit,
window = window,
}
return setmetatable(self, mt)
end
function _M.incoming(self, key, commit)
local dict = self.dict
local limit = self.limit
local window = self.window
local remaining, ok, err
if commit then
remaining, err = dict:incr(key, -1, limit)
if not remaining then
return nil, err
end
if remaining == limit - 1 then
ok, err = dict:expire(key, window)
if not ok then
if err == "not found" then
remaining, err = dict:incr(key, -1, limit)
if not remaining then
return nil, err
end
ok, err = dict:expire(key, window)
if not ok then
return nil, err
end
else
return nil, err
end
end
end
else
remaining = (dict:get(key) or limit) - 1
end
if remaining < 0 then
return nil, "rejected"
end
return 0, remaining
end
-- uncommit remaining and return remaining value
function _M.uncommit(self, key)
assert(key)
local dict = self.dict
local limit = self.limit
local remaining, err = dict:incr(key, 1)
if not remaining then
if err == "not found" then
remaining = limit
else
return nil, err
end
end
return remaining
end
return _M

153
resty/limit/req.lua Normal file
View File

@@ -0,0 +1,153 @@
-- Copyright (C) Yichun Zhang (agentzh)
--
-- This library is an approximate Lua port of the standard ngx_limit_req
-- module.
local ffi = require "ffi"
local math = require "math"
local ngx_shared = ngx.shared
local ngx_now = ngx.now
local setmetatable = setmetatable
local ffi_cast = ffi.cast
local ffi_str = ffi.string
local abs = math.abs
local tonumber = tonumber
local type = type
local assert = assert
local max = math.max
-- TODO: we could avoid the tricky FFI cdata when lua_shared_dict supports
-- hash-typed values as in redis.
ffi.cdef[[
struct lua_resty_limit_req_rec {
unsigned long excess;
uint64_t last; /* time in milliseconds */
/* integer value, 1 corresponds to 0.001 r/s */
};
]]
local const_rec_ptr_type = ffi.typeof("const struct lua_resty_limit_req_rec*")
local rec_size = ffi.sizeof("struct lua_resty_limit_req_rec")
-- we can share the cdata here since we only need it temporarily for
-- serialization inside the shared dict:
local rec_cdata = ffi.new("struct lua_resty_limit_req_rec")
local _M = {
_VERSION = '0.07'
}
local mt = {
__index = _M
}
function _M.new(dict_name, rate, burst)
local dict = ngx_shared[dict_name]
if not dict then
return nil, "shared dict not found"
end
assert(rate > 0 and burst >= 0)
local self = {
dict = dict,
rate = rate * 1000,
burst = burst * 1000,
}
return setmetatable(self, mt)
end
-- sees an new incoming event
-- the "commit" argument controls whether should we record the event in shm.
-- FIXME we have a (small) race-condition window between dict:get() and
-- dict:set() across multiple nginx worker processes. The size of the
-- window is proportional to the number of workers.
function _M.incoming(self, key, commit)
local dict = self.dict
local rate = self.rate
local now = ngx_now() * 1000
local excess
-- it's important to anchor the string value for the read-only pointer
-- cdata:
local v = dict:get(key)
if v then
if type(v) ~= "string" or #v ~= rec_size then
return nil, "shdict abused by other users"
end
local rec = ffi_cast(const_rec_ptr_type, v)
local elapsed = now - tonumber(rec.last)
-- print("elapsed: ", elapsed, "ms")
-- we do not handle changing rate values specifically. the excess value
-- can get automatically adjusted by the following formula with new rate
-- values rather quickly anyway.
excess = max(tonumber(rec.excess) - rate * abs(elapsed) / 1000 + 1000,
0)
-- print("excess: ", excess)
if excess > self.burst then
return nil, "rejected"
end
else
excess = 0
end
if commit then
rec_cdata.excess = excess
rec_cdata.last = now
dict:set(key, ffi_str(rec_cdata, rec_size))
end
-- return the delay in seconds, as well as excess
return excess / rate, excess / 1000
end
function _M.uncommit(self, key)
assert(key)
local dict = self.dict
local v = dict:get(key)
if not v then
return nil, "not found"
end
if type(v) ~= "string" or #v ~= rec_size then
return nil, "shdict abused by other users"
end
local rec = ffi_cast(const_rec_ptr_type, v)
local excess = max(tonumber(rec.excess) - 1000, 0)
rec_cdata.excess = excess
rec_cdata.last = rec.last
dict:set(key, ffi_str(rec_cdata, rec_size))
return true
end
function _M.set_rate(self, rate)
self.rate = rate * 1000
end
function _M.set_burst(self, burst)
self.burst = burst * 1000
end
return _M

58
resty/limit/traffic.lua Normal file
View File

@@ -0,0 +1,58 @@
-- Copyright (C) Yichun Zhang (agentzh)
--
-- This is an aggregator for various concrete traffic limiter instances
-- (like instances of the resty.limit.req, resty.limit.count and
-- resty.limit.conn classes).
local max = math.max
local _M = {
_VERSION = '0.07'
}
-- the states table is user supplied. each element stores the 2nd return value
-- of each limiter if there is no error returned. for resty.limit.req, the state
-- is the "excess" value (i.e., the number of excessive requests each second),
-- and for resty.limit.conn, the state is the current concurrency level
-- (including the current new connection).
function _M.combine(limiters, keys, states)
local n = #limiters
local max_delay = 0
for i = 1, n do
local lim = limiters[i]
local delay, err = lim:incoming(keys[i], i == n)
if not delay then
return nil, err
end
if i == n then
if states then
states[i] = err
end
max_delay = delay
end
end
for i = 1, n - 1 do
local lim = limiters[i]
local delay, err = lim:incoming(keys[i], true)
if not delay then
for j = 1, i - 1 do
-- we intentionally ignore any errors returned below.
limiters[j]:uncommit(keys[j])
end
limiters[n]:uncommit(keys[n])
return nil, err
end
if states then
states[i] = err
end
max_delay = max(max_delay, delay)
end
return max_delay
end
return _M

221
resty/lock.lua Normal file
View File

@@ -0,0 +1,221 @@
-- Copyright (C) Yichun Zhang (agentzh)
require "resty.core.shdict" -- enforce this to avoid dead locks
local ffi = require "ffi"
local ffi_new = ffi.new
local shared = ngx.shared
local sleep = ngx.sleep
local log = ngx.log
local max = math.max
local min = math.min
local debug = ngx.config.debug
local setmetatable = setmetatable
local tonumber = tonumber
local _M = { _VERSION = '0.08' }
local mt = { __index = _M }
local ERR = ngx.ERR
local FREE_LIST_REF = 0
-- FIXME: we don't need this when we have __gc metamethod support on Lua
-- tables.
local memo = {}
if debug then _M.memo = memo end
local function ref_obj(key)
if key == nil then
return -1
end
local ref = memo[FREE_LIST_REF]
if ref and ref ~= 0 then
memo[FREE_LIST_REF] = memo[ref]
else
ref = #memo + 1
end
memo[ref] = key
-- print("ref key_id returned ", ref)
return ref
end
if debug then _M.ref_obj = ref_obj end
local function unref_obj(ref)
if ref >= 0 then
memo[ref] = memo[FREE_LIST_REF]
memo[FREE_LIST_REF] = ref
end
end
if debug then _M.unref_obj = unref_obj end
local function gc_lock(cdata)
local dict_id = tonumber(cdata.dict_id)
local key_id = tonumber(cdata.key_id)
-- print("key_id: ", key_id, ", key: ", memo[key_id], "dict: ",
-- type(memo[cdata.dict_id]))
if key_id > 0 then
local key = memo[key_id]
unref_obj(key_id)
local dict = memo[dict_id]
-- print("dict.delete type: ", type(dict.delete))
local ok, err = dict:delete(key)
if not ok then
log(ERR, 'failed to delete key "', key, '": ', err)
end
cdata.key_id = 0
end
unref_obj(dict_id)
end
local ctype = ffi.metatype("struct { int key_id; int dict_id; }",
{ __gc = gc_lock })
function _M.new(_, dict_name, opts)
local dict = shared[dict_name]
if not dict then
return nil, "dictionary not found"
end
local cdata = ffi_new(ctype)
cdata.key_id = 0
cdata.dict_id = ref_obj(dict)
local timeout, exptime, step, ratio, max_step
if opts then
timeout = opts.timeout
exptime = opts.exptime
step = opts.step
ratio = opts.ratio
max_step = opts.max_step
end
if not exptime then
exptime = 30
end
if timeout then
timeout = min(timeout, exptime)
if step then
step = min(step, timeout)
end
end
local self = {
cdata = cdata,
dict = dict,
timeout = timeout or 5,
exptime = exptime,
step = step or 0.001,
ratio = ratio or 2,
max_step = max_step or 0.5,
}
return setmetatable(self, mt)
end
function _M.lock(self, key)
if not key then
return nil, "nil key"
end
local dict = self.dict
local cdata = self.cdata
if cdata.key_id > 0 then
return nil, "locked"
end
local exptime = self.exptime
local ok, err = dict:add(key, true, exptime)
if ok then
cdata.key_id = ref_obj(key)
self.key = key
return 0
end
if err ~= "exists" then
return nil, err
end
-- lock held by others
local step = self.step
local ratio = self.ratio
local timeout = self.timeout
local max_step = self.max_step
local elapsed = 0
while timeout > 0 do
sleep(step)
elapsed = elapsed + step
timeout = timeout - step
local ok, err = dict:add(key, true, exptime)
if ok then
cdata.key_id = ref_obj(key)
self.key = key
return elapsed
end
if err ~= "exists" then
return nil, err
end
if timeout <= 0 then
break
end
step = min(max(0.001, step * ratio), timeout, max_step)
end
return nil, "timeout"
end
function _M.unlock(self)
local dict = self.dict
local cdata = self.cdata
local key_id = tonumber(cdata.key_id)
if key_id <= 0 then
return nil, "unlocked"
end
local key = memo[key_id]
unref_obj(key_id)
local ok, err = dict:delete(key)
if not ok then
return nil, err
end
cdata.key_id = 0
return 1
end
function _M.expire(self, time)
local dict = self.dict
local cdata = self.cdata
local key_id = tonumber(cdata.key_id)
if key_id <= 0 then
return nil, "unlocked"
end
if not time then
time = self.exptime
end
local ok, err = dict:replace(self.key, true, time)
if not ok then
return nil, err
end
return true
end
return _M

340
resty/lrucache.lua Normal file
View File

@@ -0,0 +1,340 @@
-- Copyright (C) Yichun Zhang (agentzh)
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_sizeof = ffi.sizeof
local ffi_cast = ffi.cast
local ffi_fill = ffi.fill
local ngx_now = ngx.now
local uintptr_t = ffi.typeof("uintptr_t")
local setmetatable = setmetatable
local tonumber = tonumber
local type = type
local new_tab
do
local ok
ok, new_tab = pcall(require, "table.new")
if not ok then
new_tab = function(narr, nrec) return {} end
end
end
if string.find(jit.version, " 2.0", 1, true) then
ngx.log(ngx.ALERT, "use of lua-resty-lrucache with LuaJIT 2.0 is ",
"not recommended; use LuaJIT 2.1+ instead")
end
local ok, tb_clear = pcall(require, "table.clear")
if not ok then
local pairs = pairs
tb_clear = function (tab)
for k, _ in pairs(tab) do
tab[k] = nil
end
end
end
-- queue data types
--
-- this queue is a double-ended queue and the first node
-- is reserved for the queue itself.
-- the implementation is mostly borrowed from nginx's ngx_queue_t data
-- structure.
ffi.cdef[[
typedef struct lrucache_queue_s lrucache_queue_t;
struct lrucache_queue_s {
double expire; /* in seconds */
lrucache_queue_t *prev;
lrucache_queue_t *next;
uint32_t user_flags;
};
]]
local queue_arr_type = ffi.typeof("lrucache_queue_t[?]")
local queue_type = ffi.typeof("lrucache_queue_t")
local NULL = ffi.null
-- queue utility functions
local function queue_insert_tail(h, x)
local last = h[0].prev
x.prev = last
last.next = x
x.next = h
h[0].prev = x
end
local function queue_init(size)
if not size then
size = 0
end
local q = ffi_new(queue_arr_type, size + 1)
ffi_fill(q, ffi_sizeof(queue_type, size + 1), 0)
if size == 0 then
q[0].prev = q
q[0].next = q
else
local prev = q[0]
for i = 1, size do
local e = q + i
e.user_flags = 0
prev.next = e
e.prev = prev
prev = e
end
local last = q[size]
last.next = q
q[0].prev = last
end
return q
end
local function queue_is_empty(q)
-- print("q: ", tostring(q), "q.prev: ", tostring(q), ": ", q == q.prev)
return q == q[0].prev
end
local function queue_remove(x)
local prev = x.prev
local next = x.next
next.prev = prev
prev.next = next
-- for debugging purpose only:
x.prev = NULL
x.next = NULL
end
local function queue_insert_head(h, x)
x.next = h[0].next
x.next.prev = x
x.prev = h
h[0].next = x
end
local function queue_last(h)
return h[0].prev
end
local function queue_head(h)
return h[0].next
end
-- true module stuffs
local _M = {
_VERSION = '0.11'
}
local mt = { __index = _M }
local function ptr2num(ptr)
return tonumber(ffi_cast(uintptr_t, ptr))
end
function _M.new(size)
if size < 1 then
return nil, "size too small"
end
local self = {
hasht = {},
free_queue = queue_init(size),
cache_queue = queue_init(),
key2node = {},
node2key = {},
num_items = 0,
max_items = size,
}
return setmetatable(self, mt)
end
function _M.count(self)
return self.num_items
end
function _M.capacity(self)
return self.max_items
end
function _M.get(self, key)
local hasht = self.hasht
local val = hasht[key]
if val == nil then
return nil
end
local node = self.key2node[key]
-- print(key, ": moving node ", tostring(node), " to cache queue head")
local cache_queue = self.cache_queue
queue_remove(node)
queue_insert_head(cache_queue, node)
if node.expire >= 0 and node.expire < ngx_now() then
-- print("expired: ", node.expire, " > ", ngx_now())
return nil, val, node.user_flags
end
return val, nil, node.user_flags
end
function _M.delete(self, key)
self.hasht[key] = nil
local key2node = self.key2node
local node = key2node[key]
if not node then
return false
end
key2node[key] = nil
self.node2key[ptr2num(node)] = nil
queue_remove(node)
queue_insert_tail(self.free_queue, node)
self.num_items = self.num_items - 1
return true
end
function _M.set(self, key, value, ttl, flags)
local hasht = self.hasht
hasht[key] = value
local key2node = self.key2node
local node = key2node[key]
if not node then
local free_queue = self.free_queue
local node2key = self.node2key
if queue_is_empty(free_queue) then
-- evict the least recently used key
-- assert(not queue_is_empty(self.cache_queue))
node = queue_last(self.cache_queue)
local oldkey = node2key[ptr2num(node)]
-- print(key, ": evicting oldkey: ", oldkey, ", oldnode: ",
-- tostring(node))
if oldkey then
hasht[oldkey] = nil
key2node[oldkey] = nil
end
else
-- take a free queue node
node = queue_head(free_queue)
-- only add count if we are not evicting
self.num_items = self.num_items + 1
-- print(key, ": get a new free node: ", tostring(node))
end
node2key[ptr2num(node)] = key
key2node[key] = node
end
queue_remove(node)
queue_insert_head(self.cache_queue, node)
if ttl then
node.expire = ngx_now() + ttl
else
node.expire = -1
end
if type(flags) == "number" and flags >= 0 then
node.user_flags = flags
else
node.user_flags = 0
end
end
function _M.get_keys(self, max_count, res)
if not max_count or max_count == 0 then
max_count = self.num_items
end
if not res then
res = new_tab(max_count + 1, 0) -- + 1 for trailing hole
end
local cache_queue = self.cache_queue
local node2key = self.node2key
local i = 0
local node = queue_head(cache_queue)
while node ~= cache_queue do
if i >= max_count then
break
end
i = i + 1
res[i] = node2key[ptr2num(node)]
node = node.next
end
res[i + 1] = nil
return res
end
function _M.flush_all(self)
tb_clear(self.hasht)
tb_clear(self.node2key)
tb_clear(self.key2node)
self.num_items = 0
local cache_queue = self.cache_queue
local free_queue = self.free_queue
-- splice the cache_queue into free_queue
if not queue_is_empty(cache_queue) then
local free_head = free_queue[0]
local free_last = free_head.prev
local cache_head = cache_queue[0]
local cache_first = cache_head.next
local cache_last = cache_head.prev
free_last.next = cache_first
cache_first.prev = free_last
cache_last.next = free_head
free_head.prev = cache_last
cache_head.next = cache_queue
cache_head.prev = cache_queue
end
end
return _M

606
resty/lrucache/pureffi.lua Normal file
View File

@@ -0,0 +1,606 @@
-- Copyright (C) Yichun Zhang (agentzh)
-- Copyright (C) Shuxin Yang
--[[
This module implements a key/value cache store. We adopt LRU as our
replace/evict policy. Each key/value pair is tagged with a Time-to-Live (TTL);
from user's perspective, stale pairs are automatically removed from the cache.
Why FFI
-------
In Lua, expression "table[key] = nil" does not *PHYSICALLY* remove the value
associated with the key; it just set the value to be nil! So the table will
keep growing with large number of the key/nil pairs which will be purged until
resize() operator is called.
This "feature" is terribly ill-suited to what we need. Therefore we have to
rely on FFI to build a hash-table where any entry can be physically deleted
immediately.
Under the hood:
--------------
In concept, we introduce three data structures to implement the cache store:
1. key/value vector for storing keys and values.
2. a queue to mimic the LRU.
3. hash-table for looking up the value for a given key.
Unfortunately, efficiency and clarity usually come at each other cost. The
data strucutres we are using are slightly more complicated than what we
described above.
o. Lua does not have efficient way to store a vector of pair. So, we use
two vectors for key/value pair: one for keys and the other for values
(_M.key_v and _M.val_v, respectively), and i-th key corresponds to
i-th value.
A key/value pair is identified by the "id" field in a "node" (we shall
discuss node later)
o. The queue is nothing more than a doubly-linked list of "node" linked via
lrucache_pureffi_queue_s::{next|prev} fields.
o. The hash-table has two parts:
- the _M.bucket_v[] a vector of bucket, indiced by hash-value, and
- a bucket is a singly-linked list of "node" via the
lrucache_pureffi_queue_s::conflict field.
A key must be a string, and the hash value of a key is evaluated by:
crc32(key-cast-to-pointer) % size(_M.bucket_v).
We mandate size(_M.bucket_v) being a power-of-two in order to avoid
expensive modulo operation.
At the heart of the module is an array of "node" (of type
lrucache_pureffi_queue_s). A node:
- keeps the meta-data of its corresponding key/value pair
(embodied by the "id", and "expire" field);
- is a part of LRU queue (embodied by "prev" and "next" fields);
- is a part of hash-table (embodied by the "conflict" field).
]]
local ffi = require "ffi"
local bit = require "bit"
local ffi_new = ffi.new
local ffi_sizeof = ffi.sizeof
local ffi_cast = ffi.cast
local ffi_fill = ffi.fill
local ngx_now = ngx.now
local uintptr_t = ffi.typeof("uintptr_t")
local c_str_t = ffi.typeof("const char*")
local int_t = ffi.typeof("int")
local int_array_t = ffi.typeof("int[?]")
local crc_tab = ffi.new("const unsigned int[256]", {
0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F,
0xE963A535, 0x9E6495A3, 0x0EDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988,
0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, 0x1DB71064, 0x6AB020F2,
0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7,
0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9,
0xFA0F3D63, 0x8D080DF5, 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172,
0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, 0x35B5A8FA, 0x42B2986C,
0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59,
0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423,
0xCFBA9599, 0xB8BDA50F, 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924,
0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, 0x76DC4190, 0x01DB7106,
0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433,
0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D,
0x91646C97, 0xE6635C01, 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E,
0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, 0x65B0D9C6, 0x12B7E950,
0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65,
0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7,
0xA4D1C46D, 0xD3D6F4FB, 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0,
0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, 0x5005713C, 0x270241AA,
0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F,
0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81,
0xB7BD5C3B, 0xC0BA6CAD, 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A,
0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, 0xE3630B12, 0x94643B84,
0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1,
0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB,
0x196C3671, 0x6E6B06E7, 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC,
0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, 0xD6D6A3E8, 0xA1D1937E,
0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B,
0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55,
0x316E8EEF, 0x4669BE79, 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236,
0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, 0xC5BA3BBE, 0xB2BD0B28,
0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D,
0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F,
0x72076785, 0x05005713, 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38,
0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, 0x86D3D2D4, 0xF1D4E242,
0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777,
0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69,
0x616BFFD3, 0x166CCF45, 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2,
0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, 0xAED16A4A, 0xD9D65ADC,
0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9,
0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693,
0x54DE5729, 0x23D967BF, 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94,
0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D });
local setmetatable = setmetatable
local tonumber = tonumber
local tostring = tostring
local type = type
local brshift = bit.rshift
local bxor = bit.bxor
local band = bit.band
local new_tab
do
local ok
ok, new_tab = pcall(require, "table.new")
if not ok then
new_tab = function(narr, nrec) return {} end
end
end
-- queue data types
--
-- this queue is a double-ended queue and the first node
-- is reserved for the queue itself.
-- the implementation is mostly borrowed from nginx's ngx_queue_t data
-- structure.
ffi.cdef[[
/* A lrucache_pureffi_queue_s node hook together three data structures:
* o. the key/value store as embodied by the "id" (which is in essence the
* indentifier of key/pair pair) and the "expire" (which is a metadata
* of the corresponding key/pair pair).
* o. The LRU queue via the prev/next fields.
* o. The hash-tabble as embodied by the "conflict" field.
*/
typedef struct lrucache_pureffi_queue_s lrucache_pureffi_queue_t;
struct lrucache_pureffi_queue_s {
/* Each node is assigned a unique ID at construction time, and the
* ID remain immutatble, regardless the node is in active-list or
* free-list. The queue header is assigned ID 0. Since queue-header
* is a sentinel node, 0 denodes "invalid ID".
*
* Intuitively, we can view the "id" as the identifier of key/value
* pair.
*/
int id;
/* The bucket of the hash-table is implemented as a singly-linked list.
* The "conflict" refers to the ID of the next node in the bucket.
*/
int conflict;
uint32_t user_flags;
double expire; /* in seconds */
lrucache_pureffi_queue_t *prev;
lrucache_pureffi_queue_t *next;
};
]]
local queue_arr_type = ffi.typeof("lrucache_pureffi_queue_t[?]")
--local queue_ptr_type = ffi.typeof("lrucache_pureffi_queue_t*")
local queue_type = ffi.typeof("lrucache_pureffi_queue_t")
local NULL = ffi.null
--========================================================================
--
-- Queue utility functions
--
--========================================================================
-- Append the element "x" to the given queue "h".
local function queue_insert_tail(h, x)
local last = h[0].prev
x.prev = last
last.next = x
x.next = h
h[0].prev = x
end
--[[
Allocate a queue with size + 1 elements. Elements are linked together in a
circular way, i.e. the last element's "next" points to the first element,
while the first element's "prev" element points to the last element.
]]
local function queue_init(size)
if not size then
size = 0
end
local q = ffi_new(queue_arr_type, size + 1)
ffi_fill(q, ffi_sizeof(queue_type, size + 1), 0)
if size == 0 then
q[0].prev = q
q[0].next = q
else
local prev = q[0]
for i = 1, size do
local e = q[i]
e.id = i
e.user_flags = 0
prev.next = e
e.prev = prev
prev = e
end
local last = q[size]
last.next = q
q[0].prev = last
end
return q
end
local function queue_is_empty(q)
-- print("q: ", tostring(q), "q.prev: ", tostring(q), ": ", q == q.prev)
return q == q[0].prev
end
local function queue_remove(x)
local prev = x.prev
local next = x.next
next.prev = prev
prev.next = next
-- for debugging purpose only:
x.prev = NULL
x.next = NULL
end
-- Insert the element "x" the to the given queue "h"
local function queue_insert_head(h, x)
x.next = h[0].next
x.next.prev = x
x.prev = h
h[0].next = x
end
local function queue_last(h)
return h[0].prev
end
local function queue_head(h)
return h[0].next
end
--========================================================================
--
-- Miscellaneous Utility Functions
--
--========================================================================
local function ptr2num(ptr)
return tonumber(ffi_cast(uintptr_t, ptr))
end
local function crc32_ptr(ptr)
local p = brshift(ptr2num(ptr), 3)
local b = band(p, 255)
local crc32 = crc_tab[b]
b = band(brshift(p, 8), 255)
crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)])
b = band(brshift(p, 16), 255)
crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)])
--b = band(brshift(p, 24), 255)
--crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)])
return crc32
end
--========================================================================
--
-- Implementation of "export" functions
--
--========================================================================
local _M = {
_VERSION = '0.11'
}
local mt = { __index = _M }
-- "size" specifies the maximum number of entries in the LRU queue, and the
-- "load_factor" designates the 'load factor' of the hash-table we are using
-- internally. The default value of load-factor is 0.5 (i.e. 50%); if the
-- load-factor is specified, it will be clamped to the range of [0.1, 1](i.e.
-- if load-factor is greater than 1, it will be saturated to 1, likewise,
-- if load-factor is smaller than 0.1, it will be clamped to 0.1).
function _M.new(size, load_factor)
if size < 1 then
return nil, "size too small"
end
-- Determine bucket size, which must be power of two.
local load_f = load_factor
if not load_factor then
load_f = 0.5
elseif load_factor > 1 then
load_f = 1
elseif load_factor < 0.1 then
load_f = 0.1
end
local bs_min = size / load_f
-- The bucket_sz *MUST* be a power-of-two. See the hash_string().
local bucket_sz = 1
repeat
bucket_sz = bucket_sz * 2
until bucket_sz >= bs_min
local self = {
size = size,
bucket_sz = bucket_sz,
free_queue = queue_init(size),
cache_queue = queue_init(0),
node_v = nil,
key_v = new_tab(size, 0),
val_v = new_tab(size, 0),
bucket_v = ffi_new(int_array_t, bucket_sz),
num_items = 0,
}
-- "node_v" is an array of all the nodes used in the LRU queue. Exprpession
-- node_v[i] evaluates to the element of ID "i".
self.node_v = self.free_queue
-- Allocate the array-part of the key_v, val_v, bucket_v.
--local key_v = self.key_v
--local val_v = self.val_v
--local bucket_v = self.bucket_v
ffi_fill(self.bucket_v, ffi_sizeof(int_t, bucket_sz), 0)
return setmetatable(self, mt)
end
function _M.count(self)
return self.num_items
end
function _M.capacity(self)
return self.size
end
local function hash_string(self, str)
local c_str = ffi_cast(c_str_t, str)
local hv = crc32_ptr(c_str)
hv = band(hv, self.bucket_sz - 1)
-- Hint: bucket is 0-based
return hv
end
-- Search the node associated with the key in the bucket, if found returns
-- the the id of the node, and the id of its previous node in the conflict list.
-- The "bucket_hdr_id" is the ID of the first node in the bucket
local function _find_node_in_bucket(key, key_v, node_v, bucket_hdr_id)
if bucket_hdr_id ~= 0 then
local prev = 0
local cur = bucket_hdr_id
while cur ~= 0 and key_v[cur] ~= key do
prev = cur
cur = node_v[cur].conflict
end
if cur ~= 0 then
return cur, prev
end
end
end
-- Return the node corresponding to the key/val.
local function find_key(self, key)
local key_hash = hash_string(self, key)
return _find_node_in_bucket(key, self.key_v, self.node_v,
self.bucket_v[key_hash])
end
--[[ This function tries to
1. Remove the given key and the associated value from the key/value store,
2. Remove the entry associated with the key from the hash-table.
NOTE: all queues remain intact.
If there was a node bound to the key/val, return that node; otherwise,
nil is returned.
]]
local function remove_key(self, key)
local key_v = self.key_v
local val_v = self.val_v
local node_v = self.node_v
local bucket_v = self.bucket_v
local key_hash = hash_string(self, key)
local cur, prev =
_find_node_in_bucket(key, key_v, node_v, bucket_v[key_hash])
if cur then
-- In an attempt to make key and val dead.
key_v[cur] = nil
val_v[cur] = nil
self.num_items = self.num_items - 1
-- Remove the node from the hash table
local next_node = node_v[cur].conflict
if prev ~= 0 then
node_v[prev].conflict = next_node
else
bucket_v[key_hash] = next_node
end
node_v[cur].conflict = 0
return cur
end
end
--[[ Bind the key/val with the given node, and insert the node into the Hashtab.
NOTE: this function does not touch any queue
]]
local function insert_key(self, key, val, node)
-- Bind the key/val with the node
local node_id = node.id
self.key_v[node_id] = key
self.val_v[node_id] = val
-- Insert the node into the hash-table
local key_hash = hash_string(self, key)
local bucket_v = self.bucket_v
node.conflict = bucket_v[key_hash]
bucket_v[key_hash] = node_id
self.num_items = self.num_items + 1
end
function _M.get(self, key)
if type(key) ~= "string" then
key = tostring(key)
end
local node_id = find_key(self, key)
if not node_id then
return nil
end
-- print(key, ": moving node ", tostring(node), " to cache queue head")
local cache_queue = self.cache_queue
local node = self.node_v + node_id
queue_remove(node)
queue_insert_head(cache_queue, node)
local expire = node.expire
if expire >= 0 and expire < ngx_now() then
-- print("expired: ", node.expire, " > ", ngx_now())
return nil, self.val_v[node_id], node.user_flags
end
return self.val_v[node_id], nil, node.user_flags
end
function _M.delete(self, key)
if type(key) ~= "string" then
key = tostring(key)
end
local node_id = remove_key(self, key);
if not node_id then
return false
end
local node = self.node_v + node_id
queue_remove(node)
queue_insert_tail(self.free_queue, node)
return true
end
function _M.set(self, key, value, ttl, flags)
if type(key) ~= "string" then
key = tostring(key)
end
local node_id = find_key(self, key)
local node
if not node_id then
local free_queue = self.free_queue
if queue_is_empty(free_queue) then
-- evict the least recently used key
-- assert(not queue_is_empty(self.cache_queue))
node = queue_last(self.cache_queue)
remove_key(self, self.key_v[node.id])
else
-- take a free queue node
node = queue_head(free_queue)
-- print(key, ": get a new free node: ", tostring(node))
end
-- insert the key
insert_key(self, key, value, node)
else
node = self.node_v + node_id
self.val_v[node_id] = value
end
queue_remove(node)
queue_insert_head(self.cache_queue, node)
if ttl then
node.expire = ngx_now() + ttl
else
node.expire = -1
end
if type(flags) == "number" and flags >= 0 then
node.user_flags = flags
else
node.user_flags = 0
end
end
function _M.get_keys(self, max_count, res)
if not max_count or max_count == 0 then
max_count = self.num_items
end
if not res then
res = new_tab(max_count + 1, 0) -- + 1 for trailing hole
end
local cache_queue = self.cache_queue
local key_v = self.key_v
local i = 0
local node = queue_head(cache_queue)
while node ~= cache_queue do
if i >= max_count then
break
end
i = i + 1
res[i] = key_v[node.id]
node = node.next
end
res[i + 1] = nil
return res
end
function _M.flush_all(self)
local cache_queue = self.cache_queue
local key_v = self.key_v
local node = queue_head(cache_queue)
while node ~= cache_queue do
local key = key_v[node.id]
node = node.next
_M.delete(self, key)
end
end
return _M

72
resty/md5.lua Normal file
View File

@@ -0,0 +1,72 @@
-- Copyright (C) by Yichun Zhang (agentzh)
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
local setmetatable = setmetatable
--local error = error
local _M = { _VERSION = '0.14' }
local mt = { __index = _M }
ffi.cdef[[
typedef unsigned long MD5_LONG ;
enum {
MD5_CBLOCK = 64,
MD5_LBLOCK = MD5_CBLOCK/4
};
typedef struct MD5state_st
{
MD5_LONG A,B,C,D;
MD5_LONG Nl,Nh;
MD5_LONG data[MD5_LBLOCK];
unsigned int num;
} MD5_CTX;
int MD5_Init(MD5_CTX *c);
int MD5_Update(MD5_CTX *c, const void *data, size_t len);
int MD5_Final(unsigned char *md, MD5_CTX *c);
]]
local buf = ffi_new("char[16]")
local ctx_ptr_type = ffi.typeof("MD5_CTX[1]")
function _M.new(self)
local ctx = ffi_new(ctx_ptr_type)
if C.MD5_Init(ctx) == 0 then
return nil
end
return setmetatable({ _ctx = ctx }, mt)
end
function _M.update(self, s)
return C.MD5_Update(self._ctx, s, #s) == 1
end
function _M.final(self)
if C.MD5_Final(buf, self._ctx) == 1 then
return ffi_str(buf, 16)
end
return nil
end
function _M.reset(self)
return C.MD5_Init(self._ctx) == 1
end
return _M

744
resty/memcached.lua Normal file
View File

@@ -0,0 +1,744 @@
-- Copyright (C) Yichun Zhang (agentzh), CloudFlare Inc.
local escape_uri = ngx.escape_uri
local unescape_uri = ngx.unescape_uri
local match = string.match
local tcp = ngx.socket.tcp
local strlen = string.len
local concat = table.concat
local setmetatable = setmetatable
local type = type
local _M = {
_VERSION = '0.16'
}
local mt = { __index = _M }
function _M.new(self, opts)
local sock, err = tcp()
if not sock then
return nil, err
end
local escape_key = escape_uri
local unescape_key = unescape_uri
if opts then
local key_transform = opts.key_transform
if key_transform then
escape_key = key_transform[1]
unescape_key = key_transform[2]
if not escape_key or not unescape_key then
return nil, "expecting key_transform = { escape, unescape } table"
end
end
end
return setmetatable({
sock = sock,
escape_key = escape_key,
unescape_key = unescape_key,
}, mt)
end
local function set_timeouts(self, connect, send, read)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
sock:settimeouts(connect, send, read)
return 1
end
_M.set_timeouts = set_timeouts
function _M.set_timeout(self, timeout)
return set_timeouts(self, timeout, timeout, timeout)
end
function _M.connect(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:connect(...)
end
local function _multi_get(self, keys)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local nkeys = #keys
if nkeys == 0 then
return {}, nil
end
local escape_key = self.escape_key
local cmd = {"get"}
local n = 1
for i = 1, nkeys do
cmd[n + 1] = " "
cmd[n + 2] = escape_key(keys[i])
n = n + 2
end
cmd[n + 1] = "\r\n"
-- print("multi get cmd: ", cmd)
local bytes, err = sock:send(cmd)
if not bytes then
return nil, err
end
local unescape_key = self.unescape_key
local results = {}
while true do
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
if line == 'END' then
break
end
local key, flags, len = match(line, '^VALUE (%S+) (%d+) (%d+)$')
-- print("key: ", key, "len: ", len, ", flags: ", flags)
if not key then
return nil, line
end
local data, err = sock:receive(len)
if not data then
if err == "timeout" then
sock:close()
end
return nil, err
end
results[unescape_key(key)] = {data, flags}
data, err = sock:receive(2) -- discard the trailing CRLF
if not data then
if err == "timeout" then
sock:close()
end
return nil, err
end
end
return results
end
function _M.get(self, key)
if type(key) == "table" then
return _multi_get(self, key)
end
local sock = self.sock
if not sock then
return nil, nil, "not initialized"
end
local bytes, err = sock:send("get " .. self.escape_key(key) .. "\r\n")
if not bytes then
return nil, nil, err
end
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, nil, err
end
if line == 'END' then
return nil, nil, nil
end
local flags, len = match(line, '^VALUE %S+ (%d+) (%d+)$')
if not flags then
return nil, nil, line
end
-- print("len: ", len, ", flags: ", flags)
local data, err = sock:receive(len)
if not data then
if err == "timeout" then
sock:close()
end
return nil, nil, err
end
line, err = sock:receive(7) -- discard the trailing "\r\nEND\r\n"
if not line then
if err == "timeout" then
sock:close()
end
return nil, nil, err
end
return data, flags
end
local function _multi_gets(self, keys)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local nkeys = #keys
if nkeys == 0 then
return {}, nil
end
local escape_key = self.escape_key
local cmd = {"gets"}
local n = 1
for i = 1, nkeys do
cmd[n + 1] = " "
cmd[n + 2] = escape_key(keys[i])
n = n + 2
end
cmd[n + 1] = "\r\n"
-- print("multi get cmd: ", cmd)
local bytes, err = sock:send(cmd)
if not bytes then
return nil, err
end
local unescape_key = self.unescape_key
local results = {}
while true do
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
if line == 'END' then
break
end
local key, flags, len, cas_uniq =
match(line, '^VALUE (%S+) (%d+) (%d+) (%d+)$')
-- print("key: ", key, "len: ", len, ", flags: ", flags)
if not key then
return nil, line
end
local data, err = sock:receive(len)
if not data then
if err == "timeout" then
sock:close()
end
return nil, err
end
results[unescape_key(key)] = {data, flags, cas_uniq}
data, err = sock:receive(2) -- discard the trailing CRLF
if not data then
if err == "timeout" then
sock:close()
end
return nil, err
end
end
return results
end
function _M.gets(self, key)
if type(key) == "table" then
return _multi_gets(self, key)
end
local sock = self.sock
if not sock then
return nil, nil, nil, "not initialized"
end
local bytes, err = sock:send("gets " .. self.escape_key(key) .. "\r\n")
if not bytes then
return nil, nil, nil, err
end
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, nil, nil, err
end
if line == 'END' then
return nil, nil, nil, nil
end
local flags, len, cas_uniq = match(line, '^VALUE %S+ (%d+) (%d+) (%d+)$')
if not flags then
return nil, nil, nil, line
end
-- print("len: ", len, ", flags: ", flags)
local data, err = sock:receive(len)
if not data then
if err == "timeout" then
sock:close()
end
return nil, nil, nil, err
end
line, err = sock:receive(7) -- discard the trailing "\r\nEND\r\n"
if not line then
if err == "timeout" then
sock:close()
end
return nil, nil, nil, err
end
return data, flags, cas_uniq
end
local function _expand_table(value)
local segs = {}
local nelems = #value
local nsegs = 0
for i = 1, nelems do
local seg = value[i]
nsegs = nsegs + 1
if type(seg) == "table" then
segs[nsegs] = _expand_table(seg)
else
segs[nsegs] = seg
end
end
return concat(segs)
end
local function _store(self, cmd, key, value, exptime, flags)
if not exptime then
exptime = 0
end
if not flags then
flags = 0
end
local sock = self.sock
if not sock then
return nil, "not initialized"
end
if type(value) == "table" then
value = _expand_table(value)
end
local req = cmd .. " " .. self.escape_key(key) .. " " .. flags .. " "
.. exptime .. " " .. strlen(value) .. "\r\n" .. value
.. "\r\n"
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
local data, err = sock:receive()
if not data then
if err == "timeout" then
sock:close()
end
return nil, err
end
if data == "STORED" then
return 1
end
return nil, data
end
function _M.set(self, ...)
return _store(self, "set", ...)
end
function _M.add(self, ...)
return _store(self, "add", ...)
end
function _M.replace(self, ...)
return _store(self, "replace", ...)
end
function _M.append(self, ...)
return _store(self, "append", ...)
end
function _M.prepend(self, ...)
return _store(self, "prepend", ...)
end
function _M.cas(self, key, value, cas_uniq, exptime, flags)
if not exptime then
exptime = 0
end
if not flags then
flags = 0
end
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local req = "cas " .. self.escape_key(key) .. " " .. flags .. " "
.. exptime .. " " .. strlen(value) .. " " .. cas_uniq
.. "\r\n" .. value .. "\r\n"
-- local cjson = require "cjson"
-- print("request: ", cjson.encode(req))
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
-- print("response: [", line, "]")
if line == "STORED" then
return 1
end
return nil, line
end
function _M.delete(self, key)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
key = self.escape_key(key)
local req = "delete " .. key .. "\r\n"
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
local res, err = sock:receive()
if not res then
if err == "timeout" then
sock:close()
end
return nil, err
end
if res ~= 'DELETED' then
return nil, res
end
return 1
end
function _M.set_keepalive(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:setkeepalive(...)
end
function _M.get_reused_times(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:getreusedtimes()
end
function _M.flush_all(self, time)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local req
if time then
req = "flush_all " .. time .. "\r\n"
else
req = "flush_all\r\n"
end
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
local res, err = sock:receive()
if not res then
if err == "timeout" then
sock:close()
end
return nil, err
end
if res ~= 'OK' then
return nil, res
end
return 1
end
local function _incr_decr(self, cmd, key, value)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local req = cmd .. " " .. self.escape_key(key) .. " " .. value .. "\r\n"
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
if not match(line, '^%d+$') then
return nil, line
end
return line
end
function _M.incr(self, key, value)
return _incr_decr(self, "incr", key, value)
end
function _M.decr(self, key, value)
return _incr_decr(self, "decr", key, value)
end
function _M.stats(self, args)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local req
if args then
req = "stats " .. args .. "\r\n"
else
req = "stats\r\n"
end
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
local lines = {}
local n = 0
while true do
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
if line == 'END' then
return lines, nil
end
if not match(line, "ERROR") then
n = n + 1
lines[n] = line
else
return nil, line
end
end
-- cannot reach here...
return lines
end
function _M.version(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local bytes, err = sock:send("version\r\n")
if not bytes then
return nil, err
end
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
local ver = match(line, "^VERSION (.+)$")
if not ver then
return nil, ver
end
return ver
end
function _M.quit(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local bytes, err = sock:send("quit\r\n")
if not bytes then
return nil, err
end
return 1
end
function _M.verbosity(self, level)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local bytes, err = sock:send("verbosity " .. level .. "\r\n")
if not bytes then
return nil, err
end
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
if line ~= 'OK' then
return nil, line
end
return 1
end
function _M.touch(self, key, exptime)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local bytes, err = sock:send("touch " .. self.escape_key(key) .. " "
.. exptime .. "\r\n")
if not bytes then
return nil, err
end
local line, err = sock:receive()
if not line then
if err == "timeout" then
sock:close()
end
return nil, err
end
-- moxi server from couchbase returned stored after touching
if line == "TOUCHED" or line =="STORED" then
return 1
end
return nil, line
end
function _M.close(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:close()
end
return _M

1410
resty/mysql.lua Normal file

File diff suppressed because it is too large Load Diff

36
resty/random.lua Normal file
View File

@@ -0,0 +1,36 @@
-- Copyright (C) by Yichun Zhang (agentzh)
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
--local setmetatable = setmetatable
--local error = error
local _M = { _VERSION = '0.14' }
ffi.cdef[[
int RAND_bytes(unsigned char *buf, int num);
int RAND_pseudo_bytes(unsigned char *buf, int num);
]]
function _M.bytes(len, strong)
local buf = ffi_new("char[?]", len)
if strong then
if C.RAND_bytes(buf, len) == 0 then
return nil
end
else
C.RAND_pseudo_bytes(buf,len)
end
return ffi_str(buf, len)
end
return _M

676
resty/redis.lua Normal file
View File

@@ -0,0 +1,676 @@
-- Copyright (C) Yichun Zhang (agentzh)
local sub = string.sub
local byte = string.byte
local tab_insert = table.insert
local tab_remove = table.remove
local tcp = ngx.socket.tcp
local null = ngx.null
local ipairs = ipairs
local type = type
local pairs = pairs
local unpack = unpack
local setmetatable = setmetatable
local tonumber = tonumber
local tostring = tostring
local rawget = rawget
local select = select
--local error = error
local ok, new_tab = pcall(require, "table.new")
if not ok or type(new_tab) ~= "function" then
new_tab = function (narr, nrec) return {} end
end
local _M = new_tab(0, 55)
_M._VERSION = '0.29'
local common_cmds = {
"get", "set", "mget", "mset",
"del", "incr", "decr", -- Strings
"llen", "lindex", "lpop", "lpush",
"lrange", "linsert", -- Lists
"hexists", "hget", "hset", "hmget",
--[[ "hmset", ]] "hdel", -- Hashes
"smembers", "sismember", "sadd", "srem",
"sdiff", "sinter", "sunion", -- Sets
"zrange", "zrangebyscore", "zrank", "zadd",
"zrem", "zincrby", -- Sorted Sets
"auth", "eval", "expire", "script",
"sort" -- Others
}
local sub_commands = {
"subscribe", "psubscribe"
}
local unsub_commands = {
"unsubscribe", "punsubscribe"
}
local mt = { __index = _M }
function _M.new(self)
local sock, err = tcp()
if not sock then
return nil, err
end
return setmetatable({ _sock = sock,
_subscribed = false,
_n_channel = {
unsubscribe = 0,
punsubscribe = 0,
},
}, mt)
end
function _M.set_timeout(self, timeout)
local sock = rawget(self, "_sock")
if not sock then
error("not initialized", 2)
return
end
sock:settimeout(timeout)
end
function _M.set_timeouts(self, connect_timeout, send_timeout, read_timeout)
local sock = rawget(self, "_sock")
if not sock then
error("not initialized", 2)
return
end
sock:settimeouts(connect_timeout, send_timeout, read_timeout)
end
function _M.connect(self, host, port_or_opts, opts)
local sock = rawget(self, "_sock")
if not sock then
return nil, "not initialized"
end
local unix
do
local typ = type(host)
if typ ~= "string" then
error("bad argument #1 host: string expected, got " .. typ, 2)
end
if sub(host, 1, 5) == "unix:" then
unix = true
end
if unix then
typ = type(port_or_opts)
if port_or_opts ~= nil and typ ~= "table" then
error("bad argument #2 opts: nil or table expected, got " ..
typ, 2)
end
else
typ = type(port_or_opts)
if typ ~= "number" then
port_or_opts = tonumber(port_or_opts)
if port_or_opts == nil then
error("bad argument #2 port: number expected, got " ..
typ, 2)
end
end
if opts ~= nil then
typ = type(opts)
if typ ~= "table" then
error("bad argument #3 opts: nil or table expected, got " ..
typ, 2)
end
end
end
end
self._subscribed = false
local ok, err
if unix then
-- second argument of sock:connect() cannot be nil
if port_or_opts ~= nil then
ok, err = sock:connect(host, port_or_opts)
opts = port_or_opts
else
ok, err = sock:connect(host)
end
else
ok, err = sock:connect(host, port_or_opts, opts)
end
if not ok then
return ok, err
end
if opts and opts.ssl then
ok, err = sock:sslhandshake(false, opts.server_name, opts.ssl_verify)
if not ok then
return ok, "failed to do ssl handshake: " .. err
end
end
return ok, err
end
function _M.set_keepalive(self, ...)
local sock = rawget(self, "_sock")
if not sock then
return nil, "not initialized"
end
if rawget(self, "_subscribed") then
return nil, "subscribed state"
end
return sock:setkeepalive(...)
end
function _M.get_reused_times(self)
local sock = rawget(self, "_sock")
if not sock then
return nil, "not initialized"
end
return sock:getreusedtimes()
end
local function close(self)
local sock = rawget(self, "_sock")
if not sock then
return nil, "not initialized"
end
return sock:close()
end
_M.close = close
local function _read_reply(self, sock)
local line, err = sock:receive()
if not line then
if err == "timeout" and not rawget(self, "_subscribed") then
sock:close()
end
return nil, err
end
local prefix = byte(line)
if prefix == 36 then -- char '$'
-- print("bulk reply")
local size = tonumber(sub(line, 2))
if size < 0 then
return null
end
local data, err = sock:receive(size)
if not data then
if err == "timeout" then
sock:close()
end
return nil, err
end
local dummy, err = sock:receive(2) -- ignore CRLF
if not dummy then
if err == "timeout" then
sock:close()
end
return nil, err
end
return data
elseif prefix == 43 then -- char '+'
-- print("status reply")
return sub(line, 2)
elseif prefix == 42 then -- char '*'
local n = tonumber(sub(line, 2))
-- print("multi-bulk reply: ", n)
if n < 0 then
return null
end
local vals = new_tab(n, 0)
local nvals = 0
for i = 1, n do
local res, err = _read_reply(self, sock)
if res then
nvals = nvals + 1
vals[nvals] = res
elseif res == nil then
return nil, err
else
-- be a valid redis error value
nvals = nvals + 1
vals[nvals] = {false, err}
end
end
return vals
elseif prefix == 58 then -- char ':'
-- print("integer reply")
return tonumber(sub(line, 2))
elseif prefix == 45 then -- char '-'
-- print("error reply: ", n)
return false, sub(line, 2)
else
-- when `line` is an empty string, `prefix` will be equal to nil.
return nil, "unknown prefix: \"" .. tostring(prefix) .. "\""
end
end
local function _gen_req(args)
local nargs = #args
local req = new_tab(nargs * 5 + 1, 0)
req[1] = "*" .. nargs .. "\r\n"
local nbits = 2
for i = 1, nargs do
local arg = args[i]
if type(arg) ~= "string" then
arg = tostring(arg)
end
req[nbits] = "$"
req[nbits + 1] = #arg
req[nbits + 2] = "\r\n"
req[nbits + 3] = arg
req[nbits + 4] = "\r\n"
nbits = nbits + 5
end
-- it is much faster to do string concatenation on the C land
-- in real world (large number of strings in the Lua VM)
return req
end
local function _check_msg(self, res)
return rawget(self, "_subscribed") and
type(res) == "table" and (res[1] == "message" or res[1] == "pmessage")
end
local function _do_cmd(self, ...)
local args = {...}
local sock = rawget(self, "_sock")
if not sock then
return nil, "not initialized"
end
local req = _gen_req(args)
local reqs = rawget(self, "_reqs")
if reqs then
reqs[#reqs + 1] = req
return
end
-- print("request: ", table.concat(req))
local bytes, err = sock:send(req)
if not bytes then
return nil, err
end
local res, err = _read_reply(self, sock)
while _check_msg(self, res) do
if rawget(self, "_buffered_msg") == nil then
self._buffered_msg = new_tab(1, 0)
end
tab_insert(self._buffered_msg, res)
res, err = _read_reply(self, sock)
end
return res, err
end
local function _check_unsubscribed(self, res)
if type(res) == "table"
and (res[1] == "unsubscribe" or res[1] == "punsubscribe")
then
self._n_channel[res[1]] = self._n_channel[res[1]] - 1
local buffered_msg = rawget(self, "_buffered_msg")
if buffered_msg then
-- remove messages of unsubscribed channel
local msg_type =
(res[1] == "punsubscribe") and "pmessage" or "message"
local j = 1
for _, msg in ipairs(buffered_msg) do
if msg[1] == msg_type and msg[2] ~= res[2] then
-- move messages to overwrite the removed ones
buffered_msg[j] = msg
j = j + 1
end
end
-- clear remain messages
for i = j, #buffered_msg do
buffered_msg[i] = nil
end
if #buffered_msg == 0 then
self._buffered_msg = nil
end
end
if res[3] == 0 then
-- all channels are unsubscribed
self._subscribed = false
end
end
end
local function _check_subscribed(self, res)
if type(res) == "table"
and (res[1] == "subscribe" or res[1] == "psubscribe")
then
if res[1] == "subscribe" then
self._n_channel.unsubscribe = self._n_channel.unsubscribe + 1
elseif res[1] == "psubscribe" then
self._n_channel.punsubscribe = self._n_channel.punsubscribe + 1
end
end
end
function _M.read_reply(self)
local sock = rawget(self, "_sock")
if not sock then
return nil, "not initialized"
end
if not rawget(self, "_subscribed") then
return nil, "not subscribed"
end
local buffered_msg = rawget(self, "_buffered_msg")
if buffered_msg then
local msg = buffered_msg[1]
tab_remove(buffered_msg, 1)
if #buffered_msg == 0 then
self._buffered_msg = nil
end
return msg
end
local res, err = _read_reply(self, sock)
_check_unsubscribed(self, res)
return res, err
end
for i = 1, #common_cmds do
local cmd = common_cmds[i]
_M[cmd] =
function (self, ...)
return _do_cmd(self, cmd, ...)
end
end
local function handle_subscribe_result(self, cmd, nargs, res)
local err
_check_subscribed(self, res)
if nargs <= 1 then
return res
end
local results = new_tab(nargs, 0)
results[1] = res
local sock = rawget(self, "_sock")
for i = 2, nargs do
res, err = _read_reply(self, sock)
if not res then
return nil, err
end
_check_subscribed(self, res)
results[i] = res
end
return results
end
for i = 1, #sub_commands do
local cmd = sub_commands[i]
_M[cmd] =
function (self, ...)
if not rawget(self, "_subscribed") then
self._subscribed = true
end
local nargs = select("#", ...)
local res, err = _do_cmd(self, cmd, ...)
if not res then
return nil, err
end
return handle_subscribe_result(self, cmd, nargs, res)
end
end
local function handle_unsubscribe_result(self, cmd, nargs, res)
local err
_check_unsubscribed(self, res)
if self._n_channel[cmd] == 0 or nargs == 1 then
return res
end
local results = new_tab(nargs, 0)
results[1] = res
local sock = rawget(self, "_sock")
local i = 2
while nargs == 0 or i <= nargs do
res, err = _read_reply(self, sock)
if not res then
return nil, err
end
results[i] = res
i = i + 1
_check_unsubscribed(self, res)
if self._n_channel[cmd] == 0 then
-- exit the loop for unsubscribe() call
break
end
end
return results
end
for i = 1, #unsub_commands do
local cmd = unsub_commands[i]
_M[cmd] =
function (self, ...)
-- assume all channels are unsubscribed by only one time
if not rawget(self, "_subscribed") then
return nil, "not subscribed"
end
local nargs = select("#", ...)
local res, err = _do_cmd(self, cmd, ...)
if not res then
return nil, err
end
return handle_unsubscribe_result(self, cmd, nargs, res)
end
end
function _M.hmset(self, hashname, ...)
if select('#', ...) == 1 then
local t = select(1, ...)
local n = 0
for k, v in pairs(t) do
n = n + 2
end
local array = new_tab(n, 0)
local i = 0
for k, v in pairs(t) do
array[i + 1] = k
array[i + 2] = v
i = i + 2
end
-- print("key", hashname)
return _do_cmd(self, "hmset", hashname, unpack(array))
end
-- backwards compatibility
return _do_cmd(self, "hmset", hashname, ...)
end
function _M.init_pipeline(self, n)
self._reqs = new_tab(n or 4, 0)
end
function _M.cancel_pipeline(self)
self._reqs = nil
end
function _M.commit_pipeline(self)
local reqs = rawget(self, "_reqs")
if not reqs then
return nil, "no pipeline"
end
self._reqs = nil
local sock = rawget(self, "_sock")
if not sock then
return nil, "not initialized"
end
local bytes, err = sock:send(reqs)
if not bytes then
return nil, err
end
local nvals = 0
local nreqs = #reqs
local vals = new_tab(nreqs, 0)
for i = 1, nreqs do
local res, err = _read_reply(self, sock)
if res then
nvals = nvals + 1
vals[nvals] = res
elseif res == nil then
if err == "timeout" then
close(self)
end
return nil, err
else
-- be a valid redis error value
nvals = nvals + 1
vals[nvals] = {false, err}
end
end
return vals
end
function _M.array_to_hash(self, t)
local n = #t
-- print("n = ", n)
local h = new_tab(0, n / 2)
for i = 1, n, 2 do
h[t[i]] = t[i + 1]
end
return h
end
-- this method is deperate since we already do lazy method generation.
function _M.add_commands(...)
local cmds = {...}
for i = 1, #cmds do
local cmd = cmds[i]
_M[cmd] =
function (self, ...)
return _do_cmd(self, cmd, ...)
end
end
end
setmetatable(_M, {__index = function(self, cmd)
local method =
function (self, ...)
return _do_cmd(self, cmd, ...)
end
-- cache the lazily generated method in our
-- module table
_M[cmd] = method
return method
end})
return _M

19
resty/sha.lua Normal file
View File

@@ -0,0 +1,19 @@
-- Copyright (C) by Yichun Zhang (agentzh)
local ffi = require "ffi"
local _M = { _VERSION = '0.14' }
ffi.cdef[[
typedef unsigned long SHA_LONG;
typedef unsigned long long SHA_LONG64;
enum {
SHA_LBLOCK = 16
};
]];
return _M

69
resty/sha1.lua Normal file
View File

@@ -0,0 +1,69 @@
-- Copyright (C) by Yichun Zhang (agentzh)
require "resty.sha"
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
local setmetatable = setmetatable
--local error = error
local _M = { _VERSION = '0.14' }
local mt = { __index = _M }
ffi.cdef[[
typedef struct SHAstate_st
{
SHA_LONG h0,h1,h2,h3,h4;
SHA_LONG Nl,Nh;
SHA_LONG data[SHA_LBLOCK];
unsigned int num;
} SHA_CTX;
int SHA1_Init(SHA_CTX *c);
int SHA1_Update(SHA_CTX *c, const void *data, size_t len);
int SHA1_Final(unsigned char *md, SHA_CTX *c);
]]
local digest_len = 20
local buf = ffi_new("char[?]", digest_len)
local ctx_ptr_type = ffi.typeof("SHA_CTX[1]")
function _M.new(self)
local ctx = ffi_new(ctx_ptr_type)
if C.SHA1_Init(ctx) == 0 then
return nil
end
return setmetatable({ _ctx = ctx }, mt)
end
function _M.update(self, s)
return C.SHA1_Update(self._ctx, s, #s) == 1
end
function _M.final(self)
if C.SHA1_Final(buf, self._ctx) == 1 then
return ffi_str(buf, digest_len)
end
return nil
end
function _M.reset(self)
return C.SHA1_Init(self._ctx) == 1
end
return _M

60
resty/sha224.lua Normal file
View File

@@ -0,0 +1,60 @@
-- Copyright (C) by Yichun Zhang (agentzh)
require "resty.sha256"
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
local setmetatable = setmetatable
--local error = error
local _M = { _VERSION = '0.14' }
local mt = { __index = _M }
ffi.cdef[[
int SHA224_Init(SHA256_CTX *c);
int SHA224_Update(SHA256_CTX *c, const void *data, size_t len);
int SHA224_Final(unsigned char *md, SHA256_CTX *c);
]]
local digest_len = 28
local buf = ffi_new("char[?]", digest_len)
local ctx_ptr_type = ffi.typeof("SHA256_CTX[1]")
function _M.new(self)
local ctx = ffi_new(ctx_ptr_type)
if C.SHA224_Init(ctx) == 0 then
return nil
end
return setmetatable({ _ctx = ctx }, mt)
end
function _M.update(self, s)
return C.SHA224_Update(self._ctx, s, #s) == 1
end
function _M.final(self)
if C.SHA224_Final(buf, self._ctx) == 1 then
return ffi_str(buf, digest_len)
end
return nil
end
function _M.reset(self)
return C.SHA224_Init(self._ctx) == 1
end
return _M

69
resty/sha256.lua Normal file
View File

@@ -0,0 +1,69 @@
-- Copyright (C) by Yichun Zhang (agentzh)
require "resty.sha"
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
local setmetatable = setmetatable
--local error = error
local _M = { _VERSION = '0.14' }
local mt = { __index = _M }
ffi.cdef[[
typedef struct SHA256state_st
{
SHA_LONG h[8];
SHA_LONG Nl,Nh;
SHA_LONG data[SHA_LBLOCK];
unsigned int num,md_len;
} SHA256_CTX;
int SHA256_Init(SHA256_CTX *c);
int SHA256_Update(SHA256_CTX *c, const void *data, size_t len);
int SHA256_Final(unsigned char *md, SHA256_CTX *c);
]]
local digest_len = 32
local buf = ffi_new("char[?]", digest_len)
local ctx_ptr_type = ffi.typeof("SHA256_CTX[1]")
function _M.new(self)
local ctx = ffi_new(ctx_ptr_type)
if C.SHA256_Init(ctx) == 0 then
return nil
end
return setmetatable({ _ctx = ctx }, mt)
end
function _M.update(self, s)
return C.SHA256_Update(self._ctx, s, #s) == 1
end
function _M.final(self)
if C.SHA256_Final(buf, self._ctx) == 1 then
return ffi_str(buf, digest_len)
end
return nil
end
function _M.reset(self)
return C.SHA256_Init(self._ctx) == 1
end
return _M

60
resty/sha384.lua Normal file
View File

@@ -0,0 +1,60 @@
-- Copyright (C) by Yichun Zhang (agentzh)
require "resty.sha512"
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
local setmetatable = setmetatable
--local error = error
local _M = { _VERSION = '0.14' }
local mt = { __index = _M }
ffi.cdef[[
int SHA384_Init(SHA512_CTX *c);
int SHA384_Update(SHA512_CTX *c, const void *data, size_t len);
int SHA384_Final(unsigned char *md, SHA512_CTX *c);
]]
local digest_len = 48
local buf = ffi_new("char[?]", digest_len)
local ctx_ptr_type = ffi.typeof("SHA512_CTX[1]")
function _M.new(self)
local ctx = ffi_new(ctx_ptr_type)
if C.SHA384_Init(ctx) == 0 then
return nil
end
return setmetatable({ _ctx = ctx }, mt)
end
function _M.update(self, s)
return C.SHA384_Update(self._ctx, s, #s) == 1
end
function _M.final(self)
if C.SHA384_Final(buf, self._ctx) == 1 then
return ffi_str(buf, digest_len)
end
return nil
end
function _M.reset(self)
return C.SHA384_Init(self._ctx) == 1
end
return _M

75
resty/sha512.lua Normal file
View File

@@ -0,0 +1,75 @@
-- Copyright (C) by Yichun Zhang (agentzh)
require "resty.sha"
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
local setmetatable = setmetatable
--local error = error
local _M = { _VERSION = '0.14' }
local mt = { __index = _M }
ffi.cdef[[
enum {
SHA512_CBLOCK = SHA_LBLOCK*8
};
typedef struct SHA512state_st
{
SHA_LONG64 h[8];
SHA_LONG64 Nl,Nh;
union {
SHA_LONG64 d[SHA_LBLOCK];
unsigned char p[SHA512_CBLOCK];
} u;
unsigned int num,md_len;
} SHA512_CTX;
int SHA512_Init(SHA512_CTX *c);
int SHA512_Update(SHA512_CTX *c, const void *data, size_t len);
int SHA512_Final(unsigned char *md, SHA512_CTX *c);
]]
local digest_len = 64
local buf = ffi_new("char[?]", digest_len)
local ctx_ptr_type = ffi.typeof("SHA512_CTX[1]")
function _M.new(self)
local ctx = ffi_new(ctx_ptr_type)
if C.SHA512_Init(ctx) == 0 then
return nil
end
return setmetatable({ _ctx = ctx }, mt)
end
function _M.update(self, s)
return C.SHA512_Update(self._ctx, s, #s) == 1
end
function _M.final(self)
if C.SHA512_Final(buf, self._ctx) == 1 then
return ffi_str(buf, digest_len)
end
return nil
end
function _M.reset(self)
return C.SHA512_Init(self._ctx) == 1
end
return _M

40
resty/string.lua Normal file
View File

@@ -0,0 +1,40 @@
-- Copyright (C) by Yichun Zhang (agentzh)
local ffi = require "ffi"
local ffi_new = ffi.new
local ffi_str = ffi.string
local C = ffi.C
--local setmetatable = setmetatable
--local error = error
local tonumber = tonumber
local _M = { _VERSION = '0.14' }
ffi.cdef[[
typedef unsigned char u_char;
u_char * ngx_hex_dump(u_char *dst, const u_char *src, size_t len);
intptr_t ngx_atoi(const unsigned char *line, size_t n);
]]
local str_type = ffi.typeof("uint8_t[?]")
function _M.to_hex(s)
local len = #s
local buf_len = len * 2
local buf = ffi_new(str_type, buf_len)
C.ngx_hex_dump(buf, s, len)
return ffi_str(buf, buf_len)
end
function _M.atoi(s)
return tonumber(C.ngx_atoi(s, #s))
end
return _M

267
resty/upload.lua Normal file
View File

@@ -0,0 +1,267 @@
-- Copyright (C) Yichun Zhang (agentzh)
-- local sub = string.sub
local req_socket = ngx.req.socket
local match = string.match
local setmetatable = setmetatable
local type = type
local ngx_var = ngx.var
-- local print = print
local _M = { _VERSION = '0.10' }
local CHUNK_SIZE = 4096
local MAX_LINE_SIZE = 512
local STATE_BEGIN = 1
local STATE_READING_HEADER = 2
local STATE_READING_BODY = 3
local STATE_EOF = 4
local mt = { __index = _M }
local state_handlers
local function get_boundary()
local header = ngx_var.content_type
if not header then
return nil
end
if type(header) == "table" then
header = header[1]
end
local m = match(header, ";%s*boundary=\"([^\"]+)\"")
if m then
return m
end
return match(header, ";%s*boundary=([^\",;]+)")
end
function _M.new(self, chunk_size, max_line_size)
local boundary = get_boundary()
-- print("boundary: ", boundary)
if not boundary then
return nil, "no boundary defined in Content-Type"
end
-- print('boundary: "', boundary, '"')
local sock, err = req_socket()
if not sock then
return nil, err
end
local read2boundary, err = sock:receiveuntil("--" .. boundary)
if not read2boundary then
return nil, err
end
local read_line, err = sock:receiveuntil("\r\n")
if not read_line then
return nil, err
end
return setmetatable({
sock = sock,
size = chunk_size or CHUNK_SIZE,
line_size = max_line_size or MAX_LINE_SIZE,
read2boundary = read2boundary,
read_line = read_line,
boundary = boundary,
state = STATE_BEGIN
}, mt)
end
function _M.set_timeout(self, timeout)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:settimeout(timeout)
end
local function discard_line(self)
local read_line = self.read_line
local line, err = read_line(self.line_size)
if not line then
return nil, err
end
local dummy, err = read_line(1)
if dummy then
return nil, "line too long: " .. line .. dummy .. "..."
end
if err then
return nil, err
end
return 1
end
local function discard_rest(self)
local sock = self.sock
local size = self.size
while true do
local dummy, err = sock:receive(size)
if err and err ~= 'closed' then
return nil, err
end
if not dummy then
return 1
end
end
end
local function read_body_part(self)
local read2boundary = self.read2boundary
local chunk, err = read2boundary(self.size)
if err then
return nil, nil, err
end
if not chunk then
local sock = self.sock
local data = sock:receive(2)
if data == "--" then
local ok, err = discard_rest(self)
if not ok then
return nil, nil, err
end
self.state = STATE_EOF
return "part_end"
end
if data ~= "\r\n" then
local ok, err = discard_line(self)
if not ok then
return nil, nil, err
end
end
self.state = STATE_READING_HEADER
return "part_end"
end
return "body", chunk
end
local function read_header(self)
local read_line = self.read_line
local line, err = read_line(self.line_size)
if err then
return nil, nil, err
end
local dummy, err = read_line(1)
if dummy then
return nil, nil, "line too long: " .. line .. dummy .. "..."
end
if err then
return nil, nil, err
end
-- print("read line: ", line)
if line == "" then
-- after the last header
self.state = STATE_READING_BODY
return read_body_part(self)
end
local key, value = match(line, "([^: \t]+)%s*:%s*(.+)")
if not key then
return 'header', line
end
return 'header', {key, value, line}
end
local function eof()
return "eof", nil
end
function _M.read(self)
-- local size = self.size
local handler = state_handlers[self.state]
if handler then
return handler(self)
end
return nil, nil, "bad state: " .. self.state
end
local function read_preamble(self)
local sock = self.sock
if not sock then
return nil, nil, "not initialized"
end
local size = self.size
local read2boundary = self.read2boundary
while true do
local preamble = read2boundary(size)
if not preamble then
break
end
-- discard the preamble data chunk
-- print("read preamble: ", preamble)
end
local ok, err = discard_line(self)
if not ok then
return nil, nil, err
end
local read2boundary, err = sock:receiveuntil("\r\n--" .. self.boundary)
if not read2boundary then
return nil, nil, err
end
self.read2boundary = read2boundary
self.state = STATE_READING_HEADER
return read_header(self)
end
state_handlers = {
read_preamble,
read_header,
read_body_part,
eof
}
return _M

View File

@@ -0,0 +1,707 @@
local stream_sock = ngx.socket.tcp
local log = ngx.log
local ERR = ngx.ERR
local WARN = ngx.WARN
local DEBUG = ngx.DEBUG
local sub = string.sub
local re_find = ngx.re.find
local new_timer = ngx.timer.at
local shared = ngx.shared
local debug_mode = ngx.config.debug
local concat = table.concat
local tonumber = tonumber
local tostring = tostring
local ipairs = ipairs
local ceil = math.ceil
local spawn = ngx.thread.spawn
local wait = ngx.thread.wait
local pcall = pcall
local _M = {
_VERSION = '0.05'
}
if not ngx.config
or not ngx.config.ngx_lua_version
or ngx.config.ngx_lua_version < 9005
then
error("ngx_lua 0.9.5+ required")
end
local ok, upstream = pcall(require, "ngx.upstream")
if not ok then
error("ngx_upstream_lua module required")
end
local ok, new_tab = pcall(require, "table.new")
if not ok or type(new_tab) ~= "function" then
new_tab = function (narr, nrec) return {} end
end
local set_peer_down = upstream.set_peer_down
local get_primary_peers = upstream.get_primary_peers
local get_backup_peers = upstream.get_backup_peers
local get_upstreams = upstream.get_upstreams
local upstream_checker_statuses = {}
local function warn(...)
log(WARN, "healthcheck: ", ...)
end
local function errlog(...)
log(ERR, "healthcheck: ", ...)
end
local function debug(...)
-- print("debug mode: ", debug_mode)
if debug_mode then
log(DEBUG, "healthcheck: ", ...)
end
end
local function gen_peer_key(prefix, u, is_backup, id)
if is_backup then
return prefix .. u .. ":b" .. id
end
return prefix .. u .. ":p" .. id
end
local function set_peer_down_globally(ctx, is_backup, id, value)
local u = ctx.upstream
local dict = ctx.dict
local ok, err = set_peer_down(u, is_backup, id, value)
if not ok then
errlog("failed to set peer down: ", err)
end
if not ctx.new_version then
ctx.new_version = true
end
local key = gen_peer_key("d:", u, is_backup, id)
local ok, err = dict:set(key, value)
if not ok then
errlog("failed to set peer down state: ", err)
end
end
local function peer_fail(ctx, is_backup, id, peer)
debug("peer ", peer.name, " was checked to be not ok")
local u = ctx.upstream
local dict = ctx.dict
local key = gen_peer_key("nok:", u, is_backup, id)
local fails, err = dict:get(key)
if not fails then
if err then
errlog("failed to get peer nok key: ", err)
return
end
fails = 1
-- below may have a race condition, but it is fine for our
-- purpose here.
local ok, err = dict:set(key, 1)
if not ok then
errlog("failed to set peer nok key: ", err)
end
else
fails = fails + 1
local ok, err = dict:incr(key, 1)
if not ok then
errlog("failed to incr peer nok key: ", err)
end
end
if fails == 1 then
key = gen_peer_key("ok:", u, is_backup, id)
local succ, err = dict:get(key)
if not succ or succ == 0 then
if err then
errlog("failed to get peer ok key: ", err)
return
end
else
local ok, err = dict:set(key, 0)
if not ok then
errlog("failed to set peer ok key: ", err)
end
end
end
-- print("ctx fall: ", ctx.fall, ", peer down: ", peer.down,
-- ", fails: ", fails)
if not peer.down and fails >= ctx.fall then
warn("peer ", peer.name, " is turned down after ", fails,
" failure(s)")
peer.down = true
set_peer_down_globally(ctx, is_backup, id, true)
end
end
local function peer_ok(ctx, is_backup, id, peer)
debug("peer ", peer.name, " was checked to be ok")
local u = ctx.upstream
local dict = ctx.dict
local key = gen_peer_key("ok:", u, is_backup, id)
local succ, err = dict:get(key)
if not succ then
if err then
errlog("failed to get peer ok key: ", err)
return
end
succ = 1
-- below may have a race condition, but it is fine for our
-- purpose here.
local ok, err = dict:set(key, 1)
if not ok then
errlog("failed to set peer ok key: ", err)
end
else
succ = succ + 1
local ok, err = dict:incr(key, 1)
if not ok then
errlog("failed to incr peer ok key: ", err)
end
end
if succ == 1 then
key = gen_peer_key("nok:", u, is_backup, id)
local fails, err = dict:get(key)
if not fails or fails == 0 then
if err then
errlog("failed to get peer nok key: ", err)
return
end
else
local ok, err = dict:set(key, 0)
if not ok then
errlog("failed to set peer nok key: ", err)
end
end
end
if peer.down and succ >= ctx.rise then
warn("peer ", peer.name, " is turned up after ", succ,
" success(es)")
peer.down = nil
set_peer_down_globally(ctx, is_backup, id, nil)
end
end
-- shortcut error function for check_peer()
local function peer_error(ctx, is_backup, id, peer, ...)
if not peer.down then
errlog(...)
end
peer_fail(ctx, is_backup, id, peer)
end
local function check_peer(ctx, id, peer, is_backup)
local ok
local name = peer.name
local statuses = ctx.statuses
local req = ctx.http_req
local sock, err = stream_sock()
if not sock then
errlog("failed to create stream socket: ", err)
return
end
sock:settimeout(ctx.timeout)
if peer.host then
-- print("peer port: ", peer.port)
ok, err = sock:connect(peer.host, peer.port)
else
ok, err = sock:connect(name)
end
if not ok then
if not peer.down then
errlog("failed to connect to ", name, ": ", err)
end
return peer_fail(ctx, is_backup, id, peer)
end
local bytes, err = sock:send(req)
if not bytes then
return peer_error(ctx, is_backup, id, peer,
"failed to send request to ", name, ": ", err)
end
local status_line, err = sock:receive()
if not status_line then
peer_error(ctx, is_backup, id, peer,
"failed to receive status line from ", name, ": ", err)
if err == "timeout" then
sock:close() -- timeout errors do not close the socket.
end
return
end
if statuses then
local from, to, err = re_find(status_line,
[[^HTTP/\d+\.\d+\s+(\d+)]],
"joi", nil, 1)
if err then
errlog("failed to parse status line: ", err)
end
if not from then
peer_error(ctx, is_backup, id, peer,
"bad status line from ", name, ": ",
status_line)
sock:close()
return
end
local status = tonumber(sub(status_line, from, to))
if not statuses[status] then
peer_error(ctx, is_backup, id, peer, "bad status code from ",
name, ": ", status)
sock:close()
return
end
end
peer_ok(ctx, is_backup, id, peer)
sock:close()
end
local function check_peer_range(ctx, from, to, peers, is_backup)
for i = from, to do
check_peer(ctx, i - 1, peers[i], is_backup)
end
end
local function check_peers(ctx, peers, is_backup)
local n = #peers
if n == 0 then
return
end
local concur = ctx.concurrency
if concur <= 1 then
for i = 1, n do
check_peer(ctx, i - 1, peers[i], is_backup)
end
else
local threads
local nthr
if n <= concur then
nthr = n - 1
threads = new_tab(nthr, 0)
for i = 1, nthr do
if debug_mode then
debug("spawn a thread checking ",
is_backup and "backup" or "primary", " peer ", i - 1)
end
threads[i] = spawn(check_peer, ctx, i - 1, peers[i], is_backup)
end
-- use the current "light thread" to run the last task
if debug_mode then
debug("check ", is_backup and "backup" or "primary", " peer ",
n - 1)
end
check_peer(ctx, n - 1, peers[n], is_backup)
else
local group_size = ceil(n / concur)
nthr = ceil(n / group_size) - 1
threads = new_tab(nthr, 0)
local from = 1
local rest = n
for i = 1, nthr do
local to
if rest >= group_size then
rest = rest - group_size
to = from + group_size - 1
else
rest = 0
to = from + rest - 1
end
if debug_mode then
debug("spawn a thread checking ",
is_backup and "backup" or "primary", " peers ",
from - 1, " to ", to - 1)
end
threads[i] = spawn(check_peer_range, ctx, from, to, peers,
is_backup)
from = from + group_size
if rest == 0 then
break
end
end
if rest > 0 then
local to = from + rest - 1
if debug_mode then
debug("check ", is_backup and "backup" or "primary",
" peers ", from - 1, " to ", to - 1)
end
check_peer_range(ctx, from, to, peers, is_backup)
end
end
if nthr and nthr > 0 then
for i = 1, nthr do
local t = threads[i]
if t then
wait(t)
end
end
end
end
end
local function upgrade_peers_version(ctx, peers, is_backup)
local dict = ctx.dict
local u = ctx.upstream
local n = #peers
for i = 1, n do
local peer = peers[i]
local id = i - 1
local key = gen_peer_key("d:", u, is_backup, id)
local down = false
local res, err = dict:get(key)
if not res then
if err then
errlog("failed to get peer down state: ", err)
end
else
down = true
end
if (peer.down and not down) or (not peer.down and down) then
local ok, err = set_peer_down(u, is_backup, id, down)
if not ok then
errlog("failed to set peer down: ", err)
else
-- update our cache too
peer.down = down
end
end
end
end
local function check_peers_updates(ctx)
local dict = ctx.dict
local u = ctx.upstream
local key = "v:" .. u
local ver, err = dict:get(key)
if not ver then
if err then
errlog("failed to get peers version: ", err)
return
end
if ctx.version > 0 then
ctx.new_version = true
end
elseif ctx.version < ver then
debug("upgrading peers version to ", ver)
upgrade_peers_version(ctx, ctx.primary_peers, false);
upgrade_peers_version(ctx, ctx.backup_peers, true);
ctx.version = ver
end
end
local function get_lock(ctx)
local dict = ctx.dict
local key = "l:" .. ctx.upstream
-- the lock is held for the whole interval to prevent multiple
-- worker processes from sending the test request simultaneously.
-- here we substract the lock expiration time by 1ms to prevent
-- a race condition with the next timer event.
local ok, err = dict:add(key, true, ctx.interval - 0.001)
if not ok then
if err == "exists" then
return nil
end
errlog("failed to add key \"", key, "\": ", err)
return nil
end
return true
end
local function do_check(ctx)
debug("healthcheck: run a check cycle")
check_peers_updates(ctx)
if get_lock(ctx) then
check_peers(ctx, ctx.primary_peers, false)
check_peers(ctx, ctx.backup_peers, true)
end
if ctx.new_version then
local key = "v:" .. ctx.upstream
local dict = ctx.dict
if debug_mode then
debug("publishing peers version ", ctx.version + 1)
end
dict:add(key, 0)
local new_ver, err = dict:incr(key, 1)
if not new_ver then
errlog("failed to publish new peers version: ", err)
end
ctx.version = new_ver
ctx.new_version = nil
end
end
local function update_upstream_checker_status(upstream, success)
local cnt = upstream_checker_statuses[upstream]
if not cnt then
cnt = 0
end
if success then
cnt = cnt + 1
else
cnt = cnt - 1
end
upstream_checker_statuses[upstream] = cnt
end
local check
check = function (premature, ctx)
if premature then
return
end
local ok, err = pcall(do_check, ctx)
if not ok then
errlog("failed to run healthcheck cycle: ", err)
end
local ok, err = new_timer(ctx.interval, check, ctx)
if not ok then
if err ~= "process exiting" then
errlog("failed to create timer: ", err)
end
update_upstream_checker_status(ctx.upstream, false)
return
end
end
local function preprocess_peers(peers)
local n = #peers
for i = 1, n do
local p = peers[i]
local name = p.name
if name then
local from, to, err = re_find(name, [[^(.*):\d+$]], "jo", nil, 1)
if from then
p.host = sub(name, 1, to)
p.port = tonumber(sub(name, to + 2))
end
end
end
return peers
end
function _M.spawn_checker(opts)
local typ = opts.type
if not typ then
return nil, "\"type\" option required"
end
if typ ~= "http" then
return nil, "only \"http\" type is supported right now"
end
local http_req = opts.http_req
if not http_req then
return nil, "\"http_req\" option required"
end
local timeout = opts.timeout
if not timeout then
timeout = 1000
end
local interval = opts.interval
if not interval then
interval = 1
else
interval = interval / 1000
if interval < 0.002 then -- minimum 2ms
interval = 0.002
end
end
local valid_statuses = opts.valid_statuses
local statuses
if valid_statuses then
statuses = new_tab(0, #valid_statuses)
for _, status in ipairs(valid_statuses) do
-- print("found good status ", status)
statuses[status] = true
end
end
-- debug("interval: ", interval)
local concur = opts.concurrency
if not concur then
concur = 1
end
local fall = opts.fall
if not fall then
fall = 5
end
local rise = opts.rise
if not rise then
rise = 2
end
local shm = opts.shm
if not shm then
return nil, "\"shm\" option required"
end
local dict = shared[shm]
if not dict then
return nil, "shm \"" .. tostring(shm) .. "\" not found"
end
local u = opts.upstream
if not u then
return nil, "no upstream specified"
end
local ppeers, err = get_primary_peers(u)
if not ppeers then
return nil, "failed to get primary peers: " .. err
end
local bpeers, err = get_backup_peers(u)
if not bpeers then
return nil, "failed to get backup peers: " .. err
end
local ctx = {
upstream = u,
primary_peers = preprocess_peers(ppeers),
backup_peers = preprocess_peers(bpeers),
http_req = http_req,
timeout = timeout,
interval = interval,
dict = dict,
fall = fall,
rise = rise,
statuses = statuses,
version = 0,
concurrency = concur,
}
if debug_mode and opts.no_timer then
check(nil, ctx)
else
local ok, err = new_timer(0, check, ctx)
if not ok then
return nil, "failed to create timer: " .. err
end
end
update_upstream_checker_status(u, true)
return true
end
local function gen_peers_status_info(peers, bits, idx)
local npeers = #peers
for i = 1, npeers do
local peer = peers[i]
bits[idx] = " "
bits[idx + 1] = peer.name
if peer.down then
bits[idx + 2] = " DOWN\n"
else
bits[idx + 2] = " up\n"
end
idx = idx + 3
end
return idx
end
function _M.status_page()
-- generate an HTML page
local us, err = get_upstreams()
if not us then
return "failed to get upstream names: " .. err
end
local n = #us
local bits = new_tab(n * 20, 0)
local idx = 1
for i = 1, n do
if i > 1 then
bits[idx] = "\n"
idx = idx + 1
end
local u = us[i]
bits[idx] = "Upstream "
bits[idx + 1] = u
idx = idx + 2
local ncheckers = upstream_checker_statuses[u]
if not ncheckers or ncheckers == 0 then
bits[idx] = " (NO checkers)"
idx = idx + 1
end
bits[idx] = "\n Primary Peers\n"
idx = idx + 1
local peers, err = get_primary_peers(u)
if not peers then
return "failed to get primary peers in upstream " .. u .. ": "
.. err
end
idx = gen_peers_status_info(peers, bits, idx)
bits[idx] = " Backup Peers\n"
idx = idx + 1
peers, err = get_backup_peers(u)
if not peers then
return "failed to get backup peers in upstream " .. u .. ": "
.. err
end
idx = gen_peers_status_info(peers, bits, idx)
end
return concat(bits)
end
return _M

353
resty/websocket/client.lua Normal file
View File

@@ -0,0 +1,353 @@
-- Copyright (C) Yichun Zhang (agentzh)
-- FIXME: this library is very rough and is currently just for testing
-- the websocket server.
local wbproto = require "resty.websocket.protocol"
local bit = require "bit"
local _recv_frame = wbproto.recv_frame
local _send_frame = wbproto.send_frame
local new_tab = wbproto.new_tab
local tcp = ngx.socket.tcp
local re_match = ngx.re.match
local encode_base64 = ngx.encode_base64
local concat = table.concat
local char = string.char
local str_find = string.find
local rand = math.random
local rshift = bit.rshift
local band = bit.band
local setmetatable = setmetatable
local type = type
local debug = ngx.config.debug
local ngx_log = ngx.log
local ngx_DEBUG = ngx.DEBUG
local ssl_support = true
if not ngx.config
or not ngx.config.ngx_lua_version
or ngx.config.ngx_lua_version < 9011
then
ssl_support = false
end
local _M = new_tab(0, 13)
_M._VERSION = '0.08'
local mt = { __index = _M }
function _M.new(self, opts)
local sock, err = tcp()
if not sock then
return nil, err
end
local max_payload_len, send_unmasked, timeout
if opts then
max_payload_len = opts.max_payload_len
send_unmasked = opts.send_unmasked
timeout = opts.timeout
if timeout then
sock:settimeout(timeout)
end
end
return setmetatable({
sock = sock,
max_payload_len = max_payload_len or 65535,
send_unmasked = send_unmasked,
}, mt)
end
function _M.connect(self, uri, opts)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
local m, err = re_match(uri, [[^(wss?)://([^:/]+)(?::(\d+))?(.*)]], "jo")
if not m then
if err then
return nil, "failed to match the uri: " .. err
end
return nil, "bad websocket uri"
end
local scheme = m[1]
local host = m[2]
local port = m[3]
local path = m[4]
-- ngx.say("host: ", host)
-- ngx.say("port: ", port)
if not port then
port = 80
end
if path == "" then
path = "/"
end
local ssl_verify, headers, proto_header, origin_header, sock_opts = false
if opts then
local protos = opts.protocols
if protos then
if type(protos) == "table" then
proto_header = "\r\nSec-WebSocket-Protocol: "
.. concat(protos, ",")
else
proto_header = "\r\nSec-WebSocket-Protocol: " .. protos
end
end
local origin = opts.origin
if origin then
origin_header = "\r\nOrigin: " .. origin
end
local pool = opts.pool
if pool then
sock_opts = { pool = pool }
end
if opts.ssl_verify then
if not ssl_support then
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
end
ssl_verify = true
end
if opts.headers then
headers = opts.headers
if type(headers) ~= "table" then
return nil, "custom headers must be a table"
end
end
end
local ok, err
if sock_opts then
ok, err = sock:connect(host, port, sock_opts)
else
ok, err = sock:connect(host, port)
end
if not ok then
return nil, "failed to connect: " .. err
end
if scheme == "wss" then
if not ssl_support then
return nil, "ngx_lua 0.9.11+ required for SSL sockets"
end
ok, err = sock:sslhandshake(false, host, ssl_verify)
if not ok then
return nil, "ssl handshake failed: " .. err
end
end
-- check for connections from pool:
local count, err = sock:getreusedtimes()
if not count then
return nil, "failed to get reused times: " .. err
end
if count > 0 then
-- being a reused connection (must have done handshake)
return 1
end
local custom_headers
if headers then
custom_headers = concat(headers, "\r\n")
custom_headers = "\r\n" .. custom_headers
end
-- do the websocket handshake:
local bytes = char(rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1, rand(256) - 1, rand(256) - 1,
rand(256) - 1)
local key = encode_base64(bytes)
local req = "GET " .. path .. " HTTP/1.1\r\nUpgrade: websocket\r\nHost: "
.. host .. ":" .. port
.. "\r\nSec-WebSocket-Key: " .. key
.. (proto_header or "")
.. "\r\nSec-WebSocket-Version: 13"
.. (origin_header or "")
.. "\r\nConnection: Upgrade"
.. (custom_headers or "")
.. "\r\n\r\n"
local bytes, err = sock:send(req)
if not bytes then
return nil, "failed to send the handshake request: " .. err
end
local header_reader = sock:receiveuntil("\r\n\r\n")
-- FIXME: check for too big response headers
local header, err, partial = header_reader()
if not header then
return nil, "failed to receive response header: " .. err
end
-- error("header: " .. header)
-- FIXME: verify the response headers
m, err = re_match(header, [[^\s*HTTP/1\.1\s+]], "jo")
if not m then
return nil, "bad HTTP response status line: " .. header
end
return 1
end
function _M.set_timeout(self, time)
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
return sock:settimeout(time)
end
function _M.recv_frame(self)
if self.fatal then
return nil, nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
local data, typ, err = _recv_frame(sock, self.max_payload_len, false)
if not data and not str_find(err, ": timeout", 1, true) then
self.fatal = true
end
return data, typ, err
end
local function send_frame(self, fin, opcode, payload)
if self.fatal then
return nil, "fatal error already happened"
end
if self.closed then
return nil, "already closed"
end
local sock = self.sock
if not sock then
return nil, "not initialized yet"
end
local bytes, err = _send_frame(sock, fin, opcode, payload,
self.max_payload_len,
not self.send_unmasked)
if not bytes then
self.fatal = true
end
return bytes, err
end
_M.send_frame = send_frame
function _M.send_text(self, data)
return send_frame(self, true, 0x1, data)
end
function _M.send_binary(self, data)
return send_frame(self, true, 0x2, data)
end
local function send_close(self, code, msg)
local payload
if code then
if type(code) ~= "number" or code > 0x7fff then
return nil, "bad status code"
end
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
.. (msg or "")
end
if debug then
ngx_log(ngx_DEBUG, "sending the close frame")
end
local bytes, err = send_frame(self, true, 0x8, payload)
if not bytes then
self.fatal = true
end
self.closed = true
return bytes, err
end
_M.send_close = send_close
function _M.send_ping(self, data)
return send_frame(self, true, 0x9, data)
end
function _M.send_pong(self, data)
return send_frame(self, true, 0xa, data)
end
function _M.close(self)
if self.fatal then
return nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, "not initialized"
end
if not self.closed then
local bytes, err = send_close(self)
if not bytes then
return nil, "failed to send close frame: " .. err
end
end
return sock:close()
end
function _M.set_keepalive(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:setkeepalive(...)
end
return _M

View File

@@ -0,0 +1,345 @@
-- Copyright (C) Yichun Zhang (agentzh)
local bit = require "bit"
local ffi = require "ffi"
local byte = string.byte
local char = string.char
local sub = string.sub
local band = bit.band
local bor = bit.bor
local bxor = bit.bxor
local lshift = bit.lshift
local rshift = bit.rshift
--local tohex = bit.tohex
local tostring = tostring
local concat = table.concat
local rand = math.random
local type = type
local debug = ngx.config.debug
local ngx_log = ngx.log
local ngx_DEBUG = ngx.DEBUG
local ffi_new = ffi.new
local ffi_string = ffi.string
local ok, new_tab = pcall(require, "table.new")
if not ok then
new_tab = function (narr, nrec) return {} end
end
local _M = new_tab(0, 5)
_M.new_tab = new_tab
_M._VERSION = '0.08'
local types = {
[0x0] = "continuation",
[0x1] = "text",
[0x2] = "binary",
[0x8] = "close",
[0x9] = "ping",
[0xa] = "pong",
}
local str_buf_size = 4096
local str_buf
local c_buf_type = ffi.typeof("char[?]")
local function get_string_buf(size)
if size > str_buf_size then
return ffi_new(c_buf_type, size)
end
if not str_buf then
str_buf = ffi_new(c_buf_type, str_buf_size)
end
return str_buf
end
function _M.recv_frame(sock, max_payload_len, force_masking)
local data, err = sock:receive(2)
if not data then
return nil, nil, "failed to receive the first 2 bytes: " .. err
end
local fst, snd = byte(data, 1, 2)
local fin = band(fst, 0x80) ~= 0
-- print("fin: ", fin)
if band(fst, 0x70) ~= 0 then
return nil, nil, "bad RSV1, RSV2, or RSV3 bits"
end
local opcode = band(fst, 0x0f)
-- print("opcode: ", tohex(opcode))
if opcode >= 0x3 and opcode <= 0x7 then
return nil, nil, "reserved non-control frames"
end
if opcode >= 0xb and opcode <= 0xf then
return nil, nil, "reserved control frames"
end
local mask = band(snd, 0x80) ~= 0
if debug then
ngx_log(ngx_DEBUG, "recv_frame: mask bit: ", mask and 1 or 0)
end
if force_masking and not mask then
return nil, nil, "frame unmasked"
end
local payload_len = band(snd, 0x7f)
-- print("payload len: ", payload_len)
if payload_len == 126 then
local data, err = sock:receive(2)
if not data then
return nil, nil, "failed to receive the 2 byte payload length: "
.. (err or "unknown")
end
payload_len = bor(lshift(byte(data, 1), 8), byte(data, 2))
elseif payload_len == 127 then
local data, err = sock:receive(8)
if not data then
return nil, nil, "failed to receive the 8 byte payload length: "
.. (err or "unknown")
end
if byte(data, 1) ~= 0
or byte(data, 2) ~= 0
or byte(data, 3) ~= 0
or byte(data, 4) ~= 0
then
return nil, nil, "payload len too large"
end
local fifth = byte(data, 5)
if band(fifth, 0x80) ~= 0 then
return nil, nil, "payload len too large"
end
payload_len = bor(lshift(fifth, 24),
lshift(byte(data, 6), 16),
lshift(byte(data, 7), 8),
byte(data, 8))
end
if band(opcode, 0x8) ~= 0 then
-- being a control frame
if payload_len > 125 then
return nil, nil, "too long payload for control frame"
end
if not fin then
return nil, nil, "fragmented control frame"
end
end
-- print("payload len: ", payload_len, ", max payload len: ",
-- max_payload_len)
if payload_len > max_payload_len then
return nil, nil, "exceeding max payload len"
end
local rest
if mask then
rest = payload_len + 4
else
rest = payload_len
end
-- print("rest: ", rest)
local data
if rest > 0 then
data, err = sock:receive(rest)
if not data then
return nil, nil, "failed to read masking-len and payload: "
.. (err or "unknown")
end
else
data = ""
end
-- print("received rest")
if opcode == 0x8 then
-- being a close frame
if payload_len > 0 then
if payload_len < 2 then
return nil, nil, "close frame with a body must carry a 2-byte"
.. " status code"
end
local msg, code
if mask then
local fst = bxor(byte(data, 4 + 1), byte(data, 1))
local snd = bxor(byte(data, 4 + 2), byte(data, 2))
code = bor(lshift(fst, 8), snd)
if payload_len > 2 then
-- TODO string.buffer optimizations
local bytes = get_string_buf(payload_len - 2)
for i = 3, payload_len do
bytes[i - 3] = bxor(byte(data, 4 + i),
byte(data, (i - 1) % 4 + 1))
end
msg = ffi_string(bytes, payload_len - 2)
else
msg = ""
end
else
local fst = byte(data, 1)
local snd = byte(data, 2)
code = bor(lshift(fst, 8), snd)
-- print("parsing unmasked close frame payload: ", payload_len)
if payload_len > 2 then
msg = sub(data, 3)
else
msg = ""
end
end
return msg, "close", code
end
return "", "close", nil
end
local msg
if mask then
-- TODO string.buffer optimizations
local bytes = get_string_buf(payload_len)
for i = 1, payload_len do
bytes[i - 1] = bxor(byte(data, 4 + i),
byte(data, (i - 1) % 4 + 1))
end
msg = ffi_string(bytes, payload_len)
else
msg = data
end
return msg, types[opcode], not fin and "again" or nil
end
local function build_frame(fin, opcode, payload_len, payload, masking)
-- XXX optimize this when we have string.buffer in LuaJIT 2.1
local fst
if fin then
fst = bor(0x80, opcode)
else
fst = opcode
end
local snd, extra_len_bytes
if payload_len <= 125 then
snd = payload_len
extra_len_bytes = ""
elseif payload_len <= 65535 then
snd = 126
extra_len_bytes = char(band(rshift(payload_len, 8), 0xff),
band(payload_len, 0xff))
else
if band(payload_len, 0x7fffffff) < payload_len then
return nil, "payload too big"
end
snd = 127
-- XXX we only support 31-bit length here
extra_len_bytes = char(0, 0, 0, 0, band(rshift(payload_len, 24), 0xff),
band(rshift(payload_len, 16), 0xff),
band(rshift(payload_len, 8), 0xff),
band(payload_len, 0xff))
end
local masking_key
if masking then
-- set the mask bit
snd = bor(snd, 0x80)
local key = rand(0xffffffff)
masking_key = char(band(rshift(key, 24), 0xff),
band(rshift(key, 16), 0xff),
band(rshift(key, 8), 0xff),
band(key, 0xff))
-- TODO string.buffer optimizations
local bytes = get_string_buf(payload_len)
for i = 1, payload_len do
bytes[i - 1] = bxor(byte(payload, i),
byte(masking_key, (i - 1) % 4 + 1))
end
payload = ffi_string(bytes, payload_len)
else
masking_key = ""
end
return char(fst, snd) .. extra_len_bytes .. masking_key .. payload
end
_M.build_frame = build_frame
function _M.send_frame(sock, fin, opcode, payload, max_payload_len, masking)
-- ngx.log(ngx.WARN, ngx.var.uri, ": masking: ", masking)
if not payload then
payload = ""
elseif type(payload) ~= "string" then
payload = tostring(payload)
end
local payload_len = #payload
if payload_len > max_payload_len then
return nil, "payload too big"
end
if band(opcode, 0x8) ~= 0 then
-- being a control frame
if payload_len > 125 then
return nil, "too much payload for control frame"
end
if not fin then
return nil, "fragmented control frame"
end
end
local frame, err = build_frame(fin, opcode, payload_len, payload,
masking)
if not frame then
return nil, "failed to build frame: " .. err
end
local bytes, err = sock:send(frame)
if not bytes then
return nil, "failed to send frame: " .. err
end
return bytes
end
return _M

210
resty/websocket/server.lua Normal file
View File

@@ -0,0 +1,210 @@
-- Copyright (C) Yichun Zhang (agentzh)
local bit = require "bit"
local wbproto = require "resty.websocket.protocol"
local new_tab = wbproto.new_tab
local _recv_frame = wbproto.recv_frame
local _send_frame = wbproto.send_frame
local http_ver = ngx.req.http_version
local req_sock = ngx.req.socket
local ngx_header = ngx.header
local req_headers = ngx.req.get_headers
local str_lower = string.lower
local char = string.char
local str_find = string.find
local sha1_bin = ngx.sha1_bin
local base64 = ngx.encode_base64
local ngx = ngx
local read_body = ngx.req.read_body
local band = bit.band
local rshift = bit.rshift
local type = type
local setmetatable = setmetatable
local tostring = tostring
-- local print = print
local _M = new_tab(0, 10)
_M._VERSION = '0.08'
local mt = { __index = _M }
function _M.new(self, opts)
if ngx.headers_sent then
return nil, "response header already sent"
end
read_body()
if http_ver() ~= 1.1 then
return nil, "bad http version"
end
local headers = req_headers()
local val = headers.upgrade
if type(val) == "table" then
val = val[1]
end
if not val or str_lower(val) ~= "websocket" then
return nil, "bad \"upgrade\" request header: " .. tostring(val)
end
val = headers.connection
if type(val) == "table" then
val = val[1]
end
if not val or not str_find(str_lower(val), "upgrade", 1, true) then
return nil, "bad \"connection\" request header"
end
local key = headers["sec-websocket-key"]
if type(key) == "table" then
key = key[1]
end
if not key then
return nil, "bad \"sec-websocket-key\" request header"
end
local ver = headers["sec-websocket-version"]
if type(ver) == "table" then
ver = ver[1]
end
if not ver or ver ~= "13" then
return nil, "bad \"sec-websocket-version\" request header"
end
local protocols = headers["sec-websocket-protocol"]
if type(protocols) == "table" then
protocols = protocols[1]
end
if protocols then
ngx_header["Sec-WebSocket-Protocol"] = protocols
end
ngx_header["Upgrade"] = "websocket"
local sha1 = sha1_bin(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
ngx_header["Sec-WebSocket-Accept"] = base64(sha1)
ngx_header["Content-Type"] = nil
ngx.status = 101
local ok, err = ngx.send_headers()
if not ok then
return nil, "failed to send response header: " .. (err or "unknonw")
end
ok, err = ngx.flush(true)
if not ok then
return nil, "failed to flush response header: " .. (err or "unknown")
end
local sock
sock, err = req_sock(true)
if not sock then
return nil, err
end
local max_payload_len, send_masked, timeout
if opts then
max_payload_len = opts.max_payload_len
send_masked = opts.send_masked
timeout = opts.timeout
if timeout then
sock:settimeout(timeout)
end
end
return setmetatable({
sock = sock,
max_payload_len = max_payload_len or 65535,
send_masked = send_masked,
}, mt)
end
function _M.set_timeout(self, time)
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
return sock:settimeout(time)
end
function _M.recv_frame(self)
if self.fatal then
return nil, nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, nil, "not initialized yet"
end
local data, typ, err = _recv_frame(sock, self.max_payload_len, true)
if not data and not str_find(err, ": timeout", 1, true) then
self.fatal = true
end
return data, typ, err
end
local function send_frame(self, fin, opcode, payload)
if self.fatal then
return nil, "fatal error already happened"
end
local sock = self.sock
if not sock then
return nil, "not initialized yet"
end
local bytes, err = _send_frame(sock, fin, opcode, payload,
self.max_payload_len, self.send_masked)
if not bytes then
self.fatal = true
end
return bytes, err
end
_M.send_frame = send_frame
function _M.send_text(self, data)
return send_frame(self, true, 0x1, data)
end
function _M.send_binary(self, data)
return send_frame(self, true, 0x2, data)
end
function _M.send_close(self, code, msg)
local payload
if code then
if type(code) ~= "number" or code > 0x7fff then
end
payload = char(band(rshift(code, 8), 0xff), band(code, 0xff))
.. (msg or "")
end
return send_frame(self, true, 0x8, payload)
end
function _M.send_ping(self, data)
return send_frame(self, true, 0x9, data)
end
function _M.send_pong(self, data)
return send_frame(self, true, 0xa, data)
end
return _M

View File

@@ -1,2 +1,2 @@
jsConfuse()
dateReplace()
dateReplace()

33
tools.lua Normal file
View File

@@ -0,0 +1,33 @@
function i_get_cookie(s_cookie)
local cookie = {}
-- string.gfind is renamed to string.gmatch
for item in string.gmatch(s_cookie, "[^;]+") do
local _, _, k, v = string.find(item, "^%s*(%S+)%s*=%s*(%S+)%s*")
if k ~= nil and v ~= nil then
cookie[k] = v
end
end
return cookie
end
function get_cookie_table()
local raw_cookie = ngx.req.get_headers()["Cookie"]
if raw_cookie ~= nil then
return i_get_cookie(raw_cookie)
end
return nil
end
function get_cookie_raw()
return ngx.req.get_headers()["Cookie"]
end
function match_string(input_str, rule)
if input_str == nil then
return false
end
local from, to, err = ngx.re.find(input_str, rule, "jo")
return from ~= nil
end

11
waf/report.lua Normal file
View File

@@ -0,0 +1,11 @@
local _M = {}
function _M.violation(result)
full_violation_text = "violation: \n"
if result == g_violation_sql_detect then
g_result_sql_detect = full_violation_text .. "SQL Injection detected \n"
end
full_violation_text = full_violation_text .. ngx.var.request_uri .. " \n"
log(violation)
say_html()
end
return _M

6
waf/rule.lua Normal file
View File

@@ -0,0 +1,6 @@
local _M = {
sql_get = "'|\\b(and|or)\\b.+?(>|<|=|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)",
sql_post = "\\b(and|or)\\b.{1,6}?(=|>|<|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)",
sql_cookie = "\\b(and|or)\\b.{1,6}?(=|>|<|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)"
}
return _M

37
waf/sql.lua Normal file
View File

@@ -0,0 +1,37 @@
local _M = {}
local ngx_base = require "resty.core.base"
local waf_rule = require "waf/rule"
local waf_voilation = require "waf/violation_list"
local waf_report = require "waf/report"
function _M.waf_sql_filter_params(arg_tables, filter_rule)
if arg_tables then
for key, val in pairs(arg_tables) do
if match_string(val, filter_rule) then
--ngx.say(key .. ":" .. val)
waf_report.violation(waf_voilation.sql_detect)
end
end
end
end
function _M.waf_sql_filter()
--local get_arags = ngx.req.get_uri_args
--for k, v in ipairs(get_arags) do
-- print(v)
--end
if ngx_base.get_request() ~= nil then
_M.waf_sql_filter_params(ngx.req.get_uri_args(), waf_rule.sql_get)
ngx.say(get_cookie_raw())
if match_string(get_cookie_raw(), waf_rule.sql_cookie) then
waf_report.violation(waf_voilation.sql_detect)
end
local is_post_method = ngx.req.get_method() == "POST"
if is_post_method then
ngx.req.read_body()
_M.waf_sql_filter_params(ngx.req.get_post_args(), waf_rule.sql_post)
end
end
end
return _M

3
waf/violation_list.lua Normal file
View File

@@ -0,0 +1,3 @@
return {
sql_detect = 1
}

8
waf/waf.lua Normal file
View File

@@ -0,0 +1,8 @@
local waf_sql = require "waf/sql"
--以后加规则配置、插件这些.现在不加
function waf_dispatch()
waf_sql.waf_sql_filter()
end
waf_dispatch()