import inspect
import math
import os
import signal
import sys
import time
import uuid
from ctypes import c_bool
from multiprocessing import Process, Event, Queue, Value
from threading import Thread, Lock

from shinken.misc.type_hint import Optional, NoReturn, TYPE_CHECKING
from shinken.subprocess_helper.after_fork_cleanup import after_fork_cleanup
from shinken.subprocess_helper.error_handler import ERROR_LEVEL, ErrorHandler
from .compat import Empty as EmptyQueue
from .log import logger as _logger, LoggerFactory, PART_INITIALISATION, PartLogger
from .runtime_stats.threads_dumper import thread_dumper
from .runtime_stats.memory_stats import MemoryStats
from .thread_helper import Thread
from .util import start_malloc_trim_thread, set_process_name

if TYPE_CHECKING:
    from .ipc.shinken_queue.shinken_queue import ShinkenQueue
    from .log import PartLogger
    from .misc.type_hint import List, Any, Optional, Callable, Str, Number, Union

ARG_SEPARATOR = ':::'
# noinspection SpellCheckingInspection
_UNINTERRUPTABLE_SLEEP_TIME = 1.0
_old_process = {}

ON_LINUX = sys.platform.startswith("linux")


def get_old_process(class_name):
    my_pid = os.getpid()
    if my_pid in _old_process:
        return _old_process[my_pid].get(class_name, None)


def register_old_process(class_name, instance):
    my_pid = os.getpid()
    for pid in _old_process.keys():
        if pid != my_pid:
            del _old_process[pid]
    
    if my_pid not in _old_process:
        _old_process[my_pid] = {}
    
    _old_process[my_pid][class_name] = instance


class LookAtMyFatherThread(Thread):
    def __init__(self, father_pid, father_name, to_kill_name, loop_speed=1, logger=None):
        super(LookAtMyFatherThread, self).__init__(loop_speed=loop_speed, logger=logger)
        self.father_pid = father_pid
        self.father_name = father_name
        self.to_kill_name = to_kill_name
    
    
    def loop_turn(self):
        try:
            os.kill(self.father_pid, 0)
        except:
            self.logger.error('The father process [%s]-[%s] is dead, I kill [%s]-[%s].' % (self.father_pid, self.father_name, os.getpid(), self.to_kill_name))
            # noinspection PyUnresolvedReferences, PyProtectedMember
            os._exit(0)
    
    
    def get_thread_name(self):
        return 'look-at-father'


class EventHandler(Thread):
    event = None  # type: Event
    event_name = u'event-handler'  # type:unicode
    
    
    def __init__(self, event_name=u'event-handler', error_handler=None, time_wait_for_started=0):
        # type: (unicode, ErrorHandler, float) -> None
        self._private_running = Value(c_bool, False)
        super(EventHandler, self).__init__(error_handler=error_handler)
        self.event = Event()
        self.event_name = event_name
        self.time_wait_for_started = time_wait_for_started
    
    
    @property
    def _running(self):
        return self._private_running.value
    
    
    @_running.setter
    def _running(self, _running_value):
        self._private_running.value = _running_value
    
    
    def get_thread_name(self):
        return self.event_name
    
    
    def loop_turn(self):
        self.event.wait()
        self.callback()
        self.event.clear()
    
    
    def send_event(self):
        time_wait = 0
        while not self._running:
            if not self.time_wait_for_started or time_wait > self.time_wait_for_started:
                raise Exception('You can\'t send an event before starting the event handler')
            time.sleep(0.1)
            time_wait += 0.1
        self.event.set()
    
    
    def callback(self):
        raise NotImplementedError()


class QueueHandler(Thread):
    
    def __init__(self, max_size, queue_name=u'queue-handler', error_handler=None, item_keep_duration=0):
        super(QueueHandler, self).__init__(error_handler=error_handler)
        self.owner_pid = os.getpid()
        self._queue = Queue()
        self._queue_name = queue_name
        self._max_size = max_size
        self.all_items = []
        self.item_keep_duration = item_keep_duration
    
    
    def init(self):
        self._queue = Queue()
    
    
    def stop(self):
        Thread.stop(self)
        if self._queue:
            self._queue.close()
            self._queue.join_thread()
    
    
    def reset(self):
        self.stop()
        self.owner_pid = os.getpid()
        self._queue = Queue()
        self.all_items = []
        if self.is_running() and self._thread:
            self.ask_stop()
            self._thread.join()
        self.start_thread()
        
        self.logger.info(u'Reset %s for owner_pid:%s' % (self._queue_name, self.owner_pid))
    
    
    def get_thread_name(self):
        return self._queue_name
    
    
    def loop_turn(self):
        try:
            while True:
                item = self._queue.get_nowait()
                now = time.time()
                self.all_items.append((now, item))
                self.on_add_item(item)
                if self.all_items and self.item_keep_duration:
                    for i in self.all_items:
                        if (now - i[0]) > self.item_keep_duration:
                            self.all_items.pop(0)
                        else:
                            break
                if len(self.all_items) > self._max_size:
                    self.all_items.pop(0)
                self.interruptable_sleep(self.loop_speed)
        except EmptyQueue:
            # we will continue to loop until new item is available
            pass
        self.interruptable_sleep(self.loop_speed)
    
    
    def on_add_item(self, item):
        pass
    
    
    def put(self, item):
        self._queue.put(item)
    
    
    def after_fork_cleanup(self):
        super(QueueHandler, self).after_fork_cleanup()
        self._queue = None
        self.all_items = []


class UnknownCommandException(Exception):
    pass


# Object that will be sent by the daemon to the module with a specific call (and maybe args)
# Will be used to match respond and look if the answer have the same uuid as the request
class ToModuleCommandRequest(object):
    def __init__(self, command_name, args):
        # type: (Str, List[Str]) -> None
        str_send_command = command_name  # type: Union[Str, List[Str]]
        if args:
            str_send_command = [command_name]
            str_send_command.extend(args)
            str_send_command = ARG_SEPARATOR.join(str_send_command)
        self._command = str_send_command  # type: Union[Str, List[Str]]
        self._uuid = uuid.uuid4().hex  # type: Str
    
    
    def get_command(self):
        # type: () -> Union[Str, List[Str]]
        return self._command
    
    
    def get_uuid(self):
        # type: () -> Str
        return self._uuid
    
    
    def create_respond(self, result_payload):
        # type: (Callable) -> Any
        respond = FromModuleCommandRespond(self._uuid, result_payload)
        return respond
    
    
    def __str__(self):
        return u'ToModuleCommandRequest[%s-%s]' % (self._uuid, self._command)


# This will be created by a request (with the good uuid, etc.) to give back a result from the module.
# Will be used by the original request to match if the response match the request uuid (can be problem in queues)
class FromModuleCommandRespond(object):
    def __init__(self, request_uuid, result_payload):
        # type: (Str, Callable) -> None
        self._uuid = request_uuid
        self._payload = result_payload
    
    
    def do_match_request(self, request):
        # type: (ToModuleCommandRequest) -> Any
        request_uuid = request.get_uuid()
        return self._uuid == request_uuid
    
    
    def get_payload(self):
        return self._payload


class CommandQueueHandler(Thread):
    def __init__(self, name, parent_logger, commands_to_q, commands_from_q, main_process):
        # type: (Str, PartLogger, ShinkenQueue, ShinkenQueue, object) -> None
        super(CommandQueueHandler, self).__init__(loop_speed=-1, logger=parent_logger)
        self.daemon = True
        self.logger = parent_logger
        self.handler_name = name
        self.commands_to_q = commands_to_q
        self.commands_from_q = commands_from_q
        self.main_process = main_process
        self.name = name
        
        # Lock for commands from different threads racing for the result
        self.send_command_lock = Lock()  # type: Lock
    
    
    def get_thread_name(self):
        return self.name
    
    
    def loop_turn(self):
        self.get_and_execute_command_from_master()
    
    
    def get_and_execute_command_from_master(self):
        logger_command_call = self.logger.get_sub_part(u'COMMAND CALL', len(u'COMMAND CALL'))
        cmd_and_arg = ''
        try:  # NOTE: this thread is not allowed to die
            
            # Will block so we don't hammer cpu
            try:
                request = self.commands_to_q.get(block=True, timeout=1)  # type: Optional[ToModuleCommandRequest]
            except:
                request = None
            # Nothing in the queue, just loop
            if request is None:
                return
            cmd_and_arg = request.get_command()
            _split = cmd_and_arg.split(ARG_SEPARATOR)
            cmd = _split[0]
            arg = _split[1:]
            f = getattr(self.main_process, cmd, None)
            if callable(f):
                logger_command_call.debug('[PID:%s] Executing command [%s] with param %s' % (os.getpid(), cmd, arg))
                arg_spec = inspect.getargspec(f)
                if arg and len(arg_spec.args) == len(arg):
                    result = f(*arg)
                else:
                    result = f()
                respond = request.create_respond(result)
                self.commands_from_q.put(respond)
            else:
                logger_command_call.warning('[PID:%s] Received unknown command [%s] from father process !' % (os.getpid(), cmd))
        except:
            logger_command_call.error('Our father process did send us the command [%s] that did fail because:' % cmd_and_arg)
            logger_command_call.print_stack()
            time.sleep(0.01)  # if we crash in loop, do not hammer CPU
    
    
    def send_command(self, command_name, args=None, timeout=10):
        # type: (str, Optional[List[str]], Number) -> Any
        
        logger_command_call = self.logger.get_sub_part(u'COMMAND CALL', len(u'COMMAND CALL'))
        with self.send_command_lock:
            try:
                request = ToModuleCommandRequest(command_name, args)
                retry_count = 0
                before = time.time()
                self.commands_to_q.put(request)
                while True:  # will time out at max 2s because the commands request are blocked, and we are the only one asking for it
                    try:
                        response = self.commands_from_q.get(block=True, timeout=timeout)  # type: FromModuleCommandRespond
                    except EmptyQueue:
                        # Empty? Means a timeout
                        retry_count += 1
                        if retry_count == 1:
                            logger_command_call.warning('The command call [%s] for module %s was sent but the call did timeout (%ss). We will retry one time.' % (command_name, self.handler_name, timeout))
                            continue
                        # Ok still an error? need to give a real error about the timeout
                        raise
                    if not response.do_match_request(request):
                        logger_command_call.warning('The command call [%s] was called but another respond was received. Retrying.' % command_name)
                        continue
                    logger_command_call.debug('The command call [%s] was executed by the module %s in %.3fs' % (command_name, self.handler_name, time.time() - before))
                    payload = response.get_payload()
                    return payload
            
            # Empty queue means a timeout
            except EmptyQueue:
                _message = 'Fail to send command call [%s] for module %s because the module did timeout (%ss)' % (command_name, self.handler_name, timeout)
                logger_command_call.error(_message)
                raise Exception(_message)
            
            # Another exception? It's not good and must be shown as a stack
            except Exception as e:
                if getattr(self.main_process, 'interrupted', False) is False:
                    logger_command_call.print_stack()
                    _message = 'Fail to send command call [%s] for module %s because of an unknown error %s' % (command_name, self.handler_name, e.message)
                    logger_command_call.error(_message)
                    raise Exception(_message)


