#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022:
# This file is part of Shinken Enterprise, all rights reserved.

import sys
import time

from pymongo.errors import OperationFailure
from pymongo import WriteConcern

from mongo_bulk import InsertBulk, UpdateBulk, ReplaceBulk, UpsertUpdateBulk, BULK_TYPE, UpsertReplaceBulk
from mongo_error import MongoDriverException
from mongo_retry import retry_on_auto_reconnect
from shinken.log import LoggerFactory
from shinken.misc.type_hint import TYPE_CHECKING
from shinkensolutions.ssh_mongodb.mongo_bulk import AbstractBulk
from shinkensolutions.ssh_mongodb.mongo_pid_changed import raise_if_pid_changed

if TYPE_CHECKING:
    import pymongo.database
    from pymongo.results import UpdateResult
    from shinken.misc.type_hint import Number, Callable, Optional, List, Tuple, Dict, Union, Any, Set, Type, Iterator
    from shinken.log import PartLogger
    from pymongo.results import UpdateResult

_logger = LoggerFactory.get_logger()


class MongoCollection(object):
    
    def __init__(self, collection_name, database, auto_reconnect_max_try, auto_reconnect_sleep_between_try, reconnect_method, pid_used_to_start_connection, logger=None):
        # type: (unicode, pymongo.database.Database, Number, Number, Callable, Number, PartLogger) -> None
        self._collection_name = collection_name
        self._database = database
        self._collection = database[collection_name]
        self._auto_reconnect_max_try = auto_reconnect_max_try
        self._auto_reconnect_sleep_between_try = auto_reconnect_sleep_between_try
        self._mongo_client_reconnect_method = reconnect_method
        self.pid_used_to_start_connection = pid_used_to_start_connection
        
        if logger:
            self.logger = logger
        else:
            self.logger = _logger
    
    
    def _retry_connection(self):
        database = self._mongo_client_reconnect_method()
        
        if database:
            self._database = database
            self._collection = database[self._collection_name]
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def drop(self):
        # type: () -> None
        self._collection.drop()
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def get_size(self):
        # type: () -> Number
        col_stats = self._database.command('collstats', self._collection_name)
        total_size = col_stats['storageSize'] + col_stats['totalIndexSize']
        return total_size
    
    
    # @raise_if_pid_changed
    # @retry_on_auto_reconnect()
    # def count(self):
    #     return self._collection.find().count()
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def find(self, filter=None, projection=None, limit=0, only_count=False, sort=None, next=False, skip=0, hint=None, modifiers=None):
        # type: (Optional[Dict], Optional[Union[List, Dict]], Number, bool, Optional[[List[Tuple]]], bool, Number, Optional[List[Tuple[unicode, int]]], Optional[Dict[unicode, Any]]) -> Union[List[Dict[unicode, Any]], Number, None, Dict[unicode, Any]]
        
        if modifiers is None:
            modifiers = {}
        # NOTE: since mongo 3.4, the modifers= is no more allowed, need to call **modifiers instead
        cursor = self._collection.find(filter=filter, projection=projection, **modifiers)
        
        if sort:
            if isinstance(sort, list):
                cursor.sort(sort)
            else:
                cursor.sort(*sort)
        
        if skip:
            cursor.skip(skip)
        
        if limit:
            cursor.limit(limit)
        
        if hint:
            cursor.hint(hint)
        
        if only_count:
            return cursor.count()  # type: Number
        
        if next:
            try:
                return cursor.next()
            except StopIteration:
                return None
        return list(cursor)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def find_one(self, filter=None, projection=None):
        # type: (Optional[Dict], Optional[Union[Dict, List]]) -> Optional[Dict]
        cursor = self._collection.find_one(filter=filter, projection=projection)
        return cursor
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def remove(self, filter=None, w=None, modifiers=None):
        # type: (Optional[Dict], Optional[Number], Optional[Dict]) -> Optional[Dict]
        
        if modifiers is None:
            modifiers = {}
        
        # NOTE: since mongo 3.4, we must call a copy of the collection with a writeconcern
        # if we need to have another write type (like 1==flush on disk before return)
        col = self._collection
        if w != 0:
            col = col.with_options(write_concern=WriteConcern(w=w))
        # NOTE: since mongo 3.4, the modifers= is no more allowed, need to call **modifiers instead
        return col.remove(filter, **modifiers)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def rename(self, new_name, dropTarget=None):
        # type: (unicode, Optional[bool]) -> Optional[Dict]
        kwargs = {}
        if dropTarget is not None:
            kwargs['dropTarget'] = dropTarget
        
        self._collection_name = new_name
        return self._collection.rename(new_name, **kwargs)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def save(self, document=None, w=None):
        # type: (Optional[Dict], Number) -> Optional[Dict]
        if w is None:
            return self._collection.save(document)
        return self._collection.save(document, w=w)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def insert(self, document=None):
        # type: (Optional[Dict]) -> Optional[Dict]
        return self._collection.save(document)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def update(self, filter=None, entry_id=None, set_update=None, update=None, upsert=False, multi=False):
        # type: (Optional[Dict], Optional[Any], Optional[Dict], Optional[Dict], bool, bool) -> Optional[Dict]
        if filter is None:
            filter = {'_id': entry_id}
        if upsert and set_update and entry_id:
            set_update['_id'] = entry_id
        if update is None:
            update = {'$set': set_update}
        return self._collection.update(filter, update, upsert=upsert, multi=multi)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def update_many(self, filter, update):
        # type: (Dict, Dict) -> UpdateResult
        return self._collection.update_many(filter, update)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def ensure_index(self, definition, name, expire_after_seconds=None, unique=False):
        # type: (List[Union[List, Tuple]], unicode, Optional[Number], Optional[bool]) -> None
        if expire_after_seconds is not None:
            idx_info = self.index_information()
            if name in idx_info and idx_info.get(name, {}).get('expireAfterSeconds', -42) != expire_after_seconds:
                self.logger.info('The expireAfterSeconds of the index [%s] of collection [%s] need to be update to the new value %ss (%s).' % (name, self.get_name(), expire_after_seconds, self.logger.format_duration(expire_after_seconds)))
                return_update_index = self._database.command('collMod', self.get_name(), index={'keyPattern': {definition[0][0]: definition[0][1]}, 'expireAfterSeconds': expire_after_seconds})
                self.logger.debug('Result of update index:[%s]' % return_update_index)
            
            self._collection.ensure_index(definition, name=name, expireAfterSeconds=expire_after_seconds, unique=unique, cache_for=0)
        else:
            self._collection.ensure_index(definition, name=name, unique=unique, cache_for=0)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def index_information(self):
        # type: () -> Dict
        try:
            return self._collection.index_information()
        except OperationFailure as e:
            # No database -> no index
            if e.code == 26:
                return {}
            raise
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def aggregate(self, aggregate):
        # type: (List) -> pymongo.cursor.Cursor
        return self._collection.aggregate(aggregate)
    
    
    def delete_many(self, delete_filter, bulk_size=10000, pause_time=1):
        # type: (Union[Dict,List], Optional[int], Number) -> None
        if isinstance(delete_filter, list):
            to_delete_ids = delete_filter
            if bulk_size is None:
                self.remove({'_id': {'$in': to_delete_ids}})
            else:
                nb_documents = len(to_delete_ids)
                for i in xrange(sys.maxint):
                    lower_bound = i * bulk_size
                    upper_bound = lower_bound + bulk_size
                    self.remove({'_id': {'$in': to_delete_ids[lower_bound:upper_bound]}})
                    if upper_bound >= nb_documents:
                        break
                    time.sleep(pause_time)
            return
        if bulk_size is None:
            self.remove(delete_filter)
        else:
            to_delete_ids = delete_filter.get(u'_id', {}).get(u'$in', None)
            if not to_delete_ids:
                for i in xrange(sys.maxint):
                    lower_bound = i * bulk_size
                    to_delete_ids = [doc[u'_id'] for doc in self.find(delete_filter, projection={u'_id': 1}, limit=bulk_size, skip=lower_bound)]
                    if not to_delete_ids:
                        break
                    self.remove({u'_id': {u'$in': to_delete_ids}})
                    time.sleep(pause_time)
            else:
                nb_documents = len(to_delete_ids)
                for i in xrange(sys.maxint):
                    lower_bound = i * bulk_size
                    upper_bound = lower_bound + bulk_size
                    self.remove({u'_id': {u'$in': to_delete_ids[lower_bound:upper_bound]}})
                    if upper_bound >= nb_documents:
                        break
                    time.sleep(pause_time)
    
    
    def insert_many(self, documents, bulk_size=10000, pause_time=1):
        # type: (List, Optional[Number], Number) -> None
        self._bulk_action(documents, BULK_TYPE.INSERT, bulk_size, pause_time)
    
    
    def replace_many(self, documents, bulk_size=10000, pause_time=1, upsert=False):
        # type: (List, Number, Number, bool) -> None
        if upsert:
            self._bulk_action(documents, BULK_TYPE.UPSERT_REPLACE, bulk_size, pause_time)
        else:
            self._bulk_action(documents, BULK_TYPE.REPLACE, bulk_size, pause_time)
    
    
    def upsert_update_many(self, documents, bulk_size=10000, pause_time=1):
        # type: (List, Number, Number) -> None
        self._bulk_action(documents, BULK_TYPE.UPSERT_UPDATE, bulk_size, pause_time)
    
    
    def get_bulk(self, bulk_type, order=False):
        # type: (Union[unicode, Type[AbstractBulk]], bool) -> Union[AbstractBulk,InsertBulk,UpdateBulk,ReplaceBulk]
        if bulk_type == BULK_TYPE.INSERT:
            return InsertBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, self._retry_connection, order=order)
        if bulk_type == BULK_TYPE.UPDATE:
            return UpdateBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, self._retry_connection, order=order)
        if bulk_type == BULK_TYPE.REPLACE:
            return ReplaceBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, self._retry_connection, order=order)
        if bulk_type == BULK_TYPE.UPSERT_UPDATE:
            return UpsertUpdateBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, self._retry_connection, order=order)
        if bulk_type == BULK_TYPE.UPSERT_REPLACE:
            return UpsertReplaceBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, self._retry_connection, order=order)
        if issubclass(bulk_type, AbstractBulk):
            return bulk_type(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, self._retry_connection, order=order)
        raise MongoDriverException('you must chose a bulk for get_bulk')
    
    
    def get_name(self):
        # type: () -> unicode
        return self._collection_name
    
    
    @property
    def name(self):
        return self._collection_name
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def _bulk_action(self, documents, bulk_type, bulk_size, pause_time):
        # type: (List, unicode, Optional[Number], Number) -> None
        if bulk_size is None:
            bulk = self.get_bulk(bulk_type, order=False)
            bulk.set_all(documents)
            bulk.execute()
        else:
            nb_documents = len(documents)
            for i in xrange(sys.maxint):
                lower_bound = i * bulk_size
                upper_bound = lower_bound + bulk_size
                bulk = self.get_bulk(bulk_type, order=False)
                bulk.set_all(documents[lower_bound:upper_bound])
                bulk.execute()
                if upper_bound >= nb_documents:
                    break
                time.sleep(pause_time)
    
    
    def is_connection_available(self):
        try:
            self._database.command(u'serverStatus')
            return True
        except:
            return False
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def count(self):
        return self._collection.count()
