c++windowsdllshared-libraries

c++ dll loading: Call function before global variable constructors


I'm working on a plugin based system that uses shared libraries for the plugins. I want to set up some static variables in these libraries so I can use them in various places, including global constructors. Can I call a function from my main executable in a shared library to run before my global constructors in said shared library? Whether I use shared memory or I register static variables by passing them to the shared library, this set up code would otherwise happen after global constructors within the dll.

For further context, some use-cases are: loggers, allocators, and console variables/cvars. These systems all are used pervasively across the program and I want to be able to use static variables from these global constructors. In particular, AutoCvars are a common use case where a specific "variable" can be used from a file, and for simplicity, it is placed in the global space so it can grab the data immediately, and live through the lifetime of the application: https://vkguide.dev/docs/extra-chapter/cvar_system/

Here's a code example:

Allocator.hpp

#pragma once

class Allocator {
public:
    static Allocator* GetAllocator();
    static void SetAllocator(Allocator* allocator);

    void* Allocate(size_t size);
    size_t GetUsedSize() const;

protected:
    static Allocator allocatorSingleton;
    size_t usedSize = 0;
};

Allocator.cpp:

#include "Allocator.hpp"

static Allocator* allocatorState = nullptr;

Allocator* Allocator::GetAllocator() {
    return allocatorState;
}

void Allocator::SetAllocator(Allocator* allocator) {
    allocatorState = allocator;
}

void* Allocator::Allocate(size_t size) {
    // ALLOCATION OCCURS HERE
    usedSize += size;

    return nullptr;
}

size_t Allocator::GetUsedSize() const {
    return usedSize;
}

Main.cpp:

#include <iostream>
#include <Windows.h>
#include "Allocator.hpp"

int main() {
    Allocator allocator;
    Allocator::SetAllocator(&allocator);

    HINSTANCE dllHandle = LoadLibrary(L"Library.dll");
    if (!dllHandle) {
        printf("Could not load the dynamic library.\n");
        return EXIT_FAILURE;
    }

    printf("Allocated Size Before: %zu.\n", allocator.GetUsedSize());
    auto fn = reinterpret_cast<void(*)(Allocator*)>(GetProcAddress(dllHandle, "RegisterAllocator"));
    if (!fn) {
        printf("Could not locate the function.\n");
        return EXIT_FAILURE;
    }

    fn(&allocator);
    printf("Allocated Size After: %zu.\n", allocator.GetUsedSize());

    return 0;
}

DllMain.cpp:

#include "../CVarTest/Allocator.hpp"

class SampleClass {
public:
    SampleClass() {
        Allocator::GetAllocator()->Allocate(42);
    }
};

// Here, SampleClass doesn't work, and causes an error. This is what I want to get working.
SampleClass myInstance;

extern "C" {
    void __declspec(dllexport) RegisterAllocator(Allocator* sourceAllocator) {
        Allocator::SetAllocator(sourceAllocator);

        // Here, SampleClass does work, and the final usedSize is 42.
        // SampleClass myInstance;
    }
}

