#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2024
# This file is part of Shinken Enterprise, all rights reserved.

import os
import time

from shinken.log import LoggerFactory, PART_INITIALISATION, PartLogger
from shinken.misc.type_hint import TYPE_CHECKING
from shinkensolutions.ssh_mongodb.mongo_collection import MongoCollection
from shinkensolutions.ssh_mongodb.mongo_pid_changed import raise_if_pid_changed
from shinkensolutions.ssh_mongodb.mongo_retry import retry_on_auto_reconnect
from shinkensolutions.ssh_mongodb.sshtunnelmongomgr import mongo_by_ssh_mgr

if TYPE_CHECKING:
    from shinken.log import PartLogger
    from shinken.misc.type_hint import List, Number, Optional, Dict
    from pymongo.database import Database
    from shinkensolutions.ssh_mongodb.mongo_conf import MongoConf
    from shinkensolutions.ssh_mongodb.sshtunnelmongomgr import ConnectionResult


class MongoClient:
    def __init__(self, conf: 'Optional[MongoConf]' = None, logger: 'PartLogger | None ' = None, log_database_parameters: bool = True, for_command_line: bool = False) -> None:
        self._database = None  # type: Optional[Database]
        
        if logger:
            self.logger = logger  # type: PartLogger
        else:
            self.logger = LoggerFactory.get_logger()  # type: PartLogger
        
        self.logger_init = self.logger.get_sub_part(PART_INITIALISATION).get_sub_part('MONGO')  # type: PartLogger
        self.logger = self.logger.get_sub_part('MONGO')  # type: PartLogger
        self.log_database_parameters = log_database_parameters  # type: bool
        
        self.conf: 'Optional[MongoConf]' = conf
        
        # Mongodb connection part
        self._name_database: str = ''
        self._uri: str = ''
        self._replica_set: str = ''
        self._use_ssh_tunnel: int = 0
        self._ssh_keyfile: str = ''
        self._ssh_user: str = ''
        self._use_ssh_retry_failure: int = 0
        self._ssh_tunnel_timeout: int = 0
        self._auto_reconnect_sleep_between_try: int = 0
        self._auto_reconnect_max_try: int = 0
        self.connection_information = None  # type: Optional[ConnectionResult]
        self.pid_used_to_start_connection = None  # type: Optional[int]
        self.connected = False
        self.requester: str = ''
        
        self._read_configuration()
        
        self.for_command_line = for_command_line
        
        if for_command_line:
            mongo_by_ssh_mgr.disable_data_hub()
    
    
    # Get a connection
    # * requester: string that identify who ask for this connection, used for logging and SSH process display
    def init(self, requester: str = 'unknown', *, conf: 'Optional[MongoConf]' = None) -> None:
        if self.conf is None:
            if conf is None:
                raise ValueError('MongoClient is not configured. Please set "conf" parameter either in constructor or in init() method.')
            self.conf = conf
        elif conf is not None:
            raise ValueError('"conf" parameter have been given in constructor AND in init() method. Even if it is the same, please choose one option.')
        
        del conf  # At this point, we use self.conf, for consistency.
        
        self._read_configuration()
        
        if requester != 'unknown' or self.requester:
            if requester != 'unknown' and requester:
                self.requester = requester
            else:
                requester = self.requester
            if self.log_database_parameters:
                self.logger_init.info('Creating connection to database [%s], requested by [ %s ]' % (self._name_database, requester))
        elif self.log_database_parameters:
            self.logger_init.info('Creating connection to database [%s]' % self._name_database)
        if self.log_database_parameters:
            self.conf.log_configuration(log_properties=True, show_values_as_in_conf_file=True)
        self.try_connection(requester, verbose=self.log_database_parameters)
    
    
    def _read_configuration(self):
        
        if not self.conf:
            return
        
        self._name_database = self.conf.name_database
        self._uri = self.conf.uri
        self._replica_set = self.conf.replica_set
        self._use_ssh_tunnel = self.conf.use_ssh_tunnel
        self._ssh_keyfile = self.conf.ssh_keyfile
        self._ssh_user = self.conf.ssh_user
        self._use_ssh_retry_failure = self.conf.use_ssh_retry_failure
        self._ssh_tunnel_timeout = self.conf.ssh_tunnel_timeout
        self._auto_reconnect_sleep_between_try = self.conf.auto_reconnect_sleep_between_try
        self._auto_reconnect_max_try = self.conf.auto_reconnect_max_try
    
    
    # Use this after a fork to open a new connection
    def recreate_connection(self, requester='unknown'):
        # type: (str) -> None
        self.connection_information.disconnect()
        if requester != 'unknown' and requester:
            self.requester = requester
        elif self.requester:
            requester = self.requester
        if requester != 'unknown':
            self.logger_init.info('Resetting connection to database [%s], requested by [ %s ]' % (self._name_database, requester))
        else:
            self.logger_init.info('Resetting connection to database [%s]' % self._name_database)
        self.try_connection(requester)
    
    
    def try_connection(self, requester, logger=None, verbose=True):
        # type: (str, PartLogger, bool) -> None
        if logger is None:
            logger = self.logger_init
        time_start = time.time()
        if verbose:
            logger.info('Try to open a Mongodb connection to [ %s ] database [ %s ]' % (self._uri, self._name_database))
        self.do_connect(requester)
        if verbose:
            logger.info('Mongo connection established in %s' % self.logger.format_chrono(time_start))
    
    
    def do_connect(self, requester='unknown'):
        # type: (str) -> None
        if requester != 'unknown' and requester:
            self.requester = requester
        elif self.requester:
            requester = self.requester
        self.connection_information = self.get_connection(requester)
        
        self._database = getattr(self.connection_information.get_connection(), self._name_database)
        self.pid_used_to_start_connection = os.getpid()
        self.connected = True
    
    
    def is_connection_closed(self) -> bool:
        return not self.connection_information or self.connection_information.is_disconnected()
    
    
    def get_database_name(self):
        return self._name_database
    
    
    # NOTE: Private method used only in do_connect()
    @retry_on_auto_reconnect(is_connecting=True)
    def get_connection(self, requester='unknown'):
        # type: (str) -> ConnectionResult
        if requester != 'unknown' and requester:
            self.requester = requester
        return mongo_by_ssh_mgr.get_connection(
            self._uri,
            self.logger_init,
            fsync=False,
            replica_set=self._replica_set,
            use_ssh=self._use_ssh_tunnel,
            ssh_keyfile=self._ssh_keyfile,
            ssh_user=self._ssh_user,
            ssh_retry=self._use_ssh_retry_failure,
            ssh_tunnel_timeout=self._ssh_tunnel_timeout,
            requestor=self.requester
        )
    
    
    def _retry_connection(self):
        if self.connection_information and not self.connection_information.is_disconnected() and self.connection_information.network_is_reachable():
            return self._database
        
        try:
            self.do_connect()
        except Exception:
            pass
        
        return self._database
    
    
    def get_collection(self, collection_name):
        # type: (str) -> MongoCollection
        return MongoCollection(collection_name, self._database, self._auto_reconnect_max_try, self._auto_reconnect_sleep_between_try, self._retry_connection, self.pid_used_to_start_connection, self.is_connection_closed, logger=self.logger)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def list_name_collections(self, include_system_collections=True):
        # type: (bool) -> List
        return list(self._database.collection_names(include_system_collections))
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def _unsafe_rename_collection(self, old_name, new_name):
        # type: (str, str) -> None
        getattr(self._database, old_name).rename(new_name)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def drop_collection(self, collection_name):
        # type: (str) -> None
        getattr(self._database, collection_name).drop()
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def drop_database(self):
        # type: () -> None
        self.connection_information.get_connection().drop_database(self._name_database)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def repair_database(self):
        # type: () -> None
        self.logger.info('Start repair/defragmentation of the database. This operation may be long.')
        self._database.repairDatabase()
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def get_size_of_database(self):
        # type: () -> Number
        col_stats = self._database.command('dbstats')
        total_size = col_stats['storageSize']
        return total_size
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def get_server_status(self):
        # type: () -> Dict
        return self._database.command('serverStatus')
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def get_server_info(self):
        # type: () -> Dict
        return self.connection_information.get_connection().server_info()
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def edit_coll_mod(self, collection_name, index=None, usePowerOf2Sizes=None):
        # type: (str, Optional[Number|Dict], Optional[bool]) -> None
        kwargs_for_command = {}
        if usePowerOf2Sizes is not None:
            kwargs_for_command['usePowerOf2Sizes'] = usePowerOf2Sizes  # Will be deprecated with pymongo 4.2
        if index is not None:
            kwargs_for_command['index'] = index
        
        self._database.command('collMod', collection_name, **kwargs_for_command)
    
    
    @raise_if_pid_changed
    @retry_on_auto_reconnect()
    def drop_all_collections(self):
        collections = list(self._database.collection_names(include_system_collections=False))
        for collection_name in collections:
            self._database.drop_collection(collection_name)
    
    
    def collection_exist(self, collection_name):
        # type: (str) -> bool
        return collection_name in self.list_name_collections()
    
    
    def rename_collection(self, old_name, new_name):
        # type: (str, str) -> None
        if self.collection_exist(new_name):
            self.drop_collection(old_name)
        self._unsafe_rename_collection(old_name, new_name)
    
    
    def after_fork_cleanup(self):
        # type: () -> None
        self._database = None
        self.logger = None
        self.logger_init = None
        self._name_database = None
        self._uri = None
        self._replica_set = None
        self._use_ssh_tunnel = None
        self._ssh_keyfile = None
        self._ssh_user = None
        self._use_ssh_retry_failure = None
        self._ssh_tunnel_timeout = None
        self._auto_reconnect_sleep_between_try = None
        self._auto_reconnect_max_try = None
        self.connection_information = None
        self.connected = False
    
    
    def disconnect(self):
        if hasattr(self, '_reset_collections'):
            self._reset_collections()
        self._database = None
        self.connected = False
        self.pid_used_to_start_connection = None
        self.connection_information.disconnect()
    
    
    def fully_close(self):
        if hasattr(self, '_reset_collections'):
            self._reset_collections()
        self._database = None
        self.connected = False
        self.pid_used_to_start_connection = None
        mongo_by_ssh_mgr.remove_connection_result(self._uri, self.connection_information, requester=self.requester)
        self.connection_information = None
    
    
    def set_logger(self, logger):
        # type: (PartLogger) -> None
        self.logger = logger.get_sub_part('MONGO')
    
    
    def is_connection_available(self):
        try:
            self._database.command('serverStatus')
            return True
        except:
            return False
    
    
    def send_command(self, _command):
        # type: (str) -> Optional[Dict]
        return self._database.command(_command)
    
    
    def get_replica_set_status(self):
        # type: () -> Dict
        return self._database.command('replSetGetStatus')
