/* SPDX-License-Identifier: GPL-2.0-only OR GPL-3.0-only */
/* Copyright (c) 2025 Brett Sheffield <bacs@librecast.net> */

#include <state.h>
#include <log.h>
#include <err.h>
#include <errno.h>
#include <librecast/net.h>
#include <limits.h>
#include <net/if.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/param.h>
#include <sys/stat.h>
#include <unistd.h>
#include <y.tab.h>
#include <lex.h>

extern int yyparse(state_t *state);

/* default channel state */
state_chan_t state_channel_defaults = {
	.flags = CHAN_ENABLE | CHAN_JOIN,
};

state_chan_t *state_channel_free(state_chan_t *a)
{
	state_chan_t *tmp = a;
	free(a->auth_key);
	free(a->chan_name);
	free(a->dir);
	free(a->seed);
	for (state_cmd_t *cmd = a->cmd; cmd;) {
		state_cmd_t *tmp = cmd;
		free(cmd->cmd);
		cmd = cmd->next;
		free(tmp);
	}
	a = a->next;
	free(tmp);
	return a;
}

void free_state(state_t *state)
{
	if (state->log) log_close(state);
	free(state->dir_config);
	free(state->dir_cache);
	free(state->dir_state);
	free(state->dir_data);
	free(state->defaults.auth_key);
	free(state->defaults.chan_name);
	free(state->defaults.dir);
	free(state->defaults.seed);
	free(state->rcfile);
	free(state->logfile);
	for (state_chan_t *a = state->chan_head; a; ) {
		a = state_channel_free(a);
	}
	state->chan_head = NULL;
}

static void freesv(void *ptr)
{
	int err = errno;
	free(ptr);
	errno = err;
}

static int mkdirp(const char *pathname, size_t len)
{
	char mkd[] = "mkdir -p ";
	char cmd[sizeof mkd + len];
	snprintf(cmd, sizeof cmd, "%s%s", mkd, pathname);
	return system(cmd);
}

static char *state_dir_create(state_t *state, const char *xdg, const char *subdir)
{
	char *dir = NULL;
	int len;
	if (xdg) {
		len = snprintf(NULL, 0, "%s/%s", xdg, PACKAGE_NAME);
		if (len == -1) return NULL;
		dir = malloc(len + 1);
		if (!dir) return NULL;
		len = snprintf(dir, len + 1, "%s/%s", xdg, PACKAGE_NAME);
		if (len == -1) goto err_free_dir;
	}
	else {
		len = snprintf(NULL, 0, "%s/%s/%s", state->dir_home, subdir, PACKAGE_NAME);
		if (len == -1) return NULL;
		dir = malloc(len + 1);
		if (!dir) return NULL;
		len = snprintf(dir, len + 1, "%s/%s/%s", state->dir_home, subdir, PACKAGE_NAME);
		if (len == -1) goto err_free_dir;
	}
	if (mkdirp(dir, len) == 0) return dir;
err_free_dir:
	freesv(dir);
	return NULL;
}

static int state_rcfile(state_t *state)
{
	int len;
	len = snprintf(NULL, 0, "%s/%s", state->dir_home, LCAGENTRC);
	if (len < 0) return -1;
	len++;
	state->rcfile = malloc(len);
	if (!state->rcfile) return -1;
	snprintf(state->rcfile, len, "%s/%s", state->dir_home, LCAGENTRC);
	return 0;
}

