/*******************************************************************************
* AUDIT.C
*
*     This module contains the routines for logging audit events
*
* Copyright (C) 1997-1999 Microsoft Corp.
*******************************************************************************/

#include "precomp.h"
#pragma hdrstop

#include <rpc.h>
#include <msaudite.h>
#include <ntlsa.h>
#include <authz.h>
#include <authzi.h>

HANDLE AuditLogHandle = NULL;
HANDLE SystemLogHandle = NULL;

#define MAX_INSTANCE_MEMORYERR 20

/*
 * Global data
 */
//Authz Changes
AUTHZ_RESOURCE_MANAGER_HANDLE hRM         = NULL;
extern RTL_CRITICAL_SECTION g_AuthzCritSection;
//END Authz Changes

/*
 * External procedures defined
 */
VOID
AuditEvent( PWINSTATION pWinstation, ULONG EventId );

NTSTATUS
AuthzReportEventW( IN PAUTHZ_AUDIT_EVENT_TYPE_HANDLE pHAET, 
                   IN DWORD Flags, 
                   IN ULONG EventId, 
                   IN PSID pUserID, 
                   IN USHORT NumStrings,
                   IN ULONG DataSize OPTIONAL, //Future - DO NOT USE
                   IN PWSTR* Strings,
                   IN PVOID  Data OPTIONAL         //Future - DO NOT USE
                   );


BOOL AuthzInit( IN DWORD Flags,
                IN USHORT CategoryID,
                IN USHORT AuditID,
                IN USHORT ParameterCount,
                OUT PAUTHZ_AUDIT_EVENT_TYPE_HANDLE phAuditEventType
                );

BOOLEAN
AuditingEnabled ();

VOID
AuditEnd();


/*
 * Internal procedures defined
 */
NTSTATUS
AdtBuildLuidString(
    IN PLUID Value,
    OUT PUNICODE_STRING ResultantString
    );



BOOLEAN
IsAuditLogFull(
    HANDLE LogHandle
    )
{
    BOOLEAN retval = TRUE;
    EVENTLOG_FULL_INFORMATION EventLogFullInformation;
    DWORD dwBytesNeeded;

    if (GetEventLogInformation(LogHandle, 
                               EVENTLOG_FULL_INFO, 
                               &EventLogFullInformation, 
                               sizeof(EventLogFullInformation), 
                               &dwBytesNeeded )   ) {
        if (EventLogFullInformation.dwFull == FALSE) {
            retval = FALSE;
        }
    }

    return retval;
}



NTSTATUS
AdtBuildLuidString(
    IN PLUID Value,
    OUT PUNICODE_STRING ResultantString
    )

/*++

Routine Description:

    This function builds a unicode string representing the passed LUID.

    The resultant string will be formatted as follows:

        (0x00005678,0x12340000)

Arguments:

    Value - The value to be transformed to printable format (Unicode string).

    ResultantString - Points to the unicode string header.  The body of this
        unicode string will be set to point to the resultant output value
        if successful.  Otherwise, the Buffer field of this parameter
        will be set to NULL.

    FreeWhenDone - If TRUE, indicates that the body of the ResultantString
        must be freed to process heap when no longer needed.


Return Values:

    STATUS_NO_MEMORY - indicates memory could not be allocated
        for the string body.

    All other Result Codes are generated by called routines.

--*/

{
    NTSTATUS                Status;
    UNICODE_STRING          IntegerString;

    ULONG                   Buffer[(16*sizeof(WCHAR))/sizeof(ULONG)];


    IntegerString.Buffer = (PWCHAR)&Buffer[0];
    IntegerString.MaximumLength = 16*sizeof(WCHAR);


    //
    // Length (in WCHARS) is  3 for   (0x
    //                       10 for   1st hex number
    //                        3 for   ,0x
    //                       10 for   2nd hex number
    //                        1 for   )
    //                        1 for   null termination
    //

    ResultantString->Length        = 0;
    ResultantString->MaximumLength = 28 * sizeof(WCHAR);

    ResultantString->Buffer = RtlAllocateHeap( RtlProcessHeap(), 0,
                                               ResultantString->MaximumLength);
    if (ResultantString->Buffer == NULL) {
        return(STATUS_NO_MEMORY);
    }

    Status = RtlAppendUnicodeToString( ResultantString, L"(0x" );
    Status = RtlIntegerToUnicodeString( Value->HighPart, 16, &IntegerString );
    Status = RtlAppendUnicodeToString( ResultantString, IntegerString.Buffer );

    Status = RtlAppendUnicodeToString( ResultantString, L",0x" );
    Status = RtlIntegerToUnicodeString( Value->LowPart, 16, &IntegerString );
    Status = RtlAppendUnicodeToString( ResultantString, IntegerString.Buffer );

    Status = RtlAppendUnicodeToString( ResultantString, L")" );

    return(STATUS_SUCCESS);
}


