/****************************************************************
 *                                                              *
 *  LIBDIST V1.0						*
 *                                                              *
 *  dsm.c -- distributed shared memory                          *
 *                                                              *
 *  Last changed: 09.04.96                                      *
 *  Author: Frank Kargl (frank.kargl@informatik.uni-ulm.de)     *
 *                                                              *
 *  Restrictions: works only for Solaris 2.6 or above           *
 *                                                              *
 ****************************************************************/


#include "libdist.h"
#include "config.h"

#include <sys/mman.h>	/* for mmap etc */
#include <string.h>	/* for mem* */
#include <fcntl.h>	/* for open */
#include <sys/types.h>	/* general type definitions */
#include <sys/stat.h>	/* general type definitions */
#include <arpa/inet.h>	/* for inet_ntoa */
#include <errno.h>	/* for errno */
#include <sys/signal.h>	/* for signal handline */

/***
 *** int dl_dsm_init(void)
 ***
 *** Function: start the dsm_server for this machine
 ***           should be done only once per machine
 *** Return  : DL_OK if ok
 ***           DL_ERROR if error
 *** Note    : not needed right now
 ***/

int dl_dsm_init(void) {
#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: dsm init called\n");
#endif
    return DL_OK;
}

/*** void dl_dsm_mappage(void)
 ***
 *** Function: do the mmap call
 ***           internal function
 ***/

static int pcode=PROT_NONE;	/* protection code of the page */
static void *pageaddr=NULL;	/* address of the page mapped */
static int fd;			/* file descriptor for /dev/zero */
static thread_t thr_id;		/* id of server thread */
static int has_page=0;		/* if we are holding the page */

void dl_dsm_mappage(void) {

    void *addr;	/* return value */

    /* do mmap */
    addr=mmap(pageaddr, DL_DSM_SIZE, pcode, MAP_SHARED|MAP_FIXED, fd, 0);

    if (addr==MAP_FAILED || addr!=pageaddr) {
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: mmap failed\n");
#endif
	exit(1);
    }

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: mapped paged to %x pcode=%x\n",(int)addr,pcode);
#endif
}

/*** void dl_dsm_requestpage(void)
 ***
 *** Function: request the page via broadcast
 ***           internal function
 ***/

int dl_dsm_requestpage(void) {

    int commsock = -1;		/* communication socket */
    struct timeval timeouts;    /* timeout with select */
    int i;			/* looop count */
    char aport[5];		/* ascii port */
    fd_set sockfds;		/* file descriptor set for select */
    int has_data;		/* number of descriptors from select */
    int reclen;			/* number of bytes received */

    /* create communication socket */
    commsock=dl_sck_sockb();
    if (commsock==DL_ERROR) {

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: can't create socket\n");
#endif

	return DL_ERROR;
    }

    /* send request to multiple ports */
    for (i=DL_DSM_MINPORT;i<=DL_DSM_MAXPORT;i++) {
	sprintf(aport,"%u",i);
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: sending request to port %u\n",i);
#endif
	if (dl_sck_sendb(commsock,aport,DL_DSM_REQUEST) == DL_ERROR) {

#ifdef DEBUG
	    fprintf(stderr,"LIBDIST-dsm: error sending broadcast\n");
#endif
	    
	    exit(1);
	}
    }

    /* wait for response */
    timeouts.tv_sec      = (DL_DSM_TIMEOUT) / 1000;
    timeouts.tv_usec     = (DL_DSM_TIMEOUT) % 1000;

    /* Prepare descriptors */
    FD_ZERO(&sockfds);
    FD_SET(commsock,&sockfds);

    has_data = select(commsock+1,&sockfds,(fd_set *)NULL,(fd_set *)NULL,&timeouts);

    switch (has_data) {

    /*** timeout ***/

    case 0:

#ifdef DEBUG
	    fprintf(stderr,"LIBDIST-dsm: timeout waiting for page\n");
#endif
	    close(commsock);
	    return DL_DSM_TIMEDOUT;

    /*** there's a message waiting for us ***/

    case 1:

	    reclen=recv(commsock,pageaddr,DL_DSM_SIZE,0);
	    if (reclen<=0) {
#ifdef DEBUG
		fprintf(stderr,"LIBDIST-dsm: recv error %s\n",strerror(errno));
#endif
		break;
	    }

#ifdef DEBUG
	    fprintf(stderr,"LIBDIST-dsm: received page %x\n",(int)pageaddr);
	    fprintf(stderr,"LIBDIST-dsm: page starts with #%s#\n",(char *)pageaddr);
#endif

	    if (reclen != DL_DSM_SIZE) {
#ifdef DEBUG
		fprintf(stderr,"LIBDIST-dsm: wrong page length\n");
#endif
		break;
	    }

	    /* now we have the page */
	    has_page=1;

	    /* that's it */
	    close(commsock);
	    return DL_OK;

    /*** default ? how can this be ? ***/

    default:

#ifdef DEBUG
	    fprintf(stderr,"LIBDIST-dsm: select error\n");
#endif
	    break;
    }

    close(commsock);
    return DL_ERROR;
}

