/*  Copyright (C) 2015-2017 CZ.NIC, z.s.p.o. <knot-dns@labs.nic.cz>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
 */

#include <assert.h>
#include <libdnssec/binary.h>
#include <libdnssec/crypto.h>
#include <libdnssec/error.h>
#include <libdnssec/key.h>
#include <libdnssec/sign.h>
#include <libknot/descriptor.h>
#include <libknot/packet/wire.h>
#include <libknot/rdataset.h>
#include <libknot/rrset.h>
#include <libknot/rrtype/dnskey.h>
#include <libknot/rrtype/nsec.h>
#include <libknot/rrtype/rrsig.h>
#include <contrib/wire.h>

#include "lib/defines.h"
#include "lib/dnssec/nsec.h"
#include "lib/dnssec/nsec3.h"
#include "lib/dnssec/signature.h"
#include "lib/dnssec.h"
#include "lib/resolve.h"

/* forward */
static int kr_rrset_validate_with_key(kr_rrset_validation_ctx_t *vctx,
	const knot_rrset_t *covered, size_t key_pos, const struct dseckey *key);

void kr_crypto_init(void)
{
	dnssec_crypto_init();
}

void kr_crypto_cleanup(void)
{
	dnssec_crypto_cleanup();
}

void kr_crypto_reinit(void)
{
	dnssec_crypto_reinit();
}

#define FLG_WILDCARD_EXPANSION 0x01 /**< Possibly generated by using wildcard expansion. */

/**
 * Check the RRSIG RR validity according to RFC4035 5.3.1 .
 * @param flags     The flags are going to be set according to validation result.
 * @param cov_labels Covered RRSet owner label count.
 * @param rrsigs    rdata containing the signatures.
 * @param key_owner Associated DNSKEY's owner.
 * @param key_rdata Associated DNSKEY's rdata.
 * @param keytag    Used key tag.
 * @param zone_name The name of the zone cut.
 * @param timestamp Validation time.
 */
static int validate_rrsig_rr(int *flags, int cov_labels,
                             const knot_rdata_t *rrsigs,
                             const knot_dname_t *key_owner, const knot_rdata_t *key_rdata,
			     uint16_t keytag,
                             const knot_dname_t *zone_name, uint32_t timestamp)
{
	if (!flags || !rrsigs || !key_owner || !key_rdata || !zone_name) {
		return kr_error(EINVAL);
	}
	/* bullet 5 */
	if (knot_rrsig_sig_expiration(rrsigs) < timestamp) {
		return kr_error(EINVAL);
	}
	/* bullet 6 */
	if (knot_rrsig_sig_inception(rrsigs) > timestamp) {
		return kr_error(EINVAL);
	}
	/* bullet 2 */
	const knot_dname_t *signer_name = knot_rrsig_signer_name(rrsigs);
	if (!signer_name || !knot_dname_is_equal(signer_name, zone_name)) {
		return kr_error(EAGAIN);
	}
	/* bullet 4 */
	{
		int rrsig_labels = knot_rrsig_labels(rrsigs);
		if (rrsig_labels > cov_labels) {
			return kr_error(EINVAL);
		}
		if (rrsig_labels < cov_labels) {
			*flags |= FLG_WILDCARD_EXPANSION;
		}
	}

	/* bullet 7 */
	if ((!knot_dname_is_equal(key_owner, signer_name)) ||
	    (knot_dnskey_alg(key_rdata) != knot_rrsig_alg(rrsigs)) ||
	    (keytag != knot_rrsig_key_tag(rrsigs))) {
		return kr_error(EINVAL);
	}
	/* bullet 8 */
	/* Checked somewhere else. */
	/* bullet 9 and 10 */
	/* One of the requirements should be always fulfilled. */

	return kr_ok();
}

/**
 * Returns the number of labels that have been added by wildcard expansion.
 * @param expanded Expanded wildcard.
 * @param rrsigs   RRSet containing the signatures.
 * @param sig_pos  Specifies the signature within the RRSIG RRSet.
 * @return         Number of added labels, -1 on error.
 */
static inline int wildcard_radix_len_diff(const knot_dname_t *expanded,
					  const knot_rdata_t *rrsig)
{
	if (!expanded || !rrsig) {
		return -1;
	}

	return knot_dname_labels(expanded, NULL) - knot_rrsig_labels(rrsig);
}

