// Windows
#include <Windows.h>

// CRT
#include <iostream>

// FlowSshCpp
#include "FlowSshCpp.h"

// name spaces
using namespace std;
using namespace FlowSshCpp;


enum ExitCodes
{
	Success,		// no exit code received from remote
	UsageError,
	SessionError,	// including connection error 
	FatalError,
	ExitCodeBase = 1000,	// ExitCode = ExitCodeBase + exit code received from remote
};


// CmdLineParams

struct CmdLineParams
{
	bool		m_bOk;
	wstring		m_sUserName;
	wstring		m_sPassword;
	unsigned int m_nPort;
	wstring		m_sHost;

	wstring		m_sPty;
	wstring		m_sCmd;

	CmdLineParams(int argc, wchar_t const* argv[]);
};

CmdLineParams::CmdLineParams(int argc, wchar_t const* argv[]) : m_bOk(false), m_nPort(0)
{
	bool syntaxError = (argc == 1);
	for (int i = 1; i < argc && !syntaxError; ++i)
	{
		if (wcsncmp(argv[i], L"-user=", 6) == 0)
			m_sUserName = argv[i] + 6;
		else if (wcsncmp(argv[i], L"-pw=", 4) == 0)
			m_sPassword = argv[i] + 4;
		else if (wcsncmp(argv[i], L"-host=", 6) == 0)
			m_sHost = argv[i] + 6;
		else if (wcsncmp(argv[i], L"-port=", 6) == 0)
		{
			unsigned int port;
			if (swscanf_s(argv[i] + 6, L"%u", &port))
				m_nPort = port;
		}
		else if (wcsncmp(argv[i], L"-pty=", 5) == 0)
			m_sPty = argv[i] + 5;
		else if (wcsncmp(argv[i], L"-cmd=", 5) == 0)
			m_sCmd = argv[i] + 5;
		else
			syntaxError = true;
	}

	if (syntaxError)
	{
		wcout << L"FlowSshCpp sample Exec client." << endl;
		wcout << endl; 
		wcout << L"Parameters:" << endl;
		wcout << L" -host=...  (default localhost)" << endl;
		wcout << L" -port=...  (default 22)" << endl;
		wcout << L" -user=..." << endl;
		wcout << L" -pw=..." << endl;
		wcout << L" -pty=..." << endl;
		wcout << L" -cmd=...   (open shell by default)" << endl;
	}
	m_bOk = !syntaxError;
}


// MyErrorHandler

class MyErrorHandler : public ErrorHandler
{
private:
	// This object must be created on the heap.
	~MyErrorHandler() {}

protected:
	virtual void OnExceptionInHandler(bool fatal, wchar_t const* desc)
	{
		wcout << (fatal ? L"Error [fatal]: " : L"Error: ");
		wcout << (desc ? desc : L"<no error description>") << endl;
		// For the sake of simplicity, no cleanup is performed here.
		// The lack of cleanup may cause debugger to report memory leaks.
		exit(FatalError);
	}
};


// ConsoleDisableEchoInput

class ConsoleDisableEchoInput
{
public:
	ConsoleDisableEchoInput()
	{
		m_stdIn = GetStdHandle(STD_INPUT_HANDLE);		

		m_restoreMode = FALSE;
		if (GetFileType(m_stdIn) == FILE_TYPE_CHAR)
			if (m_restoreMode = GetConsoleMode(m_stdIn, &m_originalMode))
				m_restoreMode = SetConsoleMode(m_stdIn, m_originalMode & ~((DWORD) ENABLE_ECHO_INPUT));
	}
	~ConsoleDisableEchoInput() { if (m_restoreMode)	SetConsoleMode(m_stdIn, m_originalMode); }

private:
	HANDLE m_stdIn;
	DWORD  m_originalMode;
	BOOL   m_restoreMode;
};


// MyClient

class MyClient : public Client
{
private: 
	// This object must be created on the heap.
	~MyClient()	{}

public:
	MyClient(volatile unsigned int& exitCode) : m_exitCode(exitCode)	{}	

