#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2009-2022:
#     Gabes Jean, naparuba@gmail.com
#     Gerhard Lausser, Gerhard.Lausser@consol.de
#     Gregory Starck, g.starck@gmail.com
#     Hartmut Goebel, h.goebel@goebel-consult.de
#
# This file is part of Shinken.
#
# Shinken is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Shinken is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Shinken.  If not, see <http://www.gnu.org/licenses/>.

import errno
import inspect
import os
import queue
import socket
import threading
import time
import traceback
from http.server import BaseHTTPRequestHandler

import select
from cheroot.server import get_ssl_adapter_class
from cheroot.wsgi import WSGIServer

import shinkensolutions.shinkenjson as json
from shinken.compresser import compresser
from shinken.daemoninfo import daemon_info
from shinken.log import logger, LoggerFactory
from shinken.runtime_stats.cpu_stats import cpu_stats_helper
from shinken.safepickle import SafeUnpickler, SERIALIZATION_SECURITY_EXCEPTION
from shinken.thread_helper import get_thread_id
from shinken.util import make_unicode
from shinken.webui import bottlecore as bottle

ALLOW_HTTP_CALLS_DEBUG = os.path.exists('/tmp/SHINKEN_ALLOW_HTTP_CALLS_DEBUG_FLAG')

raw_logger = LoggerFactory.get_logger()
logger_http = raw_logger.get_sub_part('HTTP CALL')
wsgi_logger = raw_logger.get_sub_part('WSGI')

try:
    import ssl
except ImportError:
    ssl = None

bottle.debug(True)


class InvalidWorkDir(Exception):
    pass


class PortNotFree(Exception):
    pass


class HTTPDaemonBusinessError(bottle.HTTPError):
    
    def __init__(self, status=None, body=None, exception=None, _traceback=None, **options):
        super(HTTPDaemonBusinessError, self).__init__(status, body, exception, _traceback, **options)
        self.business_error = True
        self.code = status
        self.text = body
    
    
    def __repr__(self):
        return self.text
    
    
    def __str__(self):
        return 'HTTPDaemonBusinessError[code:%s, text:%s]' % (self.code, self.text)


class ShinkenWSGI(WSGIServer):
    # noinspection PyShadowingNames
    def error_log(self, msg='', level=20, traceback=False):
        wsgi_logger.warning(msg)
        if traceback:
            wsgi_logger.print_stack()


# CherryPy is allowing us to have a HTTP 1.1 server, and so have a KeepAlive
class CherryPyServer(bottle.ServerAdapter):
    quiet = True
    
    
    def run(self, handler):
        daemon_thread_pool_size = self.options['daemon_thread_pool_size']
        server = ShinkenWSGI((self.host, self.port), handler, numthreads=daemon_thread_pool_size, shutdown_timeout=1, timeout=300)
        wsgi_logger.info('Initializing a Cheroot backend with %d threads' % daemon_thread_pool_size)
        use_ssl = self.options['use_ssl']
        ca_cert = self.options['ca_cert']
        ssl_cert = self.options['ssl_cert']
        ssl_key = self.options['ssl_key']
        
        if use_ssl:
            server.ssl_adapter = get_ssl_adapter_class()(certificate=ssl_cert, private_key=ssl_key)
        return server