/*** void *threadfunction(void *arg)
 ***
 *** Function: reply to pagesend requests
 ***           internal function
 ***/

void *threadfunction(void *arg) {

    int commsock;		/* communication socket */
    struct sockaddr_in sin;	/* an IEA */
    struct sockaddr_in client;	/* an IEA */
    int client_len;		/* length of IEA */
    char buffer[DL_MAXLINE];	/* request buffer */
    int port;			/* port used */
    int status;			/* return status */

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: threadfunction starting\n");
#endif

    commsock = socket(PF_INET,SOCK_DGRAM,IPPROTO_UDP);

    if (commsock==-1) {
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: error creating socket\n");
#endif
	exit(1);
    }

    /* now scan for a free port and bind recsock */
    sin.sin_family	= AF_INET;
    sin.sin_addr.s_addr	= INADDR_ANY;
    port		= DL_DSM_MINPORT;
    sin.sin_port	= htons(port);

    while (bind(commsock,(struct sockaddr *)&sin, \
		sizeof(sin)) == -1 &&		/* bind sucessfull ? */
		port <= DL_DSM_MAXPORT &&	/* we only scan 1000 ports */
		port < 65534 ) {		/* and only up to maximum */
        port++;
        sin.sin_port = htons(port);
    }

    /* error ? */
    if (port > DL_DSM_MAXPORT || port >= 65534) {
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: can't bind socket\n");
#endif
	exit(1);
    }

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: bound to port %u\n",port);
#endif

    /* wait for requests */
    while(1) {

	/* get request */
	client_len=sizeof(client);
	if (recvfrom(commsock,buffer,sizeof(buffer),0, \
	    (struct sockaddr *)&client,&client_len) < 0) {
	    /* error reading ? */
#ifdef DEBUG
	    fprintf(stderr,"LIBDIST-dsm: error reading in dsm thread\n");
#endif
	    continue;
	}

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: got request #%s# from %s, %u\n", \
		    buffer,inet_ntoa(client.sin_addr),ntohs(client.sin_port));
#endif

	/* if we have the page, set PROT_NONE and send */
	if (has_page) {

	    /* send page to requesting client */
	    status=sendto(commsock,pageaddr,DL_DSM_SIZE,0, \
		    (struct sockaddr *)&client,sizeof(client));

	    if (status == -1) {
#ifdef DEBUG
		fprintf(stderr,"LIBDIST-dsm: send error %s\n",strerror(errno));
#endif
		exit(1);
	    }
	    
#ifdef DEBUG
	    fprintf(stderr,"LIBDIST-dsm: sent page to %s, %u\n", \
			inet_ntoa(client.sin_addr),ntohs(client.sin_port));
	    fprintf(stderr,"LIBDIST-dsm: page starts with #%s#\n",(char *)pageaddr);
#endif

	    pcode=PROT_NONE;
	    dl_dsm_mappage();

	    has_page=0;
	}
    }
}

