//////////////////////////////////////////////////////////////////////////////
//
//  Copyright (c) 1999-2001 Microsoft Corporation
//
//  Module Name:
//      DllSrc.cpp
//
//  Description:
//      DLL services/entry points.
//
//  Maintained By:
//      David Potter    (DavidP)    19-MAR-2001
//      Geoffrey Pease  (GPease)    18-OCT-1999
//
//  Notes:
//      Switches:
//      - ENTRY_PREFIX
//          If defined, include proxy/stub code into the DLL that is
//          generated by the MIDL compiler.
//      - USE_FUSION
//          If defined, initialize and uninitialize Fusion on process
//          attach and detach respectively.  The constant IDR_MANIFEST
//          must be defined with a value that represents the resource ID
//          for the manifest resource.
//      - NO_DLL_MAIN
//          If defined, don't implement DllMain.
//      - DEBUG_SW_TRACING_ENABLED
//          If defined, initialize and uninitialize software tracing on
//          process attach and detach respectively.
//      -- NO_THREAD_OPTIMIZATIONS
//          If defined, disable the thread notification calls when a thread
//          starts up or goes away.
//
//////////////////////////////////////////////////////////////////////////////

//
// DLL Globals
//
HINSTANCE g_hInstance = NULL;
LONG      g_cObjects  = 0;
LONG      g_cLock     = 0;
TCHAR     g_szDllFilename[ MAX_PATH ] = { 0 };

LPVOID    g_GlobalMemoryList = NULL;    // Global memory tracking list

#if defined( ENTRY_PREFIX )
extern "C"
{
    extern HINSTANCE hProxyDll;
}
#endif

//
//  Macros to generate RPC entry points
//
#define __rpc_macro_expand2( a, b ) a##b
#define __rpc_macro_expand( a, b ) __rpc_macro_expand2( a, b )

