#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022:
# This file is part of Shinken Enterprise, all rights reserved.

import time

from shinken.misc.type_hint import TYPE_CHECKING, cast
from shinken.objects.module import Module as ShinkenModuleDefinition
from shinkensolutions.data_hub.data_hub_driver.abstract_data_hub_driver_database import AbstractDataHubDriverDatabase, AbstractDataHubDriverConfigDatabase
from shinkensolutions.data_hub.data_hub_exception.data_hub_exception import DataHubItemNotFound, DataHubException
from shinkensolutions.data_hub.data_hub_factory.data_hub_factory import DataHubFactory
from shinkensolutions.ssh_mongodb.mongo_client import MongoClient
from shinkensolutions.ssh_mongodb.mongo_conf import MongoConf

if TYPE_CHECKING:
    from shinken.misc.type_hint import Optional, Dict, Any, Tuple, List
    from shinken.log import PartLogger
    from shinkensolutions.ssh_mongodb.mongo_collection import MongoCollection
    from shinkensolutions.data_hub.data_hub_driver.abstract_data_hub_driver import AbstractDataHubDriverConfig
    from shinkensolutions.data_hub.data_hub import DataHubConfig


class DATA_KEY(object):
    LAST_MODIFICATION_DATE = u'data_hub_last_modification_date'
    DATA_ID = u'_id'


class DataHubMongoWrongId(DataHubException):
    def __init__(self, data_id=u'', saved_id=u'', message=u''):
        if not message:
            message = u'Can\'t save data with ID [%s] because it has another ID [%s] in its data' % (data_id, saved_id)
        super(DataHubMongoWrongId, self).__init__(message)


class DataHubDriverConfigMongo(AbstractDataHubDriverConfigDatabase):
    def __init__(self,
                 data_location_name,
                 database,
                 uri,
                 replica_set,
                 use_ssh_tunnel,
                 ssh_user,
                 ssh_keyfile,
                 ssh_tunnel_timeout,
                 auto_reconnect_max_try,
                 auto_reconnect_sleep_between_try,
                 namespace=u'',
                 daemon_name=u'',
                 module_name=u'',
                 submodule_name=u'',
                 write_concerns=None):
        # type: (unicode, unicode, unicode, unicode, bool, unicode, unicode, int, int, int, unicode, unicode, unicode, unicode, int) -> None
        super(DataHubDriverConfigMongo, self).__init__(u'MONGO', namespace)
        
        self.data_location_name = data_location_name
        self.database = database
        self.uri = uri
        self.replica_set = replica_set
        self.use_ssh_tunnel = use_ssh_tunnel
        self.ssh_user = ssh_user
        self.ssh_keyfile = ssh_keyfile
        self.ssh_tunnel_timeout = ssh_tunnel_timeout
        self.auto_reconnect_max_try = auto_reconnect_max_try
        self.auto_reconnect_sleep_between_try = auto_reconnect_sleep_between_try
        self.daemon_name = daemon_name
        self.module_name = module_name
        self.submodule_name = submodule_name
        self.write_concerns = write_concerns


def data_hub_driver_mongo_factory(logger, driver_config, _data_hub_config):
    # type: (PartLogger, AbstractDataHubDriverConfig, DataHubConfig) -> DataHubDriverMongo
    return DataHubDriverMongo(logger, cast(DataHubDriverConfigMongo, driver_config))


DataHubFactory.register_driver_factory(DataHubDriverConfigMongo, data_hub_driver_mongo_factory)


class DataHubDriverMongo(AbstractDataHubDriverDatabase, MongoClient):
    def __init__(self, logger, driver_config):
        # type: (PartLogger, DataHubDriverConfigMongo) -> None
        
        configuration = MongoConf(ShinkenModuleDefinition({
            u'database'                        : driver_config.database,
            u'uri'                             : driver_config.uri,
            u'replica_set'                     : driver_config.replica_set,
            u'use_ssh_tunnel'                  : driver_config.use_ssh_tunnel,
            u'ssh_user'                        : driver_config.ssh_user,
            u'ssh_keyfile'                     : driver_config.ssh_keyfile,
            u'ssh_tunnel_timeout'              : driver_config.ssh_tunnel_timeout,
            u'auto_reconnect_max_try'          : driver_config.auto_reconnect_max_try,
            u'auto_reconnect_sleep_between_try': driver_config.auto_reconnect_sleep_between_try,
        }))
        
        MongoClient.__init__(self, configuration, logger=logger, log_database_parameters=True)
        AbstractDataHubDriverDatabase.__init__(self, logger, driver_config)
        
        # logger_init is define in MongoClient
        self.logger_init = self._logger_init
        
        self._data_location_name = driver_config.data_location_name
        self._daemon_name = driver_config.daemon_name
        self._module_name = driver_config.module_name
        self._submodule_name = driver_config.submodule_name
        self._write_concerns = getattr(driver_config, u'write_concerns', None)
        self._collection = None  # type: Optional[MongoCollection]
    
    
    def init(self):
        # type: () -> None
        MongoClient.init(self, requester=u'data_hub %s' % self._name)
        self._collection = self.get_collection(self._compute_collection_name())
    
    
    def _compute_collection_name(self):
        # type: () -> unicode
        return u'__'.join([self._namespace, self._data_location_name])
    
    
    def write(self, data_id, data):
        # type: (unicode, Dict[unicode, Any]) -> None
        data = data.copy()  # We need to add the data_hub_last_modification_date, so we create a copy of the dict to not edit it
        actual_id = data.get(DATA_KEY.DATA_ID, None)
        if actual_id and actual_id != data_id:
            raise DataHubMongoWrongId(data_id, actual_id)
        data[DATA_KEY.DATA_ID] = data_id
        data[DATA_KEY.LAST_MODIFICATION_DATE] = time.time()
        if self._write_concerns is not None:
            self._collection.save(data, w=self._write_concerns)
        else:
            self._collection.save(data, w=1)
    
    
    def remove(self, data_id):
        # type: (unicode) -> None
        if self._write_concerns is not None:
            self._collection.remove({DATA_KEY.DATA_ID: data_id}, w=self._write_concerns)
        else:
            self._collection.remove({DATA_KEY.DATA_ID: data_id})
    
    
    def get_last_modification_date(self, data_id):
        # type: (unicode) -> int
        data = self._collection.find_one({DATA_KEY.DATA_ID: data_id})
        if not data:
            raise DataHubItemNotFound(self._data_type, data_id)
        return data.get(DATA_KEY.LAST_MODIFICATION_DATE, 0)
    
    
    def read_and_get_last_modification_date(self, data_id, log_error=True):
        # type: (unicode, bool) -> Tuple[Any, int]
        
        data = self._collection.find_one({DATA_KEY.DATA_ID: data_id})
        if data:
            last_modification_date = data.pop(DATA_KEY.LAST_MODIFICATION_DATE, 0)
        else:
            raise DataHubItemNotFound(self._data_type, data_id)
        return data, last_modification_date
    
    
    def read(self, data_id, log_error=True):
        # type: (unicode, bool) -> Any
        data = self._collection.find_one({DATA_KEY.DATA_ID: data_id})
        if data:
            data.pop(DATA_KEY.LAST_MODIFICATION_DATE, 0)
        if not data:
            raise DataHubItemNotFound(self._data_type, data_id)
        return data
    
    
    def get_all_data_id(self):
        # type: () -> List[unicode]
        return [data[DATA_KEY.DATA_ID] for data in self._collection.find({})]
    
    
    def is_data_correct(self, data_id):
        # type: (unicode) -> bool
        if not self.read(data_id, log_error=False):
            return False
        return True
    
    
    def destroy(self):
        # type: () -> None
        self._collection.drop()
    
    
    def get_number_of_stored_data(self):
        # type: () -> int
        return self._collection.count()
    
    
    def get_total_size(self):
        raise NotImplementedError
