# Copyright (C) Internet Systems Consortium, Inc. ("ISC")
#
# SPDX-License-Identifier: MPL-2.0
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0.  If a copy of the MPL was not distributed with this
# file, you can obtain one at https://mozilla.org/MPL/2.0/.
#
# See the COPYRIGHT file distributed with this work for additional
# information regarding copyright ownership.

from typing import cast

import difflib
import shutil

from dns.edns import EDECode, EDEOption

import dns.edns
import dns.flags
import dns.message
import dns.rcode
import dns.rrset
import dns.zone

import isctest.log


def rcode(message: dns.message.Message, expected_rcode) -> None:
    assert message.rcode() == expected_rcode, str(message)


def noerror(message: dns.message.Message) -> None:
    rcode(message, dns.rcode.NOERROR)


def notimp(message: dns.message.Message) -> None:
    rcode(message, dns.rcode.NOTIMP)


def refused(message: dns.message.Message) -> None:
    rcode(message, dns.rcode.REFUSED)


def servfail(message: dns.message.Message) -> None:
    rcode(message, dns.rcode.SERVFAIL)


def adflag(message: dns.message.Message) -> None:
    assert (message.flags & dns.flags.AD) != 0, str(message)


def noadflag(message: dns.message.Message) -> None:
    assert (message.flags & dns.flags.AD) == 0, str(message)


def rdflag(message: dns.message.Message) -> None:
    assert (message.flags & dns.flags.RD) != 0, str(message)


def nordflag(message: dns.message.Message) -> None:
    assert (message.flags & dns.flags.RD) == 0, str(message)


def raflag(message: dns.message.Message) -> None:
    assert (message.flags & dns.flags.RA) != 0, str(message)


def noraflag(message: dns.message.Message) -> None:
    assert (message.flags & dns.flags.RA) == 0, str(message)


def _extract_ede_options(
    message: dns.message.Message,
) -> list[EDEOption]:
    """Extract EDE options from the DNS message."""
    return cast(
        list[EDEOption],
        [
            option
            for option in message.options
            if option.otype == dns.edns.OptionType.EDE
        ],
    )


def noede(message: dns.message.Message) -> None:
    """Check that message contains no EDE option."""
    ede_options = _extract_ede_options(message)
    assert not ede_options, f"unexpected EDE options {ede_options} in {message}"


def ede(message: dns.message.Message, code: EDECode, text: str | None = None) -> None:
    """Check if message contains expected EDE code (and its text)."""
    msg_opts = _extract_ede_options(message)
    matching_opts = [opt for opt in msg_opts if opt.code == code]

    assert matching_opts, f"missing EDE code {code} in {message}"

    if text is None:
        return

    # check at least one matching EDE option has the required text
    for opt in matching_opts:
        if opt.text == text:
            return
    opt_str = ", ".join([opt.to_text() for opt in matching_opts])
    assert False, f'EDE text "{text}" not found in [{opt_str}]'


def section_equal(first_section: list, second_section: list) -> None:
    for rrset in first_section:
        assert (
            rrset in second_section
        ), f"No corresponding RRset found in second section: {rrset}"
    for rrset in second_section:
        assert (
            rrset in first_section
        ), f"No corresponding RRset found in first section: {rrset}"


def same_data(res1: dns.message.Message, res2: dns.message.Message):
    section_equal(res1.question, res2.question)
    section_equal(res1.answer, res2.answer)
    section_equal(res1.authority, res2.authority)
    section_equal(res1.additional, res2.additional)
    assert res1.rcode() == res2.rcode()


def same_answer(res1: dns.message.Message, res2: dns.message.Message):
    section_equal(res1.question, res2.question)
    section_equal(res1.answer, res2.answer)
    assert res1.rcode() == res2.rcode()


def rrsets_equal(
    first_rrset: dns.rrset.RRset,
    second_rrset: dns.rrset.RRset,
    compare_ttl: bool | None = False,
) -> None:
    """Compare two RRset (optionally including TTL)"""

    def compare_rrs(rr1, rrset):
        rr2 = next((other_rr for other_rr in rrset if rr1 == other_rr), None)
        assert rr2 is not None, f"No corresponding RR found for: {rr1}"
        if compare_ttl:
            assert rr1.ttl == rr2.ttl

    isctest.log.debug(
        "%s() first RRset:\n%s",
        rrsets_equal.__name__,
        "\n".join([str(rr) for rr in first_rrset]),
    )
    isctest.log.debug(
        "%s() second RRset:\n%s",
        rrsets_equal.__name__,
        "\n".join([str(rr) for rr in second_rrset]),
    )
    for rr in first_rrset:
        compare_rrs(rr, second_rrset)
    for rr in second_rrset:
        compare_rrs(rr, first_rrset)


def zones_equal(
    first_zone: dns.zone.Zone,
    second_zone: dns.zone.Zone,
    compare_ttl: bool | None = False,
) -> None:
    """Compare two zones (optionally including TTL)"""

    isctest.log.debug(
        "%s() first zone:\n%s",
        zones_equal.__name__,
        first_zone.to_text(relativize=False),
    )
    isctest.log.debug(
        "%s() second zone:\n%s",
        zones_equal.__name__,
        second_zone.to_text(relativize=False),
    )
    assert first_zone == second_zone
    if compare_ttl:
        for name, node in first_zone.nodes.items():
            for rdataset in node:
                found_rdataset = second_zone.find_rdataset(
                    name=name, rdtype=rdataset.rdtype
                )
                assert found_rdataset
                assert found_rdataset.ttl == rdataset.ttl


def is_executable(cmd: str, errmsg: str) -> None:
    executable = shutil.which(cmd)
    assert executable is not None, errmsg


def named_alive(named_proc, resolver_ip):
    assert named_proc.poll() is None, "named isn't running"
    msg = isctest.query.create("version.bind", "TXT", "CH")
    isctest.query.tcp(msg, resolver_ip, expected_rcode=dns.rcode.NOERROR)


def notauth(message: dns.message.Message) -> None:
    rcode(message, dns.rcode.NOTAUTH)


def nxdomain(message: dns.message.Message) -> None:
    rcode(message, dns.rcode.NXDOMAIN)


def single_question(message: dns.message.Message) -> None:
    assert len(message.question) == 1, str(message)


def empty_answer(message: dns.message.Message) -> None:
    assert not message.answer, str(message)


def rr_count_eq(section: list, expected: int):
    # NOTE: OPT and TSIG records aren't included in the count for ADDITIONAL section
    count = sum(len(rrset) for rrset in section)
    assert count == expected, str(section)


def is_response_to(response: dns.message.Message, query: dns.message.Message) -> None:
    single_question(response)
    single_question(query)
    assert query.is_response(response), str(response)


def file_contents_equal(file1, file2):
    def normalize_line(line):
        # remove trailing&leading whitespace and replace multiple whitespaces
        return " ".join(line.split())

    def read_lines(file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            return [normalize_line(line) for line in file.readlines()]

    lines1 = read_lines(file1)
    lines2 = read_lines(file2)

    differ = difflib.Differ()
    diff = differ.compare(lines1, lines2)

    for line in diff:
        assert not line.startswith("+ ") and not line.startswith(
            "- "
        ), f'file contents of "{file1}" and "{file2}" differ'