class CherryPyBackend:
    def __init__(self, host, port, use_ssl, ca_cert, ssl_key, ssl_cert, hard_ssl_name_check, daemon_thread_pool_size):
        self.port = port
        self.use_ssl = use_ssl
        try:
            
            @bottle.error(500)
            def print_traceback(error):
                if error.traceback:
                    formatted_lines = error.traceback.split('\n')
                    first_line = formatted_lines[0]
                    to_print = []
                    if first_line == 'None':
                        for line in traceback.format_stack():
                            to_print.append("%s" % line)
                    else:
                        to_print.append("ERROR stack : %s" % first_line)
                        for line in formatted_lines[1:]:
                            to_print.append("%s" % line)
                elif isinstance(error, bottle.HTTPError):
                    # app.abord make a HTTPError without traceback see SEF-6067
                    to_print = ['HTTPError : code[%s] message:[%s]' % (error.status, error.body)]
                else:
                    to_print = [str(error)]
                return "<br/>".join(to_print)
            
            
            self.srv_pid = os.getpid()
            self.srv = bottle.run(host=host, port=port, server=CherryPyServer, use_ssl=use_ssl, ca_cert=ca_cert, ssl_key=ssl_key, ssl_cert=ssl_cert, daemon_thread_pool_size=daemon_thread_pool_size, quiet=True)  # type: ShinkenWSGI
        except socket.error as exp:
            msg = "Error: Sorry, the port %d is not free: %s" % (self.port, str(exp))
            raise PortNotFree(msg) from exp
        except Exception as e:
            # must be a problem with http workdir:
            raise InvalidWorkDir(e) from e
    
    
    # When call, it do not have a socket
    def get_sockets(self):
        return []
    
    
    # We stop our processing, but also try to hard close our socket as cherrypy is not doing it...
    def stop(self):
        if self.srv is None:
            return
        if os.getpid() != self.srv_pid:
            logger.debug(f'Stopping WSGI server on port {self.port} in forked process ')
            # Try to reinitialize some locks (which may be problematic after a fork)
            self.srv.requests._queue = queue.Queue()
            self.srv._connections._selector._lock = threading.RLock()  # noqa
            del self.srv
            self.srv = None
            self.srv_pid = None
            logger.debug(f'Stopped WSGI server on port {self.port} in forked process ')
            return
        try:
            self.srv.stop()
        except Exception as exp:
            logger.warning('Cannot stop the CherryPy backend : %s' % exp)
    
    
    # Will run and LOCK
    def run(self):
        try:
            self.srv.start()
        except socket.error as exp:
            msg = "Error: Sorry, the port %s is not free: %s" % (self.port, str(exp))
            raise PortNotFree(msg) from exp
        finally:
            self.srv.stop()


