#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022:
# This file is part of Shinken Enterprise, all rights reserved.
import json
import time

from shinken.misc.fast_copy import fast_deepcopy
from shinken.misc.type_hint import TYPE_CHECKING, cast
from shinken.objects.module import Module as ShinkenModuleDefinition
from shinken.toolbox.pickledb import ShinkenPickleableMeta
from shinkensolutions.data_hub.data_hub_driver.abstract_data_hub_driver_database import AbstractDataHubDriverDatabase, AbstractDataHubDriverConfigDatabase
from shinkensolutions.data_hub.data_hub_exception.data_hub_exception import DataHubItemNotFound, DataHubException
from shinkensolutions.data_hub.data_hub_factory.data_hub_factory import DataHubFactory
from shinkensolutions.ssh_mongodb.mongo_client import MongoClient
from shinkensolutions.ssh_mongodb.mongo_conf import MongoConf
from shinkensolutions.context_helper import nullcontext

if TYPE_CHECKING:
    from shinken.misc.type_hint import Optional, Dict, Any, Tuple, List, ContextManager
    from shinken.log import PartLogger
    from shinkensolutions.ssh_mongodb.mongo_collection import MongoCollection
    from shinkensolutions.data_hub.data_hub_driver.abstract_data_hub_driver import AbstractDataHubDriverConfig
    from shinkensolutions.data_hub.data_hub import DataHubConfig


class DATA_KEY:
    LAST_MODIFICATION_DATE = 'data_hub_last_modification_date'
    DATA_ID = '_id'


class DataHubMongoWrongId(DataHubException):
    def __init__(self, data_id='', saved_id='', message=''):
        if not message:
            message = 'Can\'t save data with ID [%s] because it has another ID [%s] in its data' % (data_id, saved_id)
        super(DataHubMongoWrongId, self).__init__(message)


class DataHubDriverConfigMongo(AbstractDataHubDriverConfigDatabase, metaclass=ShinkenPickleableMeta):
    def __init__(self,
                 data_location_name,
                 database,
                 uri,
                 replica_set,
                 use_ssh_tunnel,
                 ssh_user,
                 ssh_keyfile,
                 ssh_tunnel_timeout,
                 auto_reconnect_max_try,
                 auto_reconnect_sleep_between_try,
                 namespace='',
                 daemon_name='',
                 module_name='',
                 submodule_name='',
                 write_concerns=None):
        # type: (str, str, str, str, bool, str, str, int, int, int, str, str, str, str, int) -> None
        super(DataHubDriverConfigMongo, self).__init__('MONGO', namespace)
        
        self.data_location_name = data_location_name
        self.database = database
        self.uri = uri
        self.replica_set = replica_set
        self.use_ssh_tunnel = use_ssh_tunnel
        self.ssh_user = ssh_user
        self.ssh_keyfile = ssh_keyfile
        self.ssh_tunnel_timeout = ssh_tunnel_timeout
        self.auto_reconnect_max_try = auto_reconnect_max_try
        self.auto_reconnect_sleep_between_try = auto_reconnect_sleep_between_try
        self.daemon_name = daemon_name
        self.module_name = module_name
        self.submodule_name = submodule_name
        self.write_concerns = write_concerns


def data_hub_driver_mongo_factory(logger, driver_config, _data_hub_config):
    # type: (PartLogger, AbstractDataHubDriverConfig, DataHubConfig) -> DataHubDriverMongo
    return DataHubDriverMongo(logger, cast(DataHubDriverConfigMongo, driver_config))


DataHubFactory.register_driver_factory(DataHubDriverConfigMongo, data_hub_driver_mongo_factory)


class DataHubDriverMongo(AbstractDataHubDriverDatabase, MongoClient):
    
    def __init__(self, logger, driver_config):
        # type: (PartLogger, DataHubDriverConfigMongo) -> None
        
        configuration = MongoConf(ShinkenModuleDefinition({
            'database'                        : driver_config.database,
            'uri'                             : driver_config.uri,
            'replica_set'                     : driver_config.replica_set,
            'use_ssh_tunnel'                  : driver_config.use_ssh_tunnel,
            'ssh_user'                        : driver_config.ssh_user,
            'ssh_keyfile'                     : driver_config.ssh_keyfile,
            'ssh_tunnel_timeout'              : driver_config.ssh_tunnel_timeout,
            'auto_reconnect_max_try'          : driver_config.auto_reconnect_max_try,
            'auto_reconnect_sleep_between_try': driver_config.auto_reconnect_sleep_between_try,
        }))
        
        MongoClient.__init__(self, configuration, logger=logger, log_database_parameters=True)
        AbstractDataHubDriverDatabase.__init__(self, logger, driver_config)
        
        # logger_init is define in MongoClient
        self.logger_init = self._logger_init
        
        self._data_location_name = driver_config.data_location_name
        self._daemon_name = driver_config.daemon_name
        self._module_name = driver_config.module_name
        self._submodule_name = driver_config.submodule_name
        self._write_concerns = getattr(driver_config, 'write_concerns', None)
        self._collection = None  # type: Optional[MongoCollection]
    
    
    def init(self):
        # type: () -> None
        MongoClient.init(self, requester='data_hub %s' % self._name)
        self._collection = self.get_collection(self._compute_collection_name())
    
    
    def lock_context(self, data_id):
        # type: (str) -> ContextManager[Any]
        
        # There is no API to lock mongo
        return nullcontext()
    
    
    def _compute_collection_name(self):
        # type: () -> str
        return '__'.join([self._namespace, self._data_location_name])
    
    
    def write(self, data_id, data):
        # type: (str, Dict[str, Any]) -> None
        data = fast_deepcopy(data)  # We need to add the data_hub_last_modification_date, so we create a copy of the dict to not edit it
        actual_id = data.get(DATA_KEY.DATA_ID, None)
        if actual_id and actual_id != data_id:
            raise DataHubMongoWrongId(data_id, actual_id)
        data[DATA_KEY.DATA_ID] = data_id
        data[DATA_KEY.LAST_MODIFICATION_DATE] = time.time()
        if self._write_concerns is not None:
            self._collection.save(data, w=self._write_concerns)
        else:
            self._collection.save(data, w=1)
    
    
    def write_raw(self, data_id, data_raw):
        # type: (str, str) -> None
        self.write(data_id, json.loads(data_raw))
    
    
    def remove(self, data_id):
        # type: (str) -> None
        if self._write_concerns is not None:
            self._collection.remove({DATA_KEY.DATA_ID: data_id}, w=self._write_concerns)
        else:
            self._collection.remove({DATA_KEY.DATA_ID: data_id})
    
    
    def get_last_modification_date(self, data_id):
        # type: (str) -> int
        data = self._collection.find_one({DATA_KEY.DATA_ID: data_id}, {DATA_KEY.DATA_ID: 1, DATA_KEY.LAST_MODIFICATION_DATE: 1})
        if not data:
            raise DataHubItemNotFound(self._data_type, data_id)
        return data.get(DATA_KEY.LAST_MODIFICATION_DATE, 0)
    
    
    def read_and_get_last_modification_date(self, data_id, log_error=True):
        # type: (str, bool) -> Tuple[Any, int]
        data = self._collection.find_one({DATA_KEY.DATA_ID: data_id}, {DATA_KEY.DATA_ID: 0})
        if data:
            last_modification_date = data.pop(DATA_KEY.LAST_MODIFICATION_DATE, 0)
        else:
            raise DataHubItemNotFound(self._data_type, data_id)
        return data, last_modification_date
    
    
    def read(self, data_id, log_error=True):
        # type: (str, bool) -> Any
        data = self._collection.find_one({DATA_KEY.DATA_ID: data_id}, {DATA_KEY.DATA_ID: 0, DATA_KEY.LAST_MODIFICATION_DATE: 0})
        if not data:
            raise DataHubItemNotFound(self._data_type, data_id)
        return data
    
    
    def read_all(self, data_id_list, log_error=True):
        # type: (List[str], bool) -> List[Any]
        return self._fetch_all_raw_data_from_mongo({DATA_KEY.DATA_ID: {'$in': list(data_id_list)}})
    
    
    def read_all_with_id(self, data_id_list, log_error=True):
        # type: (List[str], bool) -> List[Tuple[str, Any]]
        return self._fetch_all_raw_data_from_mongo_with_their_id({DATA_KEY.DATA_ID: {'$in': list(data_id_list)}})
    
    
    def read_raw(self, data_id, log_error=True):
        # type: (str, bool) -> Any
        return json.dumps(self.read(data_id, log_error=log_error))
    
    
    def get_all_data_id(self):
        # type: () -> List[str]
        return self.find_data_id({})
    
    
    def get_all_data(self, log_error=True):
        # type: (bool) -> List[Any]
        return self._fetch_all_raw_data_from_mongo({})
    
    
    def get_all_data_with_id(self, log_error=True):
        # type: (bool) -> List[Tuple[str, Any]]
        return self._fetch_all_raw_data_from_mongo_with_their_id({})
    
    
    def find_data_id(self, filters):
        # type: (Any) -> List[str]
        return [data[DATA_KEY.DATA_ID] for data in self._collection.find(filters, {DATA_KEY.DATA_ID: 1})]
    
    
    def find_data_with_id(self, filters, log_error=True):
        # type: (Any, bool) -> List[Tuple[str, Any]]
        return self._fetch_all_raw_data_from_mongo_with_their_id(filters)
    
    
    def find_data(self, filters, log_error=True):
        # type: (Any, bool) -> List[Any]
        return self._fetch_all_raw_data_from_mongo(filters)
    
    
    def _fetch_all_raw_data_from_mongo(self, filters):
        # type: (Any) -> List[Any]
        return self._collection.find(filters, {DATA_KEY.DATA_ID: 0, DATA_KEY.LAST_MODIFICATION_DATE: 0})
    
    
    def _fetch_all_raw_data_from_mongo_with_their_id(self, filters):
        # type: (Any) -> List[Tuple[str, Any]]
        return [(data.pop(DATA_KEY.DATA_ID), data) for data in self._collection.find(filters, {DATA_KEY.LAST_MODIFICATION_DATE: 0})]
    
    
    def is_data_correct(self, data_id):
        # type: (str) -> bool
        if not self.read(data_id, log_error=False):
            return False
        return True
    
    
    def destroy(self):
        # type: () -> None
        self._collection.drop()
        self.stop()
    
    
    def stop(self):
        self.disconnect()
    
    
    def get_number_of_stored_data(self):
        # type: () -> int
        return self._collection.count()
    
    
    def get_total_size(self):
        raise NotImplementedError
    
    
    def get_size_of(self, data_id):
        raise NotImplementedError
