Simplify DLL hook mechanism

This change deletes the GetProcAddress hook and exports symbols
corresponding to the hooked functions from each hook DLL instead;
we stop at redirecting LoadLibrary/GetModuleHandle calls to the
hook DLL. This simplified approach has less hidden magic going on
behind the scenes and is more readily composable (i.e. a hook DLL
can export redirect symbols for more than one dynamically-loaded
DLL).
This commit is contained in:
Tau 2021-05-22 12:29:39 -04:00
parent b4cd868f85
commit 45e2416702
11 changed files with 84 additions and 126 deletions

View File

@ -1,3 +1,8 @@
LIBRARY chunihook LIBRARY chunihook
EXPORTS EXPORTS
Direct3DCreate9
amDllVideoClose @2
amDllVideoGetVBiosVersion @4
amDllVideoOpen @1
amDllVideoSetResolution @3

View File

@ -60,7 +60,7 @@ static DWORD CALLBACK chuni_pre_startup(void)
/* Hook Win32 APIs */ /* Hook Win32 APIs */
gfx_hook_init(&chuni_hook_cfg.gfx); gfx_hook_init(&chuni_hook_cfg.gfx, chuni_hook_mod);
serial_hook_init(); serial_hook_init();
/* Initialize emulation hooks */ /* Initialize emulation hooks */

View File

@ -1 +1,8 @@
LIBRARY divahook LIBRARY divahook
EXPORTS
Direct3DCreate9
amDllVideoClose @2
amDllVideoGetVBiosVersion @4
amDllVideoOpen @1
amDllVideoSetResolution @3

View File

