#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <memory.h>
#include <nt.h>
#include <ntrtl.h>
#include <nturtl.h>
#include <ctype.h>
#include <stdio.h>
#include <windows.h>
#include "profiler.h"
#include "view.h"
#include "thread.h"
#include "dump.h"
#include "except.h"
#include "memory.h"
#include "clevel.h"
#include "cap.h"

extern BOOL g_bIsWin9X;
CAPFILTER g_execFilter;
pfnExContinue g_pfnExContinue = 0;

BOOL
HookUnchainableExceptionFilter(VOID)
{
    BOOL bResult;
    pfnRtlAddVectoredExceptionHandler pfnAddExceptionHandler = 0;
    PVOID pvResult;
    HANDLE hTemp;
    DWORD dwExceptionHandler;
    DWORD dwResultSize;
    PVOID pAlternateHeap;

    //
    // If we're NT - try for the unchainable filter in ntdll
    //
    if (FALSE == g_bIsWin9X) {
       pfnAddExceptionHandler = (pfnRtlAddVectoredExceptionHandler)GetProcAddress(GetModuleHandleA("NTDLL.DLL"), 
                                                                                  "RtlAddVectoredExceptionHandler");
       if (0 == pfnAddExceptionHandler) {
          return FALSE;
       }

       pvResult = (*pfnAddExceptionHandler)(1, 
                                            (PVOID)ExceptionFilter);
       if (0 == pvResult) {
          return FALSE;
       }
    }
    else {
       //
       // Set up exception handler
       //
       hTemp = CreateFileA(NAME_OF_EXCEPTION_VXD,
                           0,
                           0,
                           0,
                           0,
                           FILE_FLAG_DELETE_ON_CLOSE,
                           0);
       if (INVALID_HANDLE_VALUE == hTemp) {
          return FALSE;
       }
 
       _asm mov dwExceptionHandler, offset Win9XExceptionDispatcher

       bResult = DeviceIoControl(hTemp,
                                 INSTALL_RING_3_HANDLER,
                                 &dwExceptionHandler,
                                 sizeof(DWORD),
                                 0,
                                 0,
                                 &dwResultSize,
                                 0);
       if (FALSE == bResult) {
          return FALSE;
       }     

       //
       // Get function pointer for ExContinue
       //
       g_pfnExContinue = (pfnExContinue)0xbff76702;
    }

    return TRUE;
}

