/****************************************************************
 *                                                              *
 *  LIBDIST V1.0						*
 *                                                              *
 *  comm.c -- communication primitives                          *
 *                                                              *
 *  Last changed: 01.12.98                                      *
 *  Author: Frank Kargl (frank.kargl@informatik.uni-ulm.de)     *
 *                                                              *
 *  Restrictions: works only for Solaris 2.6 or above           *
 *                                                              *
 ****************************************************************/

#include "libdist.h"
#include "config.h"

#include <string.h>		/* for memset */
#include <netdb.h>		/* for nameservice */
#include <arpa/inet.h>		/* for inet_addr */
#include <synch.h>		/* for semaphores */
#include <sys/ioctl.h>		/* ioctl to get broadcast/if addr */
#include <sys/sockio.h>		/* ioctl to get broadcast/if addr */
#include <net/if.h>		/* ioctl to get broadcast/if addr */

/***
 *** int dl_sck_getport(char *service,int type)
 ***
 *** Function: find the port number for an Internet service
 *** Return  : port number
 ***	       DL_ERROR if error
 ***/

int dl_sck_getport(char *service,int type) {

	struct	servent	*se_ptr;	/* service entry */
	ushort	port;			/* the port number */
	char	proto[5];		/* either "udp" or "tcp" */

	/* find the right port */
	if ( type==DL_SCK_UDP ) {
		strcpy(proto,"udp");
	} else if ( type==DL_SCK_TCP ) {
		strcpy(proto,"tcp");
	} else {

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-comm: wrong protocol\n");
#endif

		return DL_ERROR;
	}

	if ( (se_ptr=getservbyname(service,proto)) ) {
			port=ntohs(se_ptr->s_port);
	} else if ( (port=(unsigned short)atoi(service)) == 0 ) {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: can't find port for %s\n",service);
#endif

		return DL_ERROR;
	}

	return port;
}

/***
 *** int dl_sck_server(char *service, int type, void *(*function)(int *sock))
 ***
 *** Function: create a server socket of specified type
 ***           on the port that is specified by service
 ***	       service may either be a number or a service name
 ***	       according to services(4)
 ***           when service is NULL, then don't do bind
 ***           when *function != NULL and type==DL_SCK_TCP
 ***           then dl_sck_server
 ***           accepts incoming requests, forks (a thread) and passes
 ***           the client socket to function
 ***           in this case dl_sck_server is an endless loop
 ***           if type==DL_SCK_UDP then function is ignored and
 ***           should be NULL
 *** Return  : the socket id
 ***           DL_ERROR if error
 ***/

int dl_sck_server(char *service, int type, void *(*function)(int *sock)) {
	
	struct sockaddr_in sin;	/* an Internet endpoint address (IEA) */
	struct sockaddr_in cli;	/* client's IEA */
	int	cli_len;	/* length of clients addresses */
	int	s;		/* a socket */
	int	*cs;		/* a socket ptr(for client connections) */
	int	port;		/* port to use */
	int	ret;		/* temporary debug value */
	
	/* try to create a socket */
	if ( type==DL_SCK_TCP ) {
		s=socket(PF_INET,SOCK_STREAM,IPPROTO_TCP);
	} else if ( type==DL_SCK_UDP ) {
		s=socket(PF_INET,SOCK_DGRAM,IPPROTO_UDP);
	} else {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: wrong socket type\n");
#endif

		return DL_ERROR;
	}

	if (s==-1) {

#ifdef DEBUG
	    fprintf(stderr,"LIBDIST-comm: can't create socket\n");
#endif

	    return DL_ERROR;
	}

	/* if service == NULL then don't bind */
	if (service) {

	    /* prepare sin */
	    memset((void *)&sin,0,sizeof(sin));
	    sin.sin_family	= AF_INET;
	    sin.sin_addr.s_addr	= INADDR_ANY;

	    /* get the right port */
	    port=dl_sck_getport(service,type);
	    if (port == DL_ERROR) {
		close(s);
		return DL_ERROR;
	    }
	    sin.sin_port=htons(port);

	    /* bind the socket */
	    if ((ret=bind(s, (struct sockaddr *) &sin, sizeof(sin))) == -1) {

#ifdef DEBUG
		    fprintf(stderr,"LIBDIST-comm: can't bind socket\n");
#endif

		    close(s);
		    return DL_ERROR;
	    }

	    /* if TCP socket then listen */
	    if (type==DL_SCK_TCP) {
		if (listen(s,QLEN) < 0) {

#ifdef DEBUG
		    fprintf(stderr,"LIBDIST-comm: can't listen on socket\n");
#endif

		    close(s);
		    return DL_ERROR;
		    }

		/* should we also handle accepts ? */
		if (function!=NULL) {

		    /* loop forever */
		    while(1) {

			/* accept a new client connection */
			cs = malloc(sizeof(int));
			*cs = accept(s,(struct sockaddr *) &cli, &cli_len);
			if (*cs<0) {

#ifdef DEBUG
			    fprintf(stderr,"LIBDIST-comm: can't accept on socket\n");
#endif

			    close(s);
			    return DL_ERROR;
			}

			/* pass new socket descriptor to new thread */
			dl_thr_create((void *)function,(void *)cs);
		    }
		}
	    } else {
		/* UDP case not handled now */
	    }
	}

	return s;
}

/***
 *** int dl_sck_connect(char *rem_ip, char *service, int type)
 ***
 *** Function: open a connection to a remote server
 ***           rem_ip may either be a dotted-decimal or
 ***           DNS style address string
 ***           type is analogous to dl_sck_server
 *** Return  : the socket id
 ***           DL_ERROR if error
 ***/

int dl_sck_connect(char *rem_ip, char *service, int type) {

	struct hostent *he_ptr;	/* Pointer to host information entry */
	struct sockaddr_in sin;	/* IEA */
	int	s;		/* a socket */
	int	port;		/* port to use */

	/* Prepare sin */
	memset((void *)&sin,0,sizeof(sin));
	sin.sin_family = AF_INET;

	/* get the right port */
	port=dl_sck_getport(service,type);
	if (port == DL_ERROR) {
	    return DL_ERROR;
	}
	sin.sin_port=htons(port);
	
	/* get the right address */
	if ( (he_ptr=(gethostbyname(rem_ip))) ) {
		memcpy((char *)&sin.sin_addr, he_ptr->h_addr, he_ptr->h_length);
	} else if ( (sin.sin_addr.s_addr = inet_addr(rem_ip)) == -1 ) {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: can't gethostbyname\n");
#endif

		return DL_ERROR;
	}

	/* now try to create a socket */
	if ( type==DL_SCK_TCP ) {
		s=socket(PF_INET,SOCK_STREAM,IPPROTO_TCP);
	} else if ( type==DL_SCK_UDP ) {
		s=socket(PF_INET,SOCK_DGRAM,IPPROTO_UDP);
	} else {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: wrong socket type\n");
#endif

		return DL_ERROR;
	}

	if (s == -1) {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: can't create socket\n");
#endif

		return DL_ERROR;
	}
	
	/* connect the socket */
	if (connect(s, (struct sockaddr *)&sin, sizeof(sin)) < 0) {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: can't connect socket\n");
#endif

		close(s);
		return DL_ERROR;
	}

	return s;
}

/***
 *** void dl_sck_close(int socket)
 ***
 *** Function: close a socket connection
 *** Return  : -
 ***/

void dl_sck_close(int socket) {

	close(socket);

}

/***
 *** char *dl_sck_receive(int socket)
 ***
 *** Function: read one line of text from 'socket'
 *** Return  : one textline stripped of any trailing <LF> or <CR><LF>
 ***           EOF when end of connection reached
 ***           NULL indicates an error
 ***           as the buffer for line is static the result
 ***           should be copied elsewhere before the next call
 ***           to dl_sck_receive
 ***/

