# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import logging
import re

import defusedxml.ElementTree as ET
import requests
from django.contrib.contenttypes.models import ContentType
from django.core.management.base import BaseCommand
from django.db.models import Q
from django.utils.encoding import force_text

from authentic2.a2_rbac.models import OrganizationalUnit
from authentic2.attributes_ng.engine import get_service_attributes
from authentic2.compat_lasso import lasso
from authentic2.saml.models import (
    LibertyProvider,
    LibertyServiceProvider,
    SAMLAttribute,
    SPOptionsIdPPolicy,
)

ATTRIBUTE_NAME_FORMATS = [name_format[0] for name_format in SAMLAttribute.ATTRIBUTE_NAME_FORMATS]
DEFAULT_REQUESTS_TIMEOUT = 5


class Command(BaseCommand):
    def add_arguments(self, parser):
        parser.add_argument("-n", "--name", help="Service name")
        parser.add_argument("-s", "--slug", help="Service slug")
        parser.add_argument("-U", "--URL", help="Service metadata URL")
        parser.add_argument("-M", "--metadata", help="Service metadata file path")
        parser.add_argument("-O", "--ou", help="Service Organizational Unit name (or slug)")
        parser.add_argument(
            "-S",
            "--sp-options-policy",
            dest="sp_options_policy",
            help="SP options policy name",
        )
        parser.add_argument(
            "-E",
            "--enable-following-sp-options-policy",
            action="store_true",
            dest="enable_fspop",
            help="Enable following SP options policy",
        )

        parser.add_argument(
            "-A",
            "--attribute",
            action="append",
            dest="attributes",
            help=(
                'Attributes specifications separated by ":". Could be: just the attribute name,'
                " name:attribute_name, name:friendly_name:attribute_name,"
                "name_format:name:friendly_name:attribute_name or"
                " name_format:name:friendly_name:attribute_name:enabled. Available attributes: %s."
                "Field name_format must be one of following : %s. Field enabled must be true or "
                "false"
            )
            % (
                ",".join(self.get_exiting_attributes()),
                ",".join(ATTRIBUTE_NAME_FORMATS),
            ),
        )

        parser.add_argument(
            "-N",
            "--no-confirm",
            action="store_true",
            dest="no_confirm",
            help="Disable confirmation",
        )

        parser.add_argument(
            "-t",
            "--timeout",
            type=int,
            help="Timeout on retrieving service metadata (in seconds, default: %ss)"
            % DEFAULT_REQUESTS_TIMEOUT,
            default=DEFAULT_REQUESTS_TIMEOUT,
        )

    __exiting_attributes = None

    def get_exiting_attributes(self):
        if self.__exiting_attributes:
            return self.__exiting_attributes
        self.__exiting_attributes = {
            name: label for name, label in get_service_attributes(None) if name
        }
        return self.__exiting_attributes

    def is_existing_attribute(self, name):
        return name in self.get_exiting_attributes()

    def handle(self, *args, **kwargs):
        root_logger = logging.getLogger()
        logger = logging.getLogger(__name__)

        # ensure log messages are displayed only once on terminal
        stream_handlers = [
            x
            for x in root_logger.handlers
            if isinstance(x, logging.StreamHandler)
            if x.stream.isatty()
        ]
        if stream_handlers:
            handler = stream_handlers[0]
        else:
            handler = logging.StreamHandler()
            logger.addHandler(handler)

        # add timestamp to messages
        formatter = logging.Formatter(fmt="%(asctime)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
        handler.setFormatter(formatter)

        verbosity = int(kwargs["verbosity"])
        if verbosity == 1:
            logger.setLevel(logging.ERROR)
        elif verbosity == 2:
            logger.setLevel(logging.INFO)
        elif verbosity == 3:
            logger.setLevel(logging.DEBUG)

        name = kwargs.get("name")
        if not name:
            logger.fatal("You must provide service name using -n/--name parameter")
            return

        # Handle -s / --slug parameter
        slug = kwargs.get("slug")
        if not slug:
            slug = re.sub(r"[^a-z\_]", "", name.lower())

        # Handle metadata parameters
        metadata_filepath = kwargs.get("metadata")
        metadata_url = kwargs.get("URL")
        if not metadata_filepath and not metadata_url:
            logger.fatal(
                "You must provide service metadata URL (-U/--URL) or filepath (-M/--metadata)."
            )
            return

        if metadata_filepath:
            logger.info(
                'Add service provider "%s" (%s) using metadata file "%s"',
                name,
                slug,
                metadata_filepath,
            )
        else:
            logger.info(
                'Add service provider "%s" (%s) using metadata URL "%s"',
                name,
                slug,
                metadata_url,
            )

        # Handle -O / --ou parameter (OrganizationalUnit)
        ou = None
        if kwargs.get("ou"):
            ou = OrganizationalUnit.objects.filter(Q(name=kwargs["ou"]) | Q(slug=kwargs["ou"]))
            if not ou:
                logger.fatal('Unknown OU "%s"', kwargs["ou"])
                return
            if len(ou) > 1:
                logger.fatal('Duplicated OU "%s" found', kwargs["ou"])
                return
            ou = ou[0]
            logger.info('OU "%s" selected', ou)

        # Handle -S / --sp-options-policy parameter
        sp_options_policy = None
        if kwargs.get("sp_options_policy"):
            sp_options_policy = SPOptionsIdPPolicy.objects.filter(
                name=kwargs.get("sp_options_policy")
            )
            if not sp_options_policy:
                logger.fatal('Unknown service options policy "%s"', kwargs["sp_options_policy"])
                return
            if len(sp_options_policy) > 1:
                logger.fatal(
                    'Duplicated service options policy "%s" found',
                    kwargs["sp_options_policy"],
                )
                return
            sp_options_policy = sp_options_policy[0]
            logger.info('SP options policy "%s" selected', sp_options_policy)

        # Handle -A / --attribute parameters
        attributes = []
        if kwargs.get("attributes"):
            for orig_attribute_specs in kwargs["attributes"]:
                attribute_specs = [spec.strip() for spec in orig_attribute_specs.split(":")]
                if len(attribute_specs) == 1:
                    attribute_specs_args = ["attribute_name"]
                elif len(attribute_specs) == 2:
                    attribute_specs_args = ["name", "attribute_name"]
                elif len(attribute_specs) == 3:
                    attribute_specs_args = ["name", "friendly_name", "attribute_name"]
                elif len(attribute_specs) == 4:
                    attribute_specs_args = [
                        "name_format",
                        "name",
                        "friendly_name",
                        "attribute_name",
                    ]
                elif len(attribute_specs) == 5:
                    attribute_specs_args = [
                        "name_format",
                        "name",
                        "friendly_name",
                        "attribute_name",
                        "enabled",
                    ]
                else:
                    logger.fatal('Invalid attribute spec "%s"', orig_attribute_specs)
                    return

                attribute = dict()
                for idx, arg in enumerate(attribute_specs_args):
                    value = attribute_specs[idx]
                    if arg == "enabled":
                        if isinstance(value, str) and value.lower() in (
                            "true",
                            "false",
                        ):
                            value = value.lower() == "true"
                        else:
                            logger.fatal(
                                'Invalid attribute spec "%s": enabled must be true or false',
                                orig_attribute_specs,
                            )
                    if isinstance(value, str) and not value:
                        continue
                    attribute[arg] = value

                if not self.is_existing_attribute(attribute["attribute_name"]):
                    logger.fatal('Unknown attribute "%s"', attribute)
                    return
                logger.info(
                    "Attribute specs:\n  - %s",
                    "\n  - ".join(f"{key}: {value}" for key, value in attribute.items()),
                )
                attributes.append(attribute)

        # Load metadata
        if metadata_filepath:
            try:
                with open(metadata_filepath) as fd:
                    content = fd.read()
            except Exception:
                logger.fatal(
                    'Fail to read content of metadata file "%s"',
                    metadata_filepath,
                    exc_info=True,
                )
                return
        else:
            try:
                response = requests.get(metadata_url, timeout=kwargs["timeout"])  # nosec
                response.raise_for_status()
                content = force_text(response.content)
            except requests.RequestException:
                logger.fatal("Retrieval of %s failed.", metadata_url, exc_info=True)
                return

        # Check metadata
        root = ET.fromstring(content)
        if root.tag != "{%s}EntityDescriptor" % lasso.SAML2_METADATA_HREF:
            logger.fatal("Invalid SAML metadata: missing EntityDescriptor tag")
            return
        is_sp = root.find("{%s}SPSSODescriptor" % lasso.SAML2_METADATA_HREF) is not None
        if not is_sp:
            logger.fatal("Invalid SAML metadata: missing SPSSODescriptor tags")
            return

        logger.debug("Loaded metadata:\n  %s", re.sub("\n", "\n  ", content))

        # Confirmation
        if not kwargs["no_confirm"]:
            answer = input("Do you confirm you want to add this service provider? [y/N] ")
            if answer.lower().strip() != "y":
                logger.warning("User canceled")
                return

        liberty_provider_kwargs = dict()
        if metadata_url:
            liberty_provider_kwargs["metadata_url"] = metadata_url
        if ou:
            liberty_provider_kwargs["ou"] = ou
        liberty_provider = LibertyProvider(
            name=name, slug=slug, metadata=content, **liberty_provider_kwargs
        )
        liberty_provider.full_clean(exclude=("entity_id", "protocol_conformance"))

        liberty_service_provider = LibertyServiceProvider(
            liberty_provider=liberty_provider,
            enabled=True,
            sp_options_policy=sp_options_policy,
            enable_following_sp_options_policy=kwargs.get("enable_fspop", False),
        )

        liberty_provider.save()
        liberty_service_provider.liberty_provider = liberty_provider
        liberty_service_provider.save()
        logger.info("Service %s added.", name)

        if attributes:
            content_type = ContentType.objects.get_for_model(LibertyProvider)
            for attribute_specs in attributes:
                try:
                    attribute = SAMLAttribute(
                        content_type=content_type,
                        object_id=liberty_service_provider.pk,
                        **attribute_specs,
                    )
                    attribute.save()
                    logger.info(
                        'Attribute "%s" added',
                        attribute_specs.get("name")
                        or attribute_specs.get("attribute_name")
                        or attribute_specs,
                    )
                except Exception:
                    logger.error(
                        'Exception occurred adding attribute "%s"',
                        attribute_specs.get("name")
                        or attribute_specs.get("attribute_name")
                        or attribute_specs,
                        exc_info=True,
                    )
