#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2009-2012:
#    Gabes Jean, naparuba@gmail.com
#    Gerhard Lausser, Gerhard.Lausser@consol.de
#    Gregory Starck, g.starck@gmail.com
#    Hartmut Goebel, h.goebel@goebel-consult.de
#
# This file is part of Shinken.
#
# Shinken is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Shinken is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Shinken.  If not, see <http://www.gnu.org/licenses/>.

import sys
import threading
import time
import traceback
from collections import deque
from typing import NamedTuple

from shinken.compat import cPickle, Empty

from shinken.load import AvgForFixSizeCall
from shinken.misc.type_hint import TYPE_CHECKING
from shinken.moduleworker import ModuleWorker
from shinken.safepickle import SafeUnpickler

if TYPE_CHECKING:
    from shinken.misc.type_hint import Optional, Any
    from shinken.brok import Brok

ONE_MINUTE = 60
SAMPLING_TIME = 10

# Send broks by stack of X or at least once a second
MAX_BROKS_TO_SEND_QUEUE_SIZE = 100

MANAGE_BROKS_LOGGER_SUB_PART = 'MANAGE BROKS'  # type: str
MANAGE_BROKS_LOG_LOOP_TIME = 30  # type: int


class _BrokSizeDetails(NamedTuple):
    item_uuid: str
    number_of_elements: 'list[tuple[int, str]]'
    elements_size: 'list[tuple[int, str]]'


# The Broker module Worker is a classic worker but with new capabilities so:
# * it does inherit from the inventory features
# * it adds methods to send broks to the sub-process
# * it adds a thread that get the broks from the main process and call the user manage
class BrokerModuleWorker(ModuleWorker):
    
    def __init__(self, worker_id, mod_conf, name, from_module_to_main_daemon_queue, queue_factory, daemon_display_name):
        super(BrokerModuleWorker, self).__init__(worker_id, mod_conf, name, from_module_to_main_daemon_queue, queue_factory, daemon_display_name)
        
        self._brok_serialization_time_warning_threshold: float = 0.0
        self._brok_serialization_time_error_threshold: float = 0.0
        
        self.read_configuration_from_father_config()
        
        # We should create a new queue __init__ is in the main process, that the sub-process worker will inherit
        self._main_process_to_worker_broks_queue = self._queue_factory('BRK', name, 'W[%s]' % worker_id)
        
        self._last_brok_push = 0.0
        self._to_send_broks_stack = []
        
        self._in_worker_to_manage_queue_lock = None  # type: Optional[threading.Condition]
        self._in_worker_to_manage_queue = deque()
        
        # STATS:
        # If compute_stats is at False only cumulative time will be set
        self.compute_stats = True
        
        self._manage_work_time = AvgForFixSizeCall(time_limit=ONE_MINUTE)
        self._manage_work_time_sampling = self._get_time_sampling_struct()  # [-1 for _ in xrange(int(ONE_MINUTE / SAMPLING_TIME))]
        self._manage_work_time_last_print = time.time()
        self._manage_work_time_cumulative = (0.0, 0.0, 0)
        
        self._manage_work_remaining_broks = 0
        
        self._get_broks_work_time = AvgForFixSizeCall(time_limit=ONE_MINUTE)
        self._get_broks_work_time_sampling = self._get_time_sampling_struct()  # [-1 for _ in xrange(int(ONE_MINUTE / SAMPLING_TIME))]
        self._get_broks_work_time_last_print = time.time()
        self._get_broks_work_time_cumulative = (0.0, 0.0, 0)
        
        self._get_broks_time_manage_broks = 0
        self._get_broks_time_compute_stats = 0
        
        self._time_loop_manage_brok = 0
        self._time_for_manage_broks = 0
        self._time_for_prepare_broks = 0
        self._time_for_compute_stats = 0
        self._time_for_get_from_manage_queue = 0
        self._time_waiting_for_manage_queue = 0
    
    
    def set_father_config_in_main_process(self, father_config: 'dict[str, Any]') -> None:
        self.myconf.set_father_config(father_config)
        self.read_configuration_from_father_config()
    
    
    def read_configuration_from_father_config(self) -> None:
        father_config: 'dict[str,Any]' = self.myconf.father_config
        self._brok_serialization_time_warning_threshold = father_config.get('broker__manage_brok__oversized_data_warning_threshold__serialization_time', 100)
        self._brok_serialization_time_error_threshold = father_config.get('broker__manage_brok__oversized_data_error_threshold__serialization_time', 500)
        
        init_logger = self.logger.get_sub_part('INITIALISATION')
        init_logger.info(f'oversized_data_warning_threshold__serialization_time - [ {self._brok_serialization_time_warning_threshold} ]')
        init_logger.info(f'oversized_data_error_threshold__serialization_time --- [ {self._brok_serialization_time_error_threshold} ]')
        
        # Unit conversion: ms -> s
        self._brok_serialization_time_warning_threshold /= 1000
        self._brok_serialization_time_error_threshold /= 1000
        
        if self._brok_serialization_time_error_threshold < self._brok_serialization_time_warning_threshold:
            init_logger.warning('Inconsistency warning: "broker__manage_brok__oversized_data_error_threshold__serialization_time" is less than "broker__manage_brok__oversized_data_warning_threshold__serialization_time". The error threshold will use the warning threshold.')
            self._brok_serialization_time_error_threshold = self._brok_serialization_time_warning_threshold
    
    
    def worker_main(self):
        raise NotImplementedError()
    
    
    @staticmethod
    def _get_time_sampling_struct():
        return [-1 for _ in range(int(ONE_MINUTE / SAMPLING_TIME))]
    
    
    def _force_push_to_worker(self):
        if len(self._to_send_broks_stack) == 0:
            return
        queue = self._main_process_to_worker_broks_queue
        # self.logger.debug(u'Force push from %s to exchange queue %s sender named: %s' % (self.get_name(), queue, queue.get_sender_name()))
        
        try:
            for item in self._to_send_broks_stack:
                queue.send(item)
            self._to_send_broks_stack = []
            self._last_brok_push = time.time()
        except Exception as exp:
            raise Exception('[%s] Cannot send the brok to the worker %s: %s' % (self.get_name(), self._worker_id, exp))
    
    
    def _send_brok_to_worker(self, brok):
        # Stack
        # NOTE: we do pre-serialize the broks because the multiprocessing lib seems to be using pickle and not cPickle.
        serialize_start_time = time.time()
        serialized_brok = cPickle.dumps(brok, cPickle.HIGHEST_PROTOCOL)
        serialize_duration = time.time() - serialize_start_time
        if serialize_duration >= self._brok_serialization_time_warning_threshold:
            serialized_brok_size = sys.getsizeof(serialized_brok)
            self._log_brok_with_too_long_serialization_time(brok, serialize_duration, serialized_brok_size)
        
        try:
            self._main_process_to_worker_broks_queue.send(serialized_brok)
            self._last_brok_push = time.time()
        except Exception:
            self._to_send_broks_stack.append(serialized_brok)
        
        # and check if we should send it now or not
        now = time.time()
        if now > self._last_brok_push + 1 or len(self._to_send_broks_stack) > MAX_BROKS_TO_SEND_QUEUE_SIZE:
            self._force_push_to_worker()
    
    
    def _log_brok_with_too_long_serialization_time(self, brok: 'Brok', serialization_duration: float, serialized_brok_size: int) -> None:
        oversize_data_logger = self.logger.get_sub_part(MANAGE_BROKS_LOGGER_SUB_PART, register=False).get_sub_part('OVERSIZED DATA')
        nb_elements_details_logger = oversize_data_logger.get_sub_part('DETAILS')
        elements_size_details_logger = oversize_data_logger.get_sub_part('SIZE')
        
        nb_elements_details_log = nb_elements_details_logger.error if serialization_duration >= self._brok_serialization_time_error_threshold else nb_elements_details_logger.warning
        elements_size_details_log = elements_size_details_logger.error if serialization_duration >= self._brok_serialization_time_error_threshold else elements_size_details_logger.warning
        
        brok_size_details = self._get_brok_size_details(brok)
        brok_type: str = f'"{brok.type}"'
        if brok_size_details.item_uuid:
            brok_type = f'{brok_type} (item uuid: {brok_size_details.item_uuid})'
        
        nb_elements_details_msg = elements_size_details_msg = f'The brok of type {brok_type} took too much time to be serialized [{serialization_duration:.3f}s] (with size {serialized_brok_size}B) and may cause Brok management slow down.'
        if brok_size_details.number_of_elements:
            nb_elements_details_msg = f'''{nb_elements_details_msg} Detail of potential expensive content: {', '.join([f'{prop}:{prop_count}' for prop_count, prop in brok_size_details.number_of_elements])}'''
        if brok_size_details.elements_size:
            elements_size_details_msg = f'''{elements_size_details_msg} Size of potential expensive content: {', '.join([f'{prop}:{prop_size}B' for prop_size, prop in brok_size_details.elements_size])}'''
        
        nb_elements_details_log(nb_elements_details_msg)
        elements_size_details_log(elements_size_details_msg)
    
    
    @staticmethod
    def _get_brok_size_details(brok: 'Brok') -> '_BrokSizeDetails':
        item_uuid: str = brok.data.get('instance_uuid', '')
        brok_size_details = _BrokSizeDetails(item_uuid=item_uuid, number_of_elements=[], elements_size=[])
        
        if item_uuid:
            # Brok for a monitoring item
            element: 'dict[str, Any]' = brok.data
            
            notification_list: 'list[dict[str,Any]]' = element.get('notification_list', [])
            incident_nb = len(notification_list)
            notification_nb = sum(len(x.get('bloc_content', [])) for x in notification_list)
            brok_size_details.number_of_elements.append((incident_nb, 'incident nb'))
            brok_size_details.number_of_elements.append((notification_nb, 'total notifications nb'))
            
            outputs_size = sys.getsizeof(element.get('output', '')) + sys.getsizeof(element.get('long_output', ''))
            brok_size_details.elements_size.append((outputs_size, 'outputs size'))
            
            perf_data = sys.getsizeof(element.get('perf_data', ''))
            brok_size_details.elements_size.append((perf_data, 'current perf data size'))
            
            downtimes = len(element.get('downtimes', []))
            brok_size_details.number_of_elements.append((downtimes, 'downtimes nb'))
            
            downtimes_size = sum((sys.getsizeof(dt.author) + sys.getsizeof(dt.comment)) for dt in element.get('downtimes', []))
            brok_size_details.elements_size.append((downtimes_size, 'downtimes user content size'))
            
            acknowledgement_size = 0
            for ack in filter(None, [element.get('acknowledgement', None), element.get('partial_acknowledge', None)]):
                acknowledgement_size += sys.getsizeof(ack.author) + sys.getsizeof(ack.comment)
            brok_size_details.elements_size.append((acknowledgement_size, 'acknowledgement user content size'))
            
            nb_source_problems = len(element.get('source_problems', []))
            brok_size_details.number_of_elements.append((nb_source_problems, 'source problems nb'))
            
            svc_hst_distinct_lists_properties = ('parent_dependencies', 'child_dependencies', 'impacts')
            for prop in svc_hst_distinct_lists_properties:
                prop_label = prop.replace('_', ' ')
                hosts_and_services: 'dict[str, list[str]]' = element.get(prop, {})
                for item_class in ('hosts', 'services'):
                    brok_size_details.number_of_elements.append((len(hosts_and_services.get(item_class, [])), f'{prop_label} ({item_class}) nb'))
            
            if 'childs' in element:  # Only hosts have this property.
                nb_network_dep_children = len(element['childs'])
                brok_size_details.number_of_elements.append((nb_network_dep_children, 'network dependent children nb'))
        
        brok_size_details.number_of_elements.sort(reverse=True)
        brok_size_details.elements_size.sort(reverse=True)
        
        return brok_size_details
    
    
    # In the raw stats, we will have to sum the main two threads, the get_broks & manage broks one
    def get_raw_stats(self):
        data = super(BrokerModuleWorker, self).get_raw_stats()
        
        total_cumulative = (
            max(self._manage_work_time_cumulative[0], self._get_broks_work_time_cumulative[0]),
            self._manage_work_time_cumulative[1] + self._get_broks_work_time_cumulative[1],
            self._manage_work_time_cumulative[2]
        )
        data['work_time_cumulative'] = {
            'manage_work_time'   : self._manage_work_time_cumulative,
            'get_broks_work_time': self._get_broks_work_time_cumulative,
            'total'              : total_cumulative,
        }
        data['work_remaining_broks'] = self._manage_work_remaining_broks
        if not self.compute_stats:
            return data
        # Times
        manage_work_times = self._manage_work_time.get_sum_in_range(avg_on_time=True, with_range_size=True)
        get_broks_work_times = self._get_broks_work_time.get_sum_in_range(avg_on_time=True, with_range_size=True)
        # Beware: sum value, but do not sum the range size
        work_times = (manage_work_times[0] + get_broks_work_times[0], manage_work_times[1])
        data['work_time'] = work_times
        
        # Sampling
        work_time_sampling = self._get_time_sampling_struct()
        for idx, v in enumerate(self._manage_work_time_sampling):
            if v != -1:
                work_time_sampling[idx] = v
        for idx, v in enumerate(self._get_broks_work_time_sampling):
            if v != -1:
                work_time_sampling[idx] += v
        data['work_time_sampling'] = work_time_sampling
        return data
    
    
    # Every second, we do force the broks in queue to be sent, to not have out of time broks
    def in_main_process_tick(self):
        self._force_push_to_worker()
    
    
    def _do_in_worker_get_broks_thread(self):
        while True:
            try:
                self._do_in_worker_get_broks_thread_loop()
            except Exception as exp:
                self._send_exception_error_to_main_process_and_exit(traceback.format_exc(), exp, 'broks reading')
    
    
    # You can overload this methode to have your own "manage queue"
    # duo with _dequeue_brok_from_manage_queue
    def _enqueue_brok_into_manage_queue(self, brok):
        # type: (Brok) -> None
        with self._in_worker_to_manage_queue_lock:
            self._in_worker_to_manage_queue.append(brok)
            self._manage_work_remaining_broks = len(self._in_worker_to_manage_queue)
            self._in_worker_to_manage_queue_lock.notify_all()
    
    
    def _do_in_worker_get_broks_thread_loop(self):
        # We get from the Queue update from the receiver.
        # NOTE: as we are blocking, we are not using CPU and so no need for sleep
        try:
            if self._main_process_to_worker_broks_queue.poll(timeout=1):
                serialized_brok = self._main_process_to_worker_broks_queue.recv()
            else:
                return
        except Empty:
            return
        
        start_time = time.time()
        t0 = start_time
        get_time = 0
        # broks = []
        
        brok = SafeUnpickler.loads(serialized_brok, 'Brok get from main daemon')
        get_time += time.time() - t0
        # broks.append(brok)
        
        t0 = time.time()
        self._enqueue_brok_into_manage_queue(brok)
        get_time += time.time() - t0
        self._get_broks_time_manage_broks += get_time
        
        # Compute stats
        end_time = time.time()
        time_taken = end_time - start_time
        self._get_broks_work_time_cumulative = (end_time, self._get_broks_work_time_cumulative[1] + time_taken, self._get_broks_work_time_cumulative[2] + 1)
        if self.compute_stats:
            self._get_broks_work_time.update_avg(time_taken)
            last_print_period = end_time - self._get_broks_work_time_last_print
            if last_print_period > SAMPLING_TIME:
                self._get_broks_work_time_sampling.pop(0)
                self._get_broks_work_time_sampling.append(self._get_broks_work_time.get_sum_in_range(avg_on_time=False, with_range_size=False, time_limit_overload=last_print_period))
                self._get_broks_work_time_last_print = end_time
        self._get_broks_time_compute_stats += time.time() - end_time
    
    
    def _do_in_worker_manage_broks_thread(self):
        # Important: we cannot manage broks until the worker is fully init
        self._wait_for_worker_init_to_be_done()
        
        while True:
            try:
                self._do_in_worker_manage_broks_thread_loop()
            except Exception as exp:
                self._send_exception_error_to_main_process_and_exit(traceback.format_exc(), exp, 'broks managing')
    
    
    # NOTE: you can overload this method if you want another "manage queue".
    # duo with _enqueue_brok_into_manage_queue
    def _dequeue_brok_from_manage_queue(self):
        # type: () -> Optional[Brok]
        brok = None
        get_time = 0
        wait_time = 0
        t0 = time.time()
        with self._in_worker_to_manage_queue_lock:
            if len(self._in_worker_to_manage_queue) == 0:
                get_time = time.time() - t0
                t0 = time.time()
                self._in_worker_to_manage_queue_lock.wait(1.0)
                wait_time = time.time() - t0
                t0 = time.time()
            if len(self._in_worker_to_manage_queue) != 0:
                brok = self._in_worker_to_manage_queue.popleft()
                self._manage_work_remaining_broks = len(self._in_worker_to_manage_queue)
        get_time = get_time + time.time() - t0
        self._time_for_get_from_manage_queue += get_time
        self._time_waiting_for_manage_queue += wait_time
        return brok
    
    
    def _get_brok_from_manage_queue(self):
        brok = self._dequeue_brok_from_manage_queue()
        return brok
    
    
    def _do_in_worker_manage_broks_thread_loop(self):
        full_loop_start = time.time()
        brok = self._get_brok_from_manage_queue()
        if brok is None:
            self._time_loop_manage_brok += time.time() - full_loop_start
            return
        
        start_time = time.time()
        manage = getattr(self, 'manage_' + brok.type + '_brok', None)
        if manage:
            # Be sure the brok is prepared before call it
            brok.prepare()
            t0 = time.time()
            self._time_for_prepare_broks += t0 - start_time
            manage(brok)
            t1 = time.time()
            self._time_for_manage_broks += t1 - t0
        
        # Compute stats
        end_time = time.time()
        time_taken = end_time - start_time
        self._manage_work_time_cumulative = (end_time, self._manage_work_time_cumulative[1] + time_taken, self._manage_work_time_cumulative[2] + 1)
        if self.compute_stats:
            self._manage_work_time.update_avg(time_taken)
            last_print_period = end_time - self._manage_work_time_last_print
            if last_print_period > SAMPLING_TIME:
                self._manage_work_time_sampling.pop(0)
                self._manage_work_time_sampling.append(self._manage_work_time.get_sum_in_range(avg_on_time=False, with_range_size=False, time_limit_overload=last_print_period))
                self._manage_work_time_last_print = end_time
        
        self._time_for_compute_stats += time.time() - end_time
        self._time_loop_manage_brok += time.time() - full_loop_start
    
    
    def _do_in_worker_log_pending_broks_thread(self):
        manage_broks_logger = self.logger.get_sub_part(MANAGE_BROKS_LOGGER_SUB_PART)
        while not self.interrupted:
            manage_broks_logger.info('%d brok%s waiting in worker queue' % (self._manage_work_remaining_broks, 's' if self._manage_work_remaining_broks > 1 else ''))
            self.interruptable_sleep(MANAGE_BROKS_LOG_LOOP_TIME)
    
    
    def start_worker_specific_treads(self):
        # We are now in the worker process, we can create the locks
        self._in_worker_to_manage_queue_lock = threading.Condition(threading.RLock())
        
        # Also start a thread that will manage broks
        thr = threading.Thread(target=self._do_in_worker_get_broks_thread, name='WorkerGetBroks')
        thr.daemon = True
        thr.start()
        
        # Also start a thread that will manage broks
        thr = threading.Thread(target=self._do_in_worker_manage_broks_thread, name='WorkerManageBroks')
        thr.daemon = True
        thr.start()
        
        # Also start a thread that will print pending broks number to manage
        thr = threading.Thread(target=self._do_in_worker_log_pending_broks_thread, name='WorkerLogBroksPending')
        thr.daemon = True
        thr.start()
        
        # Our main queue need to starts its receiver thread too
        self._main_process_to_worker_broks_queue.get_queues_size()  # note: this call will start the receiver thread
    
    
    def get_main_process_to_worker_broks_queue_size(self):
        try:
            qsize = self._main_process_to_worker_broks_queue.qsize()
        except Exception as exp:
            self.logger.warning('Cannot get the size of the main process to workers queue: %s' % exp)
            qsize = -1
        return qsize