int state_dirs(state_t *state, char *home)
{
	char *xdg_config_home = NULL;
	char *xdg_cache_home = NULL;
	char *xdg_state_home = NULL;
	char *xdg_data_home = NULL;
	state->dir_home = (home) ? home : getenv("HOME");
	if (!state->dir_home) errx(EXIT_FAILURE, "HOME environment variable not set");
	/* if home set, construct dirs else try env, then construct dirs * */
	if (!home) {
		xdg_config_home = getenv("XDG_CONFIG_HOME");
		xdg_cache_home = getenv("XDG_CACHE_HOME");
		xdg_state_home = getenv("XDG_STATE_HOME");
		xdg_data_home = getenv("XDG_DATA_HOME");
	}
	if (!(state->dir_config = state_dir_create(state, xdg_config_home, ".config")))
		return -1;
	if (!(state->dir_cache = state_dir_create(state, xdg_cache_home, ".cache")))
		return -1;
	if (!(state->dir_state = state_dir_create(state, xdg_state_home, ".local/state")))
		return -1;
	if (!(state->dir_data = state_dir_create(state, xdg_data_home, ".local/share")))
		return -1;
	if (state_rcfile(state))
		return -1;
	return 0;
}

state_chan_t *state_chan_by_addr(state_t *state, struct in6_addr *addr)
{
	for (state_chan_t *a = state->chan_head; a; a = a->next) {
		if (a->chan && !memcmp(addr, lc_channel_in6addr(a->chan), sizeof(struct in6_addr)))
			return a;
	}
	return NULL;
}

state_chan_t *state_chan_by_name(state_t *state, char *name)
{
	for (state_chan_t *a = state->chan_head; a; a = a->next) {
		if (!strcmp(a->chan_name, name)) return a;
	}
	return NULL;
}

int state_authkey_set(state_t *state, char *authkey)
{
	state_chan_t *achan = (state->chan_head) ? state->chan_head : &state->defaults;
	char *tmp;
	size_t hexbyt = crypto_sign_PUBLICKEYBYTES * 2;
	size_t sz;
	achan->auth_keys++;
	sz = hexbyt * achan->auth_keys;
	tmp = realloc(achan->auth_key, sz);
	if (!tmp) return -1;
	achan->auth_key = tmp;
	tmp = achan->auth_key + hexbyt * (achan->auth_keys - 1);
	memcpy(tmp, authkey, hexbyt);
	return 0;
}

int state_push_channel(state_t *state, char *channel_name)
{
	state_chan_t *a = malloc(sizeof (state_chan_t));
	if (!a) return -1;
	memset(a, 0, sizeof(state_chan_t));
	*a = state_channel_defaults;
	a->next = state->chan_head;
	a->chan_name = strdup(channel_name);
	if (!a->chan_name) return freesv(a), -1;
	a->chan_status_cmd = state->defaults.chan_status_cmd;
	a->flags = state->defaults.flags;
	a->dir = state->defaults.dir;
	hash_generic(a->hash, HASHSIZE, (unsigned char *)channel_name, strlen(channel_name));
	hash_bin2hex(a->hex, HEXLEN, a->hash, HASHSIZE);
	state->chan_head = a;
	return 0;
}

int state_push_command(state_t *state, char *command, int flags)
{
	if (!state->chan_head) return -1;
	state_cmd_t *cmd = malloc(sizeof (state_cmd_t));
	if (!cmd) return -1;
	memset(cmd, 0, sizeof(state_cmd_t));
	cmd->cmd = strdup(command);
	if (!cmd->cmd) return freesv(cmd), -1;
	cmd->flags |= flags;
	/* preserve cmd order => append to end of list */
	if (state->chan_head->cmd) {
		state_cmd_t *ptr;
		for (ptr = state->chan_head->cmd; ptr->next; ptr = ptr->next);
		ptr->next = cmd;
	}
	else state->chan_head->cmd = cmd;
	return 0;
}

int state_dir_set(state_t *state, char *arg)
{
	char *dup = strdup(arg);
	if (!dup) return -1;
	if (state->chan_head) state->chan_head->dir = dup;
	else state->defaults.dir = dup;
	return 0;
}

int state_seed_set(state_t *state, char *arg)
{
	/* require strict permissions on rcfile */
	struct stat sb;
	if (stat(state->rcfile, &sb) == -1) return -1;
	if ((sb.st_mode & S_IROTH)) {
		ERROR("rcfile is world readable\n");
		return -1;
	}
	char *dup = strdup(arg);
	if (!dup) return -1;
	if (state->chan_head) state->chan_head->seed = dup;
	else state->defaults.seed = dup;
	return 0;
}