int kr_rrset_validate(kr_rrset_validation_ctx_t *vctx, const knot_rrset_t *covered)
{
	if (!vctx) {
		return kr_error(EINVAL);
	}
	if (!vctx->pkt || !covered || !vctx->keys || !vctx->zone_name) {
		return kr_error(EINVAL);
	}

	for (unsigned i = 0; i < vctx->keys->rrs.count; ++i) {
		int ret = kr_rrset_validate_with_key(vctx, covered, i, NULL);
		if (ret == 0) {
			return ret;
		}
	}

	return kr_error(ENOENT);
}

/**
 * Validate RRSet using a specific key.
 * @param vctx    Pointer to validation context.
 * @param covered RRSet covered by a signature. It must be in canonical format.
 * @param key_pos Position of the key to be validated with.
 * @param key     Key to be used to validate.
 *		  If NULL, then key from DNSKEY RRSet is used.
 * @return        0 or error code, same as vctx->result.
 */
static int kr_rrset_validate_with_key(kr_rrset_validation_ctx_t *vctx,
				const knot_rrset_t *covered,
				size_t key_pos, const struct dseckey *key)
{
	const knot_pkt_t *pkt         = vctx->pkt;
	const knot_rrset_t *keys      = vctx->keys;
	const knot_dname_t *zone_name = vctx->zone_name;
	uint32_t timestamp            = vctx->timestamp;
	bool has_nsec3		      = vctx->has_nsec3;
	struct dseckey *created_key = NULL;

	/* It's just caller's approximation that the RR is in that particular zone.
	 * We MUST guard against attempts of zones signing out-of-bailiwick records. */
	if (knot_dname_in_bailiwick(covered->owner, zone_name) < 0) {
		vctx->result = kr_error(ENOENT);
		return vctx->result;
	}

	const knot_rdata_t *key_rdata = knot_rdataset_at(&keys->rrs, key_pos);
	if (key == NULL) {
		int ret = kr_dnssec_key_from_rdata(&created_key, keys->owner,
						   key_rdata->data, key_rdata->len);
		if (ret != 0) {
			vctx->result = ret;
			return vctx->result;
		}
		key = created_key;
	}
	uint16_t keytag = dnssec_key_get_keytag((dnssec_key_t *)key);
	int covered_labels = knot_dname_labels(covered->owner, NULL);
	if (knot_dname_is_wildcard(covered->owner)) {
		/* The asterisk does not count, RFC4034 3.1.3, paragraph 3. */
		--covered_labels;
	}

	for (uint16_t i = 0; i < vctx->rrs->len; ++i) {
		/* Consider every RRSIG that matches owner and covers the class/type. */
		const knot_rrset_t *rrsig = vctx->rrs->at[i]->rr;
		if (rrsig->type != KNOT_RRTYPE_RRSIG) {
			continue;
		}
		if ((covered->rclass != rrsig->rclass) || !knot_dname_is_equal(covered->owner, rrsig->owner)) {
			continue;
		}
		knot_rdata_t *rdata_j = rrsig->rrs.rdata;
		for (uint16_t j = 0; j < rrsig->rrs.count; ++j, rdata_j = knot_rdataset_next(rdata_j)) {
			int val_flgs = 0;
			int trim_labels = 0;
			if (knot_rrsig_type_covered(rdata_j) != covered->type) {
				continue;
			}
			int ret = validate_rrsig_rr(&val_flgs, covered_labels, rdata_j,
			                      keys->owner, key_rdata, keytag,
			                      zone_name, timestamp);
			if (ret == kr_error(EAGAIN)) {
				kr_dnssec_key_free(&created_key);
				vctx->result = ret;
				return ret;
			} else if (ret != 0) {
				continue;
			}
			if (val_flgs & FLG_WILDCARD_EXPANSION) {
				trim_labels = wildcard_radix_len_diff(covered->owner, rdata_j);
				if (trim_labels < 0) {
					break;
				}
			}
			if (kr_check_signature(rdata_j, (dnssec_key_t *) key, covered, trim_labels) != 0) {
				continue;
			}
			if (val_flgs & FLG_WILDCARD_EXPANSION) {
				int ret = 0;
				if (!has_nsec3) {
					ret = kr_nsec_wildcard_answer_response_check(pkt, KNOT_AUTHORITY, covered->owner);
				} else {
					ret = kr_nsec3_wildcard_answer_response_check(pkt, KNOT_AUTHORITY, covered->owner, trim_labels - 1);
					if (ret == kr_error(KNOT_ERANGE)) {
						ret = 0;
						vctx->flags |= KR_DNSSEC_VFLG_OPTOUT;
					}
				}
				if (ret != 0) {
					continue;
				}
				vctx->flags |= KR_DNSSEC_VFLG_WEXPAND;
			}
			/* Validated with current key, OK */
			kr_dnssec_key_free(&created_key);
			vctx->result = kr_ok();
			return vctx->result;
		}
	}
	/* No applicable key found, cannot be validated. */
	kr_dnssec_key_free(&created_key);
	vctx->result = kr_error(ENOENT);
	return vctx->result;
}