class BaseSubProcess(object):
    def __init__(self, name, father_name=u'UNSET', loop_speed=1, stop_process_on_error=True, only_one_process_by_class=False, error_handler=None):
        # type: (unicode, unicode, int, bool, bool, Optional[ErrorHandler]) -> None
        my_class = type(self)
        old_process = get_old_process(my_class)
        if only_one_process_by_class and old_process:
            old_process.stop()
        
        # this part is launch in the father process
        self.name = name
        self.father_name = father_name
        self.father_pid = os.getpid()
        # set initial logger with name. we call _set_logger_name in _sub_process_common_warmup cause _set_logger_name must be call after all init.
        self.logger = LoggerFactory.get_logger(self.name)
        
        self._process = None
        self.interrupted = False
        self.loop_speed = loop_speed
        self._stop_process_on_error = stop_process_on_error
        self._error_handler = error_handler
        
        self.look_at_father_thread = LookAtMyFatherThread(self.father_pid, self.father_name, self.name, logger=self.logger)
        
        register_old_process(my_class, self)
        
        if TYPE_CHECKING:
            self.loaded_into = ''
    
    
    def update_father(self, father_name):
        # type: (unicode) -> NoReturn
        self.father_name = father_name
        self.father_pid = os.getpid()
        self.look_at_father_thread.father_name = self.father_name
        self.look_at_father_thread.father_pid = self.father_pid
    
    
    def manage_signal(self, sig, frame):
        if not ON_LINUX:
            return
        if sig == signal.SIGUSR1:  # if USR1, ask a memory dump
            MemoryStats.dump_memory_full_memory_dump(self.name)
        elif sig == signal.SIGPWR:  # SIGPWR (old signal not used) dump all threads stacks
            thread_dumper.dump_all_threads()
        else:
            self.interrupted = True
    
    
    def set_signal_handler(self, sigs=None):
        if not ON_LINUX:
            return
        
        if sigs is None:
            sigs = (signal.SIGINT, signal.SIGTERM, signal.SIGUSR1, signal.SIGPWR)
        
        for sig in sigs:
            signal.signal(sig, self.manage_signal)
    
    
    set_exit_handler = set_signal_handler
    
    
    # In the sub process, we should start some threads, like the one that look at parent
    # etc. etc. Sub Class can have other threads for special purpose
    def _start_sub_process_threads(self):
        start_malloc_trim_thread()
        self.look_at_father_thread.start_thread()
    
    
    def _set_logger_name(self):
        logger_name = self.get_logger_name()
        if isinstance(logger_name, (list, tuple)):
            logger_names = [i.replace('[', '').replace(']', '') for i in logger_name]
            if len(logger_names) > 1:
                _logger_part_names = logger_names[1:]
                my_logger = self.logger
                my_logger.set_default_part(_logger_part_names[0])
                for part_name in _logger_part_names[1:]:
                    my_logger = my_logger.get_sub_part(part_name)
                self.logger = my_logger
                self.look_at_father_thread.logger = my_logger
            logger_name = logger_name[0]
        else:
            self.logger = LoggerFactory.get_logger()
        
        logger_name = logger_name.replace('[', '(').replace(']', ')')
        _logger.set_name(logger_name)
    
    
    def _sub_process_common_warmup(self):
        self._set_logger_name()
        logger_init = self.logger.get_sub_part(PART_INITIALISATION)
        logger_init.debug('_sub_process_common_warmup')
        set_process_name(self.get_process_name())
        after_fork_cleanup.do_after_fork_cleanup(self)
        
        logger_init.debug('stop http')
        from http_daemon import daemon_inst
        if daemon_inst:
            # We are shutdown the http daemon because we should not answer to the http daemon calls.
            # NOTE: this is done in a thread
            daemon_inst.shutdown(quiet=True)
        
        logger_init.debug('set signal handler')
        self.set_signal_handler()
        
        logger_init.debug('start sub process')
        # Our sub process need some utility threads, like:
        # * malloc trimming
        # * looking at father process (if dead, we die)
        # * others set by the subclass
        self._start_sub_process_threads()
        logger_init.info('Sub process ready to work.')
    
    
    def get_process_name(self):
        # type: () -> unicode
        return u'%s [ %s ]' % (self.father_name, self.name)
    
    
    def get_logger_name(self):
        # type: () -> unicode
        return self.name
    
    
    def main(self):
        try:
            self._sub_process_common_warmup()
            self.on_init()
        except Exception as e:
            self.interrupted = True
            if self._error_handler:
                try:
                    self._error_handler.handle_exception('Fatal error caused by : %s' % e, e, self.logger, level=ERROR_LEVEL.FATAL)
                except Exception:
                    self.logger.error('Sub-process %s have a fatal error : %s' % (self.name, e))
                    self.logger.print_stack()
            else:
                self.logger.error('Sub-process %s have a fatal error : %s' % (self.name, e))
                self.logger.print_stack()
        
        while not self.interrupted:
            try:
                self.loop_turn()
                self.interruptable_sleep(self.loop_speed)
            except Exception as e:
                if self._error_handler:
                    self._error_handler.handle_exception('Fatal error caused by : %s' % e, e, self.logger, level=ERROR_LEVEL.FATAL)
                else:
                    self.logger.error('Sub-process %s have a fatal error : %s' % (self.name, e))
                    self.logger.print_stack()
                if self._stop_process_on_error:
                    self.interrupted = True
        self.on_close()
    
    
    def interruptable_sleep(self, raw_duration):
        # type: (Number) -> None
        if raw_duration <= _UNINTERRUPTABLE_SLEEP_TIME:
            time.sleep(raw_duration)
            return
        
        duration = int(math.floor(raw_duration / _UNINTERRUPTABLE_SLEEP_TIME))
        round_duration = raw_duration - duration
        for _ in xrange(duration):
            time.sleep(_UNINTERRUPTABLE_SLEEP_TIME)
            if self.interrupted:
                break
        
        time.sleep(round_duration)
    
    
    def loop_turn(self):
        # type: () -> NoReturn
        raise NotImplementedError()
    
    
    def on_init(self):
        # type: () -> NoReturn
        pass
    
    
    def on_close(self):
        # type: () -> NoReturn
        pass
    
    
    def join(self):
        # type: () -> None
        if self._process:
            self._process.join()
    
    
    def start(self, father_name=''):
        # type: (unicode) -> None
        if self._process and self._process.is_alive():
            return
        self._process = Process(target=self.main)
        self.interrupted = False
        if father_name:
            self.update_father(father_name)
        self._process.start()
    
    
    def stop(self):
        # type: () -> NoReturn
        self.logger.info('Stopping process %s' % self.get_process_name())
        if self._process and self._process.is_alive():
            self._process.terminate()
            self._process.join(timeout=1)
            if self._process.is_alive():
                self.logger.info('The [%s] sub-process is still alive, I help it to die' % self.get_process_name())
                self._kill()
                self._process.join(timeout=1)
        self._process = None
    
    
    def is_alive(self):
        return self._process and self._process.is_alive()
    
    
    def _kill(self):
        if self._process is None:
            return
        if os.name == 'nt':
            self._process.terminate()
        else:
            # Ok, let him 1 second before really KILL IT
            os.kill(self._process.pid, signal.SIGTERM)
            time.sleep(1)
            # You do not let me another choice guy...
            try:
                if self._process.is_alive():
                    self.logger.info('The [%s] pid:[%s] sub-process is still alive, I kill it (kill -9)' % (self.get_process_name(), self._process.pid))
                    os.kill(self._process.pid, signal.SIGKILL)
            except AssertionError:  # zombie process
                try:
                    os.kill(self._process.pid, signal.SIGKILL)
                except OSError:
                    pass