	void SetClientSettings(CmdLineParams& params)
	{
		if (params.m_sUserName.size()) SetUserName(params.m_sUserName.c_str());
		if (params.m_sPassword.size()) SetPassword(params.m_sPassword.c_str());
		if (params.m_nPort)			   SetPort(params.m_nPort);
		if (params.m_sHost.size())     SetHost(params.m_sHost.c_str());
	}

protected:
	void OnSshVersion(std::wstring const& version)
	{
		wcout << L"Server version: " << version << endl;
	}

	bool OnHostKey(RefPtr<PublicKey> publicKey)
	{
		wcout << L"Received the following host key:" << endl;
		wcout << L"  MD5 Fingerprint: " << publicKey->GetMd5() << endl;
		wcout << L"  Bubble-Babble: " << publicKey->GetBubbleBabble() << endl;
		wcout << L"  SHA-256: " << publicKey->GetSha256() << endl;
		wcout << L"Accept the key (yes/no)? ";

		std::wstring input;
		wcin >> input;

		if (input == L"yes")
			return true;

		// host key rejected
		return false;
	}

	bool OnFurtherAuth(FurtherAuth& furtherAuth);
	bool OnPasswordChange(PasswordChange& passwordChange);

	void OnBanner(std::wstring const& banner)
	{	// Production code should remove volatile characters from banner first.
		wcout << endl << endl << L"User authentication banner:" << endl << banner << endl;
	}
	void OnDisconnect(unsigned int reason, std::wstring const& desc)
	{			
		if (reason != FlowSshC_DisconnectReason_ByClient)
		{
			wcout << endl << endl << L"Client disconnecting: " << endl << desc << endl;
			if (m_exitCode == Success)
				m_exitCode = SessionError;
		}
	}

private:
	volatile unsigned int& m_exitCode;
};

bool MyClient::OnFurtherAuth(FurtherAuth& furtherAuth)
{
	ConsoleDisableEchoInput disableEchoInput;
	bool success = false;

	if (furtherAuth.IsPasswordRemaining())
	{	// user must provide a password
		wcout << L"Password: ";

		std::wstring password;	
		wcin >> password;

		if (password.length() > 0)
		{
			furtherAuth.SetPassword(password.c_str());
			success = true;
		}
		wcout << endl;
	}

	return success;
}

bool MyClient::OnPasswordChange(PasswordChange& passwordChange)
{
	ConsoleDisableEchoInput disableEchoInput;
	bool success = false;

	// Production code should remove volatile characters from prompt first.
	wcout << passwordChange.GetPrompt() << endl;
	wcout << L"New password: ";

	std::wstring newPassword;
	wcin >> newPassword;

	if (newPassword.length() > 0)
	{
		wcout << endl << L"Repeat password: ";

		std::wstring repeatPassword;
		wcin >> repeatPassword;

		if (newPassword == repeatPassword)
		{
			passwordChange.SetNewPassword(newPassword.c_str());
			success = true;
		}
		else
			wcout << endl << L"Passwords do not match!";
	}

	wcout << endl;
	return success;
}


// MyClientSessionChannel

class MyClientSessionChannel : public ClientSessionChannel
{
private: 
	// This object must be created on the heap.	
	~MyClientSessionChannel() {}

public:
	MyClientSessionChannel(RefPtrConst<Client> const& client, volatile unsigned int& exitCode) 
		: ClientSessionChannel(client), m_exitCode(exitCode) {}

protected:
	virtual void OnExitStatus(FlowSshC_ExitStatus const& status)
	{	// i.e. user enters Eof (Ctrl-Z + Enter)
		if (m_exitCode == Success) 
			m_exitCode = ExitCodeBase + status.m_code;
	}

private:
	volatile unsigned int& m_exitCode;
};


// Channel input thread

