#!/usr/bin/env python3
#
# Gateway FaiL Over Wan (formely known as Dead Gateway Detection)
# Copyright (c) 2009 Julien Danjou <jdanjou@easter-eggs.com>
# Copyright (c) 2013-2019 Cyril Lacoux <clacoux@easter-eggs.com>
#

""" Gateway FaiL Over Wan """

from argparse import ArgumentParser
from configparser import ConfigParser
import json
from json import JSONDecodeError
import logging
from logging.config import fileConfig
import os
import os.path
import re
import stat
import subprocess
import sys


COMMANDS = dict(
    ip='/bin/ip',
    fping='/usr/bin/fping',
)


class ExceptionFilter(logging.Filter):
    """
    Logging filter that increase log level to extra level EXCEPTION when logging an exception
    This is very usefull when you want to log *only* exceptions on a specific handler
    """

    def filter(self, record):
        if record.exc_info is not None:
            # Increase level to EXCEPTION
            record.levelno = logging.EXCEPTION
            record.levelname = logging.getLevelName(record.levelno)

        return True


# Monkey patch logging module to add support for EXCEPTION level
logging.EXCEPTION = 60
logging.addLevelName(logging.EXCEPTION, 'EXCEPTION')

log = logging.getLogger()
log.addFilter(ExceptionFilter())


