
#include "stdafx.h"
#include "sspiex.h"

extern BOOL verbose;

void 
DumpBin(PBYTE buffer,
		ULONG length)
{
	ULONG i,count,index;
	CHAR rgbDigits[]="0123456789abcdef";
	CHAR rgbLine[100];
	char cbLine;

	for(index = 0; length;
		length -= count, buffer += count, index += count) 
	{
		count = (length > 16) ? 16:length;
		sprintf_s(rgbLine, 100, "%4.4x  ",index);
		cbLine = 6;

		for(i=0;i<count;i++) 
		{
			rgbLine[cbLine++] = rgbDigits[buffer[i] >> 4];
			rgbLine[cbLine++] = rgbDigits[buffer[i] & 0x0f];
			if(i == 7) 
				rgbLine[cbLine++] = ':';
			else 
				rgbLine[cbLine++] = ' ';
		}

		for(; i < 16; i++) 
		{
			rgbLine[cbLine++] = ' ';
			rgbLine[cbLine++] = ' ';
			rgbLine[cbLine++] = ' ';
		}

		rgbLine[cbLine++] = ' ';

		for(i = 0; i < count; i++) 
		{
			if(buffer[i] < 32 || buffer[i] > 126) 
				rgbLine[cbLine++] = '.';
			else 
				rgbLine[cbLine++] = buffer[i];
		}

		rgbLine[cbLine++] = 0;
		printf("%s\n", rgbLine);
	}
}


void
DumpSecBufferType(SecBuffer *sb)
{
	switch (sb->BufferType & ~SECBUFFER_ATTRMASK)
	{
	case SECBUFFER_EMPTY:          
		printf(" empty ");
		break;
	case SECBUFFER_DATA:
		printf(" data ");
		break;
	case SECBUFFER_TOKEN:
		printf(" token ");
		break;
	case SECBUFFER_PKG_PARAMS:
		printf(" params ");
		break;
	case SECBUFFER_MISSING:
		printf(" missing ");
		break;
	case SECBUFFER_EXTRA:
		printf(" extra ");
		break;
	case SECBUFFER_STREAM_TRAILER:
		printf(" trailer ");
		break;
	case SECBUFFER_STREAM_HEADER:
		printf(" header ");
		break;
	case SECBUFFER_NEGOTIATION_INFO:
		printf(" neghint ");
		break;
	case SECBUFFER_PADDING:
		printf(" header ");
		break;
	case SECBUFFER_STREAM:
		printf(" stream / msg ");
		break;
	case SECBUFFER_MECHLIST:
		printf(" mechlist ");
		break;
	case SECBUFFER_MECHLIST_SIGNATURE:
		printf(" mechlist-mic ");
		break;
	case SECBUFFER_TARGET:
		printf(" target - obs ");
		break;
	case SECBUFFER_CHANNEL_BINDINGS:
		printf(" channel-bind ");
		break;
	case SECBUFFER_CHANGE_PASS_RESPONSE:
		printf(" chpwd ");
		break;
	default:
        printf(" ? ");
		break;
	}

	switch (sb->BufferType & SECBUFFER_ATTRMASK)
	{
	case SECBUFFER_READONLY:
		printf("RO ");
		break;
	case SECBUFFER_READONLY_WITH_CHECKSUM:
		printf("RO CHKSUM ");
		break;
	}
	printf("\n");

}

void
DumpSecBuffer(
    char *msg,
    SecBuffer *sb,
    BOOL verbose
    )
{
	printf(msg);
	DumpSecBufferType(sb);
	printf("sb: (0x%p)(size: 0x%x)(buf: 0x%p)\n", 
		sb, 
		sb->cbBuffer, 
		sb->pvBuffer);

	if (verbose)
		DumpBin((PBYTE)sb->pvBuffer,sb->cbBuffer);

}