class StdInput
{
public:
	StdInput() : m_readThread(0), m_dataSize(0), m_exitCode(Success)
	{
		m_file				= GetStdHandle(STD_INPUT_HANDLE); 
		m_newDataEvent		= CreateEvent(0, false, false, 0);
		m_continueReadEvent = CreateEvent(0, false, false, 0);		

		if (!m_newDataEvent || !m_continueReadEvent)
		{
			wcout << L"CreateEvent() failed with Windows error " << GetLastError() << L"." << endl;
			m_exitCode = FatalError;
			return;
		}
		else if (!(m_readThread = CreateThread(0, 0, &StdInput::ReadThread, this, 0, 0)))
		{
			wcout << L"CreateThread() failed with Windows error " << GetLastError() << L"." << endl;
			m_exitCode = FatalError;
		}
	}

	~StdInput()
	{		
		if (m_readThread)
		{
			TerminateThread(m_readThread, 0);
			WaitForSingleObject(m_readThread, INFINITE);
		}
		if (m_readThread)		 CloseHandle(m_readThread);
		if (m_newDataEvent)		 CloseHandle(m_newDataEvent);
		if (m_continueReadEvent) CloseHandle(m_continueReadEvent);
	}

	HANDLE GetNewDataEvent() const { return m_newDataEvent; }
	void   ContinueRead() { SetEvent(m_continueReadEvent); }
	HANDLE GetReadThreadHandle() const { return m_readThread; }
	unsigned char const* GetDataPtr() const { return m_dataPtr; }
	unsigned int		 GetDataSize() const { return m_dataSize; }
	unsigned int GetExitCode() const { return m_exitCode; }

private:
	static DWORD WINAPI ReadThread(LPVOID thisPtr);

	HANDLE m_file;	
	HANDLE m_newDataEvent;
	HANDLE m_continueReadEvent;
	HANDLE m_readThread;
	unsigned int m_dataSize;
	unsigned char m_dataPtr[4*1024];
	unsigned int m_exitCode;
};

DWORD WINAPI StdInput::ReadThread(LPVOID thisPtr)
{
	StdInput* p = static_cast<StdInput*>(thisPtr);
	if (!p) return 0;

	while (true)
	{	
		DWORD read = 0;
		if (!ReadFile(p->m_file, p->m_dataPtr, sizeof(p->m_dataPtr), &read, 0))
			return GetLastError();

		p->m_dataSize = read;

		SetEvent(p->m_newDataEvent);
		WaitForSingleObject(p->m_continueReadEvent, INFINITE);
	}
	return 0;
};


struct InputParams
{
	volatile unsigned int& m_exitCode;
	ClientSessionChannel* m_channel;
	HANDLE m_thisThread;
	HANDLE m_outputThread;		// same as OutputParams::m_thisThread

	~InputParams() { if (m_thisThread) CloseHandle(m_thisThread); }
};

DWORD WINAPI ChannelInputThread(LPVOID param)
{
	InputParams* inParam =  static_cast<InputParams*>(param);
	if (!inParam) return 0;
	RefPtr<ClientSessionChannel> pChannel(inParam->m_channel);
	
	try
	{
		RefPtr<ProgressEvent> progress(new ProgressEvent);
		StdInput stdInput;

		if (stdInput.GetExitCode() != Success)
		{
			if (inParam->m_exitCode == Success)
				inParam->m_exitCode = stdInput.GetExitCode();
		}
		else
		{
			while (true)
			{
				HANDLE waitObjects[3] = { inParam->m_outputThread, stdInput.GetReadThreadHandle(), stdInput.GetNewDataEvent() };
				DWORD dw = WaitForMultipleObjects(3, waitObjects, false, INFINITE);
				if (dw == WAIT_OBJECT_0)
					break;
				else if (dw == WAIT_OBJECT_0 + 1)
				{	// stdInput::ReadThread terminated
					DWORD winExitCode = 0;
					GetExitCodeThread(stdInput.GetReadThreadHandle(), &winExitCode); 
					wcout << L"ReadFile() failed with Windows error " << winExitCode << L"." << endl;
					if (inParam->m_exitCode == Success)
						inParam->m_exitCode = FatalError;
					break;
				}

				if (!stdInput.GetDataSize()) // user enters Eof (Ctrl-Z + Enter)?
					break;

				pChannel->Send(Data(stdInput.GetDataPtr(), stdInput.GetDataSize()), false, progress);
				stdInput.ContinueRead();
				progress->WaitDone();
			}
		}
	}
	catch (Exception const& e)
	{
		wcout << e.What() << endl;
		if (inParam->m_exitCode == Success)
			inParam->m_exitCode = FatalError;
	}

	try	{ pChannel->Send(Data(), true, NULL); }	// send Eof (implicitly terminates ChannelOutputThread)
	catch (Exception const&) {}
	return 0;
}