@ -15,8 +15,6 @@
struct dll_hook_reg { struct dll_hook_reg {
const wchar_t *name; const wchar_t *name;
HMODULE redir_mod; HMODULE redir_mod;
const struct hook_symbol *syms;
size_t nsyms;
}; };
/* Helper functions */ /* Helper functions */
@ -26,6 +24,7 @@ static HMODULE dll_hook_search_dll(const wchar_t *name);
/* Hook functions */ /* Hook functions */
static BOOL WINAPI hook_FreeLibrary(HMODULE mod);
static HMODULE WINAPI hook_GetModuleHandleA(const char *name); static HMODULE WINAPI hook_GetModuleHandleA(const char *name);
static HMODULE WINAPI hook_GetModuleHandleW(const wchar_t *name); static HMODULE WINAPI hook_GetModuleHandleW(const wchar_t *name);
static HMODULE WINAPI hook_LoadLibraryA(const char *name); static HMODULE WINAPI hook_LoadLibraryA(const char *name);
@ -34,6 +33,7 @@ static void * WINAPI hook_GetProcAddress(HMODULE mod, const char *name);
/* Link pointers */ /* Link pointers */
static BOOL (WINAPI *next_FreeLibrary)(HMODULE mod);
static HMODULE (WINAPI *next_GetModuleHandleA)(const char *name); static HMODULE (WINAPI *next_GetModuleHandleA)(const char *name);
static HMODULE (WINAPI *next_GetModuleHandleW)(const wchar_t *name); static HMODULE (WINAPI *next_GetModuleHandleW)(const wchar_t *name);
static HMODULE (WINAPI *next_LoadLibraryA)(const char *name); static HMODULE (WINAPI *next_LoadLibraryA)(const char *name);
@ -42,6 +42,10 @@ static void * (WINAPI *next_GetProcAddress)(HMODULE mod, const char *name);
static const struct hook_symbol dll_loader_syms[] = { static const struct hook_symbol dll_loader_syms[] = {
{ {
.name = "FreeLibrary",
.patch = hook_FreeLibrary,
.link = (void **) &next_FreeLibrary,
}, {
.name = "GetModuleHandleA", .name = "GetModuleHandleA",
.patch = hook_GetModuleHandleA, .patch = hook_GetModuleHandleA,
.link = (void **) &next_GetModuleHandleA, .link = (void **) &next_GetModuleHandleA,
@ -57,10 +61,6 @@ static const struct hook_symbol dll_loader_syms[] = {
.name = "LoadLibraryW", .name = "LoadLibraryW",
.patch = hook_LoadLibraryW, .patch = hook_LoadLibraryW,
.link = (void **) &next_LoadLibraryW, .link = (void **) &next_LoadLibraryW,
}, {
.name = "GetProcAddress",
.patch = hook_GetProcAddress,
.link = (void **) &next_GetProcAddress,
} }
}; };
@ -71,16 +71,13 @@ static size_t dll_hook_count;
HRESULT dll_hook_push( HRESULT dll_hook_push(
HMODULE redir_mod, HMODULE redir_mod,
const wchar_t *name, const wchar_t *name)
const struct hook_symbol *syms,
size_t nsyms)
{ {
struct dll_hook_reg *new_item; struct dll_hook_reg *new_item;
struct dll_hook_reg *new_mem; struct dll_hook_reg *new_mem;
HRESULT hr; HRESULT hr;
assert(name != NULL); assert(name != NULL);
assert(syms != NULL);
dll_hook_init(); dll_hook_init();
@ -99,8 +96,6 @@ HRESULT dll_hook_push(
new_item = &new_mem[dll_hook_count]; new_item = &new_mem[dll_hook_count];
new_item->name = name; new_item->name = name;
new_item->redir_mod = redir_mod; new_item->redir_mod = redir_mod;
new_item->syms = syms;
new_item->nsyms = nsyms;
dll_hook_list = new_mem; dll_hook_list = new_mem;
dll_hook_count++; dll_hook_count++;
@ -167,6 +162,39 @@ static HMODULE dll_hook_search_dll(const wchar_t *name)
return result; return result;
} }
static BOOL WINAPI hook_FreeLibrary(HMODULE mod)
{
bool match;
size_t i;
match = false;
EnterCriticalSection(&dll_hook_lock);
for (i = 0 ; i < dll_hook_count ; i++) {
if (mod == dll_hook_list[i].redir_mod) {
match = true;
break;
}
}
LeaveCriticalSection(&dll_hook_lock);
if (match) {
/* Block attempts to unload redirected modules, since this could cause
a hook DLL to unexpectedly vanish and crash the whole application.
Reference counting might be another solution, although it is
possible that a buggy application might cause a hook DLL unload in
that case. */
SetLastError(ERROR_SUCCESS);
return TRUE;
}
return next_FreeLibrary(mod);
}
static HMODULE WINAPI hook_GetModuleHandleA(const char *name) static HMODULE WINAPI hook_GetModuleHandleA(const char *name)
{ {
HMODULE result; HMODULE result;
@ -262,69 +290,3 @@ static HMODULE WINAPI hook_LoadLibraryW(const wchar_t *name)
} }
/* TODO LoadLibraryExA, LoadLibraryExW */ /* TODO LoadLibraryExA, LoadLibraryExW */
static void * WINAPI hook_GetProcAddress(HMODULE mod, const char *name)
{
const struct hook_symbol *syms;
uintptr_t ordinal;
size_t nsyms;
size_t i;
if (name == NULL) {
SetLastError(ERROR_INVALID_PARAMETER);
return NULL;
}
syms = NULL;
nsyms = 0;
EnterCriticalSection(&dll_hook_lock);
for (i = 0 ; i < dll_hook_count ; i++) {
if (dll_hook_list[i].redir_mod == mod) {
syms = dll_hook_list[i].syms;
nsyms = dll_hook_list[i].nsyms;
break;
}
}
LeaveCriticalSection(&dll_hook_lock);
if (syms == NULL) {
return next_GetProcAddress(mod, name);
}
ordinal = (uintptr_t) name;
if (ordinal > 0xFFFF) {
/* Import by name */
for (i = 0 ; i < nsyms ; i++) {
if (strcmp(name, syms[i].name) == 0) {
break;
}
}
} else {
/* Import by ordinal (and name != NULL so ordinal != 0) */
for (i = 0 ; i < nsyms ; i++) {
if (ordinal == syms[i].ordinal) {
break;
}
}
}
if (i < nsyms) {
SetLastError(ERROR_SUCCESS);
return syms[i].patch;
} else {
/* GetProcAddress sets this error on failure, although of course MSDN
does not see fit to document the exact error code. */
SetLastError(ERROR_PROC_NOT_FOUND);
return NULL;
}
}

View File

@ -3,10 +3,6 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include "hook/table.h"
HRESULT dll_hook_push( HRESULT dll_hook_push(
HMODULE redir_mod, HMODULE redir_mod,
const wchar_t *name, const wchar_t *name);
const struct hook_symbol *syms,
size_t nsyms);

View File

@ -24,21 +24,20 @@ static HRESULT STDMETHODCALLTYPE my_CreateDevice(
DWORD flags, DWORD flags,
D3DPRESENT_PARAMETERS *pp, D3DPRESENT_PARAMETERS *pp,
IDirect3DDevice9 **pdev); IDirect3DDevice9 **pdev);
static IDirect3D9 * WINAPI my_Direct3DCreate9(UINT sdk_ver);
static Direct3DCreate9_t next_Direct3DCreate9;
static HRESULT gfx_frame_window(HWND hwnd); static HRESULT gfx_frame_window(HWND hwnd);
static struct gfx_config gfx_config; static struct gfx_config gfx_config;
static Direct3DCreate9_t next_Direct3DCreate9;
static const struct hook_symbol gfx_hooks[] = { static const struct hook_symbol gfx_hooks[] = {
{ {
.name = "Direct3DCreate9", .name = "Direct3DCreate9",
.patch = my_Direct3DCreate9, .patch = Direct3DCreate9,
.link = (void **) &next_Direct3DCreate9 .link = (void **) &next_Direct3DCreate9
}, },
}; };
void gfx_hook_init(const struct gfx_config *cfg) void gfx_hook_init(const struct gfx_config *cfg, HINSTANCE self)
{ {
HMODULE d3d9; HMODULE d3d9;
@ -71,13 +70,15 @@ void gfx_hook_init(const struct gfx_config *cfg)
} }
} }
dll_hook_push(NULL, L"d3d9.dll", gfx_hooks, _countof(gfx_hooks)); if (self != NULL) {
dll_hook_push(self, L"d3d9.dll");
}
fail: fail:
return; return;
} }
static IDirect3D9 * WINAPI my_Direct3DCreate9(UINT sdk_ver) IDirect3D9 * WINAPI Direct3DCreate9(UINT sdk_ver)
{ {
struct com_proxy *proxy; struct com_proxy *proxy;
IDirect3D9Vtbl *vtbl; IDirect3D9Vtbl *vtbl;

View File

@ -1,5 +1,7 @@
#pragma once #pragma once
#include <windows.h>
#include <stdbool.h> #include <stdbool.h>
struct gfx_config { struct gfx_config {
@ -9,4 +11,4 @@ struct gfx_config {
int monitor; int monitor;
}; };
void gfx_hook_init(const struct gfx_config *cfg); void gfx_hook_init(const struct gfx_config *cfg, HINSTANCE self);

View File

@ -1 +1,8 @@
LIBRARY idzhook LIBRARY idzhook
EXPORTS
Direct3DCreate9
amDllVideoClose @2
amDllVideoGetVBiosVersion @4
amDllVideoOpen @1
amDllVideoSetResolution @3

View File

@ -35,7 +35,7 @@ static DWORD CALLBACK mu3_pre_startup(void)
/* Hook Win32 APIs */ /* Hook Win32 APIs */
gfx_hook_init(&mu3_hook_cfg.gfx); gfx_hook_init(&mu3_hook_cfg.gfx, mu3_hook_mod);
serial_hook_init(); serial_hook_init();
/* Initialize emulation hooks */ /* Initialize emulation hooks */

View File

@ -1 +1,8 @@
LIBRARY mu3hook LIBRARY mu3hook
EXPORTS
Direct3DCreate9
amDllVideoClose @2
amDllVideoGetVBiosVersion @4
amDllVideoOpen @1
amDllVideoSetResolution @3

View File

@ -15,11 +15,6 @@
/* Hook functions */ /* Hook functions */
static int amDllVideoOpen(void *ctx);
static int amDllVideoClose(void *ctx);
static int amDllVideoSetResolution(void *ctx, void *param);
static int amDllVideoGetVBiosVersion(void *ctx, char *dest, size_t nchars);
static HRESULT amvideo_reg_read_name(void *bytes, uint32_t *nbytes); static HRESULT amvideo_reg_read_name(void *bytes, uint32_t *nbytes);
static HRESULT amvideo_reg_read_port_X(void *bytes, uint32_t *nbytes); static HRESULT amvideo_reg_read_port_X(void *bytes, uint32_t *nbytes);
static HRESULT amvideo_reg_read_resolution_1(void *bytes, uint32_t *nbytes); static HRESULT amvideo_reg_read_resolution_1(void *bytes, uint32_t *nbytes);
@ -88,26 +83,6 @@ static const struct reg_hook_val amvideo_reg_mode_vals[] = {
} }
}; };
static const struct hook_symbol amvideo_syms[] = {
{
.ordinal = 1,
.name = "amDllVideoOpen",
.patch = amDllVideoOpen,
}, {
.ordinal = 2,
.name = "amDllVideoClose",
.patch = amDllVideoClose,
}, {
.ordinal = 3,
.name = "amDllVideoSetResolution",
.patch = amDllVideoSetResolution,
}, {
.ordinal = 4,
.name = "amDllVideoGetVBiosVersion",
.patch = amDllVideoGetVBiosVersion,
}
};
HRESULT amvideo_hook_init(const struct amvideo_config *cfg, HMODULE redir_mod) HRESULT amvideo_hook_init(const struct amvideo_config *cfg, HMODULE redir_mod)
{ {
HRESULT hr; HRESULT hr;
@ -138,11 +113,7 @@ HRESULT amvideo_hook_init(const struct amvideo_config *cfg, HMODULE redir_mod)
return hr; return hr;
} }
hr = dll_hook_push( hr = dll_hook_push(redir_mod, amvideo_dll_name);
redir_mod,
amvideo_dll_name,
amvideo_syms,
_countof(amvideo_syms));
if (FAILED(hr)) { if (FAILED(hr)) {
return hr; return hr;
@ -151,28 +122,28 @@ HRESULT amvideo_hook_init(const struct amvideo_config *cfg, HMODULE redir_mod)
return S_OK; return S_OK;
} }
static int amDllVideoOpen(void *ctx) int amDllVideoOpen(void *ctx)
{ {
dprintf("AmVideo: %s\n", __func__); dprintf("AmVideo: %s\n", __func__);
return 0; return 0;
} }
static int amDllVideoClose(void *ctx) int amDllVideoClose(void *ctx)
{ {
dprintf("AmVideo: %s\n", __func__); dprintf("AmVideo: %s\n", __func__);
return 0; return 0;
} }
static int amDllVideoSetResolution(void *ctx, void *param) int amDllVideoSetResolution(void *ctx, void *param)
{ {
dprintf("AmVideo: %s\n", __func__); dprintf("AmVideo: %s\n", __func__);
return 0; return 0;
} }
static int amDllVideoGetVBiosVersion(void *ctx, char *dest, size_t nchars) int amDllVideoGetVBiosVersion(void *ctx, char *dest, size_t nchars)
{ {
dprintf("AmVideo: %s\n", __func__); dprintf("AmVideo: %s\n", __func__);
strcpy(dest, "01.02.03.04.05"); strcpy(dest, "01.02.03.04.05");