SECURITY_STATUS
EncryptBuffer(
    PCtxtHandle hCtxt,
    SecBuffer *in,
    SecBuffer *out,
    ULONG seq,
    ULONG qop
    )
{
	SECURITY_STATUS		status;
	SecBufferDesc		sdb;
	SecBuffer			input[2];
	SecPkgContext_Sizes sizes;

	status = QueryContextAttributes(
				hCtxt,
				SECPKG_ATTR_SIZES,
				&sizes
				);

	if (!SEC_SUCCESS(status))
	{
		printf("qca (sizes) failed - 0x%08x\n", status);
		goto out;
	}

	DumpSecBuffer("cleartext", in, verbose);

	//-----------------------------------------------------------------
	//  Allocate a buffer to hold the signature,
	//  encrypted data, and a DWORD  
	//  that specifies the size of the trailer block.

	out->cbBuffer = sizes.cbSecurityTrailer + in->cbBuffer + sizeof(DWORD);

    sdb.ulVersion = 0;
	sdb.cBuffers = 2;
	sdb.pBuffers = input;

	input[0].cbBuffer = sizes.cbSecurityTrailer;
	input[0].BufferType = SECBUFFER_TOKEN;
	input[0].pvBuffer = (PVOID) (((PBYTE)out->pvBuffer) + sizeof(DWORD));

	input[1] = (*in);
	input[1].BufferType = SECBUFFER_DATA;

	status = EncryptMessage(
				hCtxt,
				qop,
				&sdb,
				seq);

	if (!SEC_SUCCESS(status)) 
	{
		printf("EncryptMessage failed: 0x%08x\n", status);
		goto out;
	}
	
	DumpSecBuffer("encrypted", in, verbose);
	DumpSecBuffer("signature", &input[0], verbose);

	//------------------------------------------------------------------
	//  Indicate the size of the signature buffer in the first DWORD. 
	*((DWORD *)out->pvBuffer) = input[0].cbBuffer;

	//-----------------------------------------------------------------
	//  Append the encrypted data to our trailer block
	//  to form a single block. 
	memcpy(
        OFFSET_TO_POINTER(out->pvBuffer,(input[0].cbBuffer + sizeof(DWORD))), 
		in->pvBuffer, 
        in->cbBuffer
        );

    
    DumpSecBuffer("msg to send", out, verbose);

out:

	return status;
}

SECURITY_STATUS
DecryptBuffer(
	SecHandle *hCtxt,
	SecBuffer *in,
	SecBuffer *decrypted,
	ULONG *pQop,
	ULONG seq
	)
{
	SECURITY_STATUS   status;
	SecBufferDesc     sbd;
	SecBuffer         sb[2];
	DWORD             signatureSize;
	PBYTE			  data = (PBYTE) in->pvBuffer;

	//-------------------------------------------------------------------
	//  By agreement, the server encrypted the message and set the size
	//  of the trailer block to be just what it needed. DecryptMessage 
	//  needs the size of the trailer block. 
	//  The size of the trailer is in the first DWORD of the
	//  message received. 

	sbd.ulVersion = 0;
	sbd.cBuffers = 2;
	sbd.pBuffers = sb;

	signatureSize = *((ULONG*) in->pvBuffer);

	sb[0].BufferType = SECBUFFER_TOKEN;
	sb[0].cbBuffer = signatureSize;
	sb[0].pvBuffer = OFFSET_TO_POINTER(in->pvBuffer, sizeof(ULONG));
	DumpSecBuffer("signature", &sb[0], verbose);

	sb[1].BufferType = SECBUFFER_DATA;
	sb[1].cbBuffer = in->cbBuffer - signatureSize - sizeof(ULONG);
	sb[1].pvBuffer = OFFSET_TO_POINTER(in->pvBuffer, 
                        signatureSize + sizeof(ULONG));

	DumpSecBuffer("enc data", &sb[1], verbose); 
	
	status = DecryptMessage(
				hCtxt,
				&sbd,
				seq,
				pQop
				);

	if (!SEC_SUCCESS(status)) 
	{
		printf("DecryptMessage() failed - %u\n", status);
		return status;
	}

	DumpSecBuffer("enc data", &sb[1], verbose);
	memcpy(decrypted, &sb[1], sizeof(SecBuffer));
	return status;
}



/*
 * Signature functions
 */
