# !/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2019
# This file is part of Shinken Enterprise, all rights reserved.

from pymongo.errors import OperationFailure
from shinken.log import LoggerFactory
from mongo_bulk import InsertBulk, UpdateBulk, ReplaceBulk, UpsertBulk, BULK_TYPE
from mongo_error import MongoDriverException
from mongo_retry import retry_on_auto_reconnect

_logger = LoggerFactory.get_logger()


class MongoCollection(object):
    
    def __init__(self, collection_name, database, auto_reconnect_max_try, auto_reconnect_sleep_between_try, logger=None):
        self._collection_name = collection_name
        self._database = database
        self._collection = getattr(database, collection_name)
        self._auto_reconnect_max_try = auto_reconnect_max_try
        self._auto_reconnect_sleep_between_try = auto_reconnect_sleep_between_try
        
        if logger:
            self.logger = logger
        else:
            self.logger = _logger
    
    
    @retry_on_auto_reconnect
    def drop(self):
        self._collection.drop()
    
    
    @retry_on_auto_reconnect
    def get_size(self):
        col_stats = self._database.command('collstats', self._collection_name)
        total_size = col_stats['storageSize'] + col_stats['totalIndexSize']
        return total_size
    
    
    @retry_on_auto_reconnect
    def find(self, filter=None, projection=None, limit=0, only_count=False, sort=None, next=False, skip=0):
        cursor = self._collection.find(filter=filter, projection=projection)
        
        if sort:
            cursor.sort(*sort)
        
        if skip:
            cursor.skip(skip)
        
        if limit:
            cursor.limit(limit)
        
        if only_count:
            return cursor.count()
        
        if next:
            try:
                return cursor.next()
            except StopIteration:
                return None
        
        return list(cursor)
    
    
    @retry_on_auto_reconnect
    def find_one(self, filter=None, projection=None):
        cursor = self._collection.find_one(filter=filter, projection=projection)
        return cursor
    
    
    @retry_on_auto_reconnect
    def remove(self, filter=None):
        return self._collection.remove(filter)
    
    
    @retry_on_auto_reconnect
    def save(self, document=None):
        return self._collection.save(document)
    
    
    @retry_on_auto_reconnect
    def insert(self, document=None):
        return self._collection.save(document)
    
    
    @retry_on_auto_reconnect
    def update(self, query=None, entry_id=None, set_update=None, update=None, upsert=False):
        if query is None:
            query = {'_id': entry_id}
        if upsert and set_update and entry_id:
            set_update['_id'] = entry_id
        if update is None:
            update = {'$set': set_update}
        return self._collection.update(query, update, upsert=upsert)
    
    
    @retry_on_auto_reconnect
    def ensure_index(self, definition, name, expire_after_seconds=None):
        if expire_after_seconds is not None:
            idx_info = self.index_information()
            if name in idx_info and idx_info.get(name, {}).get('expireAfterSeconds', -42) != expire_after_seconds:
                self.logger.info('The expireAfterSeconds of the index [%s] of collection [%s] need to be update to the new value %ss (%s).' % (name, self.get_name(), expire_after_seconds, self.logger.format_duration(expire_after_seconds)))
                return_update_index = self._database.command('collMod', self.get_name(), index={'keyPattern': {definition[0][0]: definition[0][1]}, 'expireAfterSeconds': expire_after_seconds})
                self.logger.debug('Result of update index:[%s]' % return_update_index)
            
            self._collection.ensure_index(definition, name=name, expireAfterSeconds=expire_after_seconds)
        else:
            self._collection.ensure_index(definition, name=name)
    
    
    @retry_on_auto_reconnect
    def index_information(self):
        try:
            return self._collection.index_information()
        except OperationFailure as e:
            # No database -> no index
            if e.code == 26:
                return {}
            raise
    
    
    @retry_on_auto_reconnect
    def aggregate(self, aggregate):
        return self._collection.aggregate(aggregate)
    
    
    def insert_many(self, documents):
        bulk = self.get_bulk(BULK_TYPE.INSERT, order=False)
        bulk.set_all(documents)
        bulk.execute()
    
    
    def replace_many(self, documents):
        bulk = self.get_bulk(BULK_TYPE.REPLACE, order=False)
        bulk.set_all(documents)
        bulk.execute()
    
    
    def upsert_many(self, documents):
        bulk = self.get_bulk(BULK_TYPE.UPSERT, order=False)
        bulk.set_all(documents)
        bulk.execute()
    
    
    def get_bulk(self, bulk_type, order=False):
        if bulk_type == BULK_TYPE.INSERT:
            return InsertBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, order=order)
        if bulk_type == BULK_TYPE.UPDATE:
            return UpdateBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, order=order)
        if bulk_type == BULK_TYPE.REPLACE:
            return ReplaceBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, order=order)
        if bulk_type == BULK_TYPE.UPSERT:
            return UpsertBulk(self._collection, self._collection_name, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self.logger, order=order)
        raise MongoDriverException('you must chose a bulk for get_bulk')
    
    
    def get_name(self):
        return self._collection_name
