Commit 5788ee9f authored by Rob Shearman's avatar Rob Shearman Committed by Alexandre Julliard

ole32: Implement CoRegisterChannelHook and call channel hook methods on the client side.

parent 1dc5dec6
...@@ -3297,6 +3297,26 @@ HRESULT WINAPI CoGetObject(LPCWSTR pszName, BIND_OPTS *pBindOptions, ...@@ -3297,6 +3297,26 @@ HRESULT WINAPI CoGetObject(LPCWSTR pszName, BIND_OPTS *pBindOptions,
} }
/*********************************************************************** /***********************************************************************
* CoRegisterChannelHook [OLE32.@]
*
* Registers a process-wide hook that is called during ORPC calls.
*
* PARAMS
* guidExtension [I] GUID of the channel hook to register.
* pChannelHook [I] Channel hook object to register.
*
* RETURNS
* Success: S_OK.
* Failure: HRESULT code.
*/
HRESULT WINAPI CoRegisterChannelHook(REFGUID guidExtension, IChannelHook *pChannelHook)
{
TRACE("(%s, %p)\n", debugstr_guid(guidExtension), pChannelHook);
return RPC_RegisterChannelHook(guidExtension, pChannelHook);
}
/***********************************************************************
* DllMain (OLE32.@) * DllMain (OLE32.@)
*/ */
BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID fImpLoad) BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID fImpLoad)
...@@ -3313,6 +3333,7 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID fImpLoad) ...@@ -3313,6 +3333,7 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID fImpLoad)
case DLL_PROCESS_DETACH: case DLL_PROCESS_DETACH:
if (TRACE_ON(ole)) CoRevokeMallocSpy(); if (TRACE_ON(ole)) CoRevokeMallocSpy();
COMPOBJ_UninitProcess(); COMPOBJ_UninitProcess();
RPC_UnregisterAllChannelHooks();
OLE32_hInstance = 0; OLE32_hInstance = 0;
break; break;
......
...@@ -219,6 +219,8 @@ HRESULT RPC_RegisterInterface(REFIID riid); ...@@ -219,6 +219,8 @@ HRESULT RPC_RegisterInterface(REFIID riid);
void RPC_UnregisterInterface(REFIID riid); void RPC_UnregisterInterface(REFIID riid);
void RPC_StartLocalServer(REFCLSID clsid, IStream *stream); void RPC_StartLocalServer(REFCLSID clsid, IStream *stream);
HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv); HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv);
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook);
void RPC_UnregisterAllChannelHooks(void);
/* This function initialize the Running Object Table */ /* This function initialize the Running Object Table */
HRESULT WINAPI RunningObjectTableImpl_Initialize(void); HRESULT WINAPI RunningObjectTableImpl_Initialize(void);
......
...@@ -52,7 +52,7 @@ ...@@ -52,7 +52,7 @@
@ stdcall CoQueryClientBlanket(ptr ptr ptr ptr ptr ptr ptr) @ stdcall CoQueryClientBlanket(ptr ptr ptr ptr ptr ptr ptr)
@ stdcall CoQueryProxyBlanket(ptr ptr ptr ptr ptr ptr ptr ptr) @ stdcall CoQueryProxyBlanket(ptr ptr ptr ptr ptr ptr ptr ptr)
@ stub CoQueryReleaseObject @ stub CoQueryReleaseObject
@ stub CoRegisterChannelHook @ stdcall CoRegisterChannelHook(ptr ptr)
@ stdcall CoRegisterClassObject(ptr ptr long long ptr) @ stdcall CoRegisterClassObject(ptr ptr long long ptr)
@ stdcall CoRegisterMallocSpy (ptr) @ stdcall CoRegisterMallocSpy (ptr)
@ stdcall CoRegisterMessageFilter(ptr ptr) @ stdcall CoRegisterMessageFilter(ptr ptr)
......
...@@ -68,6 +68,16 @@ static CRITICAL_SECTION_DEBUG csRegIf_debug = ...@@ -68,6 +68,16 @@ static CRITICAL_SECTION_DEBUG csRegIf_debug =
}; };
static CRITICAL_SECTION csRegIf = { &csRegIf_debug, -1, 0, 0, 0, 0 }; static CRITICAL_SECTION csRegIf = { &csRegIf_debug, -1, 0, 0, 0, 0 };
static struct list channel_hooks = LIST_INIT(channel_hooks); /* (CS csChannelHook) */
static CRITICAL_SECTION csChannelHook;
static CRITICAL_SECTION_DEBUG csChannelHook_debug =
{
0, 0, &csChannelHook,
{ &csChannelHook_debug.ProcessLocksList, &csChannelHook_debug.ProcessLocksList },
0, 0, { (DWORD_PTR)(__FILE__ ": channel hooks") }
};
static CRITICAL_SECTION csChannelHook = { &csChannelHook_debug, -1, 0, 0, 0, 0 };
static WCHAR wszRpcTransport[] = {'n','c','a','l','r','p','c',0}; static WCHAR wszRpcTransport[] = {'n','c','a','l','r','p','c',0};
...@@ -127,6 +137,145 @@ typedef struct ...@@ -127,6 +137,145 @@ typedef struct
/* [size_is((size+7)&~7)] */ unsigned char data[1]; /* [size_is((size+7)&~7)] */ unsigned char data[1];
} WIRE_ORPC_EXTENT; } WIRE_ORPC_EXTENT;
struct channel_hook_entry
{
struct list entry;
GUID id;
IChannelHook *hook;
};
struct channel_hook_buffer_data
{
GUID id;
ULONG extension_size;
};
/* Channel Hook Functions */
static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
struct channel_hook_buffer_data **data, unsigned int *hook_count,
ULONG *extension_count)
{
struct channel_hook_entry *entry;
ULONG total_size = 0;
unsigned int hook_index = 0;
*hook_count = 0;
*extension_count = 0;
EnterCriticalSection(&csChannelHook);
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
(*hook_count)++;
if (hook_count)
*data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
else
*data = NULL;
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
{
ULONG extension_size = 0;
IChannelHook_ClientGetSize(entry->hook, &entry->id, &info->iid, &extension_size);
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
extension_size = (extension_size+7)&~7;
(*data)[hook_index].id = entry->id;
(*data)[hook_index].extension_size = extension_size;
/* an extension is only put onto the wire if it has data to write */
if (extension_size)
{
total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
(*extension_count)++;
}
hook_index++;
}
LeaveCriticalSection(&csChannelHook);
return total_size;
}
static unsigned char * ChannelHooks_ClientFillBuffer(SChannelHookCallInfo *info,
unsigned char *buffer, struct channel_hook_buffer_data *data,
unsigned int hook_count)
{
struct channel_hook_entry *entry;
EnterCriticalSection(&csChannelHook);
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
{
unsigned int i;
ULONG extension_size = 0;
WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
for (i = 0; i < hook_count; i++)
if (IsEqualGUID(&entry->id, &data[i].id))
extension_size = data[i].extension_size;
/* an extension is only put onto the wire if it has data to write */
if (!extension_size)
continue;
IChannelHook_ClientFillBuffer(entry->hook, &entry->id, &info->iid,
&extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]));
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
/* FIXME: set unused portion of wire_orpc_extent->data to 0? */
wire_orpc_extent->conformance = (extension_size+7)&~7;
wire_orpc_extent->size = extension_size;
memcpy(&wire_orpc_extent->id, &entry->id, sizeof(wire_orpc_extent->id));
buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
}
LeaveCriticalSection(&csChannelHook);
HeapFree(GetProcessHeap(), 0, data);
return buffer;
}
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook)
{
struct channel_hook_entry *entry;
TRACE("(%s, %p)\n", debugstr_guid(rguid), hook);
entry = HeapAlloc(GetProcessHeap(), 0, sizeof(*entry));
if (!entry)
return E_OUTOFMEMORY;
memcpy(&entry->id, rguid, sizeof(entry->id));
entry->hook = hook;
IChannelHook_AddRef(hook);
EnterCriticalSection(&csChannelHook);
list_add_tail(&channel_hooks, &entry->entry);
LeaveCriticalSection(&csChannelHook);
return S_OK;
}
void RPC_UnregisterAllChannelHooks(void)
{
struct channel_hook_entry *cursor;
struct channel_hook_entry *cursor2;
EnterCriticalSection(&csChannelHook);
LIST_FOR_EACH_ENTRY_SAFE(cursor, cursor2, &channel_hooks, struct channel_hook_entry, entry)
HeapFree(GetProcessHeap(), 0, cursor);
LeaveCriticalSection(&csChannelHook);
}
/* RPC Channel Buffer Functions */
static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv) static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv)
{ {
...@@ -207,6 +356,11 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, ...@@ -207,6 +356,11 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
RPC_STATUS status; RPC_STATUS status;
ORPCTHIS *orpcthis; ORPCTHIS *orpcthis;
struct message_state *message_state; struct message_state *message_state;
ULONG extensions_size;
struct channel_hook_buffer_data *channel_hook_data;
unsigned int channel_hook_count;
ULONG extension_count;
SChannelHookCallInfo channel_hook_info;
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid)); TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
...@@ -230,8 +384,24 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, ...@@ -230,8 +384,24 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
msg->Handle = This->bind; msg->Handle = This->bind;
msg->RpcInterfaceInformation = cif; msg->RpcInterfaceInformation = cif;
channel_hook_info.iid = *riid;
channel_hook_info.cbSize = sizeof(channel_hook_info);
channel_hook_info.uCausality = GUID_NULL; /* FIXME */
channel_hook_info.dwServerPid = 0; /* FIXME */
channel_hook_info.iMethod = msg->ProcNum;
channel_hook_info.pObject = NULL; /* only present on server-side */
extensions_size = ChannelHooks_ClientGetSize(&channel_hook_info,
&channel_hook_data, &channel_hook_count, &extension_count);
msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4; msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4;
if (extensions_size)
{
msg->BufferLength += FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent) + 2*sizeof(DWORD) + extensions_size;
if (extension_count & 1)
msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
}
status = I_RpcGetBuffer(msg); status = I_RpcGetBuffer(msg);
message_state->prefix_data_len = 0; message_state->prefix_data_len = 0;
...@@ -245,14 +415,42 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, ...@@ -245,14 +415,42 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
orpcthis->version.MajorVersion = COM_MAJOR_VERSION; orpcthis->version.MajorVersion = COM_MAJOR_VERSION;
orpcthis->version.MinorVersion = COM_MINOR_VERSION; orpcthis->version.MinorVersion = COM_MINOR_VERSION;
orpcthis->flags = ORPCF_NULL; orpcthis->flags = channel_hook_info.dwServerPid ? ORPCF_LOCAL : ORPCF_NULL;
orpcthis->reserved1 = 0; orpcthis->reserved1 = 0;
orpcthis->cid = GUID_NULL; /* FIXME */ orpcthis->cid = channel_hook_info.uCausality;
/* NDR representation of orpcthis->extensions */ /* NDR representation of orpcthis->extensions */
*(DWORD *)msg->Buffer = 0; /* FIXME */ *(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
if (extensions_size)
{
ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
orpc_extent_array->size = extension_count;
orpc_extent_array->reserved = 0;
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
/* NDR representation of orpc_extent_array->extent */
*(DWORD *)msg->Buffer = 1;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* NDR representation of [size_is] attribute of orpc_extent_array->extent */
*(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
msg->Buffer = ChannelHooks_ClientFillBuffer(&channel_hook_info,
msg->Buffer, channel_hook_data, channel_hook_count);
/* we must add a dummy extension if there is an odd extension
* count to meet the contract specified by the size_is attribute */
if (extension_count & 1)
{
WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
wire_orpc_extent->conformance = 0;
memcpy(&wire_orpc_extent->id, &GUID_NULL, sizeof(wire_orpc_extent->id));
wire_orpc_extent->size = 0;
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
}
}
/* store the prefixed data length so that we can restore the real buffer /* store the prefixed data length so that we can restore the real buffer
* pointer in ClientRpcChannelBuffer_SendReceive. */ * pointer in ClientRpcChannelBuffer_SendReceive. */
message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis; message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment