/*****************************************************************************
 * Filename:    cryptfs/key.c                                                *
 * Description: Key handling                                                 *
 * Copyright:   2009 by Alexander Motzkau                                    *
 *****************************************************************************/

#include <stdio.h>
#include <string.h>
#include <termios.h>
#include <unistd.h>
#include <sys/ioctl.h>
#include <sys/types.h>
#include <sys/wait.h>

#include <gcrypt.h>

#include "actions.h"
#include "misc.h"
#include "key.h"
#include "tty.h"

enum bool process_input(struct key* key, char* input, struct configuration* config)
{
    int algo;
    int inplen;
    gcry_md_hd_t hd;
    gcry_error_t err;

    if(key->algo == NULL)
    {
	if(key->len>0 && key->len <= 160)
	    algo = GCRY_MD_SHA1;
	else if(!key->len || (key->len>160 && key->len<=256))
	    algo = GCRY_MD_SHA256;
	else if(key->len>256 && key->len<=384)
	    algo = GCRY_MD_SHA384;
	else /* 385 and greater... */
	    algo = GCRY_MD_SHA512;
    }
    else if(!strcmp(key->algo, "none"))
	algo = 0;
    else
    {
	algo = gcry_md_map_name(key->algo);
	if(!algo)
	{
	    errmsg("Algorithm '%s' not supported.\n", key->algo);
	    return false;
	}
    }

    inplen = strlen(input);
    if(key->newline != nl_unchanged && input[inplen-1] == '\n')
	inplen--;

    if(!algo)
    {
	key->resultlen = inplen+((key->newline == nl_add)?1:0);
        key->result = xmalloc(inplen+((key->newline == nl_add)?1:0));
	memcpy(key->result, input, inplen);
	if(key->newline == nl_add)
	{
	    ((char*)key->result)[inplen] = '\n';
	    ((char*)key->result)[inplen+1] = 0;
	}
    }
    else
    {
	unsigned int mdflags;
	gcry_control( GCRYCTL_DISABLE_SECMEM_WARN );
	
	/* This @#*% libgcrypt drops my privileges. 
	   So no secure memory when unter setuid.
	 */
	if(getuid() == geteuid())
	{
	    gcry_control( GCRYCTL_INIT_SECMEM, 16384, 0 );
	    mdflags = GCRY_MD_FLAG_SECURE;
	}
	else
	    mdflags = 0;
	
	err = gcry_md_open(&hd, algo, mdflags);
        if (err)
	{
	    errmsg("Failure: %s/%s\n", gcry_strsource (err), gcry_strerror (err));
	    return false;
        }
    
	gcry_md_write(hd, input, inplen);
	if(key->newline == nl_add)
	    gcry_md_write(hd, "\n", 1);
        gcry_md_final(hd);
    
	key->resultlen = gcry_md_get_algo_dlen(algo);
        key->result = xmalloc(key->resultlen);
	memcpy(key->result, gcry_md_read(hd, algo), key->resultlen);
        
	gcry_md_close(hd);
	
	gcry_control( GCRYCTL_TERM_SECMEM );
    }

    return true;
}

