Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

C++/COM/Proxy Dlls: method override/method forwarding (COM implementation inheritance)

Hello and good day to you.

Situation:
For some reason, from time to time I run into situation when I need to override one or two methods of a COM interface (that is being used for some older application without source code), which is normally Direct3D/DirectInput related (i.e. it is created by calling a DLL method, not by CoCreateInstance). Normally I deal with situation by writing a proxy DLL that overrides a method that creates interface I need to "modify", and replace original interface with my own. Normally this is required to make some older application work properly without crashing/artifacts.

Compiler:
I use Visual Studio express 2008 on windows machine, so there are no C++0x features. The system has msysgit, msys, python, perl, gnu utilities (awk/sed/wc/bash/etc), gnu make and qmake (Qt-4.7.1) installed (and available within PATH).

Problem:
Overriding one method of a COM interface is a pain (especially if original interface has a hundred of methods or so), because I need to forward many calls to original interface, and currently I see no way to simplify or automate the process. For example, override of IDirect3D9 looks like this:

class MyD3D9: public IDirect3D9{
protected:
    volatile LONG refCount;
    IDirect3D9 *orig;
public:
    STDMETHOD(QueryInterface)(THIS_ REFIID riid, LPVOID * ppvObj){
        if (!ppvObj)
            return E_INVALIDARG;
        *ppvObj = NULL;
        if (riid == IID_IUnknown  || riid == IID_IDirect3D9){
            *ppvObj = (LPVOID)this;
            AddRef();
            return NOERROR;
        }
        return E_NOINTERFACE;
    }

    STDMETHOD_(ULONG,AddRef)(THIS){
        InterlockedIncrement(&refCount);
        return refCount;
    }
    STDMETHOD_(ULONG,Release)(THIS){
        ULONG ref = InterlockedDecrement(&refCount);
        if (refCount == 0)
            delete this;
        return ref;
    }

    /*** IDirect3D9 methods ***/
    STDMETHOD(RegisterSoftwareDevice)(THIS_ void* pInitializeFunction){
        if (!orig)
            return E_FAIL;
        return orig->RegisterSoftwareDevice(pInitializeFunction);
    }

    STDMETHOD_(UINT, GetAdapterCount)(THIS){
        if (!orig)
            return 0;
        return orig->GetAdapterCount();
    }

    STDMETHOD(GetAdapterIdentifier)(THIS_ UINT Adapter,DWORD Flags, D3DADAPTER_IDENTIFIER9* pIdentifier){
        if (!orig)
            return E_FAIL;
        return orig->GetAdapterIdentifier(Adapter, Flags, pIdentifier);
    }

    STDMETHOD_(UINT, GetAdapterModeCount)(THIS_ UINT Adapter,D3DFORMAT Format){
        if (!orig)
            return 0;
        return orig->GetAdapterModeCount(Adapter, Format);
    }
/* some code skipped*/

    MyD3D9(IDirect3D9* origD3D9)
        :refCount(1), orig(origD3D9){
    }

    ~MyD3D9(){
        if (orig){
            orig->Release();
            orig = 0;
        }
    }
};

As you can see, this is very inefficient, error-prone and requires a lot of copy-pasting.

Question:
How can I simplify overriding of a single method of a COM interface in this situation? I would like to specify only method I change, but I currently see no way to do so. I also don't see a way to elegantly shorten "forwarded" methods with macros or templates or macros, because they have variable number of arguments. Another approach I saw is to use directly patch method table returned by another method (modify access right using VirtualProtect, then write into method table), which I don't exactly like.

Limitations:
I would prefer to solve in C++ source code (macros/templates) and without code generators (unless code generator usage is extremely simple/elegant - i.e. writing code generator is not ok, using already available code generator I can set up in minutes and solve the whole thing in one line of code is ok). Boost is okay only if it doesn't add extra DLL dependency. MS-specific compiler directives and language extensions are also ok.

Ideas? Thanks in advance.

like image 236
SigTerm Avatar asked Mar 14 '26 19:03

SigTerm


1 Answers

Okay, since I don't like unanswered questions...

To implement "COM implementation inheritance" there's currently no sane and compact solution written in pure C++. This is mostly because in C++ it is forbidden to create an instance of abstract class or manipulate virtual method table directly. As a result, there are 2 commonly used solutions:

  1. Write method forwarding for every method manually.
  2. Hack dispatch table.

Advantage of #1 is that this approach is safe and you can store additional data within custom class. Disadvantage of #1 is that writing a wrapper for every single method is extremely tedious procedure.

Advantage of #2 is that this approach is compact. You replace single method. Disadvantage of #2 is that dispatch table might be located in write-protected space (most likely it wouldn't happen, but it could happen in theory) and you can't store custom data in hacked interface. As a result, although it is simple/short, it is quite limiting.

And there's a 3rd approach. (which nobody has suggested for some reason)

Short description: instead of using virtual method table provided by C++, write non-virtual class that will emulate virtual method table.

Example:

template<typename T1, typename T2> void unsafeCast(T1 &dst, const T2 &src){
    int i[sizeof(dst) == sizeof(src)? 1: -1] = {0};
    union{
        T2 src;
        T1 dst;
    }u;
    u.src = src;
    dst = u.dst;
}

template<int Index> void __declspec(naked) vtblMapper(){
#define pointerSize 4 //adjust for 64bit
    static const int methodOffset = sizeof(void*)*Index;
    __asm{
        mov eax, [esp + pointerSize]
        mov eax, [eax + pointerSize]
        mov [esp + pointerSize], eax
        mov eax, [eax]
        add eax, methodOffset
        mov eax, [eax]
        jmp eax
    };
#undef pointerSize
}

