| 
#include <u.h>
#include <libc.h>
#include <oventi.h>
#include "session.h"
static char EAuthState[] = "bad authentication state";
static char ENotServer[] = "not a server session";
static char EVersion[] = "incorrect version number";
static char EProtocolBotch[] = "venti protocol botch";
VtSession *
vtServerAlloc(VtServerVtbl *vtbl)
{
	VtSession *z = vtAlloc();
	z->vtbl = vtMemAlloc(sizeof(VtServerVtbl));
	*z->vtbl = *vtbl;
	return z;
}
static int
srvHello(VtSession *z, char *version, char *uid, int , uchar *, int , uchar *, int )
{
	vtLock(z->lk);
	if(z->auth.state != VtAuthHello) {
		vtSetError(EAuthState);
		goto Err;
	}
	if(strcmp(version, vtGetVersion(z)) != 0) {
		vtSetError(EVersion);
		goto Err;
	}
	vtMemFree(z->uid);
	z->uid = vtStrDup(uid);
	z->auth.state = VtAuthOK;
	vtUnlock(z->lk);
	return 1;
Err:
	z->auth.state = VtAuthFailed;
	vtUnlock(z->lk);
	return 0;
}
static int
dispatchHello(VtSession *z, Packet **pkt)
{
	char *version, *uid;
	uchar *crypto, *codec;
	uchar buf[10];
	int ncrypto, ncodec, cryptoStrength;
	int ret;
	Packet *p;
	p = *pkt;
	version = nil;	
	uid = nil;
	crypto = nil;
	codec = nil;
	ret = 0;
	if(!vtGetString(p, &version))
		goto Err;
	if(!vtGetString(p, &uid))
		goto Err;
	if(!packetConsume(p, buf, 2))
		goto Err;
	cryptoStrength = buf[0];
	ncrypto = buf[1];
	crypto = vtMemAlloc(ncrypto);
	if(!packetConsume(p, crypto, ncrypto))
		goto Err;
	if(!packetConsume(p, buf, 1))
		goto Err;
	ncodec = buf[0];
	codec = vtMemAlloc(ncodec);
	if(!packetConsume(p, codec, ncodec))
		goto Err;
	if(packetSize(p) != 0) {
		vtSetError(EProtocolBotch);
		goto Err;
	}
	if(!srvHello(z, version, uid, cryptoStrength, crypto, ncrypto, codec, ncodec)) {
		packetFree(p);
		*pkt = nil;
	} else {
		if(!vtAddString(p, vtGetSid(z)))
			goto Err;
		buf[0] = vtGetCrypto(z);
		buf[1] = vtGetCodec(z);
		packetAppend(p, buf, 2);
	}
	ret = 1;
Err:
	vtMemFree(version);
	vtMemFree(uid);
	vtMemFree(crypto);
	vtMemFree(codec);
	return ret;
}
static int
dispatchRead(VtSession *z, Packet **pkt)
{
	Packet *p;
	int type, n;
	uchar score[VtScoreSize], buf[4];
	p = *pkt;
	if(!packetConsume(p, score, VtScoreSize))
		return 0;
	if(!packetConsume(p, buf, 4))
		return 0;
	type = buf[0];
	n = (buf[2]<<8) | buf[3];
	if(packetSize(p) != 0) {
		vtSetError(EProtocolBotch);
		return 0;
	}
	packetFree(p);
	*pkt = (*z->vtbl->read)(z, score, type, n);
	return 1;
}
static int
dispatchWrite(VtSession *z, Packet **pkt)
{
	Packet *p;
	int type;
	uchar score[VtScoreSize], buf[4];
	p = *pkt;
	if(!packetConsume(p, buf, 4))
		return 0;
	type = buf[0];
	if(!(z->vtbl->write)(z, score, type, p)) {
		*pkt = 0;
	} else {
		*pkt = packetAlloc();
		packetAppend(*pkt, score, VtScoreSize);
	}
	return 1;
}
static int
dispatchSync(VtSession *z, Packet **pkt)
{
	(z->vtbl->sync)(z);
	if(packetSize(*pkt) != 0) {
		vtSetError(EProtocolBotch);
		return 0;
	}
	return 1;
}
int
vtExport(VtSession *z)
{
	Packet *p;
	uchar buf[10], *hdr;
	int op, tid, clean;
	if(z->vtbl == nil) {
		vtSetError(ENotServer);
		return 0;
	}
	/* fork off slave */
	switch(rfork(RFNOWAIT|RFMEM|RFPROC)){
	case -1:
		vtOSError();
		return 0;
	case 0:
		break;
	default:
		return 1;
	}
	
	p = nil;
	clean = 0;
	vtAttach();
	if(!vtConnect(z, nil))
		goto Exit;
	vtDebug(z, "server connected!\n");
if(0)	vtSetDebug(z, 1);
	for(;;) {
		p = vtRecvPacket(z);
		if(p == nil) {
			break;
		}
		vtDebug(z, "server recv: ");
		vtDebugMesg(z, p, "\n");
		if(!packetConsume(p, buf, 2)) {
			vtSetError(EProtocolBotch);
			break;
		}
		op = buf[0];
		tid = buf[1];
		switch(op) {
		default:
			vtSetError(EProtocolBotch);
			goto Exit;
		case VtQPing:
			break;
		case VtQGoodbye:
			clean = 1;
			goto Exit;
		case VtQHello:
			if(!dispatchHello(z, &p))
				goto Exit;
			break;
		case VtQRead:
			if(!dispatchRead(z, &p))
				goto Exit;
			break;
		case VtQWrite:
			if(!dispatchWrite(z, &p))
				goto Exit;
			break;
		case VtQSync:
			if(!dispatchSync(z, &p))
				goto Exit;
			break;
		}
		if(p != nil) {
			hdr = packetHeader(p, 2);
			hdr[0] = op+1;
			hdr[1] = tid;
		} else {
			p = packetAlloc();
			hdr = packetHeader(p, 2);
			hdr[0] = VtRError;
			hdr[1] = tid;
			if(!vtAddString(p, vtGetError()))
				goto Exit;
		}
		vtDebug(z, "server send: ");
		vtDebugMesg(z, p, "\n");
		if(!vtSendPacket(z, p)) {
			p = nil;
			goto Exit;
		}
	}
Exit:
	if(p != nil)
		packetFree(p);
	if(z->vtbl->closing)
		z->vtbl->closing(z, clean);
	vtClose(z);
	vtFree(z);
	vtDetach();
	exits(0);
	return 0;	/* never gets here */
}
 |