static bool kr_ds_algo_support(const knot_rrset_t *ta)
{
	knot_rdata_t *rdata_i = ta->rrs.rdata;
	for (uint16_t i = 0; i < ta->rrs.count;
			++i, rdata_i = knot_rdataset_next(rdata_i)) {
		if (dnssec_algorithm_digest_support(knot_ds_digest_type(rdata_i))
		    && dnssec_algorithm_key_support(knot_ds_alg(rdata_i))) {
			return true;
		}
	}
	return false;
}

int kr_dnskeys_trusted(kr_rrset_validation_ctx_t *vctx, const knot_rrset_t *ta)
{
	const knot_pkt_t *pkt         = vctx->pkt;
	const knot_rrset_t *keys      = vctx->keys;

	const bool ok = pkt && keys && ta && ta->rrs.count && ta->rrs.rdata
			&& ta->type == KNOT_RRTYPE_DS;
	if (!ok) {
		assert(false);
		return kr_error(EINVAL);
	}

	/* Check if at least one DS has a usable algorithm pair. */
	if (!kr_ds_algo_support(ta)) {
		/* See RFC6840 5.2. */
		return vctx->result = kr_error(DNSSEC_INVALID_DS_ALGORITHM);
	}

	/* RFC4035 5.2, bullet 1
	 * The supplied DS record has been authenticated.
	 * It has been validated or is part of a configured trust anchor.
	 */
	for (uint16_t i = 0; i < keys->rrs.count; ++i) {
		/* RFC4035 5.3.1, bullet 8 */ /* ZSK */
		/* LATER(optim.): more efficient way to iterate than _at() */
		const knot_rdata_t *krr = knot_rdataset_at(&keys->rrs, i);
		if (!kr_dnssec_key_zsk(krr->data) || kr_dnssec_key_revoked(krr->data)) {
			continue;
		}
		
		struct dseckey *key = NULL;
		if (kr_dnssec_key_from_rdata(&key, keys->owner, krr->data, krr->len) != 0) {
			continue;
		}
		if (kr_authenticate_referral(ta, (dnssec_key_t *) key) != 0) {
			kr_dnssec_key_free(&key);
			continue;
		}
		if (kr_rrset_validate_with_key(vctx, keys, i, key) != 0) {
			kr_dnssec_key_free(&key);
			continue;
		}
		kr_dnssec_key_free(&key);
		assert (vctx->result == 0);
		return vctx->result;
	}

	/* No useable key found */
	vctx->result = kr_error(ENOENT);
	return vctx->result;
}

bool kr_dnssec_key_zsk(const uint8_t *dnskey_rdata)
{
	return wire_read_u16(dnskey_rdata) & 0x0100;
}

bool kr_dnssec_key_ksk(const uint8_t *dnskey_rdata)
{
	return wire_read_u16(dnskey_rdata) & 0x0001;
}

/** Return true if the DNSKEY is revoked. */
bool kr_dnssec_key_revoked(const uint8_t *dnskey_rdata)
{
	return wire_read_u16(dnskey_rdata) & 0x0080;
}

