cwinapiwindows-kernelschannel

InitializeSecurityContextW returns SEC_E_INVALID_HANDLE after second call


I'm trying to implement secure sockets in my kernel mode application using Winsock Kernel and Schannel. I'm using this code as a reference to establish the secure connection. However, I encounter a weird issue where the second call to InitializeSecurityContextW fails with error SEC_E_INVALID_HANDLE. This doesn't make sense to me, because the first call returns SEC_I_CONTINUE_NEEDED which is considered a success. From my understanding from the documentation, it should now be guaranteed that the phNewContext parameter returned a handle to a partially formed context, but in reality this handle will generate the mentioned error code when passing it to InitializeSecurityContextW again (in phContext parameter) and even worse, causes a BSOD when passing it to DeleteSecurityContext. The documentation doesn't mention such behaviour, which makes me wonder if this is a bug with the API? My current code looks as follows (I added comments to indicate where the error(s) occur):

NTSTATUS SecureSocketConnect(
    IN SOCKET_SESSION_HANDLE handle,
    IN PSOCKADDR remoteAddr,
    OUT PSECURE_SOCK_CONNECTION_HANDLE outHandle
)
{
    NTSTATUS retVal;
    SECURITY_STATUS credStatus = E_FAIL, ctxStatus;
    PSECURE_SOCKET_CONNECTION con = NULL;
    SECURITY_STRING packageName;
    SCHANNEL_CRED creds = { 0 };
    BOOLEAN destroyCtx = FALSE, isFirst = TRUE;
    DWORD flags, downloaded = 0;

    SecBuffer inBuf[2] = { 0 }, outBuf[1] = { 0 };
    SecBufferDesc inDesc, outDesc;

    //1. Allocate and init. structures

    con = AllocUserMemory(sizeof(SECURE_SOCKET_CONNECTION));

    if (!con)
    {
        retVal = STATUS_NO_MEMORY;
        goto Done;
    }

    creds.dwVersion = SCHANNEL_CRED_VERSION;
    creds.grbitEnabledProtocols = SP_PROT_TLS1_2;
    creds.dwFlags = SCH_USE_STRONG_CRYPTO         
        | SCH_CRED_MANUAL_CRED_VALIDATION
        | SCH_CRED_NO_DEFAULT_CREDS;

    flags = ISC_REQ_USE_SUPPLIED_CREDS | 
        ISC_REQ_ALLOCATE_MEMORY |
        ISC_REQ_CONFIDENTIALITY | 
        ISC_REQ_REPLAY_DETECT | 
        ISC_REQ_SEQUENCE_DETECT | 
        ISC_REQ_STREAM;

    RtlInitUnicodeString(
        &packageName, 
        L"Microsoft Unified Security Protocol Provider"
    );

    //2. Get pointer to security function table

    con->SecureFuncs = InitSecurityInterfaceW();

    if (!con->SecureFuncs)
    {
        retVal = STATUS_UNSATISFIED_DEPENDENCIES;
        goto Done;
    }

    //3. Acquire credentials handle

    credStatus = con->SecureFuncs->AcquireCredentialsHandleW(
        NULL,
        &packageName,
        SECPKG_CRED_OUTBOUND,
        NULL,
        &creds,
        NULL,
        NULL,
        &con->CredHandle,
        NULL
    );

    if (credStatus != SEC_E_OK)
    {
        retVal = STATUS_UNSUCCESSFUL;
        goto Done;
    }

    //4. Try to connect to remote address

    retVal = SocketConnect(
        handle,
        SOCK_STREAM,
        IPPROTO_TCP,
        remoteAddr,
        &con->SocketConnection
    );

    if (!NT_SUCCESS(retVal))
        goto Done;
    
    //5. Establish secure connection using InitializeSecurityContextW
    
    for(retVal = STATUS_PENDING; retVal == STATUS_PENDING; )
    {
        inBuf[0].BufferType = SECBUFFER_TOKEN;
        inBuf[0].pvBuffer = con->Incoming;
        inBuf[0].cbBuffer = con->Received;

        inBuf[1].BufferType = SECBUFFER_EMPTY;
        inBuf[1].pvBuffer = NULL;
        inBuf[1].cbBuffer = 0;

        outBuf[0].BufferType = SECBUFFER_TOKEN;
        outBuf[0].pvBuffer = NULL;
        outBuf[0].cbBuffer = 0;

        inDesc.ulVersion = SECBUFFER_VERSION;
        inDesc.pBuffers = inBuf;
        inDesc.cBuffers = 2;

        outDesc.ulVersion = SECBUFFER_VERSION;
        outDesc.cBuffers = 1;
        outDesc.pBuffers = outBuf;

        if (isFirst)
        {
            isFirst = FALSE;

            //This (first) call succeeds with SEC_I_CONTINUE_NEEDED
            ctxStatus = con->SecureFuncs->InitializeSecurityContextW(
                &con->CredHandle,
                NULL,
                NULL,
                flags,
                0,
                0,
                NULL,
                0,
                &con->ContextHandle,
                &outDesc,
                &flags,
                NULL
            );

            if (ctxStatus != SEC_I_CONTINUE_NEEDED)
            {
                retVal = STATUS_UNSUCCESSFUL;
                continue;
            }
            else
            {
                destroyCtx = TRUE;
            }
        }
        else
        {

            //ERROR: This (second) call fails with SEC_E_INVALID_HANDLE
            ctxStatus = con->SecureFuncs->InitializeSecurityContextW(
                &con->CredHandle,
                &con->ContextHandle,
                NULL,
                flags,
                0,
                0,
                &inDesc,
                0,
                NULL,
                &outDesc,
                &flags,
                NULL
            );
        }

        DbgPrintEx(0, 0, "ctxStatus: %x\n", ctxStatus);

        if (inBuf[1].BufferType == SECBUFFER_EXTRA)
        {
            RtlCopyMemory(
                con->Incoming,
                con->Incoming + (con->Received - inBuf[1].cbBuffer),
                inBuf[1].cbBuffer
            );

            con->Received = inBuf[1].cbBuffer;
        }
        else
        {
            con->Received = 0;
        }

        switch (ctxStatus)
        {

        case SEC_E_INCOMPLETE_MESSAGE:

            break;

        case SEC_I_CONTINUE_NEEDED:

            if (!outBuf[0].pvBuffer)
            {
                retVal = STATUS_NO_MEMORY;
                continue;
            }

            retVal = SocketSend(
                con->SocketConnection,
                outBuf[0].pvBuffer,
                outBuf[0].cbBuffer
            );

            con->SecureFuncs->FreeContextBuffer(outBuf[0].pvBuffer);

            if (!NT_SUCCESS(retVal))
                continue;

            break;

        case SEC_E_OK:

            retVal = STATUS_SUCCESS;
            *outHandle = con;
            continue;

        default:

            retVal = STATUS_UNSUCCESSFUL;
            continue;

        }

        if (con->Received == TLS_MAX_PACKET_SIZE)
        {
            retVal = STATUS_INVALID_BUFFER_SIZE;
            continue;
        }

        retVal = SocketRecieve(
            con->SocketConnection,
            TLS_MAX_PACKET_SIZE - con->Received,
            con->Incoming + con->Received,
            &downloaded
        );

        if (NT_SUCCESS(retVal))
        {
            con->Received += downloaded;
            retVal = STATUS_PENDING;
        }
    }

Done:

    //Freeing resources in case of error
    if (!NT_SUCCESS(retVal) && con)
    {
        if (con->SecureFuncs && destroyCtx)
            //ERROR: This causes a BSOD, even though the ContextHandle should be valid.
            con->SecureFuncs->DeleteSecurityContext(&con->ContextHandle);

        if (con->SecureFuncs && credStatus == SEC_E_OK)
            con->SecureFuncs->FreeCredentialsHandle(&con->CredHandle);

        if (con->SocketConnection)
            SocketCloseConnectionHandle(con->SocketConnection);

        FreeUserMemory(con);
    }

    return retVal;
}

Edit: After further research it appears that the documentation is wrong. The InitializeSecurityContextW function will always return SEC_E_INVALID_HANDLE if phNewContext is NULL. RbMm suggested in a comment to pass a pointer to the ContextHandle to both phContext and phNewContext for every call, but this will eventually result in a BSOD (access violation in ksecdd!InitUserModeContext+0x77). So the question remains: How to properly use this function in kernel mode? The documentation can't be trusted it seems.


Solution

  • Use "Schannel" instead of "Microsoft Unified Security Protocol Provider".