Commit 74989b6e authored by Jacek Caban's avatar Jacek Caban Committed by Alexandre Julliard

urlmon: Added support for COM aggregation of file protocol handler.

parent 114b0705
...@@ -25,9 +25,12 @@ ...@@ -25,9 +25,12 @@
WINE_DEFAULT_DEBUG_CHANNEL(urlmon); WINE_DEFAULT_DEBUG_CHANNEL(urlmon);
typedef struct { typedef struct {
IUnknown IUnknown_outer;
IInternetProtocolEx IInternetProtocolEx_iface; IInternetProtocolEx IInternetProtocolEx_iface;
IInternetPriority IInternetPriority_iface; IInternetPriority IInternetPriority_iface;
IUnknown *outer;
HANDLE file; HANDLE file;
ULONG size; ULONG size;
LONG priority; LONG priority;
...@@ -35,6 +38,11 @@ typedef struct { ...@@ -35,6 +38,11 @@ typedef struct {
LONG ref; LONG ref;
} FileProtocol; } FileProtocol;
static inline FileProtocol *impl_from_IUnknown(IUnknown *iface)
{
return CONTAINING_RECORD(iface, FileProtocol, IUnknown_outer);
}
static inline FileProtocol *impl_from_IInternetProtocolEx(IInternetProtocolEx *iface) static inline FileProtocol *impl_from_IInternetProtocolEx(IInternetProtocolEx *iface)
{ {
return CONTAINING_RECORD(iface, FileProtocol, IInternetProtocolEx_iface); return CONTAINING_RECORD(iface, FileProtocol, IInternetProtocolEx_iface);
...@@ -45,14 +53,14 @@ static inline FileProtocol *impl_from_IInternetPriority(IInternetPriority *iface ...@@ -45,14 +53,14 @@ static inline FileProtocol *impl_from_IInternetPriority(IInternetPriority *iface
return CONTAINING_RECORD(iface, FileProtocol, IInternetPriority_iface); return CONTAINING_RECORD(iface, FileProtocol, IInternetPriority_iface);
} }
static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, REFIID riid, void **ppv) static HRESULT WINAPI FileProtocolUnk_QueryInterface(IUnknown *iface, REFIID riid, void **ppv)
{ {
FileProtocol *This = impl_from_IInternetProtocolEx(iface); FileProtocol *This = impl_from_IUnknown(iface);
*ppv = NULL; *ppv = NULL;
if(IsEqualGUID(&IID_IUnknown, riid)) { if(IsEqualGUID(&IID_IUnknown, riid)) {
TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv); TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv);
*ppv = &This->IInternetProtocolEx_iface; *ppv = &This->IUnknown_outer;
}else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) { }else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) {
TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv); TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv);
*ppv = &This->IInternetProtocolEx_iface; *ppv = &This->IInternetProtocolEx_iface;
...@@ -68,7 +76,7 @@ static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, RE ...@@ -68,7 +76,7 @@ static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, RE
} }
if(*ppv) { if(*ppv) {
IInternetProtocolEx_AddRef(iface); IUnknown_AddRef((IUnknown*)*ppv);
return S_OK; return S_OK;
} }
...@@ -76,17 +84,17 @@ static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, RE ...@@ -76,17 +84,17 @@ static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, RE
return E_NOINTERFACE; return E_NOINTERFACE;
} }
static ULONG WINAPI FileProtocol_AddRef(IInternetProtocolEx *iface) static ULONG WINAPI FileProtocolUnk_AddRef(IUnknown *iface)
{ {
FileProtocol *This = impl_from_IInternetProtocolEx(iface); FileProtocol *This = impl_from_IUnknown(iface);
LONG ref = InterlockedIncrement(&This->ref); LONG ref = InterlockedIncrement(&This->ref);
TRACE("(%p) ref=%d\n", This, ref); TRACE("(%p) ref=%d\n", This, ref);
return ref; return ref;
} }
static ULONG WINAPI FileProtocol_Release(IInternetProtocolEx *iface) static ULONG WINAPI FileProtocolUnk_Release(IUnknown *iface)
{ {
FileProtocol *This = impl_from_IInternetProtocolEx(iface); FileProtocol *This = impl_from_IUnknown(iface);
LONG ref = InterlockedDecrement(&This->ref); LONG ref = InterlockedDecrement(&This->ref);
TRACE("(%p) ref=%d\n", This, ref); TRACE("(%p) ref=%d\n", This, ref);
...@@ -102,6 +110,33 @@ static ULONG WINAPI FileProtocol_Release(IInternetProtocolEx *iface) ...@@ -102,6 +110,33 @@ static ULONG WINAPI FileProtocol_Release(IInternetProtocolEx *iface)
return ref; return ref;
} }
static const IUnknownVtbl FileProtocolUnkVtbl = {
FileProtocolUnk_QueryInterface,
FileProtocolUnk_AddRef,
FileProtocolUnk_Release
};
static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, REFIID riid, void **ppv)
{
FileProtocol *This = impl_from_IInternetProtocolEx(iface);
TRACE("(%p)->(%s %p)\n", This, debugstr_guid(riid), ppv);
return IUnknown_QueryInterface(This->outer, riid, ppv);
}
static ULONG WINAPI FileProtocol_AddRef(IInternetProtocolEx *iface)
{
FileProtocol *This = impl_from_IInternetProtocolEx(iface);
TRACE("(%p)\n", This);
return IUnknown_AddRef(This->outer);
}
static ULONG WINAPI FileProtocol_Release(IInternetProtocolEx *iface)
{
FileProtocol *This = impl_from_IInternetProtocolEx(iface);
TRACE("(%p)\n", This);
return IUnknown_Release(This->outer);
}
static HRESULT WINAPI FileProtocol_Start(IInternetProtocolEx *iface, LPCWSTR szUrl, static HRESULT WINAPI FileProtocol_Start(IInternetProtocolEx *iface, LPCWSTR szUrl,
IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo, IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo,
DWORD grfPI, HANDLE_PTR dwReserved) DWORD grfPI, HANDLE_PTR dwReserved)
...@@ -383,22 +418,24 @@ static const IInternetPriorityVtbl FilePriorityVtbl = { ...@@ -383,22 +418,24 @@ static const IInternetPriorityVtbl FilePriorityVtbl = {
FilePriority_GetPriority FilePriority_GetPriority
}; };
HRESULT FileProtocol_Construct(IUnknown *pUnkOuter, LPVOID *ppobj) HRESULT FileProtocol_Construct(IUnknown *outer, LPVOID *ppobj)
{ {
FileProtocol *ret; FileProtocol *ret;
TRACE("(%p %p)\n", pUnkOuter, ppobj); TRACE("(%p %p)\n", outer, ppobj);
URLMON_LockModule(); URLMON_LockModule();
ret = heap_alloc(sizeof(FileProtocol)); ret = heap_alloc(sizeof(FileProtocol));
ret->IUnknown_outer.lpVtbl = &FileProtocolUnkVtbl;
ret->IInternetProtocolEx_iface.lpVtbl = &FileProtocolExVtbl; ret->IInternetProtocolEx_iface.lpVtbl = &FileProtocolExVtbl;
ret->IInternetPriority_iface.lpVtbl = &FilePriorityVtbl; ret->IInternetPriority_iface.lpVtbl = &FilePriorityVtbl;
ret->file = INVALID_HANDLE_VALUE; ret->file = INVALID_HANDLE_VALUE;
ret->priority = 0; ret->priority = 0;
ret->ref = 1; ret->ref = 1;
ret->outer = outer ? outer : (IUnknown*)&ret->IUnknown_outer;
*ppobj = &ret->IInternetProtocolEx_iface; *ppobj = &ret->IUnknown_outer;
return S_OK; return S_OK;
} }
...@@ -125,6 +125,7 @@ DEFINE_EXPECT(MimeFilter_Continue); ...@@ -125,6 +125,7 @@ DEFINE_EXPECT(MimeFilter_Continue);
DEFINE_EXPECT(Stream_Seek); DEFINE_EXPECT(Stream_Seek);
DEFINE_EXPECT(Stream_Read); DEFINE_EXPECT(Stream_Read);
DEFINE_EXPECT(Redirect); DEFINE_EXPECT(Redirect);
DEFINE_EXPECT(outer_QI_test);
static const WCHAR wszIndexHtml[] = {'i','n','d','e','x','.','h','t','m','l',0}; static const WCHAR wszIndexHtml[] = {'i','n','d','e','x','.','h','t','m','l',0};
static const WCHAR index_url[] = static const WCHAR index_url[] =
...@@ -3964,6 +3965,68 @@ static void test_binding(int prot, DWORD grf_pi, DWORD test_flags) ...@@ -3964,6 +3965,68 @@ static void test_binding(int prot, DWORD grf_pi, DWORD test_flags)
IInternetSession_Release(session); IInternetSession_Release(session);
} }
static const IID outer_test_iid = {0xabcabc00,0,0,{0,0,0,0,0,0,0,0x66}};
static HRESULT WINAPI outer_QueryInterface(IUnknown *iface, REFIID riid, void **ppv)
{
if(IsEqualGUID(riid, &outer_test_iid)) {
CHECK_EXPECT(outer_QI_test);
*ppv = (IUnknown*)0xdeadbeef;
return S_OK;
}
ok(0, "unexpected call %s\n", wine_dbgstr_guid(riid));
return E_NOINTERFACE;
}
static ULONG WINAPI outer_AddRef(IUnknown *iface)
{
return 2;
}
static ULONG WINAPI outer_Release(IUnknown *iface)
{
return 1;
}
static const IUnknownVtbl outer_vtbl = {
outer_QueryInterface,
outer_AddRef,
outer_Release
};
static void test_com_aggregation(const CLSID *clsid)
{
IUnknown outer = { &outer_vtbl };
IClassFactory *class_factory;
IUnknown *unk, *unk2, *unk3;
HRESULT hres;
hres = CoGetClassObject(clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IClassFactory, (void**)&class_factory);
ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IUnknown, (void**)&unk);
ok(hres == S_OK, "CreateInstance returned: %08x\n", hres);
hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocol, (void**)&unk2);
ok(hres == S_OK, "Could not get IDispatch iface: %08x\n", hres);
SET_EXPECT(outer_QI_test);
hres = IUnknown_QueryInterface(unk2, &outer_test_iid, (void**)&unk3);
CHECK_CALLED(outer_QI_test);
ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres);
ok(unk3 == (IUnknown*)0xdeadbeef, "unexpected unk2\n");
IUnknown_Release(unk2);
IUnknown_Release(unk);
unk = (void*)0xdeadbeef;
hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IInternetProtocol, (void**)&unk);
ok(hres == CLASS_E_NOAGGREGATION, "CreateInstance returned: %08x\n", hres);
ok(!unk, "unk = %p\n", unk);
IClassFactory_Release(class_factory);
}
START_TEST(protocol) START_TEST(protocol)
{ {
HMODULE hurlmon; HMODULE hurlmon;
...@@ -4037,5 +4100,7 @@ START_TEST(protocol) ...@@ -4037,5 +4100,7 @@ START_TEST(protocol)
CloseHandle(event_continue); CloseHandle(event_continue);
CloseHandle(event_continue_done); CloseHandle(event_continue_done);
test_com_aggregation(&CLSID_FileProtocol);
OleUninitialize(); OleUninitialize();
} }
...@@ -303,19 +303,31 @@ static ULONG WINAPI CF_Release(IClassFactory *iface) ...@@ -303,19 +303,31 @@ static ULONG WINAPI CF_Release(IClassFactory *iface)
} }
static HRESULT WINAPI CF_CreateInstance(IClassFactory *iface, IUnknown *pOuter, static HRESULT WINAPI CF_CreateInstance(IClassFactory *iface, IUnknown *outer,
REFIID riid, LPVOID *ppobj) REFIID riid, void **ppv)
{ {
ClassFactory *This = impl_from_IClassFactory(iface); ClassFactory *This = impl_from_IClassFactory(iface);
IUnknown *unk;
HRESULT hres; HRESULT hres;
LPUNKNOWN punk;
TRACE("(%p)->(%p,%s,%p)\n",This,pOuter,debugstr_guid(riid),ppobj); TRACE("(%p)->(%p %s %p)\n", This, outer, debugstr_guid(riid), ppv);
*ppobj = NULL; if(outer && !IsEqualGUID(riid, &IID_IUnknown)) {
if(SUCCEEDED(hres = This->pfnCreateInstance(pOuter, (LPVOID *) &punk))) { *ppv = NULL;
hres = IUnknown_QueryInterface(punk, riid, ppobj); return CLASS_E_NOAGGREGATION;
IUnknown_Release(punk); }
hres = This->pfnCreateInstance(outer, (void**)&unk);
if(FAILED(hres)) {
*ppv = NULL;
return hres;
}
if(!IsEqualGUID(riid, &IID_IUnknown)) {
hres = IUnknown_QueryInterface(unk, riid, ppv);
IUnknown_Release(unk);
}else {
*ppv = unk;
} }
return hres; return hres;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment