#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022
# This file is part of Shinken Enterprise, all rights reserved.

import os
import time

from .mongo_collection import MongoCollection
from .mongo_retry import retry_on_auto_reconnect
from shinken.log import LoggerFactory, PART_INITIALISATION, PartLogger
from shinken.misc.type_hint import TYPE_CHECKING
from shinkensolutions.ssh_mongodb.mongo_pid_changed import raise_if_pid_changed
from shinkensolutions.ssh_mongodb.sshtunnelmongomgr import mongo_by_ssh_mgr

if TYPE_CHECKING:
    from shinken.log import PartLogger
    from shinken.misc.type_hint import List, Number, Union, Optional, Dict
    from pymongo import MongoReplicaSetClient as _PyMongoReplicaSetClient, MongoClient as _PyMongoClient
    from pymongo.database import Database
    from shinkensolutions.ssh_mongodb.mongo_conf import MongoConf
    from shinkensolutions.ssh_mongodb.sshtunnelmongomgr import ConnectionResult


class MongoClient(object):
    def __init__(self, conf, logger=None, log_database_parameters=True):
        # type: (MongoConf, PartLogger, bool) -> None
        self._database = None  # type: Optional[Database]
        
        if logger:
            self.logger = logger  # type: PartLogger
        else:
            self.logger = LoggerFactory.get_logger()  # type: PartLogger
        
        self.logger_init = self.logger.get_sub_part(PART_INITIALISATION).get_sub_part('MONGO')  # type: PartLogger
        self.logger = self.logger.get_sub_part('MONGO')  # type: PartLogger
        self.log_database_parameters = log_database_parameters  # type: bool
        
        self.conf = conf
        # Mongodb connection part
        self._name_database = self.conf.name_database
        self._uri = self.conf.uri
        self._replica_set = self.conf.replica_set
        self._use_ssh_tunnel = self.conf.use_ssh_tunnel
        self._ssh_keyfile = self.conf.ssh_keyfile
        self._ssh_user = self.conf.ssh_user
        self._use_ssh_retry_failure = self.conf.use_ssh_retry_failure
        self._ssh_tunnel_timeout = self.conf.ssh_tunnel_timeout
        self._auto_reconnect_sleep_between_try = self.conf.auto_reconnect_sleep_between_try
        self._auto_reconnect_max_try = self.conf.auto_reconnect_max_try
        self.connection_information = None  # type: Optional[ConnectionResult]
        self.pid_used_to_start_connection = None  # type: Optional[int]
        self.connected = False
        self._pymongo_client = None  # type: Optional[Union[_PyMongoClient, _PyMongoReplicaSetClient]]
    
    
    # Get a connection
    # * requester: string that identify who ask for this connection, used for logging and SSH process display
    def init(self, requester='unknown'):
        # type: (unicode) -> None
        if requester != 'unknown':
            self.logger_init.info('Creating connection to database [%s], requested by [ %s ]' % (self._name_database, requester))
        else:
            self.logger_init.info('Creating connection to database [%s]' % self._name_database)
        if self.log_database_parameters:
            self.conf.log_configuration(log_properties=True, show_values_as_in_conf_file=True)
        self.try_connection(requester)
    
    
    # Use this after a fork to open a new connection
    def recreate_connection(self, requester='unknown'):
        # type: (unicode) -> None
        if self._pymongo_client:
            self._pymongo_client.close()
        if requester != 'unknown':
            self.logger_init.info('Resetting connection to database [%s], requested by [ %s ]' % (self._name_database, requester))
        else:
            self.logger_init.info('Resetting connection to database [%s]' % self._name_database)
        self.try_connection(requester)
    
    
    def try_connection(self, requester, logger=None, verbose=True):
        # type: (unicode, PartLogger, bool) -> None
        if logger is None:
            logger = self.logger_init
        time_start = time.time()
        if verbose:
            logger.info('Try to open a Mongodb connection to [ %s ] database [ %s ]' % (self._uri, self._name_database))
        self.do_connect(requester)
        if verbose:
            logger.info('Mongo connection established in %s' % self.logger.format_chrono(time_start))
    
    
    def do_connect(self, requester='unknown'):
        # type: (unicode) -> None
        self.connection_information = self.get_connection(requester)
        
        self._pymongo_client = self.connection_information.get_connection()
        self._database = getattr(self._pymongo_client, self._name_database)
        self.pid_used_to_start_connection = os.getpid()
        self.connected = True
    
    
    def get_database_name(self):
        return self._name_database
    
    
    # NOTE: Private method used only in do_connect()
    @retry_on_auto_reconnect(retry_connection_on_error=False)
    def get_connection(self, requester='unknown'):
        # type: (unicode) -> ConnectionResult
        return mongo_by_ssh_mgr.get_connection(
            self._uri,
            self.logger_init,
            fsync=False,
            replica_set=self._replica_set,
            use_ssh=self._use_ssh_tunnel,
            ssh_keyfile=self._ssh_keyfile,
            ssh_user=self._ssh_user,
            ssh_retry=self._use_ssh_retry_failure,
            ssh_tunnel_timeout=self._ssh_tunnel_timeout,
            requestor=requester
        )
    
    
    def _retry_connection(self):
        if self.connection_information and self.connection_information.network_is_reachable():
            return self._database
        
        try:
            self.do_connect()
        except Exception:
            pass
        
        return self._database
    
    
    def get_collection(self, collection_name):
        # type: (unicode) -> MongoCollection
        return MongoCollection(collection_name, self._database, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self._retry_connection, self.pid_used_to_start_connection, logger=self.logger)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def list_name_collections(self, include_system_collections=True):
        # type: (bool) -> List
        return list(self._database.collection_names(include_system_collections))
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def _unsafe_rename_collection(self, old_name, new_name):
        # type: (unicode, unicode) -> None
        getattr(self._database, old_name).rename(new_name)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def drop_collection(self, collection_name):
        # type: (unicode) -> None
        getattr(self._database, collection_name).drop()
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def drop_database(self):
        # type: () -> None
        self.connection_information.get_connection().drop_database(self._name_database)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def repair_database(self):
        # type: () -> None
        self.logger.info('Start repair/defragging database this operation may be long')
        self._database.repairDatabase()
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def get_size_of_database(self):
        # type: () -> Number
        col_stats = self._database.command('dbstats')
        total_size = col_stats['storageSize']
        return total_size
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def get_server_status(self):
        # type: () -> Dict
        return self._database.command('serverStatus')
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def get_server_info(self):
        # type: () -> Dict
        return self.connection_information.get_connection().server_info()
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def edit_coll_mod(self, collection_name, index=None, usePowerOf2Sizes=None):
        # type: (unicode, Optional[Number], Optional[bool]) -> None
        kwargs_for_command = {}
        if usePowerOf2Sizes is not None:
            kwargs_for_command['usePowerOf2Sizes'] = usePowerOf2Sizes  # Will be deprecated with pymongo 4.2
        if index is not None:
            kwargs_for_command['index'] = index
        
        self._database.command('collMod', collection_name, **kwargs_for_command)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def drop_all_collections(self):
        collections = list(self._database.collection_names(include_system_collections=False))
        for collection_name in collections:
            self._database.drop_collection(collection_name)
    
    
    def collection_exist(self, collection_name):
        # type: (unicode) -> bool
        return collection_name in self.list_name_collections()
    
    
    def rename_collection(self, old_name, new_name):
        # type: (unicode, unicode) -> None
        if self.collection_exist(new_name):
            self.drop_collection(old_name)
        self._unsafe_rename_collection(old_name, new_name)
    
    
    def after_fork_cleanup(self):
        # type: () -> None
        self._database = None
        self.logger = None
        self.logger_init = None
        self._name_database = None
        self._uri = None
        self._replica_set = None
        self._use_ssh_tunnel = None
        self._ssh_keyfile = None
        self._ssh_user = None
        self._use_ssh_retry_failure = None
        self._ssh_tunnel_timeout = None
        self._auto_reconnect_sleep_between_try = None
        self._auto_reconnect_max_try = None
        self.connection_information = None
    
    
    def disconnect(self):
        self._pymongo_client.close()
    
    
    def set_logger(self, logger):
        # type: (PartLogger) -> None
        self.logger = logger.get_sub_part('MONGO')
    
    
    def is_connection_available(self):
        try:
            self._database.command('serverStatus')
            return True
        except:
            return False
    
    
    def send_command(self, _command):
        # type: (unicode) -> Optional[Dict]
        return self._database.command(_command)
    
    
    def get_replica_set_status(self):
        # type: () -> Dict
        return self._database.command('replSetGetStatus')
