/*****************************************************************************
 * Filename:    cryptfs/actions.c                                            *
 * Description: Handles the actions and does fsck/mounting                   *
 * Copyright:   2009 by Alexander Motzkau                                    *
 *****************************************************************************/

#include <errno.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <sys/wait.h>

#include "actions.h"
#include "key.h"
#include "misc.h"
#include "dm.h"
#include "tty.h"
#include "xml.h"

struct map_options null_override =
{
    fsck: t_unspecified,
    mount: t_unspecified,
    ro: t_unspecified,
    mode: 0,
    uid: NULL,
    gid: NULL
};

enum bool handle_command(struct command *cmd, enum bool success, struct configuration* config)
{
    pid_t child;

    if((success && cmd->on_state == on_failure) ||
       (!success && cmd->on_state == on_success))
        return true;
    
    if((config->mode == on_remove || cmd->on_mode == on_remove) &&
	config->mode != cmd->on_mode)
	return true;
    
    if(config->mode != om_always && cmd->on_mode != om_always &&
	config->mode != cmd->on_mode)
	return true;
    
    if(cmd->needtty)
	get_tty(config);

    child = fork();
    
    if(!child)
    {
	setuidnam(cmd->uid, cmd->gid, config);
	
	execl("/bin/sh", "/bin/sh", "-c", cmd->cmd, NULL);
	errnomsg("execl");
	
	exit(1);
    }
    else if(child<0)
    {
	errnomsg("fork");
	return false;
    }
    else
    {
	int status;
	
	if(waitpid(child, &status, 0)<0)
	{
	    errnomsg("waitpid");
	    return false;
	}
	
	if(WIFEXITED(status) && WEXITSTATUS(status)==0)
	    return true;
	
	if(cmd->ignore_status)
	    return true;
	
	return false;
    }

    return false; /* Not reached. */
}

enum bool handle_commands(struct command *cmd, enum bool success, struct configuration* config)
{
    while(cmd != NULL)
    {
	success = handle_command(cmd, success, config) && success;
	cmd = cmd->next;
    }
    
    return success;
}

enum bool handle_action(char* action, struct map_options* override, enum bool access, enum bool firstcall, struct configuration* config)
{
    /* Let's go for it. */
    
    struct action* ac;
    struct map* mapping;
    struct key* key;
    struct call* call;
    enum bool success;
    enum bool denied;
    
    if(override == NULL)
	override = &null_override;
    
    ac = find_action(action, config);

    /* Prevent recursion. */    
    if(ac->done)
	return true;
    ac->done = true;

    access = check_access_list(ac->access, config, access);
    denied = !access;

    /* Check, whether there's anything to do. */    
    success = false;
    for(mapping = ac->mappings; mapping != NULL; mapping = mapping->next)
    {
	unsigned long bsize;
	char *dmname;
	
	if(!check_access_list(mapping->access, config, access))
	{
	    mapping->done = true;
	    continue;
	}
	
	denied = false;
	
	dmname = get_crypt_name(mapping);
	bsize = 0;
	
	if(get_blockdev_size(dmname, &bsize)?(bsize!=0):(errno==EACCES))
	{
	    /* Device mapped device already exists. */

	    if(config->mode == on_remove ||
		(config->mode == on_mount && !is_mounted(dmname)))
	    {
		mapping->done = false;
		success = true;
	    }
	    else
		mapping->done = true;
	}
	else if(config->mode == on_remove || config->mode == on_mount)
	    mapping->done = true;
	else
	{
	    /* Check if all used block devices are available. */
	    struct storage_block *block;
	    
	    block = find_storage(mapping->name, config);
	    if(block == NULL)
		mapping->done = true;
	    else
		mapping->done = false;
	
	    while(block != NULL && !mapping->done)
	    {
		if(!get_blockdev_size(block->device->device, &bsize) || bsize==0)
		    /* This one isn't */
		    mapping->done = true;
		else
		    block = block->next;
	    }
	    
	    if(!mapping->done)
		success = true;
	}
	
        xfree(dmname);
    }
    
    if(config->mode != on_remove)
	key = insert_key(ac->key, config);
    else
	key = NULL; /* To make the compiler happy. */
    
    if(success)
    {
	/* Something to do. */
    
	if(!handle_commands(ac->preexec, true, config))
	{
	    handle_commands(ac->postexec, false, config);
	    return false;
	}
    
        success = false;

	for(mapping = ac->mappings; mapping != NULL; mapping = mapping->next)
	{
	    enum bool retry;
	    
	    if(mapping->done)
		continue;
	
	    retry = false;
	    
	    do
	    {
		struct key* mapkey;
		char* dmname;
		void *k;
		int klen;
	
		k = NULL;
	
		/* There is nothing we can do different this time. */
		if((config->mode == on_mount || config->mode == on_remove) && retry)
		    break;
		
		if(!config->noenc && config->mode != on_mount && config->mode != on_remove && mapping->key == NULL)
		{
		    if(retry && !reset_key(key, config))
			break;
		
		    if((k=handle_key(key, &klen, config))==NULL)
		    {
			handle_commands(ac->postexec, false, config);
		        return false;
		    }
		}
	
		if(!retry && !handle_commands(mapping->preexec, true, config))
		    break;
	
		if(!config->noenc && config->mode != on_remove)
		    mapkey = insert_key(mapping->key, config);

	        if(!config->noenc && config->mode != on_mount && config->mode != on_remove && mapping->key != NULL)
	        {
		    if(retry && !reset_key(mapkey, config))
			break;
		    
		    if((k=handle_key(mapkey, &klen, config))==NULL)
			break;
	        }
	
		if(config->mode != on_mount && config->mode != on_remove &&
		    !setup_crypt(mapping, k, klen, override, config))
			break;
		
		retry = true;

		dmname = get_crypt_name(mapping);
		
		if(config->mode == on_remove && is_mounted(dmname))
		{
	    	    pid_t child;
	    
		    child = fork();
	    
	            if(!child)
		    {
		        /* umount devname */
		        char* umnt_arg[3];
	    
			umnt_arg[0] = "/bin/umount";
			umnt_arg[1] = dmname;
			umnt_arg[2] = NULL;
			
			setuid(geteuid());
			setgid(geteuid());

		        execv("/bin/umount", umnt_arg);
		        errnomsg("execv /bin/umount");
		        exit(1);
	            }
	            else if(child<0)
		    {
		        errnomsg("fork");
		        continue;
		    }
		    else
	    	    {
			int status;
		
			if(waitpid(child, &status, 0)<0)
		        {
			    errnomsg("waitpid");
			    continue;
		        }

		        if(!WIFEXITED(status) || WEXITSTATUS(status)!=0)
			{
		    	    errmsg("%s: umount failed.\n", dmname);
		    	    continue;
			}
	    	    }
		}

		if(config->mode == on_remove)
		{
		    remove_crypt(mapping, config);
		}
	
		if(config->mode != on_map && config->mode != on_remove &&
		    t_resolve(override->fsck, mapping->opts.fsck,
			(T_RESOLVE2(override->ro, mapping->opts.ro) == t_true)?t_false:config->fsck))
		{
		    pid_t child;
	    
	    	    get_tty(config);
	    
		    child = fork();

	    	    if(!child)
		    {
			/* fsck -C [ -t mapping->fs ] devname */
		        char* fsck_arg[6];
			int a;
	    
		        fsck_arg[0] = "/sbin/fsck";
			fsck_arg[1] = "-C";
			a = 2;
			
		        if(mapping->fs != NULL && strlen(mapping->fs))
			{
		    	    fsck_arg[a++] = "-t";
		    	    fsck_arg[a++] = mapping->fs;
			}
		        fsck_arg[a++] = dmname;
			fsck_arg[a++] = NULL;
		
		        execv("/sbin/fsck", fsck_arg);
			errnomsg("execv /sbin/fsck");
		        exit(1);
	    	    }
	    	    else if(child<0)
	    	    {
			errnomsg("fork");
		        remove_crypt(mapping, config);
		        continue;
	    	    }
	    	    else
	    	    {
			int status;
		
		        if(waitpid(child, &status, 0)<0)
			{
		    	    errnomsg("waitpid");
		    	    remove_crypt(mapping, config);
		    	    continue;
			}
	    
		        if(!WIFEXITED(status) || (WEXITSTATUS(status)&~1)!=0)
			{
		    	    errmsg("%s: fsck failed.\n", dmname);
			    remove_crypt(mapping, config);
			    continue;
			}
	    	    }
		}
	
	        if(config->mode != on_map && config->mode != on_remove &&
		    t_resolve(override->mount, mapping->opts.mount, config->mount))
		{
	    	    pid_t child;
	    
		    child = fork();
	    
	            if(!child)
		    {
		        /* mount [ -t mapping->fs ] [ -o mapping->options ]
			         devname [ mapping->point] */
		
		        char* mnt_arg[7];
			int a;
	    
			mnt_arg[0] = "/bin/mount";
		        a = 1;
	    
		        if(mapping->fs != NULL && strlen(mapping->fs))
			{
		    	    mnt_arg[a++] = "-t";
			    mnt_arg[a++] = mapping->fs;
		        }

		        if(mapping->options != NULL && strlen(mapping->options))
			{
		    	    mnt_arg[a++] = "-o";
			    mnt_arg[a++] = mapping->options;
		        }
			mnt_arg[a++] = dmname;
		        if(mapping->point != NULL && strlen(mapping->point))
			    mnt_arg[a++] = mapping->point;
		        mnt_arg[a++] = NULL;
		    
			setuid(geteuid());
			setgid(geteuid());

		        execv("/bin/mount", mnt_arg);
		        errnomsg("execv /bin/mount");
		        exit(1);
	            }
	            else if(child<0)
		    {
		        errnomsg("fork");
			remove_crypt(mapping, config);
		        continue;
		    }
		    else
	    	    {
			int status;
		
			if(waitpid(child, &status, 0)<0)
		        {
			    errnomsg("waitpid");
		            remove_crypt(mapping, config);
			    continue;
		        }

		        if(!WIFEXITED(status) || WEXITSTATUS(status)!=0)
			{
		    	    errmsg("%s: mount failed.\n", dmname);
		    	    remove_crypt(mapping, config);
		    	    continue;
			}
	    	    }
		}
	
	        if(!handle_commands(mapping->postexec, true, config))
		{
	    	    remove_crypt(mapping, config);
	    	    continue;
		}
		
		mapping->done = true;
	    } while(!mapping->done);
	    
	    
	    if(mapping->done)
		success = true; /* At least one successful mapping. */
	    else
		handle_commands(mapping->postexec, false, config);
        }

        handle_commands(ac->postexec, success, config);
    }
    
