// Windows
#include <Windows.h>

// CRT
#include <iostream>

// FlowSshCpp
#include "FlowSshCpp.h"

// name spaces
using namespace std;
using namespace FlowSshCpp;


// MyErrorHandler

class MyErrorHandler : public ErrorHandler
{
private:
	// This object must be created on the heap.
	~MyErrorHandler() {}

protected:
	virtual void OnExceptionInHandler(bool fatal, wchar_t const* desc)
	{
		wcout << (fatal ? L"Error [fatal] in handler: " : L"Error in handler: ");
		wcout << (desc ? desc : L"<no error description>") << endl;
		// For the sake of simplicity, no cleanup is performed here.
		// The lack of cleanup may cause debugger to report memory leaks.
		exit(3);
	}
};


// MyClient

class MyClient : public Client
{
private:
	// This object must be created on the heap.
	~MyClient() {}
	std::wstring m_hkfp; // SHA-256

public:
	MyClient(std::wstring hkfp) { m_hkfp = hkfp; }

protected:
	bool OnHostKey(RefPtr<PublicKey> publicKey)
	{
		return publicKey->GetSha256() == m_hkfp;
	}
};


// XferThread handling 

enum XferVerb { XferVerb_Put, XferVerb_Get };

struct XferCommand : public RefCountable
{
private:
	// This object must be created on the heap.
	~XferCommand() {}

public:
	XferVerb m_verb;
	std::wstring m_localPath;
	std::wstring m_remotePath;
};


class Coordination : public RefCountable
{
private:
	// This object must be created on the heap.
	virtual ~Coordination() { CloseHandle(m_stopEvent); }

public:
	Coordination() : m_stopEvent(CreateEvent(0, true, false, 0))
	{ if (!m_stopEvent) throw CreateEventException(); }
	
	bool Wait(DWORD timeout) const { return WaitForSingleObject(m_stopEvent, timeout) == WAIT_OBJECT_0; }
	bool IsStopped() const { return Wait(0); }

	std::wstring GetStopReason() const
	{
		std::wstring ret;
		{
			CsLocker locker(m_cs);
			ret = m_stopReason;
		}
		return ret;
	}
	
	void Stop(std::wstring reason)
	{
		CsLocker locker(m_cs);
		{
			if (m_stopReason.length() == 0)
				m_stopReason = reason;

			SetEvent(m_stopEvent);
		}
	}

private:
	mutable CriticalSection m_cs;
	HANDLE m_stopEvent;
	std::wstring m_stopReason;
};


struct XferThreadParams : public RefCountable
{
protected:
	// This object must be created on the heap.
	~XferThreadParams()
	{
		if (!!m_thisThread && m_thisThread != INVALID_HANDLE_VALUE)
			CloseHandle(m_thisThread);
	}

public:
	XferThreadParams()
	{
		m_threadNr = m_port = 0;
		m_thisThread = INVALID_HANDLE_VALUE;
	}		
	
	RefPtr<Coordination> m_coord;
	unsigned int m_threadNr;
	std::wstring m_host;
	unsigned int m_port;
	std::wstring m_hkfp;
	std::wstring m_user;
	std::wstring m_pass;
	RefPtr<XferCommand> m_cmd;

	HANDLE m_thisThread;
};


class XferErr : public Exception	
{
public:
	XferErr (wchar_t const* msg) : Exception(msg) {}
};


DWORD WINAPI XferThread(LPVOID param)
{
	RefPtr<XferThreadParams> xt = static_cast<XferThreadParams*>(param);
	if (!xt.Get()) return 0;
	
	try
	{
		wcout << L"Thread " << xt->m_threadNr << L" started" << endl;

		while (true)
		{
			RefPtr<MyClient> client = new MyClient(xt->m_hkfp);
			{
				client->SetAppName(L"FlowSshCpp_SftpStress");
				client->SetHost(xt->m_host.c_str());
				client->SetPort(xt->m_port);
				client->SetUserName(xt->m_user.c_str());
				client->SetPassword(xt->m_pass.c_str());

				wcout << L"Thread " << xt->m_threadNr << L" connecting" << endl;

				// Client.Conenct
				{
					RefPtr<ProgressEvent> progress(new ProgressEvent);
					client->Connect(progress);
					progress->WaitDone();
					if (!progress->Success())
						throw XferErr(progress->DescribeConnectError().c_str());
				}				

				RefPtr<ClientSftpChannel> sftp = new ClientSftpChannel(client);
				{
					// SftpChannel.Open
					{
						RefPtr<ProgressEvent> progress = new ProgressEvent();
						sftp->Open(progress);
						progress->WaitDone();
						if (!progress->Success())
							throw new XferErr(progress->DescribeSftpChannelOpenError().c_str());
					}

					RefPtr<TransferEvent> th = new TransferEvent();

					if (xt->m_cmd->m_verb == XferVerb_Put)
					{
						wcout << L"Thread" << xt->m_threadNr << L" putting" << xt->m_cmd->m_localPath << L" -> " << xt->m_cmd->m_remotePath << endl;
						sftp->Upload(xt->m_cmd->m_localPath.c_str(), xt->m_cmd->m_remotePath.c_str(), FlowSshC_TransferFlags_Binary | FlowSshC_TransferFlags_Overwrite, th);
					}
					else
					{
						wcout << L"Thread" << xt->m_threadNr << L" getting" << xt->m_cmd->m_remotePath << L" -> " << xt->m_cmd->m_localPath << endl;
						sftp->Download(xt->m_cmd->m_remotePath.c_str(), xt->m_cmd->m_localPath.c_str(), FlowSshC_TransferFlags_Binary | FlowSshC_TransferFlags_Overwrite, th);
					}

					th->WaitDone();
					if (!th->Success())
						throw XferErr(th->GetError().Describe().c_str());

					wcout << L"Thread " << xt->m_threadNr << L" transferred " << th->GetTransferStat().m_bytesTransferred << L" bytes" << endl;
				}

				wcout << L"Thread " << xt->m_threadNr << L" disconnecting" << endl;

				// Client.Disconnect
				{
					RefPtr<ProgressEvent> progress = new ProgressEvent();
					client->Disconnect(progress);
					progress->WaitDone();
				}
			}

			if (xt->m_coord->IsStopped())
				break;
		}

		xt->m_coord->Stop(L"XferThread " + std::to_wstring((unsigned __int64) xt->m_threadNr) + L" exited without exception");
	}
	catch (Exception const& e)
	{		
		xt->m_coord->Stop(L"XferThread " + std::to_wstring((unsigned __int64) xt->m_threadNr) + L" exited with exception:\r\n" + e.What());
	}
	
	return 0;
}