int state_logfile_set(state_t *state, char *arg)
{
	char *dup = strdup(arg);
	if (!dup) return -1;
	state->logfile = dup;
	return 0;
}

int state_status_set(state_t *state, char *arg)
{
	char *dup = strdup(arg);
	if (!dup) return -1;
	if (state->chan_head) state->chan_head->chan_status = dup;
	else state->defaults.chan_status = dup;
	return 0;
}

int state_statuscmd_set(state_t *state, char *arg)
{
	char *dup = strdup(arg);
	if (!dup) return -1;
	if (state->chan_head) state->chan_head->chan_status_cmd = dup;
	else state->defaults.chan_status_cmd = dup;
	return 0;
}

int state_stdout_set(state_t *state, char *arg)
{
	char *dup = strdup(arg);
	if (!dup) return -1;
	if (state->chan_head) state->chan_head->chan_stdout = dup;
	else state->defaults.chan_stdout = dup;
	return 0;
}

int state_stderr_set(state_t *state, char *arg)
{
	char *dup = strdup(arg);
	if (!dup) return -1;
	if (state->chan_head) state->chan_head->chan_stderr = dup;
	else state->defaults.chan_stderr = dup;
	return 0;
}

int state_stdin_set(state_t *state, char *arg)
{
	char *dup = strdup(arg);
	if (!dup) return -1;
	if (state->chan_head) state->chan_head->chan_stdin = dup;
	else state->defaults.chan_stdin = dup;
	return 0;
}

static int state_err_option_requires_arg(char *opt)
{
	fprintf(stderr, "option requires an argument -- '%s'\n", opt);
	return (errno = EINVAL), -1;
}

static int state_bpslimit_set(state_t *state, char *arg)
{
	size_t len = strlen(arg);
	char *endptr;
	if (!len) return -1;
	errno = 0;
	state->bpslimit = strtoul(arg, &endptr, 10);
	if (errno) return -1;
	if (endptr) switch (endptr[0]) {
		case 'T':
			state->bpslimit *= 1000; /* fallthru */
		case 'G':
			state->bpslimit *= 1000; /* fallthru */
		case 'M':
			state->bpslimit *= 1000; /* fallthru */
		case 'K':
			state->bpslimit *= 1000;
	}
	E(STATE_VERBOSE, "bpslimit = %" PRIu64 " bps\n", state->bpslimit);
	return 0;
}

static int state_u64_set(state_t *state, int argc, char *argv[], int *i, uint64_t *opt, char *name)
{
	if (*i == argc - 1) return state_err_option_requires_arg(name);
	(*i)++; clrbit(state->optmask, *i);
	if (!strcmp(name, "bpslimit")) return state_bpslimit_set(state, argv[*i]);
	*opt = atoll(argv[*i]);
	return 0;
}

static int state_parse_long_option(state_t *state, int argc, char *argv[], int *i)
{
	(void) argc;
	/* support commonly used options --help and --version */
	if (!strcmp(argv[*i] + 2, "help")) {
		argv[*i] += 2;
		return 0;
	}
	if (!strcmp(argv[*i] + 2, "version")) {
		argv[*i] += 2;
		return 0;
	}
#define X(a, b, c) \
	if (!strcmp(argv[*i] + 2, a)) return (STATE_SET(state, c)), 0;
	STATE_OPTIONS_LONG
#undef X
#define X(a, b, c, d) \
	if (!strcmp(argv[*i] + 2, a)) return state_u64_set(state, argc, argv, i, &state->d, a);
	STATE_OPTIONS_U64
#undef X
	return (errno = EINVAL), -1;
}