// Channel output thread

struct OutputParams
{
	volatile unsigned int& m_exitCode;
	ClientSessionChannel* m_channel;
	HANDLE m_thisThread;

	~OutputParams() { if (m_thisThread) CloseHandle(m_thisThread); }
};

DWORD WINAPI ChannelOutputThread(LPVOID param)
{
	OutputParams* outParam =  static_cast<OutputParams*>(param);
	if (!outParam) return 0;	

	try
	{		
		HANDLE stdErr = GetStdHandle(STD_ERROR_HANDLE);
		HANDLE stdOut = GetStdHandle(STD_OUTPUT_HANDLE);	

		bool writeFileError = false;
		RefPtr<ClientSessionChannel> pChannel(outParam->m_channel);
		RefPtr<ReceiveEvent> receiver(new ReceiveEvent);

		do
		{
			pChannel->Receive(receiver);
			receiver->WaitDone();

			if (receiver->Success())	// true until ClientSessionChannel is or gets closed
			{
				unsigned char const* dataPtr = receiver->GetDataPtr();
				DWORD dataSyze				 = receiver->GetDataSize();

				while (dataSyze && !writeFileError)
				{
					DWORD written;
					if (WriteFile(receiver->StdErr() ? stdErr : stdOut, dataPtr, dataSyze, &written, 0))
					{
						dataPtr  += written;
						dataSyze -= written;
					}
					else
					{
						wcout << L"WriteFile() failed with Windows error " << GetLastError() << "." << endl;
						if (outParam->m_exitCode == Success)
							outParam->m_exitCode = FatalError;
						writeFileError = true;
					}
				}
			}
		} while (!writeFileError && receiver->Success() && !receiver->Eof());
	}	
	catch (Exception const& e) 
	{ 
		wcout << e.What() << endl;
		if (outParam->m_exitCode == Success)
			outParam->m_exitCode = FatalError;
	}

	return 0;
}


// wmain