#if ! defined( NO_DLL_MAIN ) || defined( ENTRY_PREFIX ) || defined( DEBUG )
//////////////////////////////////////////////////////////////////////////////
//++
//
//  DllMain
//
//  Description:
//      Dll entry point.
//
//  Arguments:
//      hInstIn      - DLL instance handle.
//      ulReasonIn   - DLL reason code for entrance.
//      lpReservedIn - Not used.
//
//  Return Values:
//      TRUE
//
//--
//////////////////////////////////////////////////////////////////////////////
BOOL
WINAPI
DllMain(
    HINSTANCE   hInstIn,
    ULONG       ulReasonIn,
    LPVOID      // lpReservedIn
    )
{
    //
    // KB: THREAD_OPTIMIZATIONS gpease 19-OCT-1999
    //
    // By defining this you can prevent the linker
    // from calling your DllEntry for every new thread.
    // This makes creating new threads significantly
    // faster if every DLL in a process does it.
    // Unfortunately, not all DLLs do this.
    //
    // In CHKed/DEBUG, we keep this on for memory
    // tracking.
    //
#if ! defined( DEBUG )
    #define THREAD_OPTIMIZATIONS
#endif // DEBUG

    switch( ulReasonIn )
    {
        //////////////////////////////////////////////////////////////////////
        // DLL_PROCESS_ATTACH
        //////////////////////////////////////////////////////////////////////
        case DLL_PROCESS_ATTACH:
        {
#if defined( DEBUG_SW_TRACING_ENABLED )
            TraceInitializeProcess( g_rgTraceControlGuidList, ARRAYSIZE( g_rgTraceControlGuidList ), TRUE );
#else // ! DEBUG_SW_TRACING_ENABLED
            TraceInitializeProcess( TRUE );
#endif // DEBUG_SW_TRACING_ENABLED

            TraceFunc( "" );
            TraceMessage( TEXT(__FILE__),
                          __LINE__,
                          __MODULE__,
                          mtfDLL,
                          TEXT("DLL: DLL_PROCESS_ATTACH - ThreadID = %#x"),
                          GetCurrentThreadId()
                          );

            g_hInstance = hInstIn;

#if defined( ENTRY_PREFIX )
             hProxyDll = g_hInstance;
#endif // ENTRY_PREFIX

            GetModuleFileName( g_hInstance, g_szDllFilename, ARRAYSIZE( g_szDllFilename ) );

            //
            // Create a global memory list so that memory allocated by one
            // thread and handed to another can be tracked without causing
            // unnecessary trace messages.
            //
            TraceCreateMemoryList( g_GlobalMemoryList );

#if defined( THREAD_OPTIMIZATIONS )
            {
                //
                // Disable thread library calls so that we don't get called
                // on thread attach and detach.
                //
                BOOL fResult = DisableThreadLibraryCalls( g_hInstance );
                if ( ! fResult )
                {
                    TW32MSG( GetLastError(), "DisableThreadLibraryCalls()" );
                }
            }
#endif // THREAD_OPTIMIZATIONS

#if defined( USE_FUSION )
            //
            // Initialize Fusion.
            //
            // The value of IDR_MANIFEST in the call to
            // SHFusionInitializeFromModuleID() must match the value specified in the
            // sources file for SXS_MANIFEST_RESOURCE_ID.
            //
            BOOL fResult = SHFusionInitializeFromModuleID( hInstIn, IDR_MANIFEST );
            if ( ! fResult )
            {
                TW32MSG( GetLastError(), "SHFusionInitializeFromModuleID()" );
            }
#endif // USE_FUSION

            //
            // This is necessary here because TraceFunc() defines a variable
            // on the stack which isn't available outside the scope of this
            // block.
            // This function doesn't do anything but clean up after
            // TraceFunc().
            //
            FRETURN( TRUE );

            break;
        } // case: DLL_PROCESS_ATTACH

        //////////////////////////////////////////////////////////////////////
        // DLL_PROCESS_DETACH
        //////////////////////////////////////////////////////////////////////
        case DLL_PROCESS_DETACH:
        {
            TraceFunc( "" );
            TraceMessage( TEXT(__FILE__),
                          __LINE__,
                          __MODULE__,
                          mtfDLL,
                          TEXT("DLL: DLL_PROCESS_DETACH - ThreadID = %#x [ g_cLock=%u, g_cObjects=%u ]"),
                          GetCurrentThreadId(),
                          g_cLock,
                          g_cObjects
                          );

            //
            // Cleanup the global memory list used to track memory allocated
            // in one thread and then handed to another.
            //
            TraceTerminateMemoryList( g_GlobalMemoryList );

            //
            // This is necessary here because TraceFunc() defines a variable
            // on the stack which isn't available outside the scope of this
            // block.
            // This function doesn't do anything but clean up after
            // TraceFunc().
            //
            FRETURN( TRUE );

#if defined( DEBUG_SW_TRACING_ENABLED )
            TraceTerminateProcess( g_rgTraceControlGuidList, ARRAYSIZE( g_rgTraceControlGuidList )
                                   );
#else // ! DEBUG_SW_TRACING_ENABLED
            TraceTerminateProcess();
#endif // DEBUG_SW_TRACING_ENABLED

#if defined( USE_FUSION )
            SHFusionUninitialize();
#endif // USE_FUSION

            break;
        } // case: DLL_PROCESS_DETACH

#if ! defined( THREAD_OPTIMIZATIONS )
        //////////////////////////////////////////////////////////////////////
        // DLL_THREAD_ATTACH
        //////////////////////////////////////////////////////////////////////
        case DLL_THREAD_ATTACH:
        {
            TraceInitializeThread( NULL );
            TraceMessage( TEXT(__FILE__),
                          __LINE__,
                          __MODULE__,
                          mtfDLL,
                          TEXT("Thread %#x has started."),
                          GetCurrentThreadId()
                          );
            TraceFunc( "" );
            TraceMessage( TEXT(__FILE__),
                          __LINE__,
                          __MODULE__,
                          mtfDLL,
                          TEXT("DLL: DLL_THREAD_ATTACH - ThreadID = %#x [ g_cLock=%u, g_cObjects=%u ]"),
                          GetCurrentThreadId(),
                          g_cLock,
                          g_cObjects
                          );

            //
            // This is necessary here because TraceFunc() defines a variable
            // on the stack which isn't available outside the scope of this
            // block.
            // This function doesn't do anything but clean up after
            // TraceFunc().
            //
            FRETURN( TRUE );

            break;
        } // case: DLL_THREAD_ATTACH

        //////////////////////////////////////////////////////////////////////
        // DLL_THREAD_DETACH
        //////////////////////////////////////////////////////////////////////
        case DLL_THREAD_DETACH:
        {
            TraceFunc( "" );
            TraceMessage( TEXT(__FILE__),
                          __LINE__,
                          __MODULE__,
                          mtfDLL,
                          TEXT("DLL: DLL_THREAD_DETACH - ThreadID = %#x [ g_cLock=%u, g_cObjects=%u ]"),
                          GetCurrentThreadId(),
                          g_cLock,
                          g_cObjects
                          );

            //
            // This is necessary here because TraceFunc() defines a variable
            // on the stack which isn't available outside the scope of this
            // block.
            // This function doesn't do anything but clean up after
            // TraceFunc().
            //
            FRETURN( TRUE );

            TraceThreadRundown();

            break;
        } // case: DLL_THREAD_DETACH
#endif // ! THREAD_OPTIMIZATIONS

        default:
        {
            TraceFunc( "" );
            TraceMessage( TEXT(__FILE__),
                          __LINE__,
                          __MODULE__,
                          mtfDLL,
                          TEXT("DLL: UNKNOWN ENTRANCE REASON - ThreadID = %#x [ g_cLock=%u, g_cObjects=%u ]"),
                          GetCurrentThreadId(),
                          g_cLock,
                          g_cObjects
                          );

#if defined( THREAD_OPTIMIZATIONS )
            Assert( ( ulReasonIn != DLL_THREAD_ATTACH )
                &&  ( ulReasonIn != DLL_THREAD_DETACH ) );
#endif // THREAD_OPTIMIZATIONS

            //
            // This is necessary here because TraceFunc defines a variable
            // on the stack which isn't available outside the scope of this
            // block.
            // This function doesn't do anything but clean up after TraceFunc.
            //
            FRETURN( TRUE );

            break;
        } // default case
    } // switch on reason code

    return TRUE;

} //*** DllMain()
#endif // ! defined( NO_DLL_MAIN ) && ! defined( ENTRY_PREFIX ) && ! defined( DEBUG )