static int state_parse_short_options(state_t *state, int argc, char *argv[], int *i)
{
	char *options = argv[*i] + 1;
	size_t optlen = strlen(options);
	if (!optlen) {
		/* "-" MUST be last argument */
		if (argc - 1 != *i) return (errno = EINVAL), -1;
		optlen++;
	}
	for (size_t z = 0; z < optlen; z++) {
		switch (options[z]) {
			case 'i':
				if (*i == argc - 1) {
					return state_err_option_requires_arg(&options[z]);
				}
				(*i)++; clrbit(state->optmask, *i);
				state->ifx = if_nametoindex(argv[*i]);
				if (state->ifx > 0) continue;
				else {
					char ifname[IF_NAMESIZE];
					state->ifx = atoi(argv[*i]);
					if (state->ifx) if_indextoname(state->ifx, ifname);
				}
				break;
#define X(a, b, c) \
			case a: \
				STATE_SET(state, c); \
				continue;
				STATE_OPTIONS_SHORT
#undef X
			default:
				return (errno = EINVAL), -1;
		}
	}
	return 0;
}

static int state_parse_arg(state_t *state, char *arg)
{
#define X(a, b, c, d) \
	if (!strcmp(arg, b)) return (state->verb = a), 0;
	STATE_VERBS
#undef X
	ERROR("unknown command '%s'\n", arg);
	return (errno = EINVAL), -1;
}

char *arg_pop(state_t *state)
{
	/* could be optimized with clz, but why complicate things? */
	for (int i = 1; i < state->argc; i++) {
		if (!isset(state->optmask, i)) continue;
		clrbit(state->optmask, i);
		return state->argv[i];
	}
	return NULL;
}

int state_parse_args(state_t *state, int argc, char *argv[])
{
	int rc = 0;
	int optlen = sizeof state->optmask;
	if (howmany(argc, CHAR_BIT) > optlen) {
		fprintf(stderr, "%s: %s\n", PACKAGE_NAME, strerror(E2BIG));
		return (errno = E2BIG), -1;
	}
	memset(state->optmask, ~0, optlen);
	clrbit(state->optmask, 0); /* skip program name */
	state->argc = argc;
	state->argv = argv;
	for (int i = 1; i < argc; i++) {
		/* option parsing ends if -- encountered */
		if (!strcmp(argv[i], "--")) {
			clrbit(state->optmask, i);
			break;
		}
		/* process long options */
		if (!strncmp(argv[i],"--", 2)) {
			clrbit(state->optmask, i);
			rc = state_parse_long_option(state, argc, argv, &i);
			if (rc) return rc;
		}
		/* process short options */
		else if (*argv[i] == '-') {
			clrbit(state->optmask, i);
			rc = state_parse_short_options(state, argc, argv, &i);
			if (rc) return rc;
		}
	}
	/* options done, set verb from next arg */
	char *verb = arg_pop(state);
	if (verb) rc = state_parse_arg(state, verb);
	return rc;
}

int state_parse_configfile(state_t *state, char *configfile)
{
	int rc = -1;
	yyin = fopen(configfile, "r");
	if (yyin != NULL) {
		rc = yyparse(state);
		fclose(yyin);
	}
	return (rc == 0) ? 0 : -1;
}

int state_parse_config(state_t *state, char *config, size_t len)
{
	YY_BUFFER_STATE buf = yy_scan_bytes(config, len);
	int rc = yyparse(state);
	yy_delete_buffer(buf);
	return (rc == 0) ? 0 : -1;
}

/* return path to pidfile. free() after use */
char *state_pidfile(state_t *state)
{
	char *pathname = NULL;
	int rc;
	if (!state->dir_state) return (errno = ENOENT), NULL;
	rc = snprintf(NULL, 0, "%s/%s.pid", state->dir_state, PACKAGE_NAME);
	if (rc < 0) return NULL;
	pathname = malloc(rc + 1);
	if (pathname) {
		if (snprintf(pathname, rc + 1, "%s/%s.pid", state->dir_state, PACKAGE_NAME) != rc) {
			free(pathname);
			pathname = NULL;
		}
	}
	return pathname;
}

int state_defaults_set(state_t *state)
{
	state->defaults = state_channel_defaults;
	return 0;
}
