#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2019
# This file is part of Shinken Enterprise, all rights reserved.

import threading
import time
from datetime import date as datetime_date, timedelta

from pymongo.errors import BulkWriteError

import shinkensolutions.ssh_mongodb as mongo
from shinken.log import PartLogger
from shinken.misc.type_hint import Callable
from shinken.objects.module import Module as ShinkenModuleDefinition
from shinkensolutions.date_helper import Date
from shinkensolutions.ssh_mongodb.mongo_bulk import UpdateBulk, InsertBulk, BULK_TYPE
from shinkensolutions.ssh_mongodb.mongo_client import MongoClient
from shinkensolutions.ssh_mongodb.mongo_collection import MongoCollection
from sla_abstract_component import AbstractComponent
from sla_common import shared_data, RAW_SLA_KEY
from sla_component_manager import ComponentManager


class BulkSla(object):
    bulk_thread = None  # type: threading.Thread
    collection = None  # type: MongoCollection
    on_execute_error = None  # type: Callable
    bulk_insert = None  # type: InsertBulk
    bulk_update = None  # type: UpdateBulk
    bulk_insert_cmp = 0
    bulk_update_cmp = 0
    
    
    def __init__(self, logger, collection, on_execute_error):
        # type: (PartLogger, MongoCollection, Callable) -> None
        self.collection = collection
        self.on_execute_error = on_execute_error
        self.collection_name = self.collection.get_name()
        
        self._make_bulks()
        self.logger = logger
        self.logger.debug('[bulk-%s] create bulk' % self.collection_name)
    
    
    def insert(self, to_insert):
        self.bulk_insert.insert(to_insert)
        self.bulk_insert_cmp += 1
    
    
    def update(self, to_find, to_update):
        self.bulk_update.update(to_find, to_update)
        self.bulk_update_cmp += 1
    
    
    def bulks_execute(self):
        if self.bulk_thread:
            # t0 = time.time()
            self.bulk_thread.join()
            # self.logger.debug('wait for bulk_thread [%.4f]s' % (time.time() - t0))
        
        info = (self.bulk_insert_cmp, self.bulk_update_cmp)
        self.bulk_thread = threading.Thread(target=self._bulks_execute, args=(self.bulk_insert, self.bulk_insert_cmp, self.bulk_update, self.bulk_update_cmp))
        self.bulk_thread.start()
        self._make_bulks()
        return info
    
    
    def _make_bulks(self):
        self.bulk_insert = self.collection.get_bulk(BULK_TYPE.INSERT, order=True)
        self.bulk_update = self.collection.get_bulk(BULK_TYPE.UPDATE, order=True)
        self.bulk_insert_cmp = 0
        self.bulk_update_cmp = 0
    
    
    def _bulks_execute(self, bulk_insert, bulk_insert_counter, bulk_update, bulk_update_counter):
        # self.logger.debug('[sla] _bulks_execute')
        # t0 = time.time()
        self._bulk_generic_execute(bulk_insert, bulk_insert_counter, 'insert')
        self._bulk_generic_execute(bulk_update, bulk_update_counter, 'update')
        # self.logger.debug('[sla][bulk-%s] bulk time [%.4f] insert [%s] update [%s]' % (self.current_col, time.time() - t0, bulk_insert_counter, bulk_update_counter))
    
    
    def _bulk_generic_execute(self, bulk, bulk_cmp, bulk_type):
        if bulk_cmp == 0:
            # self.logger.debug('[sla][bulk-%s] empty bulk [%s] call' % (self.current_col, bulk_type))
            return
        
        try:
            # t0 = time.time()
            result = bulk.execute()
            # bulk_time = time.time() - t0
            # self.logger.debug('[sla] bulk time [%.4f] bulk_cmp [%s] result[%s]' % (bulk_time, bulk_cmp, result))
            if bulk_type == 'update' and result['nMatched'] != bulk_cmp:
                self.logger.error('[bulk-%s] bulk update sla fail. We reset internal cache [nMatched:%s bulk_cmp:%s]' % (self.collection_name, result['nMatched'], bulk_cmp))
                self.on_execute_error()
            if bulk_type == 'insert' and result['nInserted'] != bulk_cmp:
                self.logger.error('[bulk-%s] bulk insert sla fail. We reset internal cache [nInserted:%s bulk_cmp:%s]' % (self.collection_name, result['nInserted'], bulk_cmp))
                self.on_execute_error()
        except BulkWriteError as bwe:
            write_errors = ''
            for errmsg in bwe.details['writeErrors']:
                write_errors += errmsg.get('errmsg', 'errmsg not found')
            self.logger.error('[bulk-%s] update_sla bulk fail : %s' % (self.collection_name, write_errors))
            self.on_execute_error()