VOID
AuditEvent( PWINSTATION pWinstation, ULONG EventId )
{
    NTSTATUS Status, Status2;
    UNICODE_STRING LuidString;
    PWSTR StringPointerArray[6];
    USHORT StringIndex = 0;
    TOKEN_STATISTICS TokenInformation;
    ULONG ReturnLength;
    BOOLEAN WasEnabled;
    LUID LogonId = {0,0};
    AUTHZ_AUDIT_EVENT_TYPE_HANDLE hAET = NULL;

    if (!AuditingEnabled() )
        return;

    Status = RtlAdjustPrivilege(
                 SE_SECURITY_PRIVILEGE,
                 TRUE,    // Enable the PRIVILEGE
                 FALSE,    // Don't Use Thread token (under impersonation)
                 &WasEnabled
                 );

    if ( Status == STATUS_NO_TOKEN ) {
        DBGPRINT(("TERMSRV: AuditEvent: RtlAdjustPrivilege failure 0x%x\n",Status));
        return;
    }

    //
    //AUTHZ Changes 
    //

    if( !AuthzInit( 0, SE_CATEGID_LOGON, (USHORT)EventId, 6, &hAET ))
        goto badAuthzInit;
            
    if (pWinstation->UserName && (wcslen(pWinstation->UserName) > 0)) {
        StringPointerArray[StringIndex] = pWinstation->UserName;
    } else {
        StringPointerArray[StringIndex] = L"Unknown";
    }
    StringIndex++;

    if (pWinstation->Domain  && (wcslen(pWinstation->Domain) > 0)) {
        StringPointerArray[StringIndex] = pWinstation->Domain;
    } else {
        StringPointerArray [StringIndex] =  L"Unknown";
    }
    StringIndex++;

    if (pWinstation->UserToken != NULL) {
        Status = NtQueryInformationToken (
            pWinstation->UserToken,
            TokenStatistics,
            &TokenInformation,
            sizeof(TokenInformation),
            &ReturnLength
            );
    
        if (NT_SUCCESS(Status)) {
    
            Status = AdtBuildLuidString( &(TokenInformation.AuthenticationId), &LuidString );
        } else {
            Status = AdtBuildLuidString( &LogonId, &LuidString );
        }
    } else {
        Status = AdtBuildLuidString( &LogonId, &LuidString );
    }
    StringPointerArray[StringIndex] = LuidString.Buffer;
    StringIndex++;

    if (pWinstation->WinStationName && (wcslen(pWinstation->WinStationName) > 0)) {
        StringPointerArray[StringIndex] = pWinstation->WinStationName;
    } else {
        StringPointerArray[StringIndex] = L"Unknown" ;
    }
    StringIndex++;

    if (pWinstation->Client.ClientName && (wcslen(pWinstation->Client.ClientName) > 0)) {
        StringPointerArray[StringIndex] = pWinstation->Client.ClientName;
    } else {
        StringPointerArray[StringIndex] = L"Unknown";
    }

    StringIndex++;

    if (pWinstation->Client.ClientAddress && (wcslen(pWinstation->Client.ClientAddress) > 0)) {
        StringPointerArray[StringIndex] = pWinstation->Client.ClientAddress;
    } else {
        StringPointerArray[StringIndex] = L"Unknown";
    }

    StringIndex++;

    //Authz Changes
    
    Status = AuthzReportEventW( &hAET, 
                                APF_AuditSuccess, 
                                EventId, 
                                pWinstation->pUserSid, 
                                StringIndex,
                                0,
                                StringPointerArray,
                                NULL
                                );

    //end authz changes


     if ( !NT_SUCCESS(Status))
        DBGPRINT(("Termsrv - failed to report event \n" ));

    if( !WasEnabled ) {

        /*
         * Principle of least rights says to not go around with privileges
         * held you do not need. So we must disable the shutdown privilege
         * if it was just a logoff force.
         */
        Status2 = RtlAdjustPrivilege(
                      SE_SECURITY_PRIVILEGE,
                      FALSE,    // Disable the PRIVILEGE
                      FALSE,     // Don't Use Thread token
                      &WasEnabled
                      );

    }
badAuthzInit:
    if( hAET != NULL )
        AuthziFreeAuditEventType( hAET  );
}




