import os
import threading
import traceback

import sys
import time

from shinken.log import LoggerFactory
from shinken.misc.type_hint import TYPE_CHECKING
from shinken.thread_helper import get_thread_id

if TYPE_CHECKING:
    from shinken.misc.type_hint import Optional

raw_logger = LoggerFactory.get_logger()
logger_stack = raw_logger.get_sub_part('THREADS STACK')
logger_watch_dog = raw_logger.get_sub_part('WATCH DOG')
get_thread_id()


# Used by checks to know if the status was good or not
# NOTE: I don't see any WARNING status for a deadlock, there is no zombie lock ^^
class WATCH_DOG_STATUS_CODE:  # noqa: yes, no __init__
    OK = 0
    CRITICAL = 2


class WATCH_DOG_STATUS_KEYS:  # noqa: yes, no __init__
    HAVE_FATAL_DEAD_LOCK = 'have_fatal_dead_lock'
    DETECTION_EPOCH = 'detection_epoch'
    NAME = 'name'
    WAITED_TIME = 'waited_time'


# If a deadlock is here since TOO long, then we will be able to warn through
# monitoring checks that the process is locked and MUST be restarted manually
class WatchDogFatalStatus:
    def __init__(self):
        self._did_watchdog_detect_fatal_deadlock = False
        self._locked_name = ''  # name of the locked call
        self._lock_detection_date = 0  # epoch
        self._waited_time = 0  # how much time did we wait before
    
    
    def _reset_for_test(self):  # for testing
        self.__init__()
    
    
    def set_fatal_deadlock(self, name, waited_time):
        # Maybe It's a SECOND lock that was stacked over one already detected, if so
        # only warn about the first one, else we won't be able to find the root problem
        if self._did_watchdog_detect_fatal_deadlock:
            logger_watch_dog.error('Another dead lock was existing (%s at %s) before the new lock %s' % (self._locked_name, self._lock_detection_date, name))
            return
        self._did_watchdog_detect_fatal_deadlock = True
        self._locked_name = name
        self._lock_detection_date = int(time.time())
        self._waited_time = waited_time
    
    
    def is_daemon_in_fatal_deadlock(self):
        return self._did_watchdog_detect_fatal_deadlock
    
    
    # IMPORTANT: this is directly send to the get_raw_stats so do NOT edit without
    #            checking tests are ok with the change
    def get_status(self):
        return {
            'have_fatal_dead_lock': self._did_watchdog_detect_fatal_deadlock,
            'detection_epoch'     : self._lock_detection_date,
            'name'                : self._locked_name,
            'waited_time'         : self._waited_time,
        }
    
    
    @staticmethod
    def __assert_status_coherency(status):
        must_have_properties = (WATCH_DOG_STATUS_KEYS.HAVE_FATAL_DEAD_LOCK, WATCH_DOG_STATUS_KEYS.DETECTION_EPOCH, WATCH_DOG_STATUS_KEYS.NAME, WATCH_DOG_STATUS_KEYS.WAITED_TIME)
        # First a sanity check
        for key in must_have_properties:
            if key not in status:
                return (WATCH_DOG_STATUS_CODE.CRITICAL, 'The watch dog status is broken, missing key %s.' % key)  # noqa: pycharm, I want a tuple, I want to see it with ()
        # All is well
        return None
    
    
    # Called by the checks to look at the get_status() return
    # NOTE: we give a tuple and not a dict because so the dev MUST check that the caller is
    #       updated too
    def check_status_and_output_text(self, status, optional_text_output_prefix=''):
        _coherency = self.__assert_status_coherency(status)
        if _coherency:  # maybe there is a problem in the status structure, give it back (it's flat text)
            return _coherency
        
        # Maybe there is no problem, if so, we are finish as we are only raising if problem ^^
        if not status[WATCH_DOG_STATUS_KEYS.HAVE_FATAL_DEAD_LOCK]:
            return (WATCH_DOG_STATUS_CODE.OK, '')  # noqa: pycharm, I want a tuple, I want to see it with ()
        
        output_prefix = optional_text_output_prefix
        if output_prefix == '':
            output_prefix = 'The daemon have a lock, it\'s not working and MUST be restarted. '
        
        # Arg, we have a real problem here...
        outputs = [
            output_prefix,
            'Please contact your support to analyse the daemon logs: ',
            '"%s" was locked more than %ds, ' % (status[WATCH_DOG_STATUS_KEYS.NAME], status[WATCH_DOG_STATUS_KEYS.WAITED_TIME]),
            'and detected at %s [ WATCH DOG ]' % time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(status[WATCH_DOG_STATUS_KEYS.DETECTION_EPOCH])),
        ]
        output = ''.join(outputs)  # we do not want any \n here
        return (WATCH_DOG_STATUS_CODE.CRITICAL, output)  # noqa: pycharm, I want a tuple, I want to see it with ()
    
    
    # Called by the checks to look at the get_status() return
    # NOTE: we give a tuple and not a dict because so the dev MUST check that the caller is
    #       updated too
    def check_status_and_output_html(self, status, optional_html_output_prefix=''):
        _coherency = self.__assert_status_coherency(status)
        if _coherency:  # maybe there is a problem in the status structure, give it back (it's flat text)
            return _coherency
        
        # Maybe there is no problem, if so, we are finish as we are only raising if problem ^^
        if not status[WATCH_DOG_STATUS_KEYS.HAVE_FATAL_DEAD_LOCK]:
            return (WATCH_DOG_STATUS_CODE.OK, '')  # noqa: pycharm, I want a tuple, I want to see it with ()
        
        output_prefix = optional_html_output_prefix
        if output_prefix == '':
            output_prefix = 'The daemon have a <b>lock</b>, it\'s <b>not working</b> and <b>MUST</b> be restarted.<br/>'
        
        # Arg, we have a real problem here...
        outputs = [
            output_prefix,
            'Please contact your support to analyse the daemon logs:<br/>',
            '<ul>',
            '<li>"%s" was locked more than %ds</li>' % (status[WATCH_DOG_STATUS_KEYS.NAME], status[WATCH_DOG_STATUS_KEYS.WAITED_TIME]),
            '<li>Detected at %s [ WATCH DOG ]</li>' % time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(status[WATCH_DOG_STATUS_KEYS.DETECTION_EPOCH])),
            '</ul>',
        ]
        output = ''.join(outputs)  # we do not want any \n here
        return (WATCH_DOG_STATUS_CODE.CRITICAL, output)  # noqa: pycharm, I want a tuple, I want to see it with ()