// wmain 

int wmain(int argc, wchar_t const* argv[])
{	
	// Initialize FlowSsh and register an ErrorHandler for uncaught exceptions in user-defined handlers.
	// Example: If there is an uncaught exception in MyClient::OnHostKey, 
	//          then this is reported in MyErrorHandler::OnExceptionInHandler.
	Initializer init(new MyErrorHandler());

	if (argc < 7)
	{
		wcout << endl;
		wcout << L"Usage: <host> <port> <hostKeySha256> <user> <pass> <thread1command> [; <thread2command> ... ]" << endl;
		wcout << L"A thread command can be either a put or get:" << endl;
		wcout << L"  put \"localFilePath\" \"remoteFilePath\"" << endl;
		wcout << L"  get \"remoteFilePath\" \"localFilePath\"" << endl;
		wcout << L"Thread commands should not conflict, or have inter-dependencies." << endl;
		wcout << L"Each command will be executed repeatedly and simultaneously on a separate thread." << endl;
		return 2;
	}

	std::wstring host = argv[1];
	unsigned int port = 0; swscanf_s(argv[2], L"%u", &port);
	std::wstring hkfp = argv[3];
	std::wstring user = argv[4];
	std::wstring pass = argv[5];
	
	std::vector<RefPtr<XferCommand>> commands;
	for (int i=6; ; )
	{
		XferVerb verb;
		if (std::wstring(argv[i]) == L"put") verb = XferVerb_Put;
		else if (std::wstring(argv[i]) == L"get") verb = XferVerb_Get;
		else { wcout << L"Unrecognized command verb: " << argv[i] << endl; return 2; }

		if (++i == argc) { wcout << L"Missing command path 1" << endl; return 2; }
		std::wstring path1 = argv[i];

		if (++i == argc) { wcout << L"Missing command path 2" << endl; return 2; }
		std::wstring path2 = argv[i];

		RefPtr<XferCommand> cmd = new XferCommand;
		cmd->m_verb       = verb;
		cmd->m_localPath  = (verb == XferVerb_Put ? path1 : path2);
		cmd->m_remotePath = (verb == XferVerb_Put ? path2 : path1);
		commands.push_back(cmd);

		if (++i == argc) break;
		if (std::wstring(argv[i]) != L";") { wcout << L"Expecting ; between commands" << endl; return 2; }
		if (++i == argc)				   { wcout << L"Expecting command after ; separator" << endl; return 2; }
	}

	wcout << L"Starting threads. Press Esc to exit" << endl;

	RefPtr<Coordination> coord = new Coordination();
	std::vector<RefPtr<XferThreadParams>> threads;
	unsigned int threadNr = 1;			

	for (size_t c = 0; c < commands.size(); ++c)
	{
		RefPtr<XferThreadParams> xt = new XferThreadParams();

		xt->m_coord = coord;
		xt->m_threadNr = threadNr++;
		xt->m_host = host;
		xt->m_port = port;
		xt->m_hkfp = hkfp;
		xt->m_user = user;
		xt->m_pass = pass;
		xt->m_cmd = commands[c];

		xt->m_thisThread = CreateThread(0, 0, XferThread, xt.Get(), 0, 0);
		threads.push_back(xt);
	}

	// Wait Esc key or child thread exit

	while (true)
	{
		if(GetAsyncKeyState(VK_ESCAPE))
			coord->Stop(L"Esc key pressed");

		if (coord->Wait(200))
			break;
	}

	wcout << L"Stopping..." << endl;

	for (size_t c = 0; c < threads.size(); ++c)
		WaitForSingleObject(threads[c]->m_thisThread, INFINITE);
			
	if (coord->GetStopReason().length() != 0)
		wcout << L"Stopped: " << coord->GetStopReason() << endl;
		
	return 0;
}