diff --git a/dlls/ole32/compobj.c b/dlls/ole32/compobj.c index e007440563..92ace8a7d6 100644 --- a/dlls/ole32/compobj.c +++ b/dlls/ole32/compobj.c @@ -60,6 +60,7 @@ #include "ole2ver.h" #include "ctxtcall.h" #include "dde.h" +#include "servprov.h" #include "initguid.h" #include "compobj_private.h" @@ -96,6 +97,20 @@ struct registered_psclsid CLSID clsid; }; +/* + * This is a marshallable object exposing registered local servers. + * IServiceProvider is used only because it happens meet requirements + * and already has proxy/stub code. If more functionality is needed, + * a custom interface may be used instead. + */ +struct LocalServer +{ + IServiceProvider IServiceProvider_iface; + LONG ref; + APARTMENT *apt; + IStream *marshal_stream; +}; + /* * This lock count counts the number of times CoInitialize is called. It is * decreased every time CoUninitialize is called. When it hits 0, the COM @@ -122,7 +137,6 @@ typedef struct tagRegisteredClass DWORD runContext; DWORD connectFlags; DWORD dwCookie; - LPSTREAM pMarshaledData; /* FIXME: only really need to store OXID and IPID */ void *RpcRegistration; } RegisteredClass; @@ -544,20 +558,7 @@ static void COM_RevokeRegisteredClassObject(RegisteredClass *curClass) if (curClass->runContext & CLSCTX_LOCAL_SERVER) RPC_StopLocalServer(curClass->RpcRegistration); - /* - * Release the reference to the class object. - */ IUnknown_Release(curClass->classObject); - - if (curClass->pMarshaledData) - { - LARGE_INTEGER zero; - memset(&zero, 0, sizeof(zero)); - IStream_Seek(curClass->pMarshaledData, zero, STREAM_SEEK_SET, NULL); - CoReleaseMarshalData(curClass->pMarshaledData); - IStream_Release(curClass->pMarshaledData); - } - HeapFree(GetProcessHeap(), 0, curClass); } @@ -725,6 +726,130 @@ static HRESULT ManualResetEvent_Construct(IUnknown *punkouter, REFIID iid, void return hr; } +static inline LocalServer *impl_from_IServiceProvider(IServiceProvider *iface) +{ + return CONTAINING_RECORD(iface, LocalServer, IServiceProvider_iface); +} + +static HRESULT WINAPI LocalServer_QueryInterface(IServiceProvider *iface, REFIID riid, void **ppv) +{ + LocalServer *This = impl_from_IServiceProvider(iface); + + TRACE("(%p)->(%s %p)\n", This, debugstr_guid(riid), ppv); + + if(IsEqualGUID(riid, &IID_IUnknown) || IsEqualGUID(riid, &IID_IServiceProvider)) { + *ppv = &This->IServiceProvider_iface; + }else { + *ppv = NULL; + return E_NOINTERFACE; + } + + IUnknown_AddRef((IUnknown*)*ppv); + return S_OK; +} + +static ULONG WINAPI LocalServer_AddRef(IServiceProvider *iface) +{ + LocalServer *This = impl_from_IServiceProvider(iface); + LONG ref = InterlockedIncrement(&This->ref); + + TRACE("(%p) ref=%d\n", This, ref); + + return ref; +} + +static ULONG WINAPI LocalServer_Release(IServiceProvider *iface) +{ + LocalServer *This = impl_from_IServiceProvider(iface); + LONG ref = InterlockedDecrement(&This->ref); + + TRACE("(%p) ref=%d\n", This, ref); + + if(!ref) { + assert(!This->apt); + HeapFree(GetProcessHeap(), 0, This); + } + + return ref; +} + +static HRESULT WINAPI LocalServer_QueryService(IServiceProvider *iface, REFGUID guid, REFIID riid, void **ppv) +{ + LocalServer *This = impl_from_IServiceProvider(iface); + APARTMENT *apt = COM_CurrentApt(); + RegisteredClass *iter; + HRESULT hres = E_FAIL; + + TRACE("(%p)->(%s %s %p)\n", This, debugstr_guid(guid), debugstr_guid(riid), ppv); + + if(!This->apt) + return E_UNEXPECTED; + + EnterCriticalSection(&csRegisteredClassList); + + LIST_FOR_EACH_ENTRY(iter, &RegisteredClassList, RegisteredClass, entry) { + if(iter->apartment_id == apt->oxid + && (iter->runContext & CLSCTX_LOCAL_SERVER) + && IsEqualGUID(&iter->classIdentifier, guid)) { + hres = IUnknown_QueryInterface(iter->classObject, riid, ppv); + break; + } + } + + LeaveCriticalSection( &csRegisteredClassList ); + + return hres; +} + +static const IServiceProviderVtbl LocalServerVtbl = { + LocalServer_QueryInterface, + LocalServer_AddRef, + LocalServer_Release, + LocalServer_QueryService +}; + +static HRESULT get_local_server_stream(APARTMENT *apt, IStream **ret) +{ + HRESULT hres = S_OK; + + EnterCriticalSection(&apt->cs); + + if(!apt->local_server) { + LocalServer *obj; + + obj = heap_alloc(sizeof(*obj)); + if(obj) { + obj->IServiceProvider_iface.lpVtbl = &LocalServerVtbl; + obj->ref = 1; + obj->apt = apt; + + hres = CreateStreamOnHGlobal(0, TRUE, &obj->marshal_stream); + if(SUCCEEDED(hres)) { + hres = CoMarshalInterface(obj->marshal_stream, &IID_IServiceProvider, (IUnknown*)&obj->IServiceProvider_iface, + MSHCTX_LOCAL, NULL, MSHLFLAGS_TABLESTRONG); + if(FAILED(hres)) + IStream_Release(obj->marshal_stream); + } + + if(SUCCEEDED(hres)) + apt->local_server = obj; + else + heap_free(obj); + }else { + hres = E_OUTOFMEMORY; + } + } + + if(SUCCEEDED(hres)) + hres = IStream_Clone(apt->local_server->marshal_stream, ret); + + LeaveCriticalSection(&apt->cs); + + if(FAILED(hres)) + ERR("Failed: %08x\n", hres); + return hres; +} + /*********************************************************************** * CoRevokeClassObject [OLE32.@] * @@ -855,6 +980,21 @@ DWORD apartment_release(struct apartment *apt) TRACE("destroying apartment %p, oxid %s\n", apt, wine_dbgstr_longlong(apt->oxid)); + if(apt->local_server) { + LocalServer *local_server = apt->local_server; + LARGE_INTEGER zero; + + memset(&zero, 0, sizeof(zero)); + IStream_Seek(local_server->marshal_stream, zero, STREAM_SEEK_SET, NULL); + CoReleaseMarshalData(local_server->marshal_stream); + IStream_Release(local_server->marshal_stream); + local_server->marshal_stream = NULL; + + apt->local_server = NULL; + local_server->apt = NULL; + IServiceProvider_Release(&local_server->IServiceProvider_iface); + } + /* Release the references to the registered class objects */ COM_RevokeAllClasses(apt); @@ -2411,7 +2551,6 @@ HRESULT WINAPI CoRegisterClassObject( newClass->apartment_id = apt->oxid; newClass->runContext = dwClsContext; newClass->connectFlags = flags; - newClass->pMarshaledData = NULL; newClass->RpcRegistration = NULL; if (!(newClass->dwCookie = InterlockedIncrement( &next_cookie ))) @@ -2431,23 +2570,17 @@ HRESULT WINAPI CoRegisterClassObject( *lpdwRegister = newClass->dwCookie; if (dwClsContext & CLSCTX_LOCAL_SERVER) { - hr = CreateStreamOnHGlobal(0, TRUE, &newClass->pMarshaledData); - if (hr) { - FIXME("Failed to create stream on hglobal, %x\n", hr); + IStream *marshal_stream; + + hr = get_local_server_stream(apt, &marshal_stream); + if(FAILED(hr)) return hr; - } - hr = CoMarshalInterface(newClass->pMarshaledData, &IID_IUnknown, - newClass->classObject, MSHCTX_LOCAL, NULL, - MSHLFLAGS_TABLESTRONG); - if (hr) { - FIXME("CoMarshalInterface failed, %x!\n",hr); - return hr; - } hr = RPC_StartLocalServer(&newClass->classIdentifier, - newClass->pMarshaledData, + marshal_stream, flags & (REGCLS_MULTIPLEUSE|REGCLS_MULTI_SEPARATE), &newClass->RpcRegistration); + IStream_Release(marshal_stream); } return S_OK; } diff --git a/dlls/ole32/compobj_private.h b/dlls/ole32/compobj_private.h index b76e3a72b8..f2c9d31360 100644 --- a/dlls/ole32/compobj_private.h +++ b/dlls/ole32/compobj_private.h @@ -40,6 +40,7 @@ struct apartment; typedef struct apartment APARTMENT; +typedef struct LocalServer LocalServer; DEFINE_OLEGUID( CLSID_DfMarshal, 0x0000030b, 0, 0 ); @@ -137,6 +138,7 @@ struct apartment struct list loaded_dlls; /* list of dlls loaded by this apartment (CS cs) */ DWORD host_apt_tid; /* thread ID of apartment hosting objects of differing threading model (CS cs) */ HWND host_apt_hwnd; /* handle to apartment window of host apartment (CS cs) */ + LocalServer *local_server; /* A marshallable object exposing local servers (CS cs) */ /* FIXME: OIDs should be given out by RPCSS */ OID oidc; /* object ID counter, starts at 1, zero is invalid OID (CS cs) */ @@ -312,4 +314,14 @@ extern UINT ole_private_data_clipboard_format DECLSPEC_HIDDEN; extern LSTATUS create_classes_key(HKEY, const WCHAR *, REGSAM, HKEY *) DECLSPEC_HIDDEN; extern LSTATUS open_classes_key(HKEY, const WCHAR *, REGSAM, HKEY *) DECLSPEC_HIDDEN; +static inline void *heap_alloc(size_t len) +{ + return HeapAlloc(GetProcessHeap(), 0, len); +} + +static inline BOOL heap_free(void *mem) +{ + return HeapFree(GetProcessHeap(), 0, mem); +} + #endif /* __WINE_OLE_COMPOBJ_H */ diff --git a/dlls/ole32/errorinfo.c b/dlls/ole32/errorinfo.c index b71ee2d947..d5ec17207a 100644 --- a/dlls/ole32/errorinfo.c +++ b/dlls/ole32/errorinfo.c @@ -41,16 +41,6 @@ WINE_DEFAULT_DEBUG_CHANNEL(ole); -static inline void *heap_alloc(size_t len) -{ - return HeapAlloc(GetProcessHeap(), 0, len); -} - -static inline BOOL heap_free(void *mem) -{ - return HeapFree(GetProcessHeap(), 0, mem); -} - static inline WCHAR *heap_strdupW(const WCHAR *str) { WCHAR *ret = NULL; diff --git a/dlls/ole32/rpc.c b/dlls/ole32/rpc.c index 625938074a..5bade5c714 100644 --- a/dlls/ole32/rpc.c +++ b/dlls/ole32/rpc.c @@ -39,6 +39,7 @@ #include "rpc.h" #include "winerror.h" #include "winreg.h" +#include "servprov.h" #include "wine/unicode.h" #include "compobj_private.h" @@ -1806,6 +1807,7 @@ HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv) LARGE_INTEGER seekto; ULARGE_INTEGER newpos; int tries = 0; + IServiceProvider *local_server; static const int MAXTRIES = 30; /* 30 seconds */ @@ -1865,8 +1867,11 @@ HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv) seekto.u.LowPart = 0;seekto.u.HighPart = 0; hres = IStream_Seek(pStm,seekto,STREAM_SEEK_SET,&newpos); - TRACE("unmarshalling classfactory\n"); - hres = CoUnmarshalInterface(pStm,&IID_IClassFactory,ppv); + TRACE("unmarshalling local server\n"); + hres = CoUnmarshalInterface(pStm, &IID_IServiceProvider, (void**)&local_server); + if(SUCCEEDED(hres)) + hres = IServiceProvider_QueryService(local_server, rclsid, iid, ppv); + IServiceProvider_Release(local_server); out: IStream_Release(pStm); return hres; @@ -1927,7 +1932,7 @@ static DWORD WINAPI local_server_thread(LPVOID param) } } - TRACE("marshalling IClassFactory to client\n"); + TRACE("marshalling LocalServer to client\n"); hres = IStream_Stat(pStm,&ststg,STATFLAG_NONAME); if (hres) @@ -1957,7 +1962,7 @@ static DWORD WINAPI local_server_thread(LPVOID param) FlushFileBuffers(hPipe); DisconnectNamedPipe(hPipe); - TRACE("done marshalling IClassFactory\n"); + TRACE("done marshalling LocalServer\n"); if (!multi_use) { diff --git a/dlls/ole32/tests/defaulthandler.c b/dlls/ole32/tests/defaulthandler.c index e524f97a78..0196e514d4 100644 --- a/dlls/ole32/tests/defaulthandler.c +++ b/dlls/ole32/tests/defaulthandler.c @@ -274,11 +274,9 @@ static void test_default_handler_run(void) CoRevokeClassObject(class_reg); todo_wine CHECK_CALLED(CF_QueryInterface_IMarshal); - SET_EXPECT(CF_QueryInterface_IMarshal); hres = CoRegisterClassObject(&test_server_clsid, (IUnknown*)&ClassFactory, CLSCTX_LOCAL_SERVER, 0, &class_reg); ok(hres == S_OK, "CoRegisterClassObject failed: %x\n", hres); - todo_wine CHECK_NOT_CALLED(CF_QueryInterface_IMarshal); hres = OleCreateDefaultHandler(&test_server_clsid, NULL, &IID_IUnknown, (void**)&unk); ok(hres == S_OK, "OleCreateDefaultHandler failed: %x\n", hres);