//////////////////////////////////////////////////////////////////////////////
//++
//
//  DllGetClassObject
//
//  Description:
//      OLE calls this to get the class factory from the DLL.
//
//  Arguments:
//      rclsidIn
//          - Class ID of the object that the class factory should create.
//      riidIn
//          - Interface of the class factory
//      ppvOut
//          - The interface pointer to the class factory.
//
//  Return Values:
//      S_OK            - Operation completed successfully.
//      E_POINTER       - Required output parameter was specified as NULL.
//      CLASS_E_CLASSNOTAVAILABLE
//                      - Class ID not supported by this DLL.
//      E_OUTOFMEMORY   - Error allocating memory.
//      Other HRESULTs to indicate failure.
//
//--
//////////////////////////////////////////////////////////////////////////////
STDAPI
DllGetClassObject(
    REFCLSID    rclsidIn,
    REFIID      riidIn,
    void **     ppvOut
    )
{
    TraceFunc( "rclsidIn, riidIn, ppvOut" );

    LPCFACTORY  lpClassFactory;
    HRESULT     hr;
    int         idx;

    if ( ppvOut == NULL )
    {
        hr = E_POINTER;
        goto Cleanup;
    } // if: bad argument

    hr = CLASS_E_CLASSNOTAVAILABLE;
    idx = 0;
    while( g_DllClasses[ idx ].rclsid )
    {
        if ( *g_DllClasses[ idx ].rclsid == rclsidIn )
        {
            TraceMessage( TEXT(__FILE__), __LINE__, __MODULE__, mtfFUNC, L"rclsidIn = %s", g_DllClasses[ idx ].pszName );
            hr = S_OK;
            break;

        } // if: class found

        idx++;

    } // while: finding class

    // Didn't find the class ID.
    if ( hr == CLASS_E_CLASSNOTAVAILABLE )
    {
        TraceMsgGUID( mtfFUNC, "rclsidIn = ", rclsidIn );
#if defined( ENTRY_PREFIX )
        //
        //  See if the MIDL generated code can create it.
        //
        hr = STHR( __rpc_macro_expand( ENTRY_PREFIX, DllGetClassObject )( rclsidIn, riidIn, ppvOut ) );
#endif // defined( ENTRY_PREFIX )
        goto Cleanup;
    } // if: class not found

    Assert( g_DllClasses[ idx ].pfnCreateInstance != NULL );

    lpClassFactory = new CFactory;
    if ( lpClassFactory == NULL )
    {
        hr = E_OUTOFMEMORY;
        goto Cleanup;

    } // if: memory failure

    hr = THR( lpClassFactory->HrInit( g_DllClasses[ idx ].pfnCreateInstance ) );
    if ( FAILED( hr ) )
    {
        TraceDo( lpClassFactory->Release() );
        goto Cleanup;

    } // if: initialization failed

    // Can't safe type.
    hr = lpClassFactory->QueryInterface( riidIn, ppvOut );
    //
    // Release the created instance to counter the AddRef() in Init().
    //
    ((IUnknown *) lpClassFactory )->Release();

Cleanup:
    HRETURN( hr );

} //*** DllGetClassObject()


