From fcace799df65f9dab98211857b381613be07513a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=90=B4=E5=BF=83?= Date: Thu, 3 Mar 2022 16:16:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0sql=E7=9A=84waf=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- init.lua | 102 ++- log.lua | 56 +- resty/aes.lua | 303 +++++++ resty/core.lua | 35 + resty/core/base.lua | 259 ++++++ resty/core/base64.lua | 115 +++ resty/core/ctx.lua | 143 ++++ resty/core/exit.lua | 66 ++ resty/core/hash.lua | 154 ++++ resty/core/misc.lua | 240 ++++++ resty/core/ndk.lua | 92 +++ resty/core/phase.lua | 59 ++ resty/core/regex.lua | 1213 +++++++++++++++++++++++++++ resty/core/request.lua | 451 ++++++++++ resty/core/response.lua | 183 +++++ resty/core/shdict.lua | 638 +++++++++++++++ resty/core/socket.lua | 124 +++ resty/core/time.lua | 159 ++++ resty/core/uri.lua | 115 +++ resty/core/utils.lua | 46 ++ resty/core/var.lua | 160 ++++ resty/core/worker.lua | 77 ++ resty/dns/resolver.lua | 982 ++++++++++++++++++++++ resty/limit/conn.lua | 125 +++ resty/limit/count.lua | 103 +++ resty/limit/req.lua | 153 ++++ resty/limit/traffic.lua | 58 ++ resty/lock.lua | 221 +++++ resty/lrucache.lua | 340 ++++++++ resty/lrucache/pureffi.lua | 606 ++++++++++++++ resty/md5.lua | 72 ++ resty/memcached.lua | 744 +++++++++++++++++ resty/mysql.lua | 1410 ++++++++++++++++++++++++++++++++ resty/random.lua | 36 + resty/redis.lua | 676 +++++++++++++++ resty/sha.lua | 19 + resty/sha1.lua | 69 ++ resty/sha224.lua | 60 ++ resty/sha256.lua | 69 ++ resty/sha384.lua | 60 ++ resty/sha512.lua | 75 ++ resty/string.lua | 40 + resty/upload.lua | 267 ++++++ resty/upstream/healthcheck.lua | 707 ++++++++++++++++ resty/websocket/client.lua | 353 ++++++++ resty/websocket/protocol.lua | 345 ++++++++ resty/websocket/server.lua | 210 +++++ rsp_body.lua | 2 +- tools.lua | 33 + waf/report.lua | 11 + waf/rule.lua | 6 + waf/sql.lua | 37 + waf/violation_list.lua | 3 + waf/waf.lua | 8 + 54 files changed, 12617 insertions(+), 73 deletions(-) create mode 100644 resty/aes.lua create mode 100644 resty/core.lua create mode 100644 resty/core/base.lua create mode 100644 resty/core/base64.lua create mode 100644 resty/core/ctx.lua create mode 100644 resty/core/exit.lua create mode 100644 resty/core/hash.lua create mode 100644 resty/core/misc.lua create mode 100644 resty/core/ndk.lua create mode 100644 resty/core/phase.lua create mode 100644 resty/core/regex.lua create mode 100644 resty/core/request.lua create mode 100644 resty/core/response.lua create mode 100644 resty/core/shdict.lua create mode 100644 resty/core/socket.lua create mode 100644 resty/core/time.lua create mode 100644 resty/core/uri.lua create mode 100644 resty/core/utils.lua create mode 100644 resty/core/var.lua create mode 100644 resty/core/worker.lua create mode 100644 resty/dns/resolver.lua create mode 100644 resty/limit/conn.lua create mode 100644 resty/limit/count.lua create mode 100644 resty/limit/req.lua create mode 100644 resty/limit/traffic.lua create mode 100644 resty/lock.lua create mode 100644 resty/lrucache.lua create mode 100644 resty/lrucache/pureffi.lua create mode 100644 resty/md5.lua create mode 100644 resty/memcached.lua create mode 100644 resty/mysql.lua create mode 100644 resty/random.lua create mode 100644 resty/redis.lua create mode 100644 resty/sha.lua create mode 100644 resty/sha1.lua create mode 100644 resty/sha224.lua create mode 100644 resty/sha256.lua create mode 100644 resty/sha384.lua create mode 100644 resty/sha512.lua create mode 100644 resty/string.lua create mode 100644 resty/upload.lua create mode 100644 resty/upstream/healthcheck.lua create mode 100644 resty/websocket/client.lua create mode 100644 resty/websocket/protocol.lua create mode 100644 resty/websocket/server.lua create mode 100644 tools.lua create mode 100644 waf/report.lua create mode 100644 waf/rule.lua create mode 100644 waf/sql.lua create mode 100644 waf/violation_list.lua create mode 100644 waf/waf.lua diff --git a/init.lua b/init.lua index 599feb4..4fd4bf2 100644 --- a/init.lua +++ b/init.lua @@ -1,42 +1,45 @@ -require 'config' -require 'b64' -require 'aes' -require 'log' -require '403' -require 'tableXstring' -require 'fileio' -require 'randomStr' -require 'whiteList' +require "config" +require "b64" +require "aes" +require "log" +require "403" +require "tableXstring" +require "fileio" +require "randomStr" +require "whiteList" +require "tools" +require "waf/waf" -local optionIsOn = function (options) return options == "on" and true or false end +local optionIsOn = function(options) + return options == "on" and true or false +end ToolsProtect = optionIsOn(toolsProtect) ShiroProtect = optionIsOn(shiroProtect) JsProtect = optionIsOn(jsProtect) JsConfuse = false SensitiveProtect = optionIsOn(sensitiveProtect) - -- cookie加密 function reqCookieParse() if ShiroProtect then local userCookieX9 = ngx.var.cookie_x9i7RDYX23 - if not userCookieX9 then -- 没有cookie - log('0-cookie 无cookie', '') - ngx.req.set_header('Cookie', '') -- 移除其他cookie - elseif #userCookieX9 < 32 then -- 判断cookie长度 - log('1-cookie 不符合要求', userCookieX9) - ngx.say('4') + if not userCookieX9 then -- 没有cookie + log("0-cookie 无cookie", "") + ngx.req.set_header("Cookie", "") -- 移除其他cookie + elseif #userCookieX9 < 32 then -- 判断cookie长度 + log("1-cookie 不符合要求", userCookieX9) + ngx.say("4") say_html() - else --有cookie + else --有cookie local result = xpcall(dencrypT, emptyPrint, userCookieX9, aesKey) if not result then --解密失败 - log('2-cookie 无法解密', userCookieX9) - ngx.say('5') + log("2-cookie 无法解密", userCookieX9) + ngx.say("5") say_html() - else --解密成功 + else --解密成功 local originCookie = StrToTable(dencrypT(userCookieX9, aesKey)) - ngx.req.set_header('Cookie', transTable(originCookie)) - log('3-cookie 解密成功', userCookieX9) + ngx.req.set_header("Cookie", transTable(originCookie)) + log("3-cookie 解密成功", userCookieX9) end end end @@ -46,9 +49,9 @@ function respCookieEncrypt() if ShiroProtect then local value = ngx.resp.get_headers()["Set-Cookie"] if value then - local encryptedCookie = cookieD.."="..encrypT(TableToStr(value), aesKey) + local encryptedCookie = cookieD .. "=" .. encrypT(TableToStr(value), aesKey) ngx.header["Set-Cookie"] = encryptedCookie - log('4-cookie 加密成功',encryptedCookie) + log("4-cookie 加密成功", encryptedCookie) end end end @@ -58,30 +61,30 @@ function toolsInfoSpider() if ToolsProtect and not whiteExtCheck() then local clientCookieA = ngx.var.cookie_h0yGbdRv local clientCookieB = ngx.var.cookie_kQpFHdoh - if not (clientCookieA and clientCookieB) then --没有cookieA进入reload,302至html生成cookie后再请求原地址 - local ip = 'xxx' - local finalPath = 'http://'..ip..'/'..jsPath..'?origin='..encodeBase64(ngx.var.request_uri) - log('1-tools 无cookieA/B', '') + if not (clientCookieA and clientCookieB) then --没有cookieA进入reload,302至html生成cookie后再请求原地址 + local ip = "xxx" + local finalPath = "http://" .. ip .. "/" .. jsPath .. "?origin=" .. encodeBase64(ngx.var.request_uri) + log("1-tools 无cookieA/B", "") ngx.redirect(finalPath, 302) else local result = xpcall(dencrypT, emptyPrint, clientCookieB, clientCookieA) if not result then - log('2-tools 解密失败', clientCookieA..', '..clientCookieB) - ngx.say('1') + log("2-tools 解密失败", clientCookieA .. ", " .. clientCookieB) + ngx.say("1") say_html() -- 解密失败 - else-- 可以解密,提取数据 + else -- 可以解密,提取数据 local result2 = dencrypT(clientCookieB, clientCookieA) if #result2 < 1 then - log('3-tools 解密失败', result2) + log("3-tools 解密失败", result2) else - local srs = split(result2, ',') - local _,e = string.find(srs[1], '0') + local srs = split(result2, ",") + local _, e = string.find(srs[1], "0") if e ~= nil then - log('4-tools 工具请求', result2) - ngx.say('2') + log("4-tools 工具请求", result2) + ngx.say("2") say_html() else - log('0-tools 工具验证通过, 记录浏览器指纹', '', srs[2]) + log("0-tools 工具验证通过, 记录浏览器指纹", "", srs[2]) end end end @@ -93,7 +96,7 @@ end function jsExtDetect() if JsProtect then local ext = string.match(ngx.var.uri, ".+%.(%w+)$") - if ext == 'js' then -- 加入检查,js文件是否存在 + if ext == "js" then -- 加入检查,js文件是否存在 JsConfuse = true end end @@ -102,14 +105,14 @@ end function jsConfuse() if JsConfuse then local originBody = ngx.arg[1] - if #originBody > 200 then -- 筛选空js + if #originBody > 200 then -- 筛选空js local s = getRandom(8) - local path = '/tmp/'..s - writefile(path, originBody, 'w+') - local t = io.popen('export NODE_PATH=/usr/lib/node_modules && node /gate/node/js_confuse.js '..path) + local path = "/tmp/" .. s + writefile(path, originBody, "w+") + local t = io.popen("export NODE_PATH=/usr/lib/node_modules && node /gate/node/js_confuse.js " .. path) local a = t:read("*all") ngx.arg[1] = a - os.execute('rm -f '..path) + os.execute("rm -f " .. path) end JsConfuse = false end @@ -122,14 +125,3 @@ function dateReplace() ngx.arg[1] = replaceTelephone end end - - - - - - - - - - - diff --git a/log.lua b/log.lua index f5ba796..a68f211 100644 --- a/log.lua +++ b/log.lua @@ -1,21 +1,24 @@ -require 'config' +require "config" - -local optionIsOn = function (options) return options == "on" and true or false end +local optionIsOn = function(options) + return options == "on" and true or false +end local Attacklog = optionIsOn(attacklog) local logpath = logdir local function getClientIp() - IP = ngx.var.remote_addr + IP = ngx.var.remote_addr if IP == nil then - IP = "unknown" + IP = "unknown" end return IP end -local function write(logfile,msg) - local fd = io.open(logfile,"ab") - if fd == nil then return end +local function write(logfile, msg) + local fd = io.open(logfile, "ab") + if fd == nil then + return + end fd:write(msg) fd:flush() fd:close() @@ -23,19 +26,38 @@ end function log(data, ruletag, fp) if Attacklog then - local fingerprint = fp or '' + local fingerprint = fp or "" local realIp = getClientIp() local method = ngx.var.request_method local ua = ngx.var.http_user_agent - local servername=ngx.var.server_name + local servername = ngx.var.server_name local url = ngx.var.request_uri - local time=ngx.localtime() - if ua then - line = realIp.." ["..time.."] \""..method.." "..servername..url.."\" \""..ruletag.."\" \""..ua.."\" \""..data.."\" \""..fingerprint.."\"\n" + local time = ngx.localtime() + if ua then + line = + realIp .. + " [" .. + time .. + '] "' .. + method .. + " " .. + servername .. + url .. + '" "' .. + ruletag .. + '" "' .. ua .. '" "' .. data .. '" "' .. fingerprint .. '"\n' else - line = realIp.." ["..time.."] \""..method.." "..servername..url.."\" \""..ruletag.."\" - \""..data.."\" \""..fingerprint.."\"\n" + line = + realIp .. + " [" .. + time .. + '] "' .. + method .. + " " .. + servername .. + url .. '" "' .. ruletag .. '" - "' .. data .. '" "' .. fingerprint .. '"\n' end - local filename = logpath..'/'..servername.."_"..ngx.today().."_sec.log" - write(filename,line) + local filename = logpath .. "/" .. servername .. "_" .. ngx.today() .. "_sec.log" + write(filename, line) end -end \ No newline at end of file +end diff --git a/resty/aes.lua b/resty/aes.lua new file mode 100644 index 0000000..57aafe8 --- /dev/null +++ b/resty/aes.lua @@ -0,0 +1,303 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +--local asn1 = require "resty.asn1" +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_gc = ffi.gc +local ffi_str = ffi.string +local ffi_copy = ffi.copy +local C = ffi.C +local setmetatable = setmetatable +--local error = error +local type = type + + +local _M = { _VERSION = '0.14' } + +local mt = { __index = _M } + +local EVP_CTRL_AEAD_SET_IVLEN = 0x09 +local EVP_CTRL_AEAD_GET_TAG = 0x10 +local EVP_CTRL_AEAD_SET_TAG = 0x11 + +ffi.cdef[[ +typedef struct engine_st ENGINE; + +typedef struct evp_cipher_st EVP_CIPHER; +typedef struct evp_cipher_ctx_st EVP_CIPHER_CTX; + +typedef struct env_md_ctx_st EVP_MD_CTX; +typedef struct env_md_st EVP_MD; + +const EVP_MD *EVP_md5(void); +const EVP_MD *EVP_sha(void); +const EVP_MD *EVP_sha1(void); +const EVP_MD *EVP_sha224(void); +const EVP_MD *EVP_sha256(void); +const EVP_MD *EVP_sha384(void); +const EVP_MD *EVP_sha512(void); + +const EVP_CIPHER *EVP_aes_128_ecb(void); +const EVP_CIPHER *EVP_aes_128_cbc(void); +const EVP_CIPHER *EVP_aes_128_cfb1(void); +const EVP_CIPHER *EVP_aes_128_cfb8(void); +const EVP_CIPHER *EVP_aes_128_cfb128(void); +const EVP_CIPHER *EVP_aes_128_ofb(void); +const EVP_CIPHER *EVP_aes_128_ctr(void); +const EVP_CIPHER *EVP_aes_192_ecb(void); +const EVP_CIPHER *EVP_aes_192_cbc(void); +const EVP_CIPHER *EVP_aes_192_cfb1(void); +const EVP_CIPHER *EVP_aes_192_cfb8(void); +const EVP_CIPHER *EVP_aes_192_cfb128(void); +const EVP_CIPHER *EVP_aes_192_ofb(void); +const EVP_CIPHER *EVP_aes_192_ctr(void); +const EVP_CIPHER *EVP_aes_256_ecb(void); +const EVP_CIPHER *EVP_aes_256_cbc(void); +const EVP_CIPHER *EVP_aes_256_cfb1(void); +const EVP_CIPHER *EVP_aes_256_cfb8(void); +const EVP_CIPHER *EVP_aes_256_cfb128(void); +const EVP_CIPHER *EVP_aes_256_ofb(void); +const EVP_CIPHER *EVP_aes_128_gcm(void); +const EVP_CIPHER *EVP_aes_192_gcm(void); +const EVP_CIPHER *EVP_aes_256_gcm(void); + +EVP_CIPHER_CTX *EVP_CIPHER_CTX_new(); +void EVP_CIPHER_CTX_free(EVP_CIPHER_CTX *a); +int EVP_CIPHER_CTX_block_size(const EVP_CIPHER_CTX *ctx); + +int EVP_EncryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher, + ENGINE *impl, unsigned char *key, const unsigned char *iv); + +int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl, + const unsigned char *in, int inl); + +int EVP_EncryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl); + +int EVP_DecryptInit_ex(EVP_CIPHER_CTX *ctx,const EVP_CIPHER *cipher, + ENGINE *impl, unsigned char *key, const unsigned char *iv); + +int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, int *outl, + const unsigned char *in, int inl); + +int EVP_DecryptFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *outm, int *outl); + +int EVP_BytesToKey(const EVP_CIPHER *type,const EVP_MD *md, + const unsigned char *salt, const unsigned char *data, int datal, + int count, unsigned char *key,unsigned char *iv); + +int EVP_CIPHER_CTX_ctrl(EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr); +]] + +local hash +hash = { + md5 = C.EVP_md5(), + sha1 = C.EVP_sha1(), + sha224 = C.EVP_sha224(), + sha256 = C.EVP_sha256(), + sha384 = C.EVP_sha384(), + sha512 = C.EVP_sha512() +} +_M.hash = hash + +local EVP_MAX_BLOCK_LENGTH = 32 + +local cipher +cipher = function (size, _cipher) + local _size = size or 128 + local _cipher = _cipher or "cbc" + local func = "EVP_aes_" .. _size .. "_" .. _cipher + if C[func] then + return { size=_size, cipher=_cipher, method=C[func]()} + else + return nil + end +end +_M.cipher = cipher + +function _M.new(self, key, salt, _cipher, _hash, hash_rounds, iv_len) + local encrypt_ctx = C.EVP_CIPHER_CTX_new() + if encrypt_ctx == nil then + return nil, "no memory" + end + + ffi_gc(encrypt_ctx, C.EVP_CIPHER_CTX_free) + + local decrypt_ctx = C.EVP_CIPHER_CTX_new() + if decrypt_ctx == nil then + return nil, "no memory" + end + + ffi_gc(decrypt_ctx, C.EVP_CIPHER_CTX_free) + + local _cipher = _cipher or cipher() + local _hash = _hash or hash.md5 + local hash_rounds = hash_rounds or 1 + local _cipherLength = _cipher.size/8 + local gen_key = ffi_new("unsigned char[?]",_cipherLength) + local gen_iv = ffi_new("unsigned char[?]",_cipherLength) + iv_len = iv_len or _cipherLength + + if type(_hash) == "table" then + if not _hash.iv then + return nil, "iv is needed" + end + + --[[ Depending on the encryption algorithm, the length of iv will be + different. For detailed, please refer to + https://www.openssl.org/docs/man1.1.0/man3/EVP_CIPHER_CTX_ctrl.html + ]] + iv_len = #_hash.iv + if iv_len > _cipherLength then + return nil, "bad iv length" + end + + if _hash.method then + local tmp_key = _hash.method(key) + + if #tmp_key ~= _cipherLength then + return nil, "bad key length" + end + + ffi_copy(gen_key, tmp_key, _cipherLength) + + elseif #key ~= _cipherLength then + return nil, "bad key length" + + else + ffi_copy(gen_key, key, _cipherLength) + end + + ffi_copy(gen_iv, _hash.iv, iv_len) + + else + if salt and #salt ~= 8 then + return nil, "salt must be 8 characters or nil" + end + + if C.EVP_BytesToKey(_cipher.method, _hash, salt, key, #key, + hash_rounds, gen_key, gen_iv) + ~= _cipherLength + then + return nil, "failed to generate key and iv" + end + end + + if C.EVP_EncryptInit_ex(encrypt_ctx, _cipher.method, nil, + nil, nil) == 0 or + C.EVP_DecryptInit_ex(decrypt_ctx, _cipher.method, nil, + nil, nil) == 0 then + return nil, "failed to init ctx" + end + + local cipher_name = _cipher.cipher + if cipher_name == "gcm" + or cipher_name == "ccm" + or cipher_name == "ocb" then + if C.EVP_CIPHER_CTX_ctrl(encrypt_ctx, EVP_CTRL_AEAD_SET_IVLEN, + iv_len, nil) == 0 or + C.EVP_CIPHER_CTX_ctrl(decrypt_ctx, EVP_CTRL_AEAD_SET_IVLEN, + iv_len, nil) == 0 then + return nil, "failed to set IV length" + end + end + + return setmetatable({ + _encrypt_ctx = encrypt_ctx, + _decrypt_ctx = decrypt_ctx, + _cipher = _cipher.cipher, + _key = gen_key, + _iv = gen_iv + }, mt) +end + + +function _M.encrypt(self, s) + local typ = type(self) + if typ ~= "table" then + error("bad argument #1 self: table expected, got " .. typ, 2) + end + + local s_len = #s + local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH + local buf = ffi_new("unsigned char[?]", max_len) + local out_len = ffi_new("int[1]") + local tmp_len = ffi_new("int[1]") + local ctx = self._encrypt_ctx + + if C.EVP_EncryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then + return nil, "EVP_EncryptInit_ex failed" + end + + if C.EVP_EncryptUpdate(ctx, buf, out_len, s, s_len) == 0 then + return nil, "EVP_EncryptUpdate failed" + end + + if self._cipher == "gcm" then + local encrypt_data = ffi_str(buf, out_len[0]) + if C.EVP_EncryptFinal_ex(ctx, buf, out_len) == 0 then + return nil, "EVP_DecryptFinal_ex failed" + end + + -- FIXME: For OCB mode the taglen must either be 16 + -- or the value previously set via EVP_CTRL_OCB_SET_TAGLEN. + -- so we should extend this api in the future + C.EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, 16, buf); + local tag = ffi_str(buf, 16) + return {encrypt_data, tag} + end + + if C.EVP_EncryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then + return nil, "EVP_EncryptFinal_ex failed" + end + + return ffi_str(buf, out_len[0] + tmp_len[0]) +end + + +function _M.decrypt(self, s, tag) + local typ = type(self) + if typ ~= "table" then + error("bad argument #1 self: table expected, got " .. typ, 2) + end + + local s_len = #s + local max_len = s_len + 2 * EVP_MAX_BLOCK_LENGTH + local buf = ffi_new("unsigned char[?]", max_len) + local out_len = ffi_new("int[1]") + local tmp_len = ffi_new("int[1]") + local ctx = self._decrypt_ctx + + if C.EVP_DecryptInit_ex(ctx, nil, nil, self._key, self._iv) == 0 then + return nil, "EVP_DecryptInit_ex failed" + end + + if C.EVP_DecryptUpdate(ctx, buf, out_len, s, s_len) == 0 then + return nil, "EVP_DecryptUpdate failed" + end + + if self._cipher == "gcm" then + local plain_txt = ffi_str(buf, out_len[0]) + if tag ~= nil then + local tag_buf = ffi_new("unsigned char[?]", 16) + ffi.copy(tag_buf, tag, 16) + C.EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_SET_TAG, 16, tag_buf); + end + + if C.EVP_DecryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then + return nil, "EVP_DecryptFinal_ex failed" + end + + return plain_txt + end + + if C.EVP_DecryptFinal_ex(ctx, buf + out_len[0], tmp_len) == 0 then + return nil, "EVP_DecryptFinal_ex failed" + end + + return ffi_str(buf, out_len[0] + tmp_len[0]) +end + + +return _M + diff --git a/resty/core.lua b/resty/core.lua new file mode 100644 index 0000000..5472230 --- /dev/null +++ b/resty/core.lua @@ -0,0 +1,35 @@ +-- Copyright (C) Yichun Zhang (agentzh) + +local subsystem = ngx.config.subsystem + + +require "resty.core.var" +require "resty.core.worker" +require "resty.core.regex" +require "resty.core.shdict" +require "resty.core.time" +require "resty.core.hash" +require "resty.core.uri" +require "resty.core.exit" +require "resty.core.base64" +require "resty.core.request" + + +if subsystem == 'http' then + require "resty.core.response" + require "resty.core.phase" + require "resty.core.ndk" + require "resty.core.socket" +end + + +require "resty.core.misc" +require "resty.core.ctx" + + +local base = require "resty.core.base" + + +return { + version = base.version +} diff --git a/resty/core/base.lua b/resty/core/base.lua new file mode 100644 index 0000000..608e110 --- /dev/null +++ b/resty/core/base.lua @@ -0,0 +1,259 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require 'ffi' +local ffi_new = ffi.new +local error = error +local select = select +local ceil = math.ceil +local subsystem = ngx.config.subsystem + + +local str_buf_size = 4096 +local str_buf +local size_ptr +local FREE_LIST_REF = 0 + + +if subsystem == 'http' then + local ngx_lua_v = ngx.config.ngx_lua_version + if not ngx.config + or not ngx.config.ngx_lua_version + or (ngx_lua_v ~= 10019 and ngx_lua_v ~= 10020) + then + error("ngx_http_lua_module 0.10.19 or 0.10.20 required") + end + +elseif subsystem == 'stream' then + if not ngx.config + or not ngx.config.ngx_lua_version + or ngx.config.ngx_lua_version ~= 10 + then + error("ngx_stream_lua_module 0.0.10 required") + end + +else + error("ngx_http_lua_module 0.10.20 or " + .. "ngx_stream_lua_module 0.0.10 required") +end + + +if string.find(jit.version, " 2.0", 1, true) then + ngx.log(ngx.ALERT, "use of lua-resty-core with LuaJIT 2.0 is ", + "not recommended; use LuaJIT 2.1+ instead") +end + + +local ok, new_tab = pcall(require, "table.new") +if not ok then + new_tab = function (narr, nrec) return {} end +end + + +local clear_tab +ok, clear_tab = pcall(require, "table.clear") +if not ok then + local pairs = pairs + clear_tab = function (tab) + for k, _ in pairs(tab) do + tab[k] = nil + end + end +end + + +-- XXX for now LuaJIT 2.1 cannot compile require() +-- so we make the fast code path Lua only in our own +-- wrapper so that most of the require() calls in hot +-- Lua code paths can be JIT compiled. +do + local orig_require = require + local pkg_loaded = package.loaded + local function my_require(name) + local mod = pkg_loaded[name] + if mod then + return mod + end + return orig_require(name) + end + getfenv(0).require = my_require +end + + +if not pcall(ffi.typeof, "ngx_str_t") then + ffi.cdef[[ + typedef struct { + size_t len; + const unsigned char *data; + } ngx_str_t; + ]] +end + + +if subsystem == 'http' then + if not pcall(ffi.typeof, "ngx_http_request_t") then + ffi.cdef[[ + typedef struct ngx_http_request_s ngx_http_request_t; + ]] + end + + if not pcall(ffi.typeof, "ngx_http_lua_ffi_str_t") then + ffi.cdef[[ + typedef struct { + int len; + const unsigned char *data; + } ngx_http_lua_ffi_str_t; + ]] + end + +elseif subsystem == 'stream' then + if not pcall(ffi.typeof, "ngx_stream_lua_request_t") then + ffi.cdef[[ + typedef struct ngx_stream_lua_request_s ngx_stream_lua_request_t; + ]] + end + + if not pcall(ffi.typeof, "ngx_stream_lua_ffi_str_t") then + ffi.cdef[[ + typedef struct { + int len; + const unsigned char *data; + } ngx_stream_lua_ffi_str_t; + ]] + end + +else + error("unknown subsystem: " .. subsystem) +end + + +local c_buf_type = ffi.typeof("char[?]") + + +local _M = new_tab(0, 18) + + +_M.version = "0.1.22" +_M.new_tab = new_tab +_M.clear_tab = clear_tab + + +local errmsg + + +function _M.get_errmsg_ptr() + if not errmsg then + errmsg = ffi_new("char *[1]") + end + return errmsg +end + + +if not ngx then + error("no existing ngx. table found") +end + + +function _M.set_string_buf_size(size) + if size <= 0 then + return + end + if str_buf then + str_buf = nil + end + str_buf_size = ceil(size) +end + + +function _M.get_string_buf_size() + return str_buf_size +end + + +function _M.get_size_ptr() + if not size_ptr then + size_ptr = ffi_new("size_t[1]") + end + + return size_ptr +end + + +function _M.get_string_buf(size, must_alloc) + -- ngx.log(ngx.ERR, "str buf size: ", str_buf_size) + if size > str_buf_size or must_alloc then + return ffi_new(c_buf_type, size) + end + + if not str_buf then + str_buf = ffi_new(c_buf_type, str_buf_size) + end + + return str_buf +end + + +function _M.ref_in_table(tb, key) + if key == nil then + return -1 + end + local ref = tb[FREE_LIST_REF] + if ref and ref ~= 0 then + tb[FREE_LIST_REF] = tb[ref] + + else + ref = #tb + 1 + end + tb[ref] = key + + -- print("ref key_id returned ", ref) + return ref +end + + +function _M.allows_subsystem(...) + local total = select("#", ...) + + for i = 1, total do + if select(i, ...) == subsystem then + return + end + end + + error("unsupported subsystem: " .. subsystem, 2) +end + + +_M.FFI_OK = 0 +_M.FFI_NO_REQ_CTX = -100 +_M.FFI_BAD_CONTEXT = -101 +_M.FFI_ERROR = -1 +_M.FFI_AGAIN = -2 +_M.FFI_BUSY = -3 +_M.FFI_DONE = -4 +_M.FFI_DECLINED = -5 + + +do + local exdata + + ok, exdata = pcall(require, "thread.exdata") + if ok and exdata then + function _M.get_request() + local r = exdata() + if r ~= nil then + return r + end + end + + else + local getfenv = getfenv + + function _M.get_request() + return getfenv(0).__ngx_req + end + end +end + + +return _M diff --git a/resty/core/base64.lua b/resty/core/base64.lua new file mode 100644 index 0000000..8a0e463 --- /dev/null +++ b/resty/core/base64.lua @@ -0,0 +1,115 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local base = require "resty.core.base" + + +local C = ffi.C +local ffi_string = ffi.string +local ngx = ngx +local type = type +local error = error +local floor = math.floor +local tostring = tostring +local get_string_buf = base.get_string_buf +local get_size_ptr = base.get_size_ptr +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_encode_base64 +local ngx_lua_ffi_decode_base64 + + +if subsystem == "http" then + ffi.cdef[[ + size_t ngx_http_lua_ffi_encode_base64(const unsigned char *src, + size_t len, unsigned char *dst, + int no_padding); + + int ngx_http_lua_ffi_decode_base64(const unsigned char *src, + size_t len, unsigned char *dst, + size_t *dlen); + ]] + + ngx_lua_ffi_encode_base64 = C.ngx_http_lua_ffi_encode_base64 + ngx_lua_ffi_decode_base64 = C.ngx_http_lua_ffi_decode_base64 + +elseif subsystem == "stream" then + ffi.cdef[[ + size_t ngx_stream_lua_ffi_encode_base64(const unsigned char *src, + size_t len, unsigned char *dst, + int no_padding); + + int ngx_stream_lua_ffi_decode_base64(const unsigned char *src, + size_t len, unsigned char *dst, + size_t *dlen); + ]] + + ngx_lua_ffi_encode_base64 = C.ngx_stream_lua_ffi_encode_base64 + ngx_lua_ffi_decode_base64 = C.ngx_stream_lua_ffi_decode_base64 +end + + +local function base64_encoded_length(len, no_padding) + return no_padding and floor((len * 8 + 5) / 6) or + floor((len + 2) / 3) * 4 +end + + +ngx.encode_base64 = function (s, no_padding) + if type(s) ~= 'string' then + if not s then + s = '' + else + s = tostring(s) + end + end + + local slen = #s + local no_padding_bool = false; + local no_padding_int = 0; + + if no_padding then + if no_padding ~= true then + local typ = type(no_padding) + error("bad no_padding: boolean expected, got " .. typ, 2) + end + + no_padding_bool = true + no_padding_int = 1; + end + + local dlen = base64_encoded_length(slen, no_padding_bool) + local dst = get_string_buf(dlen) + local r_dlen = ngx_lua_ffi_encode_base64(s, slen, dst, no_padding_int) + -- if dlen ~= r_dlen then error("discrepancy in len") end + return ffi_string(dst, r_dlen) +end + + +local function base64_decoded_length(len) + return floor((len + 3) / 4) * 3 +end + + +ngx.decode_base64 = function (s) + if type(s) ~= 'string' then + error("string argument only", 2) + end + local slen = #s + local dlen = base64_decoded_length(slen) + -- print("dlen: ", tonumber(dlen)) + local dst = get_string_buf(dlen) + local pdlen = get_size_ptr() + local ok = ngx_lua_ffi_decode_base64(s, slen, dst, pdlen) + if ok == 0 then + return nil + end + return ffi_string(dst, pdlen[0]) +end + + +return { + version = base.version +} diff --git a/resty/core/ctx.lua b/resty/core/ctx.lua new file mode 100644 index 0000000..1495c60 --- /dev/null +++ b/resty/core/ctx.lua @@ -0,0 +1,143 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local debug = require "debug" +local base = require "resty.core.base" +local misc = require "resty.core.misc" + + +local C = ffi.C +local register_getter = misc.register_ngx_magic_key_getter +local register_setter = misc.register_ngx_magic_key_setter +local registry = debug.getregistry() +local new_tab = base.new_tab +local ref_in_table = base.ref_in_table +local get_request = base.get_request +local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX +local FFI_OK = base.FFI_OK +local error = error +local setmetatable = setmetatable +local type = type +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_get_ctx_ref +local ngx_lua_ffi_set_ctx_ref + + +if subsystem == "http" then + ffi.cdef[[ + int ngx_http_lua_ffi_get_ctx_ref(ngx_http_request_t *r, int *in_ssl_phase, + int *ssl_ctx_ref); + int ngx_http_lua_ffi_set_ctx_ref(ngx_http_request_t *r, int ref); + ]] + + ngx_lua_ffi_get_ctx_ref = C.ngx_http_lua_ffi_get_ctx_ref + ngx_lua_ffi_set_ctx_ref = C.ngx_http_lua_ffi_set_ctx_ref + +elseif subsystem == "stream" then + ffi.cdef[[ + int ngx_stream_lua_ffi_get_ctx_ref(ngx_stream_lua_request_t *r, + int *in_ssl_phase, int *ssl_ctx_ref); + int ngx_stream_lua_ffi_set_ctx_ref(ngx_stream_lua_request_t *r, int ref); + ]] + + ngx_lua_ffi_get_ctx_ref = C.ngx_stream_lua_ffi_get_ctx_ref + ngx_lua_ffi_set_ctx_ref = C.ngx_stream_lua_ffi_set_ctx_ref +end + + +local _M = { + _VERSION = base.version +} + + +local get_ctx_table +do + local in_ssl_phase = ffi.new("int[1]") + local ssl_ctx_ref = ffi.new("int[1]") + + function get_ctx_table(ctx) + local r = get_request() + + if not r then + error("no request found") + end + + local ctx_ref = ngx_lua_ffi_get_ctx_ref(r, in_ssl_phase, ssl_ctx_ref) + if ctx_ref == FFI_NO_REQ_CTX then + error("no request ctx found") + end + + local ctxs = registry.ngx_lua_ctx_tables + if ctx_ref < 0 then + ctx_ref = ssl_ctx_ref[0] + if ctx_ref > 0 and ctxs[ctx_ref] then + if in_ssl_phase[0] ~= 0 then + return ctxs[ctx_ref] + end + + if not ctx then + ctx = new_tab(0, 4) + end + + ctx = setmetatable(ctx, ctxs[ctx_ref]) + + else + if in_ssl_phase[0] ~= 0 then + if not ctx then + ctx = new_tab(1, 4) + end + + -- to avoid creating another table, we assume the users + -- won't overwrite the `__index` key + ctx.__index = ctx + + elseif not ctx then + ctx = new_tab(0, 4) + end + end + + ctx_ref = ref_in_table(ctxs, ctx) + if ngx_lua_ffi_set_ctx_ref(r, ctx_ref) ~= FFI_OK then + return nil + end + return ctx + end + return ctxs[ctx_ref] + end +end +register_getter("ctx", get_ctx_table) +_M.get_ctx_table = get_ctx_table + + +local function set_ctx_table(ctx) + local ctx_type = type(ctx) + if ctx_type ~= "table" then + error("ctx should be a table while getting a " .. ctx_type) + end + + local r = get_request() + + if not r then + error("no request found") + end + + local ctx_ref = ngx_lua_ffi_get_ctx_ref(r, nil, nil) + if ctx_ref == FFI_NO_REQ_CTX then + error("no request ctx found") + end + + local ctxs = registry.ngx_lua_ctx_tables + if ctx_ref < 0 then + ctx_ref = ref_in_table(ctxs, ctx) + ngx_lua_ffi_set_ctx_ref(r, ctx_ref) + return + end + ctxs[ctx_ref] = ctx +end +register_setter("ctx", set_ctx_table) + + +return _M diff --git a/resty/core/exit.lua b/resty/core/exit.lua new file mode 100644 index 0000000..30a7b61 --- /dev/null +++ b/resty/core/exit.lua @@ -0,0 +1,66 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local base = require "resty.core.base" + + +local C = ffi.C +local ffi_string = ffi.string +local ngx = ngx +local error = error +local get_string_buf = base.get_string_buf +local get_size_ptr = base.get_size_ptr +local get_request = base.get_request +local co_yield = coroutine._yield +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_exit + + +if subsystem == "http" then + ffi.cdef[[ + int ngx_http_lua_ffi_exit(ngx_http_request_t *r, int status, + unsigned char *err, size_t *errlen); + ]] + + ngx_lua_ffi_exit = C.ngx_http_lua_ffi_exit + +elseif subsystem == "stream" then + ffi.cdef[[ + int ngx_stream_lua_ffi_exit(ngx_stream_lua_request_t *r, int status, + unsigned char *err, size_t *errlen); + ]] + + ngx_lua_ffi_exit = C.ngx_stream_lua_ffi_exit +end + + +local ERR_BUF_SIZE = 128 +local FFI_DONE = base.FFI_DONE + + +ngx.exit = function (rc) + local err = get_string_buf(ERR_BUF_SIZE) + local errlen = get_size_ptr() + local r = get_request() + if r == nil then + error("no request found") + end + errlen[0] = ERR_BUF_SIZE + rc = ngx_lua_ffi_exit(r, rc, err, errlen) + if rc == 0 then + -- print("yielding...") + return co_yield() + end + if rc == FFI_DONE then + return + end + error(ffi_string(err, errlen[0]), 2) +end + + +return { + version = base.version +} diff --git a/resty/core/hash.lua b/resty/core/hash.lua new file mode 100644 index 0000000..062f3ff --- /dev/null +++ b/resty/core/hash.lua @@ -0,0 +1,154 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local base = require "resty.core.base" + + +local C = ffi.C +local ffi_new = ffi.new +local ffi_string = ffi.string +local ngx = ngx +local type = type +local error = error +local tostring = tostring +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_md5 +local ngx_lua_ffi_md5_bin +local ngx_lua_ffi_sha1_bin +local ngx_lua_ffi_crc32_long +local ngx_lua_ffi_crc32_short + + +if subsystem == "http" then + ffi.cdef[[ + void ngx_http_lua_ffi_md5_bin(const unsigned char *src, size_t len, + unsigned char *dst); + + void ngx_http_lua_ffi_md5(const unsigned char *src, size_t len, + unsigned char *dst); + + int ngx_http_lua_ffi_sha1_bin(const unsigned char *src, size_t len, + unsigned char *dst); + + unsigned int ngx_http_lua_ffi_crc32_long(const unsigned char *src, + size_t len); + + unsigned int ngx_http_lua_ffi_crc32_short(const unsigned char *src, + size_t len); + ]] + + ngx_lua_ffi_md5 = C.ngx_http_lua_ffi_md5 + ngx_lua_ffi_md5_bin = C.ngx_http_lua_ffi_md5_bin + ngx_lua_ffi_sha1_bin = C.ngx_http_lua_ffi_sha1_bin + ngx_lua_ffi_crc32_short = C.ngx_http_lua_ffi_crc32_short + ngx_lua_ffi_crc32_long = C.ngx_http_lua_ffi_crc32_long + +elseif subsystem == "stream" then + ffi.cdef[[ + void ngx_stream_lua_ffi_md5_bin(const unsigned char *src, size_t len, + unsigned char *dst); + + void ngx_stream_lua_ffi_md5(const unsigned char *src, size_t len, + unsigned char *dst); + + int ngx_stream_lua_ffi_sha1_bin(const unsigned char *src, size_t len, + unsigned char *dst); + + unsigned int ngx_stream_lua_ffi_crc32_long(const unsigned char *src, + size_t len); + + unsigned int ngx_stream_lua_ffi_crc32_short(const unsigned char *src, + size_t len); + ]] + + ngx_lua_ffi_md5 = C.ngx_stream_lua_ffi_md5 + ngx_lua_ffi_md5_bin = C.ngx_stream_lua_ffi_md5_bin + ngx_lua_ffi_sha1_bin = C.ngx_stream_lua_ffi_sha1_bin + ngx_lua_ffi_crc32_short = C.ngx_stream_lua_ffi_crc32_short + ngx_lua_ffi_crc32_long = C.ngx_stream_lua_ffi_crc32_long +end + + +local MD5_DIGEST_LEN = 16 +local md5_buf = ffi_new("unsigned char[?]", MD5_DIGEST_LEN) + +ngx.md5_bin = function (s) + if type(s) ~= 'string' then + if not s then + s = '' + else + s = tostring(s) + end + end + ngx_lua_ffi_md5_bin(s, #s, md5_buf) + return ffi_string(md5_buf, MD5_DIGEST_LEN) +end + + +local MD5_HEX_DIGEST_LEN = MD5_DIGEST_LEN * 2 +local md5_hex_buf = ffi_new("unsigned char[?]", MD5_HEX_DIGEST_LEN) + +ngx.md5 = function (s) + if type(s) ~= 'string' then + if not s then + s = '' + else + s = tostring(s) + end + end + ngx_lua_ffi_md5(s, #s, md5_hex_buf) + return ffi_string(md5_hex_buf, MD5_HEX_DIGEST_LEN) +end + + +local SHA_DIGEST_LEN = 20 +local sha_buf = ffi_new("unsigned char[?]", SHA_DIGEST_LEN) + +ngx.sha1_bin = function (s) + if type(s) ~= 'string' then + if not s then + s = '' + else + s = tostring(s) + end + end + local ok = ngx_lua_ffi_sha1_bin(s, #s, sha_buf) + if ok == 0 then + error("SHA-1 support missing in Nginx") + end + return ffi_string(sha_buf, SHA_DIGEST_LEN) +end + + +ngx.crc32_short = function (s) + if type(s) ~= "string" then + if not s then + s = "" + else + s = tostring(s) + end + end + + return ngx_lua_ffi_crc32_short(s, #s) +end + + +ngx.crc32_long = function (s) + if type(s) ~= "string" then + if not s then + s = "" + else + s = tostring(s) + end + end + + return ngx_lua_ffi_crc32_long(s, #s) +end + + +return { + version = base.version +} diff --git a/resty/core/misc.lua b/resty/core/misc.lua new file mode 100644 index 0000000..ff7954a --- /dev/null +++ b/resty/core/misc.lua @@ -0,0 +1,240 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local base = require "resty.core.base" +local ffi = require "ffi" +local os = require "os" + + +local C = ffi.C +local ffi_new = ffi.new +local ffi_str = ffi.string +local ngx = ngx +local type = type +local error = error +local rawget = rawget +local rawset = rawset +local tonumber = tonumber +local setmetatable = setmetatable +local FFI_OK = base.FFI_OK +local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX +local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT +local new_tab = base.new_tab +local get_request = base.get_request +local get_size_ptr = base.get_size_ptr +local get_string_buf = base.get_string_buf +local get_string_buf_size = base.get_string_buf_size +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_get_resp_status +local ngx_lua_ffi_get_conf_env +local ngx_magic_key_getters +local ngx_magic_key_setters + + +local _M = new_tab(0, 3) +local ngx_mt = new_tab(0, 2) + + +if subsystem == "http" then + ngx_magic_key_getters = new_tab(0, 4) + ngx_magic_key_setters = new_tab(0, 2) + +elseif subsystem == "stream" then + ngx_magic_key_getters = new_tab(0, 2) + ngx_magic_key_setters = new_tab(0, 1) +end + + +local function register_getter(key, func) + ngx_magic_key_getters[key] = func +end +_M.register_ngx_magic_key_getter = register_getter + + +local function register_setter(key, func) + ngx_magic_key_setters[key] = func +end +_M.register_ngx_magic_key_setter = register_setter + + +ngx_mt.__index = function (tb, key) + local f = ngx_magic_key_getters[key] + if f then + return f() + end + return rawget(tb, key) +end + + +ngx_mt.__newindex = function (tb, key, ctx) + local f = ngx_magic_key_setters[key] + if f then + return f(ctx) + end + return rawset(tb, key, ctx) +end + + +setmetatable(ngx, ngx_mt) + + +if subsystem == "http" then + ffi.cdef[[ + int ngx_http_lua_ffi_get_resp_status(ngx_http_request_t *r); + int ngx_http_lua_ffi_set_resp_status(ngx_http_request_t *r, int r); + int ngx_http_lua_ffi_is_subrequest(ngx_http_request_t *r); + int ngx_http_lua_ffi_headers_sent(ngx_http_request_t *r); + int ngx_http_lua_ffi_get_conf_env(const unsigned char *name, + unsigned char **env_buf, + size_t *name_len); + ]] + + + ngx_lua_ffi_get_resp_status = C.ngx_http_lua_ffi_get_resp_status + ngx_lua_ffi_get_conf_env = C.ngx_http_lua_ffi_get_conf_env + + + -- ngx.status + + + local function set_status(status) + local r = get_request() + + if not r then + error("no request found") + end + + if type(status) ~= 'number' then + status = tonumber(status) + end + + local rc = C.ngx_http_lua_ffi_set_resp_status(r, status) + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + return + end + register_setter("status", set_status) + + + -- ngx.is_subrequest + + + local function is_subreq() + local r = get_request() + + if not r then + error("no request found") + end + + local rc = C.ngx_http_lua_ffi_is_subrequest(r) + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + return rc == 1 + end + register_getter("is_subrequest", is_subreq) + + + -- ngx.headers_sent + + + local function headers_sent() + local r = get_request() + + if not r then + error("no request found") + end + + local rc = C.ngx_http_lua_ffi_headers_sent(r) + + if rc == FFI_NO_REQ_CTX then + error("no request ctx found") + end + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + return rc == 1 + end + register_getter("headers_sent", headers_sent) + +elseif subsystem == "stream" then + ffi.cdef[[ + int ngx_stream_lua_ffi_get_resp_status(ngx_stream_lua_request_t *r); + int ngx_stream_lua_ffi_get_conf_env(const unsigned char *name, + unsigned char **env_buf, + size_t *name_len); + ]] + + ngx_lua_ffi_get_resp_status = C.ngx_stream_lua_ffi_get_resp_status + ngx_lua_ffi_get_conf_env = C.ngx_stream_lua_ffi_get_conf_env +end + + +-- ngx.status + + +local function get_status() + local r = get_request() + + if not r then + error("no request found") + end + + local rc = ngx_lua_ffi_get_resp_status(r) + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + return rc +end +register_getter("status", get_status) + + +do + local _getenv = os.getenv + local env_ptr = ffi_new("unsigned char *[1]") + + os.getenv = function (name) + local r = get_request() + if r then + -- past init_by_lua* phase now + os.getenv = _getenv + env_ptr = nil + return os.getenv(name) + end + + local size = get_string_buf_size() + env_ptr[0] = get_string_buf(size) + local name_len_ptr = get_size_ptr() + + local rc = ngx_lua_ffi_get_conf_env(name, env_ptr, name_len_ptr) + if rc == FFI_OK then + return ffi_str(env_ptr[0] + name_len_ptr[0] + 1) + end + + -- FFI_DECLINED + + local value = _getenv(name) + if value ~= nil then + return value + end + + return nil + end +end + + +_M._VERSION = base.version + + +return _M diff --git a/resty/core/ndk.lua b/resty/core/ndk.lua new file mode 100644 index 0000000..6547fe5 --- /dev/null +++ b/resty/core/ndk.lua @@ -0,0 +1,92 @@ +-- Copyright (C) by OpenResty Inc. + + +local ffi = require 'ffi' +local base = require "resty.core.base" +base.allows_subsystem('http') + + +local C = ffi.C +local ffi_cast = ffi.cast +local ffi_new = ffi.new +local ffi_str = ffi.string +local FFI_OK = base.FFI_OK +local new_tab = base.new_tab +local get_string_buf = base.get_string_buf +local get_request = base.get_request +local setmetatable = setmetatable +local type = type +local tostring = tostring +local error = error + + +local _M = { + version = base.version +} + + +ffi.cdef[[ +typedef void * ndk_set_var_value_pt; + +int ngx_http_lua_ffi_ndk_lookup_directive(const unsigned char *var_data, + size_t var_len, ndk_set_var_value_pt *func); +int ngx_http_lua_ffi_ndk_set_var_get(ngx_http_request_t *r, + ndk_set_var_value_pt func, const unsigned char *arg_data, size_t arg_len, + ngx_http_lua_ffi_str_t *value); +]] + + +local func_p = ffi_new("void*[1]") +local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t") +local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*") + + +local function ndk_set_var_get(self, var) + if type(var) ~= "string" then + var = tostring(var) + end + + if C.ngx_http_lua_ffi_ndk_lookup_directive(var, #var, func_p) ~= FFI_OK then + error('ndk.set_var: directive "' .. var + .. '" not found or does not use ndk_set_var_value', 2) + end + + local func = func_p[0] + + return function (arg) + local r = get_request() + if not r then + error("no request found") + end + + if type(arg) ~= "string" then + arg = tostring(arg) + end + + local buf = get_string_buf(ffi_str_size) + local value = ffi_cast(ffi_str_type, buf) + local rc = C.ngx_http_lua_ffi_ndk_set_var_get(r, func, arg, #arg, value) + if rc ~= FFI_OK then + error("calling directive " .. var .. " failed with code " .. rc, 2) + end + + return ffi_str(value.data, value.len) + end +end + + +local function ndk_set_var_set() + error("not allowed", 2) +end + + +if ndk then + local mt = new_tab(0, 2) + mt.__newindex = ndk_set_var_set + mt.__index = ndk_set_var_get + + ndk.set_var = setmetatable(new_tab(0, 0), mt) +end + + +return _M diff --git a/resty/core/phase.lua b/resty/core/phase.lua new file mode 100644 index 0000000..331752a --- /dev/null +++ b/resty/core/phase.lua @@ -0,0 +1,59 @@ +local ffi = require 'ffi' +local base = require "resty.core.base" + +local C = ffi.C +local FFI_ERROR = base.FFI_ERROR +local get_request = base.get_request +local error = error +local tostring = tostring + + +ffi.cdef[[ +int ngx_http_lua_ffi_get_phase(ngx_http_request_t *r, char **err) +]] + + +local errmsg = base.get_errmsg_ptr() +local context_names = { + [0x0001] = "set", + [0x0002] = "rewrite", + [0x0004] = "access", + [0x0008] = "content", + [0x0010] = "log", + [0x0020] = "header_filter", + [0x0040] = "body_filter", + [0x0080] = "timer", + [0x0100] = "init_worker", + [0x0200] = "balancer", + [0x0400] = "ssl_cert", + [0x0800] = "ssl_session_store", + [0x1000] = "ssl_session_fetch", + [0x2000] = "exit_worker", +} + + +function ngx.get_phase() + local r = get_request() + + -- if we have no request object, assume we are called from the "init" phase + if not r then + return "init" + end + + local context = C.ngx_http_lua_ffi_get_phase(r, errmsg) + if context == FFI_ERROR then -- NGX_ERROR + error(errmsg, 2) + end + + local phase = context_names[context] + if not phase then + error("unknown phase: " .. tostring(context)) + end + + return phase +end + + +return { + version = base.version +} diff --git a/resty/core/regex.lua b/resty/core/regex.lua new file mode 100644 index 0000000..6a14416 --- /dev/null +++ b/resty/core/regex.lua @@ -0,0 +1,1213 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require 'ffi' +local base = require "resty.core.base" +local bit = require "bit" +local subsystem = ngx.config.subsystem +require "resty.core.time" -- for ngx.now used by resty.lrucache + +if subsystem == 'http' then + require "resty.core.phase" -- for ngx.get_phase +end + +local lrucache = require "resty.lrucache" + +local lrucache_get = lrucache.get +local lrucache_set = lrucache.set +local ffi_string = ffi.string +local ffi_gc = ffi.gc +local ffi_copy = ffi.copy +local ffi_cast = ffi.cast +local C = ffi.C +local bor = bit.bor +local band = bit.band +local lshift = bit.lshift +local sub = string.sub +local fmt = string.format +local byte = string.byte +local ngx = ngx +local type = type +local tostring = tostring +local error = error +local setmetatable = setmetatable +local tonumber = tonumber +local get_string_buf = base.get_string_buf +local get_string_buf_size = base.get_string_buf_size +local new_tab = base.new_tab +local ngx_phase = ngx.get_phase +local ngx_log = ngx.log +local ngx_NOTICE = ngx.NOTICE + + +local _M = { + version = base.version +} + + +ngx.re = new_tab(0, 5) + + +local pcre_ver_fn + +if subsystem == 'http' then + ffi.cdef[[ + const char *ngx_http_lua_ffi_pcre_version(void); + ]] + pcre_ver_fn = C.ngx_http_lua_ffi_pcre_version + +elseif subsystem == 'stream' then + ffi.cdef[[ + const char *ngx_stream_lua_ffi_pcre_version(void); + ]] + pcre_ver_fn = C.ngx_stream_lua_ffi_pcre_version + +else + error("unsupported subsystem: " .. tostring(subsystem)) +end + +local pcre_ver + +if not pcall(function() pcre_ver = ffi_string(pcre_ver_fn()) end) then + setmetatable(ngx.re, { + __index = function(_, key) + error("no support for 'ngx.re." .. key .. "': OpenResty was " .. + "compiled without PCRE support", 2) + end + }) + + _M.no_pcre = true + + return _M +end + + +local MAX_ERR_MSG_LEN = 128 + + +local FLAG_COMPILE_ONCE = 0x01 +local FLAG_DFA = 0x02 +local FLAG_JIT = 0x04 +local FLAG_DUPNAMES = 0x08 +local FLAG_NO_UTF8_CHECK = 0x10 + + +local PCRE_CASELESS = 0x0000001 +local PCRE_MULTILINE = 0x0000002 +local PCRE_DOTALL = 0x0000004 +local PCRE_EXTENDED = 0x0000008 +local PCRE_ANCHORED = 0x0000010 +local PCRE_UTF8 = 0x0000800 +local PCRE_DUPNAMES = 0x0080000 +local PCRE_JAVASCRIPT_COMPAT = 0x2000000 + + +local PCRE_ERROR_NOMATCH = -1 + + +local regex_match_cache +local regex_sub_func_cache = new_tab(0, 4) +local regex_sub_str_cache = new_tab(0, 4) +local max_regex_cache_size +local regex_cache_size = 0 +local script_engine +local ngx_lua_ffi_max_regex_cache_size +local ngx_lua_ffi_destroy_regex +local ngx_lua_ffi_compile_regex +local ngx_lua_ffi_exec_regex +local ngx_lua_ffi_create_script_engine +local ngx_lua_ffi_destroy_script_engine +local ngx_lua_ffi_init_script_engine +local ngx_lua_ffi_compile_replace_template +local ngx_lua_ffi_script_eval_len +local ngx_lua_ffi_script_eval_data + +-- PCRE 8.43 on macOS introduced the MAP_JIT option when creating the memory +-- region used to store JIT compiled code, which does not survive across +-- `fork()`, causing further usage of PCRE JIT compiler to segfault in worker +-- processes. +-- +-- This flag prevents any regex used in the init phase to be JIT compiled or +-- cached when running under macOS, even if the user requests so. Caching is +-- thus disabled to prevent further calls of same regex in worker to have poor +-- performance. +-- +-- TODO: improve this workaround when PCRE allows for unspecifying the MAP_JIT +-- option. +local no_jit_in_init + +if jit.os == "OSX" then + local maj, min = string.match(pcre_ver, "^(%d+)%.(%d+)") + if maj and min then + local pcre_ver_num = tonumber(maj .. min) + + if pcre_ver_num >= 843 then + no_jit_in_init = true + end + + else + -- assume this version is faulty as well + no_jit_in_init = true + end +end + + +if subsystem == 'http' then + ffi.cdef[[ + + typedef struct { + ngx_str_t value; + void *lengths; + void *values; + } ngx_http_lua_complex_value_t; + + typedef struct { + void *pool; + unsigned char *name_table; + int name_count; + int name_entry_size; + + int ncaptures; + int *captures; + + void *regex; + void *regex_sd; + + ngx_http_lua_complex_value_t *replace; + + const char *pattern; + } ngx_http_lua_regex_t; + + ngx_http_lua_regex_t * + ngx_http_lua_ffi_compile_regex(const unsigned char *pat, + size_t pat_len, int flags, + int pcre_opts, unsigned char *errstr, + size_t errstr_size); + + int ngx_http_lua_ffi_exec_regex(ngx_http_lua_regex_t *re, int flags, + const unsigned char *s, size_t len, int pos); + + void ngx_http_lua_ffi_destroy_regex(ngx_http_lua_regex_t *re); + + int ngx_http_lua_ffi_compile_replace_template(ngx_http_lua_regex_t *re, + const unsigned char + *replace_data, + size_t replace_len); + + struct ngx_http_lua_script_engine_s; + typedef struct ngx_http_lua_script_engine_s *ngx_http_lua_script_engine_t; + + ngx_http_lua_script_engine_t *ngx_http_lua_ffi_create_script_engine(void); + + void ngx_http_lua_ffi_init_script_engine(ngx_http_lua_script_engine_t *e, + const unsigned char *subj, + ngx_http_lua_regex_t *compiled, + int count); + + void ngx_http_lua_ffi_destroy_script_engine( + ngx_http_lua_script_engine_t *e); + + size_t ngx_http_lua_ffi_script_eval_len(ngx_http_lua_script_engine_t *e, + ngx_http_lua_complex_value_t *cv); + + size_t ngx_http_lua_ffi_script_eval_data(ngx_http_lua_script_engine_t *e, + ngx_http_lua_complex_value_t *cv, + unsigned char *dst); + + uint32_t ngx_http_lua_ffi_max_regex_cache_size(void); + ]] + + ngx_lua_ffi_max_regex_cache_size = C.ngx_http_lua_ffi_max_regex_cache_size + ngx_lua_ffi_destroy_regex = C.ngx_http_lua_ffi_destroy_regex + ngx_lua_ffi_compile_regex = C.ngx_http_lua_ffi_compile_regex + ngx_lua_ffi_exec_regex = C.ngx_http_lua_ffi_exec_regex + ngx_lua_ffi_create_script_engine = C.ngx_http_lua_ffi_create_script_engine + ngx_lua_ffi_init_script_engine = C.ngx_http_lua_ffi_init_script_engine + ngx_lua_ffi_destroy_script_engine = C.ngx_http_lua_ffi_destroy_script_engine + ngx_lua_ffi_compile_replace_template = + C.ngx_http_lua_ffi_compile_replace_template + ngx_lua_ffi_script_eval_len = C.ngx_http_lua_ffi_script_eval_len + ngx_lua_ffi_script_eval_data = C.ngx_http_lua_ffi_script_eval_data + +elseif subsystem == 'stream' then + ffi.cdef[[ + + typedef struct { + ngx_str_t value; + void *lengths; + void *values; + } ngx_stream_lua_complex_value_t; + + typedef struct { + void *pool; + unsigned char *name_table; + int name_count; + int name_entry_size; + + int ncaptures; + int *captures; + + void *regex; + void *regex_sd; + + ngx_stream_lua_complex_value_t *replace; + + const char *pattern; + } ngx_stream_lua_regex_t; + + ngx_stream_lua_regex_t * + ngx_stream_lua_ffi_compile_regex(const unsigned char *pat, + size_t pat_len, int flags, + int pcre_opts, unsigned char *errstr, + size_t errstr_size); + + int ngx_stream_lua_ffi_exec_regex(ngx_stream_lua_regex_t *re, int flags, + const unsigned char *s, size_t len, int pos); + + void ngx_stream_lua_ffi_destroy_regex(ngx_stream_lua_regex_t *re); + + int ngx_stream_lua_ffi_compile_replace_template(ngx_stream_lua_regex_t *re, + const unsigned char + *replace_data, + size_t replace_len); + + struct ngx_stream_lua_script_engine_s; + typedef struct ngx_stream_lua_script_engine_s + *ngx_stream_lua_script_engine_t; + + ngx_stream_lua_script_engine_t * + ngx_stream_lua_ffi_create_script_engine(void); + + void ngx_stream_lua_ffi_init_script_engine( + ngx_stream_lua_script_engine_t *e, const unsigned char *subj, + ngx_stream_lua_regex_t *compiled, int count); + + void ngx_stream_lua_ffi_destroy_script_engine( + ngx_stream_lua_script_engine_t *e); + + size_t ngx_stream_lua_ffi_script_eval_len( + ngx_stream_lua_script_engine_t *e, ngx_stream_lua_complex_value_t *cv); + + size_t ngx_stream_lua_ffi_script_eval_data( + ngx_stream_lua_script_engine_t *e, ngx_stream_lua_complex_value_t *cv, + unsigned char *dst); + + uint32_t ngx_stream_lua_ffi_max_regex_cache_size(void); + ]] + + ngx_lua_ffi_max_regex_cache_size = C.ngx_stream_lua_ffi_max_regex_cache_size + ngx_lua_ffi_destroy_regex = C.ngx_stream_lua_ffi_destroy_regex + ngx_lua_ffi_compile_regex = C.ngx_stream_lua_ffi_compile_regex + ngx_lua_ffi_exec_regex = C.ngx_stream_lua_ffi_exec_regex + ngx_lua_ffi_create_script_engine = C.ngx_stream_lua_ffi_create_script_engine + ngx_lua_ffi_init_script_engine = C.ngx_stream_lua_ffi_init_script_engine + ngx_lua_ffi_destroy_script_engine = + C.ngx_stream_lua_ffi_destroy_script_engine + ngx_lua_ffi_compile_replace_template = + C.ngx_stream_lua_ffi_compile_replace_template + ngx_lua_ffi_script_eval_len = C.ngx_stream_lua_ffi_script_eval_len + ngx_lua_ffi_script_eval_data = C.ngx_stream_lua_ffi_script_eval_data +end + + +local c_str_type = ffi.typeof("const char *") + +local cached_re_opts = new_tab(0, 4) + +local buf_grow_ratio = 2 + + +function _M.set_buf_grow_ratio(ratio) + buf_grow_ratio = ratio +end + + +local function get_max_regex_cache_size() + if max_regex_cache_size then + return max_regex_cache_size + end + max_regex_cache_size = ngx_lua_ffi_max_regex_cache_size() + return max_regex_cache_size +end + + +local regex_cache_is_empty = true + + +function _M.is_regex_cache_empty() + return regex_cache_is_empty +end + + +local function lrucache_set_wrapper(...) + regex_cache_is_empty = false + lrucache_set(...) +end + + +local parse_regex_opts = function (opts) + local t = cached_re_opts[opts] + if t then + return t[1], t[2] + end + + local flags = 0 + local pcre_opts = 0 + local len = #opts + + for i = 1, len do + local opt = byte(opts, i) + if opt == byte("o") then + flags = bor(flags, FLAG_COMPILE_ONCE) + + elseif opt == byte("j") then + flags = bor(flags, FLAG_JIT) + + elseif opt == byte("i") then + pcre_opts = bor(pcre_opts, PCRE_CASELESS) + + elseif opt == byte("s") then + pcre_opts = bor(pcre_opts, PCRE_DOTALL) + + elseif opt == byte("m") then + pcre_opts = bor(pcre_opts, PCRE_MULTILINE) + + elseif opt == byte("u") then + pcre_opts = bor(pcre_opts, PCRE_UTF8) + + elseif opt == byte("U") then + pcre_opts = bor(pcre_opts, PCRE_UTF8) + flags = bor(flags, FLAG_NO_UTF8_CHECK) + + elseif opt == byte("x") then + pcre_opts = bor(pcre_opts, PCRE_EXTENDED) + + elseif opt == byte("d") then + flags = bor(flags, FLAG_DFA) + + elseif opt == byte("a") then + pcre_opts = bor(pcre_opts, PCRE_ANCHORED) + + elseif opt == byte("D") then + pcre_opts = bor(pcre_opts, PCRE_DUPNAMES) + flags = bor(flags, FLAG_DUPNAMES) + + elseif opt == byte("J") then + pcre_opts = bor(pcre_opts, PCRE_JAVASCRIPT_COMPAT) + + else + error(fmt('unknown flag "%s" (flags "%s")', sub(opts, i, i), opts), + 3) + end + end + + cached_re_opts[opts] = {flags, pcre_opts} + return flags, pcre_opts +end + + +if no_jit_in_init then + local parse_regex_opts_ = parse_regex_opts + + parse_regex_opts = function (opts) + if ngx_phase() ~= "init" then + -- past init_by_lua* phase now + parse_regex_opts = parse_regex_opts_ + return parse_regex_opts(opts) + end + + local t = cached_re_opts[opts] + if t then + return t[1], t[2] + end + + local flags = 0 + local pcre_opts = 0 + local len = #opts + + for i = 1, len do + local opt = byte(opts, i) + if opt == byte("o") then + ngx_log(ngx_NOTICE, "regex compilation cache disabled in init ", + "phase under macOS") + + elseif opt == byte("j") then + ngx_log(ngx_NOTICE, "regex compilation disabled in init ", + "phase under macOS") + + elseif opt == byte("i") then + pcre_opts = bor(pcre_opts, PCRE_CASELESS) + + elseif opt == byte("s") then + pcre_opts = bor(pcre_opts, PCRE_DOTALL) + + elseif opt == byte("m") then + pcre_opts = bor(pcre_opts, PCRE_MULTILINE) + + elseif opt == byte("u") then + pcre_opts = bor(pcre_opts, PCRE_UTF8) + + elseif opt == byte("U") then + pcre_opts = bor(pcre_opts, PCRE_UTF8) + flags = bor(flags, FLAG_NO_UTF8_CHECK) + + elseif opt == byte("x") then + pcre_opts = bor(pcre_opts, PCRE_EXTENDED) + + elseif opt == byte("d") then + flags = bor(flags, FLAG_DFA) + + elseif opt == byte("a") then + pcre_opts = bor(pcre_opts, PCRE_ANCHORED) + + elseif opt == byte("D") then + pcre_opts = bor(pcre_opts, PCRE_DUPNAMES) + flags = bor(flags, FLAG_DUPNAMES) + + elseif opt == byte("J") then + pcre_opts = bor(pcre_opts, PCRE_JAVASCRIPT_COMPAT) + + else + error(fmt('unknown flag "%s" (flags "%s")', sub(opts, i, i), + opts), 3) + end + end + + cached_re_opts[opts] = {flags, pcre_opts} + return flags, pcre_opts + end +end + + +local function collect_named_captures(compiled, flags, res) + local name_count = compiled.name_count + local name_table = compiled.name_table + local entry_size = compiled.name_entry_size + + local ind = 0 + local dup_names = (band(flags, FLAG_DUPNAMES) ~= 0) + for i = 1, name_count do + local n = bor(lshift(name_table[ind], 8), name_table[ind + 1]) + -- ngx.say("n = ", n) + local name = ffi_string(name_table + ind + 2) + local cap = res[n] + if dup_names then + -- unmatched captures (false) are not collected + if cap then + local old = res[name] + if old then + old[#old + 1] = cap + else + res[name] = {cap} + end + end + else + res[name] = cap + end + + ind = ind + entry_size + end +end + + +local function collect_captures(compiled, rc, subj, flags, res) + local cap = compiled.captures + local ncap = compiled.ncaptures + local name_count = compiled.name_count + + if not res then + res = new_tab(ncap, name_count) + end + + local i = 0 + local n = 0 + while i <= ncap do + if i > rc then + res[i] = false + else + local from = cap[n] + if from >= 0 then + local to = cap[n + 1] + res[i] = sub(subj, from + 1, to) + else + res[i] = false + end + end + i = i + 1 + n = n + 2 + end + + if name_count > 0 then + collect_named_captures(compiled, flags, res) + end + + return res +end + + +_M.collect_captures = collect_captures + + +local function destroy_compiled_regex(compiled) + ngx_lua_ffi_destroy_regex(ffi_gc(compiled, nil)) +end + + +_M.destroy_compiled_regex = destroy_compiled_regex + + +local function re_match_compile(regex, opts) + local flags = 0 + local pcre_opts = 0 + + if opts then + flags, pcre_opts = parse_regex_opts(opts) + else + opts = "" + end + + local compiled, key + local compile_once = (band(flags, FLAG_COMPILE_ONCE) == 1) + + -- FIXME: better put this in the outer scope when fixing the ngx.re API's + -- compatibility in the init_by_lua* context. + if not regex_match_cache then + local sz = get_max_regex_cache_size() + if sz <= 0 then + compile_once = false + else + regex_match_cache = lrucache.new(sz) + end + end + + if compile_once then + key = regex .. '\0' .. opts + compiled = lrucache_get(regex_match_cache, key) + end + + -- compile the regex + + if compiled == nil then + -- print("compiled regex not found, compiling regex...") + local errbuf = get_string_buf(MAX_ERR_MSG_LEN) + + compiled = ngx_lua_ffi_compile_regex(regex, #regex, flags, + pcre_opts, errbuf, + MAX_ERR_MSG_LEN) + + if compiled == nil then + return nil, ffi_string(errbuf) + end + + ffi_gc(compiled, ngx_lua_ffi_destroy_regex) + + -- print("ncaptures: ", compiled.ncaptures) + + if compile_once then + -- print("inserting compiled regex into cache") + lrucache_set_wrapper(regex_match_cache, key, compiled) + end + end + + return compiled, compile_once, flags +end + + +_M.re_match_compile = re_match_compile + + +local function re_match_helper(subj, regex, opts, ctx, want_caps, res, nth) + -- we need to cast this to strings to avoid exceptions when they are + -- something else. + subj = tostring(subj) + + local compiled, compile_once, flags = re_match_compile(regex, opts) + if compiled == nil then + -- compiled_once holds the error string + if not want_caps then + return nil, nil, compile_once + end + return nil, compile_once + end + + -- exec the compiled regex + + local rc + do + local pos + if ctx then + pos = ctx.pos + if not pos or pos <= 0 then + pos = 0 + else + pos = pos - 1 + end + + else + pos = 0 + end + + rc = ngx_lua_ffi_exec_regex(compiled, flags, subj, #subj, pos) + end + + if rc == PCRE_ERROR_NOMATCH then + if not compile_once then + destroy_compiled_regex(compiled) + end + return nil + end + + if rc < 0 then + if not compile_once then + destroy_compiled_regex(compiled) + end + if not want_caps then + return nil, nil, "pcre_exec() failed: " .. rc + end + return nil, "pcre_exec() failed: " .. rc + end + + if rc == 0 then + if band(flags, FLAG_DFA) == 0 then + if not want_caps then + return nil, nil, "capture size too small" + end + return nil, "capture size too small" + end + + rc = 1 + end + + -- print("cap 0: ", compiled.captures[0]) + -- print("cap 1: ", compiled.captures[1]) + + if ctx then + ctx.pos = compiled.captures[1] + 1 + end + + if not want_caps then + if not nth or nth < 0 then + nth = 0 + end + + if nth > compiled.ncaptures then + return nil, nil, "nth out of bound" + end + + if nth >= rc then + return nil, nil + end + + local from = compiled.captures[nth * 2] + 1 + local to = compiled.captures[nth * 2 + 1] + + if from < 0 or to < 0 then + return nil, nil + end + + return from, to + end + + res = collect_captures(compiled, rc, subj, flags, res) + + if not compile_once then + destroy_compiled_regex(compiled) + end + + return res +end + + +function ngx.re.match(subj, regex, opts, ctx, res) + return re_match_helper(subj, regex, opts, ctx, true, res) +end + + +function ngx.re.find(subj, regex, opts, ctx, nth) + return re_match_helper(subj, regex, opts, ctx, false, nil, nth) +end + + +do + local function destroy_re_gmatch_iterator(iterator) + if not iterator._compile_once then + destroy_compiled_regex(iterator._compiled) + end + iterator._compiled = nil + iterator._pos = nil + iterator._subj = nil + end + + + local function iterate_re_gmatch(self) + local compiled = self._compiled + local subj = self._subj + local subj_len = self._subj_len + local flags = self._flags + local pos = self._pos + + if not pos then + -- The iterator is exhausted. + return nil + end + + local rc = ngx_lua_ffi_exec_regex(compiled, flags, subj, subj_len, pos) + + if rc == PCRE_ERROR_NOMATCH then + destroy_re_gmatch_iterator(self) + return nil + end + + if rc < 0 then + destroy_re_gmatch_iterator(self) + return nil, "pcre_exec() failed: " .. rc + end + + if rc == 0 then + if band(flags, FLAG_DFA) == 0 then + destroy_re_gmatch_iterator(self) + return nil, "capture size too small" + end + + rc = 1 + end + + local cp_pos = tonumber(compiled.captures[1]) + if cp_pos == compiled.captures[0] then + cp_pos = cp_pos + 1 + if cp_pos > subj_len then + local res = collect_captures(compiled, rc, subj, flags) + destroy_re_gmatch_iterator(self) + return res + end + end + self._pos = cp_pos + return collect_captures(compiled, rc, subj, flags) + end + + + local re_gmatch_iterator_mt = { __call = iterate_re_gmatch } + + function ngx.re.gmatch(subj, regex, opts) + subj = tostring(subj) + + local compiled, compile_once, flags = re_match_compile(regex, opts) + if compiled == nil then + -- compiled_once holds the error string + return nil, compile_once + end + + local re_gmatch_iterator = { + _compiled = compiled, + _compile_once = compile_once, + _subj = subj, + _subj_len = #subj, + _flags = flags, + _pos = 0, + } + + return setmetatable(re_gmatch_iterator, re_gmatch_iterator_mt) + end +end -- do + + +local function new_script_engine(subj, compiled, count) + if not script_engine then + script_engine = ngx_lua_ffi_create_script_engine() + if script_engine == nil then + return nil + end + ffi_gc(script_engine, ngx_lua_ffi_destroy_script_engine) + end + + ngx_lua_ffi_init_script_engine(script_engine, subj, compiled, count) + return script_engine +end + + +local function check_buf_size(buf, buf_size, pos, len, new_len, must_alloc) + if new_len > buf_size then + buf_size = buf_size * buf_grow_ratio + if buf_size < new_len then + buf_size = new_len + end + local new_buf = get_string_buf(buf_size, must_alloc) + ffi_copy(new_buf, buf, len) + buf = new_buf + pos = buf + len + end + return buf, buf_size, pos, new_len +end + + +_M.check_buf_size = check_buf_size + + +local function re_sub_compile(regex, opts, replace, func) + local flags = 0 + local pcre_opts = 0 + + if opts then + flags, pcre_opts = parse_regex_opts(opts) + else + opts = "" + end + + local compiled + local compile_once = (band(flags, FLAG_COMPILE_ONCE) == 1) + if compile_once then + if func then + local subcache = regex_sub_func_cache[opts] + if subcache then + -- print("cache hit!") + compiled = subcache[regex] + end + + else + local subcache = regex_sub_str_cache[opts] + if subcache then + local subsubcache = subcache[regex] + if subsubcache then + -- print("cache hit!") + compiled = subsubcache[replace] + end + end + end + end + + -- compile the regex + + if compiled == nil then + -- print("compiled regex not found, compiling regex...") + local errbuf = get_string_buf(MAX_ERR_MSG_LEN) + + compiled = ngx_lua_ffi_compile_regex(regex, #regex, flags, pcre_opts, + errbuf, MAX_ERR_MSG_LEN) + + if compiled == nil then + return nil, ffi_string(errbuf) + end + + ffi_gc(compiled, ngx_lua_ffi_destroy_regex) + + if func == nil then + local rc = + ngx_lua_ffi_compile_replace_template(compiled, replace, + #replace) + if rc ~= 0 then + if not compile_once then + destroy_compiled_regex(compiled) + end + return nil, "failed to compile the replacement template" + end + end + + -- print("ncaptures: ", compiled.ncaptures) + + if compile_once then + if regex_cache_size < get_max_regex_cache_size() then + -- print("inserting compiled regex into cache") + if func then + local subcache = regex_sub_func_cache[opts] + if not subcache then + regex_sub_func_cache[opts] = {[regex] = compiled} + + else + subcache[regex] = compiled + end + + else + local subcache = regex_sub_str_cache[opts] + if not subcache then + regex_sub_str_cache[opts] = + {[regex] = {[replace] = compiled}} + + else + local subsubcache = subcache[regex] + if not subsubcache then + subcache[regex] = {[replace] = compiled} + + else + subsubcache[replace] = compiled + end + end + end + + regex_cache_size = regex_cache_size + 1 + else + compile_once = false + end + end + end + + return compiled, compile_once, flags +end + + +_M.re_sub_compile = re_sub_compile + + +local function re_sub_func_helper(subj, regex, replace, opts, global) + local compiled, compile_once, flags = + re_sub_compile(regex, opts, nil, replace) + if not compiled then + -- error string is in compile_once + return nil, nil, compile_once + end + + -- exec the compiled regex + + subj = tostring(subj) + local subj_len = #subj + local count = 0 + local pos = 0 + local cp_pos = 0 + + local dst_buf_size = get_string_buf_size() + -- Note: we have to always allocate the string buffer because + -- the user might call whatever resty.core's API functions recursively + -- in the user callback function. + local dst_buf = get_string_buf(dst_buf_size, true) + local dst_pos = dst_buf + local dst_len = 0 + + while true do + local rc = ngx_lua_ffi_exec_regex(compiled, flags, subj, subj_len, pos) + if rc == PCRE_ERROR_NOMATCH then + break + end + + if rc < 0 then + if not compile_once then + destroy_compiled_regex(compiled) + end + return nil, nil, "pcre_exec() failed: " .. rc + end + + if rc == 0 then + if band(flags, FLAG_DFA) == 0 then + if not compile_once then + destroy_compiled_regex(compiled) + end + return nil, nil, "capture size too small" + end + + rc = 1 + end + + count = count + 1 + local prefix_len = compiled.captures[0] - cp_pos + + local res = collect_captures(compiled, rc, subj, flags) + + local piece = tostring(replace(res)) + local piece_len = #piece + + local new_dst_len = dst_len + prefix_len + piece_len + dst_buf, dst_buf_size, dst_pos, dst_len = + check_buf_size(dst_buf, dst_buf_size, dst_pos, dst_len, + new_dst_len, true) + + if prefix_len > 0 then + ffi_copy(dst_pos, ffi_cast(c_str_type, subj) + cp_pos, + prefix_len) + dst_pos = dst_pos + prefix_len + end + + if piece_len > 0 then + ffi_copy(dst_pos, piece, piece_len) + dst_pos = dst_pos + piece_len + end + + cp_pos = compiled.captures[1] + pos = cp_pos + if pos == compiled.captures[0] then + pos = pos + 1 + if pos > subj_len then + break + end + end + + if not global then + break + end + end + + if not compile_once then + destroy_compiled_regex(compiled) + end + + if count > 0 then + if pos < subj_len then + local suffix_len = subj_len - cp_pos + + local new_dst_len = dst_len + suffix_len + local _ + dst_buf, _, dst_pos, dst_len = + check_buf_size(dst_buf, dst_buf_size, dst_pos, dst_len, + new_dst_len, true) + + ffi_copy(dst_pos, ffi_cast(c_str_type, subj) + cp_pos, + suffix_len) + end + return ffi_string(dst_buf, dst_len), count + end + + return subj, 0 +end + + +local function re_sub_str_helper(subj, regex, replace, opts, global) + local compiled, compile_once, flags = + re_sub_compile(regex, opts, replace, nil) + if not compiled then + -- error string is in compile_once + return nil, nil, compile_once + end + + -- exec the compiled regex + + subj = tostring(subj) + local subj_len = #subj + local count = 0 + local pos = 0 + local cp_pos = 0 + + local dst_buf_size = get_string_buf_size() + local dst_buf = get_string_buf(dst_buf_size) + local dst_pos = dst_buf + local dst_len = 0 + + while true do + local rc = ngx_lua_ffi_exec_regex(compiled, flags, subj, subj_len, pos) + if rc == PCRE_ERROR_NOMATCH then + break + end + + if rc < 0 then + if not compile_once then + destroy_compiled_regex(compiled) + end + return nil, nil, "pcre_exec() failed: " .. rc + end + + if rc == 0 then + if band(flags, FLAG_DFA) == 0 then + if not compile_once then + destroy_compiled_regex(compiled) + end + return nil, nil, "capture size too small" + end + + rc = 1 + end + + count = count + 1 + local prefix_len = compiled.captures[0] - cp_pos + + local cv = compiled.replace + if cv.lengths ~= nil then + local e = new_script_engine(subj, compiled, rc) + if e == nil then + return nil, nil, "failed to create script engine" + end + + local bit_len = ngx_lua_ffi_script_eval_len(e, cv) + local new_dst_len = dst_len + prefix_len + bit_len + dst_buf, dst_buf_size, dst_pos, dst_len = + check_buf_size(dst_buf, dst_buf_size, dst_pos, dst_len, + new_dst_len) + + if prefix_len > 0 then + ffi_copy(dst_pos, ffi_cast(c_str_type, subj) + cp_pos, + prefix_len) + dst_pos = dst_pos + prefix_len + end + + if bit_len > 0 then + ngx_lua_ffi_script_eval_data(e, cv, dst_pos) + dst_pos = dst_pos + bit_len + end + + else + local bit_len = cv.value.len + + dst_buf, dst_buf_size, dst_pos, dst_len = + check_buf_size(dst_buf, dst_buf_size, dst_pos, dst_len, + dst_len + prefix_len + bit_len) + + if prefix_len > 0 then + ffi_copy(dst_pos, ffi_cast(c_str_type, subj) + cp_pos, + prefix_len) + dst_pos = dst_pos + prefix_len + end + + if bit_len > 0 then + ffi_copy(dst_pos, cv.value.data, bit_len) + dst_pos = dst_pos + bit_len + end + end + + cp_pos = compiled.captures[1] + pos = cp_pos + if pos == compiled.captures[0] then + pos = pos + 1 + if pos > subj_len then + break + end + end + + if not global then + break + end + end + + if not compile_once then + destroy_compiled_regex(compiled) + end + + if count > 0 then + if pos < subj_len then + local suffix_len = subj_len - cp_pos + + local new_dst_len = dst_len + suffix_len + local _ + dst_buf, _, dst_pos, dst_len = + check_buf_size(dst_buf, dst_buf_size, dst_pos, dst_len, + new_dst_len) + + ffi_copy(dst_pos, ffi_cast(c_str_type, subj) + cp_pos, + suffix_len) + end + return ffi_string(dst_buf, dst_len), count + end + + return subj, 0 +end + + +local function re_sub_helper(subj, regex, replace, opts, global) + local repl_type = type(replace) + if repl_type == "function" then + return re_sub_func_helper(subj, regex, replace, opts, global) + end + + if repl_type ~= "string" then + replace = tostring(replace) + end + + return re_sub_str_helper(subj, regex, replace, opts, global) +end + + +function ngx.re.sub(subj, regex, replace, opts) + return re_sub_helper(subj, regex, replace, opts, false) +end + + +function ngx.re.gsub(subj, regex, replace, opts) + return re_sub_helper(subj, regex, replace, opts, true) +end + + +return _M diff --git a/resty/core/request.lua b/resty/core/request.lua new file mode 100644 index 0000000..747b8c2 --- /dev/null +++ b/resty/core/request.lua @@ -0,0 +1,451 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require 'ffi' +local base = require "resty.core.base" +local utils = require "resty.core.utils" + + +local subsystem = ngx.config.subsystem +local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT +local FFI_DECLINED = base.FFI_DECLINED +local FFI_OK = base.FFI_OK +local new_tab = base.new_tab +local C = ffi.C +local ffi_cast = ffi.cast +local ffi_new = ffi.new +local ffi_str = ffi.string +local get_string_buf = base.get_string_buf +local get_size_ptr = base.get_size_ptr +local setmetatable = setmetatable +local lower = string.lower +local rawget = rawget +local ngx = ngx +local get_request = base.get_request +local type = type +local error = error +local tostring = tostring +local tonumber = tonumber +local str_replace_char = utils.str_replace_char + + +local _M = { + version = base.version +} + + +local ngx_lua_ffi_req_start_time + + +if subsystem == "stream" then + ffi.cdef[[ + double ngx_stream_lua_ffi_req_start_time(ngx_stream_lua_request_t *r); + ]] + + ngx_lua_ffi_req_start_time = C.ngx_stream_lua_ffi_req_start_time + +elseif subsystem == "http" then + ffi.cdef[[ + double ngx_http_lua_ffi_req_start_time(ngx_http_request_t *r); + ]] + + ngx_lua_ffi_req_start_time = C.ngx_http_lua_ffi_req_start_time +end + + +function ngx.req.start_time() + local r = get_request() + if not r then + error("no request found") + end + + return tonumber(ngx_lua_ffi_req_start_time(r)) +end + + +if subsystem == "stream" then + return _M +end + + +local errmsg = base.get_errmsg_ptr() +local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*") +local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t") + + +ffi.cdef[[ + typedef struct { + ngx_http_lua_ffi_str_t key; + ngx_http_lua_ffi_str_t value; + } ngx_http_lua_ffi_table_elt_t; + + int ngx_http_lua_ffi_req_get_headers_count(ngx_http_request_t *r, + int max, int *truncated); + + int ngx_http_lua_ffi_req_get_headers(ngx_http_request_t *r, + ngx_http_lua_ffi_table_elt_t *out, int count, int raw); + + int ngx_http_lua_ffi_req_get_uri_args_count(ngx_http_request_t *r, + int max, int *truncated); + + size_t ngx_http_lua_ffi_req_get_querystring_len(ngx_http_request_t *r); + + int ngx_http_lua_ffi_req_get_uri_args(ngx_http_request_t *r, + unsigned char *buf, ngx_http_lua_ffi_table_elt_t *out, int count); + + int ngx_http_lua_ffi_req_get_method(ngx_http_request_t *r); + + int ngx_http_lua_ffi_req_get_method_name(ngx_http_request_t *r, + unsigned char **name, size_t *len); + + int ngx_http_lua_ffi_req_set_method(ngx_http_request_t *r, int method); + + int ngx_http_lua_ffi_req_set_header(ngx_http_request_t *r, + const unsigned char *key, size_t key_len, const unsigned char *value, + size_t value_len, ngx_http_lua_ffi_str_t *mvals, size_t mvals_len, + int override, char **errmsg); +]] + + +local table_elt_type = ffi.typeof("ngx_http_lua_ffi_table_elt_t*") +local table_elt_size = ffi.sizeof("ngx_http_lua_ffi_table_elt_t") +local truncated = ffi.new("int[1]") + +local req_headers_mt = { + __index = function (tb, key) + return rawget(tb, (str_replace_char(lower(key), '_', '-'))) + end +} + + +function ngx.req.get_headers(max_headers, raw) + local r = get_request() + if not r then + error("no request found") + end + + if not max_headers then + max_headers = -1 + end + + if not raw then + raw = 0 + else + raw = 1 + end + + local n = C.ngx_http_lua_ffi_req_get_headers_count(r, max_headers, + truncated) + if n == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + if n == 0 then + local headers = {} + if raw == 0 then + headers = setmetatable(headers, req_headers_mt) + end + + return headers + end + + local raw_buf = get_string_buf(n * table_elt_size) + local buf = ffi_cast(table_elt_type, raw_buf) + + local rc = C.ngx_http_lua_ffi_req_get_headers(r, buf, n, raw) + if rc == 0 then + local headers = new_tab(0, n) + for i = 0, n - 1 do + local h = buf[i] + + local key = h.key + key = ffi_str(key.data, key.len) + + local value = h.value + value = ffi_str(value.data, value.len) + + local existing = headers[key] + if existing then + if type(existing) == "table" then + existing[#existing + 1] = value + else + headers[key] = {existing, value} + end + + else + headers[key] = value + end + end + + if raw == 0 then + headers = setmetatable(headers, req_headers_mt) + end + + if truncated[0] ~= 0 then + return headers, "truncated" + end + + return headers + end + + return nil +end + + +function ngx.req.get_uri_args(max_args) + local r = get_request() + if not r then + error("no request found") + end + + if not max_args then + max_args = -1 + end + + local n = C.ngx_http_lua_ffi_req_get_uri_args_count(r, max_args, truncated) + if n == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + if n == 0 then + return {} + end + + local args_len = C.ngx_http_lua_ffi_req_get_querystring_len(r) + + local strbuf = get_string_buf(args_len + n * table_elt_size) + local kvbuf = ffi_cast(table_elt_type, strbuf + args_len) + + local nargs = C.ngx_http_lua_ffi_req_get_uri_args(r, strbuf, kvbuf, n) + + local args = new_tab(0, nargs) + for i = 0, nargs - 1 do + local arg = kvbuf[i] + + local key = arg.key + key = ffi_str(key.data, key.len) + + local value = arg.value + local len = value.len + if len == -1 then + value = true + else + value = ffi_str(value.data, len) + end + + local existing = args[key] + if existing then + if type(existing) == "table" then + existing[#existing + 1] = value + else + args[key] = {existing, value} + end + + else + args[key] = value + end + end + + if truncated[0] ~= 0 then + return args, "truncated" + end + + return args +end + + +do + local methods = { + [0x0002] = "GET", + [0x0004] = "HEAD", + [0x0008] = "POST", + [0x0010] = "PUT", + [0x0020] = "DELETE", + [0x0040] = "MKCOL", + [0x0080] = "COPY", + [0x0100] = "MOVE", + [0x0200] = "OPTIONS", + [0x0400] = "PROPFIND", + [0x0800] = "PROPPATCH", + [0x1000] = "LOCK", + [0x2000] = "UNLOCK", + [0x4000] = "PATCH", + [0x8000] = "TRACE", + } + + local namep = ffi_new("unsigned char *[1]") + + function ngx.req.get_method() + local r = get_request() + if not r then + error("no request found") + end + + do + local id = C.ngx_http_lua_ffi_req_get_method(r) + if id == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + local method = methods[id] + if method then + return method + end + end + + local sizep = get_size_ptr() + local rc = C.ngx_http_lua_ffi_req_get_method_name(r, namep, sizep) + if rc ~= 0 then + return nil + end + + return ffi_str(namep[0], sizep[0]) + end +end -- do + + +function ngx.req.set_method(method) + local r = get_request() + if not r then + error("no request found") + end + + if type(method) ~= "number" then + error("bad method number", 2) + end + + local rc = C.ngx_http_lua_ffi_req_set_method(r, method) + if rc == FFI_OK then + return + end + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + if rc == FFI_DECLINED then + error("unsupported HTTP method: " .. method, 2) + end + + error("unknown error: " .. rc) +end + + +do + local function set_req_header(name, value, override) + local r = get_request() + if not r then + error("no request found", 3) + end + + if name == nil then + error("bad 'name' argument: string expected, got nil", 3) + end + + if type(name) ~= "string" then + name = tostring(name) + end + + local rc + + if value == nil then + if not override then + error("bad 'value' argument: string or table expected, got nil", + 3) + end + + rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, nil, 0, nil, + 0, 1, errmsg) + + else + local sval, sval_len, mvals, mvals_len, buf + local value_type = type(value) + + if value_type == "table" then + mvals_len = #value + if mvals_len == 0 and not override then + error("bad 'value' argument: non-empty table expected", 3) + end + + buf = get_string_buf(ffi_str_size * mvals_len) + mvals = ffi_cast(ffi_str_type, buf) + + for i = 1, mvals_len do + local s = value[i] + if type(s) ~= "string" then + s = tostring(s) + value[i] = s + end + + local str = mvals[i - 1] + str.data = s + str.len = #s + end + + sval_len = 0 + + else + if value_type ~= "string" then + sval = tostring(value) + else + sval = value + end + + sval_len = #sval + mvals_len = 0 + end + + rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, sval, + sval_len, mvals, mvals_len, + override and 1 or 0, errmsg) + end + + if rc == FFI_OK or rc == FFI_DECLINED then + return + end + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 3) + end + + -- rc == FFI_ERROR + error(ffi_str(errmsg[0])) + end + + + _M.set_req_header = set_req_header + + + function ngx.req.set_header(name, value) + set_req_header(name, value, true) -- override + end +end -- do + + +function ngx.req.clear_header(name) + local r = get_request() + if not r then + error("no request found") + end + + if type(name) ~= "string" then + name = tostring(name) + end + + local rc = C.ngx_http_lua_ffi_req_set_header(r, name, #name, nil, 0, nil, 0, + 1, errmsg) + + if rc == FFI_OK or rc == FFI_DECLINED then + return + end + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + -- rc == FFI_ERROR + error(ffi_str(errmsg[0])) +end + + +return _M diff --git a/resty/core/response.lua b/resty/core/response.lua new file mode 100644 index 0000000..891a07e --- /dev/null +++ b/resty/core/response.lua @@ -0,0 +1,183 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require 'ffi' +local base = require "resty.core.base" + + +local C = ffi.C +local ffi_cast = ffi.cast +local ffi_str = ffi.string +local new_tab = base.new_tab +local FFI_BAD_CONTEXT = base.FFI_BAD_CONTEXT +local FFI_NO_REQ_CTX = base.FFI_NO_REQ_CTX +local FFI_DECLINED = base.FFI_DECLINED +local get_string_buf = base.get_string_buf +local setmetatable = setmetatable +local type = type +local tostring = tostring +local get_request = base.get_request +local error = error +local ngx = ngx + + +local _M = { + version = base.version +} + + +local MAX_HEADER_VALUES = 100 +local errmsg = base.get_errmsg_ptr() +local ffi_str_type = ffi.typeof("ngx_http_lua_ffi_str_t*") +local ffi_str_size = ffi.sizeof("ngx_http_lua_ffi_str_t") + + +ffi.cdef[[ + int ngx_http_lua_ffi_set_resp_header(ngx_http_request_t *r, + const char *key_data, size_t key_len, int is_nil, + const char *sval, size_t sval_len, ngx_http_lua_ffi_str_t *mvals, + size_t mvals_len, int override, char **errmsg); + + int ngx_http_lua_ffi_get_resp_header(ngx_http_request_t *r, + const unsigned char *key, size_t key_len, + unsigned char *key_buf, ngx_http_lua_ffi_str_t *values, + int max_nvalues, char **errmsg); +]] + + +local function set_resp_header(tb, key, value, no_override) + local r = get_request() + if not r then + error("no request found") + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local rc + if value == nil then + if no_override then + error("invalid header value", 3) + end + + rc = C.ngx_http_lua_ffi_set_resp_header(r, key, #key, true, nil, 0, nil, + 0, 1, errmsg) + else + local sval, sval_len, mvals, mvals_len, buf + + if type(value) == "table" then + mvals_len = #value + if mvals_len == 0 and no_override then + return + end + + buf = get_string_buf(ffi_str_size * mvals_len) + mvals = ffi_cast(ffi_str_type, buf) + for i = 1, mvals_len do + local s = value[i] + if type(s) ~= "string" then + s = tostring(s) + value[i] = s + end + local str = mvals[i - 1] + str.data = s + str.len = #s + end + + sval_len = 0 + + else + if type(value) ~= "string" then + sval = tostring(value) + else + sval = value + end + sval_len = #sval + + mvals_len = 0 + end + + local override_int = no_override and 0 or 1 + rc = C.ngx_http_lua_ffi_set_resp_header(r, key, #key, false, sval, + sval_len, mvals, mvals_len, + override_int, errmsg) + end + + if rc == 0 or rc == FFI_DECLINED then + return + end + + if rc == FFI_NO_REQ_CTX then + error("no request ctx found") + end + + if rc == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + -- rc == FFI_ERROR + error(ffi_str(errmsg[0]), 2) +end + + +_M.set_resp_header = set_resp_header + + +local function get_resp_header(tb, key) + local r = get_request() + if not r then + error("no request found") + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local key_len = #key + + local key_buf = get_string_buf(key_len + ffi_str_size * MAX_HEADER_VALUES) + local values = ffi_cast(ffi_str_type, key_buf + key_len) + local n = C.ngx_http_lua_ffi_get_resp_header(r, key, key_len, key_buf, + values, MAX_HEADER_VALUES, + errmsg) + + -- print("retval: ", n) + + if n == FFI_BAD_CONTEXT then + error("API disabled in the current context", 2) + end + + if n == 0 then + return nil + end + + if n == 1 then + local v = values[0] + return ffi_str(v.data, v.len) + end + + if n > 0 then + local ret = new_tab(n, 0) + for i = 1, n do + local v = values[i - 1] + ret[i] = ffi_str(v.data, v.len) + end + return ret + end + + -- n == FFI_ERROR + error(ffi_str(errmsg[0]), 2) +end + + +do + local mt = new_tab(0, 2) + mt.__newindex = set_resp_header + mt.__index = get_resp_header + + ngx.header = setmetatable(new_tab(0, 0), mt) +end + + +return _M diff --git a/resty/core/shdict.lua b/resty/core/shdict.lua new file mode 100644 index 0000000..dedf12c --- /dev/null +++ b/resty/core/shdict.lua @@ -0,0 +1,638 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require 'ffi' +local base = require "resty.core.base" + + +local _M = { + version = base.version +} + +local ngx_shared = ngx.shared +if not ngx_shared then + return _M +end + + +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +local get_string_buf = base.get_string_buf +local get_string_buf_size = base.get_string_buf_size +local get_size_ptr = base.get_size_ptr +local tonumber = tonumber +local tostring = tostring +local next = next +local type = type +local error = error +local getmetatable = getmetatable +local FFI_DECLINED = base.FFI_DECLINED +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_shdict_get +local ngx_lua_ffi_shdict_incr +local ngx_lua_ffi_shdict_store +local ngx_lua_ffi_shdict_flush_all +local ngx_lua_ffi_shdict_get_ttl +local ngx_lua_ffi_shdict_set_expire +local ngx_lua_ffi_shdict_capacity +local ngx_lua_ffi_shdict_free_space +local ngx_lua_ffi_shdict_udata_to_zone + + +if subsystem == 'http' then + ffi.cdef[[ +int ngx_http_lua_ffi_shdict_get(void *zone, const unsigned char *key, + size_t key_len, int *value_type, unsigned char **str_value_buf, + size_t *str_value_len, double *num_value, int *user_flags, + int get_stale, int *is_stale, char **errmsg); + +int ngx_http_lua_ffi_shdict_incr(void *zone, const unsigned char *key, + size_t key_len, double *value, char **err, int has_init, + double init, long init_ttl, int *forcible); + +int ngx_http_lua_ffi_shdict_store(void *zone, int op, + const unsigned char *key, size_t key_len, int value_type, + const unsigned char *str_value_buf, size_t str_value_len, + double num_value, long exptime, int user_flags, char **errmsg, + int *forcible); + +int ngx_http_lua_ffi_shdict_flush_all(void *zone); + +long ngx_http_lua_ffi_shdict_get_ttl(void *zone, + const unsigned char *key, size_t key_len); + +int ngx_http_lua_ffi_shdict_set_expire(void *zone, + const unsigned char *key, size_t key_len, long exptime); + +size_t ngx_http_lua_ffi_shdict_capacity(void *zone); + +void *ngx_http_lua_ffi_shdict_udata_to_zone(void *zone_udata); + ]] + + ngx_lua_ffi_shdict_get = C.ngx_http_lua_ffi_shdict_get + ngx_lua_ffi_shdict_incr = C.ngx_http_lua_ffi_shdict_incr + ngx_lua_ffi_shdict_store = C.ngx_http_lua_ffi_shdict_store + ngx_lua_ffi_shdict_flush_all = C.ngx_http_lua_ffi_shdict_flush_all + ngx_lua_ffi_shdict_get_ttl = C.ngx_http_lua_ffi_shdict_get_ttl + ngx_lua_ffi_shdict_set_expire = C.ngx_http_lua_ffi_shdict_set_expire + ngx_lua_ffi_shdict_capacity = C.ngx_http_lua_ffi_shdict_capacity + ngx_lua_ffi_shdict_udata_to_zone = + C.ngx_http_lua_ffi_shdict_udata_to_zone + + if not pcall(function () + return C.ngx_http_lua_ffi_shdict_free_space + end) + then + ffi.cdef[[ +size_t ngx_http_lua_ffi_shdict_free_space(void *zone); + ]] + end + + pcall(function () + ngx_lua_ffi_shdict_free_space = C.ngx_http_lua_ffi_shdict_free_space + end) + +elseif subsystem == 'stream' then + + ffi.cdef[[ +int ngx_stream_lua_ffi_shdict_get(void *zone, const unsigned char *key, + size_t key_len, int *value_type, unsigned char **str_value_buf, + size_t *str_value_len, double *num_value, int *user_flags, + int get_stale, int *is_stale, char **errmsg); + +int ngx_stream_lua_ffi_shdict_incr(void *zone, const unsigned char *key, + size_t key_len, double *value, char **err, int has_init, + double init, long init_ttl, int *forcible); + +int ngx_stream_lua_ffi_shdict_store(void *zone, int op, + const unsigned char *key, size_t key_len, int value_type, + const unsigned char *str_value_buf, size_t str_value_len, + double num_value, long exptime, int user_flags, char **errmsg, + int *forcible); + +int ngx_stream_lua_ffi_shdict_flush_all(void *zone); + +long ngx_stream_lua_ffi_shdict_get_ttl(void *zone, + const unsigned char *key, size_t key_len); + +int ngx_stream_lua_ffi_shdict_set_expire(void *zone, + const unsigned char *key, size_t key_len, long exptime); + +size_t ngx_stream_lua_ffi_shdict_capacity(void *zone); + +void *ngx_stream_lua_ffi_shdict_udata_to_zone(void *zone_udata); + ]] + + ngx_lua_ffi_shdict_get = C.ngx_stream_lua_ffi_shdict_get + ngx_lua_ffi_shdict_incr = C.ngx_stream_lua_ffi_shdict_incr + ngx_lua_ffi_shdict_store = C.ngx_stream_lua_ffi_shdict_store + ngx_lua_ffi_shdict_flush_all = C.ngx_stream_lua_ffi_shdict_flush_all + ngx_lua_ffi_shdict_get_ttl = C.ngx_stream_lua_ffi_shdict_get_ttl + ngx_lua_ffi_shdict_set_expire = C.ngx_stream_lua_ffi_shdict_set_expire + ngx_lua_ffi_shdict_capacity = C.ngx_stream_lua_ffi_shdict_capacity + ngx_lua_ffi_shdict_udata_to_zone = + C.ngx_stream_lua_ffi_shdict_udata_to_zone + + if not pcall(function () + return C.ngx_stream_lua_ffi_shdict_free_space + end) + then + ffi.cdef[[ +size_t ngx_stream_lua_ffi_shdict_free_space(void *zone); + ]] + end + + -- ngx_stream_lua is only compatible with NGINX >= 1.13.6, meaning it + -- cannot lack support for ngx_stream_lua_ffi_shdict_free_space. + ngx_lua_ffi_shdict_free_space = C.ngx_stream_lua_ffi_shdict_free_space + +else + error("unknown subsystem: " .. subsystem) +end + +if not pcall(function () return C.free end) then + ffi.cdef[[ +void free(void *ptr); + ]] +end + + +local value_type = ffi_new("int[1]") +local user_flags = ffi_new("int[1]") +local num_value = ffi_new("double[1]") +local is_stale = ffi_new("int[1]") +local forcible = ffi_new("int[1]") +local str_value_buf = ffi_new("unsigned char *[1]") +local errmsg = base.get_errmsg_ptr() + + +local function check_zone(zone) + if not zone or type(zone) ~= "table" then + error("bad \"zone\" argument", 3) + end + + zone = zone[1] + if type(zone) ~= "userdata" then + error("bad \"zone\" argument", 3) + end + + zone = ngx_lua_ffi_shdict_udata_to_zone(zone) + if zone == nil then + error("bad \"zone\" argument", 3) + end + + return zone +end + + +local function shdict_store(zone, op, key, value, exptime, flags) + zone = check_zone(zone) + + if not exptime then + exptime = 0 + elseif exptime < 0 then + error('bad "exptime" argument', 2) + end + + if not flags then + flags = 0 + end + + if key == nil then + return nil, "nil key" + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local key_len = #key + if key_len == 0 then + return nil, "empty key" + end + if key_len > 65535 then + return nil, "key too long" + end + + local str_val_buf + local str_val_len = 0 + local num_val = 0 + local valtyp = type(value) + + -- print("value type: ", valtyp) + -- print("exptime: ", exptime) + + if valtyp == "string" then + valtyp = 4 -- LUA_TSTRING + str_val_buf = value + str_val_len = #value + + elseif valtyp == "number" then + valtyp = 3 -- LUA_TNUMBER + num_val = value + + elseif value == nil then + valtyp = 0 -- LUA_TNIL + + elseif valtyp == "boolean" then + valtyp = 1 -- LUA_TBOOLEAN + num_val = value and 1 or 0 + + else + return nil, "bad value type" + end + + local rc = ngx_lua_ffi_shdict_store(zone, op, key, key_len, + valtyp, str_val_buf, + str_val_len, num_val, + exptime * 1000, flags, errmsg, + forcible) + + -- print("rc == ", rc) + + if rc == 0 then -- NGX_OK + return true, nil, forcible[0] == 1 + end + + -- NGX_DECLINED or NGX_ERROR + return false, ffi_str(errmsg[0]), forcible[0] == 1 +end + + +local function shdict_set(zone, key, value, exptime, flags) + return shdict_store(zone, 0, key, value, exptime, flags) +end + + +local function shdict_safe_set(zone, key, value, exptime, flags) + return shdict_store(zone, 0x0004, key, value, exptime, flags) +end + + +local function shdict_add(zone, key, value, exptime, flags) + return shdict_store(zone, 0x0001, key, value, exptime, flags) +end + + +local function shdict_safe_add(zone, key, value, exptime, flags) + return shdict_store(zone, 0x0005, key, value, exptime, flags) +end + + +local function shdict_replace(zone, key, value, exptime, flags) + return shdict_store(zone, 0x0002, key, value, exptime, flags) +end + + +local function shdict_delete(zone, key) + return shdict_set(zone, key, nil) +end + + +local function shdict_get(zone, key) + zone = check_zone(zone) + + if key == nil then + return nil, "nil key" + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local key_len = #key + if key_len == 0 then + return nil, "empty key" + end + if key_len > 65535 then + return nil, "key too long" + end + + local size = get_string_buf_size() + local buf = get_string_buf(size) + str_value_buf[0] = buf + local value_len = get_size_ptr() + value_len[0] = size + + local rc = ngx_lua_ffi_shdict_get(zone, key, key_len, value_type, + str_value_buf, value_len, + num_value, user_flags, 0, + is_stale, errmsg) + if rc ~= 0 then + if errmsg[0] ~= nil then + return nil, ffi_str(errmsg[0]) + end + + error("failed to get the key") + end + + local typ = value_type[0] + + if typ == 0 then -- LUA_TNIL + return nil + end + + local flags = tonumber(user_flags[0]) + + local val + + if typ == 4 then -- LUA_TSTRING + if str_value_buf[0] ~= buf then + -- ngx.say("len: ", tonumber(value_len[0])) + buf = str_value_buf[0] + val = ffi_str(buf, value_len[0]) + C.free(buf) + else + val = ffi_str(buf, value_len[0]) + end + + elseif typ == 3 then -- LUA_TNUMBER + val = tonumber(num_value[0]) + + elseif typ == 1 then -- LUA_TBOOLEAN + val = (tonumber(buf[0]) ~= 0) + + else + error("unknown value type: " .. typ) + end + + if flags ~= 0 then + return val, flags + end + + return val +end + + +local function shdict_get_stale(zone, key) + zone = check_zone(zone) + + if key == nil then + return nil, "nil key" + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local key_len = #key + if key_len == 0 then + return nil, "empty key" + end + if key_len > 65535 then + return nil, "key too long" + end + + local size = get_string_buf_size() + local buf = get_string_buf(size) + str_value_buf[0] = buf + local value_len = get_size_ptr() + value_len[0] = size + + local rc = ngx_lua_ffi_shdict_get(zone, key, key_len, value_type, + str_value_buf, value_len, + num_value, user_flags, 1, + is_stale, errmsg) + if rc ~= 0 then + if errmsg[0] ~= nil then + return nil, ffi_str(errmsg[0]) + end + + error("failed to get the key") + end + + local typ = value_type[0] + + if typ == 0 then -- LUA_TNIL + return nil + end + + local flags = tonumber(user_flags[0]) + local val + + if typ == 4 then -- LUA_TSTRING + if str_value_buf[0] ~= buf then + -- ngx.say("len: ", tonumber(value_len[0])) + buf = str_value_buf[0] + val = ffi_str(buf, value_len[0]) + C.free(buf) + else + val = ffi_str(buf, value_len[0]) + end + + elseif typ == 3 then -- LUA_TNUMBER + val = tonumber(num_value[0]) + + elseif typ == 1 then -- LUA_TBOOLEAN + val = (tonumber(buf[0]) ~= 0) + + else + error("unknown value type: " .. typ) + end + + if flags ~= 0 then + return val, flags, is_stale[0] == 1 + end + + return val, nil, is_stale[0] == 1 +end + + +local function shdict_incr(zone, key, value, init, init_ttl) + zone = check_zone(zone) + + if key == nil then + return nil, "nil key" + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local key_len = #key + if key_len == 0 then + return nil, "empty key" + end + if key_len > 65535 then + return nil, "key too long" + end + + if type(value) ~= "number" then + value = tonumber(value) + end + num_value[0] = value + + if init then + local typ = type(init) + if typ ~= "number" then + init = tonumber(init) + + if not init then + error("bad init arg: number expected, got " .. typ, 2) + end + end + end + + if init_ttl ~= nil then + local typ = type(init_ttl) + if typ ~= "number" then + init_ttl = tonumber(init_ttl) + + if not init_ttl then + error("bad init_ttl arg: number expected, got " .. typ, 2) + end + end + + if init_ttl < 0 then + error('bad "init_ttl" argument', 2) + end + + if not init then + error('must provide "init" when providing "init_ttl"', 2) + end + + else + init_ttl = 0 + end + + local rc = ngx_lua_ffi_shdict_incr(zone, key, key_len, num_value, + errmsg, init and 1 or 0, + init or 0, init_ttl * 1000, + forcible) + if rc ~= 0 then -- ~= NGX_OK + return nil, ffi_str(errmsg[0]) + end + + if not init then + return tonumber(num_value[0]) + end + + return tonumber(num_value[0]), nil, forcible[0] == 1 +end + + +local function shdict_flush_all(zone) + zone = check_zone(zone) + + ngx_lua_ffi_shdict_flush_all(zone) +end + + +local function shdict_ttl(zone, key) + zone = check_zone(zone) + + if key == nil then + return nil, "nil key" + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local key_len = #key + if key_len == 0 then + return nil, "empty key" + end + + if key_len > 65535 then + return nil, "key too long" + end + + local rc = ngx_lua_ffi_shdict_get_ttl(zone, key, key_len) + + if rc == FFI_DECLINED then + return nil, "not found" + end + + return tonumber(rc) / 1000 +end + + +local function shdict_expire(zone, key, exptime) + zone = check_zone(zone) + + if not exptime then + error('bad "exptime" argument', 2) + end + + if key == nil then + return nil, "nil key" + end + + if type(key) ~= "string" then + key = tostring(key) + end + + local key_len = #key + if key_len == 0 then + return nil, "empty key" + end + + if key_len > 65535 then + return nil, "key too long" + end + + local rc = ngx_lua_ffi_shdict_set_expire(zone, key, key_len, + exptime * 1000) + + if rc == FFI_DECLINED then + return nil, "not found" + end + + -- NGINX_OK/FFI_OK + + return true +end + + +local function shdict_capacity(zone) + zone = check_zone(zone) + + return tonumber(ngx_lua_ffi_shdict_capacity(zone)) +end + + +local shdict_free_space +if ngx_lua_ffi_shdict_free_space then + shdict_free_space = function (zone) + zone = check_zone(zone) + + return tonumber(ngx_lua_ffi_shdict_free_space(zone)) + end + +else + shdict_free_space = function () + error("'shm:free_space()' not supported in NGINX < 1.11.7", 2) + end +end + + +local _, dict = next(ngx_shared, nil) +if dict then + local mt = getmetatable(dict) + if mt then + mt = mt.__index + if mt then + mt.get = shdict_get + mt.get_stale = shdict_get_stale + mt.incr = shdict_incr + mt.set = shdict_set + mt.safe_set = shdict_safe_set + mt.add = shdict_add + mt.safe_add = shdict_safe_add + mt.replace = shdict_replace + mt.delete = shdict_delete + mt.flush_all = shdict_flush_all + mt.ttl = shdict_ttl + mt.expire = shdict_expire + mt.capacity = shdict_capacity + mt.free_space = shdict_free_space + end + end +end + + +return _M diff --git a/resty/core/socket.lua b/resty/core/socket.lua new file mode 100644 index 0000000..1a504ec --- /dev/null +++ b/resty/core/socket.lua @@ -0,0 +1,124 @@ +local base = require "resty.core.base" +base.allows_subsystem('http') +local debug = require 'debug' +local ffi = require 'ffi' + + +local error = error +local tonumber = tonumber +local registry = debug.getregistry() +local ffi_new = ffi.new +local ffi_string = ffi.string +local C = ffi.C +local get_string_buf = base.get_string_buf +local get_size_ptr = base.get_size_ptr +local tostring = tostring + + +local option_index = { + ["keepalive"] = 1, + ["reuseaddr"] = 2, + ["tcp-nodelay"] = 3, + ["sndbuf"] = 4, + ["rcvbuf"] = 5, +} + + +ffi.cdef[[ +typedef struct ngx_http_lua_socket_tcp_upstream_s + ngx_http_lua_socket_tcp_upstream_t; + +int +ngx_http_lua_ffi_socket_tcp_getoption(ngx_http_lua_socket_tcp_upstream_t *u, + int opt, int *val, unsigned char *err, size_t *errlen); + +int +ngx_http_lua_ffi_socket_tcp_setoption(ngx_http_lua_socket_tcp_upstream_t *u, + int opt, int val, unsigned char *err, size_t *errlen); +]] + + +local output_value_buf = ffi_new("int[1]") +local FFI_OK = base.FFI_OK +local SOCKET_CTX_INDEX = 1 +local ERR_BUF_SIZE = 4096 + + +local function get_tcp_socket(cosocket) + local tcp_socket = cosocket[SOCKET_CTX_INDEX] + if not tcp_socket then + error("socket is never created nor connected") + end + + return tcp_socket +end + + +local function getoption(cosocket, option) + local tcp_socket = get_tcp_socket(cosocket) + + if option == nil then + return nil, 'missing the "option" argument' + end + + if option_index[option] == nil then + return nil, "unsupported option " .. tostring(option) + end + + local err = get_string_buf(ERR_BUF_SIZE) + local errlen = get_size_ptr() + errlen[0] = ERR_BUF_SIZE + + local rc = C.ngx_http_lua_ffi_socket_tcp_getoption(tcp_socket, + option_index[option], + output_value_buf, + err, + errlen) + if rc ~= FFI_OK then + return nil, ffi_string(err, errlen[0]) + end + + return tonumber(output_value_buf[0]) +end + + +local function setoption(cosocket, option, value) + local tcp_socket = get_tcp_socket(cosocket) + + if option == nil then + return nil, 'missing the "option" argument' + end + + if value == nil then + return nil, 'missing the "value" argument' + end + + if option_index[option] == nil then + return nil, "unsupported option " .. tostring(option) + end + + local err = get_string_buf(ERR_BUF_SIZE) + local errlen = get_size_ptr() + errlen[0] = ERR_BUF_SIZE + + local rc = C.ngx_http_lua_ffi_socket_tcp_setoption(tcp_socket, + option_index[option], + value, + err, + errlen) + if rc ~= FFI_OK then + return nil, ffi_string(err, errlen[0]) + end + + return true +end + + +do + local method_table = registry.__tcp_cosocket_mt + method_table.getoption = getoption + method_table.setoption = setoption +end + + +return { version = base.version } diff --git a/resty/core/time.lua b/resty/core/time.lua new file mode 100644 index 0000000..10ae72e --- /dev/null +++ b/resty/core/time.lua @@ -0,0 +1,159 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require 'ffi' +local base = require "resty.core.base" + + +local error = error +local tonumber = tonumber +local type = type +local C = ffi.C +local ffi_new = ffi.new +local ffi_str = ffi.string +local time_val = ffi_new("long[1]") +local get_string_buf = base.get_string_buf +local ngx = ngx +local FFI_ERROR = base.FFI_ERROR +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_now +local ngx_lua_ffi_time +local ngx_lua_ffi_today +local ngx_lua_ffi_localtime +local ngx_lua_ffi_utctime +local ngx_lua_ffi_update_time + + +if subsystem == 'http' then + ffi.cdef[[ +double ngx_http_lua_ffi_now(void); +long ngx_http_lua_ffi_time(void); +void ngx_http_lua_ffi_today(unsigned char *buf); +void ngx_http_lua_ffi_localtime(unsigned char *buf); +void ngx_http_lua_ffi_utctime(unsigned char *buf); +void ngx_http_lua_ffi_update_time(void); +int ngx_http_lua_ffi_cookie_time(unsigned char *buf, long t); +void ngx_http_lua_ffi_http_time(unsigned char *buf, long t); +void ngx_http_lua_ffi_parse_http_time(const unsigned char *str, size_t len, + long *time); + ]] + + ngx_lua_ffi_now = C.ngx_http_lua_ffi_now + ngx_lua_ffi_time = C.ngx_http_lua_ffi_time + ngx_lua_ffi_today = C.ngx_http_lua_ffi_today + ngx_lua_ffi_localtime = C.ngx_http_lua_ffi_localtime + ngx_lua_ffi_utctime = C.ngx_http_lua_ffi_utctime + ngx_lua_ffi_update_time = C.ngx_http_lua_ffi_update_time + +elseif subsystem == 'stream' then + ffi.cdef[[ +double ngx_stream_lua_ffi_now(void); +long ngx_stream_lua_ffi_time(void); +void ngx_stream_lua_ffi_today(unsigned char *buf); +void ngx_stream_lua_ffi_localtime(unsigned char *buf); +void ngx_stream_lua_ffi_utctime(unsigned char *buf); +void ngx_stream_lua_ffi_update_time(void); + ]] + + ngx_lua_ffi_now = C.ngx_stream_lua_ffi_now + ngx_lua_ffi_time = C.ngx_stream_lua_ffi_time + ngx_lua_ffi_today = C.ngx_stream_lua_ffi_today + ngx_lua_ffi_localtime = C.ngx_stream_lua_ffi_localtime + ngx_lua_ffi_utctime = C.ngx_stream_lua_ffi_utctime + ngx_lua_ffi_update_time = C.ngx_stream_lua_ffi_update_time +end + + +function ngx.now() + return tonumber(ngx_lua_ffi_now()) +end + + +function ngx.time() + return tonumber(ngx_lua_ffi_time()) +end + + +function ngx.update_time() + ngx_lua_ffi_update_time() +end + + +function ngx.today() + -- the format of today is 2010-11-19 + local today_buf_size = 10 + local buf = get_string_buf(today_buf_size) + ngx_lua_ffi_today(buf) + return ffi_str(buf, today_buf_size) +end + + +function ngx.localtime() + -- the format of localtime is 2010-11-19 20:56:31 + local localtime_buf_size = 19 + local buf = get_string_buf(localtime_buf_size) + ngx_lua_ffi_localtime(buf) + return ffi_str(buf, localtime_buf_size) +end + + +function ngx.utctime() + -- the format of utctime is 2010-11-19 20:56:31 + local utctime_buf_size = 19 + local buf = get_string_buf(utctime_buf_size) + ngx_lua_ffi_utctime(buf) + return ffi_str(buf, utctime_buf_size) +end + + +if subsystem == 'http' then + +function ngx.cookie_time(sec) + if type(sec) ~= "number" then + error("number argument only", 2) + end + + -- the format of cookie time is Mon, 28-Sep-2038 06:00:00 GMT + -- or Mon, 28-Sep-18 06:00:00 GMT + local cookie_time_buf_size = 29 + local buf = get_string_buf(cookie_time_buf_size) + local used_size = C.ngx_http_lua_ffi_cookie_time(buf, sec) + return ffi_str(buf, used_size) +end + + +function ngx.http_time(sec) + if type(sec) ~= "number" then + error("number argument only", 2) + end + + -- the format of http time is Mon, 28 Sep 1970 06:00:00 GMT + local http_time_buf_size = 29 + local buf = get_string_buf(http_time_buf_size) + C.ngx_http_lua_ffi_http_time(buf, sec) + return ffi_str(buf, http_time_buf_size) +end + + +function ngx.parse_http_time(time_str) + if type(time_str) ~= "string" then + error("string argument only", 2) + end + + C.ngx_http_lua_ffi_parse_http_time(time_str, #time_str, time_val) + + local res = time_val[0] + if res == FFI_ERROR then + return nil + end + + return tonumber(res) +end + +end + +return { + version = base.version +} diff --git a/resty/core/uri.lua b/resty/core/uri.lua new file mode 100644 index 0000000..96b1ab4 --- /dev/null +++ b/resty/core/uri.lua @@ -0,0 +1,115 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local base = require "resty.core.base" + + +local C = ffi.C +local ffi_string = ffi.string +local ngx = ngx +local type = type +local error = error +local tostring = tostring +local get_string_buf = base.get_string_buf +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_escape_uri +local ngx_lua_ffi_unescape_uri +local ngx_lua_ffi_uri_escaped_length + +local NGX_ESCAPE_URI = 0 +local NGX_ESCAPE_URI_COMPONENT = 2 +local NGX_ESCAPE_MAIL_AUTH = 6 + + +if subsystem == "http" then + ffi.cdef[[ + size_t ngx_http_lua_ffi_uri_escaped_length(const unsigned char *src, + size_t len, int type); + + void ngx_http_lua_ffi_escape_uri(const unsigned char *src, size_t len, + unsigned char *dst, int type); + + size_t ngx_http_lua_ffi_unescape_uri(const unsigned char *src, + size_t len, unsigned char *dst); + ]] + + ngx_lua_ffi_escape_uri = C.ngx_http_lua_ffi_escape_uri + ngx_lua_ffi_unescape_uri = C.ngx_http_lua_ffi_unescape_uri + ngx_lua_ffi_uri_escaped_length = C.ngx_http_lua_ffi_uri_escaped_length + +elseif subsystem == "stream" then + ffi.cdef[[ + size_t ngx_stream_lua_ffi_uri_escaped_length(const unsigned char *src, + size_t len, int type); + + void ngx_stream_lua_ffi_escape_uri(const unsigned char *src, size_t len, + unsigned char *dst, int type); + + size_t ngx_stream_lua_ffi_unescape_uri(const unsigned char *src, + size_t len, unsigned char *dst); + ]] + + ngx_lua_ffi_escape_uri = C.ngx_stream_lua_ffi_escape_uri + ngx_lua_ffi_unescape_uri = C.ngx_stream_lua_ffi_unescape_uri + ngx_lua_ffi_uri_escaped_length = C.ngx_stream_lua_ffi_uri_escaped_length +end + + +ngx.escape_uri = function (s, esc_type) + if type(s) ~= 'string' then + if not s then + s = '' + + else + s = tostring(s) + end + end + + if esc_type == nil then + esc_type = NGX_ESCAPE_URI_COMPONENT + + else + if type(esc_type) ~= 'number' then + error("\"type\" is not a number", 3) + end + + if esc_type < NGX_ESCAPE_URI or esc_type > NGX_ESCAPE_MAIL_AUTH then + error("\"type\" " .. esc_type .. " out of range", 3) + end + end + + local slen = #s + local dlen = ngx_lua_ffi_uri_escaped_length(s, slen, esc_type) + + -- print("dlen: ", tonumber(dlen)) + if dlen == slen then + return s + end + local dst = get_string_buf(dlen) + ngx_lua_ffi_escape_uri(s, slen, dst, esc_type) + return ffi_string(dst, dlen) +end + + +ngx.unescape_uri = function (s) + if type(s) ~= 'string' then + if not s then + s = '' + else + s = tostring(s) + end + end + local slen = #s + local dlen = slen + local dst = get_string_buf(dlen) + dlen = ngx_lua_ffi_unescape_uri(s, slen, dst) + return ffi_string(dst, dlen) +end + + +return { + version = base.version, +} diff --git a/resty/core/utils.lua b/resty/core/utils.lua new file mode 100644 index 0000000..fda074a --- /dev/null +++ b/resty/core/utils.lua @@ -0,0 +1,46 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local base = require "resty.core.base" + + +local C = ffi.C +local ffi_str = ffi.string +local ffi_copy = ffi.copy +local byte = string.byte +local str_find = string.find +local get_string_buf = base.get_string_buf +local subsystem = ngx.config.subsystem + + +local _M = { + version = base.version +} + + +if subsystem == "http" then + ffi.cdef[[ + void ngx_http_lua_ffi_str_replace_char(unsigned char *buf, size_t len, + const unsigned char find, const unsigned char replace); + ]] + + + function _M.str_replace_char(str, find, replace) + if not str_find(str, find, nil, true) then + return str + end + + local len = #str + local buf = get_string_buf(len) + ffi_copy(buf, str, len) + + C.ngx_http_lua_ffi_str_replace_char(buf, len, byte(find), + byte(replace)) + + return ffi_str(buf, len) + end +end + + +return _M diff --git a/resty/core/var.lua b/resty/core/var.lua new file mode 100644 index 0000000..ea9c763 --- /dev/null +++ b/resty/core/var.lua @@ -0,0 +1,160 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local base = require "resty.core.base" + + +local C = ffi.C +local ffi_new = ffi.new +local ffi_str = ffi.string +local type = type +local error = error +local tostring = tostring +local setmetatable = setmetatable +local get_request = base.get_request +local get_string_buf = base.get_string_buf +local get_size_ptr = base.get_size_ptr +local new_tab = base.new_tab +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_var_get +local ngx_lua_ffi_var_set + + +local ERR_BUF_SIZE = 256 + + +ngx.var = new_tab(0, 0) + + +if subsystem == "http" then + ffi.cdef[[ + int ngx_http_lua_ffi_var_get(ngx_http_request_t *r, + const char *name_data, size_t name_len, char *lowcase_buf, + int capture_id, char **value, size_t *value_len, char **err); + + int ngx_http_lua_ffi_var_set(ngx_http_request_t *r, + const unsigned char *name_data, size_t name_len, + unsigned char *lowcase_buf, const unsigned char *value, + size_t value_len, unsigned char *errbuf, size_t *errlen); + ]] + + ngx_lua_ffi_var_get = C.ngx_http_lua_ffi_var_get + ngx_lua_ffi_var_set = C.ngx_http_lua_ffi_var_set + +elseif subsystem == "stream" then + ffi.cdef[[ + int ngx_stream_lua_ffi_var_get(ngx_stream_lua_request_t *r, + const char *name_data, size_t name_len, char *lowcase_buf, + int capture_id, char **value, size_t *value_len, char **err); + + int ngx_stream_lua_ffi_var_set(ngx_stream_lua_request_t *r, + const unsigned char *name_data, size_t name_len, + unsigned char *lowcase_buf, const unsigned char *value, + size_t value_len, unsigned char *errbuf, size_t *errlen); + ]] + + ngx_lua_ffi_var_get = C.ngx_stream_lua_ffi_var_get + ngx_lua_ffi_var_set = C.ngx_stream_lua_ffi_var_set +end + + +local value_ptr = ffi_new("unsigned char *[1]") +local errmsg = base.get_errmsg_ptr() + + +local function var_get(self, name) + local r = get_request() + if not r then + error("no request found") + end + + local value_len = get_size_ptr() + local rc + if type(name) == "number" then + rc = ngx_lua_ffi_var_get(r, nil, 0, nil, name, value_ptr, value_len, + errmsg) + + else + if type(name) ~= "string" then + error("bad variable name", 2) + end + + local name_len = #name + local lowcase_buf = get_string_buf(name_len) + + rc = ngx_lua_ffi_var_get(r, name, name_len, lowcase_buf, 0, value_ptr, + value_len, errmsg) + end + + -- ngx.log(ngx.WARN, "rc = ", rc) + + if rc == 0 then -- NGX_OK + return ffi_str(value_ptr[0], value_len[0]) + end + + if rc == -5 then -- NGX_DECLINED + return nil + end + + if rc == -1 then -- NGX_ERROR + error(ffi_str(errmsg[0]), 2) + end +end + + +local function var_set(self, name, value) + local r = get_request() + if not r then + error("no request found") + end + + if type(name) ~= "string" then + error("bad variable name", 2) + end + local name_len = #name + + local errlen = get_size_ptr() + errlen[0] = ERR_BUF_SIZE + local lowcase_buf = get_string_buf(name_len + ERR_BUF_SIZE) + + local value_len + if value == nil then + value_len = 0 + else + if type(value) ~= 'string' then + value = tostring(value) + end + value_len = #value + end + + local errbuf = lowcase_buf + name_len + local rc = ngx_lua_ffi_var_set(r, name, name_len, lowcase_buf, value, + value_len, errbuf, errlen) + + -- ngx.log(ngx.WARN, "rc = ", rc) + + if rc == 0 then -- NGX_OK + return + end + + if rc == -1 then -- NGX_ERROR + error(ffi_str(errbuf, errlen[0]), 2) + end +end + + +do + local mt = new_tab(0, 2) + mt.__index = var_get + mt.__newindex = var_set + + setmetatable(ngx.var, mt) +end + + +return { + version = base.version +} diff --git a/resty/core/worker.lua b/resty/core/worker.lua new file mode 100644 index 0000000..c336deb --- /dev/null +++ b/resty/core/worker.lua @@ -0,0 +1,77 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local base = require "resty.core.base" + + +local C = ffi.C +local new_tab = base.new_tab +local subsystem = ngx.config.subsystem + + +local ngx_lua_ffi_worker_id +local ngx_lua_ffi_worker_pid +local ngx_lua_ffi_worker_count +local ngx_lua_ffi_worker_exiting + + +ngx.worker = new_tab(0, 4) + + +if subsystem == "http" then + ffi.cdef[[ + int ngx_http_lua_ffi_worker_id(void); + int ngx_http_lua_ffi_worker_pid(void); + int ngx_http_lua_ffi_worker_count(void); + int ngx_http_lua_ffi_worker_exiting(void); + ]] + + ngx_lua_ffi_worker_id = C.ngx_http_lua_ffi_worker_id + ngx_lua_ffi_worker_pid = C.ngx_http_lua_ffi_worker_pid + ngx_lua_ffi_worker_count = C.ngx_http_lua_ffi_worker_count + ngx_lua_ffi_worker_exiting = C.ngx_http_lua_ffi_worker_exiting + +elseif subsystem == "stream" then + ffi.cdef[[ + int ngx_stream_lua_ffi_worker_id(void); + int ngx_stream_lua_ffi_worker_pid(void); + int ngx_stream_lua_ffi_worker_count(void); + int ngx_stream_lua_ffi_worker_exiting(void); + ]] + + ngx_lua_ffi_worker_id = C.ngx_stream_lua_ffi_worker_id + ngx_lua_ffi_worker_pid = C.ngx_stream_lua_ffi_worker_pid + ngx_lua_ffi_worker_count = C.ngx_stream_lua_ffi_worker_count + ngx_lua_ffi_worker_exiting = C.ngx_stream_lua_ffi_worker_exiting +end + + +function ngx.worker.exiting() + return ngx_lua_ffi_worker_exiting() ~= 0 +end + + +function ngx.worker.pid() + return ngx_lua_ffi_worker_pid() +end + + +function ngx.worker.id() + local id = ngx_lua_ffi_worker_id() + if id < 0 then + return nil + end + + return id +end + + +function ngx.worker.count() + return ngx_lua_ffi_worker_count() +end + + +return { + _VERSION = base.version +} diff --git a/resty/dns/resolver.lua b/resty/dns/resolver.lua new file mode 100644 index 0000000..a67b3c1 --- /dev/null +++ b/resty/dns/resolver.lua @@ -0,0 +1,982 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +-- local socket = require "socket" +local bit = require "bit" +local udp = ngx.socket.udp +local rand = math.random +local char = string.char +local byte = string.byte +local find = string.find +local gsub = string.gsub +local sub = string.sub +local rep = string.rep +local format = string.format +local band = bit.band +local rshift = bit.rshift +local lshift = bit.lshift +local insert = table.insert +local concat = table.concat +local re_sub = ngx.re.sub +local tcp = ngx.socket.tcp +local log = ngx.log +local DEBUG = ngx.DEBUG +local unpack = unpack +local setmetatable = setmetatable +local type = type +local ipairs = ipairs + + +local ok, new_tab = pcall(require, "table.new") +if not ok then + new_tab = function (narr, nrec) return {} end +end + + +local DOT_CHAR = byte(".") +local ZERO_CHAR = byte("0") +local COLON_CHAR = byte(":") + +local IP6_ARPA = "ip6.arpa" + +local TYPE_A = 1 +local TYPE_NS = 2 +local TYPE_CNAME = 5 +local TYPE_SOA = 6 +local TYPE_PTR = 12 +local TYPE_MX = 15 +local TYPE_TXT = 16 +local TYPE_AAAA = 28 +local TYPE_SRV = 33 +local TYPE_SPF = 99 + +local CLASS_IN = 1 + +local SECTION_AN = 1 +local SECTION_NS = 2 +local SECTION_AR = 3 + + +local _M = { + _VERSION = '0.22', + TYPE_A = TYPE_A, + TYPE_NS = TYPE_NS, + TYPE_CNAME = TYPE_CNAME, + TYPE_SOA = TYPE_SOA, + TYPE_PTR = TYPE_PTR, + TYPE_MX = TYPE_MX, + TYPE_TXT = TYPE_TXT, + TYPE_AAAA = TYPE_AAAA, + TYPE_SRV = TYPE_SRV, + TYPE_SPF = TYPE_SPF, + CLASS_IN = CLASS_IN, + SECTION_AN = SECTION_AN, + SECTION_NS = SECTION_NS, + SECTION_AR = SECTION_AR +} + + +local resolver_errstrs = { + "format error", -- 1 + "server failure", -- 2 + "name error", -- 3 + "not implemented", -- 4 + "refused", -- 5 +} + +local soa_int32_fields = { "serial", "refresh", "retry", "expire", "minimum" } + +local mt = { __index = _M } + + +local arpa_tmpl = new_tab(72, 0) + +for i = 1, #IP6_ARPA do + arpa_tmpl[64 + i] = byte(IP6_ARPA, i) +end + +for i = 2, 64, 2 do + arpa_tmpl[i] = DOT_CHAR +end + + +function _M.new(class, opts) + if not opts then + return nil, "no options table specified" + end + + local servers = opts.nameservers + if not servers or #servers == 0 then + return nil, "no nameservers specified" + end + + local timeout = opts.timeout or 2000 -- default 2 sec + + local n = #servers + + local socks = {} + + for i = 1, n do + local server = servers[i] + local sock, err = udp() + if not sock then + return nil, "failed to create udp socket: " .. err + end + + local host, port + if type(server) == 'table' then + host = server[1] + port = server[2] or 53 + + else + host = server + port = 53 + servers[i] = {host, port} + end + + local ok, err = sock:setpeername(host, port) + if not ok then + return nil, "failed to set peer name: " .. err + end + + sock:settimeout(timeout) + + insert(socks, sock) + end + + local tcp_sock, err = tcp() + if not tcp_sock then + return nil, "failed to create tcp socket: " .. err + end + + tcp_sock:settimeout(timeout) + + return setmetatable( + { cur = opts.no_random and 1 or rand(1, n), + socks = socks, + tcp_sock = tcp_sock, + servers = servers, + retrans = opts.retrans or 5, + no_recurse = opts.no_recurse, + }, mt) +end + + +local function pick_sock(self, socks) + local cur = self.cur + + if cur == #socks then + self.cur = 1 + else + self.cur = cur + 1 + end + + return socks[cur] +end + + +local function _get_cur_server(self) + local cur = self.cur + + local servers = self.servers + + if cur == 1 then + return servers[#servers] + end + + return servers[cur - 1] +end + + +function _M.set_timeout(self, timeout) + local socks = self.socks + if not socks then + return nil, "not initialized" + end + + for i = 1, #socks do + local sock = socks[i] + sock:settimeout(timeout) + end + + local tcp_sock = self.tcp_sock + if not tcp_sock then + return nil, "not initialized" + end + + tcp_sock:settimeout(timeout) +end + + +local function _encode_name(s) + return char(#s) .. s +end + + +local function _decode_name(buf, pos) + local labels = {} + local nptrs = 0 + local p = pos + while nptrs < 128 do + local fst = byte(buf, p) + + if not fst then + return nil, 'truncated'; + end + + -- print("fst at ", p, ": ", fst) + + if fst == 0 then + if nptrs == 0 then + pos = pos + 1 + end + break + end + + if band(fst, 0xc0) ~= 0 then + -- being a pointer + if nptrs == 0 then + pos = pos + 2 + end + + nptrs = nptrs + 1 + + local snd = byte(buf, p + 1) + if not snd then + return nil, 'truncated' + end + + p = lshift(band(fst, 0x3f), 8) + snd + 1 + + -- print("resolving ptr ", p, ": ", byte(buf, p)) + + else + -- being a label + local label = sub(buf, p + 1, p + fst) + insert(labels, label) + + -- print("resolved label ", label) + + p = p + fst + 1 + + if nptrs == 0 then + pos = p + end + end + end + + return concat(labels, "."), pos +end + + +local function _build_request(qname, id, no_recurse, opts) + local qtype + + if opts then + qtype = opts.qtype + end + + if not qtype then + qtype = 1 -- A record + end + + local ident_hi = char(rshift(id, 8)) + local ident_lo = char(band(id, 0xff)) + + local flags + if no_recurse then + -- print("found no recurse") + flags = "\0\0" + else + flags = "\1\0" + end + + local nqs = "\0\1" + local nan = "\0\0" + local nns = "\0\0" + local nar = "\0\0" + local typ = char(rshift(qtype, 8), band(qtype, 0xff)) + local class = "\0\1" -- the Internet class + + if byte(qname, 1) == DOT_CHAR then + return nil, "bad name" + end + + local name = gsub(qname, "([^.]+)%.?", _encode_name) .. '\0' + + return { + ident_hi, ident_lo, flags, nqs, nan, nns, nar, + name, typ, class + } +end + + +local function parse_section(answers, section, buf, start_pos, size, + should_skip) + local pos = start_pos + + for _ = 1, size do + -- print(format("ans %d: qtype:%d qclass:%d", i, qtype, qclass)) + local ans = {} + + if not should_skip then + insert(answers, ans) + end + + ans.section = section + + local name + name, pos = _decode_name(buf, pos) + if not name then + return nil, pos + end + + ans.name = name + + -- print("name: ", name) + + local type_hi = byte(buf, pos) + local type_lo = byte(buf, pos + 1) + local typ = lshift(type_hi, 8) + type_lo + + ans.type = typ + + -- print("type: ", typ) + + local class_hi = byte(buf, pos + 2) + local class_lo = byte(buf, pos + 3) + local class = lshift(class_hi, 8) + class_lo + + ans.class = class + + -- print("class: ", class) + + local byte_1, byte_2, byte_3, byte_4 = byte(buf, pos + 4, pos + 7) + + local ttl = lshift(byte_1, 24) + lshift(byte_2, 16) + + lshift(byte_3, 8) + byte_4 + + -- print("ttl: ", ttl) + + ans.ttl = ttl + + local len_hi = byte(buf, pos + 8) + local len_lo = byte(buf, pos + 9) + local len = lshift(len_hi, 8) + len_lo + + -- print("record len: ", len) + + pos = pos + 10 + + if typ == TYPE_A then + + if len ~= 4 then + return nil, "bad A record value length: " .. len + end + + local addr_bytes = { byte(buf, pos, pos + 3) } + local addr = concat(addr_bytes, ".") + -- print("ipv4 address: ", addr) + + ans.address = addr + + pos = pos + 4 + + elseif typ == TYPE_CNAME then + + local cname, p = _decode_name(buf, pos) + if not cname then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("cname: ", cname) + + ans.cname = cname + + elseif typ == TYPE_AAAA then + + if len ~= 16 then + return nil, "bad AAAA record value length: " .. len + end + + local addr_bytes = { byte(buf, pos, pos + 15) } + local flds = {} + for i = 1, 16, 2 do + local a = addr_bytes[i] + local b = addr_bytes[i + 1] + if a == 0 then + insert(flds, format("%x", b)) + + else + insert(flds, format("%x%02x", a, b)) + end + end + + -- we do not compress the IPv6 addresses by default + -- due to performance considerations + + ans.address = concat(flds, ":") + + pos = pos + 16 + + elseif typ == TYPE_MX then + + -- print("len = ", len) + + if len < 3 then + return nil, "bad MX record value length: " .. len + end + + local pref_hi = byte(buf, pos) + local pref_lo = byte(buf, pos + 1) + + ans.preference = lshift(pref_hi, 8) + pref_lo + + local host, p = _decode_name(buf, pos + 2) + if not host then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + ans.exchange = host + + pos = p + + elseif typ == TYPE_SRV then + if len < 7 then + return nil, "bad SRV record value length: " .. len + end + + local prio_hi = byte(buf, pos) + local prio_lo = byte(buf, pos + 1) + ans.priority = lshift(prio_hi, 8) + prio_lo + + local weight_hi = byte(buf, pos + 2) + local weight_lo = byte(buf, pos + 3) + ans.weight = lshift(weight_hi, 8) + weight_lo + + local port_hi = byte(buf, pos + 4) + local port_lo = byte(buf, pos + 5) + ans.port = lshift(port_hi, 8) + port_lo + + local name, p = _decode_name(buf, pos + 6) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad srv record length: %d ~= %d", + p - pos, len) + end + + ans.target = name + + pos = p + + elseif typ == TYPE_NS then + + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("name: ", name) + + ans.nsdname = name + + elseif typ == TYPE_TXT or typ == TYPE_SPF then + + local key = (typ == TYPE_TXT) and "txt" or "spf" + + local slen = byte(buf, pos) + if slen + 1 > len then + -- truncate the over-run TXT record data + slen = len + end + + -- print("slen: ", len) + + local val = sub(buf, pos + 1, pos + slen) + local last = pos + len + pos = pos + slen + 1 + + if pos < last then + -- more strings to be processed + -- this code path is usually cold, so we do not + -- merge the following loop on this code path + -- with the processing logic above. + + val = {val} + local idx = 2 + repeat + local slen = byte(buf, pos) + if pos + slen + 1 > last then + -- truncate the over-run TXT record data + slen = last - pos - 1 + end + + val[idx] = sub(buf, pos + 1, pos + slen) + idx = idx + 1 + pos = pos + slen + 1 + + until pos >= last + end + + ans[key] = val + + elseif typ == TYPE_PTR then + + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("name: ", name) + + ans.ptrdname = name + + elseif typ == TYPE_SOA then + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + ans.mname = name + + pos = p + name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + ans.rname = name + + for _, field in ipairs(soa_int32_fields) do + local byte_1, byte_2, byte_3, byte_4 = byte(buf, p, p + 3) + ans[field] = lshift(byte_1, 24) + lshift(byte_2, 16) + + lshift(byte_3, 8) + byte_4 + p = p + 4 + end + + pos = p + + else + -- for unknown types, just forward the raw value + + ans.rdata = sub(buf, pos, pos + len - 1) + pos = pos + len + end + end + + return pos +end + + +local function parse_response(buf, id, opts) + local n = #buf + if n < 12 then + return nil, 'truncated'; + end + + -- header layout: ident flags nqs nan nns nar + + local ident_hi = byte(buf, 1) + local ident_lo = byte(buf, 2) + local ans_id = lshift(ident_hi, 8) + ident_lo + + -- print("id: ", id, ", ans id: ", ans_id) + + if ans_id ~= id then + -- identifier mismatch and throw it away + log(DEBUG, "id mismatch in the DNS reply: ", ans_id, " ~= ", id) + return nil, "id mismatch" + end + + local flags_hi = byte(buf, 3) + local flags_lo = byte(buf, 4) + local flags = lshift(flags_hi, 8) + flags_lo + + -- print(format("flags: 0x%x", flags)) + + if band(flags, 0x8000) == 0 then + return nil, format("bad QR flag in the DNS response") + end + + if band(flags, 0x200) ~= 0 then + return nil, "truncated" + end + + local code = band(flags, 0xf) + + -- print(format("code: %d", code)) + + local nqs_hi = byte(buf, 5) + local nqs_lo = byte(buf, 6) + local nqs = lshift(nqs_hi, 8) + nqs_lo + + -- print("nqs: ", nqs) + + if nqs ~= 1 then + return nil, format("bad number of questions in DNS response: %d", nqs) + end + + local nan_hi = byte(buf, 7) + local nan_lo = byte(buf, 8) + local nan = lshift(nan_hi, 8) + nan_lo + + -- print("nan: ", nan) + + local nns_hi = byte(buf, 9) + local nns_lo = byte(buf, 10) + local nns = lshift(nns_hi, 8) + nns_lo + + local nar_hi = byte(buf, 11) + local nar_lo = byte(buf, 12) + local nar = lshift(nar_hi, 8) + nar_lo + + -- skip the question part + + local ans_qname, pos = _decode_name(buf, 13) + if not ans_qname then + return nil, pos + end + + -- print("qname in reply: ", ans_qname) + + -- print("question: ", sub(buf, 13, pos)) + + if pos + 3 + nan * 12 > n then + -- print(format("%d > %d", pos + 3 + nan * 12, n)) + return nil, 'truncated'; + end + + -- question section layout: qname qtype(2) qclass(2) + + --[[ + local type_hi = byte(buf, pos) + local type_lo = byte(buf, pos + 1) + local ans_type = lshift(type_hi, 8) + type_lo + ]] + + -- print("ans qtype: ", ans_type) + + local class_hi = byte(buf, pos + 2) + local class_lo = byte(buf, pos + 3) + local qclass = lshift(class_hi, 8) + class_lo + + -- print("ans qclass: ", qclass) + + if qclass ~= 1 then + return nil, format("unknown query class %d in DNS response", qclass) + end + + pos = pos + 4 + + local answers = {} + + if code ~= 0 then + answers.errcode = code + answers.errstr = resolver_errstrs[code] or "unknown" + end + + local authority_section, additional_section + + if opts then + authority_section = opts.authority_section + additional_section = opts.additional_section + if opts.qtype == TYPE_SOA then + authority_section = true + end + end + + local err + + pos, err = parse_section(answers, SECTION_AN, buf, pos, nan) + + if not pos then + return nil, err + end + + if not authority_section and not additional_section then + return answers + end + + pos, err = parse_section(answers, SECTION_NS, buf, pos, nns, + not authority_section) + + if not pos then + return nil, err + end + + if not additional_section then + return answers + end + + pos, err = parse_section(answers, SECTION_AR, buf, pos, nar) + + if not pos then + return nil, err + end + + return answers +end + + +local function _gen_id(self) + local id = self._id -- for regression testing + if id then + return id + end + return rand(0, 65535) -- two bytes +end + + +local function _tcp_query(self, query, id, opts) + local sock = self.tcp_sock + if not sock then + return nil, "not initialized" + end + + log(DEBUG, "query the TCP server due to reply truncation") + + local server = _get_cur_server(self) + + local ok, err = sock:connect(server[1], server[2]) + if not ok then + return nil, "failed to connect to TCP server " + .. concat(server, ":") .. ": " .. err + end + + query = concat(query, "") + local len = #query + + local len_hi = char(rshift(len, 8)) + local len_lo = char(band(len, 0xff)) + + local bytes, err = sock:send({len_hi, len_lo, query}) + if not bytes then + return nil, "failed to send query to TCP server " + .. concat(server, ":") .. ": " .. err + end + + local buf, err = sock:receive(2) + if not buf then + return nil, "failed to receive the reply length field from TCP server " + .. concat(server, ":") .. ": " .. err + end + + len_hi = byte(buf, 1) + len_lo = byte(buf, 2) + len = lshift(len_hi, 8) + len_lo + + -- print("tcp message len: ", len) + + buf, err = sock:receive(len) + if not buf then + return nil, "failed to receive the reply message body from TCP server " + .. concat(server, ":") .. ": " .. err + end + + local answers, err = parse_response(buf, id, opts) + if not answers then + return nil, "failed to parse the reply from the TCP server " + .. concat(server, ":") .. ": " .. err + end + + sock:close() + + return answers +end + + +function _M.tcp_query(self, qname, opts) + local socks = self.socks + if not socks then + return nil, "not initialized" + end + + pick_sock(self, socks) + + local id = _gen_id(self) + + local query, err = _build_request(qname, id, self.no_recurse, opts) + if not query then + return nil, err + end + + return _tcp_query(self, query, id, opts) +end + + +function _M.query(self, qname, opts, tries) + local socks = self.socks + if not socks then + return nil, "not initialized" + end + + local id = _gen_id(self) + + local query, err = _build_request(qname, id, self.no_recurse, opts) + if not query then + return nil, err + end + + -- local cjson = require "cjson" + -- print("query: ", cjson.encode(concat(query, ""))) + + local retrans = self.retrans + if tries then + tries[1] = nil + end + + -- print("retrans: ", retrans) + + for i = 1, retrans do + local sock = pick_sock(self, socks) + + local ok + ok, err = sock:send(query) + if not ok then + local server = _get_cur_server(self) + err = "failed to send request to UDP server " + .. concat(server, ":") .. ": " .. err + + else + local buf + + for _ = 1, 128 do + buf, err = sock:receive(4096) + if err then + local server = _get_cur_server(self) + err = "failed to receive reply from UDP server " + .. concat(server, ":") .. ": " .. err + break + end + + if buf then + local answers + answers, err = parse_response(buf, id, opts) + if err == "truncated" then + answers, err = _tcp_query(self, query, id, opts) + end + + if err and err ~= "id mismatch" then + break + end + + if answers then + return answers, nil, tries + end + end + -- only here in case of an "id mismatch" + end + end + + if tries then + tries[i] = err + tries[i + 1] = nil -- ensure termination for user supplied table + end + end + + return nil, err, tries +end + + +function _M.compress_ipv6_addr(addr) + local addr = re_sub(addr, "^(0:)+|(:0)+$|:(0:)+", "::", "jo") + if addr == "::0" then + addr = "::" + end + + return addr +end + + +local function _expand_ipv6_addr(addr) + if find(addr, "::", 1, true) then + local ncol, addrlen = 8, #addr + + for i = 1, addrlen do + if byte(addr, i) == COLON_CHAR then + ncol = ncol - 1 + end + end + + if byte(addr, 1) == COLON_CHAR then + addr = "0" .. addr + end + + if byte(addr, -1) == COLON_CHAR then + addr = addr .. "0" + end + + addr = re_sub(addr, "::", ":" .. rep("0:", ncol), "jo") + end + + return addr +end + + +_M.expand_ipv6_addr = _expand_ipv6_addr + + +function _M.arpa_str(addr) + if find(addr, ":", 1, true) then + addr = _expand_ipv6_addr(addr) + local idx, hidx, addrlen = 1, 1, #addr + + for i = addrlen, 0, -1 do + local s = byte(addr, i) + if s == COLON_CHAR or not s then + for _ = hidx, 4 do + arpa_tmpl[idx] = ZERO_CHAR + idx = idx + 2 + end + hidx = 1 + else + arpa_tmpl[idx] = s + idx = idx + 2 + hidx = hidx + 1 + end + end + + addr = char(unpack(arpa_tmpl)) + else + addr = re_sub(addr, [[(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})]], + "$4.$3.$2.$1.in-addr.arpa", "ajo") + end + + return addr +end + + +function _M.reverse_query(self, addr) + return self.query(self, self.arpa_str(addr), + {qtype = self.TYPE_PTR}) +end + + +return _M diff --git a/resty/limit/conn.lua b/resty/limit/conn.lua new file mode 100644 index 0000000..b672487 --- /dev/null +++ b/resty/limit/conn.lua @@ -0,0 +1,125 @@ +-- Copyright (C) Yichun Zhang (agentzh) +-- +-- This library is an enhanced Lua port of the standard ngx_limit_conn +-- module. + + +local math = require "math" + + +local setmetatable = setmetatable +local floor = math.floor +local ngx_shared = ngx.shared +local assert = assert + + +local _M = { + _VERSION = '0.07' +} + + +local mt = { + __index = _M +} + + +function _M.new(dict_name, max, burst, default_conn_delay) + local dict = ngx_shared[dict_name] + if not dict then + return nil, "shared dict not found" + end + + assert(max > 0 and burst >= 0 and default_conn_delay > 0) + + local self = { + dict = dict, + max = max + 0, -- just to ensure the param is good + burst = burst, + unit_delay = default_conn_delay, + } + + return setmetatable(self, mt) +end + + +function _M.incoming(self, key, commit) + local dict = self.dict + local max = self.max + + self.committed = false + + local conn, err + if commit then + conn, err = dict:incr(key, 1, 0) + if not conn then + return nil, err + end + + if conn > max + self.burst then + conn, err = dict:incr(key, -1) + if not conn then + return nil, err + end + return nil, "rejected" + end + self.committed = true + + else + conn = (dict:get(key) or 0) + 1 + if conn > max + self.burst then + return nil, "rejected" + end + end + + if conn > max then + -- make the exessive connections wait + return self.unit_delay * floor((conn - 1) / max), conn + end + + -- we return a 0 delay by default + return 0, conn +end + + +function _M.is_committed(self) + return self.committed +end + + +function _M.leaving(self, key, req_latency) + assert(key) + local dict = self.dict + + local conn, err = dict:incr(key, -1) + if not conn then + return nil, err + end + + if req_latency then + local unit_delay = self.unit_delay + self.unit_delay = (req_latency + unit_delay) / 2 + end + + return conn +end + + +function _M.uncommit(self, key) + assert(key) + local dict = self.dict + + return dict:incr(key, -1) +end + + +function _M.set_conn(self, conn) + self.max = conn +end + + +function _M.set_burst(self, burst) + self.burst = burst +end + + +return _M diff --git a/resty/limit/count.lua b/resty/limit/count.lua new file mode 100644 index 0000000..fcb1042 --- /dev/null +++ b/resty/limit/count.lua @@ -0,0 +1,103 @@ +-- implement GitHub request rate limiting: +-- https://developer.github.com/v3/#rate-limiting + +local ngx_shared = ngx.shared +local setmetatable = setmetatable +local assert = assert + + +local _M = { + _VERSION = '0.07' +} + + +local mt = { + __index = _M +} + + +-- the "limit" argument controls number of request allowed in a time window. +-- time "window" argument controls the time window in seconds. +function _M.new(dict_name, limit, window) + local dict = ngx_shared[dict_name] + if not dict then + return nil, "shared dict not found" + end + + assert(limit > 0 and window > 0) + + local self = { + dict = dict, + limit = limit, + window = window, + } + + return setmetatable(self, mt) +end + + +function _M.incoming(self, key, commit) + local dict = self.dict + local limit = self.limit + local window = self.window + + local remaining, ok, err + + if commit then + remaining, err = dict:incr(key, -1, limit) + if not remaining then + return nil, err + end + + if remaining == limit - 1 then + ok, err = dict:expire(key, window) + if not ok then + if err == "not found" then + remaining, err = dict:incr(key, -1, limit) + if not remaining then + return nil, err + end + + ok, err = dict:expire(key, window) + if not ok then + return nil, err + end + + else + return nil, err + end + end + end + + else + remaining = (dict:get(key) or limit) - 1 + end + + if remaining < 0 then + return nil, "rejected" + end + + return 0, remaining +end + + +-- uncommit remaining and return remaining value +function _M.uncommit(self, key) + assert(key) + local dict = self.dict + local limit = self.limit + + local remaining, err = dict:incr(key, 1) + if not remaining then + if err == "not found" then + remaining = limit + else + return nil, err + end + end + + return remaining +end + + +return _M diff --git a/resty/limit/req.lua b/resty/limit/req.lua new file mode 100644 index 0000000..9313cb7 --- /dev/null +++ b/resty/limit/req.lua @@ -0,0 +1,153 @@ +-- Copyright (C) Yichun Zhang (agentzh) +-- +-- This library is an approximate Lua port of the standard ngx_limit_req +-- module. + + +local ffi = require "ffi" +local math = require "math" + + +local ngx_shared = ngx.shared +local ngx_now = ngx.now +local setmetatable = setmetatable +local ffi_cast = ffi.cast +local ffi_str = ffi.string +local abs = math.abs +local tonumber = tonumber +local type = type +local assert = assert +local max = math.max + + +-- TODO: we could avoid the tricky FFI cdata when lua_shared_dict supports +-- hash-typed values as in redis. +ffi.cdef[[ + struct lua_resty_limit_req_rec { + unsigned long excess; + uint64_t last; /* time in milliseconds */ + /* integer value, 1 corresponds to 0.001 r/s */ + }; +]] +local const_rec_ptr_type = ffi.typeof("const struct lua_resty_limit_req_rec*") +local rec_size = ffi.sizeof("struct lua_resty_limit_req_rec") + +-- we can share the cdata here since we only need it temporarily for +-- serialization inside the shared dict: +local rec_cdata = ffi.new("struct lua_resty_limit_req_rec") + + +local _M = { + _VERSION = '0.07' +} + + +local mt = { + __index = _M +} + + +function _M.new(dict_name, rate, burst) + local dict = ngx_shared[dict_name] + if not dict then + return nil, "shared dict not found" + end + + assert(rate > 0 and burst >= 0) + + local self = { + dict = dict, + rate = rate * 1000, + burst = burst * 1000, + } + + return setmetatable(self, mt) +end + + +-- sees an new incoming event +-- the "commit" argument controls whether should we record the event in shm. +-- FIXME we have a (small) race-condition window between dict:get() and +-- dict:set() across multiple nginx worker processes. The size of the +-- window is proportional to the number of workers. +function _M.incoming(self, key, commit) + local dict = self.dict + local rate = self.rate + local now = ngx_now() * 1000 + + local excess + + -- it's important to anchor the string value for the read-only pointer + -- cdata: + local v = dict:get(key) + if v then + if type(v) ~= "string" or #v ~= rec_size then + return nil, "shdict abused by other users" + end + local rec = ffi_cast(const_rec_ptr_type, v) + local elapsed = now - tonumber(rec.last) + + -- print("elapsed: ", elapsed, "ms") + + -- we do not handle changing rate values specifically. the excess value + -- can get automatically adjusted by the following formula with new rate + -- values rather quickly anyway. + excess = max(tonumber(rec.excess) - rate * abs(elapsed) / 1000 + 1000, + 0) + + -- print("excess: ", excess) + + if excess > self.burst then + return nil, "rejected" + end + + else + excess = 0 + end + + if commit then + rec_cdata.excess = excess + rec_cdata.last = now + dict:set(key, ffi_str(rec_cdata, rec_size)) + end + + -- return the delay in seconds, as well as excess + return excess / rate, excess / 1000 +end + + +function _M.uncommit(self, key) + assert(key) + local dict = self.dict + + local v = dict:get(key) + if not v then + return nil, "not found" + end + + if type(v) ~= "string" or #v ~= rec_size then + return nil, "shdict abused by other users" + end + + local rec = ffi_cast(const_rec_ptr_type, v) + + local excess = max(tonumber(rec.excess) - 1000, 0) + + rec_cdata.excess = excess + rec_cdata.last = rec.last + dict:set(key, ffi_str(rec_cdata, rec_size)) + return true +end + + +function _M.set_rate(self, rate) + self.rate = rate * 1000 +end + + +function _M.set_burst(self, burst) + self.burst = burst * 1000 +end + + +return _M diff --git a/resty/limit/traffic.lua b/resty/limit/traffic.lua new file mode 100644 index 0000000..f65586f --- /dev/null +++ b/resty/limit/traffic.lua @@ -0,0 +1,58 @@ +-- Copyright (C) Yichun Zhang (agentzh) +-- +-- This is an aggregator for various concrete traffic limiter instances +-- (like instances of the resty.limit.req, resty.limit.count and +-- resty.limit.conn classes). + + +local max = math.max + + +local _M = { + _VERSION = '0.07' +} + + +-- the states table is user supplied. each element stores the 2nd return value +-- of each limiter if there is no error returned. for resty.limit.req, the state +-- is the "excess" value (i.e., the number of excessive requests each second), +-- and for resty.limit.conn, the state is the current concurrency level +-- (including the current new connection). +function _M.combine(limiters, keys, states) + local n = #limiters + local max_delay = 0 + for i = 1, n do + local lim = limiters[i] + local delay, err = lim:incoming(keys[i], i == n) + if not delay then + return nil, err + end + if i == n then + if states then + states[i] = err + end + max_delay = delay + end + end + for i = 1, n - 1 do + local lim = limiters[i] + local delay, err = lim:incoming(keys[i], true) + if not delay then + for j = 1, i - 1 do + -- we intentionally ignore any errors returned below. + limiters[j]:uncommit(keys[j]) + end + limiters[n]:uncommit(keys[n]) + return nil, err + end + if states then + states[i] = err + end + + max_delay = max(max_delay, delay) + end + return max_delay +end + + +return _M diff --git a/resty/lock.lua b/resty/lock.lua new file mode 100644 index 0000000..75ccf64 --- /dev/null +++ b/resty/lock.lua @@ -0,0 +1,221 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +require "resty.core.shdict" -- enforce this to avoid dead locks + +local ffi = require "ffi" +local ffi_new = ffi.new +local shared = ngx.shared +local sleep = ngx.sleep +local log = ngx.log +local max = math.max +local min = math.min +local debug = ngx.config.debug +local setmetatable = setmetatable +local tonumber = tonumber + +local _M = { _VERSION = '0.08' } +local mt = { __index = _M } + +local ERR = ngx.ERR +local FREE_LIST_REF = 0 + +-- FIXME: we don't need this when we have __gc metamethod support on Lua +-- tables. +local memo = {} +if debug then _M.memo = memo end + + +local function ref_obj(key) + if key == nil then + return -1 + end + local ref = memo[FREE_LIST_REF] + if ref and ref ~= 0 then + memo[FREE_LIST_REF] = memo[ref] + + else + ref = #memo + 1 + end + memo[ref] = key + + -- print("ref key_id returned ", ref) + return ref +end +if debug then _M.ref_obj = ref_obj end + + +local function unref_obj(ref) + if ref >= 0 then + memo[ref] = memo[FREE_LIST_REF] + memo[FREE_LIST_REF] = ref + end +end +if debug then _M.unref_obj = unref_obj end + + +local function gc_lock(cdata) + local dict_id = tonumber(cdata.dict_id) + local key_id = tonumber(cdata.key_id) + + -- print("key_id: ", key_id, ", key: ", memo[key_id], "dict: ", + -- type(memo[cdata.dict_id])) + if key_id > 0 then + local key = memo[key_id] + unref_obj(key_id) + local dict = memo[dict_id] + -- print("dict.delete type: ", type(dict.delete)) + local ok, err = dict:delete(key) + if not ok then + log(ERR, 'failed to delete key "', key, '": ', err) + end + cdata.key_id = 0 + end + + unref_obj(dict_id) +end + + +local ctype = ffi.metatype("struct { int key_id; int dict_id; }", + { __gc = gc_lock }) + + +function _M.new(_, dict_name, opts) + local dict = shared[dict_name] + if not dict then + return nil, "dictionary not found" + end + local cdata = ffi_new(ctype) + cdata.key_id = 0 + cdata.dict_id = ref_obj(dict) + + local timeout, exptime, step, ratio, max_step + if opts then + timeout = opts.timeout + exptime = opts.exptime + step = opts.step + ratio = opts.ratio + max_step = opts.max_step + end + + if not exptime then + exptime = 30 + end + + if timeout then + timeout = min(timeout, exptime) + + if step then + step = min(step, timeout) + end + end + + local self = { + cdata = cdata, + dict = dict, + timeout = timeout or 5, + exptime = exptime, + step = step or 0.001, + ratio = ratio or 2, + max_step = max_step or 0.5, + } + return setmetatable(self, mt) +end + + +function _M.lock(self, key) + if not key then + return nil, "nil key" + end + + local dict = self.dict + local cdata = self.cdata + if cdata.key_id > 0 then + return nil, "locked" + end + local exptime = self.exptime + local ok, err = dict:add(key, true, exptime) + if ok then + cdata.key_id = ref_obj(key) + self.key = key + return 0 + end + if err ~= "exists" then + return nil, err + end + -- lock held by others + local step = self.step + local ratio = self.ratio + local timeout = self.timeout + local max_step = self.max_step + local elapsed = 0 + while timeout > 0 do + sleep(step) + elapsed = elapsed + step + timeout = timeout - step + + local ok, err = dict:add(key, true, exptime) + if ok then + cdata.key_id = ref_obj(key) + self.key = key + return elapsed + end + + if err ~= "exists" then + return nil, err + end + + if timeout <= 0 then + break + end + + step = min(max(0.001, step * ratio), timeout, max_step) + end + + return nil, "timeout" +end + + +function _M.unlock(self) + local dict = self.dict + local cdata = self.cdata + local key_id = tonumber(cdata.key_id) + if key_id <= 0 then + return nil, "unlocked" + end + + local key = memo[key_id] + unref_obj(key_id) + + local ok, err = dict:delete(key) + if not ok then + return nil, err + end + cdata.key_id = 0 + + return 1 +end + + +function _M.expire(self, time) + local dict = self.dict + local cdata = self.cdata + local key_id = tonumber(cdata.key_id) + if key_id <= 0 then + return nil, "unlocked" + end + + if not time then + time = self.exptime + end + + local ok, err = dict:replace(self.key, true, time) + if not ok then + return nil, err + end + + return true +end + + +return _M diff --git a/resty/lrucache.lua b/resty/lrucache.lua new file mode 100644 index 0000000..e52a6e3 --- /dev/null +++ b/resty/lrucache.lua @@ -0,0 +1,340 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_sizeof = ffi.sizeof +local ffi_cast = ffi.cast +local ffi_fill = ffi.fill +local ngx_now = ngx.now +local uintptr_t = ffi.typeof("uintptr_t") +local setmetatable = setmetatable +local tonumber = tonumber +local type = type +local new_tab +do + local ok + ok, new_tab = pcall(require, "table.new") + if not ok then + new_tab = function(narr, nrec) return {} end + end +end + + +if string.find(jit.version, " 2.0", 1, true) then + ngx.log(ngx.ALERT, "use of lua-resty-lrucache with LuaJIT 2.0 is ", + "not recommended; use LuaJIT 2.1+ instead") +end + + +local ok, tb_clear = pcall(require, "table.clear") +if not ok then + local pairs = pairs + tb_clear = function (tab) + for k, _ in pairs(tab) do + tab[k] = nil + end + end +end + + +-- queue data types +-- +-- this queue is a double-ended queue and the first node +-- is reserved for the queue itself. +-- the implementation is mostly borrowed from nginx's ngx_queue_t data +-- structure. + +ffi.cdef[[ + typedef struct lrucache_queue_s lrucache_queue_t; + struct lrucache_queue_s { + double expire; /* in seconds */ + lrucache_queue_t *prev; + lrucache_queue_t *next; + uint32_t user_flags; + }; +]] + +local queue_arr_type = ffi.typeof("lrucache_queue_t[?]") +local queue_type = ffi.typeof("lrucache_queue_t") +local NULL = ffi.null + + +-- queue utility functions + +local function queue_insert_tail(h, x) + local last = h[0].prev + x.prev = last + last.next = x + x.next = h + h[0].prev = x +end + + +local function queue_init(size) + if not size then + size = 0 + end + local q = ffi_new(queue_arr_type, size + 1) + ffi_fill(q, ffi_sizeof(queue_type, size + 1), 0) + + if size == 0 then + q[0].prev = q + q[0].next = q + + else + local prev = q[0] + for i = 1, size do + local e = q + i + e.user_flags = 0 + prev.next = e + e.prev = prev + prev = e + end + + local last = q[size] + last.next = q + q[0].prev = last + end + + return q +end + + +local function queue_is_empty(q) + -- print("q: ", tostring(q), "q.prev: ", tostring(q), ": ", q == q.prev) + return q == q[0].prev +end + + +local function queue_remove(x) + local prev = x.prev + local next = x.next + + next.prev = prev + prev.next = next + + -- for debugging purpose only: + x.prev = NULL + x.next = NULL +end + + +local function queue_insert_head(h, x) + x.next = h[0].next + x.next.prev = x + x.prev = h + h[0].next = x +end + + +local function queue_last(h) + return h[0].prev +end + + +local function queue_head(h) + return h[0].next +end + + +-- true module stuffs + +local _M = { + _VERSION = '0.11' +} +local mt = { __index = _M } + + +local function ptr2num(ptr) + return tonumber(ffi_cast(uintptr_t, ptr)) +end + + +function _M.new(size) + if size < 1 then + return nil, "size too small" + end + + local self = { + hasht = {}, + free_queue = queue_init(size), + cache_queue = queue_init(), + key2node = {}, + node2key = {}, + num_items = 0, + max_items = size, + } + return setmetatable(self, mt) +end + + +function _M.count(self) + return self.num_items +end + + +function _M.capacity(self) + return self.max_items +end + + +function _M.get(self, key) + local hasht = self.hasht + local val = hasht[key] + if val == nil then + return nil + end + + local node = self.key2node[key] + + -- print(key, ": moving node ", tostring(node), " to cache queue head") + local cache_queue = self.cache_queue + queue_remove(node) + queue_insert_head(cache_queue, node) + + if node.expire >= 0 and node.expire < ngx_now() then + -- print("expired: ", node.expire, " > ", ngx_now()) + return nil, val, node.user_flags + end + + return val, nil, node.user_flags +end + + +function _M.delete(self, key) + self.hasht[key] = nil + + local key2node = self.key2node + local node = key2node[key] + + if not node then + return false + end + + key2node[key] = nil + self.node2key[ptr2num(node)] = nil + + queue_remove(node) + queue_insert_tail(self.free_queue, node) + self.num_items = self.num_items - 1 + return true +end + + +function _M.set(self, key, value, ttl, flags) + local hasht = self.hasht + hasht[key] = value + + local key2node = self.key2node + local node = key2node[key] + if not node then + local free_queue = self.free_queue + local node2key = self.node2key + + if queue_is_empty(free_queue) then + -- evict the least recently used key + -- assert(not queue_is_empty(self.cache_queue)) + node = queue_last(self.cache_queue) + + local oldkey = node2key[ptr2num(node)] + -- print(key, ": evicting oldkey: ", oldkey, ", oldnode: ", + -- tostring(node)) + if oldkey then + hasht[oldkey] = nil + key2node[oldkey] = nil + end + + else + -- take a free queue node + node = queue_head(free_queue) + -- only add count if we are not evicting + self.num_items = self.num_items + 1 + -- print(key, ": get a new free node: ", tostring(node)) + end + + node2key[ptr2num(node)] = key + key2node[key] = node + end + + queue_remove(node) + queue_insert_head(self.cache_queue, node) + + if ttl then + node.expire = ngx_now() + ttl + else + node.expire = -1 + end + + if type(flags) == "number" and flags >= 0 then + node.user_flags = flags + + else + node.user_flags = 0 + end +end + + +function _M.get_keys(self, max_count, res) + if not max_count or max_count == 0 then + max_count = self.num_items + end + + if not res then + res = new_tab(max_count + 1, 0) -- + 1 for trailing hole + end + + local cache_queue = self.cache_queue + local node2key = self.node2key + + local i = 0 + local node = queue_head(cache_queue) + + while node ~= cache_queue do + if i >= max_count then + break + end + + i = i + 1 + res[i] = node2key[ptr2num(node)] + node = node.next + end + + res[i + 1] = nil + + return res +end + + +function _M.flush_all(self) + tb_clear(self.hasht) + tb_clear(self.node2key) + tb_clear(self.key2node) + + self.num_items = 0 + + local cache_queue = self.cache_queue + local free_queue = self.free_queue + + -- splice the cache_queue into free_queue + if not queue_is_empty(cache_queue) then + local free_head = free_queue[0] + local free_last = free_head.prev + + local cache_head = cache_queue[0] + local cache_first = cache_head.next + local cache_last = cache_head.prev + + free_last.next = cache_first + cache_first.prev = free_last + + cache_last.next = free_head + free_head.prev = cache_last + + cache_head.next = cache_queue + cache_head.prev = cache_queue + end +end + + +return _M diff --git a/resty/lrucache/pureffi.lua b/resty/lrucache/pureffi.lua new file mode 100644 index 0000000..a47377a --- /dev/null +++ b/resty/lrucache/pureffi.lua @@ -0,0 +1,606 @@ +-- Copyright (C) Yichun Zhang (agentzh) +-- Copyright (C) Shuxin Yang + +--[[ + This module implements a key/value cache store. We adopt LRU as our +replace/evict policy. Each key/value pair is tagged with a Time-to-Live (TTL); +from user's perspective, stale pairs are automatically removed from the cache. + +Why FFI +------- + In Lua, expression "table[key] = nil" does not *PHYSICALLY* remove the value +associated with the key; it just set the value to be nil! So the table will +keep growing with large number of the key/nil pairs which will be purged until +resize() operator is called. + + This "feature" is terribly ill-suited to what we need. Therefore we have to +rely on FFI to build a hash-table where any entry can be physically deleted +immediately. + +Under the hood: +-------------- + In concept, we introduce three data structures to implement the cache store: + 1. key/value vector for storing keys and values. + 2. a queue to mimic the LRU. + 3. hash-table for looking up the value for a given key. + + Unfortunately, efficiency and clarity usually come at each other cost. The +data strucutres we are using are slightly more complicated than what we +described above. + + o. Lua does not have efficient way to store a vector of pair. So, we use + two vectors for key/value pair: one for keys and the other for values + (_M.key_v and _M.val_v, respectively), and i-th key corresponds to + i-th value. + + A key/value pair is identified by the "id" field in a "node" (we shall + discuss node later) + + o. The queue is nothing more than a doubly-linked list of "node" linked via + lrucache_pureffi_queue_s::{next|prev} fields. + + o. The hash-table has two parts: + - the _M.bucket_v[] a vector of bucket, indiced by hash-value, and + - a bucket is a singly-linked list of "node" via the + lrucache_pureffi_queue_s::conflict field. + + A key must be a string, and the hash value of a key is evaluated by: + crc32(key-cast-to-pointer) % size(_M.bucket_v). + We mandate size(_M.bucket_v) being a power-of-two in order to avoid + expensive modulo operation. + + At the heart of the module is an array of "node" (of type + lrucache_pureffi_queue_s). A node: + - keeps the meta-data of its corresponding key/value pair + (embodied by the "id", and "expire" field); + - is a part of LRU queue (embodied by "prev" and "next" fields); + - is a part of hash-table (embodied by the "conflict" field). +]] + +local ffi = require "ffi" +local bit = require "bit" + + +local ffi_new = ffi.new +local ffi_sizeof = ffi.sizeof +local ffi_cast = ffi.cast +local ffi_fill = ffi.fill +local ngx_now = ngx.now +local uintptr_t = ffi.typeof("uintptr_t") +local c_str_t = ffi.typeof("const char*") +local int_t = ffi.typeof("int") +local int_array_t = ffi.typeof("int[?]") + + +local crc_tab = ffi.new("const unsigned int[256]", { + 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, + 0xE963A535, 0x9E6495A3, 0x0EDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, + 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, 0x1DB71064, 0x6AB020F2, + 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, + 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, 0x14015C4F, 0x63066CD9, + 0xFA0F3D63, 0x8D080DF5, 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, + 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, 0x35B5A8FA, 0x42B2986C, + 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, + 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, + 0xCFBA9599, 0xB8BDA50F, 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, + 0x2F6F7C87, 0x58684C11, 0xC1611DAB, 0xB6662D3D, 0x76DC4190, 0x01DB7106, + 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433, + 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, + 0x91646C97, 0xE6635C01, 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, + 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, 0x65B0D9C6, 0x12B7E950, + 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, + 0x4DB26158, 0x3AB551CE, 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, + 0xA4D1C46D, 0xD3D6F4FB, 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, + 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, 0x5005713C, 0x270241AA, + 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, 0xCE61E49F, + 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, + 0xB7BD5C3B, 0xC0BA6CAD, 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, + 0xEAD54739, 0x9DD277AF, 0x04DB2615, 0x73DC1683, 0xE3630B12, 0x94643B84, + 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, + 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, + 0x196C3671, 0x6E6B06E7, 0xFED41B76, 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, + 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, 0xD6D6A3E8, 0xA1D1937E, + 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, + 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, + 0x316E8EEF, 0x4669BE79, 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, + 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, 0xC5BA3BBE, 0xB2BD0B28, + 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, + 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, + 0x72076785, 0x05005713, 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, + 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, 0x86D3D2D4, 0xF1D4E242, + 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, + 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, + 0x616BFFD3, 0x166CCF45, 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, + 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, 0xAED16A4A, 0xD9D65ADC, + 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9, + 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, 0xCDD70693, + 0x54DE5729, 0x23D967BF, 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, + 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D }); + +local setmetatable = setmetatable +local tonumber = tonumber +local tostring = tostring +local type = type + +local brshift = bit.rshift +local bxor = bit.bxor +local band = bit.band + +local new_tab +do + local ok + ok, new_tab = pcall(require, "table.new") + if not ok then + new_tab = function(narr, nrec) return {} end + end +end + +-- queue data types +-- +-- this queue is a double-ended queue and the first node +-- is reserved for the queue itself. +-- the implementation is mostly borrowed from nginx's ngx_queue_t data +-- structure. + +ffi.cdef[[ + /* A lrucache_pureffi_queue_s node hook together three data structures: + * o. the key/value store as embodied by the "id" (which is in essence the + * indentifier of key/pair pair) and the "expire" (which is a metadata + * of the corresponding key/pair pair). + * o. The LRU queue via the prev/next fields. + * o. The hash-tabble as embodied by the "conflict" field. + */ + typedef struct lrucache_pureffi_queue_s lrucache_pureffi_queue_t; + struct lrucache_pureffi_queue_s { + /* Each node is assigned a unique ID at construction time, and the + * ID remain immutatble, regardless the node is in active-list or + * free-list. The queue header is assigned ID 0. Since queue-header + * is a sentinel node, 0 denodes "invalid ID". + * + * Intuitively, we can view the "id" as the identifier of key/value + * pair. + */ + int id; + + /* The bucket of the hash-table is implemented as a singly-linked list. + * The "conflict" refers to the ID of the next node in the bucket. + */ + int conflict; + + uint32_t user_flags; + + double expire; /* in seconds */ + + lrucache_pureffi_queue_t *prev; + lrucache_pureffi_queue_t *next; + }; +]] + +local queue_arr_type = ffi.typeof("lrucache_pureffi_queue_t[?]") +--local queue_ptr_type = ffi.typeof("lrucache_pureffi_queue_t*") +local queue_type = ffi.typeof("lrucache_pureffi_queue_t") +local NULL = ffi.null + + +--======================================================================== +-- +-- Queue utility functions +-- +--======================================================================== + +-- Append the element "x" to the given queue "h". +local function queue_insert_tail(h, x) + local last = h[0].prev + x.prev = last + last.next = x + x.next = h + h[0].prev = x +end + + +--[[ +Allocate a queue with size + 1 elements. Elements are linked together in a +circular way, i.e. the last element's "next" points to the first element, +while the first element's "prev" element points to the last element. +]] +local function queue_init(size) + if not size then + size = 0 + end + local q = ffi_new(queue_arr_type, size + 1) + ffi_fill(q, ffi_sizeof(queue_type, size + 1), 0) + + if size == 0 then + q[0].prev = q + q[0].next = q + + else + local prev = q[0] + for i = 1, size do + local e = q[i] + e.id = i + e.user_flags = 0 + prev.next = e + e.prev = prev + prev = e + end + + local last = q[size] + last.next = q + q[0].prev = last + end + + return q +end + + +local function queue_is_empty(q) + -- print("q: ", tostring(q), "q.prev: ", tostring(q), ": ", q == q.prev) + return q == q[0].prev +end + + +local function queue_remove(x) + local prev = x.prev + local next = x.next + + next.prev = prev + prev.next = next + + -- for debugging purpose only: + x.prev = NULL + x.next = NULL +end + + +-- Insert the element "x" the to the given queue "h" +local function queue_insert_head(h, x) + x.next = h[0].next + x.next.prev = x + x.prev = h + h[0].next = x +end + + +local function queue_last(h) + return h[0].prev +end + + +local function queue_head(h) + return h[0].next +end + + +--======================================================================== +-- +-- Miscellaneous Utility Functions +-- +--======================================================================== + +local function ptr2num(ptr) + return tonumber(ffi_cast(uintptr_t, ptr)) +end + + +local function crc32_ptr(ptr) + local p = brshift(ptr2num(ptr), 3) + local b = band(p, 255) + local crc32 = crc_tab[b] + + b = band(brshift(p, 8), 255) + crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)]) + + b = band(brshift(p, 16), 255) + crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)]) + + --b = band(brshift(p, 24), 255) + --crc32 = bxor(brshift(crc32, 8), crc_tab[band(bxor(crc32, b), 255)]) + return crc32 +end + + +--======================================================================== +-- +-- Implementation of "export" functions +-- +--======================================================================== + +local _M = { + _VERSION = '0.11' +} +local mt = { __index = _M } + + +-- "size" specifies the maximum number of entries in the LRU queue, and the +-- "load_factor" designates the 'load factor' of the hash-table we are using +-- internally. The default value of load-factor is 0.5 (i.e. 50%); if the +-- load-factor is specified, it will be clamped to the range of [0.1, 1](i.e. +-- if load-factor is greater than 1, it will be saturated to 1, likewise, +-- if load-factor is smaller than 0.1, it will be clamped to 0.1). +function _M.new(size, load_factor) + if size < 1 then + return nil, "size too small" + end + + -- Determine bucket size, which must be power of two. + local load_f = load_factor + if not load_factor then + load_f = 0.5 + elseif load_factor > 1 then + load_f = 1 + elseif load_factor < 0.1 then + load_f = 0.1 + end + + local bs_min = size / load_f + -- The bucket_sz *MUST* be a power-of-two. See the hash_string(). + local bucket_sz = 1 + repeat + bucket_sz = bucket_sz * 2 + until bucket_sz >= bs_min + + local self = { + size = size, + bucket_sz = bucket_sz, + free_queue = queue_init(size), + cache_queue = queue_init(0), + node_v = nil, + key_v = new_tab(size, 0), + val_v = new_tab(size, 0), + bucket_v = ffi_new(int_array_t, bucket_sz), + num_items = 0, + } + -- "node_v" is an array of all the nodes used in the LRU queue. Exprpession + -- node_v[i] evaluates to the element of ID "i". + self.node_v = self.free_queue + + -- Allocate the array-part of the key_v, val_v, bucket_v. + --local key_v = self.key_v + --local val_v = self.val_v + --local bucket_v = self.bucket_v + ffi_fill(self.bucket_v, ffi_sizeof(int_t, bucket_sz), 0) + + return setmetatable(self, mt) +end + + +function _M.count(self) + return self.num_items +end + + +function _M.capacity(self) + return self.size +end + + +local function hash_string(self, str) + local c_str = ffi_cast(c_str_t, str) + + local hv = crc32_ptr(c_str) + hv = band(hv, self.bucket_sz - 1) + -- Hint: bucket is 0-based + return hv +end + + +-- Search the node associated with the key in the bucket, if found returns +-- the the id of the node, and the id of its previous node in the conflict list. +-- The "bucket_hdr_id" is the ID of the first node in the bucket +local function _find_node_in_bucket(key, key_v, node_v, bucket_hdr_id) + if bucket_hdr_id ~= 0 then + local prev = 0 + local cur = bucket_hdr_id + + while cur ~= 0 and key_v[cur] ~= key do + prev = cur + cur = node_v[cur].conflict + end + + if cur ~= 0 then + return cur, prev + end + end +end + + +-- Return the node corresponding to the key/val. +local function find_key(self, key) + local key_hash = hash_string(self, key) + return _find_node_in_bucket(key, self.key_v, self.node_v, + self.bucket_v[key_hash]) +end + + +--[[ This function tries to + 1. Remove the given key and the associated value from the key/value store, + 2. Remove the entry associated with the key from the hash-table. + + NOTE: all queues remain intact. + + If there was a node bound to the key/val, return that node; otherwise, + nil is returned. +]] +local function remove_key(self, key) + local key_v = self.key_v + local val_v = self.val_v + local node_v = self.node_v + local bucket_v = self.bucket_v + + local key_hash = hash_string(self, key) + local cur, prev = + _find_node_in_bucket(key, key_v, node_v, bucket_v[key_hash]) + + if cur then + -- In an attempt to make key and val dead. + key_v[cur] = nil + val_v[cur] = nil + self.num_items = self.num_items - 1 + + -- Remove the node from the hash table + local next_node = node_v[cur].conflict + if prev ~= 0 then + node_v[prev].conflict = next_node + else + bucket_v[key_hash] = next_node + end + node_v[cur].conflict = 0 + + return cur + end +end + + +--[[ Bind the key/val with the given node, and insert the node into the Hashtab. + NOTE: this function does not touch any queue +]] +local function insert_key(self, key, val, node) + -- Bind the key/val with the node + local node_id = node.id + self.key_v[node_id] = key + self.val_v[node_id] = val + + -- Insert the node into the hash-table + local key_hash = hash_string(self, key) + local bucket_v = self.bucket_v + node.conflict = bucket_v[key_hash] + bucket_v[key_hash] = node_id + self.num_items = self.num_items + 1 +end + + +function _M.get(self, key) + if type(key) ~= "string" then + key = tostring(key) + end + + local node_id = find_key(self, key) + if not node_id then + return nil + end + + -- print(key, ": moving node ", tostring(node), " to cache queue head") + local cache_queue = self.cache_queue + local node = self.node_v + node_id + queue_remove(node) + queue_insert_head(cache_queue, node) + + local expire = node.expire + if expire >= 0 and expire < ngx_now() then + -- print("expired: ", node.expire, " > ", ngx_now()) + return nil, self.val_v[node_id], node.user_flags + end + + return self.val_v[node_id], nil, node.user_flags +end + + +function _M.delete(self, key) + if type(key) ~= "string" then + key = tostring(key) + end + + local node_id = remove_key(self, key); + if not node_id then + return false + end + + local node = self.node_v + node_id + queue_remove(node) + queue_insert_tail(self.free_queue, node) + return true +end + + +function _M.set(self, key, value, ttl, flags) + if type(key) ~= "string" then + key = tostring(key) + end + + local node_id = find_key(self, key) + local node + if not node_id then + local free_queue = self.free_queue + if queue_is_empty(free_queue) then + -- evict the least recently used key + -- assert(not queue_is_empty(self.cache_queue)) + node = queue_last(self.cache_queue) + remove_key(self, self.key_v[node.id]) + else + -- take a free queue node + node = queue_head(free_queue) + -- print(key, ": get a new free node: ", tostring(node)) + end + + -- insert the key + insert_key(self, key, value, node) + else + node = self.node_v + node_id + self.val_v[node_id] = value + end + + queue_remove(node) + queue_insert_head(self.cache_queue, node) + + if ttl then + node.expire = ngx_now() + ttl + else + node.expire = -1 + end + + if type(flags) == "number" and flags >= 0 then + node.user_flags = flags + + else + node.user_flags = 0 + end +end + + +function _M.get_keys(self, max_count, res) + if not max_count or max_count == 0 then + max_count = self.num_items + end + + if not res then + res = new_tab(max_count + 1, 0) -- + 1 for trailing hole + end + + local cache_queue = self.cache_queue + local key_v = self.key_v + + local i = 0 + local node = queue_head(cache_queue) + + while node ~= cache_queue do + if i >= max_count then + break + end + + i = i + 1 + res[i] = key_v[node.id] + node = node.next + end + + res[i + 1] = nil + + return res +end + + +function _M.flush_all(self) + local cache_queue = self.cache_queue + local key_v = self.key_v + + local node = queue_head(cache_queue) + + while node ~= cache_queue do + local key = key_v[node.id] + node = node.next + _M.delete(self, key) + end +end + + +return _M diff --git a/resty/md5.lua b/resty/md5.lua new file mode 100644 index 0000000..d01e350 --- /dev/null +++ b/resty/md5.lua @@ -0,0 +1,72 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +local setmetatable = setmetatable +--local error = error + + +local _M = { _VERSION = '0.14' } + +local mt = { __index = _M } + + +ffi.cdef[[ +typedef unsigned long MD5_LONG ; + +enum { + MD5_CBLOCK = 64, + MD5_LBLOCK = MD5_CBLOCK/4 +}; + +typedef struct MD5state_st + { + MD5_LONG A,B,C,D; + MD5_LONG Nl,Nh; + MD5_LONG data[MD5_LBLOCK]; + unsigned int num; + } MD5_CTX; + +int MD5_Init(MD5_CTX *c); +int MD5_Update(MD5_CTX *c, const void *data, size_t len); +int MD5_Final(unsigned char *md, MD5_CTX *c); +]] + +local buf = ffi_new("char[16]") +local ctx_ptr_type = ffi.typeof("MD5_CTX[1]") + + +function _M.new(self) + local ctx = ffi_new(ctx_ptr_type) + if C.MD5_Init(ctx) == 0 then + return nil + end + + return setmetatable({ _ctx = ctx }, mt) +end + + +function _M.update(self, s) + return C.MD5_Update(self._ctx, s, #s) == 1 +end + + +function _M.final(self) + if C.MD5_Final(buf, self._ctx) == 1 then + return ffi_str(buf, 16) + end + + return nil +end + + +function _M.reset(self) + return C.MD5_Init(self._ctx) == 1 +end + + +return _M + diff --git a/resty/memcached.lua b/resty/memcached.lua new file mode 100644 index 0000000..5cf384e --- /dev/null +++ b/resty/memcached.lua @@ -0,0 +1,744 @@ +-- Copyright (C) Yichun Zhang (agentzh), CloudFlare Inc. + + +local escape_uri = ngx.escape_uri +local unescape_uri = ngx.unescape_uri +local match = string.match +local tcp = ngx.socket.tcp +local strlen = string.len +local concat = table.concat +local setmetatable = setmetatable +local type = type + + +local _M = { + _VERSION = '0.16' +} + + +local mt = { __index = _M } + + +function _M.new(self, opts) + local sock, err = tcp() + if not sock then + return nil, err + end + + local escape_key = escape_uri + local unescape_key = unescape_uri + + if opts then + local key_transform = opts.key_transform + + if key_transform then + escape_key = key_transform[1] + unescape_key = key_transform[2] + if not escape_key or not unescape_key then + return nil, "expecting key_transform = { escape, unescape } table" + end + end + end + + return setmetatable({ + sock = sock, + escape_key = escape_key, + unescape_key = unescape_key, + }, mt) +end + + +local function set_timeouts(self, connect, send, read) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + sock:settimeouts(connect, send, read) + return 1 +end +_M.set_timeouts = set_timeouts + +function _M.set_timeout(self, timeout) + return set_timeouts(self, timeout, timeout, timeout) +end + + +function _M.connect(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:connect(...) +end + + +local function _multi_get(self, keys) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local nkeys = #keys + + if nkeys == 0 then + return {}, nil + end + + local escape_key = self.escape_key + local cmd = {"get"} + local n = 1 + + for i = 1, nkeys do + cmd[n + 1] = " " + cmd[n + 2] = escape_key(keys[i]) + n = n + 2 + end + cmd[n + 1] = "\r\n" + + -- print("multi get cmd: ", cmd) + + local bytes, err = sock:send(cmd) + if not bytes then + return nil, err + end + + local unescape_key = self.unescape_key + local results = {} + + while true do + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if line == 'END' then + break + end + + local key, flags, len = match(line, '^VALUE (%S+) (%d+) (%d+)$') + -- print("key: ", key, "len: ", len, ", flags: ", flags) + + if not key then + return nil, line + end + + local data, err = sock:receive(len) + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + + results[unescape_key(key)] = {data, flags} + + data, err = sock:receive(2) -- discard the trailing CRLF + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + end + + return results +end + + +function _M.get(self, key) + if type(key) == "table" then + return _multi_get(self, key) + end + + local sock = self.sock + if not sock then + return nil, nil, "not initialized" + end + + local bytes, err = sock:send("get " .. self.escape_key(key) .. "\r\n") + if not bytes then + return nil, nil, err + end + + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, nil, err + end + + if line == 'END' then + return nil, nil, nil + end + + local flags, len = match(line, '^VALUE %S+ (%d+) (%d+)$') + if not flags then + return nil, nil, line + end + + -- print("len: ", len, ", flags: ", flags) + + local data, err = sock:receive(len) + if not data then + if err == "timeout" then + sock:close() + end + return nil, nil, err + end + + line, err = sock:receive(7) -- discard the trailing "\r\nEND\r\n" + if not line then + if err == "timeout" then + sock:close() + end + return nil, nil, err + end + + return data, flags +end + + +local function _multi_gets(self, keys) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local nkeys = #keys + + if nkeys == 0 then + return {}, nil + end + + local escape_key = self.escape_key + local cmd = {"gets"} + local n = 1 + for i = 1, nkeys do + cmd[n + 1] = " " + cmd[n + 2] = escape_key(keys[i]) + n = n + 2 + end + cmd[n + 1] = "\r\n" + + -- print("multi get cmd: ", cmd) + + local bytes, err = sock:send(cmd) + if not bytes then + return nil, err + end + + local unescape_key = self.unescape_key + local results = {} + + while true do + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if line == 'END' then + break + end + + local key, flags, len, cas_uniq = + match(line, '^VALUE (%S+) (%d+) (%d+) (%d+)$') + + -- print("key: ", key, "len: ", len, ", flags: ", flags) + + if not key then + return nil, line + end + + local data, err = sock:receive(len) + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + + results[unescape_key(key)] = {data, flags, cas_uniq} + + data, err = sock:receive(2) -- discard the trailing CRLF + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + end + + return results +end + + +function _M.gets(self, key) + if type(key) == "table" then + return _multi_gets(self, key) + end + + local sock = self.sock + if not sock then + return nil, nil, nil, "not initialized" + end + + local bytes, err = sock:send("gets " .. self.escape_key(key) .. "\r\n") + if not bytes then + return nil, nil, nil, err + end + + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, nil, nil, err + end + + if line == 'END' then + return nil, nil, nil, nil + end + + local flags, len, cas_uniq = match(line, '^VALUE %S+ (%d+) (%d+) (%d+)$') + if not flags then + return nil, nil, nil, line + end + + -- print("len: ", len, ", flags: ", flags) + + local data, err = sock:receive(len) + if not data then + if err == "timeout" then + sock:close() + end + return nil, nil, nil, err + end + + line, err = sock:receive(7) -- discard the trailing "\r\nEND\r\n" + if not line then + if err == "timeout" then + sock:close() + end + return nil, nil, nil, err + end + + return data, flags, cas_uniq +end + + +local function _expand_table(value) + local segs = {} + local nelems = #value + local nsegs = 0 + for i = 1, nelems do + local seg = value[i] + nsegs = nsegs + 1 + if type(seg) == "table" then + segs[nsegs] = _expand_table(seg) + else + segs[nsegs] = seg + end + end + return concat(segs) +end + + +local function _store(self, cmd, key, value, exptime, flags) + if not exptime then + exptime = 0 + end + + if not flags then + flags = 0 + end + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + if type(value) == "table" then + value = _expand_table(value) + end + + local req = cmd .. " " .. self.escape_key(key) .. " " .. flags .. " " + .. exptime .. " " .. strlen(value) .. "\r\n" .. value + .. "\r\n" + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + local data, err = sock:receive() + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if data == "STORED" then + return 1 + end + + return nil, data +end + + +function _M.set(self, ...) + return _store(self, "set", ...) +end + + +function _M.add(self, ...) + return _store(self, "add", ...) +end + + +function _M.replace(self, ...) + return _store(self, "replace", ...) +end + + +function _M.append(self, ...) + return _store(self, "append", ...) +end + + +function _M.prepend(self, ...) + return _store(self, "prepend", ...) +end + + +function _M.cas(self, key, value, cas_uniq, exptime, flags) + if not exptime then + exptime = 0 + end + + if not flags then + flags = 0 + end + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local req = "cas " .. self.escape_key(key) .. " " .. flags .. " " + .. exptime .. " " .. strlen(value) .. " " .. cas_uniq + .. "\r\n" .. value .. "\r\n" + + -- local cjson = require "cjson" + -- print("request: ", cjson.encode(req)) + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + -- print("response: [", line, "]") + + if line == "STORED" then + return 1 + end + + return nil, line +end + + +function _M.delete(self, key) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + key = self.escape_key(key) + + local req = "delete " .. key .. "\r\n" + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + local res, err = sock:receive() + if not res then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if res ~= 'DELETED' then + return nil, res + end + + return 1 +end + + +function _M.set_keepalive(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:setkeepalive(...) +end + + +function _M.get_reused_times(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:getreusedtimes() +end + + +function _M.flush_all(self, time) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local req + if time then + req = "flush_all " .. time .. "\r\n" + else + req = "flush_all\r\n" + end + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + local res, err = sock:receive() + if not res then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if res ~= 'OK' then + return nil, res + end + + return 1 +end + + +local function _incr_decr(self, cmd, key, value) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local req = cmd .. " " .. self.escape_key(key) .. " " .. value .. "\r\n" + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if not match(line, '^%d+$') then + return nil, line + end + + return line +end + + +function _M.incr(self, key, value) + return _incr_decr(self, "incr", key, value) +end + + +function _M.decr(self, key, value) + return _incr_decr(self, "decr", key, value) +end + + +function _M.stats(self, args) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local req + if args then + req = "stats " .. args .. "\r\n" + else + req = "stats\r\n" + end + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + local lines = {} + local n = 0 + while true do + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if line == 'END' then + return lines, nil + end + + if not match(line, "ERROR") then + n = n + 1 + lines[n] = line + else + return nil, line + end + end + + -- cannot reach here... + return lines +end + + +function _M.version(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local bytes, err = sock:send("version\r\n") + if not bytes then + return nil, err + end + + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + local ver = match(line, "^VERSION (.+)$") + if not ver then + return nil, ver + end + + return ver +end + + +function _M.quit(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local bytes, err = sock:send("quit\r\n") + if not bytes then + return nil, err + end + + return 1 +end + + +function _M.verbosity(self, level) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local bytes, err = sock:send("verbosity " .. level .. "\r\n") + if not bytes then + return nil, err + end + + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + if line ~= 'OK' then + return nil, line + end + + return 1 +end + + +function _M.touch(self, key, exptime) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local bytes, err = sock:send("touch " .. self.escape_key(key) .. " " + .. exptime .. "\r\n") + if not bytes then + return nil, err + end + + local line, err = sock:receive() + if not line then + if err == "timeout" then + sock:close() + end + return nil, err + end + + -- moxi server from couchbase returned stored after touching + if line == "TOUCHED" or line =="STORED" then + return 1 + end + return nil, line +end + + +function _M.close(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:close() +end + + +return _M diff --git a/resty/mysql.lua b/resty/mysql.lua new file mode 100644 index 0000000..3b97e88 --- /dev/null +++ b/resty/mysql.lua @@ -0,0 +1,1410 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local bit = require "bit" +local resty_sha256 = require "resty.sha256" +local sub = string.sub +local tcp = ngx.socket.tcp +local strbyte = string.byte +local strchar = string.char +local strfind = string.find +local format = string.format +local strrep = string.rep +local null = ngx.null +local band = bit.band +local bxor = bit.bxor +local bor = bit.bor +local lshift = bit.lshift +local rshift = bit.rshift +local tohex = bit.tohex +local sha1 = ngx.sha1_bin +local concat = table.concat +local setmetatable = setmetatable +local error = error +local tonumber = tonumber +local to_int = math.floor + +local has_rsa, resty_rsa = pcall(require, "resty.rsa") + + +if not ngx.config then + error("ngx_lua 0.9.11+ or ngx_stream_lua required") +end + +if (not ngx.config.subsystem + or ngx.config.subsystem == "http") -- subsystem is http + and (not ngx.config.ngx_lua_version + or ngx.config.ngx_lua_version < 9011) -- old version +then + error("ngx_lua 0.9.11+ required") +end + + +local ok, new_tab = pcall(require, "table.new") +if not ok then + new_tab = function (narr, nrec) return {} end +end + + +local _M = { _VERSION = '0.24' } + + +-- constants + +local STATE_CONNECTED = 1 +local STATE_COMMAND_SENT = 2 + +local COM_QUIT = 0x01 +local COM_QUERY = 0x03 + +-- refer to https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags +-- CLIENT_LONG_PASSWORD | CLIENT_FOUND_ROWS | CLIENT_LONG_FLAG +-- | CLIENT_CONNECT_WITH_DB | CLIENT_ODBC | CLIENT_LOCAL_FILES +-- | CLIENT_IGNORE_SPACE | CLIENT_PROTOCOL_41 | CLIENT_INTERACTIVE +-- | CLIENT_IGNORE_SIGPIPE | CLIENT_TRANSACTIONS | CLIENT_RESERVED +-- | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_STATEMENTS | CLIENT_MULTI_RESULTS +local DEFAULT_CLIENT_FLAGS = 0x3f7cf +local CLIENT_SSL = 0x00000800 +local CLIENT_PLUGIN_AUTH = 0x00080000 +local CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000 + +local SERVER_MORE_RESULTS_EXISTS = 8 + +local RESP_OK = "OK" +local RESP_AUTHMOREDATA = "AUTHMOREDATA" +local RESP_LOCALINFILE = "LOCALINFILE" +local RESP_EOF = "EOF" +local RESP_ERR = "ERR" +local RESP_DATA = "DATA" + +local MY_RND_MAX_VAL = 0x3FFFFFFF +local MIN_PROTOCOL_VER = 10 + +local LEN_NATIVE_SCRAMBLE = 20 +local LEN_OLD_SCRAMBLE = 8 + +-- 16MB - 1, the default max allowed packet size used by libmysqlclient +local FULL_PACKET_SIZE = 16777215 + +-- the following charset map is generated from the following mysql query: +-- SELECT CHARACTER_SET_NAME, ID +-- FROM information_schema.collations +-- WHERE IS_DEFAULT = 'Yes' ORDER BY id; +local CHARSET_MAP = { + _default = 0, + big5 = 1, + dec8 = 3, + cp850 = 4, + hp8 = 6, + koi8r = 7, + latin1 = 8, + latin2 = 9, + swe7 = 10, + ascii = 11, + ujis = 12, + sjis = 13, + hebrew = 16, + tis620 = 18, + euckr = 19, + koi8u = 22, + gb2312 = 24, + greek = 25, + cp1250 = 26, + gbk = 28, + latin5 = 30, + armscii8 = 32, + utf8 = 33, + ucs2 = 35, + cp866 = 36, + keybcs2 = 37, + macce = 38, + macroman = 39, + cp852 = 40, + latin7 = 41, + utf8mb4 = 45, + cp1251 = 51, + utf16 = 54, + utf16le = 56, + cp1256 = 57, + cp1257 = 59, + utf32 = 60, + binary = 63, + geostd8 = 92, + cp932 = 95, + eucjpms = 97, + gb18030 = 248 +} + +local mt = { __index = _M } + + +-- mysql field value type converters +local converters = new_tab(0, 9) + +for i = 0x01, 0x05 do + -- tiny, short, long, float, double + converters[i] = tonumber +end +converters[0x00] = tonumber -- decimal +-- converters[0x08] = tonumber -- long long +converters[0x09] = tonumber -- int24 +converters[0x0d] = tonumber -- year +converters[0xf6] = tonumber -- newdecimal + + +local function _get_byte2(data, i) + local a, b = strbyte(data, i, i + 1) + return bor(a, lshift(b, 8)), i + 2 +end + + +local function _get_byte3(data, i) + local a, b, c = strbyte(data, i, i + 2) + return bor(a, lshift(b, 8), lshift(c, 16)), i + 3 +end + + +local function _get_byte4(data, i) + local a, b, c, d = strbyte(data, i, i + 3) + return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4 +end + + +local function _get_byte8(data, i) + local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7) + + -- XXX workaround for the lack of 64-bit support in bitop: + -- XXX return results in the range of signed 32 bit numbers + local lo = bor(a, lshift(b, 8), lshift(c, 16)) + local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24)) + return lo + 16777216 * d + hi * 4294967296, i + 8 + + -- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32), + -- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8 +end + + +local function _set_byte2(n) + return strchar(band(n, 0xff), band(rshift(n, 8), 0xff)) +end + + +local function _set_byte3(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff)) +end + + +local function _set_byte4(n) + return strchar(band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff), + band(rshift(n, 24), 0xff)) +end + + +local function _from_cstring(data, i) + local last = strfind(data, "\0", i, true) + if not last then + return nil, nil + end + + return sub(data, i, last - 1), last + 1 +end + + +local function _to_cstring(data) + return data .. "\0" +end + + +local function _dump(data) + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = format("%x", strbyte(data, i)) + end + return concat(bytes, " ") +end + + +local function _dumphex(data) + local len = #data + local bytes = new_tab(len, 0) + for i = 1, len do + bytes[i] = tohex(strbyte(data, i), 2) + end + return concat(bytes, " ") +end + + +local function _pwd_hash(password) + local add = 7 + + local hash1 = 1345345333 + local hash2 = 0x12345671 + + local len = #password + for i = 1, len do + -- skip spaces and tabs in password + local byte = strbyte(password, i) + if byte ~= 32 and byte ~= 9 then -- not ' ' or '\t' + hash1 = bxor(hash1, (band(hash1, 63) + add) * byte + + lshift(hash1, 8)) + + hash2 = bxor(lshift(hash2, 8), hash1) + hash2 + + add = add + byte + end + end + + -- remove sign bit (1<<31)-1) + return band(hash1, 0x7FFFFFFF), band(hash2, 0x7FFFFFFF) +end + + +local function _random_byte(seed1, seed2) + seed1 = (seed1 * 3 + seed2) % MY_RND_MAX_VAL + seed2 = (seed1 + seed2 + 33) % MY_RND_MAX_VAL + + return to_int(seed1 * 31 / MY_RND_MAX_VAL), seed1, seed2 +end + + +local function _compute_old_token(password, scramble) + if password == "" then + return "" + end + + scramble = sub(scramble, 1, LEN_OLD_SCRAMBLE) + + local hash_pw1, hash_pw2 = _pwd_hash(password) + local hash_sc1, hash_sc2 = _pwd_hash(scramble) + + local seed1 = bxor(hash_pw1, hash_sc1) % MY_RND_MAX_VAL + local seed2 = bxor(hash_pw2, hash_sc2) % MY_RND_MAX_VAL + local rand_byte + + local bytes = new_tab(LEN_OLD_SCRAMBLE, 0) + for i = 1, LEN_OLD_SCRAMBLE do + rand_byte, seed1, seed2 = _random_byte(seed1, seed2) + bytes[i] = rand_byte + 64 + end + + rand_byte = _random_byte(seed1, seed2) + for i = 1, LEN_OLD_SCRAMBLE do + bytes[i] = strchar(bxor(bytes[i], rand_byte)) + end + + return _to_cstring(concat(bytes)) +end + + +local function _compute_sha256_token(password, scramble) + if password == "" then + return "" + end + + local sha256 = resty_sha256:new() + if not sha256 then + return nil, "failed to create the sha256 object" + end + + if not sha256:update(password) then + return nil, "failed to update string to sha256" + end + + local message1 = sha256:final() + + sha256:reset() + + if not sha256:update(message1) then + return nil, "failed to update string to sha256" + end + + local message1_hash = sha256:final() + + sha256:reset() + + if not sha256:update(message1_hash) then + return nil, "failed to update string to sha256" + end + + if not sha256:update(scramble) then + return nil, "failed to update string to sha256" + end + + local message2 = sha256:final() + + local n = #message2 + local bytes = new_tab(n, 0) + for i = 1, n do + bytes[i] = strchar(bxor(strbyte(message1, i), strbyte(message2, i))) + end + + return concat(bytes) +end + + +local function _compute_token(password, scramble) + if password == "" then + return "" + end + + scramble = sub(scramble, 1, LEN_NATIVE_SCRAMBLE) + + local stage1 = sha1(password) + local stage2 = sha1(stage1) + local stage3 = sha1(scramble .. stage2) + local n = #stage1 + local bytes = new_tab(n, 0) + for i = 1, n do + bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i))) + end + + return concat(bytes) +end + + +local function _send_packet(self, req, size) + local sock = self.sock + + self.packet_no = self.packet_no + 1 + + -- print("packet no: ", self.packet_no) + + local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req + + -- print("sending packet: ", _dump(packet)) + + -- print("sending packet... of size " .. #packet) + + return sock:send(packet) +end + + +local function _recv_packet(self) + local sock = self.sock + + local data, err = sock:receive(4) -- packet header + if not data then + return nil, nil, "failed to receive packet header: " .. err + end + + --print("packet header: ", _dump(data)) + + local len, pos = _get_byte3(data, 1) + + --print("packet length: ", len) + + if len == 0 then + return nil, nil, "empty packet" + end + + if len > self._max_packet_size then + return nil, nil, "packet size too big: " .. len + end + + local num = strbyte(data, pos) + + --print("recv packet: packet no: ", num) + + self.packet_no = num + + data, err = sock:receive(len) + + --print("receive returned") + + if not data then + return nil, nil, "failed to read packet content: " .. err + end + + --print("packet content: ", _dump(data)) + --print("packet content (ascii): ", data) + + local field_count = strbyte(data, 1) + + local typ + if field_count == 0x00 then + typ = RESP_OK + elseif field_count == 0x01 then + typ = RESP_AUTHMOREDATA + elseif field_count == 0xfb then + typ = RESP_LOCALINFILE + elseif field_count == 0xfe then + typ = RESP_EOF + elseif field_count == 0xff then + typ = RESP_ERR + else + typ = RESP_DATA + end + + return data, typ +end + + +local function _from_length_coded_bin(data, pos) + local first = strbyte(data, pos) + + --print("LCB: first: ", first) + + if not first then + return nil, pos + end + + if first >= 0 and first <= 250 then + return first, pos + 1 + end + + if first == 251 then + return null, pos + 1 + end + + if first == 252 then + pos = pos + 1 + return _get_byte2(data, pos) + end + + if first == 253 then + pos = pos + 1 + return _get_byte3(data, pos) + end + + if first == 254 then + pos = pos + 1 + return _get_byte8(data, pos) + end + + return nil, pos + 1 +end + + +local function _from_length_coded_str(data, pos) + local len + len, pos = _from_length_coded_bin(data, pos) + if not len or len == null then + return null, pos + end + + return sub(data, pos, pos + len - 1), pos + len +end + + +local function _parse_ok_packet(packet) + local res = new_tab(0, 5) + local pos + + res.affected_rows, pos = _from_length_coded_bin(packet, 2) + + --print("affected rows: ", res.affected_rows, ", pos:", pos) + + res.insert_id, pos = _from_length_coded_bin(packet, pos) + + --print("insert id: ", res.insert_id, ", pos:", pos) + + res.server_status, pos = _get_byte2(packet, pos) + + --print("server status: ", res.server_status, ", pos:", pos) + + res.warning_count, pos = _get_byte2(packet, pos) + + --print("warning count: ", res.warning_count, ", pos: ", pos) + + local message = _from_length_coded_str(packet, pos) + if message and message ~= null then + res.message = message + end + + --print("message: ", res.message, ", pos:", pos) + + return res +end + + +local function _parse_eof_packet(packet) + local pos = 2 + + local warning_count, pos = _get_byte2(packet, pos) + local status_flags = _get_byte2(packet, pos) + + return warning_count, status_flags +end + + +local function _parse_err_packet(packet) + local errno, pos = _get_byte2(packet, 2) + local marker = sub(packet, pos, pos) + local sqlstate + if marker == '#' then + -- with sqlstate + pos = pos + 1 + sqlstate = sub(packet, pos, pos + 5 - 1) + pos = pos + 5 + end + + local message = sub(packet, pos) + return errno, message, sqlstate +end + + +local function _parse_result_set_header_packet(packet) + local field_count, pos = _from_length_coded_bin(packet, 1) + + local extra + extra = _from_length_coded_bin(packet, pos) + + return field_count, extra +end + + +local function _parse_field_packet(data) + local col = new_tab(0, 2) + local catalog, db, table, orig_table, orig_name, charsetnr, length + local pos + catalog, pos = _from_length_coded_str(data, 1) + + --print("catalog: ", col.catalog, ", pos:", pos) + + db, pos = _from_length_coded_str(data, pos) + table, pos = _from_length_coded_str(data, pos) + orig_table, pos = _from_length_coded_str(data, pos) + col.name, pos = _from_length_coded_str(data, pos) + + orig_name, pos = _from_length_coded_str(data, pos) + + pos = pos + 1 -- ignore the filler + + charsetnr, pos = _get_byte2(data, pos) + + length, pos = _get_byte4(data, pos) + + col.type = strbyte(data, pos) + + --[[ + pos = pos + 1 + + col.flags, pos = _get_byte2(data, pos) + + col.decimals = strbyte(data, pos) + pos = pos + 1 + + local default = sub(data, pos + 2) + if default and default ~= "" then + col.default = default + end + --]] + + return col +end + + +local function _parse_row_data_packet(data, cols, compact) + local pos = 1 + local ncols = #cols + local row + if compact then + row = new_tab(ncols, 0) + else + row = new_tab(0, ncols) + end + for i = 1, ncols do + local value + value, pos = _from_length_coded_str(data, pos) + local col = cols[i] + local typ = col.type + local name = col.name + + --print("row field value: ", value, ", type: ", typ) + + if value ~= null then + local conv = converters[typ] + if conv then + value = conv(value) + end + end + + if compact then + row[i] = value + + else + row[name] = value + end + end + + return row +end + + +local function _recv_field_packet(self) + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == RESP_ERR then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ ~= RESP_DATA then + return nil, "bad field packet type: " .. typ + end + + -- typ == RESP_DATA + + return _parse_field_packet(packet) +end + + +-- refer to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html +local function _read_hand_shake_packet(self) + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, nil, err + end + + if typ == RESP_ERR then + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, nil, msg, errno, sqlstate + end + + local protocol_ver = tonumber(strbyte(packet)) + if not protocol_ver then + return nil, nil, + "bad handshake initialization packet: bad protocol version" + end + + if protocol_ver < MIN_PROTOCOL_VER then + return nil, nil, "unsupported protocol version " .. protocol_ver + .. ", version " .. MIN_PROTOCOL_VER + .. " or higher is required" + end + + self.protocol_ver = protocol_ver + + local server_ver, pos = _from_cstring(packet, 2) + if not server_ver then + return nil, nil, + "bad handshake initialization packet: bad server version" + end + + self._server_ver = server_ver + + local thread_id, pos = _get_byte4(packet, pos) + + local scramble = sub(packet, pos, pos + 8 - 1) + if not scramble then + return nil, nil, "1st part of scramble not found" + end + + pos = pos + 9 -- skip filler(8 + 1) + + -- two lower bytes + local capabilities -- server capabilities + capabilities, pos = _get_byte2(packet, pos) + + self._server_lang = strbyte(packet, pos) + pos = pos + 1 + + self._server_status, pos = _get_byte2(packet, pos) + + local more_capabilities + more_capabilities, pos = _get_byte2(packet, pos) + + self.capabilities = bor(capabilities, lshift(more_capabilities, 16)) + + pos = pos + 11 -- skip length of auth-plugin-data(1) and reserved(10) + + -- follow official Python library uses the fixed length 12 + -- and the 13th byte is "\0 byte + local scramble_part2 = sub(packet, pos, pos + 12 - 1) + if not scramble_part2 then + return nil, nil, "2nd part of scramble not found" + end + + pos = pos + 13 + + local plugin, _ = _from_cstring(packet, pos) + if not plugin then + -- EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) + -- \NUL otherwise + plugin = sub(packet, pos) + end + + return scramble .. scramble_part2, plugin +end + + +local function _append_auth_length(self, data) + local n = #data + + if n <= 250 then + data = strchar(n) .. data + return data, 1 + n + end + + self.DEFAULT_CLIENT_FLAGS = bor(self.DEFAULT_CLIENT_FLAGS, + CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) + + if n <= 0xffff then + data = strchar(0xfc, band(n, 0xff), band(rshift(n, 8), 0xff)) .. data + return data, 3 + n + end + + if n <= 0xffffff then + data = strchar(0xfd, + band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff)) + .. data + return data, 4 + n + end + + data = strchar(0xfe, + band(n, 0xff), + band(rshift(n, 8), 0xff), + band(rshift(n, 16), 0xff), + band(rshift(n, 24), 0xff), + band(rshift(n, 32), 0xff), + band(rshift(n, 40), 0xff), + band(rshift(n, 48), 0xff), + band(rshift(n, 56), 0xff)) + .. data + return data, 9 + n +end + + +local function _write_hand_shake_response(self, auth_resp, plugin) + local append_auth, len = _append_auth_length(self, auth_resp) + + if self.use_ssl then + if band(self.capabilities, CLIENT_SSL) == 0 then + return "ssl disabled on server" + end + + -- send a SSL Request Packet + local req = _set_byte4(bor(self.DEFAULT_CLIENT_FLAGS, CLIENT_SSL)) + .. _set_byte4(self._max_packet_size) + .. strchar(self.charset) + .. strrep("\0", 23) + + local packet_len = 4 + 4 + 1 + 23 + local bytes, err = _send_packet(self, req, packet_len) + if not bytes then + return "failed to send client authentication packet: " .. err + end + + local sock = self.sock + + local ok, err = sock:sslhandshake(false, nil, self.ssl_verify) + if not ok then + return "failed to do ssl handshake: " .. (err or "") + end + end + + local req = _set_byte4(self.DEFAULT_CLIENT_FLAGS) + .. _set_byte4(self._max_packet_size) + .. strchar(self.charset) + .. strrep("\0", 23) + .. _to_cstring(self.user) + .. append_auth + .. _to_cstring(self.database) + .. _to_cstring(plugin) + + local packet_len = 4 + 4 + 1 + 23 + #self.user + 1 + + len + #self.database + 1 + #plugin + 1 + + local bytes, err = _send_packet(self, req, packet_len) + if not bytes then + return "failed to send client authentication packet: " .. err + end + + return nil +end + + +local function _read_auth_result(self, old_auth_data, plugin) + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, nil, "failed to receive the result packet: " .. err + end + + if typ == RESP_OK then + return RESP_OK, "" + end + + if typ == RESP_AUTHMOREDATA then + return sub(packet, 2), "" + end + + if typ == RESP_EOF then + if #packet == 1 then -- old pre-4.1 authentication protocol + return nil, "mysql_old_password" + end + + local pos + + plugin, pos = _from_cstring(packet, 2) + if not plugin then + return nil, nil, "malformed packet" + end + + return sub(packet, pos), plugin + end + + if typ == RESP_ERR then + local errno, msg, sqlstate = _parse_err_packet(packet) + return errno, sqlstate, msg + end + + return nil, nil, "bad packet type: " .. typ +end + + +local function _read_ok_result(self) + local packet, typ, err = _recv_packet(self) + if not packet then + return "failed to receive the result packet: " .. err + end + + if typ == RESP_ERR then + local errno, msg, sqlstate = _parse_err_packet(packet) + return msg, errno, sqlstate + end + + if typ ~= RESP_OK then + return "bad packet type: " .. typ + end +end + + +local function _encrypt_password(self, auth_data, public_key) + if not has_rsa then + error("auth plugin caching_sha2_password or sha256_password are not" .. + " supported because resty.rsa is not installed", 2) + end + + local password = _to_cstring(self.password) + local n = #password + local l = #auth_data + local bytes = new_tab(n, 0) + + for i = 1, n do + local j = i % l + bytes[i] = strchar(bxor(strbyte(password, i), strbyte(auth_data, j))) + end + + local pub, err = resty_rsa:new({ + public_key = public_key, + key_type = resty_rsa.KEY_TYPE.PKCS8, + padding = resty_rsa.PADDING.RSA_PKCS1_OAEP_PADDING, + algorithm = "sha1", + }) + if not pub then + return nil, "new rsa err: " .. err + end + + local enc, err = pub:encrypt(concat(bytes)) + if not enc then + return nil, "encode password packet: " .. err + end + + return enc +end + + +local function _write_encode_password(self, auth_data, public_key) + local enc, err = _encrypt_password(self, auth_data, public_key) + + local bytes, err = _send_packet(self, enc, #enc) + if not bytes then + return "failed to send encode password packet: " .. err + end +end + + +local function _auth(self, auth_data, plugin) + local password = self.password + + if plugin == "caching_sha2_password" then + local auth_resp, err = _compute_sha256_token(password, auth_data) + if err then + return nil, "failed to compute sha256 token: " .. err + end + + return auth_resp + end + + if plugin == "mysql_old_password" then + return _compute_old_token(password, auth_data) + end + + if plugin == "mysql_clear_password" then + return _to_cstring(password) + end + + if plugin == "mysql_native_password" then + return _compute_token(password, auth_data) + end + + if plugin == "sha256_password" then + if self.is_unix or self.use_ssl or #password == 0 then + return _to_cstring(password) + end + + local public_key = self.public_key + if public_key then + return _encrypt_password(self, auth_data, public_key) + end + + return "\1" -- request public key from server + end + + return nil, "unknown plugin: " .. plugin +end + + +local function _handle_auth_result(self, old_auth_data, plugin) + local auth_data, new_plugin, err = _read_auth_result(self, old_auth_data, + plugin) + + if err ~= nil then + local errno, sqlstate = auth_data, new_plugin + return err, errno, sqlstate + end + + if auth_data == RESP_OK then + return + end + + if new_plugin ~= "" then + if not auth_data then + auth_data = old_auth_data + else + old_auth_data = auth_data + end + + plugin = new_plugin + + local auth_resp, err = _auth(self, auth_data, plugin) + if not auth_resp then + return err + end + + local bytes, err = _send_packet(self, auth_resp, #auth_resp) + if not bytes then + return "failed to send client authentication packet: " .. err + end + + auth_data, new_plugin, err = _read_auth_result(self, old_auth_data, + plugin) + + if err ~= nil then + local errno, sqlstate = auth_data, new_plugin + return err, errno, sqlstate + end + + if auth_data == RESP_OK then + return + end + + if new_plugin ~= "" then + return "malformed packet" + end + end + + if plugin == "caching_sha2_password" then + local len = #auth_data + if len == 0 then + return + end + + if len == 1 then + local status = strbyte(auth_data) + -- caching_sha2_password fast auth success + if status == 3 then + return _read_ok_result(self) + end + + -- caching_sha2_password perform full authentication + if status == 4 then + if self.is_unix or self.use_ssl then + local bytes, err = _send_packet(self, + _to_cstring(self.password), + #self.password + 1) + + if not bytes then + return "failed to send cleartext auth packet: " + .. err + end + + else + local public_key = self.public_key + if not public_key then + -- caching_sha2_password request public_key + local bytes, err = _send_packet(self, "\2", 1) + if not bytes then + return "failed to send password request packet: " + .. err + end + + local packet, _, err = _recv_packet(self) + if not packet then + return "failed to receive the result packet: " + .. err + end + + public_key = sub(packet, 2) + end + + err = _write_encode_password(self, old_auth_data, + public_key) + + if err then + return err + end + + self.public_key = public_key + end + + return _read_ok_result(self) + end + end + + return "malformed packet" + end + + if plugin == "sha256_password" then + if #auth_data ~= 0 then + local enc, err = _write_encode_password(self, old_auth_data, + auth_data) + + if err then + return err + end + + return _read_ok_result(self) + end + end +end + + +function _M.new(self) + local sock, err = tcp() + if not sock then + return nil, err + end + return setmetatable({ sock = sock }, mt) +end + + +function _M.set_timeout(self, timeout) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:settimeout(timeout) +end + + +function _M.connect(self, opts) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local max_packet_size = opts.max_packet_size + if not max_packet_size then + max_packet_size = 1024 * 1024 -- default 1 MB + end + self._max_packet_size = max_packet_size + + local ok, err + + self.compact = opts.compact_arrays + + self.database = opts.database or "" + self.user = opts.user or "" + + self.charset = CHARSET_MAP[opts.charset or "_default"] + if not self.charset then + return nil, "charset '" .. opts.charset .. "' is not supported" + end + + local pool = opts.pool + + self.ssl_verify = opts.ssl_verify + self.use_ssl = opts.ssl or opts.ssl_verify + + self.password = opts.password or "" + + local host = opts.host + if host then + local port = opts.port or 3306 + if not pool then + pool = self.user .. ":" .. self.database .. ":" .. host .. ":" + .. port + end + + ok, err = sock:connect(host, port, { pool = pool, + pool_size = opts.pool_size, + backlog = opts.backlog }) + + else + local path = opts.path + if not path then + return nil, 'neither "host" nor "path" options are specified' + end + + if not pool then + pool = self.user .. ":" .. self.database .. ":" .. path + end + + self.is_unix = true + ok, err = sock:connect("unix:" .. path, { pool = pool, + pool_size = opts.pool_size, + backlog = opts.backlog }) + end + + if not ok then + return nil, 'failed to connect: ' .. err + end + + local reused = sock:getreusedtimes() + + if reused and reused > 0 then + self.state = STATE_CONNECTED + return 1 + end + + self.DEFAULT_CLIENT_FLAGS = bor(DEFAULT_CLIENT_FLAGS, CLIENT_PLUGIN_AUTH) + + local auth_data, plugin, err, errno, sqlstate + = _read_hand_shake_packet(self) + + if err ~= nil then + return nil, err + end + + local auth_resp, err = _auth(self, auth_data, plugin) + if not auth_resp then + return nil, err + end + + err = _write_hand_shake_response(self, auth_resp, plugin) + if err ~= nil then + return nil, err + end + + local err, errno, sqlstate = _handle_auth_result(self, auth_data, plugin) + if err ~= nil then + return nil, err, errno, sqlstate + end + + self.state = STATE_CONNECTED + + return 1 +end + + +function _M.set_keepalive(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + if self.state ~= STATE_CONNECTED then + return nil, "cannot be reused in the current connection state: " + .. (self.state or "nil") + end + + self.state = nil + return sock:setkeepalive(...) +end + + +function _M.get_reused_times(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:getreusedtimes() +end + + +function _M.close(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + self.state = nil + + local bytes, err = _send_packet(self, strchar(COM_QUIT), 1) + if not bytes then + return nil, err + end + + return sock:close() +end + + +function _M.server_ver(self) + return self._server_ver +end + + +local function send_query(self, query) + if self.state ~= STATE_CONNECTED then + return nil, "cannot send query in the current context: " + .. (self.state or "nil") + end + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + self.packet_no = -1 + + local cmd_packet = strchar(COM_QUERY) .. query + local packet_len = 1 + #query + + local bytes, err = _send_packet(self, cmd_packet, packet_len) + if not bytes then + return nil, err + end + + self.state = STATE_COMMAND_SENT + + --print("packet sent ", bytes, " bytes") + + return bytes +end +_M.send_query = send_query + + +local function read_result(self, est_nrows) + if self.state ~= STATE_COMMAND_SENT then + return nil, "cannot read result in the current context: " + .. (self.state or "nil") + end + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == RESP_ERR then + self.state = STATE_CONNECTED + + local errno, msg, sqlstate = _parse_err_packet(packet) + return nil, msg, errno, sqlstate + end + + if typ == RESP_OK then + local res = _parse_ok_packet(packet) + if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return res, "again" + end + + self.state = STATE_CONNECTED + return res + end + + if typ == RESP_LOCALINFILE then + self.state = STATE_CONNECTED + + return nil, "packet type " .. typ .. " not supported" + end + + -- typ == RESP_DATA or RESP_AUTHMOREDATA(also mean RESP_DATA here) + + --print("read the result set header packet") + + local field_count, extra = _parse_result_set_header_packet(packet) + + --print("field count: ", field_count) + + local cols = new_tab(field_count, 0) + for i = 1, field_count do + local col, err, errno, sqlstate = _recv_field_packet(self) + if not col then + return nil, err, errno, sqlstate + end + + cols[i] = col + end + + local packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ ~= RESP_EOF then + return nil, "unexpected packet type " .. typ .. " while eof packet is " + .. "expected" + end + + -- typ == RESP_EOF + + local compact = self.compact + + local rows = new_tab(est_nrows or 4, 0) + local i = 0 + while true do + --print("reading a row") + + packet, typ, err = _recv_packet(self) + if not packet then + return nil, err + end + + if typ == RESP_EOF then + local warning_count, status_flags = _parse_eof_packet(packet) + + --print("status flags: ", status_flags) + + if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then + return rows, "again" + end + + break + end + + local row = _parse_row_data_packet(packet, cols, compact) + i = i + 1 + rows[i] = row + end + + self.state = STATE_CONNECTED + + return rows +end +_M.read_result = read_result + + +function _M.query(self, query, est_nrows) + local bytes, err = send_query(self, query) + if not bytes then + return nil, "failed to send query: " .. err + end + + return read_result(self, est_nrows) +end + + +function _M.set_compact_arrays(self, value) + self.compact = value +end + + +return _M diff --git a/resty/random.lua b/resty/random.lua new file mode 100644 index 0000000..a2703f1 --- /dev/null +++ b/resty/random.lua @@ -0,0 +1,36 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +--local setmetatable = setmetatable +--local error = error + + +local _M = { _VERSION = '0.14' } + + +ffi.cdef[[ +int RAND_bytes(unsigned char *buf, int num); +int RAND_pseudo_bytes(unsigned char *buf, int num); +]] + + +function _M.bytes(len, strong) + local buf = ffi_new("char[?]", len) + if strong then + if C.RAND_bytes(buf, len) == 0 then + return nil + end + else + C.RAND_pseudo_bytes(buf,len) + end + + return ffi_str(buf, len) +end + + +return _M + diff --git a/resty/redis.lua b/resty/redis.lua new file mode 100644 index 0000000..4ddbac3 --- /dev/null +++ b/resty/redis.lua @@ -0,0 +1,676 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local sub = string.sub +local byte = string.byte +local tab_insert = table.insert +local tab_remove = table.remove +local tcp = ngx.socket.tcp +local null = ngx.null +local ipairs = ipairs +local type = type +local pairs = pairs +local unpack = unpack +local setmetatable = setmetatable +local tonumber = tonumber +local tostring = tostring +local rawget = rawget +local select = select +--local error = error + + +local ok, new_tab = pcall(require, "table.new") +if not ok or type(new_tab) ~= "function" then + new_tab = function (narr, nrec) return {} end +end + + +local _M = new_tab(0, 55) + +_M._VERSION = '0.29' + + +local common_cmds = { + "get", "set", "mget", "mset", + "del", "incr", "decr", -- Strings + "llen", "lindex", "lpop", "lpush", + "lrange", "linsert", -- Lists + "hexists", "hget", "hset", "hmget", + --[[ "hmset", ]] "hdel", -- Hashes + "smembers", "sismember", "sadd", "srem", + "sdiff", "sinter", "sunion", -- Sets + "zrange", "zrangebyscore", "zrank", "zadd", + "zrem", "zincrby", -- Sorted Sets + "auth", "eval", "expire", "script", + "sort" -- Others +} + + +local sub_commands = { + "subscribe", "psubscribe" +} + + +local unsub_commands = { + "unsubscribe", "punsubscribe" +} + + +local mt = { __index = _M } + + +function _M.new(self) + local sock, err = tcp() + if not sock then + return nil, err + end + return setmetatable({ _sock = sock, + _subscribed = false, + _n_channel = { + unsubscribe = 0, + punsubscribe = 0, + }, + }, mt) +end + + +function _M.set_timeout(self, timeout) + local sock = rawget(self, "_sock") + if not sock then + error("not initialized", 2) + return + end + + sock:settimeout(timeout) +end + + +function _M.set_timeouts(self, connect_timeout, send_timeout, read_timeout) + local sock = rawget(self, "_sock") + if not sock then + error("not initialized", 2) + return + end + + sock:settimeouts(connect_timeout, send_timeout, read_timeout) +end + + +function _M.connect(self, host, port_or_opts, opts) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + local unix + + do + local typ = type(host) + if typ ~= "string" then + error("bad argument #1 host: string expected, got " .. typ, 2) + end + + if sub(host, 1, 5) == "unix:" then + unix = true + end + + if unix then + typ = type(port_or_opts) + if port_or_opts ~= nil and typ ~= "table" then + error("bad argument #2 opts: nil or table expected, got " .. + typ, 2) + end + + else + typ = type(port_or_opts) + if typ ~= "number" then + port_or_opts = tonumber(port_or_opts) + if port_or_opts == nil then + error("bad argument #2 port: number expected, got " .. + typ, 2) + end + end + + if opts ~= nil then + typ = type(opts) + if typ ~= "table" then + error("bad argument #3 opts: nil or table expected, got " .. + typ, 2) + end + end + end + + end + + self._subscribed = false + + local ok, err + + if unix then + -- second argument of sock:connect() cannot be nil + if port_or_opts ~= nil then + ok, err = sock:connect(host, port_or_opts) + opts = port_or_opts + else + ok, err = sock:connect(host) + end + else + ok, err = sock:connect(host, port_or_opts, opts) + end + + if not ok then + return ok, err + end + + if opts and opts.ssl then + ok, err = sock:sslhandshake(false, opts.server_name, opts.ssl_verify) + if not ok then + return ok, "failed to do ssl handshake: " .. err + end + end + + return ok, err +end + + +function _M.set_keepalive(self, ...) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + if rawget(self, "_subscribed") then + return nil, "subscribed state" + end + + return sock:setkeepalive(...) +end + + +function _M.get_reused_times(self) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + return sock:getreusedtimes() +end + + +local function close(self) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + return sock:close() +end +_M.close = close + + +local function _read_reply(self, sock) + local line, err = sock:receive() + if not line then + if err == "timeout" and not rawget(self, "_subscribed") then + sock:close() + end + return nil, err + end + + local prefix = byte(line) + + if prefix == 36 then -- char '$' + -- print("bulk reply") + + local size = tonumber(sub(line, 2)) + if size < 0 then + return null + end + + local data, err = sock:receive(size) + if not data then + if err == "timeout" then + sock:close() + end + return nil, err + end + + local dummy, err = sock:receive(2) -- ignore CRLF + if not dummy then + if err == "timeout" then + sock:close() + end + return nil, err + end + + return data + + elseif prefix == 43 then -- char '+' + -- print("status reply") + + return sub(line, 2) + + elseif prefix == 42 then -- char '*' + local n = tonumber(sub(line, 2)) + + -- print("multi-bulk reply: ", n) + if n < 0 then + return null + end + + local vals = new_tab(n, 0) + local nvals = 0 + for i = 1, n do + local res, err = _read_reply(self, sock) + if res then + nvals = nvals + 1 + vals[nvals] = res + + elseif res == nil then + return nil, err + + else + -- be a valid redis error value + nvals = nvals + 1 + vals[nvals] = {false, err} + end + end + + return vals + + elseif prefix == 58 then -- char ':' + -- print("integer reply") + return tonumber(sub(line, 2)) + + elseif prefix == 45 then -- char '-' + -- print("error reply: ", n) + + return false, sub(line, 2) + + else + -- when `line` is an empty string, `prefix` will be equal to nil. + return nil, "unknown prefix: \"" .. tostring(prefix) .. "\"" + end +end + + +local function _gen_req(args) + local nargs = #args + + local req = new_tab(nargs * 5 + 1, 0) + req[1] = "*" .. nargs .. "\r\n" + local nbits = 2 + + for i = 1, nargs do + local arg = args[i] + if type(arg) ~= "string" then + arg = tostring(arg) + end + + req[nbits] = "$" + req[nbits + 1] = #arg + req[nbits + 2] = "\r\n" + req[nbits + 3] = arg + req[nbits + 4] = "\r\n" + + nbits = nbits + 5 + end + + -- it is much faster to do string concatenation on the C land + -- in real world (large number of strings in the Lua VM) + return req +end + + +local function _check_msg(self, res) + return rawget(self, "_subscribed") and + type(res) == "table" and (res[1] == "message" or res[1] == "pmessage") +end + + +local function _do_cmd(self, ...) + local args = {...} + + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + local req = _gen_req(args) + + local reqs = rawget(self, "_reqs") + if reqs then + reqs[#reqs + 1] = req + return + end + + -- print("request: ", table.concat(req)) + + local bytes, err = sock:send(req) + if not bytes then + return nil, err + end + + local res, err = _read_reply(self, sock) + while _check_msg(self, res) do + if rawget(self, "_buffered_msg") == nil then + self._buffered_msg = new_tab(1, 0) + end + + tab_insert(self._buffered_msg, res) + res, err = _read_reply(self, sock) + end + + return res, err +end + + +local function _check_unsubscribed(self, res) + if type(res) == "table" + and (res[1] == "unsubscribe" or res[1] == "punsubscribe") + then + self._n_channel[res[1]] = self._n_channel[res[1]] - 1 + + local buffered_msg = rawget(self, "_buffered_msg") + if buffered_msg then + -- remove messages of unsubscribed channel + local msg_type = + (res[1] == "punsubscribe") and "pmessage" or "message" + local j = 1 + for _, msg in ipairs(buffered_msg) do + if msg[1] == msg_type and msg[2] ~= res[2] then + -- move messages to overwrite the removed ones + buffered_msg[j] = msg + j = j + 1 + end + end + + -- clear remain messages + for i = j, #buffered_msg do + buffered_msg[i] = nil + end + + if #buffered_msg == 0 then + self._buffered_msg = nil + end + end + + if res[3] == 0 then + -- all channels are unsubscribed + self._subscribed = false + end + end +end + + +local function _check_subscribed(self, res) + if type(res) == "table" + and (res[1] == "subscribe" or res[1] == "psubscribe") + then + if res[1] == "subscribe" then + self._n_channel.unsubscribe = self._n_channel.unsubscribe + 1 + + elseif res[1] == "psubscribe" then + self._n_channel.punsubscribe = self._n_channel.punsubscribe + 1 + end + end +end + + +function _M.read_reply(self) + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + if not rawget(self, "_subscribed") then + return nil, "not subscribed" + end + + local buffered_msg = rawget(self, "_buffered_msg") + if buffered_msg then + local msg = buffered_msg[1] + tab_remove(buffered_msg, 1) + + if #buffered_msg == 0 then + self._buffered_msg = nil + end + + return msg + end + + local res, err = _read_reply(self, sock) + _check_unsubscribed(self, res) + + return res, err +end + + +for i = 1, #common_cmds do + local cmd = common_cmds[i] + + _M[cmd] = + function (self, ...) + return _do_cmd(self, cmd, ...) + end +end + + +local function handle_subscribe_result(self, cmd, nargs, res) + local err + _check_subscribed(self, res) + + if nargs <= 1 then + return res + end + + local results = new_tab(nargs, 0) + results[1] = res + local sock = rawget(self, "_sock") + + for i = 2, nargs do + res, err = _read_reply(self, sock) + if not res then + return nil, err + end + + _check_subscribed(self, res) + results[i] = res + end + + return results +end + +for i = 1, #sub_commands do + local cmd = sub_commands[i] + + _M[cmd] = + function (self, ...) + if not rawget(self, "_subscribed") then + self._subscribed = true + end + + local nargs = select("#", ...) + + local res, err = _do_cmd(self, cmd, ...) + if not res then + return nil, err + end + + return handle_subscribe_result(self, cmd, nargs, res) + end +end + + +local function handle_unsubscribe_result(self, cmd, nargs, res) + local err + _check_unsubscribed(self, res) + + if self._n_channel[cmd] == 0 or nargs == 1 then + return res + end + + local results = new_tab(nargs, 0) + results[1] = res + local sock = rawget(self, "_sock") + local i = 2 + + while nargs == 0 or i <= nargs do + res, err = _read_reply(self, sock) + if not res then + return nil, err + end + + results[i] = res + i = i + 1 + + _check_unsubscribed(self, res) + if self._n_channel[cmd] == 0 then + -- exit the loop for unsubscribe() call + break + end + end + + return results +end + +for i = 1, #unsub_commands do + local cmd = unsub_commands[i] + + _M[cmd] = + function (self, ...) + -- assume all channels are unsubscribed by only one time + if not rawget(self, "_subscribed") then + return nil, "not subscribed" + end + + local nargs = select("#", ...) + + local res, err = _do_cmd(self, cmd, ...) + if not res then + return nil, err + end + + return handle_unsubscribe_result(self, cmd, nargs, res) + end +end + + +function _M.hmset(self, hashname, ...) + if select('#', ...) == 1 then + local t = select(1, ...) + + local n = 0 + for k, v in pairs(t) do + n = n + 2 + end + + local array = new_tab(n, 0) + + local i = 0 + for k, v in pairs(t) do + array[i + 1] = k + array[i + 2] = v + i = i + 2 + end + -- print("key", hashname) + return _do_cmd(self, "hmset", hashname, unpack(array)) + end + + -- backwards compatibility + return _do_cmd(self, "hmset", hashname, ...) +end + + +function _M.init_pipeline(self, n) + self._reqs = new_tab(n or 4, 0) +end + + +function _M.cancel_pipeline(self) + self._reqs = nil +end + + +function _M.commit_pipeline(self) + local reqs = rawget(self, "_reqs") + if not reqs then + return nil, "no pipeline" + end + + self._reqs = nil + + local sock = rawget(self, "_sock") + if not sock then + return nil, "not initialized" + end + + local bytes, err = sock:send(reqs) + if not bytes then + return nil, err + end + + local nvals = 0 + local nreqs = #reqs + local vals = new_tab(nreqs, 0) + for i = 1, nreqs do + local res, err = _read_reply(self, sock) + if res then + nvals = nvals + 1 + vals[nvals] = res + + elseif res == nil then + if err == "timeout" then + close(self) + end + return nil, err + + else + -- be a valid redis error value + nvals = nvals + 1 + vals[nvals] = {false, err} + end + end + + return vals +end + + +function _M.array_to_hash(self, t) + local n = #t + -- print("n = ", n) + local h = new_tab(0, n / 2) + for i = 1, n, 2 do + h[t[i]] = t[i + 1] + end + return h +end + + +-- this method is deperate since we already do lazy method generation. +function _M.add_commands(...) + local cmds = {...} + for i = 1, #cmds do + local cmd = cmds[i] + _M[cmd] = + function (self, ...) + return _do_cmd(self, cmd, ...) + end + end +end + + +setmetatable(_M, {__index = function(self, cmd) + local method = + function (self, ...) + return _do_cmd(self, cmd, ...) + end + + -- cache the lazily generated method in our + -- module table + _M[cmd] = method + return method +end}) + + +return _M diff --git a/resty/sha.lua b/resty/sha.lua new file mode 100644 index 0000000..3f0b3bc --- /dev/null +++ b/resty/sha.lua @@ -0,0 +1,19 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +local ffi = require "ffi" + + +local _M = { _VERSION = '0.14' } + + +ffi.cdef[[ +typedef unsigned long SHA_LONG; +typedef unsigned long long SHA_LONG64; + +enum { + SHA_LBLOCK = 16 +}; +]]; + +return _M diff --git a/resty/sha1.lua b/resty/sha1.lua new file mode 100644 index 0000000..1d42f6d --- /dev/null +++ b/resty/sha1.lua @@ -0,0 +1,69 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +require "resty.sha" +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +local setmetatable = setmetatable +--local error = error + + +local _M = { _VERSION = '0.14' } + + +local mt = { __index = _M } + + +ffi.cdef[[ +typedef struct SHAstate_st + { + SHA_LONG h0,h1,h2,h3,h4; + SHA_LONG Nl,Nh; + SHA_LONG data[SHA_LBLOCK]; + unsigned int num; + } SHA_CTX; + +int SHA1_Init(SHA_CTX *c); +int SHA1_Update(SHA_CTX *c, const void *data, size_t len); +int SHA1_Final(unsigned char *md, SHA_CTX *c); +]] + +local digest_len = 20 + +local buf = ffi_new("char[?]", digest_len) +local ctx_ptr_type = ffi.typeof("SHA_CTX[1]") + + +function _M.new(self) + local ctx = ffi_new(ctx_ptr_type) + if C.SHA1_Init(ctx) == 0 then + return nil + end + + return setmetatable({ _ctx = ctx }, mt) +end + + +function _M.update(self, s) + return C.SHA1_Update(self._ctx, s, #s) == 1 +end + + +function _M.final(self) + if C.SHA1_Final(buf, self._ctx) == 1 then + return ffi_str(buf, digest_len) + end + + return nil +end + + +function _M.reset(self) + return C.SHA1_Init(self._ctx) == 1 +end + + +return _M + diff --git a/resty/sha224.lua b/resty/sha224.lua new file mode 100644 index 0000000..43c2a92 --- /dev/null +++ b/resty/sha224.lua @@ -0,0 +1,60 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +require "resty.sha256" +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +local setmetatable = setmetatable +--local error = error + + +local _M = { _VERSION = '0.14' } + + +local mt = { __index = _M } + + +ffi.cdef[[ +int SHA224_Init(SHA256_CTX *c); +int SHA224_Update(SHA256_CTX *c, const void *data, size_t len); +int SHA224_Final(unsigned char *md, SHA256_CTX *c); +]] + +local digest_len = 28 + +local buf = ffi_new("char[?]", digest_len) +local ctx_ptr_type = ffi.typeof("SHA256_CTX[1]") + + +function _M.new(self) + local ctx = ffi_new(ctx_ptr_type) + if C.SHA224_Init(ctx) == 0 then + return nil + end + + return setmetatable({ _ctx = ctx }, mt) +end + + +function _M.update(self, s) + return C.SHA224_Update(self._ctx, s, #s) == 1 +end + + +function _M.final(self) + if C.SHA224_Final(buf, self._ctx) == 1 then + return ffi_str(buf, digest_len) + end + + return nil +end + + +function _M.reset(self) + return C.SHA224_Init(self._ctx) == 1 +end + + +return _M diff --git a/resty/sha256.lua b/resty/sha256.lua new file mode 100644 index 0000000..3f37a66 --- /dev/null +++ b/resty/sha256.lua @@ -0,0 +1,69 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +require "resty.sha" +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +local setmetatable = setmetatable +--local error = error + + +local _M = { _VERSION = '0.14' } + + +local mt = { __index = _M } + + +ffi.cdef[[ +typedef struct SHA256state_st + { + SHA_LONG h[8]; + SHA_LONG Nl,Nh; + SHA_LONG data[SHA_LBLOCK]; + unsigned int num,md_len; + } SHA256_CTX; + +int SHA256_Init(SHA256_CTX *c); +int SHA256_Update(SHA256_CTX *c, const void *data, size_t len); +int SHA256_Final(unsigned char *md, SHA256_CTX *c); +]] + +local digest_len = 32 + +local buf = ffi_new("char[?]", digest_len) +local ctx_ptr_type = ffi.typeof("SHA256_CTX[1]") + + +function _M.new(self) + local ctx = ffi_new(ctx_ptr_type) + if C.SHA256_Init(ctx) == 0 then + return nil + end + + return setmetatable({ _ctx = ctx }, mt) +end + + +function _M.update(self, s) + return C.SHA256_Update(self._ctx, s, #s) == 1 +end + + +function _M.final(self) + if C.SHA256_Final(buf, self._ctx) == 1 then + return ffi_str(buf, digest_len) + end + + return nil +end + + +function _M.reset(self) + return C.SHA256_Init(self._ctx) == 1 +end + + +return _M + diff --git a/resty/sha384.lua b/resty/sha384.lua new file mode 100644 index 0000000..625c3a9 --- /dev/null +++ b/resty/sha384.lua @@ -0,0 +1,60 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +require "resty.sha512" +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +local setmetatable = setmetatable +--local error = error + + +local _M = { _VERSION = '0.14' } + + +local mt = { __index = _M } + + +ffi.cdef[[ +int SHA384_Init(SHA512_CTX *c); +int SHA384_Update(SHA512_CTX *c, const void *data, size_t len); +int SHA384_Final(unsigned char *md, SHA512_CTX *c); +]] + +local digest_len = 48 + +local buf = ffi_new("char[?]", digest_len) +local ctx_ptr_type = ffi.typeof("SHA512_CTX[1]") + + +function _M.new(self) + local ctx = ffi_new(ctx_ptr_type) + if C.SHA384_Init(ctx) == 0 then + return nil + end + + return setmetatable({ _ctx = ctx }, mt) +end + + +function _M.update(self, s) + return C.SHA384_Update(self._ctx, s, #s) == 1 +end + + +function _M.final(self) + if C.SHA384_Final(buf, self._ctx) == 1 then + return ffi_str(buf, digest_len) + end + + return nil +end + + +function _M.reset(self) + return C.SHA384_Init(self._ctx) == 1 +end + +return _M + diff --git a/resty/sha512.lua b/resty/sha512.lua new file mode 100644 index 0000000..f73ed7d --- /dev/null +++ b/resty/sha512.lua @@ -0,0 +1,75 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +require "resty.sha" +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +local setmetatable = setmetatable +--local error = error + + +local _M = { _VERSION = '0.14' } + + +local mt = { __index = _M } + + +ffi.cdef[[ +enum { + SHA512_CBLOCK = SHA_LBLOCK*8 +}; + +typedef struct SHA512state_st + { + SHA_LONG64 h[8]; + SHA_LONG64 Nl,Nh; + union { + SHA_LONG64 d[SHA_LBLOCK]; + unsigned char p[SHA512_CBLOCK]; + } u; + unsigned int num,md_len; + } SHA512_CTX; + +int SHA512_Init(SHA512_CTX *c); +int SHA512_Update(SHA512_CTX *c, const void *data, size_t len); +int SHA512_Final(unsigned char *md, SHA512_CTX *c); +]] + +local digest_len = 64 + +local buf = ffi_new("char[?]", digest_len) +local ctx_ptr_type = ffi.typeof("SHA512_CTX[1]") + + +function _M.new(self) + local ctx = ffi_new(ctx_ptr_type) + if C.SHA512_Init(ctx) == 0 then + return nil + end + + return setmetatable({ _ctx = ctx }, mt) +end + + +function _M.update(self, s) + return C.SHA512_Update(self._ctx, s, #s) == 1 +end + + +function _M.final(self) + if C.SHA512_Final(buf, self._ctx) == 1 then + return ffi_str(buf, digest_len) + end + + return nil +end + + +function _M.reset(self) + return C.SHA512_Init(self._ctx) == 1 +end + + +return _M diff --git a/resty/string.lua b/resty/string.lua new file mode 100644 index 0000000..6070041 --- /dev/null +++ b/resty/string.lua @@ -0,0 +1,40 @@ +-- Copyright (C) by Yichun Zhang (agentzh) + + +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local C = ffi.C +--local setmetatable = setmetatable +--local error = error +local tonumber = tonumber + + +local _M = { _VERSION = '0.14' } + + +ffi.cdef[[ +typedef unsigned char u_char; + +u_char * ngx_hex_dump(u_char *dst, const u_char *src, size_t len); + +intptr_t ngx_atoi(const unsigned char *line, size_t n); +]] + +local str_type = ffi.typeof("uint8_t[?]") + + +function _M.to_hex(s) + local len = #s + local buf_len = len * 2 + local buf = ffi_new(str_type, buf_len) + C.ngx_hex_dump(buf, s, len) + return ffi_str(buf, buf_len) +end + +function _M.atoi(s) + return tonumber(C.ngx_atoi(s, #s)) +end + + +return _M diff --git a/resty/upload.lua b/resty/upload.lua new file mode 100644 index 0000000..37da961 --- /dev/null +++ b/resty/upload.lua @@ -0,0 +1,267 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +-- local sub = string.sub +local req_socket = ngx.req.socket +local match = string.match +local setmetatable = setmetatable +local type = type +local ngx_var = ngx.var +-- local print = print + + +local _M = { _VERSION = '0.10' } + + +local CHUNK_SIZE = 4096 +local MAX_LINE_SIZE = 512 + +local STATE_BEGIN = 1 +local STATE_READING_HEADER = 2 +local STATE_READING_BODY = 3 +local STATE_EOF = 4 + + +local mt = { __index = _M } + +local state_handlers + + +local function get_boundary() + local header = ngx_var.content_type + if not header then + return nil + end + + if type(header) == "table" then + header = header[1] + end + + local m = match(header, ";%s*boundary=\"([^\"]+)\"") + if m then + return m + end + + return match(header, ";%s*boundary=([^\",;]+)") +end + + +function _M.new(self, chunk_size, max_line_size) + local boundary = get_boundary() + + -- print("boundary: ", boundary) + + if not boundary then + return nil, "no boundary defined in Content-Type" + end + + -- print('boundary: "', boundary, '"') + + local sock, err = req_socket() + if not sock then + return nil, err + end + + local read2boundary, err = sock:receiveuntil("--" .. boundary) + if not read2boundary then + return nil, err + end + + local read_line, err = sock:receiveuntil("\r\n") + if not read_line then + return nil, err + end + + return setmetatable({ + sock = sock, + size = chunk_size or CHUNK_SIZE, + line_size = max_line_size or MAX_LINE_SIZE, + read2boundary = read2boundary, + read_line = read_line, + boundary = boundary, + state = STATE_BEGIN + }, mt) +end + + +function _M.set_timeout(self, timeout) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:settimeout(timeout) +end + + +local function discard_line(self) + local read_line = self.read_line + + local line, err = read_line(self.line_size) + if not line then + return nil, err + end + + local dummy, err = read_line(1) + if dummy then + return nil, "line too long: " .. line .. dummy .. "..." + end + + if err then + return nil, err + end + + return 1 +end + + +local function discard_rest(self) + local sock = self.sock + local size = self.size + + while true do + local dummy, err = sock:receive(size) + if err and err ~= 'closed' then + return nil, err + end + + if not dummy then + return 1 + end + end +end + + +local function read_body_part(self) + local read2boundary = self.read2boundary + + local chunk, err = read2boundary(self.size) + if err then + return nil, nil, err + end + + if not chunk then + local sock = self.sock + + local data = sock:receive(2) + if data == "--" then + local ok, err = discard_rest(self) + if not ok then + return nil, nil, err + end + + self.state = STATE_EOF + return "part_end" + end + + if data ~= "\r\n" then + local ok, err = discard_line(self) + if not ok then + return nil, nil, err + end + end + + self.state = STATE_READING_HEADER + return "part_end" + end + + return "body", chunk +end + + +local function read_header(self) + local read_line = self.read_line + + local line, err = read_line(self.line_size) + if err then + return nil, nil, err + end + + local dummy, err = read_line(1) + if dummy then + return nil, nil, "line too long: " .. line .. dummy .. "..." + end + + if err then + return nil, nil, err + end + + -- print("read line: ", line) + + if line == "" then + -- after the last header + self.state = STATE_READING_BODY + return read_body_part(self) + end + + local key, value = match(line, "([^: \t]+)%s*:%s*(.+)") + if not key then + return 'header', line + end + + return 'header', {key, value, line} +end + + +local function eof() + return "eof", nil +end + + +function _M.read(self) + -- local size = self.size + + local handler = state_handlers[self.state] + if handler then + return handler(self) + end + + return nil, nil, "bad state: " .. self.state +end + + +local function read_preamble(self) + local sock = self.sock + if not sock then + return nil, nil, "not initialized" + end + + local size = self.size + local read2boundary = self.read2boundary + + while true do + local preamble = read2boundary(size) + if not preamble then + break + end + + -- discard the preamble data chunk + -- print("read preamble: ", preamble) + end + + local ok, err = discard_line(self) + if not ok then + return nil, nil, err + end + + local read2boundary, err = sock:receiveuntil("\r\n--" .. self.boundary) + if not read2boundary then + return nil, nil, err + end + + self.read2boundary = read2boundary + + self.state = STATE_READING_HEADER + return read_header(self) +end + + +state_handlers = { + read_preamble, + read_header, + read_body_part, + eof +} + + +return _M diff --git a/resty/upstream/healthcheck.lua b/resty/upstream/healthcheck.lua new file mode 100644 index 0000000..72dac88 --- /dev/null +++ b/resty/upstream/healthcheck.lua @@ -0,0 +1,707 @@ +local stream_sock = ngx.socket.tcp +local log = ngx.log +local ERR = ngx.ERR +local WARN = ngx.WARN +local DEBUG = ngx.DEBUG +local sub = string.sub +local re_find = ngx.re.find +local new_timer = ngx.timer.at +local shared = ngx.shared +local debug_mode = ngx.config.debug +local concat = table.concat +local tonumber = tonumber +local tostring = tostring +local ipairs = ipairs +local ceil = math.ceil +local spawn = ngx.thread.spawn +local wait = ngx.thread.wait +local pcall = pcall + +local _M = { + _VERSION = '0.05' +} + +if not ngx.config + or not ngx.config.ngx_lua_version + or ngx.config.ngx_lua_version < 9005 +then + error("ngx_lua 0.9.5+ required") +end + +local ok, upstream = pcall(require, "ngx.upstream") +if not ok then + error("ngx_upstream_lua module required") +end + +local ok, new_tab = pcall(require, "table.new") +if not ok or type(new_tab) ~= "function" then + new_tab = function (narr, nrec) return {} end +end + +local set_peer_down = upstream.set_peer_down +local get_primary_peers = upstream.get_primary_peers +local get_backup_peers = upstream.get_backup_peers +local get_upstreams = upstream.get_upstreams + +local upstream_checker_statuses = {} + +local function warn(...) + log(WARN, "healthcheck: ", ...) +end + +local function errlog(...) + log(ERR, "healthcheck: ", ...) +end + +local function debug(...) + -- print("debug mode: ", debug_mode) + if debug_mode then + log(DEBUG, "healthcheck: ", ...) + end +end + +local function gen_peer_key(prefix, u, is_backup, id) + if is_backup then + return prefix .. u .. ":b" .. id + end + return prefix .. u .. ":p" .. id +end + +local function set_peer_down_globally(ctx, is_backup, id, value) + local u = ctx.upstream + local dict = ctx.dict + local ok, err = set_peer_down(u, is_backup, id, value) + if not ok then + errlog("failed to set peer down: ", err) + end + + if not ctx.new_version then + ctx.new_version = true + end + + local key = gen_peer_key("d:", u, is_backup, id) + local ok, err = dict:set(key, value) + if not ok then + errlog("failed to set peer down state: ", err) + end +end + +local function peer_fail(ctx, is_backup, id, peer) + debug("peer ", peer.name, " was checked to be not ok") + + local u = ctx.upstream + local dict = ctx.dict + + local key = gen_peer_key("nok:", u, is_backup, id) + local fails, err = dict:get(key) + if not fails then + if err then + errlog("failed to get peer nok key: ", err) + return + end + fails = 1 + + -- below may have a race condition, but it is fine for our + -- purpose here. + local ok, err = dict:set(key, 1) + if not ok then + errlog("failed to set peer nok key: ", err) + end + else + fails = fails + 1 + local ok, err = dict:incr(key, 1) + if not ok then + errlog("failed to incr peer nok key: ", err) + end + end + + if fails == 1 then + key = gen_peer_key("ok:", u, is_backup, id) + local succ, err = dict:get(key) + if not succ or succ == 0 then + if err then + errlog("failed to get peer ok key: ", err) + return + end + else + local ok, err = dict:set(key, 0) + if not ok then + errlog("failed to set peer ok key: ", err) + end + end + end + + -- print("ctx fall: ", ctx.fall, ", peer down: ", peer.down, + -- ", fails: ", fails) + + if not peer.down and fails >= ctx.fall then + warn("peer ", peer.name, " is turned down after ", fails, + " failure(s)") + peer.down = true + set_peer_down_globally(ctx, is_backup, id, true) + end +end + +local function peer_ok(ctx, is_backup, id, peer) + debug("peer ", peer.name, " was checked to be ok") + + local u = ctx.upstream + local dict = ctx.dict + + local key = gen_peer_key("ok:", u, is_backup, id) + local succ, err = dict:get(key) + if not succ then + if err then + errlog("failed to get peer ok key: ", err) + return + end + succ = 1 + + -- below may have a race condition, but it is fine for our + -- purpose here. + local ok, err = dict:set(key, 1) + if not ok then + errlog("failed to set peer ok key: ", err) + end + else + succ = succ + 1 + local ok, err = dict:incr(key, 1) + if not ok then + errlog("failed to incr peer ok key: ", err) + end + end + + if succ == 1 then + key = gen_peer_key("nok:", u, is_backup, id) + local fails, err = dict:get(key) + if not fails or fails == 0 then + if err then + errlog("failed to get peer nok key: ", err) + return + end + else + local ok, err = dict:set(key, 0) + if not ok then + errlog("failed to set peer nok key: ", err) + end + end + end + + if peer.down and succ >= ctx.rise then + warn("peer ", peer.name, " is turned up after ", succ, + " success(es)") + peer.down = nil + set_peer_down_globally(ctx, is_backup, id, nil) + end +end + +-- shortcut error function for check_peer() +local function peer_error(ctx, is_backup, id, peer, ...) + if not peer.down then + errlog(...) + end + peer_fail(ctx, is_backup, id, peer) +end + +local function check_peer(ctx, id, peer, is_backup) + local ok + local name = peer.name + local statuses = ctx.statuses + local req = ctx.http_req + + local sock, err = stream_sock() + if not sock then + errlog("failed to create stream socket: ", err) + return + end + + sock:settimeout(ctx.timeout) + + if peer.host then + -- print("peer port: ", peer.port) + ok, err = sock:connect(peer.host, peer.port) + else + ok, err = sock:connect(name) + end + if not ok then + if not peer.down then + errlog("failed to connect to ", name, ": ", err) + end + return peer_fail(ctx, is_backup, id, peer) + end + + local bytes, err = sock:send(req) + if not bytes then + return peer_error(ctx, is_backup, id, peer, + "failed to send request to ", name, ": ", err) + end + + local status_line, err = sock:receive() + if not status_line then + peer_error(ctx, is_backup, id, peer, + "failed to receive status line from ", name, ": ", err) + if err == "timeout" then + sock:close() -- timeout errors do not close the socket. + end + return + end + + if statuses then + local from, to, err = re_find(status_line, + [[^HTTP/\d+\.\d+\s+(\d+)]], + "joi", nil, 1) + if err then + errlog("failed to parse status line: ", err) + end + + if not from then + peer_error(ctx, is_backup, id, peer, + "bad status line from ", name, ": ", + status_line) + sock:close() + return + end + + local status = tonumber(sub(status_line, from, to)) + if not statuses[status] then + peer_error(ctx, is_backup, id, peer, "bad status code from ", + name, ": ", status) + sock:close() + return + end + end + + peer_ok(ctx, is_backup, id, peer) + sock:close() +end + +local function check_peer_range(ctx, from, to, peers, is_backup) + for i = from, to do + check_peer(ctx, i - 1, peers[i], is_backup) + end +end + +local function check_peers(ctx, peers, is_backup) + local n = #peers + if n == 0 then + return + end + + local concur = ctx.concurrency + if concur <= 1 then + for i = 1, n do + check_peer(ctx, i - 1, peers[i], is_backup) + end + else + local threads + local nthr + + if n <= concur then + nthr = n - 1 + threads = new_tab(nthr, 0) + for i = 1, nthr do + + if debug_mode then + debug("spawn a thread checking ", + is_backup and "backup" or "primary", " peer ", i - 1) + end + + threads[i] = spawn(check_peer, ctx, i - 1, peers[i], is_backup) + end + -- use the current "light thread" to run the last task + if debug_mode then + debug("check ", is_backup and "backup" or "primary", " peer ", + n - 1) + end + check_peer(ctx, n - 1, peers[n], is_backup) + + else + local group_size = ceil(n / concur) + nthr = ceil(n / group_size) - 1 + + threads = new_tab(nthr, 0) + local from = 1 + local rest = n + for i = 1, nthr do + local to + if rest >= group_size then + rest = rest - group_size + to = from + group_size - 1 + else + rest = 0 + to = from + rest - 1 + end + + if debug_mode then + debug("spawn a thread checking ", + is_backup and "backup" or "primary", " peers ", + from - 1, " to ", to - 1) + end + + threads[i] = spawn(check_peer_range, ctx, from, to, peers, + is_backup) + from = from + group_size + if rest == 0 then + break + end + end + if rest > 0 then + local to = from + rest - 1 + + if debug_mode then + debug("check ", is_backup and "backup" or "primary", + " peers ", from - 1, " to ", to - 1) + end + + check_peer_range(ctx, from, to, peers, is_backup) + end + end + + if nthr and nthr > 0 then + for i = 1, nthr do + local t = threads[i] + if t then + wait(t) + end + end + end + end +end + +local function upgrade_peers_version(ctx, peers, is_backup) + local dict = ctx.dict + local u = ctx.upstream + local n = #peers + for i = 1, n do + local peer = peers[i] + local id = i - 1 + local key = gen_peer_key("d:", u, is_backup, id) + local down = false + local res, err = dict:get(key) + if not res then + if err then + errlog("failed to get peer down state: ", err) + end + else + down = true + end + if (peer.down and not down) or (not peer.down and down) then + local ok, err = set_peer_down(u, is_backup, id, down) + if not ok then + errlog("failed to set peer down: ", err) + else + -- update our cache too + peer.down = down + end + end + end +end + +local function check_peers_updates(ctx) + local dict = ctx.dict + local u = ctx.upstream + local key = "v:" .. u + local ver, err = dict:get(key) + if not ver then + if err then + errlog("failed to get peers version: ", err) + return + end + + if ctx.version > 0 then + ctx.new_version = true + end + + elseif ctx.version < ver then + debug("upgrading peers version to ", ver) + upgrade_peers_version(ctx, ctx.primary_peers, false); + upgrade_peers_version(ctx, ctx.backup_peers, true); + ctx.version = ver + end +end + +local function get_lock(ctx) + local dict = ctx.dict + local key = "l:" .. ctx.upstream + + -- the lock is held for the whole interval to prevent multiple + -- worker processes from sending the test request simultaneously. + -- here we substract the lock expiration time by 1ms to prevent + -- a race condition with the next timer event. + local ok, err = dict:add(key, true, ctx.interval - 0.001) + if not ok then + if err == "exists" then + return nil + end + errlog("failed to add key \"", key, "\": ", err) + return nil + end + return true +end + +local function do_check(ctx) + debug("healthcheck: run a check cycle") + + check_peers_updates(ctx) + + if get_lock(ctx) then + check_peers(ctx, ctx.primary_peers, false) + check_peers(ctx, ctx.backup_peers, true) + end + + if ctx.new_version then + local key = "v:" .. ctx.upstream + local dict = ctx.dict + + if debug_mode then + debug("publishing peers version ", ctx.version + 1) + end + + dict:add(key, 0) + local new_ver, err = dict:incr(key, 1) + if not new_ver then + errlog("failed to publish new peers version: ", err) + end + + ctx.version = new_ver + ctx.new_version = nil + end +end + +local function update_upstream_checker_status(upstream, success) + local cnt = upstream_checker_statuses[upstream] + if not cnt then + cnt = 0 + end + + if success then + cnt = cnt + 1 + else + cnt = cnt - 1 + end + + upstream_checker_statuses[upstream] = cnt +end + +local check +check = function (premature, ctx) + if premature then + return + end + + local ok, err = pcall(do_check, ctx) + if not ok then + errlog("failed to run healthcheck cycle: ", err) + end + + local ok, err = new_timer(ctx.interval, check, ctx) + if not ok then + if err ~= "process exiting" then + errlog("failed to create timer: ", err) + end + + update_upstream_checker_status(ctx.upstream, false) + return + end +end + +local function preprocess_peers(peers) + local n = #peers + for i = 1, n do + local p = peers[i] + local name = p.name + + if name then + local from, to, err = re_find(name, [[^(.*):\d+$]], "jo", nil, 1) + if from then + p.host = sub(name, 1, to) + p.port = tonumber(sub(name, to + 2)) + end + end + end + return peers +end + +function _M.spawn_checker(opts) + local typ = opts.type + if not typ then + return nil, "\"type\" option required" + end + + if typ ~= "http" then + return nil, "only \"http\" type is supported right now" + end + + local http_req = opts.http_req + if not http_req then + return nil, "\"http_req\" option required" + end + + local timeout = opts.timeout + if not timeout then + timeout = 1000 + end + + local interval = opts.interval + if not interval then + interval = 1 + + else + interval = interval / 1000 + if interval < 0.002 then -- minimum 2ms + interval = 0.002 + end + end + + local valid_statuses = opts.valid_statuses + local statuses + if valid_statuses then + statuses = new_tab(0, #valid_statuses) + for _, status in ipairs(valid_statuses) do + -- print("found good status ", status) + statuses[status] = true + end + end + + -- debug("interval: ", interval) + + local concur = opts.concurrency + if not concur then + concur = 1 + end + + local fall = opts.fall + if not fall then + fall = 5 + end + + local rise = opts.rise + if not rise then + rise = 2 + end + + local shm = opts.shm + if not shm then + return nil, "\"shm\" option required" + end + + local dict = shared[shm] + if not dict then + return nil, "shm \"" .. tostring(shm) .. "\" not found" + end + + local u = opts.upstream + if not u then + return nil, "no upstream specified" + end + + local ppeers, err = get_primary_peers(u) + if not ppeers then + return nil, "failed to get primary peers: " .. err + end + + local bpeers, err = get_backup_peers(u) + if not bpeers then + return nil, "failed to get backup peers: " .. err + end + + local ctx = { + upstream = u, + primary_peers = preprocess_peers(ppeers), + backup_peers = preprocess_peers(bpeers), + http_req = http_req, + timeout = timeout, + interval = interval, + dict = dict, + fall = fall, + rise = rise, + statuses = statuses, + version = 0, + concurrency = concur, + } + + if debug_mode and opts.no_timer then + check(nil, ctx) + + else + local ok, err = new_timer(0, check, ctx) + if not ok then + return nil, "failed to create timer: " .. err + end + end + + update_upstream_checker_status(u, true) + + return true +end + +local function gen_peers_status_info(peers, bits, idx) + local npeers = #peers + for i = 1, npeers do + local peer = peers[i] + bits[idx] = " " + bits[idx + 1] = peer.name + if peer.down then + bits[idx + 2] = " DOWN\n" + else + bits[idx + 2] = " up\n" + end + idx = idx + 3 + end + return idx +end + +function _M.status_page() + -- generate an HTML page + local us, err = get_upstreams() + if not us then + return "failed to get upstream names: " .. err + end + + local n = #us + local bits = new_tab(n * 20, 0) + local idx = 1 + for i = 1, n do + if i > 1 then + bits[idx] = "\n" + idx = idx + 1 + end + + local u = us[i] + + bits[idx] = "Upstream " + bits[idx + 1] = u + idx = idx + 2 + + local ncheckers = upstream_checker_statuses[u] + if not ncheckers or ncheckers == 0 then + bits[idx] = " (NO checkers)" + idx = idx + 1 + end + + bits[idx] = "\n Primary Peers\n" + idx = idx + 1 + + local peers, err = get_primary_peers(u) + if not peers then + return "failed to get primary peers in upstream " .. u .. ": " + .. err + end + + idx = gen_peers_status_info(peers, bits, idx) + + bits[idx] = " Backup Peers\n" + idx = idx + 1 + + peers, err = get_backup_peers(u) + if not peers then + return "failed to get backup peers in upstream " .. u .. ": " + .. err + end + + idx = gen_peers_status_info(peers, bits, idx) + end + return concat(bits) +end + +return _M diff --git a/resty/websocket/client.lua b/resty/websocket/client.lua new file mode 100644 index 0000000..067b2a5 --- /dev/null +++ b/resty/websocket/client.lua @@ -0,0 +1,353 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +-- FIXME: this library is very rough and is currently just for testing +-- the websocket server. + + +local wbproto = require "resty.websocket.protocol" +local bit = require "bit" + + +local _recv_frame = wbproto.recv_frame +local _send_frame = wbproto.send_frame +local new_tab = wbproto.new_tab +local tcp = ngx.socket.tcp +local re_match = ngx.re.match +local encode_base64 = ngx.encode_base64 +local concat = table.concat +local char = string.char +local str_find = string.find +local rand = math.random +local rshift = bit.rshift +local band = bit.band +local setmetatable = setmetatable +local type = type +local debug = ngx.config.debug +local ngx_log = ngx.log +local ngx_DEBUG = ngx.DEBUG +local ssl_support = true + +if not ngx.config + or not ngx.config.ngx_lua_version + or ngx.config.ngx_lua_version < 9011 +then + ssl_support = false +end + +local _M = new_tab(0, 13) +_M._VERSION = '0.08' + + +local mt = { __index = _M } + + +function _M.new(self, opts) + local sock, err = tcp() + if not sock then + return nil, err + end + + local max_payload_len, send_unmasked, timeout + if opts then + max_payload_len = opts.max_payload_len + send_unmasked = opts.send_unmasked + timeout = opts.timeout + + if timeout then + sock:settimeout(timeout) + end + end + + return setmetatable({ + sock = sock, + max_payload_len = max_payload_len or 65535, + send_unmasked = send_unmasked, + }, mt) +end + + +function _M.connect(self, uri, opts) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local m, err = re_match(uri, [[^(wss?)://([^:/]+)(?::(\d+))?(.*)]], "jo") + if not m then + if err then + return nil, "failed to match the uri: " .. err + end + + return nil, "bad websocket uri" + end + + local scheme = m[1] + local host = m[2] + local port = m[3] + local path = m[4] + + -- ngx.say("host: ", host) + -- ngx.say("port: ", port) + + if not port then + port = 80 + end + + if path == "" then + path = "/" + end + + local ssl_verify, headers, proto_header, origin_header, sock_opts = false + + if opts then + local protos = opts.protocols + if protos then + if type(protos) == "table" then + proto_header = "\r\nSec-WebSocket-Protocol: " + .. concat(protos, ",") + + else + proto_header = "\r\nSec-WebSocket-Protocol: " .. protos + end + end + + local origin = opts.origin + if origin then + origin_header = "\r\nOrigin: " .. origin + end + + local pool = opts.pool + if pool then + sock_opts = { pool = pool } + end + + if opts.ssl_verify then + if not ssl_support then + return nil, "ngx_lua 0.9.11+ required for SSL sockets" + end + ssl_verify = true + end + + if opts.headers then + headers = opts.headers + if type(headers) ~= "table" then + return nil, "custom headers must be a table" + end + end + end + + local ok, err + if sock_opts then + ok, err = sock:connect(host, port, sock_opts) + else + ok, err = sock:connect(host, port) + end + if not ok then + return nil, "failed to connect: " .. err + end + + if scheme == "wss" then + if not ssl_support then + return nil, "ngx_lua 0.9.11+ required for SSL sockets" + end + ok, err = sock:sslhandshake(false, host, ssl_verify) + if not ok then + return nil, "ssl handshake failed: " .. err + end + end + + -- check for connections from pool: + + local count, err = sock:getreusedtimes() + if not count then + return nil, "failed to get reused times: " .. err + end + if count > 0 then + -- being a reused connection (must have done handshake) + return 1 + end + + local custom_headers + if headers then + custom_headers = concat(headers, "\r\n") + custom_headers = "\r\n" .. custom_headers + end + + -- do the websocket handshake: + + local bytes = char(rand(256) - 1, rand(256) - 1, rand(256) - 1, + rand(256) - 1, rand(256) - 1, rand(256) - 1, + rand(256) - 1, rand(256) - 1, rand(256) - 1, + rand(256) - 1, rand(256) - 1, rand(256) - 1, + rand(256) - 1, rand(256) - 1, rand(256) - 1, + rand(256) - 1) + + local key = encode_base64(bytes) + local req = "GET " .. path .. " HTTP/1.1\r\nUpgrade: websocket\r\nHost: " + .. host .. ":" .. port + .. "\r\nSec-WebSocket-Key: " .. key + .. (proto_header or "") + .. "\r\nSec-WebSocket-Version: 13" + .. (origin_header or "") + .. "\r\nConnection: Upgrade" + .. (custom_headers or "") + .. "\r\n\r\n" + + local bytes, err = sock:send(req) + if not bytes then + return nil, "failed to send the handshake request: " .. err + end + + local header_reader = sock:receiveuntil("\r\n\r\n") + -- FIXME: check for too big response headers + local header, err, partial = header_reader() + if not header then + return nil, "failed to receive response header: " .. err + end + + -- error("header: " .. header) + + -- FIXME: verify the response headers + + m, err = re_match(header, [[^\s*HTTP/1\.1\s+]], "jo") + if not m then + return nil, "bad HTTP response status line: " .. header + end + + return 1 +end + + +function _M.set_timeout(self, time) + local sock = self.sock + if not sock then + return nil, nil, "not initialized yet" + end + + return sock:settimeout(time) +end + + +function _M.recv_frame(self) + if self.fatal then + return nil, nil, "fatal error already happened" + end + + local sock = self.sock + if not sock then + return nil, nil, "not initialized yet" + end + + local data, typ, err = _recv_frame(sock, self.max_payload_len, false) + if not data and not str_find(err, ": timeout", 1, true) then + self.fatal = true + end + return data, typ, err +end + + +local function send_frame(self, fin, opcode, payload) + if self.fatal then + return nil, "fatal error already happened" + end + + if self.closed then + return nil, "already closed" + end + + local sock = self.sock + if not sock then + return nil, "not initialized yet" + end + + local bytes, err = _send_frame(sock, fin, opcode, payload, + self.max_payload_len, + not self.send_unmasked) + if not bytes then + self.fatal = true + end + return bytes, err +end +_M.send_frame = send_frame + + +function _M.send_text(self, data) + return send_frame(self, true, 0x1, data) +end + + +function _M.send_binary(self, data) + return send_frame(self, true, 0x2, data) +end + + +local function send_close(self, code, msg) + local payload + if code then + if type(code) ~= "number" or code > 0x7fff then + return nil, "bad status code" + end + payload = char(band(rshift(code, 8), 0xff), band(code, 0xff)) + .. (msg or "") + end + + if debug then + ngx_log(ngx_DEBUG, "sending the close frame") + end + + local bytes, err = send_frame(self, true, 0x8, payload) + + if not bytes then + self.fatal = true + end + + self.closed = true + + return bytes, err +end +_M.send_close = send_close + + +function _M.send_ping(self, data) + return send_frame(self, true, 0x9, data) +end + + +function _M.send_pong(self, data) + return send_frame(self, true, 0xa, data) +end + + +function _M.close(self) + if self.fatal then + return nil, "fatal error already happened" + end + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + if not self.closed then + local bytes, err = send_close(self) + if not bytes then + return nil, "failed to send close frame: " .. err + end + end + + return sock:close() +end + + +function _M.set_keepalive(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:setkeepalive(...) +end + + +return _M diff --git a/resty/websocket/protocol.lua b/resty/websocket/protocol.lua new file mode 100644 index 0000000..0d75b55 --- /dev/null +++ b/resty/websocket/protocol.lua @@ -0,0 +1,345 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local bit = require "bit" +local ffi = require "ffi" + + +local byte = string.byte +local char = string.char +local sub = string.sub +local band = bit.band +local bor = bit.bor +local bxor = bit.bxor +local lshift = bit.lshift +local rshift = bit.rshift +--local tohex = bit.tohex +local tostring = tostring +local concat = table.concat +local rand = math.random +local type = type +local debug = ngx.config.debug +local ngx_log = ngx.log +local ngx_DEBUG = ngx.DEBUG +local ffi_new = ffi.new +local ffi_string = ffi.string + + +local ok, new_tab = pcall(require, "table.new") +if not ok then + new_tab = function (narr, nrec) return {} end +end + + +local _M = new_tab(0, 5) + +_M.new_tab = new_tab +_M._VERSION = '0.08' + + +local types = { + [0x0] = "continuation", + [0x1] = "text", + [0x2] = "binary", + [0x8] = "close", + [0x9] = "ping", + [0xa] = "pong", +} + +local str_buf_size = 4096 +local str_buf +local c_buf_type = ffi.typeof("char[?]") + + +local function get_string_buf(size) + if size > str_buf_size then + return ffi_new(c_buf_type, size) + end + if not str_buf then + str_buf = ffi_new(c_buf_type, str_buf_size) + end + + return str_buf +end + + +function _M.recv_frame(sock, max_payload_len, force_masking) + local data, err = sock:receive(2) + if not data then + return nil, nil, "failed to receive the first 2 bytes: " .. err + end + + local fst, snd = byte(data, 1, 2) + + local fin = band(fst, 0x80) ~= 0 + -- print("fin: ", fin) + + if band(fst, 0x70) ~= 0 then + return nil, nil, "bad RSV1, RSV2, or RSV3 bits" + end + + local opcode = band(fst, 0x0f) + -- print("opcode: ", tohex(opcode)) + + if opcode >= 0x3 and opcode <= 0x7 then + return nil, nil, "reserved non-control frames" + end + + if opcode >= 0xb and opcode <= 0xf then + return nil, nil, "reserved control frames" + end + + local mask = band(snd, 0x80) ~= 0 + + if debug then + ngx_log(ngx_DEBUG, "recv_frame: mask bit: ", mask and 1 or 0) + end + + if force_masking and not mask then + return nil, nil, "frame unmasked" + end + + local payload_len = band(snd, 0x7f) + -- print("payload len: ", payload_len) + + if payload_len == 126 then + local data, err = sock:receive(2) + if not data then + return nil, nil, "failed to receive the 2 byte payload length: " + .. (err or "unknown") + end + + payload_len = bor(lshift(byte(data, 1), 8), byte(data, 2)) + + elseif payload_len == 127 then + local data, err = sock:receive(8) + if not data then + return nil, nil, "failed to receive the 8 byte payload length: " + .. (err or "unknown") + end + + if byte(data, 1) ~= 0 + or byte(data, 2) ~= 0 + or byte(data, 3) ~= 0 + or byte(data, 4) ~= 0 + then + return nil, nil, "payload len too large" + end + + local fifth = byte(data, 5) + if band(fifth, 0x80) ~= 0 then + return nil, nil, "payload len too large" + end + + payload_len = bor(lshift(fifth, 24), + lshift(byte(data, 6), 16), + lshift(byte(data, 7), 8), + byte(data, 8)) + end + + if band(opcode, 0x8) ~= 0 then + -- being a control frame + if payload_len > 125 then + return nil, nil, "too long payload for control frame" + end + + if not fin then + return nil, nil, "fragmented control frame" + end + end + + -- print("payload len: ", payload_len, ", max payload len: ", + -- max_payload_len) + + if payload_len > max_payload_len then + return nil, nil, "exceeding max payload len" + end + + local rest + if mask then + rest = payload_len + 4 + + else + rest = payload_len + end + -- print("rest: ", rest) + + local data + if rest > 0 then + data, err = sock:receive(rest) + if not data then + return nil, nil, "failed to read masking-len and payload: " + .. (err or "unknown") + end + else + data = "" + end + + -- print("received rest") + + if opcode == 0x8 then + -- being a close frame + if payload_len > 0 then + if payload_len < 2 then + return nil, nil, "close frame with a body must carry a 2-byte" + .. " status code" + end + + local msg, code + if mask then + local fst = bxor(byte(data, 4 + 1), byte(data, 1)) + local snd = bxor(byte(data, 4 + 2), byte(data, 2)) + code = bor(lshift(fst, 8), snd) + + if payload_len > 2 then + -- TODO string.buffer optimizations + local bytes = get_string_buf(payload_len - 2) + for i = 3, payload_len do + bytes[i - 3] = bxor(byte(data, 4 + i), + byte(data, (i - 1) % 4 + 1)) + end + msg = ffi_string(bytes, payload_len - 2) + + else + msg = "" + end + + else + local fst = byte(data, 1) + local snd = byte(data, 2) + code = bor(lshift(fst, 8), snd) + + -- print("parsing unmasked close frame payload: ", payload_len) + + if payload_len > 2 then + msg = sub(data, 3) + + else + msg = "" + end + end + + return msg, "close", code + end + + return "", "close", nil + end + + local msg + if mask then + -- TODO string.buffer optimizations + local bytes = get_string_buf(payload_len) + for i = 1, payload_len do + bytes[i - 1] = bxor(byte(data, 4 + i), + byte(data, (i - 1) % 4 + 1)) + end + msg = ffi_string(bytes, payload_len) + + else + msg = data + end + + return msg, types[opcode], not fin and "again" or nil +end + + +local function build_frame(fin, opcode, payload_len, payload, masking) + -- XXX optimize this when we have string.buffer in LuaJIT 2.1 + local fst + if fin then + fst = bor(0x80, opcode) + else + fst = opcode + end + + local snd, extra_len_bytes + if payload_len <= 125 then + snd = payload_len + extra_len_bytes = "" + + elseif payload_len <= 65535 then + snd = 126 + extra_len_bytes = char(band(rshift(payload_len, 8), 0xff), + band(payload_len, 0xff)) + + else + if band(payload_len, 0x7fffffff) < payload_len then + return nil, "payload too big" + end + + snd = 127 + -- XXX we only support 31-bit length here + extra_len_bytes = char(0, 0, 0, 0, band(rshift(payload_len, 24), 0xff), + band(rshift(payload_len, 16), 0xff), + band(rshift(payload_len, 8), 0xff), + band(payload_len, 0xff)) + end + + local masking_key + if masking then + -- set the mask bit + snd = bor(snd, 0x80) + local key = rand(0xffffffff) + masking_key = char(band(rshift(key, 24), 0xff), + band(rshift(key, 16), 0xff), + band(rshift(key, 8), 0xff), + band(key, 0xff)) + + -- TODO string.buffer optimizations + local bytes = get_string_buf(payload_len) + for i = 1, payload_len do + bytes[i - 1] = bxor(byte(payload, i), + byte(masking_key, (i - 1) % 4 + 1)) + end + payload = ffi_string(bytes, payload_len) + + else + masking_key = "" + end + + return char(fst, snd) .. extra_len_bytes .. masking_key .. payload +end +_M.build_frame = build_frame + + +function _M.send_frame(sock, fin, opcode, payload, max_payload_len, masking) + -- ngx.log(ngx.WARN, ngx.var.uri, ": masking: ", masking) + + if not payload then + payload = "" + + elseif type(payload) ~= "string" then + payload = tostring(payload) + end + + local payload_len = #payload + + if payload_len > max_payload_len then + return nil, "payload too big" + end + + if band(opcode, 0x8) ~= 0 then + -- being a control frame + if payload_len > 125 then + return nil, "too much payload for control frame" + end + if not fin then + return nil, "fragmented control frame" + end + end + + local frame, err = build_frame(fin, opcode, payload_len, payload, + masking) + if not frame then + return nil, "failed to build frame: " .. err + end + + local bytes, err = sock:send(frame) + if not bytes then + return nil, "failed to send frame: " .. err + end + return bytes +end + + +return _M diff --git a/resty/websocket/server.lua b/resty/websocket/server.lua new file mode 100644 index 0000000..c56f07b --- /dev/null +++ b/resty/websocket/server.lua @@ -0,0 +1,210 @@ +-- Copyright (C) Yichun Zhang (agentzh) + + +local bit = require "bit" +local wbproto = require "resty.websocket.protocol" + +local new_tab = wbproto.new_tab +local _recv_frame = wbproto.recv_frame +local _send_frame = wbproto.send_frame +local http_ver = ngx.req.http_version +local req_sock = ngx.req.socket +local ngx_header = ngx.header +local req_headers = ngx.req.get_headers +local str_lower = string.lower +local char = string.char +local str_find = string.find +local sha1_bin = ngx.sha1_bin +local base64 = ngx.encode_base64 +local ngx = ngx +local read_body = ngx.req.read_body +local band = bit.band +local rshift = bit.rshift +local type = type +local setmetatable = setmetatable +local tostring = tostring +-- local print = print + + +local _M = new_tab(0, 10) +_M._VERSION = '0.08' + +local mt = { __index = _M } + + +function _M.new(self, opts) + if ngx.headers_sent then + return nil, "response header already sent" + end + + read_body() + + if http_ver() ~= 1.1 then + return nil, "bad http version" + end + + local headers = req_headers() + + local val = headers.upgrade + if type(val) == "table" then + val = val[1] + end + if not val or str_lower(val) ~= "websocket" then + return nil, "bad \"upgrade\" request header: " .. tostring(val) + end + + val = headers.connection + if type(val) == "table" then + val = val[1] + end + if not val or not str_find(str_lower(val), "upgrade", 1, true) then + return nil, "bad \"connection\" request header" + end + + local key = headers["sec-websocket-key"] + if type(key) == "table" then + key = key[1] + end + if not key then + return nil, "bad \"sec-websocket-key\" request header" + end + + local ver = headers["sec-websocket-version"] + if type(ver) == "table" then + ver = ver[1] + end + if not ver or ver ~= "13" then + return nil, "bad \"sec-websocket-version\" request header" + end + + local protocols = headers["sec-websocket-protocol"] + if type(protocols) == "table" then + protocols = protocols[1] + end + + if protocols then + ngx_header["Sec-WebSocket-Protocol"] = protocols + end + ngx_header["Upgrade"] = "websocket" + + local sha1 = sha1_bin(key .. "258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + ngx_header["Sec-WebSocket-Accept"] = base64(sha1) + + ngx_header["Content-Type"] = nil + + ngx.status = 101 + local ok, err = ngx.send_headers() + if not ok then + return nil, "failed to send response header: " .. (err or "unknonw") + end + ok, err = ngx.flush(true) + if not ok then + return nil, "failed to flush response header: " .. (err or "unknown") + end + + local sock + sock, err = req_sock(true) + if not sock then + return nil, err + end + + local max_payload_len, send_masked, timeout + if opts then + max_payload_len = opts.max_payload_len + send_masked = opts.send_masked + timeout = opts.timeout + + if timeout then + sock:settimeout(timeout) + end + end + + return setmetatable({ + sock = sock, + max_payload_len = max_payload_len or 65535, + send_masked = send_masked, + }, mt) +end + + +function _M.set_timeout(self, time) + local sock = self.sock + if not sock then + return nil, nil, "not initialized yet" + end + + return sock:settimeout(time) +end + + +function _M.recv_frame(self) + if self.fatal then + return nil, nil, "fatal error already happened" + end + + local sock = self.sock + if not sock then + return nil, nil, "not initialized yet" + end + + local data, typ, err = _recv_frame(sock, self.max_payload_len, true) + if not data and not str_find(err, ": timeout", 1, true) then + self.fatal = true + end + return data, typ, err +end + + +local function send_frame(self, fin, opcode, payload) + if self.fatal then + return nil, "fatal error already happened" + end + + local sock = self.sock + if not sock then + return nil, "not initialized yet" + end + + local bytes, err = _send_frame(sock, fin, opcode, payload, + self.max_payload_len, self.send_masked) + if not bytes then + self.fatal = true + end + return bytes, err +end +_M.send_frame = send_frame + + +function _M.send_text(self, data) + return send_frame(self, true, 0x1, data) +end + + +function _M.send_binary(self, data) + return send_frame(self, true, 0x2, data) +end + + +function _M.send_close(self, code, msg) + local payload + if code then + if type(code) ~= "number" or code > 0x7fff then + end + payload = char(band(rshift(code, 8), 0xff), band(code, 0xff)) + .. (msg or "") + end + return send_frame(self, true, 0x8, payload) +end + + +function _M.send_ping(self, data) + return send_frame(self, true, 0x9, data) +end + + +function _M.send_pong(self, data) + return send_frame(self, true, 0xa, data) +end + + +return _M diff --git a/rsp_body.lua b/rsp_body.lua index e43b051..d71bb3e 100644 --- a/rsp_body.lua +++ b/rsp_body.lua @@ -1,2 +1,2 @@ jsConfuse() -dateReplace() \ No newline at end of file +dateReplace() diff --git a/tools.lua b/tools.lua new file mode 100644 index 0000000..6dca087 --- /dev/null +++ b/tools.lua @@ -0,0 +1,33 @@ +function i_get_cookie(s_cookie) + local cookie = {} + + -- string.gfind is renamed to string.gmatch + for item in string.gmatch(s_cookie, "[^;]+") do + local _, _, k, v = string.find(item, "^%s*(%S+)%s*=%s*(%S+)%s*") + if k ~= nil and v ~= nil then + cookie[k] = v + end + end + + return cookie +end + +function get_cookie_table() + local raw_cookie = ngx.req.get_headers()["Cookie"] + if raw_cookie ~= nil then + return i_get_cookie(raw_cookie) + end + return nil +end + +function get_cookie_raw() + return ngx.req.get_headers()["Cookie"] +end + +function match_string(input_str, rule) + if input_str == nil then + return false + end + local from, to, err = ngx.re.find(input_str, rule, "jo") + return from ~= nil +end diff --git a/waf/report.lua b/waf/report.lua new file mode 100644 index 0000000..92a912e --- /dev/null +++ b/waf/report.lua @@ -0,0 +1,11 @@ +local _M = {} +function _M.violation(result) + full_violation_text = "violation: \n" + if result == g_violation_sql_detect then + g_result_sql_detect = full_violation_text .. "SQL Injection detected \n" + end + full_violation_text = full_violation_text .. ngx.var.request_uri .. " \n" + log(violation) + say_html() +end +return _M diff --git a/waf/rule.lua b/waf/rule.lua new file mode 100644 index 0000000..fea6201 --- /dev/null +++ b/waf/rule.lua @@ -0,0 +1,6 @@ +local _M = { + sql_get = "'|\\b(and|or)\\b.+?(>|<|=|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)", + sql_post = "\\b(and|or)\\b.{1,6}?(=|>|<|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)", + sql_cookie = "\\b(and|or)\\b.{1,6}?(=|>|<|\\bin\\b|\\blike\\b)|\\/\\*.+?\\*\\/|<\\s*script\\b|\\bEXEC\\b|UNION.+?SELECT|UPDATE.+?SET|INSERT\\s+INTO.+?VALUES|(SELECT|DELETE).+?FROM|(CREATE|ALTER|DROP|TRUNCATE)\\s+(TABLE|DATABASE)" +} +return _M diff --git a/waf/sql.lua b/waf/sql.lua new file mode 100644 index 0000000..c65f169 --- /dev/null +++ b/waf/sql.lua @@ -0,0 +1,37 @@ +local _M = {} +local ngx_base = require "resty.core.base" +local waf_rule = require "waf/rule" +local waf_voilation = require "waf/violation_list" +local waf_report = require "waf/report" + +function _M.waf_sql_filter_params(arg_tables, filter_rule) + if arg_tables then + for key, val in pairs(arg_tables) do + if match_string(val, filter_rule) then + --ngx.say(key .. ":" .. val) + waf_report.violation(waf_voilation.sql_detect) + end + end + end +end + +function _M.waf_sql_filter() + --local get_arags = ngx.req.get_uri_args + --for k, v in ipairs(get_arags) do + -- print(v) + --end + if ngx_base.get_request() ~= nil then + _M.waf_sql_filter_params(ngx.req.get_uri_args(), waf_rule.sql_get) + ngx.say(get_cookie_raw()) + if match_string(get_cookie_raw(), waf_rule.sql_cookie) then + waf_report.violation(waf_voilation.sql_detect) + end + + local is_post_method = ngx.req.get_method() == "POST" + if is_post_method then + ngx.req.read_body() + _M.waf_sql_filter_params(ngx.req.get_post_args(), waf_rule.sql_post) + end + end +end +return _M diff --git a/waf/violation_list.lua b/waf/violation_list.lua new file mode 100644 index 0000000..1a711f3 --- /dev/null +++ b/waf/violation_list.lua @@ -0,0 +1,3 @@ +return { + sql_detect = 1 +} diff --git a/waf/waf.lua b/waf/waf.lua new file mode 100644 index 0000000..3e77a90 --- /dev/null +++ b/waf/waf.lua @@ -0,0 +1,8 @@ +local waf_sql = require "waf/sql" + +--以后加规则配置、插件这些.现在不加 +function waf_dispatch() + waf_sql.waf_sql_filter() +end + +waf_dispatch()