LONG 
ExceptionFilter(struct _EXCEPTION_POINTERS *ExceptionInfo)
{
    DWORD dwThreadId;
    DWORD dwCounter;
    BOOL bResult;
    LONG lRet;
    PCONTEXT pContext = ExceptionInfo->ContextRecord;
    PEXCEPTION_RECORD pRecord = ExceptionInfo->ExceptionRecord;
    PVIEWCHAIN pView = 0;
    PTHREADFAULT pThreadFault = 0;
    CHAR szBuffer[MAX_PATH];

    //
    // Retrieve thread data
    //
    dwThreadId = GetCurrentThreadId();

    pThreadFault = GetProfilerThreadData();
    if (0 == pThreadFault) {
       //
       // NT only code path
       //
       pThreadFault = AllocateProfilerThreadData();
       if (0 == pThreadFault) {
          //
          // This wasn't suppose to happen
          //
          ExitProcess(-1);
       }
    }
    
    //
    // Rehook the view
    //
    if (STATUS_SINGLE_STEP == pRecord->ExceptionCode) {
       //
       // Trace is used to map into call or jumps types we can't forward map
       //
       if (pThreadFault->dwPrevBP) {
          //
          // If we're a call - patch the return address so we can maintain call level
          //
          if (pThreadFault->prevBPType == Call) {
             //
             // Push the return level hook
             //
             bResult = PushCaller((PVOID)pThreadFault,
                                  (PVOID)pContext->Esp);
             if (FALSE == bResult) {
                //
                // Ooops
                //
                ExitProcess(-1);
             }
          }

          RestoreAddressFromView(pThreadFault->dwPrevBP,
                                 FALSE);

          if ((pThreadFault->prevBPType == Call) ||
              (pThreadFault->prevBPType == Jump)) {
             //
             // Profile this routine if it hasn't been mapped
             //
             pView = FindView((DWORD)pRecord->ExceptionAddress);
             if (0 == pView) {
                //
                // Add this address as a mapping breakpoint
                //
                pView = AddViewToMonitor((DWORD)pRecord->ExceptionAddress,
                                         Map);
                if (pView) {
                   bResult = MapCode(pView);
                   if (FALSE == bResult) {
                      //
                      // This is fatal
                      //
                      ExitProcess(-1);
                   }
                }
             }
          }

          pThreadFault->dwPrevBP = 0;
          pThreadFault->prevBPType = None;

          return EXCEPTION_CONTINUE_EXECUTION;
       }

       //
       // Trace exception wasn't generated by us
       //
       sprintf(szBuffer, "Unhandled Trace %08X\r\n", (DWORD)pRecord->ExceptionAddress);
       WriteError(szBuffer);

       return EXCEPTION_CONTINUE_SEARCH;
    }
 
    //
    // Restore the view
    //
    if (STATUS_BREAKPOINT == pRecord->ExceptionCode) {
       //
       // Restore any BP that hasn't been restored
       //
       if (pThreadFault->dwPrevBP) {
          RestoreAddressFromView(pThreadFault->dwPrevBP,
                                 FALSE);

          if ((DWORD)pRecord->ExceptionAddress == pThreadFault->dwPrevBP) {
             pThreadFault->dwPrevBP = 0;
             pThreadFault->prevBPType = None;

             return EXCEPTION_CONTINUE_EXECUTION;
          }
       }

/*
       //
       // Add address to the execution filter
       //
       bResult = AddToCap(&g_execFilter,
                          (DWORD)pRecord->ExceptionAddress);
       if (FALSE == bResult) {
          //
          // This is fatal
          //
          ExitProcess(-1);
       }

       //
       // If we've hit iteration - disable this and the previous breakpoints
       //
       if (0 != g_execFilter.dwIterationLock) {
          for (dwCounter = 0; dwCounter < g_execFilter.dwRunLength; dwCounter++) {
              //
              // Replace the munged code
              //
              pView = RestoreAddressFromView(g_execFilter.dwArray[g_execFilter.dwCursor - dwCounter - 1],
                                             TRUE);

              //
              // Add runtime event to log
              //
              sprintf(szBuffer, "CAP'ed %08X\r\n", g_execFilter.dwArray[g_execFilter.dwCursor - dwCounter - 1]);
              AddToDump(szBuffer, 
                        strlen(szBuffer), 
                        FALSE);
          }

          //
          // Clear breakpoint monitor flags
          //
          pThreadFault->dwPrevBP = 0;
          pThreadFault->prevBPType = None;

          return EXCEPTION_CONTINUE_EXECUTION;
       }
*/

       //
       // Replace the munged code
       //
       pView = RestoreAddressFromView((DWORD)pRecord->ExceptionAddress,
                                      TRUE);
       if (pView) {
          //
          // See if we've mapped this address range in yet
          //
          if (FALSE == pView->bMapped) {
             //
             // See if this address is already mapped
             //
             bResult = MapCode(pView);
             if (FALSE == bResult) {
                //
                // This is fatal
                //
                ExitProcess(-1);
             }
          }

          //
          // Set the trace so the last bp can be rehooked (unless we just executed a map bp)
          //          
          pContext->EFlags |= 0x00000100;
          pThreadFault->dwPrevBP = (DWORD)pRecord->ExceptionAddress;
          pThreadFault->prevBPType = pView->bpType;

          //
          // Add runtime event to log
          //
          if (pView->bpType != ThreadStart) {
              WriteExeFlow(dwThreadId,
                           (DWORD)pRecord->ExceptionAddress,
                           pThreadFault->dwCallLevel);
          }
          else {
              WriteThreadStart(dwThreadId,
                               (DWORD)pRecord->ExceptionAddress);
          }

          return EXCEPTION_CONTINUE_EXECUTION;
       }

       //
       // BP exception wasn't generated by us
       //
       sprintf(szBuffer, "Unhandled BP %08X\r\n", (DWORD)pRecord->ExceptionAddress);
       WriteError(szBuffer);

       return EXCEPTION_CONTINUE_SEARCH;
    }

    //
    // Continue searching the chain
    //

    return EXCEPTION_CONTINUE_SEARCH;
}

VOID
Win9XExceptionDispatcher(struct _EXCEPTION_POINTERS *ExceptionInfo)
{
    LONG lResult;

    //
    // Call exception handler
    //
    lResult = ExceptionFilter(ExceptionInfo);
    if (lResult != EXCEPTION_CONTINUE_EXECUTION) {
       //
       // Fault not handled - page fault will terminate app
       //
       return;
    }

    //
    // Set the context results
    //
    SET_CONTEXT();

    //
    // This code path is never executed (unless the above call fails)
    //
    return;
}