    /* Now call other actions. */
    for(call = ac->calls; call != NULL; call = call->next)
    {
	struct map_options* new_override;

	new_override = copy_map_options(override);
	if(call->overrides.uid != NULL && new_override->uid == NULL)
	    new_override->uid = xmallocstr(call->overrides.uid);
	if(call->overrides.gid != NULL && new_override->gid == NULL)
	    new_override->gid = xmallocstr(call->overrides.gid);
	if(new_override->fsck == t_unspecified)
	    new_override->fsck = call->overrides.fsck;
	if(new_override->mount == t_unspecified)
	    new_override->mount = call->overrides.mount;
	if(new_override->ro == t_unspecified)
	    new_override->ro = call->overrides.ro;
	if(!new_override->mode)
	    new_override->mode = call->overrides.mode;
	
	if(call->filename == NULL)
	{
	    /* Internal call. */
	    if(find_action(call->action, config)!=NULL)
		if(handle_action(call->action, new_override, access, false, config))
		    denied = false;
	}
	else
	{
	    /* External call. */
	    pid_t child;
	    int status;
	    
	    child = fork();
	    if(!child)
	    {
		void* key_tree;
		struct key* tree_keys;
		char* action;
		char* file;
		enum on_mode mode;
		
		/* Save the keys. */
		key_tree = config->key_tree;
		config->key_tree = NULL;
		tree_keys = config->keys_in_tree;
		config->keys_in_tree = NULL;
		/* ... and other data. */
		action = call->action;
		call->action = NULL;
		file = call->filename;
		call->filename = NULL;
		mode = config->mode;
		
		free_config(config);
		
		config = parse_file(file);
    		if(config == NULL)
	    	    exit(1);
		
		if(!complete_config(config))
		    exit(1);
		    
		config->mode = mode;
		config->key_tree = key_tree;
		config->keys_in_tree = tree_keys;
		
		if(find_action(action, config)!=NULL)
		    exit(handle_action(action, new_override, check_access_list(config->access, config, access), false, config)?0:1);
		
		exit(1);
	    }
	    else if(child<0)
		errnomsg("fork");
	    else if(waitpid(child, &status, 0)<0)
		errnomsg("waitpid");
	    else if(WIFEXITED(status) && !WEXITSTATUS(status))
		denied = false;
	}
	
	free_map_options(new_override);
    }
    
    if(firstcall)
    {
	if(denied)
	    errmsg("%s: Access denied.\n", ac->name);
	return true;
    }
    else
	return !denied;
}

enum bool check_action(struct action* action, enum bool access, struct configuration* config)
{
    struct map* mapping;
    
    access = check_access_list(action->access, config, access);
    if(access)
	return true;
	
    for(mapping = action->mappings; mapping != NULL; mapping = mapping->next)
	if(check_access_list(mapping->access, config, access))
	    return true;

    return false;
}