class SLADatabaseConnection(AbstractComponent, MongoClient):
    col_archive = None  # type: MongoCollection
    col_archive_before_migration = None  # type: MongoCollection
    col_acknowledge = None  # type: MongoCollection
    col_downtime = None  # type: MongoCollection
    col_sla_info = None  # type: MongoCollection
    col_sla_future_states = None  # type: MongoCollection
    
    
    def __init__(self, conf, component_manager):
        # type: (ShinkenModuleDefinition, ComponentManager) -> None
        AbstractComponent.__init__(self, conf, component_manager)
        MongoClient.__init__(self, conf, self.logger)
    
    
    def tick(self):
        pass
    
    
    def init(self):
        if self._database:
            return
        time_start = time.time()
        MongoClient.init(self, requester='sla')
        self._init_collections()
        self.logger_init.info('Open mongo connection done in %s' % self.logger.format_chrono(time_start))
        
        # Add index for the collections
        time_start = time.time()
        self.col_sla_info.ensure_index([('host_name', mongo.ASCENDING), ('service_description', mongo.ASCENDING)], name='names')
        self.col_sla_info.ensure_index([('monitoring_start_time', mongo.ASCENDING)], name='monitoring_start_time')  # used by the graphite queries, time: 0.2s for 15K checks the first time, 0.0009 after
        self.col_archive.ensure_index([('hname', mongo.ASCENDING), ('type', mongo.ASCENDING), ('year', mongo.ASCENDING), ('yday', mongo.ASCENDING), ], name='hname_type_year_yday')
        self.col_archive.ensure_index([('hname', mongo.ASCENDING), ('type', mongo.ASCENDING), ('sdesc', mongo.ASCENDING), ('year', mongo.ASCENDING), ('yday', mongo.ASCENDING), ], name='hname_sdesc_type_year_yday')
        self.col_archive.ensure_index([('uuid', mongo.ASCENDING), ('year', mongo.ASCENDING), ('yday', mongo.ASCENDING), ], name='uuid_year_yday_idx')
        self.col_archive.ensure_index([('uuid', mongo.ASCENDING)], name='uuid_idx')
        self.col_archive.ensure_index([('version', mongo.ASCENDING)], name='version_idx')
        self.logger_init.info('Ensure mongo index done in %s' % self.logger.format_chrono(time_start))
        
        # Clean deprecated collection
        time_start = time.time()
        if self.collection_exist('sla'):
            _DEPRECATED_col_sla = self.get_collection('sla')
            try:
                last_update_sla = _DEPRECATED_col_sla.find({}, {'last_update': 1}, sort=('last_update', mongo.ASCENDING), limit=1, next=True)
                if last_update_sla:
                    last_week = datetime_date.today() - timedelta(weeks=1)
                    last_update_sla_date = datetime_date.fromtimestamp(last_update_sla['last_update'])
                    
                    self.logger_init.info('DEPRECATED FORMAT', 'Old working sla collection from version 02.03.03 are found and there last update is [%s]' % last_update_sla_date)
                    if last_update_sla_date < last_week:
                        self.logger_init.info('DEPRECATED FORMAT', 'Old working sla collection from version 02.03.03 is old enough to be remove')
                        _DEPRECATED_col_sla.drop()
                        self.repair_database()
            except Exception:
                _DEPRECATED_col_sla.drop()
                self.repair_database()
                self.logger_init.info('DEPRECATED FORMAT', 'Old working sla collection from version 02.03.03 are deleted')
            self.logger_init.info('DEPRECATED FORMAT', 'Old working sla collection from version 02.03.03 check donne in %s' % self.logger.format_chrono(time_start))
    
    
    def _init_collections(self):
        self.col_archive = self.get_collection('sla_archive')
        self.col_archive_before_migration = self.get_collection('sla_archive_version_0')
        self.col_acknowledge = self.get_collection('acknowledge')
        self.col_downtime = self.get_collection('downtime')
        self.col_sla_info = self.get_collection('sla_info')
        self.col_sla_future_states = self.get_collection('sla_future_states')
    
    
    def get_raw_sla_collection(self, date, on_write):
        collection_name = SLADatabaseConnection.get_sla_collection_name(date)
        raw_sla_collection = self.get_collection(collection_name)
        if on_write and not self.collection_exist(collection_name):
            raw_sla_collection.ensure_index([(RAW_SLA_KEY.UUID, mongo.ASCENDING)], name='uuid_idx')
        return raw_sla_collection
    
    
    def collection_exist(self, collection_name):
        if isinstance(collection_name, Date):
            collection_name = SLADatabaseConnection.get_sla_collection_name(collection_name)
        return MongoClient.collection_exist(self, collection_name)
    
    
    def rename_collection(self, old_name, new_name):
        if isinstance(old_name, Date):
            old_name = SLADatabaseConnection.get_sla_collection_name(old_name)
        return MongoClient.rename_collection(self, old_name, new_name)
    
    
    def drop_collection(self, collection_name):
        if isinstance(collection_name, Date):
            collection_name = SLADatabaseConnection.get_sla_collection_name(collection_name)
        return MongoClient.drop_collection(self, collection_name)
    
    
    @staticmethod
    def get_sla_collection_name(date):
        return 'has_been_archive_%s_%s' % date if shared_data.get_already_archived() else '%s_%s' % date
    
    
    @staticmethod
    def get_archived_sla_collection_name(date):
        return 'has_been_archive_%s_%s' % date
    
    
    def bulk_sla_factory(self, date, on_execute_error):
        return BulkSla(self.logger, self.get_raw_sla_collection(date, on_write=True), on_execute_error)