//////////////////////////////////////////////////////////////////////////////
//++
//
//  DllRegisterServer
//
//  Description:
//      OLE's register entry point.
//
//  Argument:
//      None.
//
//  Return Values:
//      S_OK    - Operation completed successfully.
//      Other HRESULTs to indicate failure.
//
//--
//////////////////////////////////////////////////////////////////////////////
STDAPI
DllRegisterServer( void )
{
    HRESULT hr;

    TraceFunc( "" );

    hr = THR( HrRegisterDll( TRUE ) );

#if defined( ENTRY_PREFIX )
    if ( SUCCEEDED( hr ) )
    {
        hr = THR( __rpc_macro_expand( ENTRY_PREFIX, DllRegisterServer )() );
    } // if: register proxy/stub
#endif // defined( ENTRY_PREFIX )

    HRETURN( hr );

} //*** DllRegisterServer()

//////////////////////////////////////////////////////////////////////////////
//++
//
//  DllUnregisterServer
//
//  Description:
//      OLE's unregister entry point.
//
//  Arguments:
//      None.
//
//  Return Values:
//      S_OK    - Operation completed successful.
//      Other HRESULTs to indicate failure.
//
//--
//////////////////////////////////////////////////////////////////////////////
STDAPI
DllUnregisterServer( void )
{
    TraceFunc( "" );

    HRESULT hr;

    hr = THR( HrRegisterDll( FALSE ) );

#if defined( ENTRY_PREFIX )
    if ( SUCCEEDED( hr ) )
    {
        hr = THR( __rpc_macro_expand( ENTRY_PREFIX, DllUnregisterServer )() );
    } // if: unregister proxy/stub
#endif // defined( ENTRY_PREFIX )

    HRETURN( hr );

} //*** DllUnregisterServer()

