#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2018
# This file is part of Shinken Enterprise, all rights reserved.
#
import collections
import functools
import os
import os.path
import sys
import threading
import time
from abc import ABC
from threading import Event
from types import FrameType

from shinken.compat import unicode_to_bytes
from shinken.log import PartLogger, LoggerFactory
from shinken.subprocess_helper.error_handler import ERROR_LEVEL

try:
    import ctypes
    
    libc = ctypes.CDLL('libc.so.6')
except Exception:
    libc = None

from shinken.misc.type_hint import TYPE_CHECKING

if TYPE_CHECKING:
    from shinken.misc.type_hint import WrappedFunction
    from typing import Callable, ParamSpec, Optional
    
    Parameters = ParamSpec('Parameters')


# Hook threading to allow thread renaming
def patch_thread_name():
    if sys.platform.startswith('win'):
        
        # If no ctypes, like in a static python build: exit
        try:
            import ctypes
        except ImportError:
            return
        import threading
        import time
        from ctypes import wintypes
        
        class THREADNAME_INFO(ctypes.Structure):
            _pack_ = 8
            _fields_ = [
                ('dwType', wintypes.DWORD),
                ('szName', wintypes.LPCSTR),
                ('dwThreadID', wintypes.DWORD),
                ('dwFlags', wintypes.DWORD),
            ]
            
            
            def __init__(self):
                self.dwType = 0x1000
                self.dwFlags = 0
        
        def debugChecker():
            kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
            RaiseException = kernel32.RaiseException
            RaiseException.argtypes = [
                wintypes.DWORD, wintypes.DWORD, wintypes.DWORD,
                ctypes.c_void_p]
            RaiseException.restype = None
            
            IsDebuggerPresent = kernel32.IsDebuggerPresent
            IsDebuggerPresent.argtypes = []
            IsDebuggerPresent.restype = wintypes.BOOL
            MS_VC_EXCEPTION = 0x406D1388
            info = THREADNAME_INFO()
            while True:
                time.sleep(1)
                if IsDebuggerPresent():
                    for thread in threading.enumerate():
                        if thread.ident is None:
                            continue  # not started
                        if hasattr(threading, '_MainThread'):
                            if isinstance(thread, threading._MainThread):
                                continue  # don't name the main thread
                        info.szName = '%s (Python)' % (thread.name,)
                        info.dwThreadID = thread.ident
                        try:
                            RaiseException(MS_VC_EXCEPTION, 0, ctypes.sizeof(info) / ctypes.sizeof(ctypes.c_void_p), ctypes.addressof(info))
                        except:
                            pass
        
        
        dt = threading.Thread(target=debugChecker, name='MSVC debugging support thread')
        dt.daemon = True
        dt.start()
    elif sys.platform.startswith('linux'):
        # Return if python was build without ctypes (like in a static build)
        try:
            import ctypes
            import ctypes.util
        except ImportError:
            return
        import threading
        libpthread_path = ctypes.util.find_library('pthread')
        if not libpthread_path:
            return
        libpthread = ctypes.CDLL(libpthread_path)
        if not hasattr(libpthread, 'pthread_setname_np'):
            return
        
        pthread_setname_np = libpthread.pthread_setname_np
        pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
        pthread_setname_np.restype = ctypes.c_int
        
        orig_setter = threading.Thread.__setattr__
        
        
        def attr_setter(self, prop_name, value):
            orig_setter(self, prop_name, value)
            if prop_name == 'name':
                ident = getattr(self, 'ident', None)
                if ident:
                    try:
                        pthread_setname_np(ident, unicode_to_bytes(value))
                    except:
                        pass  # Don't care about failure to set name
        
        
        threading.Thread.__setattr__ = attr_setter
        
        old_bootstrap_inner = threading.Thread._bootstrap_inner
        
        
        def namer_wapper(thread):
            threading.current_thread().name = thread.name[:15]
            old_bootstrap_inner(thread)
        
        
        threading.Thread._bootstrap_inner = namer_wapper


def get_thread_id():
    tid = 0
    if libc:
        tid = libc.syscall(186)  # get the thread id when you are in it :)
    return tid


def async_call(func: 'Callable[Parameters, None]') -> 'Callable[Parameters, threading.Thread]':
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        thread_name = '%s-async' % func.__name__
        _thread = threading.Thread(target=func, name=thread_name, args=args, kwargs=kwargs)
        _thread.daemon = True
        _thread.start()
        return _thread
    
    
    return wrapper


def locked_method(func):
    # type: (WrappedFunction) -> WrappedFunction
    
    @functools.wraps(func)
    def locked_method__wrapper(*args, **kwargs):
        self = args[0]  # Possibly used for @classmethod, so the first arg could be 'cls'
        lock = self._lock
        with lock:
            return func(*args, **kwargs)
    
    
    return locked_method__wrapper


_old_threads = {}


class LockWithTimer:
    
    _STACK_FRAME_NAMES_TO_SKIP = ('wrapper', '<listcomp>')
    
    
    def __init__(self, lock, my_logger=None, lock_name='generic lock', wait_time_log_threshold=1.0, release_time_log_threshold=1.0):
        self.lock = lock
        self.lock_name = lock_name
        
        if my_logger is None:
            self.logger = LoggerFactory.get_logger()
        else:
            self.logger = my_logger
        
        self.logger = self.logger.get_sub_part(f'LONG LOCK (name={self.lock_name})')
        
        self._wait_time_log_threshold: float = wait_time_log_threshold
        self._release_time_log_threshold: float = release_time_log_threshold
        
        self._caller_stack: 'collections.deque[tuple[str, float]]' = collections.deque()
    
    
    def acquire(self, *args, caller_stack_level: int = 1, **kwargs):
        frame: 'Optional[FrameType]'
        caller_name: str
        try:
            frame = sys._getframe(caller_stack_level)  # noqa: _getframe needed
            while frame is not None and any(key in frame.f_code.co_name for key in self._STACK_FRAME_NAMES_TO_SKIP):  # Skip decorator frames
                frame = frame.f_back
            if frame is None:
                raise ValueError('Out of bound')
            caller_name = f'{frame.f_code.co_name!r} in {os.path.basename(frame.f_code.co_filename)} (line {frame.f_lineno})'
        except Exception:
            caller_name = '<unknown function name>'
        finally:
            # Always drop the references to frame object to avoid memory leaks
            frame = None  # noqa: Yes it is unused after assignment, this is the goal.
        
        _start_wait_time = time.time()
        self.lock.acquire(*args, **kwargs)
        _acquire_timestamp = time.time()
        _wait_time = _acquire_timestamp - _start_wait_time
        
        self._caller_stack.append((caller_name, _acquire_timestamp))
        
        long_lock_logger = self._long_lock_logger_with_caller_sub_parts()
        
        if _wait_time >= self._wait_time_log_threshold:
            long_lock_logger.warning(f'Lock acquired after [ {PartLogger.format_duration(_wait_time)} ] seconds')
    
    
    def __enter__(self):
        self.acquire(caller_stack_level=2)
    
    
    def release(self, *args, **kwargs):
        try:
            if self._caller_stack:
                long_lock_logger = self._long_lock_logger_with_caller_sub_parts()
                _, acquire_start_time = self._caller_stack.pop()
                
                release_time = time.time() - acquire_start_time
                
                if release_time >= self._release_time_log_threshold:
                    long_lock_logger.warning(f'Released the locked after [ {PartLogger.format_duration(release_time)} ] seconds')
        finally:
            self.lock.release(*args, **kwargs)
    
    
    def __exit__(self, _type, value, tb):
        self.release()
    
    
    def _long_lock_logger_with_caller_sub_parts(self) -> 'PartLogger':
        logger: 'PartLogger' = self.logger
        for caller, _ in self._caller_stack:
            logger = logger.get_sub_part(caller, register=False)
        return logger


NO_SLEEP = -1
UNSET_PID = -1


class Thread:
    def __init__(self, only_one_thread_by_class=False, loop_speed: 'float|None' = 1, force_stop_with_application=True, stop_thread_on_error=True, stop_app_on_error=False, error_handler=None, logger=None, force_pause_between_loops=False):
        my_class = type(self)
        if only_one_thread_by_class and _old_threads.get(my_class, None):
            _old_threads[my_class].ask_stop()
        
        self._running = False
        self._thread: 'threading.Thread|None' = None
        self.loop_speed = loop_speed
        self._force_stop_with_application = force_stop_with_application
        self._stop_thread_on_error = stop_thread_on_error
        self._stop_app_on_error = stop_app_on_error
        self._error_handler = error_handler
        self._pause_event = Event()
        self._starting_pid = UNSET_PID  # Used to detect zombie thread object
        self._force_pause_between_loops = force_pause_between_loops
        if logger:
            self.logger = logger
        else:
            self.logger = LoggerFactory.get_logger()
        
        if only_one_thread_by_class:
            _old_threads[my_class] = self
    
    
    def start_thread(self):
        if not self._running:
            thread_name = self.get_thread_name()
            self._thread = threading.Thread(None, target=self._run, name=thread_name)
            self._thread.daemon = self._force_stop_with_application
            self._running = True
            self._starting_pid = os.getpid()  # we cannot live across fork() so we do not need to update this value after our start()
            self._pause_event.clear()
            self._thread.start()
    
    
    # A zombie thread is a thread that was started in a process, and is now in another process after a fork()
    # but fork() do not copy threads, so this object IS dead, it's now a zombie object
    def _is_zombie_thread(self):
        return self._starting_pid != UNSET_PID and os.getpid() != self._starting_pid
    
    
    def is_running(self):
        # We cannot be running if we are a zombie
        if self._is_zombie_thread():
            return False
        return self._running
    
    
    def ask_stop(self):
        if self._is_zombie_thread():  # Do not call more thing, we are already stop, we must NOT touch event and such broken object
            return
        self._running = False
        self._pause_event.set()
    
    
    def stop(self):
        if self._is_zombie_thread():  # Do not call more thing, we are already stop, we must NOT touch event and such broken object
            return
        if self._thread:
            self.ask_stop()
            self._thread.join()
    
    
    def get_thread_name(self):
        raise NotImplementedError()
    
    
    def _run(self):
        while self._running:
            try:
                _loop_start_time = time.monotonic()
                self.loop_turn()
                _loop_end_time = time.monotonic()
                if self.loop_speed is None:
                    self.interruptable_sleep(None)
                elif self.loop_speed != NO_SLEEP:
                    wait_time = self.loop_speed - (_loop_end_time - _loop_start_time) if not self._force_pause_between_loops else self.loop_speed
                    self.interruptable_sleep(max(wait_time, 0))
            except Exception as e:
                if self._error_handler:
                    try:
                        self._error_handler.handle_exception('Fatal error caused by : %s' % e, e, self.logger, level=ERROR_LEVEL.FATAL)
                    except Exception:
                        self.logger.error('Thread %s have a fatal error : %s' % (self.get_thread_name(), e))
                        self.logger.print_stack()
                else:
                    self.logger.error('Thread %s have a fatal error : %s' % (self.get_thread_name(), e))
                    self.logger.print_stack()
                if self._stop_app_on_error:
                    self.logger.error('Thread %s force stopping application' % self.get_thread_name())
                    os._exit(1)
                if self._stop_thread_on_error:
                    self._running = False
        if self.logger.is_debug():
            self.logger.debug(f'Ending thread {threading.current_thread().name}')
    
    
    def loop_turn(self):
        raise NotImplementedError()
    
    
    def interruptable_sleep(self, raw_duration: 'float | None') -> None:
        self._pause_event.wait(raw_duration)
        if self._running and self._pause_event.is_set():
            self._pause_event.clear()
    
    
    def notify(self):
        self._pause_event.set()
    
    
    def after_fork_cleanup(self):
        global _old_threads
        _old_threads = {}
        if self._thread:
            try:
                # this may be useless ?
                self._thread.join()
            except:
                pass
        self._thread = None
        self._pause_event = None
        self._error_handler = None
        self.logger = None