class HTTPDaemon:
    def __init__(self, host, port, use_ssl, ca_cert, ssl_key, ssl_cert, hard_ssl_name_check, daemon_thread_pool_size):
        self.port = port
        self.host = host
        self.bottle = bottle
        self.abort = bottle.abort
        # Port = 0 means "I don't want HTTP server"
        if self.port == 0:
            return
        
        self.use_ssl = use_ssl
        
        self.registered_fun = {}
        self.registered_fun_names = []
        self.registered_fun_defaults = {}
        
        protocol = 'http'
        if use_ssl:
            protocol = 'https'
        self.uri = '%s://%s:%s' % (protocol, self.host, self.port)
        logger.info("Opening HTTP socket at %s" % self.uri)
        
        # Hack the BaseHTTPServer so only IP will be looked by wsgiref, and not names
        BaseHTTPRequestHandler.address_string = lambda x: x.client_address[0]
        
        self.srv = CherryPyBackend(host, port, use_ssl, ca_cert, ssl_key, ssl_cert, hard_ssl_name_check, daemon_thread_pool_size)
        self.lock = threading.RLock()
        self.is_stopping = False
    
    
    # Get the server socket but not if disabled or closed
    def get_sockets(self):
        if self.port == 0 or self.srv is None:
            return []
        return self.srv.get_sockets()
    
    
    def run(self):
        self.srv.run()
    
    
    def register(self, obj):
        methods = inspect.getmembers(obj, predicate=inspect.ismethod)
        merge = [fname for (fname, f) in methods if fname in self.registered_fun_names]
        if merge:
            methods_in = [m.__name__ for m in obj.__class__.__dict__.values() if inspect.isfunction(m)]
            methods = [m for m in methods if m[0] in methods_in]
        
        for (fname, f) in methods:
            if fname.startswith('_'):
                continue
            # Get the args of the function to catch them in the queries
            argspec = inspect.getfullargspec(f)
            args = argspec.args
            varargs = argspec.varargs
            keywords = argspec.varkw
            defaults = argspec.defaults
            # If we got some defauts, save arg=value so we can lookup
            # for them after
            if defaults:
                default_args = list(zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
                _d = {}
                for (argname, defavalue) in default_args:
                    _d[argname] = defavalue
                self.registered_fun_defaults[fname] = _d
            # remove useless self in args, because we alredy got a bonded method f
            if 'self' in args:
                args.remove('self')
            self.registered_fun_names.append(fname)
            self.registered_fun[fname] = f
            
            
            # WARNING : we MUST do a 2 levels function here, or the f_wrapper
            # will be uniq and so will link to the last function again and again
            def register_callback(function_name, args, f, obj, lock):
                def f_wrapper():
                    caller = bottle.request.environ.get('HTTP_X_CALLER_NAME', '(unknown)')
                    remote_addr = bottle.request.environ.get('REMOTE_ADDR', '(unknown)')
                    if ALLOW_HTTP_CALLS_DEBUG:
                        logger_http.debug('[HTTP] incoming http request: thread=%s caller=%-15s/%-15s  functioncalled=%s' % (get_thread_id(), caller, remote_addr, function_name))
                    hours_since_epoch = int(time.time()) / 3600
                    ret = None
                    try:
                        t0 = time.time()
                        cpu_snap = cpu_stats_helper.get_thread_cpu_snapshot()
                        args_time = aqu_lock_time = calling_time = json_time = 0
                        need_lock = getattr(f, 'need_lock', obj.default_lock)
                        
                        # Warning : put the bottle.response set inside the wrapper
                        # because outside it will break bottle
                        d = {}
                        method = getattr(f, 'method', 'get').lower()
                        for aname in args:
                            v = None
                            if method == 'post':
                                v = bottle.request.forms.get(aname, None)
                                # Post args are zlibed and cPickled
                                if v:
                                    v = compresser.decompress(v.encode('latin1'))
                                    caller_string_display = 'HTTP(s) call "%s" by IP=%s' % (getattr(f, 'display_name', '/%s' % function_name), remote_addr)
                                    try:
                                        v = SafeUnpickler.loads(v, caller_string_display)
                                    except SERIALIZATION_SECURITY_EXCEPTION:
                                        # NOTE: we alrady loggued in loads in ERROR with the / and the addr, so we are ok for debug/support
                                        return None  # NOTE: don't give the other any clue about this, so it can't detect if attack succeed or not
                            
                            elif method == 'get':
                                v = bottle.request.GET.get(aname, None)
                            if v is None:
                                # Maybe we got a default value?
                                default_args = self.registered_fun_defaults.get(function_name, {})
                                if aname not in default_args:
                                    _err = 'HTTP/500: missing argument %s for function %s' % (aname, function_name)
                                    if not daemon_info.daemon_is_requested_to_stop.value:
                                        logger_http.error(_err)
                                    raise bottle.HTTPError(500, _err)
                                v = default_args[aname]
                            d[aname] = v
                        
                        # Clean args from bottle+cherrypy in memory
                        # Always clean bottle environ data as bottle it won't do it itself (or cherrypy?)
                        if method == 'post':
                            bottle.request.environ['wsgi.input'].truncate(0)  # StringIO
                            bottle.request.environ['bottle.request.body'].truncate(0)  # StringIO too
                            # forms parameters are kept too, in Multidict (with a .dict)
                            bottle.request.environ['bottle.request.post'].dict.clear()
                            bottle.request.environ['bottle.request.forms'].dict.clear()
                        
                        args_time = time.time() - t0
                        
                        t1 = time.time()
                        if need_lock:
                            # logger.debug("HTTP: calling lock for %s" % fname)
                            lock.acquire()
                        aqu_lock_time = time.time() - t1
                        
                        t2 = time.time()
                        
                        try:
                            ret = f(**d)
                        # Always call the lock release if need
                        finally:
                            # Ok now we can release the lock
                            if need_lock:
                                lock.release()
                        calling_time = time.time() - t2
                        t3 = time.time()
                        
                        ret_json = json.dumps(make_unicode(ret))
                        json_time = time.time() - t3
                        
                        # Do not log very small times
                        if calling_time > 0.05 or aqu_lock_time > 0.05 or json_time > 0.05 or args_time > 0.05:
                            logger_http.debug(" [ PERF ] : %s [args:%.4f] [aqu_lock:%.4f] [calling:%.4f] [json:%.4f]" %
                                              (function_name, args_time, aqu_lock_time, calling_time, json_time))
                            logger_http.debug('[ %s ] %s ' % (function_name, cpu_snap.get_diff()))
                        
                        # all app daemons will have this but for schedulers app can be Scheduler and not a daemon
                        if hasattr(obj.app, 'http_errors_count'):
                            # Clean up older error counts
                            for key in list(obj.app.http_errors_count.keys()):
                                if (int(key) < (hours_since_epoch - 24)) or int(key) > hours_since_epoch:
                                    del obj.app.http_errors_count[key]
                        
                        return ret_json
                    except Exception as e:
                        if not daemon_info.daemon_is_requested_to_stop.value:
                            is_public_api = getattr(f, 'public_api', obj.public_api)
                            if is_public_api and isinstance(e, bottle.HTTPError):
                                logger_http.error('error in %s : %s' % (function_name, e))
                            elif getattr(e, 'business_error', False):
                                if hasattr(e, 'text'):
                                    error_msg = e.text
                                    if isinstance(error_msg, bytes):
                                        error_msg = error_msg.decode('utf8', 'ignore')
                                else:
                                    error_msg = e
                                logger_http.error('error in %s : %s' % (function_name, error_msg))
                            else:
                                # all app daemons will have this but for schedulers app can be Scheduler and not a daemon
                                if hasattr(obj.app, 'http_errors_count'):
                                    # Store HTTP errors for the last 24 hours
                                    if hours_since_epoch in obj.app.http_errors_count:
                                        obj.app.http_errors_count[int(hours_since_epoch)] += 1
                                    else:
                                        obj.app.http_errors_count[int(hours_since_epoch)] = 1
                                    
                                    logger_http.error('An error occurred when calling [%s], add http_errors_count %s' % (function_name, obj.app.http_errors_count))
                                    if ret:
                                        logger_http.debug('The function [%s] returns value [%s]' % (function_name, ret))
                                    else:
                                        logger_http.debug('The call was not successful or the function [%s] return "None"' % function_name)
                                
                                logger.print_stack()
                            raise
                
                
                # here we build the route with prefix and api version. To add an api_version, you need a route prefix.
                # In default and many case (no prefix and no api version), the route will be : /my_method
                # In the case of the interface have only a route prefix, the route will be : /my-prefix/my_method
                # In the case of the interface have route prefix and api version, the route will be: /my-prefix/v1/my_method
                version_route_prefix = ''
                if obj.API_VERSION != 0:
                    version_route_prefix = 'v%s/' % obj.API_VERSION
                final_route = '%s/%s%s' % (obj.route_prefix, version_route_prefix, function_name)
                # and the name with - instead of _ if need
                final_route_dash_replaced = '%s/%s%s' % (obj.route_prefix, version_route_prefix, function_name.replace('_', '-'))
                # Ok now really put the route in place
                bottle.route(final_route, callback=f_wrapper, method=getattr(f, 'method', 'get').upper())
                if final_route != final_route_dash_replaced:
                    bottle.route(final_route_dash_replaced, callback=f_wrapper, method=getattr(f, 'method', 'get').upper())
            
            
            register_callback(fname, args, f, obj, self.lock)
        
        bottle.route('/', callback=HTTPDaemon._slash)
    
    
    # Add a simple / page
    @staticmethod
    def _slash():
        return "OK"
    
    
    # TODO to remove ( legacy form pyro )
    def unregister(self, obj):
        return
    
    
    def handleRequests(self, s):
        self.srv.handle_request()
    
    
    def create_uri(address, port, obj_name, use_ssl=False):
        return "PYRO:%s@%s:%d" % (obj_name, address, port)
    
    
    def set_timeout(con, timeout):
        con._pyroTimeout = timeout
    
    
    def _do_shutdown(self):
        if self.is_stopping:
            return
        else:
            self.is_stopping = True
        if self.srv is not None:
            self.srv.stop()
            self.srv = None
    
    
    # Close all sockets and delete the server object to be sure
    # no one is still alive
    def shutdown(self, quiet=False):
        if self.srv is None:
            return
        if not quiet:
            logger.debug('Closing the http socket on the process %s' % os.getpid())
        if os.getpid() != self.srv.srv_pid:
            # If shutdown is requested in a forked process, avoid deadlock by resetting the lock
            self.lock = threading.RLock()
        thread = threading.Thread(target=self._do_shutdown, name='Stopping cherrypy http server')
        thread.daemon = True
        thread.start()
        # This thread will never be joined, but it should be called only once a process
    
    
    def get_socks_activity(self, timeout):
        try:
            ins, _, _ = select.select(self.get_sockets(), [], [], timeout)
        except select.error as e:
            errnum, _ = e
            if errnum == errno.EINTR:
                return []
            raise
        return ins


daemon_inst = None


def http_daemon_set_daemon_inst(_inst):
    global daemon_inst
    daemon_inst = _inst