/***************************************************************************\
* AuditingEnabled
*
* Purpose : Check auditing via LSA.
*
* Returns:  TRUE on success, FALSE on failure
*
* History:
* 5-6-92   DaveHart     Created.
\***************************************************************************/

BOOLEAN
AuditingEnabled()
{
    NTSTATUS                    Status, IgnoreStatus;
    PPOLICY_AUDIT_EVENTS_INFO   AuditInfo;
    OBJECT_ATTRIBUTES           ObjectAttributes;
    SECURITY_QUALITY_OF_SERVICE SecurityQualityOfService;
    LSA_HANDLE                  PolicyHandle;

    //
    // Set up the Security Quality Of Service for connecting to the
    // LSA policy object.
    //

    SecurityQualityOfService.Length = sizeof(SECURITY_QUALITY_OF_SERVICE);
    SecurityQualityOfService.ImpersonationLevel = SecurityImpersonation;
    SecurityQualityOfService.ContextTrackingMode = SECURITY_DYNAMIC_TRACKING;
    SecurityQualityOfService.EffectiveOnly = FALSE;

    //
    // Set up the object attributes to open the Lsa policy object
    //

    InitializeObjectAttributes(
        &ObjectAttributes,
        NULL,
        0L,
        NULL,
        NULL
        );
    ObjectAttributes.SecurityQualityOfService = &SecurityQualityOfService;

    //
    // Open the local LSA policy object
    //

    Status = LsaOpenPolicy(
                 NULL,
                 &ObjectAttributes,
                 POLICY_VIEW_AUDIT_INFORMATION | POLICY_SET_AUDIT_REQUIREMENTS,
                 &PolicyHandle
                 );
    if (!NT_SUCCESS(Status)) {
        DBGPRINT(("Termsrv: Failed to open LsaPolicyObject Status = 0x%lx", Status));
        return FALSE;
    }

    Status = LsaQueryInformationPolicy(
                 PolicyHandle,
                 PolicyAuditEventsInformation,
                 (PVOID *)&AuditInfo
                 );
    IgnoreStatus = LsaClose(PolicyHandle);
    ASSERT(NT_SUCCESS(IgnoreStatus));

    if (!NT_SUCCESS(Status)) {
        DBGPRINT(("Termsrv: Failed to query audit event info Status = 0x%lx", Status));
        return FALSE;
    }

    return (AuditInfo->AuditingMode &&
            ((AuditInfo->EventAuditingOptions)[AuditCategoryLogon] &
                                          POLICY_AUDIT_EVENT_SUCCESS));
}


VOID     WriteErrorLogEntry(
            IN  NTSTATUS NtStatusCode,
            IN  PVOID    pRawData,
            IN  ULONG    RawDataLength
            )
{
    NTSTATUS Status;
    ULONG Length;


    if ( !SystemLogHandle ) {
        UNICODE_STRING ModuleName;

        RtlInitUnicodeString( &ModuleName, L"TermService");

        Status = ElfRegisterEventSourceW( NULL, &ModuleName, &SystemLogHandle );

        if (!NT_SUCCESS(Status)) {
            DBGPRINT(("Termsrv - failed to open System log file\n"));
            return;
        }
    }

    if (IsAuditLogFull(SystemLogHandle))
        return;

    Status = ElfReportEventW( SystemLogHandle,
                              EVENTLOG_ERROR_TYPE,
                              0,
                              NtStatusCode,
                              NULL,
                              0,
                              RawDataLength,
                              NULL,
                              pRawData,
                              0,
                              NULL,
                              NULL );
    if ( !NT_SUCCESS(Status))
        DBGPRINT(("Termsrv - failed to report event \n" ));
}