char *dl_sck_receive(int socket) {

	static sema_t dl_sck_lock = DL_SEM_UISEM;
				/* semaphore for shared access to bufferfield */
				/* initialized this way so we can determine */
				/* if we have done sema_init yet */

	char *buffer;		/* implementation of a ring buffer */
	char *beg_ptr,*end_ptr; /* Pointers to begin/end of used buffer */

	static struct buffers {
		int sock;		/* associated with this socket */
		char *begin;		/* begin of used area */
		char *end;		/* end of used area */
		char data[DL_MAXLINE*2];	/* actual line */
	} bufferfield[DL_SCK_MAXELE] =	/* Field of buffer entries */
		{ {-1, NULL, NULL, ""} };
				/* init value needed to determine */
				/* an empty bufferfield */

	int i,empty;		/* indexes into buffer */

	char *ptr,*ptr2;	/* temporary pointers */
	int numbytes;		/* number of bytes read */
	
	/* but first lock everything */
	if (dl_sck_lock.count!=1 && dl_sck_lock.count!=0) {
		sema_init(&dl_sck_lock, 1, USYNC_THREAD, NULL);
	}
	sema_wait(&dl_sck_lock);

	/* get the right buffer */

	empty = -1;
	for (i=0;i<DL_SCK_MAXELE;i++) {
		switch (bufferfield[i].sock) {

		/*** found a free entry ***/
		case -2:

			if (empty == -1) {
				empty=i;
			}
			break;

		/*** new entry ***/
		case -1:

			if (empty == -1) {
				/* use new field */
				empty=i;
				bufferfield[i+1].sock = -1;
			}
			bufferfield[empty].sock=socket;
			bufferfield[empty].data[0]='\0';
			bufferfield[empty].begin=bufferfield[empty].data;
			bufferfield[empty].end=bufferfield[empty].data;
			/* in this case a goto is very sensefull */
			goto dl_sck_jmppnt;

		/*** perhaps the one we wanted ? ***/
		default:

			if (bufferfield[i].sock == socket) {
				/* socket found */
				empty=i;
				/* in this case a goto is very sensefull */
				goto dl_sck_jmppnt;
			}
			/* go on */
			break;
		}
	}
	
dl_sck_jmppnt:

	/* and now free semaphore again */
	sema_post(&dl_sck_lock);

	/* set the actual values of pointers */
	buffer=bufferfield[empty].data;
	beg_ptr=bufferfield[empty].begin;
	end_ptr=bufferfield[empty].end;

	/* Nothing more to do ? -> harakiri */
	if (*beg_ptr == (char)EOF) {	/* ATTENTION: as char(EOF) may occur */
					/* within the data, this could lead  */
					/* to a premature end of connection  */
					/* will be fix as soon as I have an  */
					/* idea and enough time              */

		/* free ringbuffer entry */
		bufferfield[empty].sock = -2;
		bufferfield[empty].data[0]='\0';
		bufferfield[empty].begin=bufferfield[empty].data;
		bufferfield[empty].end=bufferfield[empty].data;
		return (char *)EOF;
	}
	
	/* not really an endless loop, but until we got one complete line */
	while (1) {

		/* now test if there's a whole line left in the buffer */
		/* we know this from the trailing LF */
		if ((ptr=strchr(beg_ptr,'\n'))) {
			/* yes , so strip out CRLR */
			*ptr='\0';
			if (((ptr-1)>=buffer) && (*(ptr-1)=='\r')) {
				*(ptr-1)='\0';
			}
			
			ptr++;
			ptr2=beg_ptr;
			beg_ptr=ptr;

			/* we have our line, go down to the end */
			break;

		/*** as there was not enough left, read a little bit more ***/
		} else {

			/* if buffer is filled to more than 50% */
			/* move remainder to begin of buffer */
			if (end_ptr>=(buffer+DL_MAXLINE)) {
				beg_ptr=strcpy(buffer,beg_ptr);
				end_ptr=(buffer+strlen(buffer));
			}

			/* read a little bit */
			numbytes=read(socket,(void *)end_ptr,DL_MAXLINE);
			if (numbytes == -1) {

#ifdef DEBUG
				fprintf(stderr,"LIBDIST-comm: Read Error\n");
#endif

				return NULL;
			}

			if (numbytes==0) {
			/* nothing read ? flush buffers */
				if (*beg_ptr == '\0') {
				/* nothing left in the buffer anyway */
					/* free ringbuffer entry */
					bufferfield[empty].sock = -2;
					bufferfield[empty].data[0]='\0';
					bufferfield[empty].begin=bufferfield[empty].data;
					bufferfield[empty].end=bufferfield[empty].data;				
					return (char *)EOF;
				}

				ptr2=beg_ptr;
				beg_ptr=end_ptr;
				*end_ptr=(char)EOF;

				/* return the last line */
				break;
			} 

			/* adjust end_ptr */
			end_ptr+=numbytes;
			*end_ptr='\0';


		} /* if */
	} /* while */

	/* enter begin/end values in array */
	bufferfield[i].begin=beg_ptr;
	bufferfield[i].end=end_ptr;

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-comm: received #%s#\n",ptr2);
#endif

	/* return the line */
	return ptr2;
}

/***
 *** char *dl_sck_recvtio(int socket,int timeout)
 ***
 *** Function: read a datagram from socket but with timeout (in usec)
 *** Return  : the datagram as string
 ***           NULL indicates an timeout
 ***           as the buffer for line is static the result
 ***           should be copied elsewhere before the next call
 ***           to dl_sck_recvtio
 ***/

char *dl_sck_recvtio(int socket, int timeout) {

    fd_set sockfds;		/* file descriptor set for select */
    int maxsockfds=0;		/* maximum descriptor used in sockfds */
    struct timeval timeouts;	/* timeout with select */
    int has_data;		/* number of descriptors from select */
    static char buffer[DL_MAXLINE];	/* receive buffer */
    int reclen;			/* number of bytes received */

    /* calculate the timeout */
    timeouts.tv_sec      = (timeout) / 1000;
    timeouts.tv_usec     = (timeout) % 1000;

    /* Prepare descriptors */
    FD_ZERO(&sockfds);
    FD_SET(socket,&sockfds);
    maxsockfds = socket+1;

    has_data = select(maxsockfds,&sockfds,(fd_set *)NULL,(fd_set *)NULL,&timeouts);

    switch (has_data) {

	/*** timeout ***/

	case 0:

		return NULL;
	
	/*** there's a message waiting for us ***/

	case 1:

		reclen=recv(socket,buffer,sizeof(buffer),0);

		if (reclen<=0) {
#ifdef DEBUG
		    fprintf(stderr,"LIBDIST-comm: recv error\n");
#endif

		    return NULL;
		}

		buffer[reclen]='\0';

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: received #%s#\n",buffer);
#endif

		return buffer;
	
	default:

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: select error\n");
#endif

		return NULL;

    }
}

/***
 *** int dl_sck_send(int socket, char *line)
 ***
 *** Function: send one line of text to 'socket'
 ***           <CR><LF> is autmatically added if not present
 ***	       suited only for DL_SCK_TCP
 *** Return  : the number of characters sent
 ***           DL_ERROR if error
 ***/

int dl_sck_send(int socket, char *line) {

	int numbytes;	/* number of bytes sent */

	numbytes=write(socket,line,strlen(line));

	if ( line[strlen(line)-1] != '\n' ) {
		numbytes+=write(socket,"\r\n",strlen("\r\n"));
	}
	
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-comm: sent #%s#\n",line);
#endif

	return numbytes;
}

/***
 *** int dl_sck_sockb(void);
 ***
 *** Function: open a broadcast socket
 *** Return  : the socket id
 ***           DL_ERROR if error
 ***/

static struct sockaddr_in dl_sck_broad = { AF_MAX+1 };
			/* IEA for broadcasting */
			/* must be global so dl_sck_sendb can read it */

int dl_sck_sockb(void) {

	int	s;		/* a socket */
	int	on = 1;		/* switch broadcast on */
	struct ifreq interf;	/* interface data */

	/* now try to create a socket */
	s=socket(PF_INET,SOCK_DGRAM,0);

	if (s == -1) {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: can't create socket\n");
#endif

		return DL_ERROR;
	}
	
	/* get the broadcast address if not allready done */
	if ( dl_sck_broad.sin_family == AF_MAX+1 ) {

	    /* what interface to use */
	    strcpy(interf.ifr_name,DL_SCK_ETHER);

	    /* now get the address */
	    if (ioctl(s, SIOCGIFBRDADDR, (char *)&interf) < 0) {

#ifdef DEBUG
		fprintf(stderr,"LIBDIST-comm: ioctl error\n");
#endif

		return DL_ERROR;
	    }

	    memcpy((char *)&dl_sck_broad, (char *)&interf.ifr_broadaddr, sizeof(struct sockaddr));

	}

	/* allow broadcasting */
	setsockopt(s,SOL_SOCKET,SO_BROADCAST,(char *)&on,sizeof(on));

	return s;
}

/***
 *** int dl_sck_sendb(int socket, char *service, char *line)
 ***
 *** Function: send text to 'socket' as a broadcast
 ***           the port to use is specified by 'service'
 *** Return  : the number of characters sent
 ***           DL_ERROR indicates an error
 ***/

int dl_sck_sendb(int socket, char *service, char *line) {

	int numbytes;	/* number of bytes sent */
	int port;	/* port to use */

	/* get the right port */
	port=dl_sck_getport(service,DL_SCK_UDP);
	if (port == DL_ERROR) {
	    return DL_ERROR;
	}
	dl_sck_broad.sin_port=htons(port);

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-comm: sending #%s# as broadcast\n",line);
#endif

	numbytes=sendto(socket,line,strlen(line)+1,0,
			(struct sockaddr *)&dl_sck_broad,sizeof(dl_sck_broad));

	return numbytes;
}

/***
 *** struct sockaddr_in *dl_sck_getifadr(void)
 ***
 *** Function: get the interface address of the current interface
 *** Return  : an Internet endpoint address
 ***           NULL indicates an error
 ***/

struct sockaddr_in *dl_sck_getifadr(void) {

    int s;				/* socket used for ioctl */
    static struct ifreq interf;		/* interface data */
    struct sockaddr_in *sin_ptr;	/* IEA */

    /* create the socket */
    s=socket(PF_INET,SOCK_DGRAM,IPPROTO_UDP);

    /* what interface to use */
    strcpy(interf.ifr_name,DL_SCK_ETHER);

    /* now get the address */
    if (ioctl(s, SIOCGIFADDR, (char *)&interf) < 0) {

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-comm: ioctl error\n");
#endif

	return NULL;
    }

    sin_ptr=(struct sockaddr_in *)&(interf.ifr_addr);
    close(s);
    return sin_ptr;
}