struct MyD3DIndexBuffer9{
protected:
    VtblMethod* vtbl;
    IDirect3DIndexBuffer9* orig;
    volatile LONG refCount;
    enum{vtblSize = 14};
    DWORD flags;
    bool dynamic, writeonly;
public:
    inline IDirect3DIndexBuffer9*getOriginalPtr(){
        return orig;
    }

    HRESULT __declspec(nothrow) __stdcall QueryInterface(REFIID riid, LPVOID * ppvObj){
        if (!ppvObj)
            return E_INVALIDARG;
        *ppvObj = NULL;
        if (riid == IID_IUnknown  || riid == IID_IDirect3DIndexBuffer9){
            *ppvObj = (LPVOID)this;
            AddRef();
            return NOERROR;
        }
        return E_NOINTERFACE;
    }

    ULONG __declspec(nothrow) __stdcall AddRef(){
        InterlockedIncrement(&refCount);
        return refCount;
    }

    ULONG __declspec(nothrow) __stdcall Release(){
        ULONG ref = InterlockedDecrement(&refCount);
        if (refCount == 0)
            delete this;
        return ref;
    }

    MyD3DIndexBuffer9(IDirect3DIndexBuffer9* origIb, DWORD flags_)
            :vtbl(0), orig(origIb), refCount(1), flags(flags_), dynamic(false), writeonly(false){
        dynamic = (flags & D3DUSAGE_DYNAMIC) != 0;
        writeonly = (flags & D3DUSAGE_WRITEONLY) != 0;
        vtbl = new VtblMethod[vtblSize];
        initVtbl();
    }

    HRESULT __declspec(nothrow) __stdcall Lock(UINT OffsetToLock, UINT SizeToLock, void** ppbData, DWORD Flags){
        if (!orig)
            return E_FAIL;
        return orig->Lock(OffsetToLock, SizeToLock, ppbData, Flags);
    }

    ~MyD3DIndexBuffer9(){
        if (orig){
            orig->Release();
            orig = 0;
        }
        delete[] vtbl;
    }
private:
    void initVtbl(){
        int index = 0;
        for (int i = 0; i < vtblSize; i++)
            vtbl[i] = 0;

#define defaultInit(i) vtbl[i] = &vtblMapper<(i)>; index++
        //STDMETHOD(QueryInterface)(THIS_ REFIID riid, void** ppvObj) PURE;
        unsafeCast(vtbl[0], &MyD3DIndexBuffer9::QueryInterface); index++;
        //STDMETHOD_(ULONG,AddRef)(THIS) PURE;
        unsafeCast(vtbl[1], &MyD3DIndexBuffer9::AddRef); index++;
        //STDMETHOD_(ULONG,Release)(THIS) PURE;
        unsafeCast(vtbl[2], &MyD3DIndexBuffer9::Release); index++;

        // IDirect3DResource9 methods 
        //STDMETHOD(GetDevice)(THIS_ IDirect3DDevice9** ppDevice) PURE;
        defaultInit(3);
        //STDMETHOD(SetPrivateData)(THIS_ REFGUID refguid,CONST void* pData,DWORD SizeOfData,DWORD Flags) PURE;
        defaultInit(4);
        //STDMETHOD(GetPrivateData)(THIS_ REFGUID refguid,void* pData,DWORD* pSizeOfData) PURE;
        defaultInit(5);
        //STDMETHOD(FreePrivateData)(THIS_ REFGUID refguid) PURE;
        defaultInit(6);
        //STDMETHOD_(DWORD, SetPriority)(THIS_ DWORD PriorityNew) PURE;
        defaultInit(7);
        //STDMETHOD_(DWORD, GetPriority)(THIS) PURE;
        defaultInit(8);
        //STDMETHOD_(void, PreLoad)(THIS) PURE;
        defaultInit(9);
        //STDMETHOD_(D3DRESOURCETYPE, GetType)(THIS) PURE;
        defaultInit(10);
        //STDMETHOD(Lock)(THIS_ UINT OffsetToLock,UINT SizeToLock,void** ppbData,DWORD Flags) PURE;
        //defaultInit(11);
        unsafeCast(vtbl[11], &MyD3DIndexBuffer9::Lock); index++;
        //STDMETHOD(Unlock)(THIS) PURE;
        defaultInit(12);
        //STDMETHOD(GetDesc)(THIS_ D3DINDEXBUFFER_DESC *pDesc) PURE;
        defaultInit(13);
#undef defaultInit
    }
};

To swap it with real interface, you'll have to use reinterpret_cast.

        MyD3DIndexBuffer9* myIb = reinterpret_cast<MyD3DIndexBuffer9*>(pIndexData);

As you can see this method requires assembly, macros, templates combined together with casting class method pointer to void*. Also it is compiler-dependent(msvc, although you should be able to do same trick with g++) and architecture-dependent (32/64-bit). Plus it is unsafe (as with dispatch table hacking).

The advantage compared to dispatch tables you can use custom class and store additional data within interface. However:

  1. All virtual methods are forbidden. (as far as I know, any attempt to use virtual method will instantly insert invisible 4-bytes pointer at the beginning of the class, which will break everything).
  2. Calling convention must be stdcall (should work with cdecl, though, but for everything else you'll need different wrapper)
  3. You have to initialize entire vtable yourself (very error-prone). One mistake, and everything will crash.
like image 101
SigTerm Avatar answered Mar 16 '26 08:03

SigTerm