// This function is duplicated in \nt\termsrv\sessdir\dis\tssdis.cpp.
/****************************************************************************/
// PostErrorValueEvent
//
// Utility function used to create a system log error event containing one
// hex DWORD error code value.
/****************************************************************************/
void PostErrorValueEvent(unsigned EventCode, DWORD ErrVal)
{
    HANDLE hLog;
    WCHAR hrString[128];
    PWSTR String = NULL;
	extern WCHAR gpszServiceName[];
    static DWORD numInstances = 0;
    //
    //count the numinstances of out of memory error, if this is more than
    //a specified number, we just won't log them
    //
    if( STATUS_COMMITMENT_LIMIT == ErrVal )
    {
        if( numInstances > MAX_INSTANCE_MEMORYERR )
            return;
         //
        //if applicable, tell the user that we won't log any more of the out of memory errors
        //
        if( numInstances >= MAX_INSTANCE_MEMORYERR - 1 ) {
            wsprintfW(hrString, L"0x%X. This type of error will not be logged again to avoid clutter.", ErrVal);
            String = hrString;
        }
        numInstances++;
    }

    hLog = RegisterEventSource(NULL, gpszServiceName);
    if (hLog != NULL) {
        if( NULL == String ) {
            wsprintfW(hrString, L"0x%X", ErrVal);
            String = hrString;
        }
        ReportEvent(hLog, EVENTLOG_ERROR_TYPE, 0, EventCode, NULL, 1, 0,
                (const WCHAR **)&String, NULL);
        DeregisterEventSource(hLog);
    }
}

/*************************************************************
* AuthzInit Purpose : Initialize authz for logging an event to the security log
*Flags - unused
*Category Id - Security Category to which this event belongs
*Audit Id - An id for the event
*PArameter count - Number of parameters that will be passed to the logging function later
****************************************************************/

BOOL AuthzInit( IN DWORD Flags,
                IN USHORT CategoryID,
                IN USHORT AuditID,
                IN USHORT ParameterCount,
                OUT PAUTHZ_AUDIT_EVENT_TYPE_HANDLE phAuditEventType
                )                     
{
     BOOL fAuthzInit   = TRUE;

     if( NULL == phAuditEventType )
        goto badAuthzInit;
    
    *phAuditEventType = NULL;
    
    //
    //only one thread can create hRM
    //
    RtlEnterCriticalSection( &g_AuthzCritSection );
    if( NULL == hRM )
    {
            fAuthzInit = AuthzInitializeResourceManager( 0,
                                                         NULL,
                                                         NULL,
                                                         NULL,
                                                         L"Terminal Server",
                                                         &hRM
                                                         );

            if ( !fAuthzInit )
            {
                DBGPRINT(("TERMSRV: AuditEvent: AuthzInitializeResourceManager failed with %d\n", GetLastError()));
                goto badAuthzInit;
            }
    }
    RtlLeaveCriticalSection( &g_AuthzCritSection );

    fAuthzInit = AuthziInitializeAuditEventType( Flags,
                                                 CategoryID,
                                                 AuditID,
                                                 ParameterCount,
                                                 phAuditEventType
                                                 );

    if ( !fAuthzInit )
    {
        DBGPRINT(("TERMSRV: AuditEvent: AuthziInitializeAuditEventType failed with %d\n", GetLastError()));
        goto badAuthzInit;
    }

badAuthzInit:
    if( !fAuthzInit )
    {
        if( NULL != *phAuditEventType )
        {
            if( !AuthziFreeAuditEventType( *phAuditEventType ))
                DBGPRINT(("TERMSRV: AuditEvent: AuthziFreeAuditEventType failed with %d\n", GetLastError()));
            *phAuditEventType = NULL;
        }
    }

  // if( fAuthzInit )
   //  DBGPRINT(("TERMSRV: Successfully initialized authz = %d\n", AuditID));
 return fAuthzInit;
}


