1411 lines
34 KiB
Lua
1411 lines
34 KiB
Lua
-- Copyright (C) Yichun Zhang (agentzh)
|
|
|
|
|
|
local bit = require "bit"
|
|
local resty_sha256 = require "resty.sha256"
|
|
local sub = string.sub
|
|
local tcp = ngx.socket.tcp
|
|
local strbyte = string.byte
|
|
local strchar = string.char
|
|
local strfind = string.find
|
|
local format = string.format
|
|
local strrep = string.rep
|
|
local null = ngx.null
|
|
local band = bit.band
|
|
local bxor = bit.bxor
|
|
local bor = bit.bor
|
|
local lshift = bit.lshift
|
|
local rshift = bit.rshift
|
|
local tohex = bit.tohex
|
|
local sha1 = ngx.sha1_bin
|
|
local concat = table.concat
|
|
local setmetatable = setmetatable
|
|
local error = error
|
|
local tonumber = tonumber
|
|
local to_int = math.floor
|
|
|
|
local has_rsa, resty_rsa = pcall(require, "resty.rsa")
|
|
|
|
|
|
if not ngx.config then
|
|
error("ngx_lua 0.9.11+ or ngx_stream_lua required")
|
|
end
|
|
|
|
if (not ngx.config.subsystem
|
|
or ngx.config.subsystem == "http") -- subsystem is http
|
|
and (not ngx.config.ngx_lua_version
|
|
or ngx.config.ngx_lua_version < 9011) -- old version
|
|
then
|
|
error("ngx_lua 0.9.11+ required")
|
|
end
|
|
|
|
|
|
local ok, new_tab = pcall(require, "table.new")
|
|
if not ok then
|
|
new_tab = function (narr, nrec) return {} end
|
|
end
|
|
|
|
|
|
local _M = { _VERSION = '0.24' }
|
|
|
|
|
|
-- constants
|
|
|
|
local STATE_CONNECTED = 1
|
|
local STATE_COMMAND_SENT = 2
|
|
|
|
local COM_QUIT = 0x01
|
|
local COM_QUERY = 0x03
|
|
|
|
-- refer to https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
|
|
-- CLIENT_LONG_PASSWORD | CLIENT_FOUND_ROWS | CLIENT_LONG_FLAG
|
|
-- | CLIENT_CONNECT_WITH_DB | CLIENT_ODBC | CLIENT_LOCAL_FILES
|
|
-- | CLIENT_IGNORE_SPACE | CLIENT_PROTOCOL_41 | CLIENT_INTERACTIVE
|
|
-- | CLIENT_IGNORE_SIGPIPE | CLIENT_TRANSACTIONS | CLIENT_RESERVED
|
|
-- | CLIENT_SECURE_CONNECTION | CLIENT_MULTI_STATEMENTS | CLIENT_MULTI_RESULTS
|
|
local DEFAULT_CLIENT_FLAGS = 0x3f7cf
|
|
local CLIENT_SSL = 0x00000800
|
|
local CLIENT_PLUGIN_AUTH = 0x00080000
|
|
local CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000
|
|
|
|
local SERVER_MORE_RESULTS_EXISTS = 8
|
|
|
|
local RESP_OK = "OK"
|
|
local RESP_AUTHMOREDATA = "AUTHMOREDATA"
|
|
local RESP_LOCALINFILE = "LOCALINFILE"
|
|
local RESP_EOF = "EOF"
|
|
local RESP_ERR = "ERR"
|
|
local RESP_DATA = "DATA"
|
|
|
|
local MY_RND_MAX_VAL = 0x3FFFFFFF
|
|
local MIN_PROTOCOL_VER = 10
|
|
|
|
local LEN_NATIVE_SCRAMBLE = 20
|
|
local LEN_OLD_SCRAMBLE = 8
|
|
|
|
-- 16MB - 1, the default max allowed packet size used by libmysqlclient
|
|
local FULL_PACKET_SIZE = 16777215
|
|
|
|
-- the following charset map is generated from the following mysql query:
|
|
-- SELECT CHARACTER_SET_NAME, ID
|
|
-- FROM information_schema.collations
|
|
-- WHERE IS_DEFAULT = 'Yes' ORDER BY id;
|
|
local CHARSET_MAP = {
|
|
_default = 0,
|
|
big5 = 1,
|
|
dec8 = 3,
|
|
cp850 = 4,
|
|
hp8 = 6,
|
|
koi8r = 7,
|
|
latin1 = 8,
|
|
latin2 = 9,
|
|
swe7 = 10,
|
|
ascii = 11,
|
|
ujis = 12,
|
|
sjis = 13,
|
|
hebrew = 16,
|
|
tis620 = 18,
|
|
euckr = 19,
|
|
koi8u = 22,
|
|
gb2312 = 24,
|
|
greek = 25,
|
|
cp1250 = 26,
|
|
gbk = 28,
|
|
latin5 = 30,
|
|
armscii8 = 32,
|
|
utf8 = 33,
|
|
ucs2 = 35,
|
|
cp866 = 36,
|
|
keybcs2 = 37,
|
|
macce = 38,
|
|
macroman = 39,
|
|
cp852 = 40,
|
|
latin7 = 41,
|
|
utf8mb4 = 45,
|
|
cp1251 = 51,
|
|
utf16 = 54,
|
|
utf16le = 56,
|
|
cp1256 = 57,
|
|
cp1257 = 59,
|
|
utf32 = 60,
|
|
binary = 63,
|
|
geostd8 = 92,
|
|
cp932 = 95,
|
|
eucjpms = 97,
|
|
gb18030 = 248
|
|
}
|
|
|
|
local mt = { __index = _M }
|
|
|
|
|
|
-- mysql field value type converters
|
|
local converters = new_tab(0, 9)
|
|
|
|
for i = 0x01, 0x05 do
|
|
-- tiny, short, long, float, double
|
|
converters[i] = tonumber
|
|
end
|
|
converters[0x00] = tonumber -- decimal
|
|
-- converters[0x08] = tonumber -- long long
|
|
converters[0x09] = tonumber -- int24
|
|
converters[0x0d] = tonumber -- year
|
|
converters[0xf6] = tonumber -- newdecimal
|
|
|
|
|
|
local function _get_byte2(data, i)
|
|
local a, b = strbyte(data, i, i + 1)
|
|
return bor(a, lshift(b, 8)), i + 2
|
|
end
|
|
|
|
|
|
local function _get_byte3(data, i)
|
|
local a, b, c = strbyte(data, i, i + 2)
|
|
return bor(a, lshift(b, 8), lshift(c, 16)), i + 3
|
|
end
|
|
|
|
|
|
local function _get_byte4(data, i)
|
|
local a, b, c, d = strbyte(data, i, i + 3)
|
|
return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24)), i + 4
|
|
end
|
|
|
|
|
|
local function _get_byte8(data, i)
|
|
local a, b, c, d, e, f, g, h = strbyte(data, i, i + 7)
|
|
|
|
-- XXX workaround for the lack of 64-bit support in bitop:
|
|
-- XXX return results in the range of signed 32 bit numbers
|
|
local lo = bor(a, lshift(b, 8), lshift(c, 16))
|
|
local hi = bor(e, lshift(f, 8), lshift(g, 16), lshift(h, 24))
|
|
return lo + 16777216 * d + hi * 4294967296, i + 8
|
|
|
|
-- return bor(a, lshift(b, 8), lshift(c, 16), lshift(d, 24), lshift(e, 32),
|
|
-- lshift(f, 40), lshift(g, 48), lshift(h, 56)), i + 8
|
|
end
|
|
|
|
|
|
local function _set_byte2(n)
|
|
return strchar(band(n, 0xff), band(rshift(n, 8), 0xff))
|
|
end
|
|
|
|
|
|
local function _set_byte3(n)
|
|
return strchar(band(n, 0xff),
|
|
band(rshift(n, 8), 0xff),
|
|
band(rshift(n, 16), 0xff))
|
|
end
|
|
|
|
|
|
local function _set_byte4(n)
|
|
return strchar(band(n, 0xff),
|
|
band(rshift(n, 8), 0xff),
|
|
band(rshift(n, 16), 0xff),
|
|
band(rshift(n, 24), 0xff))
|
|
end
|
|
|
|
|
|
local function _from_cstring(data, i)
|
|
local last = strfind(data, "\0", i, true)
|
|
if not last then
|
|
return nil, nil
|
|
end
|
|
|
|
return sub(data, i, last - 1), last + 1
|
|
end
|
|
|
|
|
|
local function _to_cstring(data)
|
|
return data .. "\0"
|
|
end
|
|
|
|
|
|
local function _dump(data)
|
|
local len = #data
|
|
local bytes = new_tab(len, 0)
|
|
for i = 1, len do
|
|
bytes[i] = format("%x", strbyte(data, i))
|
|
end
|
|
return concat(bytes, " ")
|
|
end
|
|
|
|
|
|
local function _dumphex(data)
|
|
local len = #data
|
|
local bytes = new_tab(len, 0)
|
|
for i = 1, len do
|
|
bytes[i] = tohex(strbyte(data, i), 2)
|
|
end
|
|
return concat(bytes, " ")
|
|
end
|
|
|
|
|
|
local function _pwd_hash(password)
|
|
local add = 7
|
|
|
|
local hash1 = 1345345333
|
|
local hash2 = 0x12345671
|
|
|
|
local len = #password
|
|
for i = 1, len do
|
|
-- skip spaces and tabs in password
|
|
local byte = strbyte(password, i)
|
|
if byte ~= 32 and byte ~= 9 then -- not ' ' or '\t'
|
|
hash1 = bxor(hash1, (band(hash1, 63) + add) * byte
|
|
+ lshift(hash1, 8))
|
|
|
|
hash2 = bxor(lshift(hash2, 8), hash1) + hash2
|
|
|
|
add = add + byte
|
|
end
|
|
end
|
|
|
|
-- remove sign bit (1<<31)-1)
|
|
return band(hash1, 0x7FFFFFFF), band(hash2, 0x7FFFFFFF)
|
|
end
|
|
|
|
|
|
local function _random_byte(seed1, seed2)
|
|
seed1 = (seed1 * 3 + seed2) % MY_RND_MAX_VAL
|
|
seed2 = (seed1 + seed2 + 33) % MY_RND_MAX_VAL
|
|
|
|
return to_int(seed1 * 31 / MY_RND_MAX_VAL), seed1, seed2
|
|
end
|
|
|
|
|
|
local function _compute_old_token(password, scramble)
|
|
if password == "" then
|
|
return ""
|
|
end
|
|
|
|
scramble = sub(scramble, 1, LEN_OLD_SCRAMBLE)
|
|
|
|
local hash_pw1, hash_pw2 = _pwd_hash(password)
|
|
local hash_sc1, hash_sc2 = _pwd_hash(scramble)
|
|
|
|
local seed1 = bxor(hash_pw1, hash_sc1) % MY_RND_MAX_VAL
|
|
local seed2 = bxor(hash_pw2, hash_sc2) % MY_RND_MAX_VAL
|
|
local rand_byte
|
|
|
|
local bytes = new_tab(LEN_OLD_SCRAMBLE, 0)
|
|
for i = 1, LEN_OLD_SCRAMBLE do
|
|
rand_byte, seed1, seed2 = _random_byte(seed1, seed2)
|
|
bytes[i] = rand_byte + 64
|
|
end
|
|
|
|
rand_byte = _random_byte(seed1, seed2)
|
|
for i = 1, LEN_OLD_SCRAMBLE do
|
|
bytes[i] = strchar(bxor(bytes[i], rand_byte))
|
|
end
|
|
|
|
return _to_cstring(concat(bytes))
|
|
end
|
|
|
|
|
|
local function _compute_sha256_token(password, scramble)
|
|
if password == "" then
|
|
return ""
|
|
end
|
|
|
|
local sha256 = resty_sha256:new()
|
|
if not sha256 then
|
|
return nil, "failed to create the sha256 object"
|
|
end
|
|
|
|
if not sha256:update(password) then
|
|
return nil, "failed to update string to sha256"
|
|
end
|
|
|
|
local message1 = sha256:final()
|
|
|
|
sha256:reset()
|
|
|
|
if not sha256:update(message1) then
|
|
return nil, "failed to update string to sha256"
|
|
end
|
|
|
|
local message1_hash = sha256:final()
|
|
|
|
sha256:reset()
|
|
|
|
if not sha256:update(message1_hash) then
|
|
return nil, "failed to update string to sha256"
|
|
end
|
|
|
|
if not sha256:update(scramble) then
|
|
return nil, "failed to update string to sha256"
|
|
end
|
|
|
|
local message2 = sha256:final()
|
|
|
|
local n = #message2
|
|
local bytes = new_tab(n, 0)
|
|
for i = 1, n do
|
|
bytes[i] = strchar(bxor(strbyte(message1, i), strbyte(message2, i)))
|
|
end
|
|
|
|
return concat(bytes)
|
|
end
|
|
|
|
|
|
local function _compute_token(password, scramble)
|
|
if password == "" then
|
|
return ""
|
|
end
|
|
|
|
scramble = sub(scramble, 1, LEN_NATIVE_SCRAMBLE)
|
|
|
|
local stage1 = sha1(password)
|
|
local stage2 = sha1(stage1)
|
|
local stage3 = sha1(scramble .. stage2)
|
|
local n = #stage1
|
|
local bytes = new_tab(n, 0)
|
|
for i = 1, n do
|
|
bytes[i] = strchar(bxor(strbyte(stage3, i), strbyte(stage1, i)))
|
|
end
|
|
|
|
return concat(bytes)
|
|
end
|
|
|
|
|
|
local function _send_packet(self, req, size)
|
|
local sock = self.sock
|
|
|
|
self.packet_no = self.packet_no + 1
|
|
|
|
-- print("packet no: ", self.packet_no)
|
|
|
|
local packet = _set_byte3(size) .. strchar(band(self.packet_no, 255)) .. req
|
|
|
|
-- print("sending packet: ", _dump(packet))
|
|
|
|
-- print("sending packet... of size " .. #packet)
|
|
|
|
return sock:send(packet)
|
|
end
|
|
|
|
|
|
local function _recv_packet(self)
|
|
local sock = self.sock
|
|
|
|
local data, err = sock:receive(4) -- packet header
|
|
if not data then
|
|
return nil, nil, "failed to receive packet header: " .. err
|
|
end
|
|
|
|
--print("packet header: ", _dump(data))
|
|
|
|
local len, pos = _get_byte3(data, 1)
|
|
|
|
--print("packet length: ", len)
|
|
|
|
if len == 0 then
|
|
return nil, nil, "empty packet"
|
|
end
|
|
|
|
if len > self._max_packet_size then
|
|
return nil, nil, "packet size too big: " .. len
|
|
end
|
|
|
|
local num = strbyte(data, pos)
|
|
|
|
--print("recv packet: packet no: ", num)
|
|
|
|
self.packet_no = num
|
|
|
|
data, err = sock:receive(len)
|
|
|
|
--print("receive returned")
|
|
|
|
if not data then
|
|
return nil, nil, "failed to read packet content: " .. err
|
|
end
|
|
|
|
--print("packet content: ", _dump(data))
|
|
--print("packet content (ascii): ", data)
|
|
|
|
local field_count = strbyte(data, 1)
|
|
|
|
local typ
|
|
if field_count == 0x00 then
|
|
typ = RESP_OK
|
|
elseif field_count == 0x01 then
|
|
typ = RESP_AUTHMOREDATA
|
|
elseif field_count == 0xfb then
|
|
typ = RESP_LOCALINFILE
|
|
elseif field_count == 0xfe then
|
|
typ = RESP_EOF
|
|
elseif field_count == 0xff then
|
|
typ = RESP_ERR
|
|
else
|
|
typ = RESP_DATA
|
|
end
|
|
|
|
return data, typ
|
|
end
|
|
|
|
|
|
local function _from_length_coded_bin(data, pos)
|
|
local first = strbyte(data, pos)
|
|
|
|
--print("LCB: first: ", first)
|
|
|
|
if not first then
|
|
return nil, pos
|
|
end
|
|
|
|
if first >= 0 and first <= 250 then
|
|
return first, pos + 1
|
|
end
|
|
|
|
if first == 251 then
|
|
return null, pos + 1
|
|
end
|
|
|
|
if first == 252 then
|
|
pos = pos + 1
|
|
return _get_byte2(data, pos)
|
|
end
|
|
|
|
if first == 253 then
|
|
pos = pos + 1
|
|
return _get_byte3(data, pos)
|
|
end
|
|
|
|
if first == 254 then
|
|
pos = pos + 1
|
|
return _get_byte8(data, pos)
|
|
end
|
|
|
|
return nil, pos + 1
|
|
end
|
|
|
|
|
|
local function _from_length_coded_str(data, pos)
|
|
local len
|
|
len, pos = _from_length_coded_bin(data, pos)
|
|
if not len or len == null then
|
|
return null, pos
|
|
end
|
|
|
|
return sub(data, pos, pos + len - 1), pos + len
|
|
end
|
|
|
|
|
|
local function _parse_ok_packet(packet)
|
|
local res = new_tab(0, 5)
|
|
local pos
|
|
|
|
res.affected_rows, pos = _from_length_coded_bin(packet, 2)
|
|
|
|
--print("affected rows: ", res.affected_rows, ", pos:", pos)
|
|
|
|
res.insert_id, pos = _from_length_coded_bin(packet, pos)
|
|
|
|
--print("insert id: ", res.insert_id, ", pos:", pos)
|
|
|
|
res.server_status, pos = _get_byte2(packet, pos)
|
|
|
|
--print("server status: ", res.server_status, ", pos:", pos)
|
|
|
|
res.warning_count, pos = _get_byte2(packet, pos)
|
|
|
|
--print("warning count: ", res.warning_count, ", pos: ", pos)
|
|
|
|
local message = _from_length_coded_str(packet, pos)
|
|
if message and message ~= null then
|
|
res.message = message
|
|
end
|
|
|
|
--print("message: ", res.message, ", pos:", pos)
|
|
|
|
return res
|
|
end
|
|
|
|
|
|
local function _parse_eof_packet(packet)
|
|
local pos = 2
|
|
|
|
local warning_count, pos = _get_byte2(packet, pos)
|
|
local status_flags = _get_byte2(packet, pos)
|
|
|
|
return warning_count, status_flags
|
|
end
|
|
|
|
|
|
local function _parse_err_packet(packet)
|
|
local errno, pos = _get_byte2(packet, 2)
|
|
local marker = sub(packet, pos, pos)
|
|
local sqlstate
|
|
if marker == '#' then
|
|
-- with sqlstate
|
|
pos = pos + 1
|
|
sqlstate = sub(packet, pos, pos + 5 - 1)
|
|
pos = pos + 5
|
|
end
|
|
|
|
local message = sub(packet, pos)
|
|
return errno, message, sqlstate
|
|
end
|
|
|
|
|
|
local function _parse_result_set_header_packet(packet)
|
|
local field_count, pos = _from_length_coded_bin(packet, 1)
|
|
|
|
local extra
|
|
extra = _from_length_coded_bin(packet, pos)
|
|
|
|
return field_count, extra
|
|
end
|
|
|
|
|
|
local function _parse_field_packet(data)
|
|
local col = new_tab(0, 2)
|
|
local catalog, db, table, orig_table, orig_name, charsetnr, length
|
|
local pos
|
|
catalog, pos = _from_length_coded_str(data, 1)
|
|
|
|
--print("catalog: ", col.catalog, ", pos:", pos)
|
|
|
|
db, pos = _from_length_coded_str(data, pos)
|
|
table, pos = _from_length_coded_str(data, pos)
|
|
orig_table, pos = _from_length_coded_str(data, pos)
|
|
col.name, pos = _from_length_coded_str(data, pos)
|
|
|
|
orig_name, pos = _from_length_coded_str(data, pos)
|
|
|
|
pos = pos + 1 -- ignore the filler
|
|
|
|
charsetnr, pos = _get_byte2(data, pos)
|
|
|
|
length, pos = _get_byte4(data, pos)
|
|
|
|
col.type = strbyte(data, pos)
|
|
|
|
--[[
|
|
pos = pos + 1
|
|
|
|
col.flags, pos = _get_byte2(data, pos)
|
|
|
|
col.decimals = strbyte(data, pos)
|
|
pos = pos + 1
|
|
|
|
local default = sub(data, pos + 2)
|
|
if default and default ~= "" then
|
|
col.default = default
|
|
end
|
|
--]]
|
|
|
|
return col
|
|
end
|
|
|
|
|
|
local function _parse_row_data_packet(data, cols, compact)
|
|
local pos = 1
|
|
local ncols = #cols
|
|
local row
|
|
if compact then
|
|
row = new_tab(ncols, 0)
|
|
else
|
|
row = new_tab(0, ncols)
|
|
end
|
|
for i = 1, ncols do
|
|
local value
|
|
value, pos = _from_length_coded_str(data, pos)
|
|
local col = cols[i]
|
|
local typ = col.type
|
|
local name = col.name
|
|
|
|
--print("row field value: ", value, ", type: ", typ)
|
|
|
|
if value ~= null then
|
|
local conv = converters[typ]
|
|
if conv then
|
|
value = conv(value)
|
|
end
|
|
end
|
|
|
|
if compact then
|
|
row[i] = value
|
|
|
|
else
|
|
row[name] = value
|
|
end
|
|
end
|
|
|
|
return row
|
|
end
|
|
|
|
|
|
local function _recv_field_packet(self)
|
|
local packet, typ, err = _recv_packet(self)
|
|
if not packet then
|
|
return nil, err
|
|
end
|
|
|
|
if typ == RESP_ERR then
|
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
|
return nil, msg, errno, sqlstate
|
|
end
|
|
|
|
if typ ~= RESP_DATA then
|
|
return nil, "bad field packet type: " .. typ
|
|
end
|
|
|
|
-- typ == RESP_DATA
|
|
|
|
return _parse_field_packet(packet)
|
|
end
|
|
|
|
|
|
-- refer to https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
|
|
local function _read_hand_shake_packet(self)
|
|
local packet, typ, err = _recv_packet(self)
|
|
if not packet then
|
|
return nil, nil, err
|
|
end
|
|
|
|
if typ == RESP_ERR then
|
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
|
return nil, nil, msg, errno, sqlstate
|
|
end
|
|
|
|
local protocol_ver = tonumber(strbyte(packet))
|
|
if not protocol_ver then
|
|
return nil, nil,
|
|
"bad handshake initialization packet: bad protocol version"
|
|
end
|
|
|
|
if protocol_ver < MIN_PROTOCOL_VER then
|
|
return nil, nil, "unsupported protocol version " .. protocol_ver
|
|
.. ", version " .. MIN_PROTOCOL_VER
|
|
.. " or higher is required"
|
|
end
|
|
|
|
self.protocol_ver = protocol_ver
|
|
|
|
local server_ver, pos = _from_cstring(packet, 2)
|
|
if not server_ver then
|
|
return nil, nil,
|
|
"bad handshake initialization packet: bad server version"
|
|
end
|
|
|
|
self._server_ver = server_ver
|
|
|
|
local thread_id, pos = _get_byte4(packet, pos)
|
|
|
|
local scramble = sub(packet, pos, pos + 8 - 1)
|
|
if not scramble then
|
|
return nil, nil, "1st part of scramble not found"
|
|
end
|
|
|
|
pos = pos + 9 -- skip filler(8 + 1)
|
|
|
|
-- two lower bytes
|
|
local capabilities -- server capabilities
|
|
capabilities, pos = _get_byte2(packet, pos)
|
|
|
|
self._server_lang = strbyte(packet, pos)
|
|
pos = pos + 1
|
|
|
|
self._server_status, pos = _get_byte2(packet, pos)
|
|
|
|
local more_capabilities
|
|
more_capabilities, pos = _get_byte2(packet, pos)
|
|
|
|
self.capabilities = bor(capabilities, lshift(more_capabilities, 16))
|
|
|
|
pos = pos + 11 -- skip length of auth-plugin-data(1) and reserved(10)
|
|
|
|
-- follow official Python library uses the fixed length 12
|
|
-- and the 13th byte is "\0 byte
|
|
local scramble_part2 = sub(packet, pos, pos + 12 - 1)
|
|
if not scramble_part2 then
|
|
return nil, nil, "2nd part of scramble not found"
|
|
end
|
|
|
|
pos = pos + 13
|
|
|
|
local plugin, _ = _from_cstring(packet, pos)
|
|
if not plugin then
|
|
-- EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
|
|
-- \NUL otherwise
|
|
plugin = sub(packet, pos)
|
|
end
|
|
|
|
return scramble .. scramble_part2, plugin
|
|
end
|
|
|
|
|
|
local function _append_auth_length(self, data)
|
|
local n = #data
|
|
|
|
if n <= 250 then
|
|
data = strchar(n) .. data
|
|
return data, 1 + n
|
|
end
|
|
|
|
self.DEFAULT_CLIENT_FLAGS = bor(self.DEFAULT_CLIENT_FLAGS,
|
|
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
|
|
|
if n <= 0xffff then
|
|
data = strchar(0xfc, band(n, 0xff), band(rshift(n, 8), 0xff)) .. data
|
|
return data, 3 + n
|
|
end
|
|
|
|
if n <= 0xffffff then
|
|
data = strchar(0xfd,
|
|
band(n, 0xff),
|
|
band(rshift(n, 8), 0xff),
|
|
band(rshift(n, 16), 0xff))
|
|
.. data
|
|
return data, 4 + n
|
|
end
|
|
|
|
data = strchar(0xfe,
|
|
band(n, 0xff),
|
|
band(rshift(n, 8), 0xff),
|
|
band(rshift(n, 16), 0xff),
|
|
band(rshift(n, 24), 0xff),
|
|
band(rshift(n, 32), 0xff),
|
|
band(rshift(n, 40), 0xff),
|
|
band(rshift(n, 48), 0xff),
|
|
band(rshift(n, 56), 0xff))
|
|
.. data
|
|
return data, 9 + n
|
|
end
|
|
|
|
|
|
local function _write_hand_shake_response(self, auth_resp, plugin)
|
|
local append_auth, len = _append_auth_length(self, auth_resp)
|
|
|
|
if self.use_ssl then
|
|
if band(self.capabilities, CLIENT_SSL) == 0 then
|
|
return "ssl disabled on server"
|
|
end
|
|
|
|
-- send a SSL Request Packet
|
|
local req = _set_byte4(bor(self.DEFAULT_CLIENT_FLAGS, CLIENT_SSL))
|
|
.. _set_byte4(self._max_packet_size)
|
|
.. strchar(self.charset)
|
|
.. strrep("\0", 23)
|
|
|
|
local packet_len = 4 + 4 + 1 + 23
|
|
local bytes, err = _send_packet(self, req, packet_len)
|
|
if not bytes then
|
|
return "failed to send client authentication packet: " .. err
|
|
end
|
|
|
|
local sock = self.sock
|
|
|
|
local ok, err = sock:sslhandshake(false, nil, self.ssl_verify)
|
|
if not ok then
|
|
return "failed to do ssl handshake: " .. (err or "")
|
|
end
|
|
end
|
|
|
|
local req = _set_byte4(self.DEFAULT_CLIENT_FLAGS)
|
|
.. _set_byte4(self._max_packet_size)
|
|
.. strchar(self.charset)
|
|
.. strrep("\0", 23)
|
|
.. _to_cstring(self.user)
|
|
.. append_auth
|
|
.. _to_cstring(self.database)
|
|
.. _to_cstring(plugin)
|
|
|
|
local packet_len = 4 + 4 + 1 + 23 + #self.user + 1
|
|
+ len + #self.database + 1 + #plugin + 1
|
|
|
|
local bytes, err = _send_packet(self, req, packet_len)
|
|
if not bytes then
|
|
return "failed to send client authentication packet: " .. err
|
|
end
|
|
|
|
return nil
|
|
end
|
|
|
|
|
|
local function _read_auth_result(self, old_auth_data, plugin)
|
|
local packet, typ, err = _recv_packet(self)
|
|
if not packet then
|
|
return nil, nil, "failed to receive the result packet: " .. err
|
|
end
|
|
|
|
if typ == RESP_OK then
|
|
return RESP_OK, ""
|
|
end
|
|
|
|
if typ == RESP_AUTHMOREDATA then
|
|
return sub(packet, 2), ""
|
|
end
|
|
|
|
if typ == RESP_EOF then
|
|
if #packet == 1 then -- old pre-4.1 authentication protocol
|
|
return nil, "mysql_old_password"
|
|
end
|
|
|
|
local pos
|
|
|
|
plugin, pos = _from_cstring(packet, 2)
|
|
if not plugin then
|
|
return nil, nil, "malformed packet"
|
|
end
|
|
|
|
return sub(packet, pos), plugin
|
|
end
|
|
|
|
if typ == RESP_ERR then
|
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
|
return errno, sqlstate, msg
|
|
end
|
|
|
|
return nil, nil, "bad packet type: " .. typ
|
|
end
|
|
|
|
|
|
local function _read_ok_result(self)
|
|
local packet, typ, err = _recv_packet(self)
|
|
if not packet then
|
|
return "failed to receive the result packet: " .. err
|
|
end
|
|
|
|
if typ == RESP_ERR then
|
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
|
return msg, errno, sqlstate
|
|
end
|
|
|
|
if typ ~= RESP_OK then
|
|
return "bad packet type: " .. typ
|
|
end
|
|
end
|
|
|
|
|
|
local function _encrypt_password(self, auth_data, public_key)
|
|
if not has_rsa then
|
|
error("auth plugin caching_sha2_password or sha256_password are not" ..
|
|
" supported because resty.rsa is not installed", 2)
|
|
end
|
|
|
|
local password = _to_cstring(self.password)
|
|
local n = #password
|
|
local l = #auth_data
|
|
local bytes = new_tab(n, 0)
|
|
|
|
for i = 1, n do
|
|
local j = i % l
|
|
bytes[i] = strchar(bxor(strbyte(password, i), strbyte(auth_data, j)))
|
|
end
|
|
|
|
local pub, err = resty_rsa:new({
|
|
public_key = public_key,
|
|
key_type = resty_rsa.KEY_TYPE.PKCS8,
|
|
padding = resty_rsa.PADDING.RSA_PKCS1_OAEP_PADDING,
|
|
algorithm = "sha1",
|
|
})
|
|
if not pub then
|
|
return nil, "new rsa err: " .. err
|
|
end
|
|
|
|
local enc, err = pub:encrypt(concat(bytes))
|
|
if not enc then
|
|
return nil, "encode password packet: " .. err
|
|
end
|
|
|
|
return enc
|
|
end
|
|
|
|
|
|
local function _write_encode_password(self, auth_data, public_key)
|
|
local enc, err = _encrypt_password(self, auth_data, public_key)
|
|
|
|
local bytes, err = _send_packet(self, enc, #enc)
|
|
if not bytes then
|
|
return "failed to send encode password packet: " .. err
|
|
end
|
|
end
|
|
|
|
|
|
local function _auth(self, auth_data, plugin)
|
|
local password = self.password
|
|
|
|
if plugin == "caching_sha2_password" then
|
|
local auth_resp, err = _compute_sha256_token(password, auth_data)
|
|
if err then
|
|
return nil, "failed to compute sha256 token: " .. err
|
|
end
|
|
|
|
return auth_resp
|
|
end
|
|
|
|
if plugin == "mysql_old_password" then
|
|
return _compute_old_token(password, auth_data)
|
|
end
|
|
|
|
if plugin == "mysql_clear_password" then
|
|
return _to_cstring(password)
|
|
end
|
|
|
|
if plugin == "mysql_native_password" then
|
|
return _compute_token(password, auth_data)
|
|
end
|
|
|
|
if plugin == "sha256_password" then
|
|
if self.is_unix or self.use_ssl or #password == 0 then
|
|
return _to_cstring(password)
|
|
end
|
|
|
|
local public_key = self.public_key
|
|
if public_key then
|
|
return _encrypt_password(self, auth_data, public_key)
|
|
end
|
|
|
|
return "\1" -- request public key from server
|
|
end
|
|
|
|
return nil, "unknown plugin: " .. plugin
|
|
end
|
|
|
|
|
|
local function _handle_auth_result(self, old_auth_data, plugin)
|
|
local auth_data, new_plugin, err = _read_auth_result(self, old_auth_data,
|
|
plugin)
|
|
|
|
if err ~= nil then
|
|
local errno, sqlstate = auth_data, new_plugin
|
|
return err, errno, sqlstate
|
|
end
|
|
|
|
if auth_data == RESP_OK then
|
|
return
|
|
end
|
|
|
|
if new_plugin ~= "" then
|
|
if not auth_data then
|
|
auth_data = old_auth_data
|
|
else
|
|
old_auth_data = auth_data
|
|
end
|
|
|
|
plugin = new_plugin
|
|
|
|
local auth_resp, err = _auth(self, auth_data, plugin)
|
|
if not auth_resp then
|
|
return err
|
|
end
|
|
|
|
local bytes, err = _send_packet(self, auth_resp, #auth_resp)
|
|
if not bytes then
|
|
return "failed to send client authentication packet: " .. err
|
|
end
|
|
|
|
auth_data, new_plugin, err = _read_auth_result(self, old_auth_data,
|
|
plugin)
|
|
|
|
if err ~= nil then
|
|
local errno, sqlstate = auth_data, new_plugin
|
|
return err, errno, sqlstate
|
|
end
|
|
|
|
if auth_data == RESP_OK then
|
|
return
|
|
end
|
|
|
|
if new_plugin ~= "" then
|
|
return "malformed packet"
|
|
end
|
|
end
|
|
|
|
if plugin == "caching_sha2_password" then
|
|
local len = #auth_data
|
|
if len == 0 then
|
|
return
|
|
end
|
|
|
|
if len == 1 then
|
|
local status = strbyte(auth_data)
|
|
-- caching_sha2_password fast auth success
|
|
if status == 3 then
|
|
return _read_ok_result(self)
|
|
end
|
|
|
|
-- caching_sha2_password perform full authentication
|
|
if status == 4 then
|
|
if self.is_unix or self.use_ssl then
|
|
local bytes, err = _send_packet(self,
|
|
_to_cstring(self.password),
|
|
#self.password + 1)
|
|
|
|
if not bytes then
|
|
return "failed to send cleartext auth packet: "
|
|
.. err
|
|
end
|
|
|
|
else
|
|
local public_key = self.public_key
|
|
if not public_key then
|
|
-- caching_sha2_password request public_key
|
|
local bytes, err = _send_packet(self, "\2", 1)
|
|
if not bytes then
|
|
return "failed to send password request packet: "
|
|
.. err
|
|
end
|
|
|
|
local packet, _, err = _recv_packet(self)
|
|
if not packet then
|
|
return "failed to receive the result packet: "
|
|
.. err
|
|
end
|
|
|
|
public_key = sub(packet, 2)
|
|
end
|
|
|
|
err = _write_encode_password(self, old_auth_data,
|
|
public_key)
|
|
|
|
if err then
|
|
return err
|
|
end
|
|
|
|
self.public_key = public_key
|
|
end
|
|
|
|
return _read_ok_result(self)
|
|
end
|
|
end
|
|
|
|
return "malformed packet"
|
|
end
|
|
|
|
if plugin == "sha256_password" then
|
|
if #auth_data ~= 0 then
|
|
local enc, err = _write_encode_password(self, old_auth_data,
|
|
auth_data)
|
|
|
|
if err then
|
|
return err
|
|
end
|
|
|
|
return _read_ok_result(self)
|
|
end
|
|
end
|
|
end
|
|
|
|
|
|
function _M.new(self)
|
|
local sock, err = tcp()
|
|
if not sock then
|
|
return nil, err
|
|
end
|
|
return setmetatable({ sock = sock }, mt)
|
|
end
|
|
|
|
|
|
function _M.set_timeout(self, timeout)
|
|
local sock = self.sock
|
|
if not sock then
|
|
return nil, "not initialized"
|
|
end
|
|
|
|
return sock:settimeout(timeout)
|
|
end
|
|
|
|
|
|
function _M.connect(self, opts)
|
|
local sock = self.sock
|
|
if not sock then
|
|
return nil, "not initialized"
|
|
end
|
|
|
|
local max_packet_size = opts.max_packet_size
|
|
if not max_packet_size then
|
|
max_packet_size = 1024 * 1024 -- default 1 MB
|
|
end
|
|
self._max_packet_size = max_packet_size
|
|
|
|
local ok, err
|
|
|
|
self.compact = opts.compact_arrays
|
|
|
|
self.database = opts.database or ""
|
|
self.user = opts.user or ""
|
|
|
|
self.charset = CHARSET_MAP[opts.charset or "_default"]
|
|
if not self.charset then
|
|
return nil, "charset '" .. opts.charset .. "' is not supported"
|
|
end
|
|
|
|
local pool = opts.pool
|
|
|
|
self.ssl_verify = opts.ssl_verify
|
|
self.use_ssl = opts.ssl or opts.ssl_verify
|
|
|
|
self.password = opts.password or ""
|
|
|
|
local host = opts.host
|
|
if host then
|
|
local port = opts.port or 3306
|
|
if not pool then
|
|
pool = self.user .. ":" .. self.database .. ":" .. host .. ":"
|
|
.. port
|
|
end
|
|
|
|
ok, err = sock:connect(host, port, { pool = pool,
|
|
pool_size = opts.pool_size,
|
|
backlog = opts.backlog })
|
|
|
|
else
|
|
local path = opts.path
|
|
if not path then
|
|
return nil, 'neither "host" nor "path" options are specified'
|
|
end
|
|
|
|
if not pool then
|
|
pool = self.user .. ":" .. self.database .. ":" .. path
|
|
end
|
|
|
|
self.is_unix = true
|
|
ok, err = sock:connect("unix:" .. path, { pool = pool,
|
|
pool_size = opts.pool_size,
|
|
backlog = opts.backlog })
|
|
end
|
|
|
|
if not ok then
|
|
return nil, 'failed to connect: ' .. err
|
|
end
|
|
|
|
local reused = sock:getreusedtimes()
|
|
|
|
if reused and reused > 0 then
|
|
self.state = STATE_CONNECTED
|
|
return 1
|
|
end
|
|
|
|
self.DEFAULT_CLIENT_FLAGS = bor(DEFAULT_CLIENT_FLAGS, CLIENT_PLUGIN_AUTH)
|
|
|
|
local auth_data, plugin, err, errno, sqlstate
|
|
= _read_hand_shake_packet(self)
|
|
|
|
if err ~= nil then
|
|
return nil, err
|
|
end
|
|
|
|
local auth_resp, err = _auth(self, auth_data, plugin)
|
|
if not auth_resp then
|
|
return nil, err
|
|
end
|
|
|
|
err = _write_hand_shake_response(self, auth_resp, plugin)
|
|
if err ~= nil then
|
|
return nil, err
|
|
end
|
|
|
|
local err, errno, sqlstate = _handle_auth_result(self, auth_data, plugin)
|
|
if err ~= nil then
|
|
return nil, err, errno, sqlstate
|
|
end
|
|
|
|
self.state = STATE_CONNECTED
|
|
|
|
return 1
|
|
end
|
|
|
|
|
|
function _M.set_keepalive(self, ...)
|
|
local sock = self.sock
|
|
if not sock then
|
|
return nil, "not initialized"
|
|
end
|
|
|
|
if self.state ~= STATE_CONNECTED then
|
|
return nil, "cannot be reused in the current connection state: "
|
|
.. (self.state or "nil")
|
|
end
|
|
|
|
self.state = nil
|
|
return sock:setkeepalive(...)
|
|
end
|
|
|
|
|
|
function _M.get_reused_times(self)
|
|
local sock = self.sock
|
|
if not sock then
|
|
return nil, "not initialized"
|
|
end
|
|
|
|
return sock:getreusedtimes()
|
|
end
|
|
|
|
|
|
function _M.close(self)
|
|
local sock = self.sock
|
|
if not sock then
|
|
return nil, "not initialized"
|
|
end
|
|
|
|
self.state = nil
|
|
|
|
local bytes, err = _send_packet(self, strchar(COM_QUIT), 1)
|
|
if not bytes then
|
|
return nil, err
|
|
end
|
|
|
|
return sock:close()
|
|
end
|
|
|
|
|
|
function _M.server_ver(self)
|
|
return self._server_ver
|
|
end
|
|
|
|
|
|
local function send_query(self, query)
|
|
if self.state ~= STATE_CONNECTED then
|
|
return nil, "cannot send query in the current context: "
|
|
.. (self.state or "nil")
|
|
end
|
|
|
|
local sock = self.sock
|
|
if not sock then
|
|
return nil, "not initialized"
|
|
end
|
|
|
|
self.packet_no = -1
|
|
|
|
local cmd_packet = strchar(COM_QUERY) .. query
|
|
local packet_len = 1 + #query
|
|
|
|
local bytes, err = _send_packet(self, cmd_packet, packet_len)
|
|
if not bytes then
|
|
return nil, err
|
|
end
|
|
|
|
self.state = STATE_COMMAND_SENT
|
|
|
|
--print("packet sent ", bytes, " bytes")
|
|
|
|
return bytes
|
|
end
|
|
_M.send_query = send_query
|
|
|
|
|
|
local function read_result(self, est_nrows)
|
|
if self.state ~= STATE_COMMAND_SENT then
|
|
return nil, "cannot read result in the current context: "
|
|
.. (self.state or "nil")
|
|
end
|
|
|
|
local sock = self.sock
|
|
if not sock then
|
|
return nil, "not initialized"
|
|
end
|
|
|
|
local packet, typ, err = _recv_packet(self)
|
|
if not packet then
|
|
return nil, err
|
|
end
|
|
|
|
if typ == RESP_ERR then
|
|
self.state = STATE_CONNECTED
|
|
|
|
local errno, msg, sqlstate = _parse_err_packet(packet)
|
|
return nil, msg, errno, sqlstate
|
|
end
|
|
|
|
if typ == RESP_OK then
|
|
local res = _parse_ok_packet(packet)
|
|
if res and band(res.server_status, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
|
return res, "again"
|
|
end
|
|
|
|
self.state = STATE_CONNECTED
|
|
return res
|
|
end
|
|
|
|
if typ == RESP_LOCALINFILE then
|
|
self.state = STATE_CONNECTED
|
|
|
|
return nil, "packet type " .. typ .. " not supported"
|
|
end
|
|
|
|
-- typ == RESP_DATA or RESP_AUTHMOREDATA(also mean RESP_DATA here)
|
|
|
|
--print("read the result set header packet")
|
|
|
|
local field_count, extra = _parse_result_set_header_packet(packet)
|
|
|
|
--print("field count: ", field_count)
|
|
|
|
local cols = new_tab(field_count, 0)
|
|
for i = 1, field_count do
|
|
local col, err, errno, sqlstate = _recv_field_packet(self)
|
|
if not col then
|
|
return nil, err, errno, sqlstate
|
|
end
|
|
|
|
cols[i] = col
|
|
end
|
|
|
|
local packet, typ, err = _recv_packet(self)
|
|
if not packet then
|
|
return nil, err
|
|
end
|
|
|
|
if typ ~= RESP_EOF then
|
|
return nil, "unexpected packet type " .. typ .. " while eof packet is "
|
|
.. "expected"
|
|
end
|
|
|
|
-- typ == RESP_EOF
|
|
|
|
local compact = self.compact
|
|
|
|
local rows = new_tab(est_nrows or 4, 0)
|
|
local i = 0
|
|
while true do
|
|
--print("reading a row")
|
|
|
|
packet, typ, err = _recv_packet(self)
|
|
if not packet then
|
|
return nil, err
|
|
end
|
|
|
|
if typ == RESP_EOF then
|
|
local warning_count, status_flags = _parse_eof_packet(packet)
|
|
|
|
--print("status flags: ", status_flags)
|
|
|
|
if band(status_flags, SERVER_MORE_RESULTS_EXISTS) ~= 0 then
|
|
return rows, "again"
|
|
end
|
|
|
|
break
|
|
end
|
|
|
|
local row = _parse_row_data_packet(packet, cols, compact)
|
|
i = i + 1
|
|
rows[i] = row
|
|
end
|
|
|
|
self.state = STATE_CONNECTED
|
|
|
|
return rows
|
|
end
|
|
_M.read_result = read_result
|
|
|
|
|
|
function _M.query(self, query, est_nrows)
|
|
local bytes, err = send_query(self, query)
|
|
if not bytes then
|
|
return nil, "failed to send query: " .. err
|
|
end
|
|
|
|
return read_result(self, est_nrows)
|
|
end
|
|
|
|
|
|
function _M.set_compact_arrays(self, value)
|
|
self.compact = value
|
|
end
|
|
|
|
|
|
return _M
|