/*** void dl_dsm_pagefault(int sign)
 ***
 *** Function: when pagefault occurs, requests page and make it RW again
 ***           internal function
 ***/

void dl_dsm_pagefault(int sign) {

    int status;

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: SEGV catched\n");
#endif

    /* restore access mask */
    pcode=PROT_READ|PROT_WRITE;
    dl_dsm_mappage();

    /* Request page */
    status=dl_dsm_requestpage();

    if (status==DL_ERROR) {
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: fatal dsm error\n");
#endif
	exit(1);
    }

    if (status==DL_DSM_TIMEOUT) {
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: timeout getting page - retrying\n");
#endif
	pcode=PROT_NONE;
	dl_dsm_mappage();

	return;
    }

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: in pagefault %x #%s#\n",(int)pageaddr,(char *)pageaddr);
#endif

    signal(SIGSEGV,dl_dsm_pagefault);

}

/*** void *dl_dsm_page(char *key, int timeout)
 ***
 *** Function: create a new page associated with 'key' or give access to an
 ***           existing one 
 *** Return  : pointer to the shared memory segment
 ***           (void *)DL_ERROR if error
 *** Restrictions: only one page can be used, key is ignored, dito timeout
 ***/

void *dl_dsm_page(char *key, int timeout) {

    int status;		/* status of dl_dsm_requestpage */

    /* create dsm page */
    if ((fd =open("/dev/zero", O_RDWR)) == -1) {

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: /dev/zero not available\n");
#endif

	return (void *)DL_ERROR;
    }

    pcode=PROT_READ|PROT_WRITE;
    pageaddr=mmap((void *)0, DL_DSM_SIZE, pcode, MAP_SHARED, fd, 0);
    if (pageaddr==MAP_FAILED) {

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: mmap failed\n");
#endif

	return (void *)DL_ERROR;
    }

    /* now get the page */
    status=dl_dsm_requestpage();
    if (status == DL_ERROR) {

#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: error getting page\n");
#endif
	
	/* tilt ... stop all activity */
	dl_dsm_remove(NULL);
	return (void *)DL_ERROR;
    }

    if (status == DL_DSM_TIMEDOUT) {
#ifdef DEBUG
	fprintf(stderr,"LIBDIST-dsm: timeout - we are the first dsm user\n");
#endif
	/* assume we are the first one using dsm */
	has_page=1;
    }

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: mapped page first time\n");
#endif

    /* Setup signal handler */
    signal(SIGSEGV,dl_dsm_pagefault);

    /* Start dsm thread */
    thr_id=dl_thr_create(threadfunction,NULL);

    return pageaddr;
}

/*** 
 *** void dl_dsm_remove(char *key)
 ***
 *** Function: remove a shared memory segment from your use
 *** Return  : -
 *** Restrictions: only one page can be used, key is ignored
 ***/

void dl_dsm_remove(char *key) {

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: dl_dsm_remove called\n");
#endif

    /* Stop server thread */
    dl_thr_kill(thr_id,SIGKILL);

    /* Remove signal handler */
    signal(SIGSEGV,SIG_DFL);

    /* Unmap page */
    munmap(pageaddr,DL_DSM_SIZE);

    /* Close /dev/zero */
    close(fd);
}

/***
 *** void dl_dsm_end(void)
 ***
 *** Function: end all dsm activity for this machine
 ***           stop the dsm_server
 *** Return  : -
 *** Note    : not needed right now
 ***/

void dl_dsm_end(void) {

#ifdef DEBUG
    fprintf(stderr,"LIBDIST-dsm: dl_dsm_end called\n");
#endif

}