int kr_dnssec_key_tag(uint16_t rrtype, const uint8_t *rdata, size_t rdlen)
{
	if (!rdata || rdlen == 0 || (rrtype != KNOT_RRTYPE_DS && rrtype != KNOT_RRTYPE_DNSKEY)) {
		return kr_error(EINVAL);
	}
	if (rrtype == KNOT_RRTYPE_DS) {
		return wire_read_u16(rdata);
	} else if (rrtype == KNOT_RRTYPE_DNSKEY) {
		struct dseckey *key = NULL;
		int ret = kr_dnssec_key_from_rdata(&key, NULL, rdata, rdlen);
		if (ret != 0) {
			return ret;
		}
		uint16_t keytag = dnssec_key_get_keytag((dnssec_key_t *)key);
		kr_dnssec_key_free(&key);
		return keytag;
	} else {
		return kr_error(EINVAL);
	}
}

int kr_dnssec_key_match(const uint8_t *key_a_rdata, size_t key_a_rdlen,
                        const uint8_t *key_b_rdata, size_t key_b_rdlen)
{
	dnssec_key_t *key_a = NULL, *key_b = NULL;
	int ret = kr_dnssec_key_from_rdata((struct dseckey **)&key_a, NULL, key_a_rdata, key_a_rdlen);
	if (ret != 0) {
		return ret;
	}
	ret = kr_dnssec_key_from_rdata((struct dseckey **)&key_b, NULL, key_b_rdata, key_b_rdlen);
	if (ret != 0) {
		dnssec_key_free(key_a);
		return ret;
	}
	/* If the algorithm and the public key match, we can be sure
	 * that they are the same key. */
	ret = kr_error(ENOENT);
	dnssec_binary_t pk_a, pk_b;
	if (dnssec_key_get_algorithm(key_a) == dnssec_key_get_algorithm(key_b) &&
	    dnssec_key_get_pubkey(key_a, &pk_a) == DNSSEC_EOK &&
	    dnssec_key_get_pubkey(key_b, &pk_b) == DNSSEC_EOK) {
		if (pk_a.size == pk_b.size && memcmp(pk_a.data, pk_b.data, pk_a.size) == 0) {
			ret = 0;
		}
	}
	dnssec_key_free(key_a);
	dnssec_key_free(key_b);
	return ret;
}

int kr_dnssec_key_from_rdata(struct dseckey **key, const knot_dname_t *kown, const uint8_t *rdata, size_t rdlen)
{
	if (!key || !rdata || rdlen == 0) {
		return kr_error(EINVAL);
	}

	dnssec_key_t *new_key = NULL;
	const dnssec_binary_t binary_key = {
		.size = rdlen,
		.data = (uint8_t *)rdata
	};

	int ret = dnssec_key_new(&new_key);
	if (ret != DNSSEC_EOK) {
		return kr_error(ENOMEM);
	}
	ret = dnssec_key_set_rdata(new_key, &binary_key);
	if (ret != DNSSEC_EOK) {
		dnssec_key_free(new_key);
		return kr_error(ret);
	}
	if (kown) {
		ret = dnssec_key_set_dname(new_key, kown);
		if (ret != DNSSEC_EOK) {
			dnssec_key_free(new_key);
			return kr_error(ENOMEM);
		}
	}

	*key = (struct dseckey *) new_key;
	return kr_ok();
}

void kr_dnssec_key_free(struct dseckey **key)
{
	assert(key);

	dnssec_key_free((dnssec_key_t *) *key);
	*key = NULL;
}

int kr_dnssec_matches_name_and_type(const ranked_rr_array_t *rrs, uint32_t qry_uid,
				    const knot_dname_t *name, uint16_t type)
{
	int ret = kr_error(ENOENT);
	for (size_t i = 0; i < rrs->len; ++i) {
		const ranked_rr_array_entry_t *entry = rrs->at[i];
		const knot_rrset_t *nsec = entry->rr;
		if (entry->qry_uid != qry_uid || entry->yielded) {
			continue;
		}
		if (nsec->type != KNOT_RRTYPE_NSEC &&
		    nsec->type != KNOT_RRTYPE_NSEC3) {
			continue;
		}
		if (!kr_rank_test(entry->rank, KR_RANK_SECURE)) {
			continue;
		}
		if (nsec->type == KNOT_RRTYPE_NSEC) {
			ret = kr_nsec_matches_name_and_type(nsec, name, type);
		} else {
			ret = kr_nsec3_matches_name_and_type(nsec, name, type);
		}
		if (ret == kr_ok()) {
			return kr_ok();
		} else if (ret != kr_error(ENOENT)) {
			return ret;
		}
	}
	return ret;
}
