Commit 1dc5dec6 authored by Rob Shearman's avatar Rob Shearman Committed by Alexandre Julliard

ole32: Marshal the ORPCTHIS structure prefixed to the client data when doing ORPC calls.

This is done by putting the ORPCTHIS data into the buffer when calling IRpcChannelBuffer::GetBuffer on the client side and then storing the amount we increased the buffer in a structure stored in the Handle field. This is done to present the correct Buffer pointer to the proxy so that it writes its data after the ORPCTHIS data. Unmarshal the data on the server side (during RPC_ExecuteCall) and make sure the data is consistent according to NDR rules. Also add several checks on the unmarshaled data that are specified by the DCOM draft specification.
parent e4fc45e0
......@@ -113,6 +113,21 @@ struct dispatch_params
HRESULT hr; /* hresult (out) */
};
struct message_state
{
RPC_BINDING_HANDLE binding_handle;
ULONG prefix_data_len;
};
typedef struct
{
ULONG conformance; /* NDR */
GUID id;
ULONG size;
/* [size_is((size+7)&~7)] */ unsigned char data[1];
} WIRE_ORPC_EXTENT;
static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv)
{
*ppv = NULL;
......@@ -164,11 +179,21 @@ static HRESULT WINAPI ServerRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status;
struct message_state *message_state;
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
status = I_RpcGetBuffer(msg);
/* save away the message state again */
msg->Handle = message_state;
message_state->prefix_data_len = 0;
TRACE("-- %ld\n", status);
return HRESULT_FROM_WIN32(status);
......@@ -180,6 +205,8 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_CLIENT_INTERFACE *cif;
RPC_STATUS status;
ORPCTHIS *orpcthis;
struct message_state *message_state;
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
......@@ -187,17 +214,51 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
if (!cif)
return E_OUTOFMEMORY;
message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state));
if (!message_state)
{
HeapFree(GetProcessHeap(), 0, cif);
return E_OUTOFMEMORY;
}
cif->Length = sizeof(RPC_CLIENT_INTERFACE);
/* RPC interface ID = COM interface ID */
cif->InterfaceId.SyntaxGUID = *riid;
/* COM objects always have a version of 0.0 */
cif->InterfaceId.SyntaxVersion.MajorVersion = 0;
cif->InterfaceId.SyntaxVersion.MinorVersion = 0;
msg->RpcInterfaceInformation = cif;
msg->Handle = This->bind;
msg->RpcInterfaceInformation = cif;
msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4;
status = I_RpcGetBuffer(msg);
message_state->prefix_data_len = 0;
message_state->binding_handle = This->bind;
msg->Handle = message_state;
if (status == RPC_S_OK)
{
orpcthis = (ORPCTHIS *)msg->Buffer;
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHIS, extensions);
orpcthis->version.MajorVersion = COM_MAJOR_VERSION;
orpcthis->version.MinorVersion = COM_MINOR_VERSION;
orpcthis->flags = ORPCF_NULL;
orpcthis->reserved1 = 0;
orpcthis->cid = GUID_NULL; /* FIXME */
/* NDR representation of orpcthis->extensions */
*(DWORD *)msg->Buffer = 0; /* FIXME */
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* store the prefixed data length so that we can restore the real buffer
* pointer in ClientRpcChannelBuffer_SendReceive. */
message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;
msg->BufferLength -= message_state->prefix_data_len;
}
TRACE("-- %ld\n", status);
return HRESULT_FROM_WIN32(status);
......@@ -264,6 +325,7 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
struct dispatch_params *params;
APARTMENT *apt = NULL;
IPID ipid;
struct message_state *message_state;
TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod);
......@@ -278,6 +340,12 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
params = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*params));
if (!params) return E_OUTOFMEMORY;
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
params->msg = olemsg;
params->status = RPC_S_OK;
params->hr = S_OK;
......@@ -289,7 +357,7 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
* a thread to process the RPC when this function is called indirectly
* from DllMain */
RpcBindingInqObject(msg->Handle, &ipid);
RpcBindingInqObject(message_state->binding_handle, &ipid);
hr = ipid_get_dispatch_params(&ipid, &apt, &params->stub, &params->chan);
params->handle = ClientRpcChannelBuffer_GetEventHandle(This);
if ((hr == S_OK) && !apt->multi_threaded)
......@@ -337,6 +405,10 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
}
ClientRpcChannelBuffer_ReleaseEventHandle(This, params->handle);
/* save away the message state again */
msg->Handle = message_state;
message_state->prefix_data_len = 0;
if (hr == S_OK) hr = params->hr;
status = params->status;
......@@ -364,11 +436,21 @@ static HRESULT WINAPI ServerRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface
{
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status;
struct message_state *message_state;
TRACE("(%p)\n", msg);
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
message_state->prefix_data_len = 0;
status = I_RpcFreeBuffer(msg);
msg->Handle = message_state;
TRACE("-- %ld\n", status);
return HRESULT_FROM_WIN32(status);
......@@ -378,13 +460,21 @@ static HRESULT WINAPI ClientRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface
{
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status;
struct message_state *message_state;
TRACE("(%p)\n", msg);
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
status = I_RpcFreeBuffer(msg);
HeapFree(GetProcessHeap(), 0, msg->RpcInterfaceInformation);
msg->RpcInterfaceInformation = NULL;
HeapFree(GetProcessHeap(), 0, message_state);
TRACE("-- %ld\n", status);
......@@ -519,11 +609,134 @@ HRESULT RPC_CreateServerChannel(IRpcChannelBuffer **chan)
return S_OK;
}
/* unmarshals ORPCTHIS according to NDR rules, but doesn't allocate any memory */
static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
{
const char *end = (char *)msg->Buffer + msg->BufferLength;
*first_wire_orpc_extent = NULL;
if (msg->BufferLength < FIELD_OFFSET(ORPCTHIS, extensions) + 4)
{
ERR("invalid buffer length\n");
return RPC_E_INVALID_HEADER;
}
memcpy(orpcthis, msg->Buffer, FIELD_OFFSET(ORPCTHIS, extensions));
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHIS, extensions);
if ((const char *)msg->Buffer + sizeof(DWORD) > end)
return RPC_E_INVALID_HEADER;
if (*(DWORD *)msg->Buffer)
orpcthis->extensions = orpc_ext_array;
else
orpcthis->extensions = NULL;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
if (orpcthis->extensions)
{
DWORD pointer_id;
DWORD i;
memcpy(orpcthis->extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
return RPC_E_INVALID_HEADER;
pointer_id = *(DWORD *)msg->Buffer;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
orpcthis->extensions->extent = NULL;
if (pointer_id)
{
WIRE_ORPC_EXTENT *wire_orpc_extent;
/* conformance */
if (*(DWORD *)msg->Buffer != ((orpcthis->extensions->size+1)&~1))
return RPC_S_INVALID_BOUND;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* arbritary limit for security (don't know what native does) */
if (orpcthis->extensions->size > 256)
{
ERR("too many extensions: %ld\n", orpcthis->extensions->size);
return RPC_S_INVALID_BOUND;
}
*first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
for (i = 0; i < ((orpcthis->extensions->size+1)&~1); i++)
{
if ((const char *)&wire_orpc_extent->data[0] > end)
return RPC_S_INVALID_BOUND;
if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
return RPC_S_INVALID_BOUND;
if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
return RPC_S_INVALID_BOUND;
TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
}
msg->Buffer = wire_orpc_extent;
}
}
if ((orpcthis->version.MajorVersion != COM_MAJOR_VERSION) ||
(orpcthis->version.MinorVersion > COM_MINOR_VERSION))
{
ERR("COM version {%d, %d} not supported\n",
orpcthis->version.MajorVersion, orpcthis->version.MinorVersion);
return RPC_E_VERSION_MISMATCH;
}
if (orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4))
{
ERR("invalid flags 0x%lx\n", orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4));
return RPC_E_INVALID_HEADER;
}
return S_OK;
}
void RPC_ExecuteCall(struct dispatch_params *params)
{
struct message_state *message_state;
RPC_MESSAGE *msg = (RPC_MESSAGE *)params->msg;
char *original_buffer = msg->Buffer;
ORPCTHIS orpcthis;
ORPC_EXTENT_ARRAY orpc_ext_array;
WIRE_ORPC_EXTENT *first_wire_orpc_extent;
params->hr = unmarshal_ORPCTHIS(msg, &orpcthis, &orpc_ext_array, &first_wire_orpc_extent);
if (params->hr != S_OK)
goto exit;
message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state));
if (!message_state)
{
params->hr = E_OUTOFMEMORY;
goto exit;
}
message_state->prefix_data_len = original_buffer - (char *)msg->Buffer;
message_state->binding_handle = msg->Handle;
msg->Handle = message_state;
msg->BufferLength -= message_state->prefix_data_len;
/* invoke the method */
params->hr = IRpcStubBuffer_Invoke(params->stub, params->msg, params->chan);
message_state = (struct message_state *)msg->Handle;
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
HeapFree(GetProcessHeap(), 0, message_state);
exit:
IRpcStubBuffer_Release(params->stub);
IRpcChannelBuffer_Release(params->chan);
if (params->handle) SetEvent(params->handle);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment