#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2018
# This file is part of Shinken Enterprise, all rights reserved.
#
import os
import sys
import threading
import time
from threading import Event

import six

from shinken.log import PartLogger, LoggerFactory
from shinken.subprocess_helper.error_handler import ERROR_LEVEL

try:
    import ctypes
    
    libc = ctypes.CDLL('libc.so.6')
except Exception:
    libc = None


# Hook threading to allow thread renaming
def patch_thread_name():
    if sys.platform.startswith('win'):
        
        # If no ctypes, like in a static python build: exit
        try:
            import ctypes
        except ImportError:
            return
        import threading
        import time
        from ctypes import wintypes
        
        class THREADNAME_INFO(ctypes.Structure):
            _pack_ = 8
            _fields_ = [
                ('dwType', wintypes.DWORD),
                ('szName', wintypes.LPCSTR),
                ('dwThreadID', wintypes.DWORD),
                ('dwFlags', wintypes.DWORD),
            ]
            
            
            def __init__(self):
                self.dwType = 0x1000
                self.dwFlags = 0
        
        def debugChecker():
            kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
            RaiseException = kernel32.RaiseException
            RaiseException.argtypes = [
                wintypes.DWORD, wintypes.DWORD, wintypes.DWORD,
                ctypes.c_void_p]
            RaiseException.restype = None
            
            IsDebuggerPresent = kernel32.IsDebuggerPresent
            IsDebuggerPresent.argtypes = []
            IsDebuggerPresent.restype = wintypes.BOOL
            MS_VC_EXCEPTION = 0x406D1388
            info = THREADNAME_INFO()
            while True:
                time.sleep(1)
                if IsDebuggerPresent():
                    for thread in threading.enumerate():
                        if thread.ident is None:
                            continue  # not started
                        if hasattr(threading, '_MainThread'):
                            if isinstance(thread, threading._MainThread):
                                continue  # don't name the main thread
                        info.szName = '%s (Python)' % (thread.name,)
                        info.dwThreadID = thread.ident
                        try:
                            RaiseException(MS_VC_EXCEPTION, 0, ctypes.sizeof(info) / ctypes.sizeof(ctypes.c_void_p), ctypes.addressof(info))
                        except:
                            pass
        
        
        dt = threading.Thread(target=debugChecker, name='MSVC debugging support thread')
        dt.daemon = True
        dt.start()
    elif sys.platform.startswith('linux'):
        # Return if python was build without ctypes (like in a static build)
        try:
            import ctypes
            import ctypes.util
        except ImportError:
            return
        import threading
        libpthread_path = ctypes.util.find_library("pthread")
        if not libpthread_path:
            return
        libpthread = ctypes.CDLL(libpthread_path)
        if not hasattr(libpthread, "pthread_setname_np"):
            return
        pthread_setname_np = libpthread.pthread_setname_np
        pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
        pthread_setname_np.restype = ctypes.c_int
        
        orig_setter = threading.Thread.__setattr__
        
        
        def attr_setter(self, prop_name, value):
            if prop_name == 'name' and isinstance(value, unicode):
                value = value.encode('utf-8')
            
            orig_setter(self, prop_name, value)
            if prop_name == 'name':
                ident = getattr(self, 'ident', None)
                if ident:
                    try:
                        pthread_setname_np(ident, str(value))
                    except:
                        pass  # Don't care about failure to set name
        
        
        threading.Thread.__setattr__ = attr_setter
        # TODO: manage python3
        if six.PY2:
            old_bootstrap_inner = threading.Thread._Thread__bootstrap_inner
            
            
            def namer_wapper(thread):
                threading.currentThread().name = thread.name[:15]
                old_bootstrap_inner(thread)
            
            
            threading.Thread._Thread__bootstrap_inner = namer_wapper


def get_thread_id():
    tid = 0
    if libc:
        tid = libc.syscall(186)  # get the thread id when you are in it :)
    return tid


def async_call(func):
    def wrapper(*args, **kwargs):
        thread_name = '%s-async' % func.__name__
        _thread = threading.Thread(target=func, name=thread_name, args=args, kwargs=kwargs)
        _thread.daemon = True
        _thread.start()
        return _thread
    
    
    return wrapper


_old_threads = {}


class LockWithTimer(object):
    
    def __init__(self, lock, my_logger=None, lock_name='generic lock'):
        self.lock = lock
        self.lock_name = lock_name
        
        if my_logger is None:
            self.logger = LoggerFactory.get_logger()
        else:
            self.logger = my_logger
    
    
    def acquire(self, *args, **kwargs):
        _start_wait_time = time.time()
        self.lock.acquire(*args, **kwargs)
        _wait_time = time.time() - _start_wait_time
        if _wait_time > 0.01:
            self.logger.debug('[%s] wait %s' % (self.lock_name, PartLogger.format_duration(_wait_time)))
    
    
    def __enter__(self):
        self.acquire()
    
    
    def release(self, *args, **kwargs):
        self.lock.release(*args, **kwargs)
    
    
    def __exit__(self, _type, value, tb):
        self.release()