#if 0
PBYTE 
VerifySig(
	SecHandle *hCtxt,
	SecBuffer *msg,
	SecBuffer *token,
	ULONG cbMaxSignature)
{

	SECURITY_STATUS   status;
	SecBufferDesc     BuffDesc;
	SecBuffer         SecBuff[2];
	ULONG             ulQop = 0;
	PBYTE             pSigBuffer;
	PBYTE             pDataBuffer;

	//-------------------------------------------------------------------
	//  The global cbMaxSignature is the size of the signature
	//  in the message received.

	printf ("data before verifying (including signature):\n");
	PrintHexDump (*pcbMessage, pBuffer);

	//--------------------------------------------------------------------
	//  By agreement with the server, 
	//  the signature is at the beginning of the message received,
	//  and the data that was signed comes after the signature.

	pSigBuffer = pBuffer;
	pDataBuffer = pBuffer + cbMaxSignature;

	//-------------------------------------------------------------------
	//  The size of the message is reset to the size of the data only.

	*pcbMessage = *pcbMessage - (cbMaxSignature);

	//--------------------------------------------------------------------
	//  Prepare the buffers to be passed to the signature verification 
	//  function.

	BuffDesc.ulVersion    = 0;
	BuffDesc.cBuffers     = 2;
	BuffDesc.pBuffers     = SecBuff;

	SecBuff[0].cbBuffer   = cbMaxSignature;
	SecBuff[0].BufferType = SECBUFFER_TOKEN;
	SecBuff[0].pvBuffer   = pSigBuffer;

	SecBuff[1].cbBuffer   = *pcbMessage;
	SecBuff[1].BufferType = SECBUFFER_DATA;
	SecBuff[1].pvBuffer   = pDataBuffer;

	ss = VerifySignature(
		hCtxt,
		&BuffDesc,
		0,
		&ulQop
		);

	if (!SEC_SUCCESS(ss)) 
	{
		fprintf(stderr, "VerifyMessage failed");
	}
	else
	{
		printf("Message was properly signed.\n");
	}

	return pDataBuffer;

}  // end VerifyThis

#endif
/*
 * Socket handlers
 */

DWORD
SendBytes(
	SOCKET s, 
	PBYTE pBuf, 
	DWORD cbBuf
	)
{
	DWORD dwError;
	PBYTE pTemp = pBuf;
	int   cbSent;
	int   cbRemaining = cbBuf;

	if (0 == cbBuf)
		return(ERROR_SUCCESS);

	while (cbRemaining) 
	{
		cbSent = send(
					s, 
					(const char *)pTemp, 
					cbRemaining, 
					0);

		if (SOCKET_ERROR == cbSent) 
		{
			dwError = WSAGetLastError();
			fprintf(stderr, "send failed: %u\n", dwError);
			return dwError;
		}

		pTemp += cbSent;
		cbRemaining -= cbSent;
	}

    return ERROR_SUCCESS;
}

DWORD
ReceiveBytes(
	SOCKET  s, 
	PBYTE   pBuf, 
	DWORD   cbBuf, 
	DWORD  *pcbRead
	)
{
	DWORD dwError;
	PBYTE pTemp = pBuf;
	int cbRead, cbRemaining = cbBuf;

	while (cbRemaining) 
	{
		cbRead = recv(
					s, 
					(char *)pTemp, 
					cbRemaining, 
					0);
		
		if (0 == cbRead)
			break;

		if (SOCKET_ERROR == cbRead) 
		{
			dwError = WSAGetLastError();
			fprintf (stderr, "recv failed: %u\n", dwError);
			return dwError;
		}

		cbRemaining -= cbRead;
		pTemp += cbRead;
	}

	*pcbRead = cbBuf - cbRemaining;

	return ERROR_SUCCESS;
}  // end ReceiveBytes


DWORD 
SendMsg(
	SOCKET s,
	PBYTE pBuf, 
	DWORD cbBuf
    )
{
    DWORD dwError;
	if (0 == cbBuf)
		return(ERROR_SUCCESS);

    /* Send the size of the message. */
	dwError = SendBytes(s, (PBYTE)&cbBuf, sizeof (cbBuf));
    if (!dwError) {
        /*  Send the body of the message. */
	    dwError = SendBytes(s, pBuf, cbBuf);
    }

    return (dwError);

}    

DWORD 
ReceiveMsg(
	SOCKET  s, 
	PBYTE   pBuf, 
	DWORD   cbBuf, 
	DWORD  *pcbRead
    )

{
	DWORD dwError;
    DWORD cbRead;
	DWORD cbData;
    
	//----------------------------------------------------------
	//  Receive the number of bytes in the message.

	dwError = ReceiveBytes(
				s, 
				(PBYTE)&cbData, 
				sizeof (cbData), 
				&cbRead);

	if (dwError)
	{
		printf("ReceiveBytes failed %u\n", dwError);
		return dwError;
	}
	if (sizeof(cbData) != cbRead || cbData > cbBuf)
		return(ERROR_INSUFFICIENT_BUFFER);
	//----------------------------------------------------------
	//  Read the full message.

	dwError = ReceiveBytes(
				s, 
				pBuf, 
				cbData, 
				&cbRead);
	if (dwError)
	{
		printf("ReceiveBytes(2) failed %u\n", dwError);
		return dwError;
	}

	if (cbRead != cbData)
		return(ERROR_INSUFFICIENT_BUFFER);

	*pcbRead = cbRead;
	return(ERROR_SUCCESS);
}  // end ReceiveMessage    



DWORD
BuildAuthIdentity(
	  LPWSTR user,
	  LPWSTR password,
	  LPWSTR domain,
	  PSEC_WINNT_AUTH_IDENTITY_EXW *id
	  )
{
	PSEC_WINNT_AUTH_IDENTITY_EXW localId = NULL;
	ULONG localSize = sizeof(SEC_WINNT_AUTH_IDENTITY_EXW);

	localId = (PSEC_WINNT_AUTH_IDENTITY_EXW) LocalAlloc(LPTR, localSize);
	if (!localId)
		return ERROR_NOT_ENOUGH_MEMORY;

	localId->Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;

	localId->Version = SEC_WINNT_AUTH_IDENTITY_VERSION;
    localId->Length = sizeof(SEC_WINNT_AUTH_IDENTITY_EXW);
    localId->User = (USHORT*) user;
    localId->Domain = (USHORT*)domain;
    localId->Password = (USHORT*)password;

    localId->UserLength = (ULONG) wcslen(user);
    localId->DomainLength = (ULONG) wcslen(domain);
    if (password)
        localId->DomainLength = (ULONG) wcslen(password);
    else
        localId->DomainLength = 0;

    *id = localId;

	return (ERROR_SUCCESS);
}


DWORD
BuildAuthIdentityM(
	  LPWSTR user,
	  LPWSTR password,
	  LPWSTR domain,
	  PSEC_WINNT_AUTH_IDENTITY_EXW *id
	  )
{
	PSEC_WINNT_AUTH_IDENTITY_EXW localId = NULL;
	ULONG localSize = sizeof(SEC_WINNT_AUTH_IDENTITY_EXW);
	PBYTE copyTo;

	/* no NULL required! */
    localSize += ALIGN_UP_LPWSTR(WSZ_BYTECOUNT(user));

    if (password)
	    localSize += ALIGN_UP_LPWSTR(WSZ_BYTECOUNT(password));

	localSize += ALIGN_UP_LPWSTR(WSZ_BYTECOUNT(domain));

	localId = (PSEC_WINNT_AUTH_IDENTITY_EXW) LocalAlloc(LPTR, localSize);
	if (!localId)
		return ERROR_NOT_ENOUGH_MEMORY;

	localId->Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE |
        SEC_WINNT_AUTH_IDENTITY_MARSHALLED;

	localId->Version = 0;
    localId->Length =  sizeof(SEC_WINNT_AUTH_IDENTITY_EXW);

    /* marshall this blob */
	copyTo = (PBYTE) localId + sizeof(SEC_WINNT_AUTH_IDENTITY_EXW);

	localId->UserLength = (ULONG) wcslen(user);
	memcpy(copyTo, user, WSZ_BYTECOUNT(user));
	localId->User = (USHORT*) POINTER_TO_OFFSET_S(localId, copyTo);
	copyTo += ALIGN_UP_LPWSTR(WSZ_BYTECOUNT(user));

	localId->DomainLength = (ULONG) wcslen(domain);
	memcpy(copyTo, domain, WSZ_BYTECOUNT(domain));
	localId->Domain = (USHORT*)POINTER_TO_OFFSET_S(localId, copyTo);
	copyTo += ALIGN_UP_LPWSTR(WSZ_BYTECOUNT(domain));

    /* usually, we'll default */
    if (password) 
    {
	    localId->PasswordLength = (ULONG) wcslen(password);
	    memcpy(copyTo, password, WSZ_BYTECOUNT(password));
	    localId->Password = (USHORT*)POINTER_TO_OFFSET_S(localId, copyTo);
    }

    *id = localId;

	return (ERROR_SUCCESS);
}
