#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2019
# This file is part of Shinken Enterprise, all rights reserved.

import time
from datetime import datetime, timedelta

import shinkensolutions.ssh_mongodb as mongo
from shinken.log import PART_INITIALISATION
from shinken.objects.module import Module as ShinkenModuleDefinition
from shinkensolutions.date_helper import date_now, timestamp_from_date, Date, timestamp_from_datetime
from sla_abstract_component import ThreadComponent
from sla_common import RAW_SLA_KEY, shared_data
from sla_component_manager import ComponentManager
from sla_database_connection import SLADatabaseConnection

DEFAULT_CHECK_INTERVAL = 5
DEFAULT_FIRST_MONITORING_DATE = datetime.strptime('2000-01-01', '%Y-%m-%d')


class SLAInfo(ThreadComponent):
    
    def __init__(self, conf, component_manager, sla_database_connection):
        # type: (ShinkenModuleDefinition, ComponentManager, SLADatabaseConnection) -> None
        super(SLAInfo, self).__init__(conf, component_manager, only_one_thread_by_class=True, stop_thread_on_error=False)
        self.sla_database_connection = sla_database_connection
        self._cache_date = None
        self._cache_sla_info_by_uuid = {}
        self._cache_sla_info_by_name = {}
        self._cache_first_monitoring_date = None
        self._last_clean_check_time = None

    
    def init(self):
        time_start = time.time()
        self._cache_first_monitoring_date = None
        all_sla_infos = self.sla_database_connection.col_sla_info.find({})
        
        tmp_cache = {}
        # handle double name with id !=
        for sla_info in all_sla_infos:
            if sla_info['_id'] == 'SLA_INFO':
                self._cache_date = sla_info.get('last_update', -1)
                self._last_clean_check_time = sla_info.get('last_clean_check_time', None)
            else:
                name = '%s-%s' % (sla_info['host_name'], sla_info['service_description'])
                tmp = tmp_cache.get(name, [])
                tmp.append(sla_info)
                tmp_cache[name] = tmp
        
        invalide_sla_infos = []
        for tmp in tmp_cache.itervalues():
            if len(tmp) > 1:
                invalide_sla_infos.extend(tmp)
            else:
                sla_info = tmp[0]
                name = '%s-%s' % (sla_info['host_name'], sla_info['service_description'])
                sla_info_uuid = sla_info['_id']
                self._cache_sla_info_by_uuid[sla_info_uuid] = sla_info
                self._cache_sla_info_by_name[name] = sla_info
        
        for sla_info_invalide in invalide_sla_infos:
            # self.logger.debug('removing sla_info_invalide [%s]' % sla_info_invalide)
            self.sla_database_connection.col_sla_info.remove({'_id': sla_info_invalide['_id']})
        
        self.logger.info(PART_INITIALISATION, 'Load %s elements info in cache done in %s' % (len(self._cache_sla_info_by_uuid), self.logger.format_chrono(time_start)))
        self.start_thread()
    
    
    def get_thread_name(self):
        return 'sla-info-thread'
    
    
    def loop_turn(self):
        sla_info = self.sla_database_connection.col_sla_info.find_one({'_id': 'SLA_INFO'})
        if sla_info and self._cache_date != sla_info.get('last_update', -1):
            self._reload_cache()
    
    
    def _reload_cache(self):
        time_start = time.time()
        self.logger.info('Elements info was updated we reload our cache')
        self._cache_first_monitoring_date = None
        all_sla_infos = self.sla_database_connection.col_sla_info.find({})
        
        for sla_info in all_sla_infos:
            if sla_info['_id'] == 'SLA_INFO':
                self._cache_date = sla_info.get('last_update', -1)
                self._last_clean_check_time = sla_info.get('last_clean_check_time', None)
            else:
                name = '%s-%s' % (sla_info['host_name'], sla_info['service_description'])
                sla_info_uuid = sla_info['_id']
                self._cache_sla_info_by_uuid[sla_info_uuid] = sla_info
                self._cache_sla_info_by_name[name] = sla_info
        
        self.logger.info(PART_INITIALISATION, 'Reload %s elements info in cache done in %s' % (len(self._cache_sla_info_by_uuid), self.logger.format_chrono(time_start)))
    
    
    def _get_cache_sla_info(self, item_uuid='', host_name='', service_description=''):
        if item_uuid:
            _where = {'_id': item_uuid}
            sla_info = self._cache_sla_info_by_uuid.get(item_uuid, None)
        else:
            service_description = '' if service_description is None else service_description
            _where = {'host_name': host_name, 'service_description': service_description}
            name = '%s-%s' % (host_name, service_description)
            sla_info = self._cache_sla_info_by_name.get(name, None)
        
        if sla_info:
            return sla_info
        
        sla_info = self.sla_database_connection.col_sla_info.find_one(_where)
        if sla_info:
            name = '%s-%s' % (sla_info['host_name'], sla_info['service_description'])
            self._cache_sla_info_by_uuid[sla_info['_id']] = sla_info
            self._cache_sla_info_by_name[name] = sla_info
            return sla_info
        return None
    
    
    def get_all_uuids(self):
        return [uuid[0:uuid.rfind('-')] if uuid.count('-') == 3 else uuid for uuid in self._cache_sla_info_by_uuid.iterkeys()]
    
    
    def get_uuid(self, hname, sdesc):
        info = self._get_cache_sla_info(host_name=hname, service_description=sdesc)
        if info:
            item_uuid = info['_id']
            if item_uuid.count('-') == 3:
                item_uuid = item_uuid[0:item_uuid.rfind('-')]
            return item_uuid
        else:
            return None
    
    
    def get_name(self, item_uuid):
        info = self._get_cache_sla_info(item_uuid=item_uuid)
        if info:
            return info.get('host_name', None), info.get('service_description', None)
        else:
            return None, None
    
    
    def get_monitoring_start_time(self, item_uuid='', host_name='', service_description=''):
        first_monitoring_start_time = timestamp_from_datetime(self.get_first_monitoring_start_time())
        sla_info = self._get_cache_sla_info(item_uuid, host_name, service_description)
        if sla_info:
            return sla_info.get('monitoring_start_time', first_monitoring_start_time)
        return first_monitoring_start_time
    
    
    def get_check_interval(self, item_uuid):
        sla_info = self._get_cache_sla_info(item_uuid)
        if sla_info:
            return sla_info.get('check_interval', DEFAULT_CHECK_INTERVAL)
        return DEFAULT_CHECK_INTERVAL
    
    
    def get_sla_thresholds(self, item_uuid):
        sla_info = self._get_cache_sla_info(item_uuid)
        default_value = shared_data.get_default_sla_thresholds()
        if sla_info:
            return sla_info.get('sla_thresholds', default_value)
        return default_value
    
    
    def handle_brok(self, item_type, brok):
        brok_data = brok.data
        actual_check_interval = brok_data['check_interval']
        monitoring_start_time_from_brok = brok_data.get('monitoring_start_time', -1)
        host_name = brok_data['host_name']
        if item_type == 'service':
            service_description = brok_data['service_description']
            name = '%s-%s' % (host_name, service_description)
            item_uuid = brok_data['instance_uuid']
            actual_sla_thresholds = (brok_data.get('sla_warning_threshold', shared_data.get_default_sla_thresholds()[0]), brok_data.get('sla_critical_threshold', shared_data.get_default_sla_thresholds()[1]))
            sla_info_entry = {'_id': item_uuid, 'host_name': host_name, 'service_description': service_description, 'check_interval': actual_check_interval, 'sla_thresholds': actual_sla_thresholds}
            _where_sla_info_entry = {'host_name': host_name, 'service_description': service_description}
            _where_archive_name = {'hname': host_name, 'sdesc': service_description, 'type': item_type}
            _where_archive_uuid = {'uuid': item_uuid}
            _where_current_uuid = {RAW_SLA_KEY.UUID: item_uuid}
            _where_current_name = {RAW_SLA_KEY.HNAME: host_name, RAW_SLA_KEY.SDESC: service_description, RAW_SLA_KEY.TYPE: 'S'}
        else:
            service_description = None
            item_uuid = brok_data['instance_uuid']
            name = '%s-' % host_name
            actual_sla_thresholds = (brok_data.get('sla_warning_threshold', shared_data.get_default_sla_thresholds()[0]), brok_data.get('sla_critical_threshold', shared_data.get_default_sla_thresholds()[1]))
            sla_info_entry = {'_id': item_uuid, 'host_name': host_name, 'service_description': '', 'check_interval': actual_check_interval, 'sla_thresholds': actual_sla_thresholds}
            _where_sla_info_entry = {'host_name': host_name, 'service_description': ''}
            _where_archive_name = {'hname': host_name, 'type': item_type}
            _where_archive_uuid = {'uuid': item_uuid}
            _where_current_uuid = {RAW_SLA_KEY.UUID: item_uuid}
            _where_current_name = {RAW_SLA_KEY.HNAME: host_name, RAW_SLA_KEY.TYPE: 'H'}
        
        sla_info_entry_found = self._get_cache_sla_info(host_name=host_name, service_description=service_description)
        if not sla_info_entry_found:
            sla_info_entry_found = self._get_cache_sla_info(item_uuid=item_uuid)
        
        if sla_info_entry_found:
            verify_monitoring_start_time = sla_info_entry_found.get('monitoring_start_time', -1)
            verify_check_interval = sla_info_entry_found.get('check_interval', None)
            verify_sla_thresholds = tuple(sla_info_entry_found.get('sla_thresholds', []))
            
            # self.logger.debug('item_uuid:[%s] monitoring_start_time_from_brok:[%s-%s] verify_monitoring_start_time:[%s-%s]' % (
            # item_uuid, monitoring_start_time_from_brok, print_time(monitoring_start_time_from_brok), verify_monitoring_start_time, print_time(verify_monitoring_start_time)))
            
            found_name = '%s-%s' % (sla_info_entry_found['host_name'], sla_info_entry_found['service_description'])
            if found_name != name or \
                    item_uuid != sla_info_entry_found['_id'] or \
                    verify_monitoring_start_time == -1 or \
                    (monitoring_start_time_from_brok != -1 and verify_monitoring_start_time > monitoring_start_time_from_brok) or \
                    verify_check_interval != actual_check_interval or \
                    verify_sla_thresholds != actual_sla_thresholds:
                # Invalide entry was found, maybe because the item was rename so we remove them
                self.sla_database_connection.col_sla_info.remove({'_id': sla_info_entry_found['_id']})
                sla_info_entry_found = None
                self.logger.debug('Element info of [%s] must be update monitoring_start_time:[%s->%s] check_interval:[%s->%s] verify_sla_thresholds:[%s->%s]' % (
                    name, self.logger.format_time_as_sla(monitoring_start_time_from_brok), self.logger.format_time_as_sla(verify_monitoring_start_time), verify_check_interval, actual_check_interval, verify_sla_thresholds, actual_sla_thresholds))
        else:
            self.logger.debug(
                'Adding new element info for [%s] monitoring_start_time:[%s] check_interval:[%s] verify_sla_thresholds:[%s]' % (name, self.logger.format_time_as_sla(monitoring_start_time_from_brok), actual_check_interval, actual_sla_thresholds))
        
        if sla_info_entry_found:
            self.logger.debug(
                'Element info of [%s] is correct monitoring_start_time:[%s] check_interval:[%s] verify_sla_thresholds:[%s]' % (name, self.logger.format_time_as_sla(monitoring_start_time_from_brok), actual_check_interval, actual_sla_thresholds))
            return True
        
        monitoring_start_time_from_old_data = self._search_monitoring_start_time_in_archive(_where_archive_uuid, _where_archive_name)
        if monitoring_start_time_from_old_data is None:
            # The item wasn't found in archive so we will search in current collection
            monitoring_start_time_from_old_data = self._search_monitoring_start_time_in_current_day(_where_current_uuid, _where_current_name)
            if monitoring_start_time_from_old_data is None:
                # No data in current collection so we assume time in brok is good
                monitoring_start_time_from_old_data = monitoring_start_time_from_brok
        
        # print_start_time = -1 if monitoring_start_time_from_brok == -1 else print_time(monitoring_start_time_from_brok)
        # print_start_time_from_old_data = -1 if monitoring_start_time_from_old_data == -1 else print_time(monitoring_start_time_from_old_data)
        # self.logger.debug('item [%s] monitoring_start_time_from_brok [%s] / monitoring_start_time_from_old_data [%s] ' % (name, print_start_time, print_start_time_from_old_data))
        
        if monitoring_start_time_from_brok == -1 or monitoring_start_time_from_old_data < monitoring_start_time_from_brok:
            verify_monitoring_start_time = monitoring_start_time_from_old_data
        else:
            verify_monitoring_start_time = monitoring_start_time_from_brok
        
        sla_info_entry['monitoring_start_time'] = verify_monitoring_start_time
        self._cache_sla_info_by_name[name] = sla_info_entry
        self._cache_sla_info_by_uuid[item_uuid] = sla_info_entry
        
        self.sla_database_connection.col_sla_info.save(sla_info_entry)
        self._cache_date = time.time()
        self.sla_database_connection.col_sla_info.update(query={'_id': 'SLA_INFO'}, set_update={'last_update': self._cache_date}, upsert=True)
        # self.logger.debug('update sla info of [%s] monitoring_start_time:[%s] check_interval[%s]' % (name, -1 if verify_monitoring_start_time == -1 else print_time(verify_monitoring_start_time), actual_check_interval))
    
    
    def _search_monitoring_start_time_in_current_day(self, _where_current_uuid, _where_current_name):
        date = date_now()
        current_collection = self.sla_database_connection.get_raw_sla_collection(date, False)
        first_entry = current_collection.find(_where_current_uuid, RAW_SLA_KEY.START, sort=(RAW_SLA_KEY.START, mongo.ASCENDING), limit=1, next=True)
        if not first_entry and not shared_data.get_migration_daily_done():
            first_entry = current_collection.find(_where_current_name, RAW_SLA_KEY.START, sort=(RAW_SLA_KEY.START, mongo.ASCENDING), limit=1, next=True)
        
        if not first_entry:
            return None
        
        monitoring_start_time_from_old_data = first_entry[RAW_SLA_KEY.START]
        return monitoring_start_time_from_old_data
    
    
    def _search_monitoring_start_time_in_archive(self, _where_archive_uuid, _where_archive_name):
        _where_archive = _where_archive_uuid
        first_year_entry = self.sla_database_connection.col_archive.find(_where_archive, {'year': 1}, sort=('yday', mongo.ASCENDING), limit=1, next=True)
        if not first_year_entry and not shared_data.get_migration_archive_done():
            _where_archive = _where_archive_name
            first_year_entry = self.sla_database_connection.col_archive.find(_where_archive, {'year': 1}, sort=('yday', mongo.ASCENDING), limit=1, next=True)
        
        if not first_year_entry:
            return None
        
        first_entry = self.sla_database_connection.col_archive.find(_where_archive, {'yday': 1, 'year': 1}, sort=('yday', mongo.ASCENDING), limit=1, next=True)
        if not first_entry:
            return None
        
        monitoring_start_time_from_old_data = timestamp_from_date(Date(first_entry['yday'], first_entry['year']))
        return monitoring_start_time_from_old_data
    
    
    def get_first_monitoring_start_time(self):
        if self._cache_first_monitoring_date:
            return self._cache_first_monitoring_date
        
        first_entry = self.sla_database_connection.col_sla_info.find({'monitoring_start_time': {'$exists': True}}, {'monitoring_start_time': 1}, sort=('monitoring_start_time', mongo.ASCENDING), limit=1, next=True)
        if first_entry:
            _first_monitoring_date = datetime.fromtimestamp(first_entry['monitoring_start_time']) - timedelta(days=1)
        else:
            _first_monitoring_date = DEFAULT_FIRST_MONITORING_DATE
        
        self._cache_first_monitoring_date = _first_monitoring_date
        self.logger.info('Found first monitoring start time at %s' % self._cache_first_monitoring_date.strftime('%d-%m-%Y %H:%M:%S'))
        return _first_monitoring_date

    
    def get_last_clean_time(self):
        return self._last_clean_check_time
    
    
    def set_last_clean_time(self, unix_timestamp):
        # type: (time.time()) -> None
        self._last_clean_check_time = unix_timestamp
        self.sla_database_connection.col_sla_info.update(query={'_id': 'SLA_INFO'}, set_update={'last_clean_check_time': unix_timestamp}, upsert=True)
