#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022
# This file is part of Shinken Enterprise, all rights reserved.

import time

from pymongo.errors import AutoReconnect

from mongo_error import MongoDriverException, ShinkenMongoException
from mongo_retry import retry_on_auto_reconnect
from shinken.misc.type_hint import TYPE_CHECKING

if TYPE_CHECKING:
    from pymongo.bulk import BulkOperationBuilder
    from pymongo.collection import Collection
    from shinken.misc.type_hint import Number, Callable, List, Dict, Optional
    from shinken.log import PartLogger


class BULK_TYPE(object):
    UPDATE = u'update'
    REPLACE = u'replace'
    INSERT = u'insert'
    UPSERT_UPDATE = u'upsert_update'
    UPSERT_REPLACE = u'upsert_replace'
    REMOVE = u'remove'


class AbstractBulk(object):
    
    def __init__(self, collection, collection_name, auto_reconnect_max_try, auto_reconnect_sleep_between_try, logger, reconnect_method, order=False):
        # type: (Collection, unicode, Number, Number, PartLogger, Callable, bool) -> None
        self.logger = logger
        self._collection = collection
        self._collection_name = collection_name
        self._ops = []
        self._order = order
        self._auto_reconnect_max_try = auto_reconnect_max_try
        self._auto_reconnect_sleep_between_try = auto_reconnect_sleep_between_try
        self._mongo_collection_reconnect_method = reconnect_method
    
    
    def _retry_connection(self):
        # type: () -> None
        collection = self._mongo_collection_reconnect_method()
        if collection:
            self._collection = collection
    
    
    def set_all(self, values):
        # type: (List) -> None
        raise NotImplementedError()
    
    
    def _stack_action(self, bulk_ops):
        # type: (List) -> None
        raise NotImplementedError()
    
    
    @retry_on_auto_reconnect()
    def execute(self):
        # type: () -> None
        if not self._ops:
            return
        
        if self._order:
            bulk_ops = self._collection.initialize_ordered_bulk_op()
        else:
            bulk_ops = self._collection.initialize_unordered_bulk_op()
        
        self._stack_action(bulk_ops)
        
        bulk_result = bulk_ops.execute()
        self._ops = []
        return bulk_result


class UpdateBulk(AbstractBulk):
    
    def update(self, find, value):
        # type: (Dict, Dict) -> None
        self._ops.append((find, value))
    
    
    def set_all(self, values):
        # type: (List) -> None
        self._ops = [({u'_id': v[u'_id']}, v) for v in values]
    
    
    def _stack_action(self, bulk_ops):
        # type: (BulkOperationBuilder) -> None
        for find, value in self._ops:
            bulk_ops.find(find).update(value)


class UpsertUpdateBulk(UpdateBulk):
    
    def _stack_action(self, bulk_ops):
        # type: (BulkOperationBuilder) -> None
        for find, value in self._ops:
            bulk_ops.find(find).upsert().update({u'$set': value})


class ReplaceBulk(AbstractBulk):
    
    def replace(self, find, value):
        # type: (Dict, Dict) -> None
        self._ops.append((find, value))
    
    
    def set_all(self, values):
        # type: (List) -> None
        self._ops = [({u'_id': v[u'_id']}, v) for v in values]
    
    
    def _stack_action(self, bulk_ops):
        # type: (BulkOperationBuilder) -> None
        for find, value in self._ops:
            bulk_ops.find(find).replace_one(value)


class UpsertReplaceBulk(ReplaceBulk):
    
    def _stack_action(self, bulk_ops):
        # type: (BulkOperationBuilder) -> None
        for find, value in self._ops:
            bulk_ops.find(find).upsert().replace_one(value)


class InsertBulk(AbstractBulk):
    
    # Not use here : we override execute
    def _stack_action(self, bulk_ops):
        # type: (BulkOperationBuilder) -> None
        pass
    
    
    def insert(self, value):
        # type: (Dict) -> None
        self._ops.append(value)
    
    
    def set_all(self, values):
        # type: (List) -> None
        self._ops = values
    
    
    def _unsafe_insert_many(self, documents):
        # type: (List) -> Dict
        if self._order:
            bulk_ops = self._collection.initialize_ordered_bulk_op()
        else:
            bulk_ops = self._collection.initialize_unordered_bulk_op()
        
        for document in documents:
            if u'_id' not in document:
                raise MongoDriverException(u'_id must be define in all document to call insert_many')
            bulk_ops.insert(document)
        
        return bulk_ops.execute()
    
    
    def _filter_already_insert(self):
        # type: () -> List
        ids = [d[u'_id'] for d in self._ops]
        already_in = [d[u'_id'] for d in self._collection.find({u'_id': {u'$in': ids}}, {u'_id': 1})]
        to_retry = set(ids) - set(already_in)
        return [d for d in self._ops if d[u'_id'] in to_retry]
    
    
    def _resume_insert_many(self):
        # type: () -> Dict
        left_round = self._auto_reconnect_max_try - 1  # We already have done a try here
        operation_name = u'filter_already_insert in %s' % self._collection_name
        while True:
            try:
                to_retry = self._filter_already_insert()
                operation_name = u'insert_many in %s (retry on %s/%s element)' % (self._collection_name, len(to_retry), len(self._ops))
                return self._unsafe_insert_many(to_retry)
            except AutoReconnect as e:
                if left_round <= 1:
                    error_msg = u'Mongo raised ( %s ) on the operation %s. Operation failed : %s/%s. We tried %s times but it kept failing.' % (
                        e, operation_name, self._auto_reconnect_max_try - left_round + 1, self._auto_reconnect_max_try, self._auto_reconnect_max_try)
                    self.logger.error(error_msg)
                    raise ShinkenMongoException(error_msg)
                else:
                    self.logger.info(u'Mongo raised ( %s ) on the operation %s. Operation failed : %s/%s' % (e, operation_name, self._auto_reconnect_max_try - left_round + 1, self._auto_reconnect_max_try))
                time.sleep(self._auto_reconnect_sleep_between_try)
                self._retry_connection()
                left_round -= 1
    
    
    def execute(self):
        # type: () -> Optional[Dict]
        if not self._ops:
            return
        try:
            bulk_result = self._unsafe_insert_many(self._ops)
        except Exception:
            bulk_result = self._resume_insert_many()
        self._ops = []
        return bulk_result
