/***********************************************************
 * Threaded chat server                                    *
 * solution to programming exercise                        *
 * Verteilte Systeme II Kapitel 6                          *
 ***********************************************************
 * 1998 by Frank Kargl (frank.kargl@rz.uni-ulm.de)         *
 ***********************************************************
 * Usage: chat <port> (default port = 1099)                *
 ***********************************************************/

/* debug version ? */
#undef DEBUG

/* required by solaris */
#define _REENTRANT

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <thread.h>
#include <synch.h>
#include <signal.h>
#include <sys/socket.h>
#include <arpa/inet.h>

/* default port to use */
#define DEFPORT 1099
/* max line length */
#define MAXLINE 1024 
/* number of buffer lines */
#define MAXMSG 8

/* format of message in buffer */
struct msg_t {
    long int id;		/* message id, must be unique
				   there is a overflow problem here ! */
    int sent;			/* message has been sent to this many
				   clients */
    rwlock_t linelock;		/* rw lock for this line */
    char line[MAXLINE];		/* message line */
};

/* argument structure for threads */
struct threadarg_t {
    int socket;			/* socket for I/O */
    thread_t tid;		/* thread ID of corresponding thread */
};

/* global buffer for messages */
struct msg_t buffer[MAXMSG];
/* semaphore for buffer handling */
sema_t buffer_free;
/* mutex for locking of buffer while checking for free line */
mutex_t bufferlock;
/* number of clients */
int num_clients = 0;

/***********************************************************
 *                                                         *
 * Function:                                               *
 * FILE* getstream(int sock, int inout)                    *
 * convert a socket to a line buffered stream              *
 * Parameters:                                             * 
 *  sock - socket to convert                               *
 *  inout - 0 for Input-Stream, 1 for Output-Stream        *
 * Return:                                                 *
 *  stream pointer - NULL if error                         *
 *                                                         *
 ***********************************************************/

void* getstream(int sock, int inout) {

    FILE *fp;				/* stream pointer */

    /* get file streams */
    if (inout == 0) {
	fp = fdopen(sock, "r");
    } else if (inout = 1) {
	fp = fdopen(sock, "w");
    } else {
	/* wrong argument */
	return NULL;
    }

    /* set line buffered */
    setvbuf(fp, NULL, _IOLBF, BUFSIZ);

    return fp;

}

/***********************************************************
 *                                                         *
 * Function:                                               *
 * void* write_lines(int sock)                             *
 * read line from network and write it to buffer           *
 * Parameters:                                             * 
 *  sock - socket to read from                             *
 * Return:                                                 *
 *  always NULL                                            *
 *                                                         *
 ***********************************************************/

void* write_lines(void* arg) {
    
    struct threadarg_t* targ;		/* thread argument for conversion */

    char line[MAXLINE];			/* line buffer */
    FILE *fpin;				/* file pointer to read from */
    int i;
    int sock;				/* socket to use */
    int tid;				/* thread id of corresponding thread */

#ifdef DEBUG
    fprintf(stderr, "Thread %d starting\n", thr_self());
    fflush(stderr); sleep(1);
#endif

    /* convert arguments */
    targ = (struct threadarg_t*)arg;
    sock = targ->socket;
    tid  = targ->tid;
    free(targ);

#ifdef DEBUG
    fprintf(stderr, "%d: socket = %d, tid = %d\n",thr_self(),sock,tid);
#endif

    /* convert socket to stream */
    fpin = getstream(sock, 0);
    
    /* read lines */
    while( fgets(line, MAXLINE, fpin) != NULL ) {
	
#ifdef DEBUG
	fprintf(stderr, "%d: read line #%s#\n",thr_self(),line);
#endif
	
	/* wait for free buffer slot */
	sema_wait(&buffer_free);
	
	/* lock the whole buffer so we can allocate an entry */
	mutex_lock(&bufferlock);

	/* search a free buffer */
	i=0;
	while (i<=MAXMSG) {
	    if (i == MAXMSG) {
		fprintf(stderr, "Error: no free buffer\n");
		exit(1);
	    }
	    if (buffer[i].sent == -1) {
		/* got it */
		break;
	    }
	    i++;
	}
	
	/* set buffer to allocted */
	buffer[i].sent = 0;

	/* obtain writelock on buffer line */
	rw_wrlock(&buffer[i].linelock);

	/* free the whole buffer */
	mutex_unlock(&bufferlock);

	/* copy buffer */

#ifdef DEBUG
	fprintf(stderr, "%d: writing to buffer %d\n",thr_self(),i);
#endif

	strcpy(buffer[i].line, line);
	buffer[i].id++;
	
	/* free buffer line lock */
	rw_unlock(&buffer[i].linelock);
    }
    
    /* quit this client */
    num_clients--;

    /* kill corresponding read thread */
    thr_kill(tid, SIGINT);

    return NULL;
    
}

/***********************************************************
 *                                                         *
 * Function:                                               *
 * void* read_lines(int sock)                              *
 * read lines from buffer and write them to network        *
 * Parameters:                                             * 
 *  sock - socket to write to                              *
 * Return:                                                 *
 *  always NULL                                            *
 *                                                         *
 ***********************************************************/