void *handle_key(struct key* key, int *len, struct configuration* config)
{
    if(key->result != NULL)
    {
	*len = key->resultlen;
	return key->result;
    }
    
    if(!handle_commands(key->preexec, true, config))
    {
	handle_commands(key->postexec, false, config);
        return NULL;
    }
    
    switch(key->type)
    {
	case ref:
	{
	    struct key* rkey;
	    
	    rkey = find_key(key->id, config);
	    
	    if(rkey == NULL || rkey == key)
	    {
		errmsg("Key not found.\n");
		break;
	    }
	    
	    key->result = handle_key(rkey, &key->resultlen, config);
	}
	break;
	    
	case composite:
	{
	    struct key* k;
	    
	    for(k = key->subkeys; k != NULL; k = k->next)
	    {
		void *kres;
		int kreslen;
		
		kres = handle_key(k, &kreslen, config);
		if(!kreslen || kres==NULL)
		{
		    xfree(key->result);
		    key->resultlen = 0;
		    key->result=NULL;
		    break;
		}
		
		if(kreslen > key->resultlen)
		{
		    void * tempkey;
		    
		    tempkey = xmalloc(kreslen);
		    memcpy(tempkey, kres, kreslen);
		    
		    if(key->result != NULL)
		    {
			char* t1 = (char*) tempkey;
			char* t2 = (char*) key->result;
			int i;

			for(i=key->resultlen; i--; t1++, t2++)
			    *t1 ^= *t2;
			
			xfree(key->result);
		    }
		    key->result = tempkey;
		    key->resultlen = kreslen;
		}
		else
		{
		    char* t1 = (char*) key->result;
		    char* t2 = (char*) kres;
		    int i;

		    for(i=kreslen; i--; t1++, t2++)
		        *t1 ^= *t2;
		}
	    }
	}
	break;
	
	case passphrase:
	{
	    char pp[4096];
	    int pplen, ttyrows;
	    struct termios old, new;
	    enum bool term;

	    get_tty(config);
	    
	    term = tcgetattr(fileno(stdin), &old) == 0;
	    if(term)
	    {
		new = old;
		new.c_lflag &= ~ECHO;
		tcsetattr(fileno(stdin), TCSAFLUSH, &new);
	    }
	    
	    ttyrows = 25;

	    if(key->vt100reserve > 0)
	    {
		int i;
	        struct winsize size;
    
		if(!ioctl(fileno(stdout), TIOCGWINSZ, &size))
		    ttyrows = size.ws_row;
		    
		for(i=0; i<key->vt100reserve; i++)
		    fprintf(stdout, "\n");
		
		fprintf(stdout, "\e[%dA\e7\e[%d;1H\e[?25l\n",
		    key->vt100reserve, ttyrows - key->vt100reserve);
	    }
	    
	    if(key->value && *(key->value)) /* Not empty */
		fprintf(stdout, "%s\n", key->value);
	    fprintf(stdout, "Passphrase: ");

	    if(key->vt100reserve > 0)
		fprintf(stdout, "\e[5m_\e[25m\e[1;%dr\e8",
		    ttyrows - key->vt100reserve);
	    
	    fgets(pp, 4096, stdin);
	    pplen = strlen(pp);

	    if(key->vt100reserve > 0)
		fprintf(stdout, "\e7\e[1;%dr\e8\e[?25h\e[J", ttyrows);
	    
	    if(term) tcsetattr(fileno(stdin), TCSAFLUSH, &old);
	    fprintf(stdout, "\n");
	    
	    if(pplen == 4095 && pp[4094]!='\n' && !feof(stdin))
	    {
		errmsg("Passphrase too long.\n");
		break;
	    }
	    
	    if(pp[pplen-1]=='\n')
		pp[pplen-1] = 0;

	    if(!process_input(key, pp, config))
		break;
	}
	break;
	
	case file:
	{
	    char fkey[4096];
	    FILE *inkey;
	    
	    inkey= fopen(key->value, "rb");
	    if(inkey == NULL)
	    {
		errnomsg(key->value);
		break;
	    }
	    
	    key->resultlen = fread(fkey, 1,
		    (key->len>0 && key->len<4096*8)?(key->len/8):4096, inkey);
	    if(key->resultlen==0)
	    {
		errmsg("%s: File is empty.\n", key->value);
		break;
	    }
	    
	    fclose(inkey);
	    
	    key->result = xmalloc(key->resultlen);
	    memcpy(key->result, fkey, key->resultlen);
	}
	break;
	
	case program:
	{
	    int fds[2];
	    pid_t child;

	    if(key->needtty)
		get_tty(config);
	    
	    if(pipe(fds)<0)
	    {
		errnomsg("pipe");
		break;
	    }
	    
	    child = fork();
	    
	    if(!child)
	    {
		/* Child process */
		
		setuidnam(key->uid, key->gid, config);
		
		close(fds[0]);
		if(dup2(fds[1], STDOUT_FILENO)<0)
		{
		    close(fds[1]);
		    errnomsg("dup2");
		    exit(1);
		}
		close(fds[1]);
		
		execl("/bin/sh", "/bin/sh", "-c", key->value, NULL);
		errnomsg("execl");
		exit(1);
	    }
	    else if(child<0)
	    {
		errnomsg("fork");
		return NULL;
	    }
	    else
	    {
#define CHUNKSIZE 16384
		int status;
		void *first;
		void *current;
		unsigned int len, pos;
		ssize_t r;
		
		/* Parent process */
		close(fds[1]);
		
		first=current=xmalloc(CHUNKSIZE);
		pos = sizeof(void*);
		len = 0;
		
		while((r=read(fds[0], current+pos, CHUNKSIZE-pos)))
		{
		    len+=r;
		    pos+=r;
		    
		    if(pos==CHUNKSIZE)
		    {
			void* next;
			
			next = xmalloc(CHUNKSIZE);
			*(void**)current = next;
			current = next;
			pos = sizeof(void*);
		    }
		}
		
		*(void**)current = NULL;
		close(fds[0]);
		
		current = xmalloc(len+1);
		pos = 0;
		
		while(len && first!=NULL)
		{
		    int s;
		    void * old;
		    
		    s = CHUNKSIZE-sizeof(void*);
		    if(len<s)
			s = len;
			
		    memcpy(current+pos, first+sizeof(void*), s);
		    len -= s;
		    pos += s;
		    
		    old = first;
		    first = *(void**)first;
		    xfree(old);
		}
		
		((char*)current)[pos] = 0;
		
		if(waitpid(child, &status, 0)<0)
		{
		    errnomsg("waitpid");
		    xfree(current);
		    break;
		}
		
		if(!key->ignore_status &&
		    (!WIFEXITED(status) || WEXITSTATUS(status)!=0))
		{
		    errmsg("Command for key \"%s\" failed.\n", key->id==NULL?"":key->id);
		    xfree(current);
		    break;
		}
		
		if(!process_input(key, current, config))
		{
		    xfree(current);
		    break;
		}
		
		xfree(current);
	    }
	}
	break;
	
	case literal:
	    if(key->value==NULL)
	    {
		errmsg("No literal given.\n");
		break;
	    }
	    
	    if(!process_input(key, key->value, config))
		break;
	    
	    break;
    }
    
    if(key->result!=NULL)
    {
	if(key->len>0 && key->resultlen*8 > key->len)
	    key->resultlen = key->len / 8;
	    
	*len = key->resultlen;
	
	handle_commands(key->postexec, true, config);
    }
    else
    {
	handle_commands(key->postexec, false, config);
    }
    return key->result;
}

enum bool reset_key(struct key* key, struct configuration* config)
{
    enum bool success;
    
    if(key->result == NULL)
	return false;

    success = false;

    switch(key->type)
    {
	case ref:
	{
	    struct key* rkey;
	    
	    rkey = find_key(key->id, config);
	    
	    if(rkey == NULL || rkey == key)
		return false;
	    
	    success = reset_key(rkey, config);
	}
	break;
	
	case composite:
	{
	    struct key* k;
	    
	    for(k = key->subkeys; k != NULL; k = k->next)
		if(reset_key(k, config))
		    success = true;
	}
	break;
	
	case passphrase:
	case file:
	case program:
	    if(key->retry != 0)
	    {
		if(key->retry > 0)
		    key->retry--;

		success = true;
	    }
	    break;
	
	case literal:
	    return false;
    }

    if(success)
    {
	if(key->type != ref)
    	    xfree(key->result);

	key->result = NULL;
	key->resultlen = 0;
	return true;
    }
    
    return false;
}