int wmain(int argc, wchar_t const* argv[])
{
	CmdLineParams params(argc, argv);
	if (!params.m_bOk)
		return UsageError;

	volatile unsigned int exitCode = Success;
	try
	{
		// Initialize FlowSsh and register an ErrorHandler for uncaught exceptions in user-defined handlers.
		// Example: If there is an uncaught exception in MyClient::OnHostKey, 
		//          then this is reported in MyErrorHandler::OnExceptionInHandler.
		Initializer init(new MyErrorHandler());

		// For use in deployed applications where Bitvise SSH Client might not be installed, 
		// provide an activation code using SetActCode to ensure that FlowSsh does not  
		// display an evaluation dialog. On computers with Bitvise SSH Client, use of 
		// FlowSsh is permitted under the same terms as the Client; in this case, there 
		// will be no evaluation dialog, and SetActCode does not have to be called.
		//
		//init.SetActCode(L"Enter Your Activation Code Here");

		// create client

		RefPtr<MyClient> client(new MyClient(exitCode));	
		// CHANGE APPLICATION NAME IN PRODUCTION CODE!
		client->SetAppName(L"FlowSshCpp_Exec 1.0");
		client->SetClientSettings(params);

		// connect client

		RefPtr<ProgressEvent> progress(new ProgressEvent);
		client->Connect(progress);
		progress->WaitDone();
		
		if (!progress->Success())
		{	// Alternatively we could derive a class from ProgressEvent,  
			// override ProgressEvent::OnError, and do the logging there.
			std::wstring auxInfo;
			if (progress->GetAuxInfo().size()) auxInfo = progress->GetAuxInfo();
			else auxInfo = L"(no additional info)";

			switch(progress->GetTaskSpecificStep())
			{
			case FlowSshC_ConnectStep_ConnectToProxy:		wcout << L"Connecting to proxy server failed: " << auxInfo << endl; break;
			case FlowSshC_ConnectStep_ConnectToSshServer:	wcout << L"Connecting to SSH server failed: " << auxInfo << endl; break;
			case FlowSshC_ConnectStep_SshVersionString:		wcout << L"SSH version string failed: " << auxInfo << endl;	break;
			case FlowSshC_ConnectStep_SshKeyExchange:		wcout << L"SSH key exchange failed: " << auxInfo << endl; break;
			case FlowSshC_ConnectStep_SshUserAuth:			wcout << L"SSH authentication failed: " << auxInfo << endl;	break;
			default:										wcout << L"Connecting failed at unknown step: " << auxInfo << endl; break;
			}
			return SessionError;
		}

		// open session channel

		RefPtr<MyClientSessionChannel> channel(new MyClientSessionChannel(client, exitCode));			
		channel->OpenRequest(progress);
		progress->WaitDone();

		if (!progress->Success())
		{
			if (progress->GetAuxInfo().size())
				wcout << L"Opening session channel failed: " << progress->GetAuxInfo() << endl;
			else
				wcout << L"Opening session channel failed: (no additional info)" << endl;
			return SessionError;
		}

		// Create ChannelOutputThread (thread for reading channel's incoming data).
		// Note that data can arrive at any point after channel is opened.

		OutputParams outputParams = { exitCode, channel.Get(), 0 };
		outputParams.m_thisThread = CreateThread(0, 0, ChannelOutputThread, &outputParams, 0, 0);

		if (!outputParams.m_thisThread)
		{
			wcout << L"CreateThread() failed with Windows error " << GetLastError() << L"." << endl;
			return FatalError;
		}

		// send requests (pty, exec, shell)

		if (params.m_sPty.length())
		{
			channel->PtyRequest(params.m_sPty.c_str(), 80, 25, progress);
			progress->WaitDone();
			if (!progress->Success())
			{				
				if (progress->GetAuxInfo().size())
					wcout << L"Pty request failed: " << progress->GetAuxInfo() << endl;
				else
					wcout << L"Pty request failed: (no additional info)" << endl;
				return SessionError;
			}
		}

		if (params.m_sCmd.length())
		{				
			channel->ExecRequest(params.m_sCmd.c_str(), progress);
			progress->WaitDone();
			if (!progress->Success())
			{				
				if (progress->GetAuxInfo().size())
					wcout << L"Exec request failed: " << progress->GetAuxInfo() << endl;
				else
					wcout << L"Exec request failed: (no additional info)" << endl;
				return SessionError;
			}
		}
		else
		{
			channel->ShellRequest(progress);
			progress->WaitDone();
			if (!progress->Success())
			{				
				if (progress->GetAuxInfo().size())
					wcout << L"Shell request failed: " << progress->GetAuxInfo() << endl;
				else
					wcout << L"Shell request failed: (no additional info)" << endl;
				return SessionError;
			}
		}
				
		// create ChannelInputThread

		InputParams inputParams = { exitCode, channel.Get(), 0, outputParams.m_thisThread };		
		inputParams.m_thisThread = CreateThread(0, 0, ChannelInputThread, &inputParams, 0, 0);

		if (!inputParams.m_thisThread)
		{
			wcout << L"CreateThread() failed with Windows error " << GetLastError() << L"." << endl;
			
			channel->Close(NULL);	// close channel (implicitly terminates ChannelOutputThread)
			WaitForSingleObject(outputParams.m_thisThread, INFINITE);

			return FatalError;
		}

		// Output thread exits if channel gets closed or if Eof is received.
		// Input thread exits if output thread has finished or if user enters Eof (i.e. Ctrl + Z).
		HANDLE waitThreads[2] = { inputParams.m_thisThread, outputParams.m_thisThread };
		WaitForMultipleObjects(2, waitThreads, true, INFINITE);

		client->Disconnect(progress);
		progress->WaitDone();
	}
	catch (Exception const& e)
	{
		wcout << e.What() << endl;
		return FatalError;
	}
	return (int)exitCode;
}