class GFLOW:
    """ Gateway FaiL Over Wan """

    def __init__(self, filepath, commands, dry_run=False, multipath=False):

        self.commands = commands
        self.conf_filepath = filepath
        self.dry_run = dry_run
        self.multipath = multipath

        self.cache_filepath = None
        self.gateways = []

    def __enter__(self):
        """ Method called when entering the ```with``` statement """

        self.gateways = self.__load()

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """ Method called when exiting the ```with``` statement """

        if not self.dry_run:
            save_cache(self.gateways, self.cache_filepath)

        self.gateways = []
        self.route_flush()

    def __load(self):
        """ Load gateways from configuration and cache files """

        parser = ConfigParser(defaults=dict(here=os.path.dirname(self.conf_filepath)))
        parser.read(self.conf_filepath)

        if not parser.getboolean('global', 'configured', fallback=False):
            log.error('Please configure gflow before running it. Then, change the "configured" option from "global" section to "yes" (in %s).', self.conf_filepath)
            return []

        self.cache_filepath = parser.get('global', 'cache_file', fallback='/tmp/gflow.json')

        for key, fallback in COMMANDS.items():
            if self.commands.get(key) is None:
                self.commands[key] = parser.get('commands', key, fallback=fallback)

        if self.multipath is None:
            self.multipath = parser.getboolean('global', 'multipath', fallback=False)

        # Load gateways from cache file
        data = dict(
            (gateway['id'], gateway)
            for gateway in load_cache(self.cache_filepath, fallback={})
        )

        # Load gateways from configuration file
        gateways = []
        for section in parser.sections():
            if not section.startswith('gateway:'):
                continue

            try:
                index = int(section[8:])
            except ValueError:
                log.error('Failed to extract index for gateway "%s", exiting.', section)
                return []

            options = parser.options(section)
            if 'address' in options:
                id_ = address = parser.get(section, 'address')
                iface = None
            elif 'iface' in options:
                id_ = iface = parser.get(section, 'iface')
                address = None
            else:
                log.error('Missing `address` or `iface` parameter for gateway #%s, exiting.', index)
                return []

            gateway = data.pop(id_, None)
            if gateway is None:
                log.info('Adding new monitored gateway #%s...', id_)

                gateway = dict(
                    id=id_,
                    address=address,
                    iface=iface,
                )

            gateway['index'] = index
            gateway['reachable'] = parser.getboolean(section, 'reachable', fallback=True)
            gateway['remotes'] = parser.get(section, 'remotes', fallback='').split()

            gateway['total'] = len(gateway['remotes']) + 1
            gateway['score'] = parser.getint(section, 'score', fallback=gateway['total'])
            gateway['weight'] = parser.getint(section, 'weight', fallback=1)

            for event in ('change', 'down', 'up'):
                gateway[event] = parser.get(section, event, fallback=None)

            gateways.append(gateway)

        # Removing old gateways
        for id_ in data:
            log.info('Removing old gateway #%s...', id_)

        return sorted(gateways, key=lambda x: x['index'])

    def __setup(self, address=None):
        """ Load configuration file """

        if not self.gateways:
            log.error('No gateway set. Exiting.')
            return False

        alive_count = 0
        for gateway in self.gateways:
            if gateway.get('address') is None and not self.load_peer(gateway):
                log.warning('Failed to get peer for gateway #%s, tunnel is probably down.', gateway['id'])
                continue

            gateway['is_default'] = False

            self.load_active_prefixes(gateway)
            self.check_gateway(gateway)

            if address in ('all', gateway['address']) or 'prefixes' not in gateway:
                gateway['prefixes'] = gateway['active_prefixes']

            gateway['is_active'] = gateway['prefixes'] == gateway['active_prefixes']

            if gateway['is_alive']:
                alive_count += 1

        if not alive_count:
            log.error('No gateway is alive, this is usually bad.')

        self.load_default_routes()

        return True

    def check_address(self, address):
        """
        Check an IP address using the fping command

        @param string address, the address to check

        @return boolean result, True if the address is alive, False otherwise
        """

        result, _, _ = run_command('{0} {1}'.format(self.commands['fping'], address))

        return result

    def check_gateway(self, gateway):
        """
        Check if a gateway is alive
        This means that count of alived gateway and remotes is at least equal to score.

        @param dict gateway, the gateway to check

        @return boolean result, True if gateway is alive, False otherwise
        """

        if 'address' not in gateway:
            gateway['is_alive'] = False
            return False

        gateway['passed'] = 0
        if not gateway['reachable'] or self.check_address(gateway['address']):
            # Gateway is alive
            log.debug(' address %s is alive', gateway['address'])
            gateway['passed'] += 1

            # Testing remotes
            for remote in gateway['remotes']:
                if not self.route_add('{0}/32'.format(remote), gateway['address']):
                    log.warning(' failed to add route for %s via %s', remote, gateway['address'])
                    continue

                if self.check_address(remote):
                    # Remote is alive
                    log.debug(' address %s is alive', remote)
                    gateway['passed'] += 1
                else:
                    log.warning(' address %s is unreachable', remote)

                self.route_del('{0}/32'.format(remote), gateway['address'])

        previous_status = gateway.get('is_alive', False)
        if gateway['passed'] >= gateway['score']:
            # Alive
            log.debug('Gateway %(address)s is alive (%(passed)s/%(total)s, needed %(score)s)', gateway)
            gateway['is_alive'] = True
            if not previous_status:
                self.run_hook(gateway, 'change')
                self.run_hook(gateway, 'up')
        else:
            # Dead
            log.error('Gateway %(address)s is dead (%(passed)s/%(total)s, needed=%(score)s)', gateway)
            gateway['is_alive'] = False
            if previous_status:
                self.run_hook(gateway, 'change')
                self.run_hook(gateway, 'down')

        return gateway['is_alive']

    def list(self):
        """ List gateways """

        self.dry_run = True
        return self.scan(address=None)

    def load_active_prefixes(self, gateway):
        """
        Load gateway's active prefixes from routing table

        @param dict gateway, the gateway to get prefixes from

        @return boolean result, True is gateway is active (ie, a default route is set), False otherwise
        """

        gateway['active_prefixes'] = []

        if 'address' not in gateway:
            if 'prefixes' not in gateway:
                gateway['prefixes'] = []
            gateway['is_active'] = False
            return False

        result, stdout, _ = run_command('{0} route show via {1} scope global'.format(self.commands['ip'], gateway['address']))
        if not result:
            return False

        for match in re.finditer(r'(?P<prefix>.+)\s+dev\s+(?P<iface>\S+)', stdout):
            if 'iface' not in gateway:
                gateway['iface'] = match.group('iface')
            prefix = match.group('prefix')
            if prefix == 'default':
                continue
            gateway['active_prefixes'].append(prefix)

        gateway['active_prefixes'].sort()

        return True

    def load_default_routes(self):
        """
        Load current default routes from routing table

        @return int count, the number of default routes
        """

        result, stdout, _ = run_command('{0} route show 0.0.0.0/0'.format(self.commands['ip']))
        if not result:
            return 0

        count = 0

        # Get active default routes for "normal" interfaces
        addresses = [gateway.get('address') for gateway in self.gateways]
        for match in re.finditer(r'default\s+via\s+(?P<address>[\d\.]+)\s+dev\s+(?P<iface>\S+)', stdout):
            address = match.group('address')
            if address in addresses:
                gateway = self.gateways[addresses.index(address)]
                gateway['is_default'] = True
                count += 1

        # Get active default routes for ptp interfaces
        ifaces = [gateway.get('iface') for gateway in self.gateways]
        for match in re.finditer(r'default\s+dev\s+(?P<iface>\S+)', stdout):
            iface = match.group('iface')
            if iface in ifaces:
                gateway = self.gateways[ifaces.index(iface)]
                gateway['is_default'] = True
                count += 1

        return count

    def load_peer(self, gateway):
        """
        Load peer for ptp connections

        @param dict gateway, the gateway we want to get peer

        @return boolean result, True if success, False otherwise
        """

        if 'address' in gateway:
            del gateway['address']

        result, stdout, _ = run_command('{0} route show dev {1} scope link'.format(self.commands['ip'], gateway['iface']))
        if not result:
            return False

        match = re.match(r'(?P<address>.+)\s+ proto kernel\s+src\s+(?P<local>.+)', stdout)
        if match is None:
            return False

        gateway['address'] = match.group('address')
        return True

    def route_add(self, dst, via, flush=False):
        """
        Add a route to routing table

        @param string dst, destination
        @param string via, gateway to use to reach destination
        @param boolean flush, wether to flush the routing cache or not

        @return boolean result, True on success, False otherwise
        """

        result, _, _ = run_command('{0} route add {1} nexthop via {2}'.format(self.commands['ip'], dst, via), dry_run=self.dry_run)
        if not result or not flush:
            return result

        return self.route_flush()

    def route_change(self, dst, via, flush=False):
        """
        Change a route in routing table

        @param string dst, destination
        @param string via, gateway to use to reach destination
        @param boolean flush, wether to flush the routing cache or not

        @return boolean result, True on success, False otherwise
        """

        result, _, _ = run_command('{0} route change {1} nexthop via {2}'.format(self.commands['ip'], dst, via), dry_run=self.dry_run)
        if not result or not flush:
            return result

        return self.route_flush()

    def route_del(self, dst, via, flush=False):
        """
        Remove a route from routing table

        @param string dst, destination
        @param string via, gateway to use to reach destination
        @param boolean flush, wether to flush the routing cache or not

        @return boolean result, True on success, False otherwise
        """

        result, _, _ = run_command('{0} route del {1} via {2}'.format(self.commands['ip'], dst, via), dry_run=self.dry_run)
        if not result or not flush:
            return result

        return self.route_flush()

    def route_flush(self):
        """ Flush route cache. Return True if OK, False otherwise. """

        result, _, _ = run_command('{0} route flush cache'.format(self.commands['ip']), dry_run=self.dry_run)
        return result

    def run_hook(self, gateway, event):
        """
        Run event hook for gateway

        @param dict gateway, the gateway which triggered the event
        @param string event, event that was triggered by the gateway

        @return boolean result, True if the hook was ran successully, False otherwise
        """
        if event not in ('change', 'down', 'up'):
            log.error('Invalid event %s for gateway %s', event, gateway['address'])
            return False

        hook_cmd = gateway.get(event)
        if hook_cmd is None:
            log.debug('No hook for event %s for gateway %s', event, gateway['address'])
            return True

        log.info('Triggering %s event for gateway %s', event, gateway['address'])

        environ = os.environ.copy()
        environ['GFLOW_GATEWAY'] = gateway['address']
        environ['GFLOW_REMOTES'] = ' '.join(gateway['remotes'])
        environ['GFLOW_EVENT'] = event
        if gateway['iface']:
            environ['GFLOW_IFACE'] = gateway['iface']
        else:
            environ['GFLOW_IFACE'] = ''
        environ['GFLOW_WEIGHT'] = str(gateway['weight'])

        result, _, _ = run_command(hook_cmd, environ=environ, dry_run=self.dry_run)
        return result

    def scan(self, address='all'):
        """ Scan for gateways and refresh active prefixes """

        if not self.__setup(address=address):
            return False

        print('Gateways list:')

        for gateway in self.gateways:
            if address not in (None, 'all', gateway['address']):
                continue

            print('\n** Gateway #{id} (index={index}) **'.format(**gateway))

            if 'address' in gateway:
                print(' address: {0}'.format(gateway['address']))

            if gateway['remotes']:
                print(' remotes:')
                print('\n'.join('   {0}'.format(remote) for remote in gateway['remotes']))

            print(' reachable: {0}'.format('yes' if gateway['reachable'] else 'no'))
            print(' weight: {weight}'.format(**gateway))
            print(' default: {0}'.format('yes' if gateway['is_default'] else 'no'))
            print(' alive: {0}'.format('yes' if gateway['is_alive'] else 'no'))
            print(' active: {0}'.format('yes' if gateway['is_active'] else 'no'))

            if gateway['prefixes']:
                print(' prefixes:')
                print('\n'.join('   {0}'.format(prefix) for prefix in gateway['prefixes']))

            if gateway['active_prefixes']:
                print(' active prefixes:')
                print('\n'.join('   {0}'.format(prefix) for prefix in gateway['active_prefixes']))

        return True

    def run(self):
        """ The job is done here """

        if not self.__setup():
            return False

        result = True

        default_count = 0

        for gateway in self.gateways:
            if gateway['is_default']:
                default_count += 1

            if gateway['is_alive']:
                # Activate gateways that are alive but inactive
                if not gateway['is_active']:
                    log.info('Activating alive gateway %s...', gateway['address'])
                    for prefix in gateway['prefixes']:
                        if prefix in gateway['active_prefixes']:
                            continue

                        if self.route_add(prefix, gateway['address']):
                            log.info('  added missing prefix %s', prefix)
                            gateway['active_prefixes'].append(prefix)
                        else:
                            log.error('  failed to add missing prefix %s', prefix)
                            result = False

                gateway['active_prefixes'].sort()

            elif gateway['active_prefixes']:
                log.info('Deactivating dead gateway %s...', gateway['address'])
                for prefix in gateway['active_prefixes'][:]:
                    if self.route_del(prefix, gateway['address']):
                        log.info('  removed active prefix %s', prefix)
                        gateway['active_prefixes'].remove(prefix)
                    else:
                        log.error('  failed to remove active prefix %s', prefix)
                        result = False

        # Default route setup
        need_change = False
        if self.multipath:
            # Using all alive gateways as default route
            msg = []
            via = []
            for gateway in self.gateways:
                if gateway['is_alive']:
                    msg.append('{address} ({weight})'.format(**gateway))
                    via.append('{address} weight {weight}'.format(**gateway))
                    if not gateway['is_default']:
                        need_change = True
                elif gateway['is_default']:
                    need_change = True

            msg = ', '.join(msg)
            via = ' nexthop via '.join(via)

        else:
            # Using first alive gateway as default route
            msg = ''
            via = ''
            for gateway in self.gateways:
                if gateway['is_alive']:
                    msg = gateway['address']
                    via = gateway['address']
                    if not gateway['is_default']:
                        need_change = True
                    break

            if default_count > 1:
                need_change = True

        if not need_change:
            return result

        if default_count:
            if self.route_change('default', via):
                log.info('Successfully changed default route using gateway(s) %s', msg)
                return True

            log.error('Failed to change default route using gateway(s) %s', msg)
            result = False

        if self.route_add('default', via):
            log.info('Successfully added default route using gateway(s) %s', msg)
            return True

        log.error('Failed to add default route using gateway(s) %s', msg)
        return False


def load_cache(filepath, fallback=None):
    """ Safely load data from cache file using JSON format """

    if not os.path.isfile(filepath):
        return fallback

    try:
        with open(filepath, 'r') as fp:
            data = json.load(fp)

    except (OSError, JSONDecodeError):
        log.exception('Failed to open data file %s', filepath)
        return fallback

    return data


def run_command(command, environ=None, dry_run=False):
    """
    Execute a command

    @param string command, the command to execute
    @param dict environ, os environment to use when executing the command
    @param boolean dry_run, dry-run mode, the command is not executed

    @return tuple (result, stdout, stderr), result is True on success, False otherwise
    """

    if dry_run:
        log.debug('[DRY RUN MODE] Executing `%s`', command)
        return (True, None, None)

    log.debug('Executing `%s`', command)

    process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=environ)
    process.wait()

    if process.returncode != 0:
        log.error('Command `%s` failed with exit code %s', command, process.returncode)

    stdout, stderr = process.communicate()

    return (process.returncode == 0, stdout.decode('utf-8'), stderr.decode('utf-8'))


def save_cache(data, filepath):
    """ Save data to cache file using JSON format """

    try:
        with open(filepath, 'w') as fp:
            json.dump(data, fp, indent=2)

        os.chmod(filepath, stat.S_IRUSR | stat.S_IWUSR)

    except (OSError, TypeError, ValueError):
        log.exception('Failed to save data file %s', filepath)
        return False

    return True


def main(argv=None):
    """ Main function """

    if argv is None:
        argv = sys.argv[1:]

    parser = ArgumentParser(description='Gateway FaiL Over Wan')

    parser.add_argument(
        'config_file',
        nargs='?',
        help='Configuration file to use.'
    )

    parser.add_argument(
        '-d', '--debug',
        action='store_true',
        help='Show debug messages'
    )

    parser.add_argument(
        '-t', '--dry-run',
        action='store_true',
        help='Test mode'
    )

    parser.add_argument(
        '-f', '--fping-cmd',
        dest='cmd_fping',
        help='Path of fping command'
    )

    parser.add_argument(
        '-i', '--ip-cmd',
        dest='cmd_ip',
        help='Path of ip command'
    )

    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        '--enable-multipath',
        action='store_true',
        default=None,
        dest='multipath',
        help='Enable multipath route setup'
    )

    group.add_argument(
        '--disable-multipath',
        action='store_false',
        default=None,
        dest='multipath',
        help='Disable multipath route setup'
    )

    group = parser.add_mutually_exclusive_group()
    group.add_argument(
        '-s', '--scan',
        help='Scan gateway prefixes (use "all" to scan all gateways)'
    )
    group.add_argument(
        '-l', '--list',
        action='store_true',
        help='List current configuration (Implies --dry-run)'
    )

    args = parser.parse_args(argv)

    if not args.config_file:
        config_filepath = '/etc/gflow.ini'
    else:
        config_filepath = os.path.abspath(args.config_file)

    # Logging initialization
    try:
        fileConfig(config_filepath)
    except KeyError:
        print('Warning: logging configuration is invalid or empty in file {0}'.format(config_filepath))

    if args.debug:
        log.setLevel(logging.DEBUG)

    # Checking access on configuration file
    if not os.path.isfile(config_filepath) or not os.access(config_filepath, os.R_OK):
        parser.error('Invalid configuration file {0}'.format(config_filepath))

    commands = dict(
        fping=args.cmd_fping,
        ip=args.cmd_ip,
    )

    # Go!
    with GFLOW(config_filepath, commands, multipath=args.multipath, dry_run=args.dry_run) as gflow:
        if args.scan:
            return gflow.scan(address=args.scan)

        if args.list:
            return gflow.list()

        return gflow.run()


if __name__ == '__main__':
    sys.exit(not main())
