#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022
# This file is part of Shinken Enterprise, all rights reserved.

import threading
import time

from pymongo.errors import BulkWriteError, ConnectionFailure

import shinkensolutions.ssh_mongodb as mongo
from ec_event import Event
from shinken.log import PART_INITIALISATION, PartLogger
from shinken.misc.type_hint import TYPE_CHECKING, cast
from shinkensolutions.lib_modules.configuration_reader import read_int_in_configuration
from shinkensolutions.ssh_mongodb.mongo_bulk import UpdateBulk, InsertBulk, BULK_TYPE
from shinkensolutions.ssh_mongodb.mongo_client import MongoClient
from shinkensolutions.ssh_mongodb.mongo_collection import MongoCollection
from shinkensolutions.ssh_mongodb.mongo_conf import MongoConf

if TYPE_CHECKING:
    from shinken.objects.module import Module as ShinkenModuleDefinition
    from shinken.misc.type_hint import List, Dict, Callable, Optional, Union

INFO_ENTRY_ID = 'INFO_ENTRY_ID'


class ECDatabaseError(Exception):
    def __init__(self, uri=''):
        # type: (Union[unicode, str]) -> None
        self._uri = uri
        error_message = 'Cannot connect to database : %s' % uri
        super(ECDatabaseError, self).__init__(error_message)


class ECBulk(object):
    bulk_thread = None  # type: threading.Thread
    collection = None  # type: MongoCollection
    on_execute_error = None  # type: Callable
    bulk_insert = None  # type: InsertBulk
    bulk_update = None  # type: UpdateBulk
    bulk_insert_cmp = 0
    bulk_update_cmp = 0
    
    
    def __init__(self, collection, logger, on_execute_error=None):
        # type: (MongoCollection, PartLogger, Callable) -> None
        self.logger = logger
        self.collection = collection
        if not on_execute_error:
            on_execute_error = ECBulk._default_on_execute_error
        self.on_execute_error = on_execute_error
        self.collection_name = self.collection.get_name()
        
        self._make_bulks()
        self.logger.debug('[bulk-%s] create bulk' % self.collection_name)
    
    
    @staticmethod
    def _default_on_execute_error():
        pass
    
    
    def insert_event(self, event):
        self.bulk_insert.insert(event.to_database_entry())
        self.bulk_insert_cmp += 1
    
    
    def update_event(self, event_uuid, state_type, event_hard_since):
        self.bulk_update.update({'_id': event_uuid}, {'$set': {'state_type': state_type, 'event_hard_since': event_hard_since}})
        self.bulk_update_cmp += 1
    
    
    def bulks_execute(self):
        if self.bulk_thread:
            self.bulk_thread.join()
        
        info = (self.bulk_insert_cmp, self.bulk_update_cmp)
        self.bulk_thread = threading.Thread(target=self._bulks_execute, args=(self.bulk_insert, self.bulk_insert_cmp, self.bulk_update, self.bulk_update_cmp))
        self.bulk_thread.start()
        self._make_bulks()
        return info
    
    
    def _make_bulks(self):
        self.bulk_insert = cast(InsertBulk, self.collection.get_bulk(BULK_TYPE.INSERT, order=True))
        self.bulk_update = cast(UpdateBulk, self.collection.get_bulk(BULK_TYPE.UPDATE, order=True))
        self.bulk_insert_cmp = 0
        self.bulk_update_cmp = 0
    
    
    def _bulks_execute(self, bulk_insert, bulk_insert_counter, bulk_update, bulk_update_counter):
        self._bulk_generic_execute(bulk_insert, bulk_insert_counter, 'insert')
        self._bulk_generic_execute(bulk_update, bulk_update_counter, 'update')
    
    
    def _bulk_generic_execute(self, bulk, bulk_cmp, bulk_type):
        if bulk_cmp == 0:
            return
        
        try:
            result = bulk.execute()
            
            actual_cmp = 0
            if bulk_type == 'update':
                actual_cmp = result['nMatched']
            if bulk_type == 'insert':
                actual_cmp = result['nInserted']
            
            if actual_cmp != bulk_cmp:
                self.logger.error('[bulk-%(collection_name)s] bulk %(bulk_type)s fail because we expect to %(bulk_type)s %(bulk_cmp)s item and we effectively %(bulk_type)s %(actual_cmp)s' % {
                    'collection_name': self.collection_name,
                    'bulk_type'      : bulk_type,
                    'bulk_cmp'       : bulk_cmp,
                    'actual_cmp'     : result['nMatched']
                })
                self.on_execute_error()
        except BulkWriteError as bwe:
            if bwe.details.get('writeErrors', ''):
                write_errors = ''
                for errmsg in bwe.details['writeErrors']:
                    write_errors += errmsg.get('errmsg', 'errmsg not found')
            else:
                write_errors = str(bwe)
            self.logger.error('[bulk-%s] bulk %s fail with error : %s' % (bulk_type, self.collection_name, write_errors))
            self.on_execute_error()


class ECDatabaseConnection(MongoClient):
    col_event_container = None  # type: MongoCollection
    col_state_cache = None  # type: MongoCollection
    col_info = None  # type: MongoCollection
    col_acknowledge = None  # type: MongoCollection
    col_downtime = None  # type: MongoCollection
    
    
    def __init__(self, conf, logger, ttl_index=True):
        # type: (ShinkenModuleDefinition, PartLogger, bool) -> None
        MongoClient.__init__(self, MongoConf(conf, logger=logger, default_database='event_container'), logger)
        self.logger = logger
        
        self.ttl_index = ttl_index
        if ttl_index:
            self.expire_after_seconds = read_int_in_configuration(conf, 'day_keep_data', '30') * 86400
        else:
            self.expire_after_seconds = None
    
    
    def get_day_keep_data(self):
        if self.expire_after_seconds is None:
            try:
                index_information = self.col_event_container.index_information()
            except Exception:
                raise ECDatabaseError(uri=self._uri)
            
            if not index_information:
                return 30
            ttl_idx = index_information.get('ttl_idx', {}).get('expireAfterSeconds', 30 * 86400)
            self.expire_after_seconds = ttl_idx
        
        return self.expire_after_seconds / 86400
    
    
    def init(self, requester='event_container'):
        if self._database:
            return
        MongoClient.init(self, requester=requester)
        self._init_collections()
        
        # Add index for the collections
        time_start = time.time()
        try:
            self.col_event_container.ensure_index([('ordering_uuid', mongo.DESCENDING)], name='date_idx')
            self.col_event_container.ensure_index([('item_uuid', mongo.ASCENDING)], name='item_uuid_idx')
            if self.ttl_index:
                self.col_event_container.ensure_index([('event_since', mongo.ASCENDING)], name='ttl_idx', expire_after_seconds=self.expire_after_seconds)
        except Exception:
            raise ECDatabaseError(uri=self._uri)
        _logger = self.logger.get_sub_part(PART_INITIALISATION).get_sub_part(u'MONGO')
        _logger.info(u'Ensure mongo index done in %s' % self.logger.format_chrono(time_start))
    
    
    def _init_collections(self):
        try:
            self.col_event_container = self.get_collection('ec_datas')
        except Exception:
            raise ECDatabaseError(uri=self._uri)
        self.col_state_cache = self.get_collection('ec_state_cache')
        self.col_info = self.get_collection('ec_info')
        self.col_acknowledge = self.get_collection('ec_acknowledge')
        self.col_downtime = self.get_collection('ec_downtime')
    
    
    # Note: page_nb must start at 0
    def find_events(self, filters, hint=None, skip_value=0, page_size=0):
        # type: (Dict, Optional[unicode], int, int) -> List[Event]
        
        if hint:
            hint = (hint, mongo.ASCENDING)
        else:
            hint = None
        
        try:
            start_time = time.time()
            events = [Event.from_database_entry(i) for i in self.col_event_container.find(filters, sort=[(u'ordering_uuid', mongo.DESCENDING)], limit=page_size, skip=skip_value, hint=hint)]
            self.logger.debug(u'find_events DB query [%.3f] with hint:[%s] sort:[ordering_uuid desc] limit:[%s] offset:[%s] filter:[%s]' % (time.time() - start_time, hint, page_size, skip_value, filters))
            return events
        except ConnectionFailure:
            raise ECDatabaseError(uri=self._uri)
        except Exception:
            self.logger.print_stack(u'DB query failure ')
            raise ECDatabaseError(uri=self._uri)
    
    
    def load_state_cache(self):
        return self.col_state_cache.find({})
    
    
    def is_in_state_cache(self, item_uuid):
        return bool(self.col_state_cache.find_one(filter={'_id': item_uuid}, projection={}))
    
    
    def update_state_cache(self, state_caches):
        self.col_state_cache.upsert_update_many(state_caches)
    
    
    def get_last_tick(self):
        info_entry = self.col_info.find_one({'_id': INFO_ENTRY_ID})
        return info_entry['last_tick'] if info_entry else 0
    
    
    def update_tick_info(self, tick_time):
        self.col_info.update(entry_id=INFO_ENTRY_ID, set_update={'last_tick': tick_time}, upsert=True)
    
    
    def event_bulk_factory(self):
        # type:()->ECBulk
        try:
            return ECBulk(self.col_event_container, self.logger)
        except Exception:
            raise ECDatabaseError(uri=self._uri)
    
    
    def save_acknowledge(self, acknowledge):
        self.col_acknowledge.save(acknowledge)
    
    
    def update_acknowledges(self, acknowledges):
        self.col_acknowledge.replace_many(acknowledges)
    
    
    def find_acknowledge(self, acknowledge_id):
        return self.col_acknowledge.find_one({'_id': acknowledge_id})
    
    
    def save_downtime(self, downtime):
        self.col_downtime.save(downtime)
    
    
    def find_downtime(self, downtime_id):
        return self.col_downtime.find_one({'_id': downtime_id})
    
    
    def find_downtimes(self, downtime_ids):
        return self.col_downtime.find({'_id': {'$in': downtime_ids}})
    
    
    def count_event_in_db(self):
        try:
            return self.col_event_container.count()
        except Exception:
            raise ECDatabaseError(uri=self._uri)
    
    
    def find_oldest_event_in_db(self):
        try:
            return self.col_event_container.find(limit=1, sort=[('event_since', mongo.ASCENDING)], next=True)
        except Exception:
            raise ECDatabaseError(uri=self._uri)