//////////////////////////////////////////////////////////////////////////////
//++
//
//  DllCanUnloadNow
//
//  Description:
//      OLE calls this entry point to see if it can unload the DLL.
//
//  Arguments:
//      None.
//
//  Return Values:
//      S_OK    - Can unload the DLL.
//      S_FALSE - Can not unload the DLL.
//
//--
//////////////////////////////////////////////////////////////////////////////
STDAPI
DllCanUnloadNow( void )
{
    TraceFunc( "" );

    HRESULT hr = S_OK;

    if ( g_cLock || g_cObjects )
    {
        TraceMsg( mtfDLL, "DLL: Can't unload - g_cLock=%u, g_cObjects=%u", g_cLock, g_cObjects );
        hr = S_FALSE;

    } // if: any object or locks
#if defined( ENTRY_PREFIX )
    else
    {
        //
        //  Check with the MIDL generated proxy/stubs.
        //
        hr = STHR( __rpc_macro_expand( ENTRY_PREFIX, DllCanUnloadNow )() );
    }
#endif

    HRETURN( hr );

} //*** DLlCanUnloadNow()

//////////////////////////////////////////////////////////////////////////////
//++
//
//  HrCoCreateInternalInstance
//
//  Description:
//      Mimic CoCreateInstance() except that it looks into the DLL table
//      to see if we can shortcut the CoCreate with a simple CreateInstance
//      call.
//
//  Arguments: (matches CoCreateInstance)
//      rclsidIn        -   Class identifier (CLSID) of the object
//      pUnkOuterIn     -   Pointer to controlling IUnknown
//      dwClsContext    -   Context for running executable code
//      riidIn          -   Reference to the identifier of the interface
//      ppvOut          -   Address of output variable that receives
//
//  Return Values:
//      S_OK            - Success.
//      E_OUTOFMEMORY   - Out of memory.
//      other HRESULT values
//
//////////////////////////////////////////////////////////////////////////////
HRESULT
HrCoCreateInternalInstance(
    REFCLSID rclsidIn,
    LPUNKNOWN pUnkOuterIn,
    DWORD dwClsContextIn,
    REFIID riidIn,
    LPVOID * ppvOut
    )
{
    TraceFunc( "" );

    Assert( ppvOut != NULL );

    HRESULT hr = CLASS_E_CLASSNOTAVAILABLE;

    //
    // Limit simple CoCreate() only works to INPROC and non-aggregatable objects.
    //

    if (    ( dwClsContextIn & CLSCTX_INPROC_HANDLER )  // inproc only
        &&  ( pUnkOuterIn == NULL )                     // no aggregation
       )
    {
        int idx;

        //
        // Try to find the class in our DLL table.
        //
        for( idx = 0; g_DllClasses[ idx ].rclsid != NULL; idx++ )
        {
            if ( *g_DllClasses[ idx ].rclsid == rclsidIn )
            {
                LPUNKNOWN punk;
                Assert( g_DllClasses[ idx ].pfnCreateInstance != NULL );

                hr = THR( g_DllClasses[ idx ].pfnCreateInstance( &punk ) );
                if ( SUCCEEDED( hr ) )
                {
                    // Can't safe type.
                    hr = THR( punk->QueryInterface( riidIn, ppvOut ) );
                    punk->Release();
                } // if: got object

                break;  // bail loop

            } // if: class found

        } // for: finding class

    } // if: simple CoCreate()

    //
    // If not found or asking for something we do not support,
    // use the COM version.
    //

    if ( hr == CLASS_E_CLASSNOTAVAILABLE )
    {
        //
        // Try it the old fashion way...
        //
        hr = THR( CoCreateInstance( rclsidIn, pUnkOuterIn, dwClsContextIn, riidIn, ppvOut ) );

    } // if: class not found

    HRETURN( hr );

} //*** HrClusCoCreateInstance()


//
// TODO:    gpease 27-NOV-1999
//          While perusing around the MIDL SDK, I foud that
//          RPC creates the same type of class table we do. Maybe
//          we can leverage the MIDL code to create our objects(??).
//
