#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2019
# This file is part of Shinken Enterprise, all rights reserved.

import time

from mongo_error import MongoDriverException
from mongo_retry import retry_on_auto_reconnect, AutoReconnect


class BULK_TYPE(object):
    UPDATE = 'update'
    REPLACE = 'replace'
    INSERT = 'insert'
    UPSERT = 'upsert'


class AbstractBulk(object):
    
    def __init__(self, collection, collection_name, auto_reconnect_max_try, auto_reconnect_sleep_between_try, logger, order=False):
        self.logger = logger
        self._collection = collection
        self._collection_name = collection_name
        self._ops = []
        self._order = order
        self._auto_reconnect_max_try = auto_reconnect_max_try
        self._auto_reconnect_sleep_between_try = auto_reconnect_sleep_between_try
    
    
    def set_all(self, values):
        raise NotImplementedError()
    
    
    def _stack_action(self, bulk_ops):
        raise NotImplementedError()
    
    
    @retry_on_auto_reconnect
    def execute(self):
        if self._order:
            bulk_ops = self._collection.initialize_ordered_bulk_op()
        else:
            bulk_ops = self._collection.initialize_unordered_bulk_op()
        
        self._stack_action(bulk_ops)
        
        bulk_result = bulk_ops.execute()
        self._ops = []
        return bulk_result


class UpdateBulk(AbstractBulk):
    
    def update(self, find, value):
        self._ops.append((find, value))
    
    
    def set_all(self, values):
        self._ops = [({'_id': v['_id']}, v) for v in values]
    
    
    def _stack_action(self, bulk_ops):
        for find, value in self._ops:
            bulk_ops.find(find).update(value)


class UpsertBulk(UpdateBulk):
    
    def _stack_action(self, bulk_ops):
        for find, value in self._ops:
            bulk_ops.find(find).upsert().update({'$set': value})


class ReplaceBulk(AbstractBulk):
    
    def replace(self, find, value):
        self._ops.append((find, value))
    
    
    def set_all(self, values):
        self._ops = [({'_id': v['_id']}, v) for v in values]
    
    
    def _stack_action(self, bulk_ops):
        for find, value in self._ops:
            bulk_ops.find(find).replace_one(value)


class InsertBulk(AbstractBulk):
    
    # Not use here : we override execute
    def _stack_action(self, bulk_ops):
        pass
    
    
    def insert(self, value):
        self._ops.append(value)
    
    
    def set_all(self, values):
        self._ops = values
    
    
    def _unsafe_insert_many(self, documents):
        if self._order:
            bulk_ops = self._collection.initialize_ordered_bulk_op()
        else:
            bulk_ops = self._collection.initialize_unordered_bulk_op()
        
        for document in documents:
            if '_id' not in document:
                raise MongoDriverException('_id must be define in all document to call insert_many')
            bulk_ops.insert(document)
        
        return bulk_ops.execute()
    
    
    def _filter_already_insert(self):
        ids = [d['_id'] for d in self._ops]
        already_in = self._collection.find({'_id': {'$in': ids}})
        to_retry = set(ids) - set(already_in)
        return [d for d in self._ops if d['_id'] in to_retry]
    
    
    def _resume_insert_many(self):
        left_round = self._auto_reconnect_max_try - 1  # We already have done a try here
        operation_name = 'filter_already_insert in %s' % self._collection_name
        
        while left_round:
            try:
                to_retry = self._filter_already_insert()
                operation_name = 'insert_many in %s (retry on %s/%s element)' % (self._collection_name, len(to_retry), len(self._ops))
                return self._unsafe_insert_many(to_retry)
            except AutoReconnect as e:
                if left_round <= 0:
                    self.logger.info('[MONGO] ' 'Mongo ask a AutoReconnect on the operation %s and we try %s time but it keep failing' % (operation_name, self._auto_reconnect_max_try))
                    raise e
                self.logger.info('[MONGO] ' 'Mongo ask a AutoReconnect on the operation %s retrying %s/%s' % (operation_name, self._auto_reconnect_max_try - left_round + 1, self._auto_reconnect_max_try))
                time.sleep(self._auto_reconnect_sleep_between_try)
                left_round -= 1
    
    
    def execute(self):
        try:
            bulk_result = self._unsafe_insert_many(self._ops)
        except AutoReconnect:
            bulk_result = self._resume_insert_many()
        self._ops = []
        return bulk_result