NO_SLEEP = -1
UNSET_PID = -1

class Thread(object):
    def __init__(self, only_one_thread_by_class=False, loop_speed=1, force_stop_with_application=True, stop_thread_on_error=True, stop_app_on_error=False, error_handler=None, logger=None):
        my_class = type(self)
        if only_one_thread_by_class and _old_threads.get(my_class, None):
            _old_threads[my_class].ask_stop()
        
        self._running = False
        self._thread = None
        self.loop_speed = loop_speed
        self._force_stop_with_application = force_stop_with_application
        self._stop_thread_on_error = stop_thread_on_error
        self._stop_app_on_error = stop_app_on_error
        self._error_handler = error_handler
        self._pause_event = Event()
        self._starting_pid = UNSET_PID  # Used to detect zombie thread object
        if logger:
            self.logger = logger
        else:
            self.logger = LoggerFactory.get_logger()
        
        if only_one_thread_by_class:
            _old_threads[my_class] = self
    
    
    def start_thread(self):
        if not self._running:
            thread_name = self.get_thread_name()
            self._thread = threading.Thread(None, target=self._run, name=thread_name)
            self._thread.daemon = self._force_stop_with_application
            self._running = True
            self._starting_pid = os.getpid()  # we cannot live across fork() so we do not need to update this value after our start()
            self._pause_event.clear()
            self._thread.start()
    
    
    # A zombie thread is a thread that was started in a process, and is now in another process after a fork()
    # but fork() do not copy threads, so this object IS dead, it's now a zombie object
    def _is_zombie_thread(self):
        return self._starting_pid != UNSET_PID and os.getpid() != self._starting_pid
    
    
    def is_running(self):
        # We cannot be running if we are a zombie
        if self._is_zombie_thread():
            return False
        return self._running
    
    
    def ask_stop(self):
        if self._is_zombie_thread():  # Do not call more thing, we are already stop, we must NOT touch event and such broken object
            return
        self._running = False
        self._pause_event.set()
    
    
    def stop(self):
        if self._is_zombie_thread():  # Do not call more thing, we are already stop, we must NOT touch event and such broken object
            return
        if self._thread:
            self.ask_stop()
            self._thread.join()
    
    
    def get_thread_name(self):
        raise NotImplementedError()
    
    
    def _run(self):
        while self._running:
            try:
                self.loop_turn()
                if self.loop_speed != NO_SLEEP:
                    self.interruptable_sleep(self.loop_speed)
            except Exception as e:
                if self._error_handler:
                    try:
                        self._error_handler.handle_exception('Fatal error caused by : %s' % e, e, self.logger, level=ERROR_LEVEL.FATAL)
                    except Exception:
                        self.logger.error('Thread %s have a fatal error : %s' % (self.get_thread_name(), e))
                        self.logger.print_stack()
                else:
                    self.logger.error('Thread %s have a fatal error : %s' % (self.get_thread_name(), e))
                    self.logger.print_stack()
                if self._stop_app_on_error:
                    self.logger.error('Thread %s force stopping application' % self.get_thread_name())
                    os._exit(1)
                if self._stop_thread_on_error:
                    self._running = False
    
    
    def loop_turn(self):
        raise NotImplementedError()
    
    
    def interruptable_sleep(self, raw_duration):
        self._pause_event.wait(raw_duration)
        if self._running and self._pause_event.is_set():
            self._pause_event.clear()
    
    
    def notify(self):
        self._pause_event.set()
    
    
    def after_fork_cleanup(self):
        global _old_threads
        _old_threads = {}
        if self._thread:
            try:
                # this may be useless ?
                self._thread.join()
            except:
                pass
        self._thread = None
        self._pause_event = None
        self._error_handler = None
        self.logger = None


class DelayedThread(Thread):
    def loop_turn(self):
        raise NotImplementedError()
    
    
    def get_thread_name(self):
        raise NotImplementedError()
    
    
    def __init__(self, delay_to_wait=0, only_one_thread_by_class=False, loop_speed=1, force_stop_with_application=True, stop_thread_on_error=True, stop_app_on_error=False, error_handler=None, logger=None):
        self.delay_to_wait = delay_to_wait
        super(DelayedThread, self).__init__(only_one_thread_by_class, loop_speed, force_stop_with_application, stop_thread_on_error, stop_app_on_error, error_handler, logger)
    
    
    def _run(self):
        # type: () -> None
        # If we are not delayed, then no need to log that we are delayed
        if self.delay_to_wait:
            self.logger.info(u'Waiting for [ %d ] seconds before running the thread' % self.delay_to_wait)
            time.sleep(self.delay_to_wait)
        super(DelayedThread, self)._run()
    
    
    def set_delay_to_wait(self, delay_to_wait):
        # type: (int) -> None
        self.delay_to_wait = delay_to_wait