# Object used to exchange between the caller and the watchdog thread
class WatchdogMessage:
    def __init__(self):
        self._is_call_finished = False
        self.start_time = None
    
    
    def set_finished(self):
        self._is_call_finished = True
        self.start_time = None
    
    
    def set_starting(self):
        self._is_call_finished = False
        self.start_time = time.time()
    
    
    def has_finished(self):
        return self._is_call_finished


# This class is a static thread dumper: if asked, it will log stacks from all
# the threads
class ThreadsDumper:
    @staticmethod
    def dump_all_threads():
        all_threads = threading.enumerate()
        logger_stack.info('Starting to dump all the %d threads currently running' % (len(all_threads)))
        for th in all_threads:
            try:
                try:
                    thread_frames = sys._current_frames()[th.ident]  # noqa: yes we go in _current_frames
                except KeyError:
                    # The thread can be killed at this point so sys._current_frames() don't have frames for this thread
                    logger_stack.debug('Thread : %s is already dead. We cannot print this stack.' % str(th))
                    continue
                extract = traceback.extract_stack(thread_frames)
                formatted = traceback.format_list(extract)
                logger_stack.info('* Thread: %s' % str(th))
                for entry in formatted:
                    # The line is in fact 2 lines
                    lines = entry.split('\n')
                    for line in lines:
                        if line:
                            logger_stack.info('  %s' % line.strip())
            except:
                logger_stack.error('fail to dump stack of thread : %s' % str(th))