class DelayedThread(Thread):
    def loop_turn(self):
        raise NotImplementedError()
    
    
    def get_thread_name(self):
        raise NotImplementedError()
    
    
    def __init__(self, delay_to_wait=0, only_one_thread_by_class=False, loop_speed=1, force_stop_with_application=True, stop_thread_on_error=True, stop_app_on_error=False, error_handler=None, logger=None):
        self.delay_to_wait = delay_to_wait
        super(DelayedThread, self).__init__(only_one_thread_by_class, loop_speed, force_stop_with_application, stop_thread_on_error, stop_app_on_error, error_handler, logger)
    
    
    def _run(self):
        # type: () -> None
        # If we are not delayed, then no need to log that we are delayed
        if self.delay_to_wait:
            self.logger.info('Waiting for [ %d ] seconds before running the thread' % self.delay_to_wait)
            time.sleep(self.delay_to_wait)
        super(DelayedThread, self)._run()
    
    
    def set_delay_to_wait(self, delay_to_wait):
        # type: (int) -> None
        self.delay_to_wait = delay_to_wait


class OnEventWorkingThread(Thread, ABC):
    
    def __init__(self, only_one_thread_by_class=False, force_stop_with_application=True, stop_thread_on_error=True, stop_app_on_error=False, error_handler=None, logger=None):
        super(OnEventWorkingThread, self).__init__(only_one_thread_by_class, None, force_stop_with_application, stop_thread_on_error, stop_app_on_error, error_handler, logger)
        self._skip_first_loop_turn_call = True
    
    
    def work(self):
        raise NotImplementedError()
    
    
    def loop_turn(self):
        if self._skip_first_loop_turn_call:
            self._skip_first_loop_turn_call = False
            return
        # Clear the event BEFORE calling self.work()
        # because if self.notify() method has been called during self.work() run,
        # this would be erased and the thread would return to IDLE mode.
        self._pause_event.clear()
        self.work()


class ScheduledCall:
    
    def __init__(self, *, target: 'Callable[[], object] | None' = None, name: 'str | None' = None) -> None:
        self._target = target
        self._timer_lock = threading.RLock()
        self._timer = None  # type: Optional[threading.Timer]
        self._default_name = name or f'{self.__class__.__name__}'
    
    
    # Can be overridden
    def get_thread_name(self) -> str:
        return self._default_name
    
    
    # Can be overridden
    def run(self) -> None:
        if self._target is not None:
            self._target()
    
    
    def cancel(self) -> None:
        with self._timer_lock:
            timer = self._timer
            self._timer = None
            if timer is not None:
                timer.cancel()
    
    
    def reschedule(self, wait_time_in_secs: float) -> None:
        thread_name = f'{self.get_thread_name()} ({wait_time_in_secs} seconds timer)'
        
        with self._timer_lock:
            if self._timer is not None:
                self._timer.cancel()
            
            
            def _on_timer_done() -> None:
                # There can be a re-schedule or cancellation while waiting (race condition)
                effectively_timed_out: bool = False
                
                with self._timer_lock:
                    if self._timer is my_timer:
                        self._timer = None
                        effectively_timed_out = True
                
                # Take care to call self.run() out of the lock scope
                if effectively_timed_out:
                    self.run()
            
            
            my_timer = threading.Timer(wait_time_in_secs, _on_timer_done)
            my_timer.name = thread_name
            my_timer.daemon = True
            
            self._timer = my_timer
            self._timer.start()