/*********************************************************
* Purpose : Log an Event to the security log
* In pHAET
*  Audit Event type obtained from a call to AuthzInit() above
* In Flags
*   APF_AuditSuccess or others as listed in the header file
* pUserSID - Unused
* NumStrings - Number of strings contained within "Strings"
* DataSize - unused
* Strings- Pointer to a sequence of unicode strings
* Data - unused
*
**********************************************************/
NTSTATUS
AuthzReportEventW( IN PAUTHZ_AUDIT_EVENT_TYPE_HANDLE pHAET, 
                   IN DWORD Flags, 
                   IN ULONG EventId, 
                   IN PSID pUserSID, 
                   IN USHORT NumStrings,
                   IN ULONG DataSize OPTIONAL, //Future - DO NOT USE
                   IN PWSTR* Strings,
                   IN PVOID  Data OPTIONAL         //Future - DO NOT USE
                  )
{
    NTSTATUS status = STATUS_ACCESS_DENIED;
    AUTHZ_AUDIT_EVENT_HANDLE      hAE         = NULL;
    BOOL                          fSuccess   = FALSE;
    PAUDIT_PARAMS                 pParams     = NULL;

    if( NULL == hRM || NULL == pHAET || *pHAET == NULL )
        return status;

    fSuccess = AuthziAllocateAuditParams( &pParams,  NumStrings  );

    if ( !fSuccess )
    {
        DBGPRINT(("TERMSRV: AuditEvent: AuthzAllocateAuditParams failed with %d\n", GetLastError()));
        goto BadAuditEvent;
    }


    if( 6 == NumStrings )
    {
        fSuccess = AuthziInitializeAuditParamsWithRM( Flags,
                                                     hRM,
                                                     NumStrings,
                                                     pParams,
                                                     APT_String, Strings[0],
                                                     APT_String,  Strings[1],
                                                     APT_String,  Strings[2],
                                                     APT_String,  Strings[3],
                                                     APT_String, Strings[4],
                                                     APT_String, Strings[5]
                                                     );
    }
    else if( 0 == NumStrings )
    {
        fSuccess = AuthziInitializeAuditParamsWithRM( Flags,
                                                     hRM,
                                                     NumStrings,
                                                     pParams
                                                     );
    }
    else
    {
        //we don't support anything else
        fSuccess = FALSE;
        DBGPRINT(("TERMSRV: AuditEvent: unsupported audit type \n"));
        goto BadAuditEvent;
    }
    
    if ( !fSuccess )
    {
        DBGPRINT(("TERMSRV: AuditEvent: AuthziInitializeAuditParamsWithRM failed with %d\n", GetLastError()));
        goto BadAuditEvent;
    }

    fSuccess = AuthziInitializeAuditEvent( 0,
                                           hRM,
                                           *pHAET,
                                           pParams,
                                           NULL,
                                           INFINITE,
                                           L"",
                                           L"",
                                           L"",
                                           L"",
                                           &hAE
                                           );

    if ( !fSuccess )
    {
        DBGPRINT(("TERMSRV: AuditEvent: AuthziInitializeAuditEvent failed with %d\n", GetLastError()));
        goto BadAuditEvent;
    }    

    fSuccess = AuthziLogAuditEvent( 0,
                                    hAE,
                                    NULL
                                    );

    if ( !fSuccess )
    {
        DBGPRINT(("TERMSRV: AuditEvent: AuthziLogAuditEvent failed with %d\n", GetLastError()));
        goto BadAuditEvent;
    }    

BadAuditEvent:

    if( hAE )
        AuthzFreeAuditEvent( hAE );

    if( pParams )
        AuthziFreeAuditParams( pParams );
    
    if( fSuccess )
        status = STATUS_SUCCESS;

    //if( fSuccess )
    // DBGPRINT(("TERMSRV: Successfully audited event with authz= %d\n", EventId));
    return status;
}


//
//should only be called once per our process
//
VOID AuditEnd()
{
    if( NULL != hRM )
    {
        if( !AuthzFreeResourceManager( hRM ))
            DBGPRINT(("TERMSRV: AuditEvent: AuthzFreeResourceManager failed with %d\n", GetLastError()));
        hRM = NULL;
    }
}