class WatchDogThreadDumper:
    def __init__(self, caller_name, wait_time, dump_interval, fatal_dead_lock_delay=None, multi_usage=False):
        # type: (str, int, int, Optional[int], Optional[bool]) -> None
        self._caller_name = caller_name
        self._wait_time = wait_time
        self._dump_interval = dump_interval
        self._fatal_dead_lock_delay = fatal_dead_lock_delay
        self._keep_thread = multi_usage
        self._is_running = True
        self._cond = threading.Condition(threading.RLock())
        # Fatal must be at least wait time, or it's a nonsense
        if self._fatal_dead_lock_delay is not None and self._fatal_dead_lock_delay < self._wait_time:
            raise Exception('The WatchDocThreadDumper %s definition is wrong: fatal time:%s < wait time:%s' % (caller_name, self._fatal_dead_lock_delay, self._wait_time))
        
        self._logger = logger_watch_dog.get_sub_part(caller_name, register=False)
        
        self._exchange_message = WatchdogMessage()
        
        self._thread = None  # type: Optional[threading.Thread]
    
    
    def _thread_launch_watchdog(self):
        last_dump_time = 0.0  # do not dump too much, at much once a minute
        while not self._exchange_message.has_finished():
            now = time.time()
            elapsed_time = now - self._exchange_message.start_time
            last_dump_delay = now - last_dump_time
            
            if elapsed_time < self._wait_time:
                timeout = self._wait_time - elapsed_time
            elif self._fatal_dead_lock_delay and elapsed_time < self._fatal_dead_lock_delay:
                timeout = self._fatal_dead_lock_delay - elapsed_time
            elif last_dump_delay < self._dump_interval:
                timeout = self._dump_interval - last_dump_delay
            else:
                timeout = self._dump_interval
            
            # Time is over, but do not dump too much
            if elapsed_time > self._wait_time:
                if last_dump_delay > self._dump_interval:
                    self._logger.warning('The call is too long (%ds > %ds) so we are dumping all threads to help find what is blocking' % (elapsed_time, self._wait_time))
                    thread_dumper.dump_all_threads()
                    last_dump_time = now
            # Maybe we are entering a fatal deadlock time, so checks will be warned
            # NOTE: this is a one way thing, you don't came back from it
            if self._fatal_dead_lock_delay is not None and elapsed_time > self._fatal_dead_lock_delay and not watchdog_fatal_status.is_daemon_in_fatal_deadlock():
                self._logger.error('The call did reach a fatal dead lock period (%ds > %ds). Now monitoring checks will be in ERRORS until the daemon is restarted.' % (elapsed_time, self._fatal_dead_lock_delay))
                watchdog_fatal_status.set_fatal_deadlock(self._caller_name, elapsed_time)
                
            if self._logger.is_debug():
                self._logger.debug('Watchdog we will wait %.3fs before dumping threads' % timeout)
                
            self._cond.wait(timeout)
        if self._logger.is_debug():
            start_time = self._exchange_message.start_time
            if start_time:
                self._logger.debug('[ %.3fs ] Watchdog completed' % (time.time() - start_time))
            else:
                self._logger.debug('Watchdog completed')
    
    
    def _thread_loop(self):
        try:
            with self._cond:
                while self._is_running:
                    if not self._exchange_message.has_finished():
                        self._thread_launch_watchdog()
                        if not self._keep_thread:
                            break
                        continue
                    
                    if self._logger.is_debug():
                        self._logger.debug('Watchdog thread is entering standby mode')
                    self._cond.wait()
            if self._logger.is_debug():
                self._logger.debug('Watchdog thread is exiting')
        except:  # noqa
            self._logger.error('The thread did crash with an unknown error, exiting: %s' % traceback.format_exc())
            os._exit(2)  # noqa
    
    
    def quit(self):
        if self._logger.is_debug():
            self._logger.debug('The caller requests watchdog quit')
        with self._cond:
            self._exchange_message.set_finished()
            self._is_running = False
            self._cond.notify_all()
        if self._thread:
            self._thread.join()
            self._thread = None
    
    
    def __enter__(self):
        if self._logger.is_debug():
            self._logger.debug('The caller requests a watchdog start')
        with self._cond:
            self._exchange_message.set_starting()
            self._cond.notify_all()
        
        if not self._thread:
            thread_name = 'WD-%s' % self._caller_name
            self._thread = threading.Thread(None, target=self._thread_loop, name=thread_name)
            self._thread.daemon = True
            self._thread.start()
    
    
    # We are called from the caller thread, so let the thread
    # know we have finished, and wait a bit it finishes
    def __exit__(self, exc_type, exc_value, traceback_):
        if self._logger.is_debug():
            self._logger.debug('The caller requests a watchdog end')
        with self._cond:
            self._exchange_message.set_finished()
            if not self._keep_thread:
                self._is_running = False
            self._cond.notify_all()
        if not self._keep_thread:
            self._thread.join()
            self._thread = None


thread_dumper = ThreadsDumper()
watchdog_fatal_status = WatchDogFatalStatus()
