Commit ced5800b authored by Huw Davies's avatar Huw Davies Committed by Alexandre Julliard

ole32: Correctly release an existing pointer when unmarshalling a NULL-ptr.

However, this should not be done in the case of pUnkForRelease. Signed-off-by: 's avatarHuw Davies <huw@codeweavers.com> Signed-off-by: 's avatarAlexandre Julliard <julliard@winehq.org>
parent 9021b967
...@@ -527,6 +527,17 @@ static const IUnknownVtbl TestUnknown_Vtbl = ...@@ -527,6 +527,17 @@ static const IUnknownVtbl TestUnknown_Vtbl =
Test_IUnknown_Release, Test_IUnknown_Release,
}; };
struct test_stream
{
IStream IStream_iface;
LONG refs;
};
static inline struct test_stream *impl_from_IStream(IStream *iface)
{
return CONTAINING_RECORD(iface, struct test_stream, IStream_iface);
}
static HRESULT WINAPI Test_IStream_QueryInterface(IStream *iface, static HRESULT WINAPI Test_IStream_QueryInterface(IStream *iface,
REFIID riid, LPVOID *ppvObj) REFIID riid, LPVOID *ppvObj)
{ {
...@@ -546,12 +557,14 @@ static HRESULT WINAPI Test_IStream_QueryInterface(IStream *iface, ...@@ -546,12 +557,14 @@ static HRESULT WINAPI Test_IStream_QueryInterface(IStream *iface,
static ULONG WINAPI Test_IStream_AddRef(IStream *iface) static ULONG WINAPI Test_IStream_AddRef(IStream *iface)
{ {
return 2; /* non-heap-based object */ struct test_stream *This = impl_from_IStream(iface);
return InterlockedIncrement(&This->refs);
} }
static ULONG WINAPI Test_IStream_Release(IStream *iface) static ULONG WINAPI Test_IStream_Release(IStream *iface)
{ {
return 1; /* non-heap-based object */ struct test_stream *This = impl_from_IStream(iface);
return InterlockedDecrement(&This->refs);
} }
static const IStreamVtbl TestStream_Vtbl = static const IStreamVtbl TestStream_Vtbl =
...@@ -564,7 +577,8 @@ static const IStreamVtbl TestStream_Vtbl = ...@@ -564,7 +577,8 @@ static const IStreamVtbl TestStream_Vtbl =
static TestUnknown Test_Unknown = { {&TestUnknown_Vtbl}, 1 }; static TestUnknown Test_Unknown = { {&TestUnknown_Vtbl}, 1 };
static TestUnknown Test_Unknown2 = { {&TestUnknown_Vtbl}, 1 }; static TestUnknown Test_Unknown2 = { {&TestUnknown_Vtbl}, 1 };
static IStream Test_Stream = { &TestStream_Vtbl }; static struct test_stream Test_Stream = { {&TestStream_Vtbl}, 1 };
static struct test_stream Test_Stream2 = { {&TestStream_Vtbl}, 1 };
ULONG __RPC_USER WdtpInterfacePointer_UserSize(ULONG *, ULONG, ULONG, IUnknown *, REFIID); ULONG __RPC_USER WdtpInterfacePointer_UserSize(ULONG *, ULONG, ULONG, IUnknown *, REFIID);
unsigned char * __RPC_USER WdtpInterfacePointer_UserMarshal(ULONG *, ULONG, unsigned char *, IUnknown *, REFIID); unsigned char * __RPC_USER WdtpInterfacePointer_UserMarshal(ULONG *, ULONG, unsigned char *, IUnknown *, REFIID);
...@@ -683,7 +697,7 @@ static void test_marshal_WdtpInterfacePointer(void) ...@@ -683,7 +697,7 @@ static void test_marshal_WdtpInterfacePointer(void)
marshal_WdtpInterfacePointer(MSHCTX_INPROC, MSHCTX_DIFFERENTMACHINE,1,1,1); marshal_WdtpInterfacePointer(MSHCTX_INPROC, MSHCTX_DIFFERENTMACHINE,1,1,1);
} }
static void test_marshal_STGMEDIUM(void) static void marshal_STGMEDIUM(BOOL client, BOOL in, BOOL out)
{ {
USER_MARSHAL_CB umcb; USER_MARSHAL_CB umcb;
MIDL_STUB_MESSAGE stub_msg; MIDL_STUB_MESSAGE stub_msg;
...@@ -692,10 +706,12 @@ static void test_marshal_STGMEDIUM(void) ...@@ -692,10 +706,12 @@ static void test_marshal_STGMEDIUM(void)
ULONG size, expect_size; ULONG size, expect_size;
STGMEDIUM med, med2; STGMEDIUM med, med2;
IUnknown *unk = &Test_Unknown.IUnknown_iface; IUnknown *unk = &Test_Unknown.IUnknown_iface;
IStream *stm = &Test_Stream; IStream *stm = &Test_Stream.IStream_iface;
/* TYMED_NULL with pUnkForRelease */ /* TYMED_NULL with pUnkForRelease */
Test_Unknown.refs = 1;
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
expect_size = WdtpInterfacePointer_UserSize(&umcb.Flags, umcb.Flags, 2 * sizeof(DWORD), unk, &IID_IUnknown); expect_size = WdtpInterfacePointer_UserSize(&umcb.Flags, umcb.Flags, 2 * sizeof(DWORD), unk, &IID_IUnknown);
expect_buffer = HeapAlloc(GetProcessHeap(), 0, expect_size); expect_buffer = HeapAlloc(GetProcessHeap(), 0, expect_size);
...@@ -721,17 +737,20 @@ static void test_marshal_STGMEDIUM(void) ...@@ -721,17 +737,20 @@ static void test_marshal_STGMEDIUM(void)
ok(!memcmp(buffer+8, expect_buffer + 8, expect_buffer_end - expect_buffer - 8), "buffer mismatch\n"); ok(!memcmp(buffer+8, expect_buffer + 8, expect_buffer_end - expect_buffer - 8), "buffer mismatch\n");
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, buffer, size, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, buffer, size, MSHCTX_DIFFERENTMACHINE);
umcb.pStubMsg->IsClient = client;
umcb.pStubMsg->fIsIn = in;
umcb.pStubMsg->fIsOut = out;
/* native crashes if this is uninitialised, presumably because it Test_Unknown2.refs = 1;
tries to release it */
med2.tymed = TYMED_NULL; med2.tymed = TYMED_NULL;
U(med2).pstm = NULL; U(med2).pstm = NULL;
med2.pUnkForRelease = NULL; med2.pUnkForRelease = &Test_Unknown2.IUnknown_iface;
STGMEDIUM_UserUnmarshal(&umcb.Flags, buffer, &med2); STGMEDIUM_UserUnmarshal(&umcb.Flags, buffer, &med2);
ok(med2.tymed == TYMED_NULL, "got tymed %x\n", med2.tymed); ok(med2.tymed == TYMED_NULL, "got tymed %x\n", med2.tymed);
ok(med2.pUnkForRelease != NULL, "Incorrectly unmarshalled\n"); ok(med2.pUnkForRelease != NULL, "Incorrectly unmarshalled\n");
ok(Test_Unknown2.refs == 0, "got %d\n", Test_Unknown2.refs);
HeapFree(GetProcessHeap(), 0, buffer); HeapFree(GetProcessHeap(), 0, buffer);
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
...@@ -745,10 +764,15 @@ static void test_marshal_STGMEDIUM(void) ...@@ -745,10 +764,15 @@ static void test_marshal_STGMEDIUM(void)
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
STGMEDIUM_UserFree(&umcb.Flags, &med2); STGMEDIUM_UserFree(&umcb.Flags, &med2);
ok(Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs);
HeapFree(GetProcessHeap(), 0, expect_buffer); HeapFree(GetProcessHeap(), 0, expect_buffer);
/* TYMED_ISTREAM with pUnkForRelease */ /* TYMED_ISTREAM with pUnkForRelease */
Test_Unknown.refs = 1;
Test_Stream.refs = 1;
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
expect_size = WdtpInterfacePointer_UserSize(&umcb.Flags, umcb.Flags, 3 * sizeof(DWORD), (IUnknown*)stm, &IID_IStream); expect_size = WdtpInterfacePointer_UserSize(&umcb.Flags, umcb.Flags, 3 * sizeof(DWORD), (IUnknown*)stm, &IID_IStream);
expect_size = WdtpInterfacePointer_UserSize(&umcb.Flags, umcb.Flags, expect_size, unk, &IID_IUnknown); expect_size = WdtpInterfacePointer_UserSize(&umcb.Flags, umcb.Flags, expect_size, unk, &IID_IUnknown);
...@@ -782,18 +806,23 @@ static void test_marshal_STGMEDIUM(void) ...@@ -782,18 +806,23 @@ static void test_marshal_STGMEDIUM(void)
ok(!memcmp(buffer + 12, expect_buffer + 12, (buffer_end - buffer) - 12), "buffer mismatch\n"); ok(!memcmp(buffer + 12, expect_buffer + 12, (buffer_end - buffer) - 12), "buffer mismatch\n");
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, buffer, size, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, buffer, size, MSHCTX_DIFFERENTMACHINE);
umcb.pStubMsg->IsClient = client;
umcb.pStubMsg->fIsIn = in;
umcb.pStubMsg->fIsOut = out;
/* native crashes if this is uninitialised, presumably because it Test_Stream2.refs = 1;
tries to release it */ Test_Unknown2.refs = 1;
med2.tymed = TYMED_NULL; med2.tymed = TYMED_ISTREAM;
U(med2).pstm = NULL; U(med2).pstm = &Test_Stream2.IStream_iface;
med2.pUnkForRelease = NULL; med2.pUnkForRelease = &Test_Unknown2.IUnknown_iface;
STGMEDIUM_UserUnmarshal(&umcb.Flags, buffer, &med2); STGMEDIUM_UserUnmarshal(&umcb.Flags, buffer, &med2);
ok(med2.tymed == TYMED_ISTREAM, "got tymed %x\n", med2.tymed); ok(med2.tymed == TYMED_ISTREAM, "got tymed %x\n", med2.tymed);
ok(U(med2).pstm != NULL, "Incorrectly unmarshalled\n"); ok(U(med2).pstm != NULL, "Incorrectly unmarshalled\n");
ok(med2.pUnkForRelease != NULL, "Incorrectly unmarshalled\n"); ok(med2.pUnkForRelease != NULL, "Incorrectly unmarshalled\n");
ok(Test_Stream2.refs == 0, "got %d\n", Test_Stream2.refs);
ok(Test_Unknown2.refs == 0, "got %d\n", Test_Unknown2.refs);
HeapFree(GetProcessHeap(), 0, buffer); HeapFree(GetProcessHeap(), 0, buffer);
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
...@@ -807,7 +836,68 @@ static void test_marshal_STGMEDIUM(void) ...@@ -807,7 +836,68 @@ static void test_marshal_STGMEDIUM(void)
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE); init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
STGMEDIUM_UserFree(&umcb.Flags, &med2); STGMEDIUM_UserFree(&umcb.Flags, &med2);
ok(Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs);
ok(Test_Stream.refs == 1, "got %d\n", Test_Stream.refs);
HeapFree(GetProcessHeap(), 0, expect_buffer); HeapFree(GetProcessHeap(), 0, expect_buffer);
/* TYMED_ISTREAM = NULL with pUnkForRelease = NULL */
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
expect_size = 3 * sizeof(DWORD);
med.tymed = TYMED_ISTREAM;
U(med).pstm = NULL;
med.pUnkForRelease = NULL;
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
size = STGMEDIUM_UserSize(&umcb.Flags, 0, &med);
ok(size == expect_size, "size %d should be %d bytes\n", size, expect_size);
buffer = HeapAlloc(GetProcessHeap(), 0, size);
memset(buffer, 0xcc, size);
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, buffer, size, MSHCTX_DIFFERENTMACHINE);
buffer_end = STGMEDIUM_UserMarshal(&umcb.Flags, buffer, &med);
ok(buffer_end - buffer == expect_size, "buffer size mismatch\n");
ok(*(DWORD*)buffer == TYMED_ISTREAM, "got %08x\n", *(DWORD*)buffer);
ok(*((DWORD*)buffer+1) == 0, "got %08x\n", *((DWORD*)buffer+1));
ok(*((DWORD*)buffer+2) == 0, "got %08x\n", *((DWORD*)buffer+2));
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, buffer, size, MSHCTX_DIFFERENTMACHINE);
umcb.pStubMsg->IsClient = client;
umcb.pStubMsg->fIsIn = in;
umcb.pStubMsg->fIsOut = out;
Test_Stream2.refs = 1;
Test_Unknown2.refs = 1;
med2.tymed = TYMED_ISTREAM;
U(med2).pstm = &Test_Stream2.IStream_iface;
med2.pUnkForRelease = &Test_Unknown2.IUnknown_iface;
STGMEDIUM_UserUnmarshal(&umcb.Flags, buffer, &med2);
ok(med2.tymed == TYMED_ISTREAM, "got tymed %x\n", med2.tymed);
ok(U(med2).pstm == NULL, "Incorrectly unmarshalled\n");
ok(med2.pUnkForRelease == &Test_Unknown2.IUnknown_iface, "Incorrectly unmarshalled\n");
ok(Test_Stream2.refs == 0, "got %d\n", Test_Stream2.refs);
ok(Test_Unknown2.refs == 1, "got %d\n", Test_Unknown2.refs);
HeapFree(GetProcessHeap(), 0, buffer);
init_user_marshal_cb(&umcb, &stub_msg, &rpc_msg, NULL, 0, MSHCTX_DIFFERENTMACHINE);
STGMEDIUM_UserFree(&umcb.Flags, &med2);
}
static void test_marshal_STGMEDIUM(void)
{
marshal_STGMEDIUM(0, 0, 0);
marshal_STGMEDIUM(0, 0, 1);
marshal_STGMEDIUM(0, 1, 0);
marshal_STGMEDIUM(0, 1, 1);
/* For Windows versions post 2003, client side, non-[in,out] STGMEDIUMs get zero-initialised.
However since inline stubs don't set fIsIn or fIsOut this behaviour would break
ref counting in GetDataHere_Proxy for example, as we'd end up not releasing the original
interface. For simplicity we don't test or implement this. */
marshal_STGMEDIUM(1, 1, 1);
} }
static void test_marshal_SNB(void) static void test_marshal_SNB(void)
......
...@@ -1850,7 +1850,10 @@ unsigned char * __RPC_USER STGMEDIUM_UserUnmarshal(ULONG *pFlags, unsigned char ...@@ -1850,7 +1850,10 @@ unsigned char * __RPC_USER STGMEDIUM_UserUnmarshal(ULONG *pFlags, unsigned char
pBuffer = WdtpInterfacePointer_UserUnmarshal(pFlags, pBuffer, (IUnknown**)&pStgMedium->u.pstm, &IID_IStream); pBuffer = WdtpInterfacePointer_UserUnmarshal(pFlags, pBuffer, (IUnknown**)&pStgMedium->u.pstm, &IID_IStream);
} }
else else
{
if (pStgMedium->u.pstm) IStream_Release( pStgMedium->u.pstm );
pStgMedium->u.pstm = NULL; pStgMedium->u.pstm = NULL;
}
break; break;
case TYMED_ISTORAGE: case TYMED_ISTORAGE:
TRACE("TYMED_ISTORAGE\n"); TRACE("TYMED_ISTORAGE\n");
...@@ -1859,7 +1862,10 @@ unsigned char * __RPC_USER STGMEDIUM_UserUnmarshal(ULONG *pFlags, unsigned char ...@@ -1859,7 +1862,10 @@ unsigned char * __RPC_USER STGMEDIUM_UserUnmarshal(ULONG *pFlags, unsigned char
pBuffer = WdtpInterfacePointer_UserUnmarshal(pFlags, pBuffer, (IUnknown**)&pStgMedium->u.pstg, &IID_IStorage); pBuffer = WdtpInterfacePointer_UserUnmarshal(pFlags, pBuffer, (IUnknown**)&pStgMedium->u.pstg, &IID_IStorage);
} }
else else
{
if (pStgMedium->u.pstg) IStorage_Release( pStgMedium->u.pstg );
pStgMedium->u.pstg = NULL; pStgMedium->u.pstg = NULL;
}
break; break;
case TYMED_GDI: case TYMED_GDI:
TRACE("TYMED_GDI\n"); TRACE("TYMED_GDI\n");
...@@ -1888,9 +1894,10 @@ unsigned char * __RPC_USER STGMEDIUM_UserUnmarshal(ULONG *pFlags, unsigned char ...@@ -1888,9 +1894,10 @@ unsigned char * __RPC_USER STGMEDIUM_UserUnmarshal(ULONG *pFlags, unsigned char
RaiseException(DV_E_TYMED, 0, 0, NULL); RaiseException(DV_E_TYMED, 0, 0, NULL);
} }
pStgMedium->pUnkForRelease = NULL;
if (releaseunk) if (releaseunk)
pBuffer = WdtpInterfacePointer_UserUnmarshal(pFlags, pBuffer, &pStgMedium->pUnkForRelease, &IID_IUnknown); pBuffer = WdtpInterfacePointer_UserUnmarshal(pFlags, pBuffer, &pStgMedium->pUnkForRelease, &IID_IUnknown);
/* Unlike the IStream / IStorage ifaces, the existing pUnkForRelease
is left intact if a NULL ptr is unmarshalled - see the tests. */
return pBuffer; return pBuffer;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment