singrdk/base/Windows/UnitTests/NtlmWinHost/NtlmWinHost.cpp

1126 lines
34 KiB
C++
Raw Normal View History

2008-11-17 18:29:00 -05:00
////////////////////////////////////////////////////////////////////////////////
//
// Microsoft Research Singularity
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
// File: NtlmWinHost.cpp
//
// Note:
//
// This program is part of the unit test for the NTLM authentication library.
// Run this program on a Windows machine, and run Application\Tests\NtlmUnitTest
// on Singularity. Specify the @remote command, and provide -user=username
// and -password=password, and the IP address the machine running this program.
// NtlmUnitTest will then connect to the Windows machine and perform an NTLM
// exchange.
//
// TODO:
//
// * Clean up this file.
//
#include <windows.h>
#define SECURITY_WIN32
#include <security.h>
using namespace System;
using namespace System::Net;
using namespace System::Net::Sockets;
using namespace System::Diagnostics;
using namespace System::Runtime::InteropServices;
using namespace System::Threading;
using namespace System::Text;
using namespace System::Reflection;
using namespace System::Runtime::CompilerServices;
using namespace System::Security::Permissions;
typedef System::Byte byte;
typedef System::UInt32 uint;
typedef System::UInt16 ushort;
typedef System::UInt64 ulong;
const int NtlmUnitTestPort = 720;
#pragma comment(lib,"secur32.lib")
[StructLayout(LayoutKind::Sequential)]
value struct TestMessageHeader
{
public:
uint TotalLength;
uint MessageType;
};
[StructLayout(LayoutKind::Sequential)]
value struct ResultMessage
{
public:
int Succeeded;
// unicode string of error follows, no nul terminator
};
enum class TestMessageType
{
Negotiate = 1,
Challenge = 2,
Response = 3,
Result = 4,
};
ref class Util
{
public:
literal String^ HexDigits = "0123456789abcdef";
static void DumpException(Exception^ chain)
{
for (Exception^ ex = chain; ex != nullptr; ex = ex->InnerException) {
Console::WriteLine("{0}: {1}", ex->GetType()->FullName, ex->Message);
}
}
static array<byte>^ GetSubArray(array<byte>^ arr, int offset, int length)
{
array<byte>^ newarray = gcnew array<byte>(length);
Buffer::BlockCopy(arr, offset, newarray, 0, length);
return newarray;
}
static String^ ByteArrayToString(array<byte>^ arr)
{
StringBuilder^ buf = gcnew StringBuilder(arr->Length * 2);
for (int i = 0; i < arr->Length; i++) {
byte b = arr[i];
buf->Append(HexDigits[b >> 4]);
buf->Append(HexDigits[b & 0xf]);
}
return buf->ToString();
}
public:
static String^ FormatMessageFromSystem(ULONG message);
static void DumpBuffer(PUCHAR buffer, int length);
static void DumpBuffer(array<byte>^ buffer, int offset, int length);
[DllImport("KERNEL32.DLL", CharSet = CharSet::Unicode)]
static int FormatMessage(DWORD dwFlags, LPCVOID lpSource, DWORD dwMessageId,
DWORD dwLanguageId, LPTSTR lpBuffer, DWORD nSize, PVOID Arguments);
generic<typename T> static int CompareArraySpans(array<T>^ array1, int offset1, array<T>^ array2, int offset2, int length)
{
for (int i = 0; i < length; i++)
{
T element1 = array1[i];
T element2 = array2[i];
if (element1 < element2)
return -1;
if (element1 > element2)
return 1;
}
return 0;
}
};
//
//
//struct {
// byte protocol[8]; // 'N', 'T', 'L', 'M', 'S', 'S', 'P', '\0'
// byte type; // 0x01
// byte zero[3];
// short flags; // 0xb203
// byte zero[2];
//
//0x10 short dom_len; // domain string length
// short dom_len; // domain string length
// short dom_off; // domain string offset
// byte zero[2];
//
//0x18 short host_len; // host string length
// short host_len; // host string length
// short host_off; // host string offset (always 0x20)
// byte zero[2];
//
//0x20 byte host[*]; // host string (ASCII)
// byte dom[*]; // domain string (ASCII)
// } type-1-message
//
//
//
public enum class NtlmMessageType
{
Negotiate = 1,
Challenge = 2,
Response = 3,
};
[Flags]
public enum class NtlmNegotiateFlags
{
None = 0,
NegotiateUnicode = 0x00000001,// Text strings are in unicode
NegotiateOem = 0x00000002,// Text strings are in OEM
RequestTarget = 0x00000004,// Server should return its authentication realm
NegotiateSign = 0x00000010,// Request signature capability
NegotiateSeal = 0x00000020,// Request confidentiality
NegotiateDatagram = 0x00000040,// Use datagram style authentication
NegotiateLmKey = 0x00000080,// Use LM session key for sign/seal
NegotiateNetware = 0x00000100,// NetWare authentication
NegotiateNtlm = 0x00000200,// NTLM authentication
NegotiateNtOnly = 0x00000400,// NT authentication only (no LM)
NegotiateNullSession = 0x00000800,// NULL Sessions on NT 5.0 and beyond
NegotiateOemDomainSupplied = 0x1000,// Domain Name supplied on negotiate
NegotiateOemWorkstationSupplied = 0x2000,// Workstation Name supplied on negotiate
NegotiateLocalCall = 0x00004000,// Indicates client/server are same machine
NegotiateAlwaysSign = 0x00008000,// Sign for all security levels
};
#define NTLM_CHALLENGE_LENGTH 8
public ref class NtlmConstants
{
public:
static initonly array<Byte>^ MessageSignature = { 'N', 'T', 'L', 'M', 'S', 'S', 'P', 0 };
};
public ref class NtlmUtil
{
private:
literal int HeaderLength = 0x10;
public:
static ushort GetUInt16(array<Byte>^ message, int pos)
{
return message[pos] + (message[pos + 1] << 8);
}
static array<Byte>^ GetCountedBytesAt(array<Byte>^ message, int pos)
{
int length = GetUInt16(message, pos + 0);
int maxlength = GetUInt16(message, pos + 2);
int offset = GetUInt16(message, pos + 4);
if (offset >= message->Length)
throw gcnew Exception("String has invalid offset");
if (offset + length > message->Length)
throw gcnew Exception("String has invalid offset / length");
return Util::GetSubArray(message, offset, length);
}
static String^ GetCountedStringAt(array<Byte>^ message, int pos)
{
int length = GetUInt16(message, pos + 0);
int maxlength = GetUInt16(message, pos + 2);
int offset = GetUInt16(message, pos + 4);
if (offset >= message->Length)
throw gcnew Exception("String has invalid offset");
if (offset + length > message->Length)
throw gcnew Exception("String has invalid offset / length");
String^ result = Encoding::Unicode->GetString(message, offset, length);
return result;
}
static void DumpMessage(array<Byte>^ message)
{
DumpMessage(message, message->Length);
}
static void DumpMessage(array<Byte>^ message, int length)
{
Console::WriteLine("");
Console::WriteLine("NTLM message:");
Util::DumpBuffer(message, 0, length);
if (length < HeaderLength) {
Console::WriteLine(" Message is invalid; too short");
return;
}
for (int i = 0; i < NtlmConstants::MessageSignature->Length; i++) {
if (message[i] != NtlmConstants::MessageSignature[i]) {
Console::WriteLine(" Message is invalid; signature does not match");
return;
}
}
NtlmMessageType type = (NtlmMessageType)message[8];
switch (type) {
case NtlmMessageType::Negotiate:
{
//
//typedef struct _NEGOTIATE_MESSAGE {
// UCHAR Signature[sizeof(NTLMSSP_SIGNATURE)];
// NTLM_MESSAGE_TYPE MessageType;
// ULONG NegotiateFlags;
// STRING32 OemDomainName;
// STRING32 OemWorkstationName;
// ULONG64 Version;
//} NEGOTIATE_MESSAGE, *PNEGOTIATE_MESSAGE;
//
//
Console::WriteLine(" Type: Negotiate");
NtlmNegotiateFlags flags = (NtlmNegotiateFlags)GetUInt16(message, 12);
Console::WriteLine(String::Format(" Flags: 0x{0:x8} {1}", (UInt32)flags, flags.ToString()));
String^ domain = GetCountedStringAt(message, 0x10);
Console::WriteLine(" Domain: " + domain);
}
break;
case NtlmMessageType::Challenge:
{
//
//typedef struct _CHALLENGE_MESSAGE {
//UCHAR Signature[sizeof(NTLMSSP_SIGNATURE)]; // 0x00
//NTLM_MESSAGE_TYPE MessageType; // 0x08
//STRING32 TargetName; // 0x0c
//ULONG NegotiateFlags; // 0x10
//UCHAR Challenge[MSV1_0_CHALLENGE_LENGTH]; // 0x14
//ULONG64 ServerContextHandle; // 0x20
//STRING32 TargetInfo; // 0x28
//ULONG64 Version; // 0x30
// // 0x38
//} CHALLENGE_MESSAGE, *PCHALLENGE_MESSAGE;
//
Console::WriteLine(" Type: Challenge");
String^ TargetName = GetCountedStringAt(message, 0x0c);
array<Byte>^ Challenge = Util::GetSubArray(message, 0x14, NTLM_CHALLENGE_LENGTH);
Console::WriteLine(" TargetName: " + TargetName);
Console::WriteLine(" Challenge: " + Util::ByteArrayToString(Challenge));
}
break;
case NtlmMessageType::Response:
{
//
//typedef struct _AUTHENTICATE_MESSAGE {
//UCHAR Signature[sizeof(NTLMSSP_SIGNATURE)];
//NTLM_MESSAGE_TYPE MessageType;
//STRING32 LmChallengeResponse;
//STRING32 NtChallengeResponse;
//STRING32 DomainName;
//STRING32 UserName;
//STRING32 Workstation;
//STRING32 SessionKey;
//ULONG NegotiateFlags;
//ULONG64 Version;
//} AUTHENTICATE_MESSAGE, *PAUTHENTICATE_MESSAGE;
//
//typedef struct _OLD_AUTHENTICATE_MESSAGE {
//UCHAR Signature[sizeof(NTLMSSP_SIGNATURE)]; // 0
//NTLM_MESSAGE_TYPE MessageType; // 8
//STRING32 LmChallengeResponse; // 12 0x0c
//STRING32 NtChallengeResponse; // 20 0x14
//STRING32 DomainName; // 28 0x1c
//STRING32 UserName; // 36 0x24
//STRING32 Workstation; // 42 0x2c
//} OLD_AUTHENTICATE_MESSAGE, *POLD_AUTHENTICATE_MESSAGE;
//
Console::WriteLine(" Type: Response");
array<Byte>^ LmChallengeResponse = GetCountedBytesAt(message, 0x0c);
array<Byte>^ NtChallengeResponse = GetCountedBytesAt(message, 0x14);
String^ DomainName = GetCountedStringAt(message, 28);
String^ UserName = GetCountedStringAt(message, 36);
String^ Workstation = GetCountedStringAt(message, 44);
Console::WriteLine(" LmChallengeResponse: " + Util::ByteArrayToString(LmChallengeResponse));
Console::WriteLine(" NtChallengeResponse: " + Util::ByteArrayToString(NtChallengeResponse));
Console::WriteLine(" DomainName: " + DomainName);
Console::WriteLine(" UserName: " + UserName);
Console::WriteLine(" Workstation: " + Workstation);
}
break;
default:
Console::WriteLine(" Message is invalid; message type byte is not recognized");
return;
}
Console::WriteLine("");
}
};
#define ThrowStatus(status, message) do { \
Console::WriteLine("ThrowStatus: " + message); \
Console::WriteLine(Util::FormatMessageFromSystem(status)); \
throw gcnew Exception(String::Format("FAILED: {0} - {1}", message, Util::FormatMessageFromSystem(status))); \
} while (0)
#if 0
void CheckStatus(SECURITY_STATUS status, String^ message)
{
if (status != SEC_E_OK)
Throwstatus(status);
Console::WriteLine(message + " - ok");
}
#else
#define CheckStatus(status, message) do { \
if (status != SEC_E_OK) ThrowStatus(status, message); \
Console::WriteLine(message + " - ok"); \
} while(false)
#endif
class StringWrapperW
{
private:
wchar_t* _buffer;
public:
StringWrapperW(String^ str)
{
_buffer = (wchar_t*)(void*)Marshal::StringToHGlobalUni(str);
}
~StringWrapperW()
{
Marshal::FreeHGlobal((IntPtr)(void*)_buffer);
}
operator wchar_t*()
{
return _buffer;
}
};
class StringWrapperA
{
private:
char* _buffer;
public:
StringWrapperA(String^ str)
{
_buffer = (char*)(void*)Marshal::StringToHGlobalAnsi(str);
}
~StringWrapperA()
{
Marshal::FreeHGlobal((IntPtr)(void*)_buffer);
}
operator char*()
{
return _buffer;
}
};
value struct ManagedSecHandle
{
public:
ULONG_PTR dwLower;
ULONG_PTR dwUpper;
public:
ManagedSecHandle(SecHandle handle)
{
dwLower = handle.dwLower;
dwUpper = handle.dwUpper;
}
operator SecHandle()
{
SecHandle handle;
handle.dwLower = dwLower;
handle.dwUpper = dwUpper;
return handle;
}
};
typedef ManagedSecHandle ManagedCredHandle;
typedef ManagedSecHandle ManagedCtxtHandle;
#if 0
ref class SspiNtlmSupplicant : NtlmSupplicant
{
public:
ManagedCredHandle _credentials;
ManagedCtxtHandle _context;
virtual array<Byte>^ GetNegotiate(
String^ username,
String^ domain,
String^ password) override
{
SECURITY_STATUS status;
StringWrapperW username_w(username);
StringWrapperW domain_w(domain);
StringWrapperW password_w(password);
//
// First, generate the NTLM client credentials
//
SEC_WINNT_AUTH_IDENTITY clientIdentity;
ZeroMemory(&clientIdentity, sizeof(clientIdentity));
clientIdentity.Domain = domain_w;
clientIdentity.DomainLength = _tcslen(domain_w);
clientIdentity.User = username_w;
clientIdentity.UserLength = _tcslen(username_w);
clientIdentity.Password = password_w;
clientIdentity.PasswordLength = _tcslen(password_w);
clientIdentity.Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
TimeStamp expires;
CredHandle credentials = { 0, 0 };
status = AcquireCredentialsHandle(
NULL,
L"NTLM",
SECPKG_CRED_OUTBOUND, // credentials use
NULL, // logon id (LUID, not used)
&clientIdentity, // auth data
NULL, // pGetKeyFn - not used
NULL, // pGetKeyArgument - not used
&credentials, // the credentials returned
&expires // not used
);
CheckStatus(status, "Client - Failed to acquire credentials");
_credentials = ManagedCredHandle(credentials);
TimeStamp contextExpires;
SecBuffer outputBuffer;
array<Byte>^ hello = gcnew array<Byte>(0x100);
pin_ptr<Byte> hello_pinned = &hello[0];
outputBuffer.BufferType = SECBUFFER_TOKEN;
outputBuffer.cbBuffer = hello->Length;
outputBuffer.pvBuffer = hello_pinned;
SecBufferDesc outputDesc;
outputDesc.pBuffers = &outputBuffer;
outputDesc.cBuffers = 1;
outputDesc.ulVersion = SECBUFFER_VERSION;
ULONG attributes;
CtxtHandle context = { 0, 0 };
status = InitializeSecurityContext(
&credentials, // credentials
NULL, // existing context handle (none)
NULL, // service principal name (none)
ISC_REQ_CONNECTION, // context requirements (most don't apply)
0, // reserved
SECURITY_NETWORK_DREP, // data representation
NULL, // input token (none on first call)
0, // reserved
&context, // the new context handle
&outputDesc, // the output buffer
&attributes, // returns context attributes
&contextExpires // when it expires
);
switch (status) {
case SEC_I_CONTINUE_NEEDED:
// This is the expected case.
Console::WriteLine("Client - SEC_I_CONTINUE_NEEDED");
Console::WriteLine(String::Format(" attributes: 0x{0:x}", attributes));
_context = ManagedCtxtHandle(context);
return Util::GetSubArray(hello, 0, outputBuffer.cbBuffer);
default:
ThrowStatus(status, "Client - AcquireCredentialsHandle");
}
}
// Returns the response.
virtual array<Byte>^ ProcessChallenge(array<Byte>^ challenge) override
{
TimeStamp contextExpires;
pin_ptr<Byte> challenge_pinned = &challenge[0];
SecBuffer challengeSecBuffer;
challengeSecBuffer.BufferType = SECBUFFER_TOKEN;
challengeSecBuffer.cbBuffer = challenge->Length;
challengeSecBuffer.pvBuffer = challenge_pinned;
SecBufferDesc challengeDesc;
challengeDesc.cBuffers = 1;
challengeDesc.pBuffers = &challengeSecBuffer;
challengeDesc.ulVersion = SECBUFFER_VERSION;
array<Byte>^ response_buffer = gcnew array<Byte>(0x100);
pin_ptr<Byte> response_pinned = &response_buffer[0];
SecBuffer outputBuffer;
outputBuffer.BufferType = SECBUFFER_TOKEN;
outputBuffer.cbBuffer = response_buffer->Length;
outputBuffer.pvBuffer = response_pinned;
SecBufferDesc outputDesc;
outputDesc.pBuffers = &outputBuffer;
outputDesc.cBuffers = 1;
outputDesc.ulVersion = SECBUFFER_VERSION;
CredHandle credentials = _credentials;
CtxtHandle context = _context;
ULONG attributes;
SECURITY_STATUS status = InitializeSecurityContext(
&credentials, // credentials
&context, // existing context handle
NULL, // service principal name (none)
0, // context requirements (most don't apply)
0, // reserved
SECURITY_NETWORK_DREP, // data representation
&challengeDesc, // input token (none on first call)
0, // reserved
&context, // the new context handle
&outputDesc, // the output buffer
&attributes, // returns context attributes
&contextExpires // when it expires
);
_context = ManagedCtxtHandle(context);
switch (status) {
case SEC_E_OK:
// This is the expected case.
Console::WriteLine("Client - SEC_E_OK");
Console::WriteLine(String::Format(" attributes: 0x{0:x}", attributes));
return Util::GetSubArray(response_buffer, 0, outputBuffer.cbBuffer);
default:
ThrowStatus(status, "Client - AcquireCredentialsHandle");
}
}
};
#endif
//
//This class wraps the Windows NTLMSSP, through the SSPI.
//
ref class SspiNtlmAuthenticator
{
private:
ManagedCredHandle _credentials;
ManagedCtxtHandle _context;
public:
void Initialize()
{
SECURITY_STATUS status;
CredHandle credentials = { 0, 0 };
TimeStamp expires;
status = AcquireCredentialsHandle(
NULL,
L"NTLM",
SECPKG_CRED_INBOUND, // credentials use
NULL, // logon id (LUID, not used)
NULL, // auth data
NULL, // pGetKeyFn - not used
NULL, // pGetKeyArgument - not used
&credentials, // the credentials returned
&expires // not used
);
CheckStatus(status, "Server - AcquireCredentialsHandle");
_credentials = ManagedCredHandle(credentials);
}
array<byte>^ GetChallenge(array<byte>^ negotiate)
{
pin_ptr<byte> negotiate_pinned = &negotiate[0];
SecBuffer negotiateSecBuffer;
negotiateSecBuffer.BufferType = SECBUFFER_TOKEN | SECBUFFER_READONLY;
negotiateSecBuffer.cbBuffer = negotiate->Length;
negotiateSecBuffer.pvBuffer = negotiate_pinned;
SecBufferDesc negotiateDesc;
negotiateDesc.cBuffers = 1;
negotiateDesc.pBuffers = &negotiateSecBuffer;
negotiateDesc.ulVersion = SECBUFFER_VERSION;
array<byte>^ challenge_buffer = gcnew array<Byte>(0x200);
pin_ptr<byte> challenge_pinned = &challenge_buffer[0];
SecBuffer challengeSecBuffer;
challengeSecBuffer.BufferType = SECBUFFER_TOKEN;
challengeSecBuffer.cbBuffer = challenge_buffer->Length;
challengeSecBuffer.pvBuffer = challenge_pinned;
SecBufferDesc challengeDesc;
challengeDesc.ulVersion = SECBUFFER_VERSION;
challengeDesc.cBuffers = 1;
challengeDesc.pBuffers = &challengeSecBuffer;
TimeStamp expires;
ULONG attributes;
CredHandle credentials = _credentials;
CtxtHandle context = { 0, 0 };
SECURITY_STATUS status = AcceptSecurityContext(
&credentials,
NULL, // no previous context
&negotiateDesc, // input buffer
0, // context requirements
SECURITY_NETWORK_DREP, // data representation
&context,
&challengeDesc,
&attributes,
&expires
);
switch (status) {
case SEC_I_CONTINUE_NEEDED:
{
// This is the expected success case.
Console::WriteLine("Server - AcceptSecurityContext succeeded");
array<byte>^ challenge = gcnew array<Byte>(challengeSecBuffer.cbBuffer);
Array::Copy(challenge_buffer, 0, challenge, 0, challenge->Length);
_context = ManagedCtxtHandle(context);
return challenge;
}
default:
ThrowStatus(status, "Server - AcceptSecurityContext");
}
}
void VerifyResponse(array<Byte>^ response)
{
SECURITY_STATUS status;
pin_ptr<Byte> response_pinned = &response[0];
SecBuffer responseSecBuffer;
responseSecBuffer.BufferType = SECBUFFER_TOKEN;
responseSecBuffer.cbBuffer = response->Length;
responseSecBuffer.pvBuffer = response_pinned;
SecBufferDesc responseDesc;
responseDesc.ulVersion = SECBUFFER_VERSION;
responseDesc.pBuffers = &responseSecBuffer;
responseDesc.cBuffers = 1;
TimeStamp expires;
ULONG attributes;
CtxtHandle context = _context;
CredHandle credentials = _credentials;
status = AcceptSecurityContext(
&credentials,
&context, // previous context
&responseDesc, // input buffer
0, // context requirements
SECURITY_NETWORK_DREP, // data representation
&context,
NULL, // no output buffer
&attributes,
&expires
);
_context = ManagedCtxtHandle(context);
switch (status) {
case SEC_E_OK:
Console::WriteLine("Server - Authentication complete");
break;
default:
ThrowStatus(status, "Server - Authentication failed");
}
HANDLE token;
status = QuerySecurityContextToken(&context, &token);
CheckStatus(status, "Server - QuerySecurityContextToken");
Console::WriteLine(String::Format("token - 0x{0:x8}", (ULONG_PTR)token));
}
};
ref class Client
{
private:
initonly Socket^ _socket;
Client(Socket^ socket)
{
_socket = socket;
}
array<Byte>^ ReceiveMessage(TestMessageType% type)
{
array<byte>^ headerBuffer = gcnew array<byte>(sizeof(TestMessageHeader));
int length = _socket->Receive(headerBuffer, sizeof(TestMessageHeader), SocketFlags::None);
if (length == 0) {
throw gcnew Exception("Server has closed socket.");
}
if (length < (int)sizeof(TestMessageHeader)) {
throw gcnew Exception("Received short data from server.");
}
pin_ptr<byte> headerBuffer_ptr = &headerBuffer[0];
TestMessageHeader header = *(TestMessageHeader*)headerBuffer_ptr;
if (header.TotalLength < sizeof(TestMessageHeader)) {
throw gcnew Exception("Received invalid header from server (length is too short)");
}
if (header.TotalLength > 0x10000) {
throw gcnew Exception("Received excessively large message from server.");
}
int bodyLength = (int)(header.TotalLength - sizeof(TestMessageHeader));
array<byte>^ body = gcnew array<byte>(bodyLength);
length = _socket->Receive(body, bodyLength, SocketFlags::None);
if (length == 0)
throw gcnew Exception("Server has closed socket.");
if (length < bodyLength)
throw gcnew Exception("Received short data (payload) from server.");
type = (TestMessageType)header.MessageType;
return body;
}
array<byte>^ ReceiveExpectedMessage(TestMessageType type)
{
TestMessageType actualType;
array<byte>^ msg = ReceiveMessage(actualType);
if (actualType != type) {
Console::WriteLine("Received message, but its type is not the expected type!");
Console::WriteLine("Received type {0}, wanted type {1}", actualType, type);
throw gcnew Exception("Invalid message received.");
}
return msg;
}
void SendMessage(TestMessageType type, array<byte>^ payload)
{
TestMessageHeader header;
header.TotalLength = (uint)(sizeof(TestMessageHeader) + payload->Length);
header.MessageType = (uint)type;
array<byte>^ headerBuffer = gcnew array<byte>(sizeof(TestMessageHeader));
pin_ptr<byte> headerBuffer_pinned = &headerBuffer[0];
*(TestMessageHeader*)headerBuffer_pinned = header;
_socket->Send(headerBuffer);
_socket->Send(payload);
}
void ThreadRoutine()
{
try {
for (;;) {
Console::WriteLine("Waiting for Negotiate");
array<Byte>^ negotiate = ReceiveExpectedMessage(TestMessageType::Negotiate);
Console::WriteLine("Received Negotiate:");
NtlmUtil::DumpMessage(negotiate);
SspiNtlmAuthenticator authenticator;
authenticator.Initialize();
array<byte>^ challenge = authenticator.GetChallenge(negotiate);
Console::WriteLine("Sending Challenge:");
NtlmUtil::DumpMessage(challenge);
SendMessage(TestMessageType::Challenge, challenge);
Console::WriteLine("Waiting for Response");
array<byte>^ response = ReceiveExpectedMessage(TestMessageType::Response);
Console::WriteLine("Received Response:");
NtlmUtil::DumpMessage(response);
String^ resulttext;
bool succeeded;
try {
authenticator.VerifyResponse(response);
Console::WriteLine("Authentication succeeded");
resulttext = "OK";
succeeded = true;
} catch(Exception^ ex) {
Console::WriteLine("Authentication failed");
Util::DumpException(ex);
resulttext = "FAILED: " + ex->Message;
succeeded = false;
}
ResultMessage result;
result.Succeeded = succeeded ? 1 : 0;
Console::WriteLine("Sending Result: " + resulttext);
Encoding^ encoding = Encoding::Unicode;
array<byte>^ response_buffer = gcnew array<byte>(sizeof(ResultMessage) + encoding->GetByteCount(resulttext));
pin_ptr<byte> response_pinned = &response_buffer[0];
*(ResultMessage*)response_pinned = result;
encoding->GetBytes(resulttext, 0, resulttext->Length, response_buffer, sizeof(ResultMessage));
SendMessage(TestMessageType::Result, response_buffer);
}
} catch (Exception^ ex) {
Console::WriteLine("Exception on client: " + _socket->RemoteEndPoint->ToString());
Util::DumpException(ex);
} finally {
_socket->Close();
}
Console::WriteLine("Client thread has terminated.");
}
public:
static void StartThread(Socket^ socket)
{
Client^ client = gcnew Client(socket);
Thread^ thread = gcnew Thread(gcnew ThreadStart(client, &Client::ThreadRoutine));
thread->Start();
}
};
ref class NtlmWinHost
{
public:
static void Main(array<String^>^ args)
{
try {
Socket^ listener = gcnew Socket(AddressFamily::InterNetwork, SocketType::Stream, ProtocolType::Tcp);
IPEndPoint^ listenAddress = gcnew IPEndPoint(IPAddress::Any, NtlmUnitTestPort);
try {
listener->Bind(listenAddress);
} catch(Exception^) {
Console::WriteLine("Failed to bind to listen address: " + listenAddress);
Console::WriteLine("Check to see if any other apps (including an instance of this one) are using the port.");
throw;
}
listener->Listen(4);
for (;;) {
Console::WriteLine("Waiting for clients on address: " + listenAddress);
Socket^ clientsocket = listener->Accept();
Client::StartThread(clientsocket);
}
} catch(Exception^ ex) {
Util::DumpException(ex);
}
}
};
int main(array<String^> ^args)
{
NtlmWinHost::Main(args);
return 0;
}
String^ FormatMessageFromSystem(ULONG message)
{
array<Char>^ buffer = gcnew array<Char>(0x200);
pin_ptr<Char> buffer_pinned = &buffer[0];
int length;
length = Util::FormatMessage(
FORMAT_MESSAGE_FROM_SYSTEM,
NULL,
message,
LANG_NEUTRAL,
buffer_pinned,
buffer->Length,
NULL);
if (length == 0)
return String::Format("(unknown error 0x{0:x8} {0})", message);
else {
String^ b = gcnew String(buffer, 0, length);
String^ result = String::Format("{0} (0x{1:x8} {1})", b, message, message);
return result;
}
}
String^ Util::FormatMessageFromSystem(ULONG message)
{
return ::FormatMessageFromSystem(message);
}
void Util::DumpBuffer(PUCHAR buffer, int length)
{
StringBuilder line;
for (int i = 0; i < length; i += 0x10) {
line.Length = 0;
for (int j = 0; j < 0x10; j++) {
if (i + j < length) {
line.Append(" ");
byte b = buffer[i + j];
line.Append((Char)HexDigits[b >> 4]);
line.Append((Char)HexDigits[b & 0xf]);
} else {
line.Append(" ");
}
}
line.Append(" : ");
for (int j = 0; j < 0x10; j++) {
if (i + j < length) {
byte b = buffer[i + j];
if (b >= 32 && b <= 127) {
line.Append((Char)b);
} else {
line.Append(".");
}
} else {
break;
}
}
Console::WriteLine(line.ToString());
}
}
void Util::DumpBuffer(array<byte>^ buffer, int offset, int length)
{
StringBuilder line;
for (int i = 0; i < length; i += 0x10) {
line.Length = 0;
for (int j = 0; j < 0x10; j++) {
if (i + j < length) {
line.Append(" ");
byte b = buffer[offset + i + j];
line.Append((Char)HexDigits[b >> 4]);
line.Append((Char)HexDigits[b & 0xf]);
} else {
line.Append(" ");
}
}
line.Append(" : ");
for (int j = 0; j < 0x10; j++) {
if (i + j < length) {
byte b = buffer[offset + i + j];
if (b >= 32 && b <= 127) {
line.Append((Char)b);
} else {
line.Append(".");
}
} else {
break;
}
}
Console::WriteLine(line.ToString());
}
}
//
// General Information about an assembly is controlled through the following
// set of attributes. Change these attribute values to modify the information
// associated with an assembly.
//
[assembly:AssemblyTitleAttribute("NtlmWinHost")];
[assembly:AssemblyDescriptionAttribute("Unit test for NTLM authentication library")];
[assembly:AssemblyConfigurationAttribute("")];
[assembly:AssemblyCompanyAttribute("Microsoft")];
[assembly:AssemblyProductAttribute("NtlmWinHost")];
[assembly:AssemblyCopyrightAttribute("Copyright (c) 2006")];
[assembly:AssemblyTrademarkAttribute("")];
[assembly:AssemblyCultureAttribute("")];
//
// Version information for an assembly consists of the following four values:
//
// Major Version
// Minor Version
// Build Number
// Revision
//
// You can specify all the value or you can default the Revision and Build Numbers
// by using the '*' as shown below:
[assembly:AssemblyVersionAttribute("1.0.*")];
[assembly:ComVisible(false)];
[assembly:CLSCompliantAttribute(true)];
[assembly:SecurityPermission(SecurityAction::RequestMinimum, UnmanagedCode = true)];