Solution

  • Both answers from @RbMm and @AhmedAEK seemed to work well. Thanks to everyone for their suggestions. I believe I'll use Method 2, as it seems cleaner. Here are code examples of each:

    Method 1: LdrRegisterDllNotification from @RbMm

    This method uses DLL loading callbacks from ntdll.dll to call methods in the new libraries. It seems like there may be some vulnerabilities here in case you're not careful, and will likely require more thorough investigation.

    #include <iostream>
    #include <Windows.h>
    #include "Allocator.hpp"
    
    typedef struct _UNICODE_STR
    {
        USHORT Length;
        USHORT MaximumLength;
        PWSTR pBuffer;
    } UNICODE_STR, * PUNICODE_STR;
    
    // Sources:
    // https://shorsec.io/blog/dll-notification-injection/
    // https://modexp.wordpress.com/2020/08/06/windows-data-structures-and-callbacks-part-1/
    
    typedef struct _LDR_DLL_LOADED_NOTIFICATION_DATA {
        ULONG           Flags;             // Reserved.
        PUNICODE_STR FullDllName;       // The full path name of the DLL module.
        PUNICODE_STR BaseDllName;       // The base file name of the DLL module.
        PVOID           DllBase;           // A pointer to the base address for the DLL in memory.
        ULONG           SizeOfImage;       // The size of the DLL image, in bytes.
    } LDR_DLL_LOADED_NOTIFICATION_DATA, * PLDR_DLL_LOADED_NOTIFICATION_DATA;
    
    typedef struct _LDR_DLL_UNLOADED_NOTIFICATION_DATA {
        ULONG           Flags;             // Reserved.
        PUNICODE_STR FullDllName;       // The full path name of the DLL module.
        PUNICODE_STR BaseDllName;       // The base file name of the DLL module.
        PVOID           DllBase;           // A pointer to the base address for the DLL in memory.
        ULONG           SizeOfImage;       // The size of the DLL image, in bytes.
    } LDR_DLL_UNLOADED_NOTIFICATION_DATA, * PLDR_DLL_UNLOADED_NOTIFICATION_DATA;
    
    typedef union _LDR_DLL_NOTIFICATION_DATA {
        LDR_DLL_LOADED_NOTIFICATION_DATA   Loaded;
        LDR_DLL_UNLOADED_NOTIFICATION_DATA Unloaded;
    } LDR_DLL_NOTIFICATION_DATA, * PLDR_DLL_NOTIFICATION_DATA;
    
    typedef VOID(CALLBACK* PLDR_DLL_NOTIFICATION_FUNCTION)(
        ULONG                       NotificationReason,
        PLDR_DLL_NOTIFICATION_DATA  NotificationData,
        PVOID                       Context);
    
    typedef NTSTATUS(NTAPI* _LdrRegisterDllNotification) (
        ULONG                          Flags,
        PLDR_DLL_NOTIFICATION_FUNCTION NotificationFunction,
        PVOID                          Context,
        PVOID* Cookie);
    
    VOID MyCallback(ULONG NotificationReason, const PLDR_DLL_NOTIFICATION_DATA NotificationData, PVOID Context) {
        if (lstrcmpiW(NotificationData->Loaded.BaseDllName->pBuffer, L"Library.dll") != 0) {
            return;
        }
    
        HINSTANCE dllHandle = reinterpret_cast<HINSTANCE>(NotificationData->Loaded.DllBase);
        auto fn = reinterpret_cast<void(*)(Allocator*)>(GetProcAddress(dllHandle, "RegisterAllocator"));
        if (!fn) {
            printf("Could not locate the function.\n");
            return;
        }
    
        fn(Allocator::GetAllocator());
    }
    
    int main() {
        Allocator allocator;
        Allocator::SetAllocator(&allocator);
    
        HMODULE hNtdll = GetModuleHandleA("NTDLL.dll");
    
        if (hNtdll != NULL) {
            _LdrRegisterDllNotification pLdrRegisterDllNotification = (_LdrRegisterDllNotification)GetProcAddress(hNtdll, "LdrRegisterDllNotification");
    
            PVOID cookie;
            NTSTATUS status = pLdrRegisterDllNotification(0, (PLDR_DLL_NOTIFICATION_FUNCTION)MyCallback, NULL, &cookie);
            if (status != 0) {
                printf("Failed to load DLL Callback! Exiting\n");
                return EXIT_FAILURE;
            }
        }
        else {
            printf("Failed to load NTDLL.dll! Exiting\n");
            return EXIT_FAILURE;
        }
    
        printf("Allocated Size Before: %zu.\n", allocator.GetUsedSize());
        HINSTANCE dllHandle = LoadLibrary(L"Library.dll");
        if (!dllHandle) {
            printf("Could not load the dynamic library.\n");
            return EXIT_FAILURE;
        }
    
        printf("Allocated Size After: %zu.\n", allocator.GetUsedSize());
    
        return 0;
    }
    

    Method 2: Getting the variable from a Singleton's Getter from @AhmedAEK

    This alteration to Allocator.cpp uses the Getter of the Singleton to find its value in the main executable if it is not yet set. I use GetModuleHandle(NULL) to get the main module, and find the ProcAddress of a non-member function that will return the singleton's value.

    // ...
    
    #if MAIN_EXECUTABLE
    extern "C" {
        __declspec(dllexport) Allocator* GetAllocator() {
            return Allocator::GetAllocator();
        }
    }
    
    Allocator* Allocator::GetAllocator() {
        return allocatorState;
    }
    #else
    #include <Windows.h>
    
    Allocator* Allocator::GetAllocator() {
        if (allocatorState == nullptr) {
            HMODULE module = GetModuleHandle(NULL);
            auto GetAllocatorFn = (Allocator*(*)(void))GetProcAddress(module, "GetAllocator");
            if (GetAllocatorFn != NULL) {
                return GetAllocatorFn();
            }
    
            return nullptr;
        }
    
        return allocatorState;
    }
    #endif