Commit b0218db9 authored by Robert Shearman's avatar Robert Shearman Committed by Alexandre Julliard

oleaut32: Fix circular reference count in Typelib marshaler.

The current method of handling typelib-marshaled interfaces that derive from IDispatch is to query for an IDispatch pointer from the proxy, but this causes a circular reference count. Fix the reference counting by loading using the IRpcProxyBuffer of IDispatch without an outer unknown, so that the lifetime is controlled by the typelib-marshaled interface's proxy. The IDispatch proxy now shares the same channel as the typelib-marshaled interface, so fix up the stub side to handle this.
parent fd81d9c5
......@@ -355,6 +355,7 @@ typedef struct _TMProxyImpl {
CRITICAL_SECTION crit;
IUnknown *outerunknown;
IDispatch *dispatch;
IRpcProxyBuffer *dispatch_proxy;
} TMProxyImpl;
static HRESULT WINAPI
......@@ -391,7 +392,7 @@ TMProxyImpl_Release(LPRPCPROXYBUFFER iface)
if (!refCount)
{
if (This->dispatch) IDispatch_Release(This->dispatch);
if (This->dispatch_proxy) IRpcProxyBuffer_Release(This->dispatch_proxy);
DeleteCriticalSection(&This->crit);
if (This->chanbuf) IRpcChannelBuffer_Release(This->chanbuf);
VirtualFree(This->asmstubs, 0, MEM_RELEASE);
......@@ -415,6 +416,9 @@ TMProxyImpl_Connect(
LeaveCriticalSection(&This->crit);
if (This->dispatch_proxy)
IRpcProxyBuffer_Connect(This->dispatch_proxy, pRpcChannelBuffer);
return S_OK;
}
......@@ -431,6 +435,9 @@ TMProxyImpl_Disconnect(LPRPCPROXYBUFFER iface)
This->chanbuf = NULL;
LeaveCriticalSection(&This->crit);
if (This->dispatch_proxy)
IRpcProxyBuffer_Disconnect(This->dispatch_proxy);
}
......@@ -1380,56 +1387,29 @@ ULONG WINAPI ProxyIUnknown_Release(IUnknown *iface)
static HRESULT WINAPI ProxyIDispatch_GetTypeInfoCount(LPDISPATCH iface, UINT * pctinfo)
{
TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr;
TRACE("(%p)\n", pctinfo);
if (!This->dispatch)
{
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch,
(LPVOID *)&This->dispatch);
}
if (This->dispatch)
hr = IDispatch_GetTypeInfoCount(This->dispatch, pctinfo);
return hr;
return IDispatch_GetTypeInfoCount(This->dispatch, pctinfo);
}
static HRESULT WINAPI ProxyIDispatch_GetTypeInfo(LPDISPATCH iface, UINT iTInfo, LCID lcid, ITypeInfo** ppTInfo)
{
TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr = S_OK;
TRACE("(%d, %lx, %p)\n", iTInfo, lcid, ppTInfo);
if (!This->dispatch)
{
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch,
(LPVOID *)&This->dispatch);
}
if (This->dispatch)
hr = IDispatch_GetTypeInfo(This->dispatch, iTInfo, lcid, ppTInfo);
return hr;
return IDispatch_GetTypeInfo(This->dispatch, iTInfo, lcid, ppTInfo);
}
static HRESULT WINAPI ProxyIDispatch_GetIDsOfNames(LPDISPATCH iface, REFIID riid, LPOLESTR * rgszNames, UINT cNames, LCID lcid, DISPID * rgDispId)
{
TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr;
TRACE("(%s, %p, %d, 0x%lx, %p)\n", debugstr_guid(riid), rgszNames, cNames, lcid, rgDispId);
if (!This->dispatch)
{
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch,
(LPVOID *)&This->dispatch);
}
if (This->dispatch)
hr = IDispatch_GetIDsOfNames(This->dispatch, riid, rgszNames,
cNames, lcid, rgDispId);
return hr;
return IDispatch_GetIDsOfNames(This->dispatch, riid, rgszNames,
cNames, lcid, rgDispId);
}
static HRESULT WINAPI ProxyIDispatch_Invoke(LPDISPATCH iface, DISPID dispIdMember, REFIID riid, LCID lcid,
......@@ -1437,21 +1417,25 @@ static HRESULT WINAPI ProxyIDispatch_Invoke(LPDISPATCH iface, DISPID dispIdMembe
EXCEPINFO * pExcepInfo, UINT * puArgErr)
{
TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr;
TRACE("(%ld, %s, 0x%lx, 0x%x, %p, %p, %p, %p)\n", dispIdMember, debugstr_guid(riid), lcid, wFlags, pDispParams, pVarResult, pExcepInfo, puArgErr);
TRACE("(%ld, %s, 0x%lx, 0x%x, %p, %p, %p, %p)\n", dispIdMember,
debugstr_guid(riid), lcid, wFlags, pDispParams, pVarResult,
pExcepInfo, puArgErr);
if (!This->dispatch)
{
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch,
(LPVOID *)&This->dispatch);
}
if (This->dispatch)
hr = IDispatch_Invoke(This->dispatch, dispIdMember, riid, lcid,
wFlags, pDispParams, pVarResult, pExcepInfo,
puArgErr);
return IDispatch_Invoke(This->dispatch, dispIdMember, riid, lcid,
wFlags, pDispParams, pVarResult, pExcepInfo,
puArgErr);
}
return hr;
static inline HRESULT get_facbuf_for_iid(REFIID riid, IPSFactoryBuffer **facbuf)
{
HRESULT hr;
CLSID clsid;
if ((hr = CoGetPSClsid(riid, &clsid)))
return hr;
return CoGetClassObject(&clsid, CLSCTX_INPROC_SERVER, NULL,
&IID_IPSFactoryBuffer, (LPVOID*)facbuf);
}
static HRESULT WINAPI
......@@ -1479,6 +1463,7 @@ PSFacBuf_CreateProxy(
assert(sizeof(TMAsmProxy) == 12);
proxy->dispatch = NULL;
proxy->dispatch_proxy = NULL;
proxy->outerunknown = pUnkOuter;
proxy->asmstubs = VirtualAlloc(NULL, sizeof(TMAsmProxy) * nroffuncs, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
if (!proxy->asmstubs) {
......@@ -1558,10 +1543,22 @@ PSFacBuf_CreateProxy(
{
if (typeattr->wTypeFlags & TYPEFLAG_FDISPATCHABLE)
{
proxy->lpvtbl[3] = ProxyIDispatch_GetTypeInfoCount;
proxy->lpvtbl[4] = ProxyIDispatch_GetTypeInfo;
proxy->lpvtbl[5] = ProxyIDispatch_GetIDsOfNames;
proxy->lpvtbl[6] = ProxyIDispatch_Invoke;
IPSFactoryBuffer *factory_buffer;
hres = get_facbuf_for_iid(&IID_IDispatch, &factory_buffer);
if (hres == S_OK)
{
hres = IPSFactoryBuffer_CreateProxy(factory_buffer, NULL,
&IID_IDispatch, &proxy->dispatch_proxy,
(void **)&proxy->dispatch);
IPSFactoryBuffer_Release(factory_buffer);
}
if (hres == S_OK)
{
proxy->lpvtbl[3] = ProxyIDispatch_GetTypeInfoCount;
proxy->lpvtbl[4] = ProxyIDispatch_GetTypeInfo;
proxy->lpvtbl[5] = ProxyIDispatch_GetIDsOfNames;
proxy->lpvtbl[6] = ProxyIDispatch_Invoke;
}
}
ITypeInfo_ReleaseTypeAttr(tinfo, typeattr);
}
......@@ -1572,10 +1569,16 @@ PSFacBuf_CreateProxy(
proxy->tinfo = tinfo;
memcpy(&proxy->iid,riid,sizeof(*riid));
proxy->chanbuf = 0;
*ppv = (LPVOID)proxy;
*ppProxy = (IRpcProxyBuffer *)&(proxy->lpvtbl2);
IUnknown_AddRef((IUnknown *)*ppv);
return S_OK;
if (hres == S_OK)
{
*ppv = (LPVOID)proxy;
*ppProxy = (IRpcProxyBuffer *)&(proxy->lpvtbl2);
IUnknown_AddRef((IUnknown *)*ppv);
return S_OK;
}
else
TMProxyImpl_Release((IRpcProxyBuffer *)&proxy->lpvtbl2);
return hres;
}
typedef struct _TMStubImpl {
......@@ -1585,6 +1588,7 @@ typedef struct _TMStubImpl {
LPUNKNOWN pUnk;
ITypeInfo *tinfo;
IID iid;
IRpcStubBuffer *dispatch_stub;
} TMStubImpl;
static HRESULT WINAPI
......@@ -1636,6 +1640,10 @@ TMStubImpl_Connect(LPRPCSTUBBUFFER iface, LPUNKNOWN pUnkServer)
IUnknown_AddRef(pUnkServer);
This->pUnk = pUnkServer;
if (This->dispatch_stub)
IRpcStubBuffer_Connect(This->dispatch_stub, pUnkServer);
return S_OK;
}
......@@ -1651,6 +1659,9 @@ TMStubImpl_Disconnect(LPRPCSTUBBUFFER iface)
IUnknown_Release(This->pUnk);
This->pUnk = NULL;
}
if (This->dispatch_stub)
IRpcStubBuffer_Disconnect(This->dispatch_stub);
}
static HRESULT WINAPI
......@@ -1668,12 +1679,6 @@ TMStubImpl_Invoke(
BSTR iname = NULL;
ITypeInfo *tinfo;
memset(&buf,0,sizeof(buf));
buf.size = xmsg->cbBuffer;
buf.base = HeapAlloc(GetProcessHeap(), 0, xmsg->cbBuffer);
memcpy(buf.base, xmsg->Buffer, xmsg->cbBuffer);
buf.curoff = 0;
TRACE("...\n");
if (xmsg->iMethod < 3) {
......@@ -1681,6 +1686,15 @@ TMStubImpl_Invoke(
return E_UNEXPECTED;
}
if (This->dispatch_stub && xmsg->iMethod < sizeof(IDispatchVtbl)/sizeof(void *))
return IRpcStubBuffer_Invoke(This->dispatch_stub, xmsg, rpcchanbuf);
memset(&buf,0,sizeof(buf));
buf.size = xmsg->cbBuffer;
buf.base = HeapAlloc(GetProcessHeap(), 0, xmsg->cbBuffer);
memcpy(buf.base, xmsg->Buffer, xmsg->cbBuffer);
buf.curoff = 0;
hres = _get_funcdesc(This->tinfo,xmsg->iMethod,&tinfo,&fdesc,&iname,NULL);
if (hres) {
ERR("GetFuncDesc on method %ld failed with %lx\n",xmsg->iMethod,hres);
......@@ -1839,25 +1853,48 @@ PSFacBuf_CreateStub(
HRESULT hres;
ITypeInfo *tinfo;
TMStubImpl *stub;
TYPEATTR *typeattr;
TRACE("(%s,%p,%p)\n",debugstr_guid(riid),pUnkServer,ppStub);
hres = _get_typeinfo_for_iid(riid,&tinfo);
if (hres) {
ERR("No typeinfo for %s?\n",debugstr_guid(riid));
return hres;
}
stub = CoTaskMemAlloc(sizeof(TMStubImpl));
if (!stub)
return E_OUTOFMEMORY;
stub->lpvtbl = &tmstubvtbl;
stub->ref = 1;
stub->tinfo = tinfo;
stub->dispatch_stub = NULL;
memcpy(&(stub->iid),riid,sizeof(*riid));
hres = IRpcStubBuffer_Connect((LPRPCSTUBBUFFER)stub,pUnkServer);
*ppStub = (LPRPCSTUBBUFFER)stub;
TRACE("IRpcStubBuffer: %p\n", stub);
if (hres)
ERR("Connect to pUnkServer failed?\n");
/* if we derive from IDispatch then defer to its stub for some of its methods */
hres = ITypeInfo_GetTypeAttr(tinfo, &typeattr);
if (hres == S_OK)
{
if (typeattr->wTypeFlags & TYPEFLAG_FDISPATCHABLE)
{
IPSFactoryBuffer *factory_buffer;
hres = get_facbuf_for_iid(&IID_IDispatch, &factory_buffer);
if (hres == S_OK)
{
hres = IPSFactoryBuffer_CreateStub(factory_buffer, &IID_IDispatch,
pUnkServer, &stub->dispatch_stub);
IPSFactoryBuffer_Release(factory_buffer);
}
}
ITypeInfo_ReleaseTypeAttr(tinfo, typeattr);
}
return hres;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment