#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2020
# This file is part of Shinken Enterprise, all rights reserved.

import time
import threading
import shinkensolutions.shinkenjson as json
import uuid
from collections import OrderedDict

import operator

from shinken.log import logger
from shinken.webui import bottlewebui as bottle
from shinken.webui.bottlewebui import request, response, parse_auth, HTTPResponse
from shinken.webui.cherrypybackend import CherryPyServerHTTP

from shinkensolutions.api.synchronizer.source.abstract_module.listener_module import ListenerModule
from shinkensolutions.api.synchronizer import ITEM_STATE, ITEM_TYPE, NOT_TO_LOOK, DEF_ITEMS, METADATA, DataProviderMongo, get_type_item_from_class
from collections import Iterable


# This module will open an HTTP service, where a user can send an host (for now)
class BaseRESTListener(ListenerModule):
    url_base_path = '/shinken/listener-rest'
    
    _configuration_fields = None
    
    
    def __init__(self, modconf):
        super(BaseRESTListener, self).__init__(modconf)
        self.srv = None
        self.srv_lock = threading.RLock()
        self.provider_mongo = None
        self.thread = None
        self.serveropts = {}
        try:
            logger.debug("[%s] Configuration starting ..." % self._get_logger_name())
            self.port = int(getattr(modconf, 'port', '7761'))
            self.host = getattr(modconf, 'host', '0.0.0.0')
            self.module_name = getattr(modconf, 'module_name', 'module-listener-rest')
            if self.module_name.startswith('module-'):
                self.module_name = self.module_name[7:]
            logger.info("[%s] Configuration done, host: %s(%s)" % (self._get_logger_name(), self.host, self.port))
        except AttributeError:
            logger.error("[%s] The module is missing a property, check module declaration in shinken-specific.cfg" % self._get_logger_name())
            raise
        except Exception, e:
            logger.error("[%s] Exception : %s" % (self._get_logger_name(), str(e)))
            raise
        self.hosts_lock = threading.RLock()
    
    
    def _get_logger_name(self):
        return self.get_name()[7:].upper()
    
    
    def get_configuration_fields(self):
        if self._configuration_fields is None:
            self._configuration_fields = OrderedDict([
                ('configuration', OrderedDict([
                    ('authentication', {
                        'display_name': self._('analyzer.conf_authentication'),
                        'default'     : False,
                        'protected'   : False,
                        'help'        : '',
                        'type'        : 'checkbox',
                        'display_bind': ('login', 'password')
                    }),
                    ('login', {
                        'display_name': self._('analyzer.conf_login'),
                        'default'     : '',
                        'protected'   : False,
                        'help'        : '',
                        'type'        : 'text',
                    }),
                    ('password', {
                        'display_name': self._('analyzer.conf_password'),
                        'default'     : '',
                        'protected'   : True,
                        'help'        : '',
                        'type'        : 'text',
                    }),
                    ('use_ssl', {
                        'display_name': self._('analyzer.conf_use_ssl'),
                        'default'     : False,
                        'protected'   : False,
                        'help'        : '',
                        'type'        : 'checkbox',
                        'display_bind': ('ssl_key', 'ssl_cert')
                    }),
                    ('ssl_key', {
                        'display_name': self._('analyzer.conf_ssl_key'),
                        'default'     : '',
                        'protected'   : False,
                        'help'        : '',
                        'type'        : 'text',
                    }),
                    ('ssl_cert', {
                        'display_name': self._('analyzer.conf_ssl_cert'),
                        'default'     : '',
                        'protected'   : False,
                        'help'        : '',
                        'type'        : 'text',
                    }),
                
                ])
                 ),
            ])
        return self._configuration_fields
    
    
    # We initialize the HTTP part. It's a simple wsgi backend
    # with a select hack so we can still exit if someone ask it
    def _init_http(self):
        logger.info("[%s] Starting http socket" % self._get_logger_name())
        conf = self.get_my_configuration()
        ssl_conf = conf.get('configuration', {})
        use_ssl = ssl_conf.get("use_ssl", "")
        ssl_key = ssl_conf.get("ssl_key", "")
        ssl_cert = ssl_conf.get("ssl_cert", "")
        try:
            # instantiate a new Bottle object, don't use the default one otherwise all module will share the same
            app = bottle.Bottle()
            app = self._init_routes(app)
            self.srv = app.run(host=self.host, port=self.port, server=CherryPyServerHTTP, use_ssl=use_ssl, ssl_key=ssl_key, ssl_cert=ssl_cert, **self.serveropts)
        except Exception, e:
            logger.error("[%s] Exception : %s" % (self._get_logger_name(), str(e)))
            raise
        logger.info("[%s] Server loaded" % self._get_logger_name())
    
    
    def _init_routes(self, app):
        
        def base_rest_host():
            response.content_type = 'application/json'
            self._query_check_auth()
            if request.method == 'PUT':
                return self.post_host()
            elif request.method == 'GET':
                return self.list_hosts()
        
        
        def specific_rest_item(item_uuid):
            response.content_type = 'application/json'
            self._query_check_auth()
            import_needed = bottle.request.GET.get('import_needed', True)
            if request.method == 'DELETE':
                return self.delete_host(item_uuid, import_needed)
            elif request.method == 'GET':
                return self.get_element(item_uuid)
            
            elif request.method in ('POST', 'PUT', 'PATCH'):
                return self.update_element(item_uuid)
        
        
        app.route('%s/v1/hosts/' % self.url_base_path, callback=base_rest_host, method=('PUT', 'GET',))
        app.route('%s/v1/hosts' % self.url_base_path, callback=base_rest_host, method=('PUT', 'GET',))
        app.route('%s/v1/hosts/:item_uuid' % self.url_base_path, callback=specific_rest_item, method=('DELETE', 'GET', 'POST', 'PUT', 'PATCH'))
        return app
    
    
    def _init_datamanager(self):
        self.provider_mongo = DataProviderMongo(self.syncdaemon.mongodb_db, self.syncdaemon.database_cipher)
    
    
    def get_dataprovider(self):
        if self.provider_mongo is None:
            self._init_datamanager()
        return self.provider_mongo
    
    
    def start_listener(self):
        self.get_dataprovider()
        # We must protect against a user that spam the enable/disable button
        with self.srv_lock:
            # We already did start, skip it
            if self.thread is not None:
                if self.thread.is_alive():
                    return
                
                self.thread.join()
                self.thread = None
            
            self._init_http()
            self._start_thread()
    
    
    def stop_listener(self):
        # Already no more thread, we are great
        if self.thread is None:
            return
        # We must protect against a user that spam the enable/disable button
        with self.srv_lock:
            logger.info("[%s] Calling stop on our listener" % self._get_logger_name())
            self.srv.stop()
            self.thread.join(1)
            self.thread = None
            logger.info("[%s] Stop listener is done" % self._get_logger_name())
    
    
    def _start_thread(self):
        self.thread = threading.Thread(None, target=self._http_start, name='Listener-REST')
        self.thread.daemon = True
        self.thread.start()
    
    
    def _http_start(self):
        # Now block and run
        logger.info('[%s] Server starting' % self._get_logger_name())
        self.srv.start()
    
    
    # WARNING: do not call it check_auth or it will be used for Configuration UI auth!
    def _query_check_auth(self):
        """Check for auth if it's not anonymously allowed"""
        conf = self.get_my_configuration()
        if conf.get('configuration', {}).get('authentication'):
            auth_conf = conf.get('configuration', {})
            username = auth_conf.get("login", "")
            password = auth_conf.get("password", "")
            basic = parse_auth(request.environ.get('HTTP_AUTHORIZATION', ''))
            # Maybe the user not even ask for user/pass. If so, bail out
            if not basic:
                raise HTTPResponse(self._('listener.error_auth_required'), 401)
            # Maybe he do not give the good credential?
            if basic[0] != username or basic[1] != password:
                raise HTTPResponse(self._('listener.error_auth_denied'), 403)
    
    
    def _return_in_400(self, err):
        logger.error('[%s] bad request object received: %s' % (self._get_logger_name(), err))
        raise HTTPResponse(err, 400)
    
    
    @staticmethod
    def _disallow_duplicate_keys(ordered_pairs):
        d = {}
        duplicate = []
        for k, v in ordered_pairs:
            if k in d:
                duplicate.append(k)
            else:
                d[k] = v
        if duplicate:
            raise ValueError("Duplicate keys:%s" % ", ".join(duplicate))
        return d
    
    
    def _get_data(self):
        # Getting lists of information for the commands
        data = request.body.readline()
        if not data:
            self._return_in_400(self._('listener.error_no_data'))
        try:
            try:
                decoded_data = json.loads(data, object_pairs_hook=self._disallow_duplicate_keys)
            except:
                # In python 2.6 object_pairs_hook does not exist
                decoded_data = json.loads(data)
            if '' in decoded_data:
                del decoded_data['']
            for key, value in decoded_data.iteritems():
                if not isinstance(value, basestring):
                    if isinstance(value, Iterable):
                        decoded_data[key] = ','.join(value)
                    else:
                        decoded_data[key] = str(value)
            return decoded_data
        except Exception as exp:
            if "Duplicate keys:" in exp.message:
                duplicate_keys = exp.message.split(':')[1]
                if len(duplicate_keys.split(',')) == 1:
                    return self._return_in_400(self._('listener.error_duplicate_key') % duplicate_keys)
                else:
                    return self._return_in_400(self._('listener.error_duplicate_key') % duplicate_keys)
            self._return_in_400(self._('listener.error_bad_json') % exp)
    
    
    def delete_host(self, item_se_uuid, import_needed=True):
        hosts = self.get_dataprovider().find_items(ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name(), where={'_SE_UUID': item_se_uuid})
        if not hosts:
            raise HTTPResponse(self._('listener.error_host_id') % item_se_uuid, 404)
        with self.hosts_lock:
            for host in hosts:
                self.get_dataprovider().delete_item(host, ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name())
        response.status = 200
        self.callback_synchronizer_about_delete_elements(items_type=ITEM_TYPE.HOSTS, data={'_id': item_se_uuid}, import_needed=import_needed)
        return "done"
    
    
    def update_element(self, item_uuid):
        hosts = self.get_dataprovider().find_items(ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name(), where={'_SE_UUID': item_uuid})
        new_data = self._get_data()
        item_type = get_type_item_from_class(ITEM_TYPE.HOSTS, new_data)
        if item_type != ITEM_TYPE.HOSTS:
            return self._return_in_400(self._('listener.error_type_not_allowed') % item_type)
        
        host = None
        if hosts:
            raw_source_host = hosts[0]
            name_key = DEF_ITEMS[ITEM_TYPE.HOSTS]['key_name']
            host = {
                name_key  : raw_source_host[name_key],
                '_SE_UUID': "core-%s-%s" % (ITEM_TYPE.HOSTS, raw_source_host['_id']),
                '_id'     : raw_source_host['_id']
            }
        else:
            # if no host in listners, try to fund it in workarea or stagging
            where = {'_id': item_uuid}
            if item_uuid.startswith('core'):
                where = {'_SE_UUID': item_uuid}
            # try to find an object with the same id in stagging or workarea
            stagging_hosts = self.get_dataprovider().find_merge_state_items(ITEM_TYPE.HOSTS, item_states=[ITEM_STATE.WORKING_AREA, ITEM_STATE.STAGGING], where=where)
            if any(stagging_hosts):
                # now re-use the object _SE_UUID or one from the stagging object
                stagging_host = stagging_hosts[0]
                name_key = DEF_ITEMS[ITEM_TYPE.HOSTS]['key_name']
                host = {
                    name_key  : stagging_host[name_key],
                    '_SE_UUID': "core-%s-%s" % (ITEM_TYPE.HOSTS, stagging_host['_id']),
                    '_id'     : stagging_host['_id']
                }
        
        # make sure the update host have the same id as the old one
        if not host:
            raise HTTPResponse(self._('listener.error_host_se_uuid') % item_uuid, 404)
        
        # use the NOT_TO_LOOK as list of forbidden keys
        forbidden_keys = set(map(lambda key: key.upper().strip(), new_data.keys())).intersection(set(map(lambda name: name.upper().strip().decode('utf8', 'ignore'), NOT_TO_LOOK)))
        if forbidden_keys:
            return self._return_in_400(self._('listener.error_forbidden_keys') % ", ".join(forbidden_keys))
        
        # add the '_id' and host_name in sent data so we can keep track of this object (in the sources history)
        sent_data = new_data.copy()
        sent_data['_id'] = item_uuid
        
        if not 'host_name' in sent_data and 'host_name' in host:
            sent_data['host_name'] = host['host_name']
        elif 'host_name' in sent_data and sent_data['host_name'] != host['host_name']:
            # host name have been updated
            # we have to update the synk_keys
            sync_keys = list(host.get('_SYNC_KEYS', ()))
            # remove the old host name
            if host['host_name'] in sync_keys:
                sync_keys.remove(host['host_name'])
            if not sent_data['host_name'] in sync_keys:
                sync_keys.append(sent_data['host_name'])
            host['_SYNC_KEYS'] = map(operator.methodcaller('lower'), sync_keys)
        
        host.update(new_data)
        host['update_date'] = time.time()
        
        if not host.get('_SYNC_KEYS', None):
            host['_SYNC_KEYS'] = "%s,%s" % (host['_SE_UUID'], host['host_name'])
        if isinstance(host.get('_SYNC_KEYS', None), unicode):
            host['_SYNC_KEYS'] = host['_SYNC_KEYS'].encode("ascii")
        with self.hosts_lock:
            ciphered_host = self.syncdaemon.database_cipher.cipher(host, item_type=ITEM_TYPE.HOSTS)
            self.get_dataprovider().save_item(ciphered_host, item_type=ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name())
        response.status = 200
        self.callback_synchronizer_about_update_elements(items_type=ITEM_TYPE.HOSTS, data=sent_data)
        return "done"
    
    
    def get_element(self, item_uuid):
        ciphered_hosts = self.get_dataprovider().find_items(ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name(), where={'_SE_UUID': item_uuid})
        if not ciphered_hosts:
            raise HTTPResponse(self._('listener.error_host_se_uuid') % item_uuid, 404)
        ciphered_host = ciphered_hosts[0]
        host = self.syncdaemon.database_cipher.uncipher(ciphered_host, item_type=ITEM_TYPE.HOSTS)
        host.pop('@metadata', None)
        return json.dumps(host)
    
    
    def list_hosts(self):
        raw_hosts = self.get_dataprovider().find_items(ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name())
        hosts = []
        for host in raw_hosts:
            self.syncdaemon.database_cipher.uncipher(host, item_type=ITEM_TYPE.HOSTS)
            host.pop('@metadata', None)
            hosts.append(host)
        
        return json.dumps(hosts)
    
    
    def post_host(self):
        host = self._get_data()
        sent_data = host.copy()
        logger.info('[%s] did receive host from API: %s' % (self._get_logger_name(), host))
        
        item_type = get_type_item_from_class(ITEM_TYPE.HOSTS, sent_data)
        if item_type != ITEM_TYPE.HOSTS:
            return self._return_in_400(self._('listener.error_type_not_allowed') % item_type)
        # use the NOT_TO_LOOK as list of forbidden keys
        forbidden_keys = set(map(lambda key: key.upper().strip(), host.keys())).intersection(set((map(lambda name: name.upper().strip().decode('utf8', 'ignore'), NOT_TO_LOOK))))
        if forbidden_keys:
            return self._return_in_400(self._('listener.error_forbidden_keys') % ", ".join(forbidden_keys))
        
        duplicate_keys = []
        for try_key in [k[:-7] for k in host if k.endswith('[FORCE]')]:
            if host.get(try_key, ""):
                duplicate_keys.append(try_key)
        if duplicate_keys:
            if len(duplicate_keys) == 1:
                return self._return_in_400(self._('listener.error_duplicate_key') % ", ".join(duplicate_keys))
            else:
                return self._return_in_400(self._('listener.error_duplicate_key') % ", ".join(duplicate_keys))
        
        name = host.get('host_name', '')
        # try to find if an already object exists with this name
        if not '_id' in host:
            old_hosts = self.get_dataprovider().find_items(ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name(), where={'host_name': name})
            if any(old_hosts):
                # an old host have been found, update the host and return the id
                old_host = self.syncdaemon.database_cipher.uncipher(old_hosts[0], item_type=ITEM_TYPE.HOSTS)
                # remove the _SYNC_KEYS, in any case they will be override
                old_host.pop('_SYNC_KEYS', None)
                key_to_remove = []
                for key in old_host:
                    try_key = key[:-7] if key.endswith('[FORCE]') else "%s[FORCE]" % key
                    if host.get(try_key, ""):
                        key_to_remove.append(key)
                for key in key_to_remove:
                    old_host.pop(key)
                old_host.update(host)
                host = old_host
                # the _id is an Object, and mongo will refuse to save it again, cast it in str
                host['_id'] = str(host['_id'])
            # else try to find if an already object exists with this name in stagging and workarea
            else:
                # try to find an object with the same name in raw_source
                # for them just reuse their _id and/or _SE_UUID if available
                stagging_hosts = self.get_dataprovider().find_merge_state_items(ITEM_TYPE.HOSTS, item_states=[ITEM_STATE.WORKING_AREA, ITEM_STATE.STAGGING], where={'host_name': name})
                if any(stagging_hosts):
                    staggin_host = stagging_hosts[0]
                    host['_SE_UUID'] = "core-%s-%s" % (item_type, staggin_host['_id'])
                    host['_id'] = staggin_host['_id']
        
        if '_SE_UUID' in host:
            h_uuid = host['_SE_UUID']
        else:
            _id = uuid.uuid4().hex
            host['_id'] = _id
            h_uuid = "core-%s-%s" % (item_type, _id)
            host['_SE_UUID'] = h_uuid
        
        if not name:
            return self._return_in_400(self._('listener.error_missing_hostname_field'))
        
        if not isinstance(name, basestring):
            return self._return_in_400(self._('listener.error_host_name_not_string'))
        
        for (k, v) in host.iteritems():
            if not k.startswith('_'):
                continue
            if not isinstance(k, basestring):
                return self._return_in_400(self._('listener.error_data_keys_not_strings') % (type(k), k))
            if k != '_id' and k != k.upper():
                return self._return_in_400(self._('listener.error_data_keys_not_uppercase'))
            if not isinstance(v, basestring):
                return self._return_in_400(self._('listener.error_data_values_not_strings') % (type(v), k, v))
        
        address = host.get('address', '')
        if address and not isinstance(address, basestring):
            return self._return_in_400(self._('listener.error_address_not_string'))
        
        host['_SYNC_KEYS'] = "%s,%s" % (h_uuid.lower(), name.lower())
        host['update_date'] = time.time()
        # request.environ.get('HTTP_USER_AGENT', "") # if needed
        host['imported_from'] = u"%s %s" % (self.get_my_source().get_name(), self._("listener.sent_from") % request.environ.get('REMOTE_ADDR', "UNSET"))
        logger.info('[%s] host %s(%s) is valid' % (self._get_logger_name(), name, h_uuid))
        
        if isinstance(host['_SYNC_KEYS'], unicode):
            host['_SYNC_KEYS'] = host['_SYNC_KEYS'].encode("ascii")
        with self.hosts_lock:
            ciphered_host = self.syncdaemon.database_cipher.cipher(host, item_type=ITEM_TYPE.HOSTS)
            self.get_dataprovider().save_item(ciphered_host, item_type=ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name())
        
        self.callback_synchronizer_about_new_elements(items_type=ITEM_TYPE.HOSTS, data=sent_data)
        response.status = 201
        return json.dumps(h_uuid)
    
    
    def get_all_discovery_elements(self):
        raw_objects = {'host': []}
        # Get a copy of the values as we don't know how much time they will be keep outside this code, and so outside the lock
        with self.hosts_lock:
            my_hosts = self.get_dataprovider().find_items(ITEM_TYPE.HOSTS, item_state=ITEM_STATE.RAW_SOURCES, item_source=self.get_name())
        
        # remove unnecessary property for object after merge and set them as metadata
        for host in my_hosts:
            update_date = host.pop('update_date', None)
            host = self.syncdaemon.database_cipher.uncipher(host, item_type=ITEM_TYPE.HOSTS)
            if update_date:
                METADATA.update_metadata(host, METADATA.UPDATE_DATE, int(update_date))
        
        output = self.syncdaemon._('listener.output_successful') % len(my_hosts)
        raw_objects['host'] = my_hosts
        res = {'state': 'OK', 'output': output, 'objects': raw_objects, 'errors': [], 'warnings': []}
        return res
    
    
    def remove_source_item(self, item_type, source_item):
        with self.hosts_lock:
            self.delete_host(source_item['_SE_UUID'], import_needed=False)
