#!/usr/bin/python3
# vim: shiftwidth=4 tabstop=4 expandtab

import argparse
import datetime
import json
import logging
import os
import re
import shutil
import subprocess  # nosec B404
import textwrap
import time

import dotenv
import humanize
import prettytable
from mako.template import Template as MakoTemplate

sort_choices = {
    "id": "queue_id",
    "state": "queue_name",
    "date": "arrival_time",
    "size": "message_size",
}
queue_names = [
    "active",
    "deferred",
    "bounce",
    "corrupt",
    "defer",
    "flush",
    "hold",
    "incoming",
    "private",
    "public",
    "save",
    "trace",
    "maildrop",
]

# Load ee-postfix-tools configurations in environment
config_dir = "/etc/ee-postfix-tools"
if os.path.isdir(config_dir):
    for file in sorted([f for f in os.listdir(config_dir) if f.endswith(".conf")]):
        dotenv.load_dotenv(os.path.join(config_dir, file), override=True)

parser = argparse.ArgumentParser()
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("-v", "--verbose", action="store_true")
parser.add_argument(
    "-i",
    "--instance",
    help="Select postfix instance (optional, default: use default postfix instance)",
    default=os.environ.get("POSTFIX_INSTANCE"),
)
parser.add_argument(
    "--input",
    help="Load result of 'postqueue -j' command from specified file instead of running this "
    "command (useful for devel)",
)


output_args = parser.add_argument_group("Output")
output_args.add_argument("-o", "--output", help="Output file (default: stdout)")
output_args.add_argument(
    "-J",
    "--json",
    action="store_true",
    help=(
        "JSON output (format: dictionary with ID as key. See JSON OBJECT FORMAT in postqueue "
        "manual message's attributes)"
    ),
)
output_args.add_argument(
    "-X",
    "--template",
    type=MakoTemplate,
    help="Mako template use to compute the render a message",
)
output_args.add_argument(
    "-O", "--output-data", choices=["id", "count", "total-size"], help="Output data"
)
output_args.add_argument(
    "-s",
    "--sort",
    choices=sort_choices.keys(),
    help="Sort message in output",
)
output_args.add_argument("--reverse", action="store_true", help="Reverse result set")
output_args.add_argument(
    "-l",
    "--limit",
    type=int,
    help="Limit result set to the first X message(s)",
)

filter_args = parser.add_argument_group(
    "Filtering messages", "Note: All specified filters must match."
)
filter_args.add_argument("-I", "--id", action="append", help="Filter on email ID")
filter_args.add_argument("-f", "--sender", action="append", help="Filter on email sender")
filter_args.add_argument(
    "-F",
    "--sender-regex",
    type=re.compile,
    action="append",
    help="Filter on email sender using regex",
)
filter_args.add_argument("-t", "--recipient", action="append", help="Filter on email recipient")
filter_args.add_argument(
    "-T",
    "--recipient-regex",
    type=re.compile,
    action="append",
    help="Filter on email recipient using regex",
)
filter_args.add_argument(
    "-D",
    "--delay-reason-regex",
    type=re.compile,
    action="append",
    help="Filter on delay reason using regex",
)
filter_args.add_argument(
    "-m",
    "--min-size",
    type=int,
    help="Filter email by minimal size (bytes)",
)
filter_args.add_argument(
    "-M",
    "--max-size",
    type=int,
    help="Filter email by maximal size (bytes)",
)
filter_args.add_argument(
    "-S",
    "--state",
    choices=queue_names,
    action="append",
    help=f"Filter email by state (possible values: {', '.join(queue_names)})",
)

_age_filter_option_regex = re.compile(
    r"^(?P<op><|<=|>|>=)?(?P<number>[0-9]+)(?P<unit>s|m|h|d)?$", re.IGNORECASE
)


def age_filter_option(value):
    match = _age_filter_option_regex.match(value)
    if not match:
        raise ValueError()
    op = match.group("op") or ">"
    number = int(match.group("number"))
    unit = (match.group("unit") or "m").lower()
    delta = (
        datetime.timedelta(minutes=number)
        if unit == "m"
        else (
            datetime.timedelta(hours=number)
            if unit == "h"
            else (
                datetime.timedelta(days=number)
                if unit == "d"
                else datetime.timedelta(seconds=number)
            )
        )
    )
    return (
        (lambda x: x >= delta)
        if op == ">="
        else (
            (lambda x: x < delta)
            if op == "<"
            else (lambda x: x <= delta) if op == "<=" else (lambda x: x > delta)
        )
    )


filter_args.add_argument(
    "-a",
    "--age",
    type=age_filter_option,
    action="append",
    help="Filter email by age (multiple filter allowed, format: [operator][value][unit], "
    "example: '>7d' for more than 7 days in mail queue. Operators: '>' (default), '>=', '<' "
    "or '<=', units: 's' (default), 'm', 'h', or 'd')",
)

args = parser.parse_args()

if args.debug:
    loglevel = logging.DEBUG
elif args.verbose:
    loglevel = logging.INFO
else:
    loglevel = logging.WARNING

logging.basicConfig(
    level=loglevel,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%m/%d/%Y %I:%M:%S %p",
)

if args.input:
    with open(args.input, encoding="utf8") as file_desc:
        output = file_desc.read()
else:
    cmd = ["/usr/sbin/postqueue", "-j"]
    if args.instance:
        cmd = ["/usr/sbin/postmulti", "-i", args.instance, "-x"] + cmd
    logging.debug("Run '%s'", "' '".join(cmd))

    proc = subprocess.run(cmd, shell=False, stdout=subprocess.PIPE)  # nosec B603
    output = proc.stdout.decode()


def get_terminal_size(default_witdh=None, default_height=None):
    """Retrieve terminal size"""
    return shutil.get_terminal_size(
        (
            default_witdh if default_witdh else 80,
            default_height if default_height else 80,
        )
    )


def format_timestamp(timestamp):
    """Format timestamp"""
    dt = datetime.datetime.fromtimestamp(timestamp)
    if dt.date() == datetime.date.today():
        return dt.strftime("%Y/%m/%d %H:%M:%S")
    return dt.strftime("%Y/%m/%d %H:%M:%S")


def match_any(regs, values):
    """Check if at least one of specify values match with any of specified regex"""
    for regex in regs:
        for value in values:
            if regex.search(value):
                return True
    return False


def filter_msg(msg):
    """Filter message based on script's arguments"""
    if args.id and msg["queue_id"] not in args.id:
        return False

    if args.state and msg["queue_name"] not in args.state:
        return False

    if args.min_size and msg["message_size"] < args.min_size:
        return False

    if args.max_size and msg["message_size"] > args.max_size:
        return False

    if args.sender and msg.get("sender") not in args.sender:
        return False

    if args.sender_regex:
        for regex in args.sender_regex:
            if not regex.search(msg["sender"]):
                return False

    if args.recipient and not {
        recipient["address"] for recipient in msg["recipients"]
    }.intersection(args.recipient):
        return False

    if args.recipient_regex:
        if not match_any(
            args.recipient_regex,
            [recipient["address"] for recipient in msg["recipients"]],
        ):
            return False

    if args.delay_reason_regex:
        if not match_any(
            args.delay_reason_regex,
            [
                recipient["delay_reason"]
                for recipient in msg["recipients"]
                if "delay_reason" in recipient
            ],
        ):
            return False

    if args.age:
        age = datetime.timedelta(seconds=time.time() - msg["arrival_time"])
        for age_filter in args.age:
            if not age_filter(age):
                return False

    return True


msgs = {}
for line in output.split("\n"):
    # Ignore empty line
    if not line:
        continue
    try:
        msg = json.loads(line)
    except ValueError:
        logging.warning("Fail to decode JSON message line: '%s'", line)
        continue
    if filter_msg(msg):
        logging.debug("Message %s: %s", msg.get("queue_id", "/!\\ NO ID ?? /?\\"), msg)
        msgs[msg["queue_id"]] = msg

if args.sort:
    logging.debug("Sort by %s", args.sort)
    msgs = {
        msg_id: msg
        for msg_id, msg in sorted(
            msgs.items(), key=lambda item: item[1].get(sort_choices[args.sort])
        )
    }

if args.reverse:
    msgs = {msg_id: msgs[msg_id] for msg_id in reversed(msgs)}

if args.limit:
    msgs = {msg_id: msgs[msg_id] for msg_id in list(msgs)[0 : args.limit]}

if args.json:
    if args.output_data == "id":
        output = list(msgs.keys())
    elif args.output_data == "count":
        output = len(msgs)
    elif args.output_data == "total-size":
        output = sum([msg["message_size"] for msg in msgs.values()])
    elif args.template:
        output = {msg_id: args.template.render(**msg) for msg_id, msg in msgs.items()}
    else:
        output = msgs
    output = json.dumps(output, indent=2)
elif args.output_data == "id":
    output = "\n".join(msgs.keys())
elif args.output_data == "count":
    output = len(msgs)
elif args.output_data == "total-size":
    output = humanize.naturalsize(sum([msg["message_size"] for msg in msgs.values()]))
elif args.template:
    output = "\n".join([args.template.render(**msg) for msg in msgs.values()])
else:
    total_size = 0
    table = prettytable.PrettyTable(
        field_names=[
            "ID",
            "State",
            "Arrival",
            "Size",
            "Sender & Recipients",
        ],
        hrules=prettytable.ALL,
    )
    if hasattr(prettytable, "DOUBLE_BORDER"):
        table.set_style(prettytable.DOUBLE_BORDER)  # pylint: disable=no-member
    table.align["Sender & Recipients"] = "l"
    line_width = min(80, int(get_terminal_size().columns / 2))
    logging.debug("Output line width: %d", line_width)
    for msg_id, msg in msgs.items():
        sender_recipients = [f"Sender: {msg['sender']}"]
        if len(msg["recipients"]) == 1:
            recipient = msg["recipients"].pop()
            sender_recipients.append(f"Recipient: {recipient['address']}")
            if "delay_reason" in recipient:
                sender_recipients.append(
                    textwrap.fill(
                        recipient["delay_reason"],
                        width=line_width,
                        initial_indent="  ",
                        subsequent_indent="  ",
                    )
                )
        else:
            sender_recipients.append("Recipients:")
            for recipient in msg["recipients"]:
                sender_recipients.append(recipient["address"])
                if "delay_reason" in recipient:
                    sender_recipients.append(
                        textwrap.fill(
                            recipient["delay_reason"],
                            width=line_width,
                            initial_indent="  ",
                            subsequent_indent="  ",
                        )
                    )
        table.add_row(
            [
                msg_id,
                msg["queue_name"],
                humanize.naturaltime(datetime.datetime.fromtimestamp(msg["arrival_time"])),
                humanize.naturalsize(msg["message_size"]),
                "\n".join(sender_recipients),
            ]
        )
        total_size += msg["message_size"]
    output = (
        table.get_string()
        + "\n"
        + f"Total: {len(msgs)} message(s) (size: {humanize.naturalsize(total_size)})"
    )

if args.output:
    with open(args.output, "w", encoding="utf8") as file_desc:
        file_desc.write(output)
elif output:
    print(output)
