#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2019
# This file is part of Shinken Enterprise, all rights reserved.
import threading
import time
from datetime import timedelta, datetime

from ec_common import shared_data, STATE, ACK, DOWNTIME, STATE_TYPE, ITEM_TYPE
from ec_database_connection import ECDatabaseConnection, ECBulk
from ec_event import Event
from shinken.acknowledge import Acknowledge
from shinken.brokermoduleworker import BrokerModuleWorker
from shinken.check import NO_END_VALIDITY
from shinken.downtime import Downtime
from shinken.misc.type_hint import Dict, Union, List
from shinken.objects.itemsummary import HostSummary, CheckSummary
from shinken.thread_helper import LockWithTimer
from shinkensolutions.date_helper import get_datetime_with_local_time_zone, get_now
from shinkensolutions.lib_modules.configuration_reader import read_int_in_configuration

LOOP_SPEED = 1  # in sec
EVENT_BULK_SPEED = 1  # in sec
CHECK_MISSING_DATA_SPEED = 10  # in sec
MARGIN_BEFORE_MISSING_DATA = 30  # in sec
TIMEDELTA_MARGIN_BEFORE_MISSING_DATA = timedelta(seconds=MARGIN_BEFORE_MISSING_DATA)

SHINKEN_INACTIVE_OUTPUT = 'shinken inactive'
MISSING_DATA_OUTPUT = 'missing data'

_cache_write_downtime = set()
_cache_write_acknowledge = set()


class StateCache(object):
    comprehensive_state = 0
    state_type = 0
    expiration_time = None  # type: datetime
    event_uuid = 0
    
    
    def __init__(self, comprehensive_state, state_type, expiration_time, event_uuid):
        self.comprehensive_state = comprehensive_state
        self.state_type = state_type
        self.expiration_time = expiration_time
        self.event_uuid = event_uuid
    
    
    def update_expiration_time(self, at):
        self.expiration_time = at
    
    
    def update_state_type(self, state_type):
        self.state_type = state_type
    
    
    @classmethod
    def from_database_entry(cls, _to_load):
        # type: (dict) -> StateCache
        return StateCache(comprehensive_state=_to_load['comprehensive_state'], state_type=_to_load['state_type'], expiration_time=None, event_uuid=_to_load['event_uuid'])
    
    
    def to_database_entry(self):
        return {
            'comprehensive_state': self.comprehensive_state,
            'state_type'         : self.state_type,
            'event_uuid'         : self.event_uuid,
        }


class BrokHandlerModuleWorker(BrokerModuleWorker):
    database_connection = None  # type: ECDatabaseConnection
    _state_cache = {}  # type: Dict[str, StateCache]
    _event_bulk = None  # type: ECBulk
    _state_cache_update = set()
    _see_this_boot = set()
    _state_cache_lock = threading.RLock()
    _last_check_missing_data = -1
    _last_event_bulk_execute = -1
    _write_event_cumulative = (0.0, 0)
    
    
    def init_worker(self, configuration):
        global MARGIN_BEFORE_MISSING_DATA, CHECK_MISSING_DATA_SPEED, TIMEDELTA_MARGIN_BEFORE_MISSING_DATA
        if self.logger.is_debug():
            self._state_cache_lock = LockWithTimer(self._state_cache_lock, my_logger=self.logger, lock_name='state_cache_lock')
        
        self.database_connection = ECDatabaseConnection(configuration, self.logger)
        self.database_connection.init('%s %s' % (self.father_name, self._get_worker_display_name()))
        self._load_state_cache()
        self._last_check_missing_data = -1
        self._event_bulk = self.database_connection.event_bulk_factory()
        self._write_event_cumulative = (0.0, 0)
        
        # We do not use stats from our parent
        self.compute_stats = False
        CHECK_MISSING_DATA_SPEED = read_int_in_configuration(configuration, 'check_missing_data_speed', CHECK_MISSING_DATA_SPEED)
        MARGIN_BEFORE_MISSING_DATA = read_int_in_configuration(configuration, 'margin_before_missing_data', MARGIN_BEFORE_MISSING_DATA)
        TIMEDELTA_MARGIN_BEFORE_MISSING_DATA = timedelta(seconds=MARGIN_BEFORE_MISSING_DATA)
        
        self.logger.debug('worker start at %s' % self.logger.format_datetime(get_datetime_with_local_time_zone()))
    
    
    def on_init(self):
        pass
    
    
    def manage_service_check_result_brok(self, new_info):
        self._handle_brok(new_info)
    
    
    def manage_host_check_result_brok(self, new_info):
        self._handle_brok(new_info)
    
    
    def manage_update_service_status_brok(self, new_info):
        self._handle_brok(new_info)
    
    
    def manage_update_host_status_brok(self, new_info):
        self._handle_brok(new_info)
    
    
    def manage_initial_service_status_brok(self, new_info):
        self._handle_brok(new_info)
    
    
    def manage_initial_host_status_brok(self, new_info):
        self._handle_brok(new_info)
    
    
    def worker_main(self):
        while not self.interrupted:
            self._tick()
            self.interruptable_sleep(LOOP_SPEED)
    
    
    def get_raw_stats(self):
        data = super(BrokHandlerModuleWorker, self).get_raw_stats()
        data['write_event_cumulative'] = self._write_event_cumulative
        return data
    
    
    def _do_in_worker_manage_broks_thread_loop(self):
        super(BrokHandlerModuleWorker, self)._do_in_worker_manage_broks_thread_loop()
        now = get_now()
        if now - self._last_check_missing_data > CHECK_MISSING_DATA_SPEED:
            self._check_missing_data()
            self._last_check_missing_data = now
    
    
    @staticmethod
    def _is_state_expired(expiration_time):
        if not expiration_time or expiration_time == NO_END_VALIDITY:
            return False
        return expiration_time + TIMEDELTA_MARGIN_BEFORE_MISSING_DATA < get_datetime_with_local_time_zone()
    
    
    def _check_missing_data(self):
        for inventory_item in self.get_all_inventory():
            cache_value = self._state_cache.get(inventory_item.get_uuid(), None)  # type: StateCache
            # logger.debug('[%s] [check_missing_data] of %s : expiration_time:[%s]' % (
            #     inventory_item.get_uuid(),
            #     inventory_item.get_instance_name(),
            #     logger.format_datetime(cache_value.expiration_time) if cache_value else 'no cache_value'
            #
            # ))
            if not cache_value or not self._is_state_expired(cache_value.expiration_time):
                continue
            
            self.logger.debug('[%s] new missing_data cache_value.expiration_time:[%s] at [%s]' % (
                inventory_item.get_uuid(),
                self.logger.format_datetime(cache_value.expiration_time),
                self.logger.format_datetime(get_datetime_with_local_time_zone()),
            ))
            missing_data_event = self._build_missing_data_event(inventory_item, cache_value.expiration_time)
            self._put_event(missing_data_event)
    
    
    def _handle_brok(self, brok):
        brok_data = brok.data
        # logger.debug('[%s][%s-%s] new [%s] brok at [%s]' % (brok_data['instance_uuid'], brok_data['host_name'], brok_data.get('service_description', ''), brok.type, logger.format_time(brok_data['creation_date'])))
        event = Event.from_brok_data(brok)
        self._put_event(event)
        
        downtime = brok_data.get('in_scheduled_downtime', None)
        active_downtime_uuids = brok_data.get('active_downtime_uuids', None)
        if downtime:
            downtimes = brok_data.get('downtimes', None)
            self._create_downtime(active_downtime_uuids, downtimes)
        
        acknowledgement = brok_data.get('acknowledgement', None)
        # logger.debug('[%s][%s-%s] new [%s] brok [%s]/[%s]' % (
        #     brok_data['instance_uuid'],
        #     brok_data['host_name'],
        #     brok_data.get('service_description', ''),
        #     brok.type, brok_data.get('acknowledgement', None),
        #     brok_data.get('partial_acknowledge', None)))
        if acknowledgement:
            self._create_acknowledge(acknowledgement, event.item_uuid)
        
        partial_acknowledge = brok_data.get('partial_acknowledge', None)
        if partial_acknowledge:
            self._create_acknowledge(partial_acknowledge, event.item_uuid)
    
    
    def _create_downtime(self, uuid_downtimes, downtimes):
        # type: (List[basestring], List[Downtime]) -> None
        if not uuid_downtimes:
            return
        for uuid_downtime in uuid_downtimes[:]:
            if uuid_downtime in _cache_write_downtime:
                continue
            
            prev_downtime = self.database_connection.find_downtime(uuid_downtime)
            if prev_downtime:
                _cache_write_downtime.add(uuid_downtime)
                continue
            
            downtime = next((i for i in downtimes if i.uuid == uuid_downtime), None) if downtimes else None
            if not downtime:
                # Inherited downtime aren't in broks so we expect to have the downtime definition in the host brok
                continue
            
            downtime_entry = {
                '_id'    : uuid_downtime,
                'author' : downtime.author,
                'comment': downtime.comment,
            }
            for k in downtime.__class__.properties:
                downtime_entry[k] = getattr(downtime, k)
            _cache_write_downtime.add(uuid_downtime)
            self.database_connection.save_downtime(downtime_entry)
            self.logger.debug('Saving new downtime: %s' % downtime_entry)
    
    
    def _create_acknowledge(self, acknowledge, item_uuid):
        # type: (Acknowledge, basestring) -> None
        acknowledge_id = acknowledge.id
        if acknowledge_id in _cache_write_acknowledge:
            return
        prev_acknowledge = self.database_connection.find_acknowledge(acknowledge_id)
        if prev_acknowledge is None:
            # ok there is not such acknowledge before, save it
            acknowledge_entry = {'_id': acknowledge_id, 'item_uuid': item_uuid}
            for k in acknowledge.__class__.properties:
                acknowledge_entry[k] = getattr(acknowledge, k)
            self.database_connection.save_acknowledge(acknowledge_entry)
            self.logger.debug('Saving new acknowledge: %s' % acknowledge_entry)
        _cache_write_acknowledge.add(acknowledge.id)
    
    
    def _put_event(self, event):
        # type: (Event) -> None
        start_time = time.time()
        if self._is_state_expired(event.get_state_expiration_time()):
            # logger.debug('[%s] put_event reject event because the state is expire. Expire at [%s]' % (event.item_uuid, logger.format_datetime(event.get_state_expiration_time() + TIMEDELTA_MARGIN_BEFORE_MISSING_DATA)))
            return
        
        cache_value = self._state_cache.get(event.item_uuid, None)  # type: StateCache
        event_comprehensive_state = event.get_comprehensive_state()
        have_state_in_cache = cache_value and cache_value.comprehensive_state == event_comprehensive_state
        if have_state_in_cache:
            # logger.debug('[%s] put_event event in cache comprehensive_state:[%s]' % (event.item_uuid, event_comprehensive_state))
            if event.state_type != cache_value.state_type:
                with self._state_cache_lock:
                    # logger.debug('[%s] must update state_type of event:[%s] to [%s]' % (event.item_uuid, cache_value.event_uuid, event.state_type))
                    self._event_bulk.update_event(cache_value.event_uuid, event.state_type, event.event_hard_since)
                    self._state_cache[event.item_uuid].update_state_type(event.state_type)
                    self._state_cache_update.add(event.item_uuid)
                    self._write_event_cumulative = (start_time, self._write_event_cumulative[1] + 1)
            
            cache_value.update_expiration_time(event.get_state_expiration_time())
            # logger.debug('[%s] put_event update_expiration_time to :[%s]' % (event.item_uuid, logger.format_datetime(event.get_state_expiration_time())))
        else:
            with self._state_cache_lock:
                # logger.debug('[%s] put_event event not in cache comprehensive_state:[%s] expiration_time:[%s]' % (event.item_uuid, event_comprehensive_state, logger.format_datetime(event.get_state_expiration_time())))
                self._event_bulk.insert_event(event)
                self._state_cache[event.item_uuid] = StateCache(event_comprehensive_state, event.state_type, event.get_state_expiration_time(), event.event_uuid)
                self._state_cache_update.add(event.item_uuid)
                self._write_event_cumulative = (start_time, self._write_event_cumulative[1] + 1)
    
    
    def _tick(self):
        now = get_now()
        with self._state_cache_lock:
            if now - self._last_event_bulk_execute > EVENT_BULK_SPEED:
                self._event_bulk.bulks_execute()
                self._last_event_bulk_execute = now
            
            if self._state_cache_update:
                to_update = [{'_id': _id, 'state_cache': self._state_cache[_id].to_database_entry()} for _id in self._state_cache_update]
                self._state_cache_update = set()
                self.database_connection.update_state_cache(to_update)
    
    
    def _load_state_cache(self):
        raw_state_caches = self.database_connection.load_state_cache()
        for raw_state_cache in raw_state_caches:
            self._state_cache[raw_state_cache['_id']] = StateCache.from_database_entry(raw_state_cache['state_cache'])
    
    
    def add_shinken_inactive_event(self, item):
        item_uuid = item.get_uuid()
        if item_uuid in self._see_this_boot:
            return
        shinken_inactive_event = self._build_shinken_inactive_event(item)
        self._put_event(shinken_inactive_event)
        self._see_this_boot.add(item_uuid)
    
    
    @staticmethod
    def _build_shinken_inactive_event(item):
        # type: (Union[HostSummary, CheckSummary]) -> Event
        event_since = get_datetime_with_local_time_zone(shared_data.get_last_start_time())
        shinken_inactive_event = BrokHandlerModuleWorker._build_artificial_event(item, event_since, STATE.SHINKEN_INACTIVE, SHINKEN_INACTIVE_OUTPUT)
        return shinken_inactive_event
    
    
    @staticmethod
    def _build_missing_data_event(item, event_since):
        # type: (Union[HostSummary, CheckSummary], datetime) -> Event
        missing_data_event = BrokHandlerModuleWorker._build_artificial_event(item, event_since, STATE.MISSING_DATA, MISSING_DATA_OUTPUT)
        return missing_data_event
    
    
    @staticmethod
    def _build_artificial_event(item, event_since, state_id, output):
        missing_data_event = Event(
            item_uuid=item.get_uuid(),
            event_since=event_since,
            event_hard_since=event_since,
            item_type=item.get_type(),
            state_id=state_id,
            state_type=STATE_TYPE.HARD,
            state_validity_period=NO_END_VALIDITY,
            flapping=False,
            acknowledged=ACK.NONE,
            downtime=DOWNTIME.NONE,
            partial_flapping=False,
            partial_ack=False,
            partial_dt=False,
            active_downtime_uuids=[],
            output=output,
            long_output='',
            realm=item.get_realm(),
            name=item.get_full_name(),
            host_name=item.get_host_name(),
            check_name=item.get_name() if item.get_type() == ITEM_TYPE.CHECK else '',
        )
        return missing_data_event
    
    
    def callback__a_new_host_added(self, host_uuid):
        
        if not shared_data.get_shinken_inactive_period():
            return
        
        # Unknown element don't have shinken inactive it must be there first time
        if not self.database_connection.is_in_state_cache(host_uuid):
            return
        
        host = self.get_host_from_uuid(host_uuid)
        self.add_shinken_inactive_event(host)
        for check in host.get_checks().itervalues():
            self.add_shinken_inactive_event(check)
