Commit 28b916b2 authored by Huw Davies's avatar Huw Davies Committed by Alexandre Julliard

ole32: Fix ref counting in GetDataHere Proxy.

parent ced5800b
......@@ -85,6 +85,91 @@ static void init_user_marshal_cb(USER_MARSHAL_CB *umcb,
umcb->CBType = buffer ? USER_MARSHAL_CB_UNMARSHALL : USER_MARSHAL_CB_BUFFER_SIZE;
}
#define RELEASEMARSHALDATA WM_USER
struct host_object_data
{
IStream *stream;
IID iid;
IUnknown *object;
MSHLFLAGS marshal_flags;
HANDLE marshal_event;
IMessageFilter *filter;
};
static DWORD CALLBACK host_object_proc(LPVOID p)
{
struct host_object_data *data = p;
HRESULT hr;
MSG msg;
CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
if (data->filter)
{
IMessageFilter * prev_filter = NULL;
hr = CoRegisterMessageFilter(data->filter, &prev_filter);
if (prev_filter) IMessageFilter_Release(prev_filter);
ok(hr == S_OK, "got %08x\n", hr);
}
hr = CoMarshalInterface(data->stream, &data->iid, data->object, MSHCTX_INPROC, NULL, data->marshal_flags);
ok(hr == S_OK, "got %08x\n", hr);
/* force the message queue to be created before signaling parent thread */
PeekMessageA(&msg, NULL, WM_USER, WM_USER, PM_NOREMOVE);
SetEvent(data->marshal_event);
while (GetMessageA(&msg, NULL, 0, 0))
{
if (msg.hwnd == NULL && msg.message == RELEASEMARSHALDATA)
{
CoReleaseMarshalData(data->stream);
SetEvent((HANDLE)msg.lParam);
}
else
DispatchMessageA(&msg);
}
HeapFree(GetProcessHeap(), 0, data);
CoUninitialize();
return hr;
}
static DWORD start_host_object2(IStream *stream, REFIID riid, IUnknown *object, MSHLFLAGS marshal_flags, IMessageFilter *filter, HANDLE *thread)
{
DWORD tid = 0;
HANDLE marshal_event = CreateEventA(NULL, FALSE, FALSE, NULL);
struct host_object_data *data = HeapAlloc(GetProcessHeap(), 0, sizeof(*data));
data->stream = stream;
data->iid = *riid;
data->object = object;
data->marshal_flags = marshal_flags;
data->marshal_event = marshal_event;
data->filter = filter;
*thread = CreateThread(NULL, 0, host_object_proc, data, 0, &tid);
/* wait for marshaling to complete before returning */
ok( !WaitForSingleObject(marshal_event, 10000), "wait timed out\n" );
CloseHandle(marshal_event);
return tid;
}
static void end_host_object(DWORD tid, HANDLE thread)
{
BOOL ret = PostThreadMessageA(tid, WM_QUIT, 0, 0);
ok(ret, "PostThreadMessage failed with error %d\n", GetLastError());
/* be careful of races - don't return until hosting thread has terminated */
ok( !WaitForSingleObject(thread, 10000), "wait timed out\n" );
CloseHandle(thread);
}
static const char cf_marshaled[] =
{
0x9, 0x0, 0x0, 0x0,
......@@ -1105,9 +1190,156 @@ static void test_marshal_HBRUSH(void)
DeleteObject(hBrush);
}
struct obj
{
IDataObject IDataObject_iface;
};
static HRESULT WINAPI obj_QueryInterface(IDataObject *iface, REFIID iid, void **obj)
{
*obj = NULL;
if (IsEqualGUID(iid, &IID_IUnknown) ||
IsEqualGUID(iid, &IID_IDataObject))
*obj = iface;
if (*obj)
{
IDataObject_AddRef(iface);
return S_OK;
}
return E_NOINTERFACE;
}
static ULONG WINAPI obj_AddRef(IDataObject *iface)
{
return 2;
}
static ULONG WINAPI obj_Release(IDataObject *iface)
{
return 1;
}
static HRESULT WINAPI obj_DO_GetDataHere(IDataObject *iface, FORMATETC *fmt,
STGMEDIUM *med)
{
ok( med->pUnkForRelease == NULL, "got %p\n", med->pUnkForRelease );
if (fmt->cfFormat == 2)
{
IStream_Release(U(med)->pstm);
U(med)->pstm = &Test_Stream2.IStream_iface;
}
return S_OK;
}
static const IDataObjectVtbl obj_data_object_vtbl =
{
obj_QueryInterface,
obj_AddRef,
obj_Release,
NULL, /* GetData */
obj_DO_GetDataHere,
NULL, /* QueryGetData */
NULL, /* GetCanonicalFormatEtc */
NULL, /* SetData */
NULL, /* EnumFormatEtc */
NULL, /* DAdvise */
NULL, /* DUnadvise */
NULL /* EnumDAdvise */
};
static struct obj obj =
{
{&obj_data_object_vtbl}
};
static void test_GetDataHere_Proxy(void)
{
HRESULT hr;
IStream *stm;
HANDLE thread;
DWORD tid;
static const LARGE_INTEGER zero;
IDataObject *data;
FORMATETC fmt;
STGMEDIUM med;
hr = CreateStreamOnHGlobal( NULL, TRUE, &stm );
ok( hr == S_OK, "got %08x\n", hr );
tid = start_host_object2( stm, &IID_IDataObject, (IUnknown *)&obj.IDataObject_iface, MSHLFLAGS_NORMAL, NULL, &thread );
IStream_Seek( stm, zero, STREAM_SEEK_SET, NULL );
hr = CoUnmarshalInterface( stm, &IID_IDataObject, (void **)&data );
ok( hr == S_OK, "got %08x\n", hr );
IStream_Release( stm );
Test_Stream.refs = 1;
Test_Stream2.refs = 1;
Test_Unknown.refs = 1;
fmt.cfFormat = 1;
fmt.ptd = NULL;
fmt.dwAspect = DVASPECT_CONTENT;
fmt.lindex = -1;
U(med).pstm = NULL;
med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
fmt.tymed = med.tymed = TYMED_NULL;
hr = IDataObject_GetDataHere( data, &fmt, &med );
ok( hr == DV_E_TYMED, "got %08x\n", hr );
for (fmt.tymed = TYMED_HGLOBAL; fmt.tymed <= TYMED_ENHMF; fmt.tymed <<= 1)
{
med.tymed = fmt.tymed;
hr = IDataObject_GetDataHere( data, &fmt, &med );
ok( hr == (fmt.tymed <= TYMED_ISTORAGE ? S_OK : DV_E_TYMED), "got %08x for tymed %d\n", hr, fmt.tymed );
ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
}
fmt.tymed = TYMED_ISTREAM;
med.tymed = TYMED_ISTORAGE;
hr = IDataObject_GetDataHere( data, &fmt, &med );
ok( hr == DV_E_TYMED, "got %08x\n", hr );
fmt.tymed = med.tymed = TYMED_ISTREAM;
U(med).pstm = &Test_Stream.IStream_iface;
med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
hr = IDataObject_GetDataHere( data, &fmt, &med );
ok( hr == S_OK, "got %08x\n", hr );
ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" );
ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" );
ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs );
ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
fmt.cfFormat = 2;
fmt.tymed = med.tymed = TYMED_ISTREAM;
U(med).pstm = &Test_Stream.IStream_iface;
med.pUnkForRelease = &Test_Unknown.IUnknown_iface;
hr = IDataObject_GetDataHere( data, &fmt, &med );
ok( hr == S_OK, "got %08x\n", hr );
ok( U(med).pstm == &Test_Stream.IStream_iface, "stm changed\n" );
ok( med.pUnkForRelease == &Test_Unknown.IUnknown_iface, "punk changed\n" );
ok( Test_Stream.refs == 1, "got %d\n", Test_Stream.refs );
ok( Test_Unknown.refs == 1, "got %d\n", Test_Unknown.refs );
ok( Test_Stream2.refs == 0, "got %d\n", Test_Stream2.refs );
IDataObject_Release( data );
end_host_object( tid, thread );
}
START_TEST(usrmarshal)
{
CoInitialize(NULL);
CoInitializeEx(NULL, COINIT_APARTMENTTHREADED);
test_marshal_CLIPFORMAT();
test_marshal_HWND();
......@@ -1122,5 +1354,7 @@ START_TEST(usrmarshal)
test_marshal_HICON();
test_marshal_HBRUSH();
test_GetDataHere_Proxy();
CoUninitialize();
}
......@@ -2783,13 +2783,39 @@ HRESULT __RPC_STUB IDataObject_GetData_Stub(
return IDataObject_GetData(This, pformatetcIn, pRemoteMedium);
}
HRESULT CALLBACK IDataObject_GetDataHere_Proxy(
IDataObject* This,
FORMATETC *pformatetc,
STGMEDIUM *pmedium)
HRESULT CALLBACK IDataObject_GetDataHere_Proxy(IDataObject *iface, FORMATETC *fmt, STGMEDIUM *med)
{
TRACE("(%p)->(%p, %p)\n", This, pformatetc, pmedium);
return IDataObject_RemoteGetDataHere_Proxy(This, pformatetc, pmedium);
IUnknown *release;
IStorage *stg = NULL;
HRESULT hr;
TRACE("(%p)->(%p, %p)\n", iface, fmt, med);
if ((med->tymed & (TYMED_HGLOBAL | TYMED_FILE | TYMED_ISTREAM | TYMED_ISTORAGE)) == 0)
return DV_E_TYMED;
if (med->tymed != fmt->tymed)
return DV_E_TYMED;
release = med->pUnkForRelease;
med->pUnkForRelease = NULL;
if (med->tymed == TYMED_ISTREAM || med->tymed == TYMED_ISTORAGE)
{
stg = med->u.pstg; /* This may actually be a stream, but that's ok */
if (stg) IStorage_AddRef( stg );
}
hr = IDataObject_RemoteGetDataHere_Proxy(iface, fmt, med);
med->pUnkForRelease = release;
if (stg)
{
if (med->u.pstg)
IStorage_Release( med->u.pstg );
med->u.pstg = stg;
}
return hr;
}
HRESULT __RPC_STUB IDataObject_GetDataHere_Stub(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment