#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2018:
# This file is part of Shinken Enterprise, all rights reserved.

import copy
import re
import time
import uuid

from pymongo import ASCENDING
from pymongo.errors import BulkWriteError

from dataprovider import DataProvider
from shinken.log import logger
from shinken.misc.fast_copy import fast_deepcopy
from shinken.misc.type_hint import TYPE_CHECKING
from shinkensolutions.date_helper import get_datetime_from_utc_to_local_time_zone
from ..def_items import DEF_ITEMS, ITEM_STATE, WORKING_AREA_STATUS, METADATA, NAGIOS_TABLE_KEYS, ITEM_TYPE
from ..helpers import split_list_attr
from ..transactions.transactions import get_transaction_object
from ...dao import DataException

if TYPE_CHECKING:
    from shinkensolutions.ssh_mongodb.mongo_client import MongoClient
    from shinken.misc.type_hint import NoReturn
    from ..callbacks.callback_history_info import HistoryEntry
    from ..crypto import Cipher

COLLECTIONS = {
    ITEM_STATE.NEW            : 'newelements-%s',
    ITEM_STATE.STAGGING       : 'configuration-stagging-%s',
    ITEM_STATE.PRODUCTION     : 'configuration-production-%s',
    ITEM_STATE.PREPROD        : 'configuration-preprod-%s',
    ITEM_STATE.MERGE_SOURCES  : 'merge_from_sources-%s',
    ITEM_STATE.RAW_SOURCES    : 'data-%s-%s',
    ITEM_STATE.CHANGES        : 'changeelements-%s',
    ITEM_STATE.DELETED        : 'deleted-stagging-%s',
    ITEM_STATE.WORKING_AREA   : 'configuration-working-area-%s',
    ITEM_STATE.SUBMIT_STAGGING: 'configuration-working-area-%s',
}


# TODO implement Lookup parameter for get functions
class DataProviderMongo(DataProvider):
    
    def __init__(self, mongo, database_cipher=None):
        # type: (MongoClient, Cipher) -> NoReturn
        self.mongo = mongo
        self.database_cipher = database_cipher
    
    
    @staticmethod
    def get_collection_name(item_type, item_state, item_source=''):
        try:
            table = DEF_ITEMS[item_type]['table']
            if item_state == ITEM_STATE.RAW_SOURCES:
                collection_name = COLLECTIONS.get(item_state, item_state) % (item_source, table)
            else:
                collection_name = COLLECTIONS.get(item_state, item_state) % table
        except Exception as e:
            raise DataException('[%s] Collections type[%s] - state[%s] not found %s' % (DataProviderMongo.__name__, item_type, item_state, e))
        
        return collection_name
    
    
    def get_collection(self, item_type, item_state, item_source=''):
        if not (isinstance(item_state, basestring)):
            raise DataException('[%s] get_collection : item_state [%s] must be a string' % (self.__class__.__name__, item_state))
        if not (isinstance(item_type, basestring)):
            raise DataException('[%s] get_collection : item_type [%s] must be a string' % (self.__class__.__name__, item_type))
        
        try:
            col_name = DataProviderMongo.get_collection_name(item_type, item_state, item_source)
            col = self.mongo.get_collection(col_name)
        except Exception as e:
            raise DataException('[%s] Collections (type:[%s]-state:[%s]) not found. [%s]' % (self.__class__.__name__, item_type, item_state, e))
        
        return col
    
    
    def find_item_by_name(self, item_name, item_type='', item_state='', item_source='', lookup=None):
        if not item_name:
            raise DataException('[%s] find_item_by_name : Please set item_name' % self.__class__.__name__)
        if not item_type:
            raise DataException('[%s] find_item_by_name : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] find_item_by_name : Please set item_state' % self.__class__.__name__)
        
        where = self._get_where(item_type, item_state)
        
        target_def_item = DEF_ITEMS[item_type]
        
        where.update({target_def_item['key_name']: re.compile(u'^%s$' % item_name, re.IGNORECASE)})
        
        collection = self.get_collection(item_type, item_state, item_source)
        item = collection.find_one(where)
        if self.database_cipher and item:
            item = self.database_cipher.uncipher(item, item_type, item_state)
        return item
    
    
    def find_item_by_name_with_case_sensitive(self, item_name, item_type='', item_state='', item_source='', lookup=None):
        if not item_name:
            raise DataException('[%s] find_item_by_name : Please set item_name' % self.__class__.__name__)
        if not item_type:
            raise DataException('[%s] find_item_by_name : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] find_item_by_name : Please set item_state' % self.__class__.__name__)
        
        where = self._get_where(item_type, item_state)
        
        target_def_item = DEF_ITEMS[item_type]
        
        where.update({target_def_item['key_name']: item_name})
        
        collection = self.get_collection(item_type, item_state, item_source)
        item = collection.find_one(where)
        if self.database_cipher and item:
            item = self.database_cipher.uncipher(item, item_type, item_state)
        return item
    
    
    def find_item_by_id(self, item_id, item_type='', item_state='', item_source='', lookup=None):
        if not item_id:
            raise DataException('[%s] find_item_by_id : Please set item_id' % self.__class__.__name__)
        if not item_type:
            raise DataException('[%s] find_item_by_id : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] find_item_by_id : Please set item_state' % self.__class__.__name__)
        
        where = self._get_where(item_type, item_state)
        where.update({'_id': item_id})
        
        collection = self.get_collection(item_type, item_state, item_source)
        item = collection.find_one(where)
        if self.database_cipher and item:
            item = self.database_cipher.uncipher(item, item_type, item_state)
        return item
    
    
    def find_items(self, item_type, item_state='', item_source='', where=None, lookup=None):
        if not item_type:
            raise DataException('[%s] find_items : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] find_items : Please set item_state' % self.__class__.__name__)
        
        if not where:
            where = {}
        
        if item_type == ITEM_TYPE.ELEMENTS:
            item_type = DEF_ITEMS.keys()
        if isinstance(item_type, basestring):
            item_type = [item_type]
        if isinstance(item_state, basestring):
            item_state = [item_state]
        
        items = []
        for current_type in item_type:
            for current_state in item_state:
                where_for_type = where
                if item_state != ITEM_STATE.CHANGES:
                    where_for_type = self._get_where(current_type, current_state)
                    where_for_type.update(where)
                collection = self.get_collection(current_type, current_state, item_source)
                items_for_type_and_state = list(collection.find(where_for_type, lookup))
                for item in items_for_type_and_state:
                    if self.database_cipher:
                        self.database_cipher.uncipher(item, current_type, current_state)
                    METADATA.update_metadata(item, METADATA.ITEM_TYPE, current_type)
                    METADATA.update_metadata(item, METADATA.STATE, current_state)
                items.extend(items_for_type_and_state)
        
        return items
    
    
    def find_merge_state_items(self, item_type, item_states, item_source='', where=None, lookup=None):
        _item_states = list(item_states)
        _item_states.reverse()
        
        items = {}
        for _item_state in _item_states:
            _items = self.find_items(item_type, _item_state, item_source, where, lookup)
            for _item in _items:
                METADATA.update_metadata(_item, METADATA.STATE, _item_state)
                items[_item['_id']] = _item
        return items.values()
    
    
    def find_double_link_items(self, item_link_name, item_type_link_to, item_type='', item_state='', item_source=''):
        if not item_link_name:
            raise DataException('[%s] find_double_link_items : Please set item_link_name' % self.__class__.__name__)
        if not item_type_link_to:
            raise DataException('[%s] find_double_link_items : Please set item_type_link_to' % self.__class__.__name__)
        if not item_type:
            raise DataException('[%s] find_double_link_items : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] find_double_link_items : Please set item_state' % self.__class__.__name__)
        
        item_link_name = item_link_name.strip()
        where = None
        double_links = DEF_ITEMS[item_type_link_to].get('double_links', [])
        target_attr = ""
        for double_link in double_links:
            if double_link['of_type'] == item_type:
                target_attr = double_link['is_link_with_attr']
                where = {target_attr: re.compile(re.escape(item_link_name), re.IGNORECASE)}
                break
        
        if where:
            items = list(self.find_items(item_type, item_state, where=where))
            filter_items = []
            
            for item in items:
                if item_link_name in split_list_attr(item, target_attr):
                    filter_items.append(self.database_cipher.uncipher(item, item_type, item_state) if self.database_cipher else item)
            
            return filter_items
        else:
            return []
    
    
    def save_item(self, item, item_type='', item_state='', item_source='', **kwargs):
        if not item:
            raise DataException('[%s] save_item : Please set item' % self.__class__.__name__)
        if not item_type:
            raise DataException('[%s] save_item : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] save_item : Please set item_state' % self.__class__.__name__)
        
        collection = self.get_collection(item_type, item_state, item_source)
        item_to_save = self.database_cipher.cipher(item, item_type, item_state) if self.database_cipher else item
        
        transaction = get_transaction_object()
        if transaction:
            col_name = collection.get_name()
            container = transaction.get_container(self.__class__.__name__)
            if not col_name in container:
                container[col_name] = {}
            if not 'save' in container[col_name]:
                container[col_name]['save'] = {}
            container[col_name]['save'][item_to_save['_id']] = fast_deepcopy(item_to_save)
        else:
            collection.save(item_to_save)
        
        return item
    
    
    def delete_item(self, item, item_type='', item_state='', item_source=''):
        if not item:
            raise DataException('[%s] delete_item : Please set item' % self.__class__.__name__)
        if not item_type:
            raise DataException('[%s] delete_item : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] delete_item : Please set item_state' % self.__class__.__name__)
        
        collection = self.get_collection(item_type, item_state, item_source)
        item_id = item['_id']
        transaction = get_transaction_object()
        if transaction:
            col_name = collection.get_name()
            container = transaction.get_container(self.__class__.__name__)
            if not col_name in container:
                container[col_name] = {}
            if not 'remove' in container[col_name]:
                container[col_name]['remove'] = {}
            container[col_name]['remove'][item_id] = item_id
        else:
            collection.remove({'_id': item_id})
    
    
    def count_items(self, item_type, item_state='', item_source='', where=None):
        if not item_type:
            raise DataException('[%s] count_items : Please set item_type' % self.__class__.__name__)
        if not item_state:
            raise DataException('[%s] count_items : Please set item_state' % self.__class__.__name__)
        
        collection = self.get_collection(item_type, item_state, item_source)
        
        if not where:
            where = {}
        # In change collections we don't store item just change.
        if not item_state == ITEM_STATE.CHANGES:
            where_for_type = self._get_where(item_type, item_state)
            where_for_type.update(where)
            where = where_for_type
        
        return collection.find(where, projection={'_id': 1}, only_count=True)
    
    
    def update_all_item(self, item_state, update):
        for table in NAGIOS_TABLE_KEYS.iterkeys():
            from_col = self.mongo.get_collection(COLLECTIONS[item_state] % table)
            from_col.update({}, update=update, upsert=False, multi=True)
    
    
    def copy_state(self, from_state, to_state):
        tmp_col_at = int(time.time())
        tmp_col_uuid = uuid.uuid1().hex
        for table in NAGIOS_TABLE_KEYS.iterkeys():
            from_col = self.mongo.get_collection(COLLECTIONS[from_state] % table)
            to_col = self.mongo.get_collection(COLLECTIONS[to_state] % table)
            from_items = list(from_col.find({}))
            
            tmp_col_name = 'tmpBatch-%s-%s-%s-%s-%s' % (from_state, to_state, tmp_col_at, tmp_col_uuid, table)
            tmp_col = self.mongo.get_collection(tmp_col_name)
            if len(from_items) != 0:
                tmp_col.insert_many(from_items)
                tmp_col.rename(COLLECTIONS[to_state] % table, dropTarget=True)
            else:
                to_col.remove()
    
    
    def _get_where(self, item_type, item_state):
        if item_state == ITEM_STATE.CHANGES:
            return {}
        target_def_item = DEF_ITEMS[item_type]
        where = target_def_item.get('where', {}).copy()
        if item_state == ITEM_STATE.SUBMIT_STAGGING:
            where.update({'work_area_info.status': WORKING_AREA_STATUS.PROPOSED})
        
        return where
    
    
    # Transaction stuff
    def start_transaction(self, transaction_uuid):
        pass
    
    
    def find_in_transaction(self, item_id, item_type, item_state):
        transaction = get_transaction_object()
        container = transaction.get_container(self.__class__.__name__)
        col_name = self.get_collection(item_type, item_state).get_name()
        # if we are in a transaction context, find if the raw change is not in a pending delete
        collection = container.get(col_name, {})
        if item_id in collection.get('save', {}):
            raw_change = copy.deepcopy(container[col_name]['save'][item_id])
        elif item_id in collection.get('remove', {}):
            raw_change = None
        else:
            raw_change = self.find_item_by_id(item_id, item_type, item_state)
        
        if raw_change:
            raw_change = self.database_cipher.uncipher(raw_change, item_type, item_state) if self.database_cipher else raw_change
        return raw_change
    
    
    def commit_transaction(self, transaction_uuid):
        # create a bulk operation for each collection to save
        transaction = get_transaction_object()
        nb_op = 0
        if transaction:
            if transaction.transaction_uuid != transaction_uuid:
                raise EnvironmentError("The transaction uuid mismatch with the commit transaction uuid")
            
            logger.debug('MONGO Transaction %s will be commit' % transaction_uuid)
            transaction_time = time.time()
            container = transaction.get_container(self.__class__.__name__)
            for col_name, action_dict in container.iteritems():
                saved_info = {}
                col_time = time.time()
                if not col_name in saved_info:
                    saved_info[col_name] = {}
                    saved_info[col_name]['save'] = len(action_dict.get('save', []))
                    saved_info[col_name]['remove'] = len(action_dict.get('remove', []))
                
                col = self.mongo.get_collection(col_name)
                
                try:
                    # update items
                    if action_dict.get('save', None):
                        col.replace_many([item for item in list(action_dict.get('save', {}).itervalues())], upsert=True)
                    
                    # remove items
                    if action_dict.get('remove', None):
                        col.delete_many({'_id': {'$in': list(action_dict.get('remove', {}).itervalues())}})
                        
                    for col_name, action_info in saved_info.iteritems():
                        nb_save = action_info['save']
                        nb_remove = action_info['remove']
                        nb_op += nb_save + nb_remove
                        if nb_save:
                            logger.debug('   %s insert/update %s items ([%.3f]s)' % (col_name, nb_save, time.time() - col_time))
                        if nb_remove:
                            logger.debug('   %s remove %s items ([%.3f]s)' % (col_name, nb_remove, time.time() - col_time))
                except BulkWriteError as bwe:
                    logger.error("commit transaction error [%s] items in collection [%s] failed." % (saved_info, col_name))
                    logger.error("BulkWriteError.details [%s]" % bwe.details)
                    raise
            logger.debug('   all mongo ops done in [%.3f]s' % (time.time() - transaction_time))
        return nb_op
    
    
    def save_history(self, history_entry):
        # type: (HistoryEntry) -> NoReturn
        
        collection = self.mongo.get_collection(u'history')
        history_entry = self.database_cipher.cipher_history_entry(history_entry) if self.database_cipher else history_entry
        
        transaction = get_transaction_object()
        if transaction:
            col_name = collection.get_name()
            container = transaction.get_container(self.__class__.__name__)
            if col_name not in container:
                container[col_name] = {}
            if u'save' not in container[col_name]:
                container[col_name][u'save'] = {}
            container[col_name][u'save'][history_entry[u'_id']] = fast_deepcopy(history_entry)
        else:
            collection.save(history_entry)
        
        return history_entry
    
    
    def find_history_for_item(self, item_id, item_type):
        # type: (unicode, unicode) -> list
        ret = []
        if not item_id:
            raise DataException(u'[%s] find_history_for_item : Please set item_id' % self.__class__.__name__)
        if not item_type:
            raise DataException(u'[%s] find_history_for_item : Please set item_type' % self.__class__.__name__)
        
        target_def_item = DEF_ITEMS[item_type]
        where = target_def_item.get(u'where', {}).copy()
        where.update({u'item_uuid': item_id})
        
        history_entries = self.mongo.get_collection(u'history').find(where, sort=[(u'date', ASCENDING)])
        if self.database_cipher:
            for entry in history_entries:
                entry[u'date'] = get_datetime_from_utc_to_local_time_zone(entry[u'date'])
                ret.append(self.database_cipher.uncipher_history_entry(entry))
        return ret
