#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2019
# This file is part of Shinken Enterprise, all rights reserved.
import uuid

import sla_common
from shinken.check import NO_END_VALIDITY
from shinken.log import PART_INITIALISATION
from shinken.misc.type_hint import List, Dict
from shinken.objects.module import Module as ShinkenModuleDefinition
from shinkensolutions.date_helper import DATE_COMPARE, date_now, get_start_of_day, get_end_of_day, compare_date, get_now, Date
from shinkensolutions.lib_modules.configuration_reader import read_int_in_configuration
from sla_abstract_component import AbstractComponent
from sla_common import STATUS, LIST_STATUS, RAW_SLA_KEY, shared_data
from sla_component_manager import ComponentManager
from sla_compute_percent_sla import ComputePercentSla
from sla_database import SLADatabase
from sla_info import SLAInfo

MARGIN_SLA_INACTIVE = 30


class SLAArchive(AbstractComponent):
    
    def __init__(self, conf, component_manager, sla_info, compute_percent_sla, sla_database):
        # type: (ShinkenModuleDefinition, ComponentManager,  SLAInfo, ComputePercentSla, SLADatabase)
        global MARGIN_SLA_INACTIVE
        super(SLAArchive, self).__init__(conf, component_manager)
        self.sla_info = sla_info
        self.compute_percent_sla = compute_percent_sla
        self.sla_database = sla_database
        
        MARGIN_SLA_INACTIVE = read_int_in_configuration(conf, 'time_before_shinken_inactive', MARGIN_SLA_INACTIVE)
        
        self.logger.info(PART_INITIALISATION, 'Parameter load for build sla archive')
        self.logger.info(PART_INITIALISATION, '   - time_before_shinken_inactive:[%s]' % MARGIN_SLA_INACTIVE)
    
    
    def init(self):
        pass
    
    
    def tick(self):
        pass
    
    
    def build_archive_for_missing_day(self, date, item_uuid, start_of_range=None, end_of_day=None, sla_thresholds=None, only_existing_day=False):
        current_date = date_now()
        item_monitoring_start_time = self.sla_info.get_monitoring_start_time(item_uuid)
        
        start_of_day = get_start_of_day(date)
        if start_of_range is not None and start_of_range != -1 and start_of_range > start_of_day:
            start_of_day = int(start_of_range)
        
        end_of_day = get_end_of_day(date) if end_of_day is None else end_of_day
        sla_thresholds = (99, 97) if sla_thresholds is None else sla_thresholds
        
        in_future = compare_date(current_date, date) == DATE_COMPARE.IS_AFTER
        in_past = end_of_day < item_monitoring_start_time
        
        if only_existing_day and (in_future or in_past):
            return None
        
        if start_of_day < item_monitoring_start_time < end_of_day:
            start_of_day = item_monitoring_start_time
        
        sla_archive = {
            '_id'             : uuid.uuid4().hex,
            'build_at'        : get_now(),
            'uuid'            : item_uuid,
            'yday'            : date.yday,
            'year'            : date.year,
            'ranges'          : [{'rc': STATUS.SHINKEN_INACTIVE, 'end': end_of_day, 'start': start_of_day, 'ack': False, 'dt': False, 'flg': False}],
            'version'         : sla_common.CURRENT_ARCHIVE_VERSION,
            'history_inactive': end_of_day - start_of_day,
            'history_total'   : end_of_day - start_of_day,
            'thresholds'      : sla_thresholds,
            'missing'         : True,
        }
        
        self.compute_percent_sla.compute_sla(sla_archive)
        for prefix in LIST_STATUS:
            _sum = sla_archive.get('sla_%s' % prefix, 0)
            if _sum:
                sla_archive['archive_%s' % prefix] = _sum
        sla_archive['archive_total'] = sla_archive['sla_total']
        sla_archive['archive_format'] = sla_archive['sla_format']
        sla_archive['total_ranges'] = len(sla_archive['ranges'])
        
        if in_future:
            sla_archive['in_future'] = True
        elif in_past:
            sla_archive['in_past'] = True
        
        return sla_archive
    
    
    def build_archive(self, date, item_uuid, end_of_day=None, inactive_ranges=None):
        # type: (Date, str, int, List) -> Dict
        # self.logger.debug('asking build_archive [%s]  [%s]' % (date, item_uuid))
        
        if shared_data.get_already_archived() and self.sla_database.find_archive(item_uuid, date, lookup={'_id': 1}):
            self.sla_database.remove_archive(item_uuid, date)
        
        sla_archive = {'uuid': item_uuid, 'yday': date.yday, 'year': date.year}
        
        if end_of_day is None:
            end_of_day = get_end_of_day(date)
            # self.logger.debug('build_archive end_of_day [%s]' % print_time(end_of_day))
        
        start_of_range = self.sla_info.get_monitoring_start_time(item_uuid)
        sla_thresholds = self.sla_info.get_sla_thresholds(item_uuid)
        # self.logger.debug('get_monitoring_start_time [%s]' % print_time(start_of_range))
        sla_archive['ranges'] = self._build_range(date, start_of_range, end_of_day, item_uuid, inactive_ranges=inactive_ranges)
        # self.logger.debug('build_range [%s]' % len(info['ranges']))
        if not sla_archive['ranges']:
            return self.build_archive_for_missing_day(date, item_uuid, start_of_range=start_of_range, end_of_day=end_of_day, sla_thresholds=sla_thresholds)
        self.compute_percent_sla.compute_state_sum(sla_archive['ranges'], sla_archive)
        self.compute_percent_sla.compute_sla(sla_archive)
        for prefix in LIST_STATUS:
            _sum = sla_archive.get('sla_%s' % prefix, 0)
            if _sum:
                sla_archive['archive_%s' % prefix] = _sum
        
        sla_archive['archive_total'] = sla_archive['sla_total']
        sla_archive['archive_format'] = sla_archive['sla_format']
        sla_archive['thresholds'] = sla_thresholds
        sla_archive['total_ranges'] = len(sla_archive['ranges'])
        sla_archive['version'] = sla_common.CURRENT_ARCHIVE_VERSION
        sla_archive['build_at'] = get_now()
        sla_archive['_id'] = uuid.uuid4().hex
        return sla_archive
    
    
    def compute_inactive_ranges(self, date, start_of_day, end_of_range):
        values = []
        sla_status = self.sla_database.find_raw_sla_status(date)
        if not sla_status:
            return [{
                'active_range'                : False,
                RAW_SLA_KEY.CONTEXT_AND_STATUS: STATUS.SHINKEN_INACTIVE * 100,
                RAW_SLA_KEY.START             : start_of_day,
                RAW_SLA_KEY.END               : end_of_range,
            }]
        
        sla_status = sla_status['active_ranges']
        sla_status = sorted(sla_status, key=lambda _value: _value['start'])
        
        # self.logger.debug('range %s -- %s' % (self.logger.format_time(start_of_day), self.logger.format_time(end_of_range)))
        last_end = start_of_day
        for sla_stat in sla_status:
            inactive_range = {
                'active_range'                : False,
                RAW_SLA_KEY.CONTEXT_AND_STATUS: STATUS.SHINKEN_INACTIVE * 100
            }
            
            start = sla_stat['start']
            end = sla_stat['end']
            if start > end:
                self.logger.warning('An inconsistency in shinken activity was found. Check your ntp server.')
                self.logger.warning('Inconsistency : start:[%s] - end:[%s]' % (self.logger.format_time(start), self.logger.format_time(end)))
                continue
            # self.logger.debug('active_range [%s  -  %s] last_end[%s] start-last_end:[%s]' % (self.logger.format_time(start), self.logger.format_time(end), self.logger.format_time(last_end), (start - last_end)))
            
            if start - last_end > MARGIN_SLA_INACTIVE:
                if last_end < start_of_day:
                    last_end = start_of_day
                inactive_range[RAW_SLA_KEY.START] = last_end
                inactive_range[RAW_SLA_KEY.END] = start
                
                values.append(inactive_range)
                # self.logger.debug('inactive_range1 %s : [%s  -  %s] ' % (inactive_range[RAW_SLA_KEY.CONTEXT_AND_STATUS], self.logger.format_time(inactive_range[RAW_SLA_KEY.START]), self.logger.format_time(inactive_range[RAW_SLA_KEY.END])))
            if end > last_end:
                last_end = end
        
        if end_of_range - last_end > MARGIN_SLA_INACTIVE:
            if last_end < start_of_day:
                last_end = start_of_day
            inactive_range = {
                'active_range'                : False,
                RAW_SLA_KEY.CONTEXT_AND_STATUS: STATUS.SHINKEN_INACTIVE * 100,
                RAW_SLA_KEY.START             : last_end,
                RAW_SLA_KEY.END               : end_of_range
            }
            values.append(inactive_range)
            # self.logger.debug('inactive_range2 %s : [%s  -  %s] ' % (inactive_range[RAW_SLA_KEY.CONTEXT_AND_STATUS], self.logger.format_time(inactive_range[RAW_SLA_KEY.START]), self.logger.format_time(inactive_range[RAW_SLA_KEY.END])))
        return values
    
    
    def _build_range(self, date, start_of_range, end_of_range, item_uuid, inactive_ranges=None):
        # type: (Date, int, int, str, List) -> List
        # self.logger.debug('_build_range_sla [%s]:[%s]-[%s] [%s]' % (date, print_time(start_of_range), print_time(end_of_range), item_uuid))
        
        start_of_day = get_start_of_day(date)
        if start_of_range is not None and start_of_range != -1 and start_of_range > start_of_day:
            start_of_day = int(start_of_range)
        # self.logger.debug('start_of_day [%s]' % self.logger.format_time_as_sla(start_of_day))
        
        if not inactive_ranges:
            inactive_ranges = self.compute_inactive_ranges(date, start_of_day, end_of_range)
        
        raw_sla = self.sla_database.find_raw_sla(date, item_uuid)
        # self.logger.debug('nb raw sla found %s' % len(raw_sla))
        
        margin_new_range = MARGIN_SLA_INACTIVE * 2
        inactive_ranges = SLAArchive._prepare_ranges('inactive_ranges', inactive_ranges, margin_new_range, start_of_day, end_of_range)
        raw_sla = SLAArchive._prepare_ranges('raw_sla', raw_sla, margin_new_range, start_of_day, end_of_range)
        
        # self.logger.debug('inactive_ranges')
        # for i in inactive_ranges:
        #     self.logger.debug('inactive_ranges [%s] [%s-%s]' % (i[RAW_SLA_KEY.CONTEXT_AND_STATUS], print_time(i[RAW_SLA_KEY.START]), print_time(i[RAW_SLA_KEY.END])))
        #
        # self.logger.debug('raw_sla')
        # for i in raw_sla:
        #     self.logger.debug('raw_sla [%s] [%s-%s] output:[%s][%s]' % (i[RAW_SLA_KEY.CONTEXT_AND_STATUS], print_time(i[RAW_SLA_KEY.START]), print_time(i[RAW_SLA_KEY.END]), i.get(RAW_SLA_KEY.OUTPUT, None), i.get(RAW_SLA_KEY.LONG_OUTPUT, None)))
        
        mixed_raw_ranges = SLAArchive._mix_ranges(inactive_ranges, raw_sla)
        mixed_raw_ranges = SLAArchive._remove_small_missing_data(mixed_raw_ranges)
        # for i in mixed_raw_ranges:
        #     self.logger.debug(        'mixed_raw_ranges [%s] [%s-%s] output:[%s][%s]' % (i[RAW_SLA_KEY.CONTEXT_AND_STATUS], print_time(i[RAW_SLA_KEY.START]), print_time(i[RAW_SLA_KEY.END]), i.get(RAW_SLA_KEY.OUTPUT, None), i.get(RAW_SLA_KEY.LONG_OUTPUT, None)))
        ranges = [SLAArchive._raw_range_to_range(i) for i in mixed_raw_ranges]
        
        ranges.reverse()
        return ranges
    
    
    @staticmethod
    def _mix_ranges(ranges1, ranges2):
        if not ranges1:
            return ranges2
        if not ranges2:
            return ranges1
        
        ranges1_name = ranges1[0]['name']
        ranges2_name = ranges2[0]['name']
        
        ranges = ranges1 + ranges2
        ranges = sorted(ranges, key=lambda _value: _value[RAW_SLA_KEY.START])
        
        cursor = None
        current_range1 = None
        current_range2 = None
        range_build = []
        for current_range in ranges:
            if current_range['name'] == ranges1_name:
                current_range1 = current_range
            if current_range['name'] == ranges2_name:
                current_range2 = current_range
            
            if cursor is None:
                cursor = current_range.copy()
                range_build.append(cursor)
            else:
                if SLAArchive._range_collide(current_range1, current_range2):
                    win_range = SLAArchive._mix_state(current_range1, current_range2)
                else:
                    win_range = current_range.copy()
                
                del win_range[RAW_SLA_KEY.START]
                del win_range[RAW_SLA_KEY.END]
                
                if cursor[RAW_SLA_KEY.CONTEXT_AND_STATUS] != win_range[RAW_SLA_KEY.CONTEXT_AND_STATUS]:
                    if cursor[RAW_SLA_KEY.START] == current_range[RAW_SLA_KEY.START]:
                        cursor.update(win_range)
                    else:
                        cursor[RAW_SLA_KEY.END] = current_range[RAW_SLA_KEY.START]
                        cursor = current_range.copy()
                        cursor.update(win_range)
                        range_build.append(cursor)
                
                if cursor[RAW_SLA_KEY.END] < current_range[RAW_SLA_KEY.END]:
                    cursor[RAW_SLA_KEY.END] = current_range[RAW_SLA_KEY.END]
        return range_build
    
    
    @staticmethod
    def _remove_small_missing_data(ranges):
        ret_ranges = []
        max_index = len(ranges) - 1
        for i, _range in enumerate(ranges):
            is_missing_data = _range[RAW_SLA_KEY.CONTEXT_AND_STATUS] == 100 * STATUS.MISSING_DATA
            duration = _range[RAW_SLA_KEY.END] - _range[RAW_SLA_KEY.START]
            if is_missing_data and duration < MARGIN_SLA_INACTIVE:
                if i < max_index:
                    ranges[i + 1][RAW_SLA_KEY.START] = _range[RAW_SLA_KEY.START]
            else:
                ret_ranges.append(_range)
        
        return ret_ranges
    
    
    @staticmethod
    def _range_collide(current_range1, current_range2):
        return current_range1 and current_range2 and \
               ((current_range1[RAW_SLA_KEY.START] <= current_range2[RAW_SLA_KEY.START] <= current_range1[RAW_SLA_KEY.END]) or
                (current_range2[RAW_SLA_KEY.START] <= current_range1[RAW_SLA_KEY.START] <= current_range2[RAW_SLA_KEY.END]))
    
    
    @staticmethod
    def _mix_state(current_range1, current_range2):
        state1 = current_range1[RAW_SLA_KEY.CONTEXT_AND_STATUS]
        state2 = current_range2[RAW_SLA_KEY.CONTEXT_AND_STATUS]
        # from low to high
        state_priorities = (STATUS.MISSING_DATA, STATUS.UNKNOWN, STATUS.OK, STATUS.WARN, STATUS.CRIT, STATUS.SHINKEN_INACTIVE)
        for state_priority in state_priorities:
            if state1 / 100 % 10 == state_priority:
                return current_range2.copy()
            if state2 / 100 % 10 == state_priority:
                return current_range1.copy()
        
        return current_range1.copy()
    
    
    @staticmethod
    def _prepare_ranges(name, values, margin_new_range, start_of_range, end_of_range):
        # type: (str, List, int, int, int) -> List
        values = sorted(values, key=lambda _value: _value[RAW_SLA_KEY.START])
        cursor = None
        range_build = []
        for value in values:
            if RAW_SLA_KEY.CONTEXT_AND_STATUS in value:
                cache_value = value[RAW_SLA_KEY.CONTEXT_AND_STATUS]
            else:
                cache_value = value[RAW_SLA_KEY.ACK] + 10 * value[RAW_SLA_KEY.DT] + 100 * value[RAW_SLA_KEY.STATUS] + 1000 * value.get(RAW_SLA_KEY.FLAPPING, 0)
            start = value[RAW_SLA_KEY.START]
            end = value[RAW_SLA_KEY.END]
            ack_uuid = value.get(RAW_SLA_KEY.ACK_UUID, '')
            output = value.get(RAW_SLA_KEY.OUTPUT, None)
            long_output = value.get(RAW_SLA_KEY.LONG_OUTPUT, None)
            downtimes_uuid = value.get(RAW_SLA_KEY.DOWNTIMES_UUID, [])
            
            next_start = start_of_range if cursor is None else cursor[RAW_SLA_KEY.END]
            if next_start == NO_END_VALIDITY:
                next_start = start
            
            if start - next_start > margin_new_range:
                cursor = {
                    'name'                        : name,
                    RAW_SLA_KEY.CONTEXT_AND_STATUS: 100 * STATUS.MISSING_DATA,
                    RAW_SLA_KEY.START             : next_start,
                    RAW_SLA_KEY.END               : start,
                }
                if cursor[RAW_SLA_KEY.START] > end_of_range:
                    break
                range_build.append(cursor)
            
            if cursor is None:
                cursor = {
                    'name'                        : name,
                    RAW_SLA_KEY.CONTEXT_AND_STATUS: cache_value,
                    RAW_SLA_KEY.START             : start_of_range,
                    RAW_SLA_KEY.END               : end,
                    RAW_SLA_KEY.ACK_UUID          : ack_uuid,
                    RAW_SLA_KEY.OUTPUT            : output,
                    RAW_SLA_KEY.LONG_OUTPUT       : long_output,
                    RAW_SLA_KEY.DOWNTIMES_UUID    : downtimes_uuid,
                }
                if cursor[RAW_SLA_KEY.START] > end_of_range:
                    break
                range_build.append(cursor)
            else:
                if cursor[RAW_SLA_KEY.CONTEXT_AND_STATUS] != cache_value:
                    cursor[RAW_SLA_KEY.END] = start
                    cursor = {
                        'name'                        : name,
                        RAW_SLA_KEY.CONTEXT_AND_STATUS: cache_value,
                        RAW_SLA_KEY.START             : start,
                        RAW_SLA_KEY.END               : end,
                        RAW_SLA_KEY.ACK_UUID          : ack_uuid,
                        RAW_SLA_KEY.OUTPUT            : output,
                        RAW_SLA_KEY.LONG_OUTPUT       : long_output,
                        RAW_SLA_KEY.DOWNTIMES_UUID    : downtimes_uuid,
                    }
                    if cursor[RAW_SLA_KEY.START] > end_of_range:
                        break
                    range_build.append(cursor)
                else:
                    cursor[RAW_SLA_KEY.END] = end
        
        if cursor:
            if cursor[RAW_SLA_KEY.END] != NO_END_VALIDITY and end_of_range - cursor[RAW_SLA_KEY.END] > margin_new_range:
                cursor = {
                    'name'                        : name,
                    RAW_SLA_KEY.CONTEXT_AND_STATUS: 100 * STATUS.MISSING_DATA,
                    RAW_SLA_KEY.START             : cursor[RAW_SLA_KEY.END],
                    RAW_SLA_KEY.END               : end_of_range,
                }
                range_build.append(cursor)
            
            cursor[RAW_SLA_KEY.END] = end_of_range
        else:
            cursor = {
                'name'                        : name,
                RAW_SLA_KEY.CONTEXT_AND_STATUS: 100 * STATUS.MISSING_DATA,
                RAW_SLA_KEY.START             : start_of_range,
                RAW_SLA_KEY.END               : end_of_range,
            }
            range_build.append(cursor)
        
        return range_build
    
    
    @staticmethod
    def _raw_range_to_range(raw_range):
        cache_value = raw_range[RAW_SLA_KEY.CONTEXT_AND_STATUS]
        start = raw_range[RAW_SLA_KEY.START]
        end = raw_range[RAW_SLA_KEY.END]
        ack_uuid = raw_range.get(RAW_SLA_KEY.ACK_UUID, '')
        downtimes_uuid = raw_range.get(RAW_SLA_KEY.DOWNTIMES_UUID, [])
        output = raw_range.get(RAW_SLA_KEY.OUTPUT, None)
        long_output = raw_range.get(RAW_SLA_KEY.LONG_OUTPUT, None)
        
        _range = {
            'ack'     : cache_value % 10,
            'dt'      : cache_value / 10 % 10,
            'rc'      : cache_value / 100 % 10,
            'flg'     : bool(cache_value / 1000 % 10),
            'p_flg'   : bool(cache_value / 10000 % 10),
            'p_ack'   : bool(cache_value / 100000 % 10),
            'p_dt'    : bool(cache_value / 1000000 % 10),
            'ack_uuid': ack_uuid if cache_value % 10 else '',
            'start'   : start,
            'end'     : end
        }
        if downtimes_uuid:
            _range['downtimes_uuid'] = downtimes_uuid
        if output is not None:
            _range['output'] = output
        if long_output is not None:
            _range['long_output'] = long_output
        return _range