void* read_lines(void* arg) {

    struct threadarg_t* targ;		/* thread argument for conversion */
    int msgvec[MAXMSG];			/* here we store the message
					   vector that contains the
					   message ids as far as they have
					   been delivered */
    int sock;				/* socket to use */
    int i;

#ifdef DEBUG
    fprintf(stderr, "Thread %d starting\n", thr_self());
    fflush(stderr); sleep(1);
#endif

    /* convert arguments */
    targ = (struct threadarg_t*)arg;
    sock = targ->socket;
    free(targ);

#ifdef DEBUG
    fprintf(stderr, "%d: socket = %d\n",thr_self(),sock);
#endif

    /* set message vector to current values */
    for (i=0; i<MAXMSG; i++) {
	msgvec[i] = buffer[i].id;
    }

    /* loop forever */
    while(1) {

	/* check if any new messages appeared */
	for (i=0; i<MAXMSG; i++) {
	    if ( msgvec[i] < buffer[i].id ) {

#ifdef DEBUG
		fprintf(stderr, "%d: new message in buffer %d\n",
			thr_self(),i);
#endif

		/* a new msg has arrived in this buffer line */
		msgvec[i] = buffer[i].id;
		/* obtain readlock on buffer line */
		rw_rdlock(&buffer[i].linelock);
		/* copy buffer */
		
#ifdef DEBUG
		fprintf(stderr,"%d: printing line #%s#\n",
			thr_self(), buffer[i].line);
#endif

		write(sock, buffer[i].line, strlen(buffer[i].line)+1);

#ifdef DEBUG
		fprintf(stderr,"%d: printed\n",thr_self());
#endif

		/* free lock */
		rw_unlock(&buffer[i].linelock);
		/* increase client count */
		buffer[i].sent++;
		
#ifdef DEBUG
		fprintf(stderr,"%d: buffer[%d].sent = %d, num_clients = %d\n",
			thr_self(), i, buffer[i].sent, num_clients);
#endif

		if (buffer[i].sent >= num_clients) {
		    /* free buffer line */
		    
#ifdef DEBUG
		    fprintf(stderr, "%d: buffer[%d] freed\n",
			    thr_self(), i);
#endif

		    buffer[i].sent = -1;
		    sema_post(&buffer_free);
		}
	    }
	}
    }
    
    return NULL;
    
}

/***********************************************************
 *                                                         *
 * Function:                                               *
 * void usage(char* name)                                  *
 * print usage message                                     *
 * Parameters:                                             * 
 *  name - name of executable                              *
 * Return:                                                 *
 *  -                                                      *
 *                                                         *
 ***********************************************************/

void usage(char* name) {
    printf("%s - a simple chat server\n", name);
    printf("Usage: %s <port> (default port = %d)\n", name, DEFPORT);
    exit(1);
}

int main(int argc, char** argv) {
    
    int lsocket;			/* listen socket */
    int asocket;			/* accept socket */
    int port = DEFPORT;			/* port to use */
    struct sockaddr_in servaddr;	/* server address */
    struct sockaddr_in cliaddr;		/* client address */
    int cliaddr_len;			/* length of client address */
    int ret;				/* generic return value */
    thread_t tid1, tid2;		/* thread ids */
    struct threadarg_t* targ;		/* arg for thread */
    int i;

    /* init semaphore */
    sema_init(&buffer_free, MAXMSG, USYNC_THREAD, NULL);
    /* init mutex */
    mutex_init(&bufferlock, USYNC_THREAD, NULL);
    /* init rw locks and buffer lines */
    for (i=0; i<MAXMSG; i++) {
	rwlock_init(&buffer[i].linelock, USYNC_THREAD, NULL);
	buffer[i].sent = -1;
    }

    /* check for port */
    if (argc == 2) {
	port = atoi(argv[1]);
    } else if (argc != 1) {
	usage(argv[0]);
    }
    
    /* open listen socket */
    lsocket = socket(AF_INET, SOCK_STREAM, 0);
    if (lsocket == -1) {
	perror("Can't open lsocket");
	exit(1);
    }
    
    /* bind socket to port */
    servaddr.sin_addr.s_addr = htonl(INADDR_ANY);
    servaddr.sin_port = htons(port);
    ret = bind(lsocket, (struct sockaddr*) &servaddr, sizeof(servaddr));
    if (ret == -1) {
	perror("Can't bind lsocket");
	exit(1);
    }
    
    /* listen to socket */
    ret = listen(lsocket, 10);
    if (ret == -1) {
	perror("Can't listen to lsocket");
	exit(1);
    }
    
    /* server loop */
    while (1) {
	
	/* accept connection */
	cliaddr_len = sizeof(cliaddr);
	asocket = accept(lsocket, (struct sockaddr*) &cliaddr, &cliaddr_len);
	if (ret == -1) {
	    perror("Can't accept on lsocket");
	    exit(1);
	}

	/* "log" connection */
	printf("Connection from %s, port %d\n",
	       	inet_ntoa(cliaddr.sin_addr),
		ntohs(cliaddr.sin_port));
	num_clients++;
	
	/* create two new threads to handle this request */
	targ = malloc(sizeof(struct threadarg_t));
	targ->socket = asocket;
	targ->tid = 0;
	ret = thr_create(NULL, NULL, read_lines, targ, THR_NEW_LWP, &tid1);
	if (ret != 0) {
	    fprintf(stderr, "Thread creation error\n");
	    exit(1);
	}
	targ = malloc(sizeof(struct threadarg_t));
	targ->socket = asocket;
	targ->tid = tid1;
	ret = thr_create(NULL, NULL, write_lines, targ, THR_NEW_LWP, &tid2);
	if (ret != 0) {
	    fprintf(stderr, "Thread creation error\n");
	    exit(1);
	}

    }
    
    /* when do we close lsocket ? */
    /* propably should catch signal or so */

}
