import atexit
import os
import random
import re
import shlex
import socket
import subprocess
import sys
import threading
import time

from pymongo import ReplicaSetConnection
from pymongo.connection import Connection
from pymongo.uri_parser import parse_uri

from shinken.log import LoggerFactory, logger

try:
    # Try to read in non-blocking mode, from now this only from no:
    # w on  Unix systems -> Not handle so i remove try
    import fcntl
except:
    fcntl = None
    
try:
    from pymongo.connection import Connection
    from pymongo import ReplicaSetConnection, ReadPreference
    from pymongo.uri_parser import parse_uri
    from pymongo.mongo_client import ConnectionFailure, ConfigurationError
    import pymongo
    import bson
except ImportError, exp:
    logger.error('Cannot import pymongo librairy (%s). Please install it' % exp)
    Connection = None

SSH_ADDRESS_ALREADY_IN_USE_ERROR = 'address_already_in_use'
SSH_TUNNEL_TIMEOUT_DEFAULT = 5


class ConnectionResult(object):
    def __init__(self, con, ssh_time, mongodb_time):
        self._con = con
        self._ssh_time = ssh_time
        self._mongodb_time = mongodb_time
    
    
    def get_connection(self):
        return self._con
    
    
    def get_ssh_tunnel_time(self):
        return self._ssh_time
    
    
    def get_mongodb_connection_time(self):
        return self._mongodb_time


class SshTunnelMongoManager(object):
    def __init__(self):
        self.connection_results = {}
        self.mute_mode = False
        
        # Mongodb lib do not manage fork() calls. So when getting the connection, do NOT give connections that are
        # from another process. So when we change process, reset/close all and restart from scratch
        self.current_pid = 0
        
        # We will remember all living ssh tunnels, so we can close them when need
        self._living_sub_processes = {}
    
    
    # We did change process (fork()) so reset our pool
    def _reset_pool_after_change_process(self):
        for result in self.connection_results.values():
            try:
                # After fork we can deadlock so we reset the lock
                setattr(result.get_connection(), '__connecting_lock', threading.Lock())
                result.get_connection().disconnect()
            except:
                pass
        self.connection_results.clear()
        self._living_sub_processes.clear()
    
    
    # For a specific uri we will have one id for the destination, and we will have 1 and only one tunnel by destination
    def _get_destination_from_uri(self, uri, logger):
        # Parsing example :
        # parse_uri("mongodb://user:@example.com/my_database/?w=2")
        # {'username': 'user', 'nodelist': [('example.com', 27017)], 'database': 'my_database/', 'collection': None, 'password': '', 'options': {'w': 2}}
        uri_datas = parse_uri(uri)
        # TODO: what to do with more than one node? how to manage this?
        nodes = uri_datas['nodelist']
        if len(nodes) != 1:
            if not self.mute_mode:
                logger.error('Cannot extract destination server from the mongodb uri "%s". There must be one destination server to connect to.' % uri)
            return None
        node = nodes[0]
        addr, dest_port = node
        return addr, dest_port
    
    
    # We want a port that is 30000 < HASH PORT < 60000 based on the addr string
    @staticmethod
    def _get_tunnel_port(addr):
        h = hash(addr)
        if h < 0:
            h += sys.maxsize
        # h is here a fucking big int, map it into the 10000 => 30000 space
        port = 10000 + divmod(h, 20000)[1]
        return port
    
    
    # We want a binding port, randomly from base port.
    # _get_new_binding_port_from_base_port port is 10K=>30K, so we can random in the next 30K to have
    # a random between 30K->60K
    @staticmethod
    def _get_new_binding_port_from_base_port(base_port):
        new_binding_port = base_port + random.randint(1, 30000)
        return new_binding_port
    
    
    # We will spawn a ssh background process.
    # It will wait for a connection in the next 30s
    # * If there i no connection in the 30s, it will just stop
    # * when the last connection will close, it will stop too
    def _spawn_background_ssh(self, uri, addr, dest_port, local_port_search, user, keyfile, max_ssh_retry, requestor, logger, tunnel_end_of_life_delay, ssh_tunnel_timeout):
        
        # the +1 is here to tel if the settings is 1 so we want 1 normal try and 1 retry
        nb_try = 1
        local_port_search['current'] = self._get_new_binding_port_from_base_port(local_port_search['base'])
        if not self.mute_mode:
            logger.info('Connection to %s with a ssh tunnel:' % uri)
        while True:
            current_binding_port = local_port_search['current']
            # note: lang=C + shell so we can be sure we will have errors in english to grep them!
            # note2: -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no ==> do not look at host authorization
            # note3: -o BatchMode=yes => do not prompt password or such things
            # note4: -o ExitOnForwardFailure=true => if forward fail, exit!
            # note5: "root" user is strictly forbidden to prevent huge security holes
            # note6: -4 : ipv4 only, so it won't open only ipv6 if ipv4 is unbindable
            # note7: -o PreferredAuthentications=publickey : we do connect with key, so directly skip other methods
            # More info in : #SEF-2392
            cmd = '/usr/bin/ssh -4 -o PreferredAuthentications=publickey -o ExitOnForwardFailure=true -o BatchMode=yes -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -L %d:localhost:%d ' % (current_binding_port, dest_port)
            
            if user is not None and user != "" and user != 'root':
                cmd += ' -l "%s"' % user
            
            if keyfile is not None and keyfile != "":
                cmd += ' -i "%s"' % os.path.expanduser(keyfile)
            
            # Protect requestor from non shell characters
            requestor = re.sub(r'\W+', '-', requestor)
            cmd += ' "%s" "echo \"CONNECTED\";echo \" Mongo SSH Tunnel. Connexion requestor is %s\";sleep %d"' % (addr, requestor, tunnel_end_of_life_delay)
            
            if not self.mute_mode:
                logger.info('   - searching a random local port available for the tunnel binding (trying %d): localhost:%s =(ssh tunnel)=> %s:22 =(mongodb)=> %s:%d (search try:%d)' %
                            (current_binding_port, current_binding_port, addr, addr, dest_port, nb_try))
            
            is_last_try = False
            if nb_try >= max_ssh_retry:
                is_last_try = True
            
            # Real error? oups
            nb_try += 1
            
            p = subprocess.Popen(shlex.split(cmd), stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, close_fds=True, preexec_fn=os.setsid, env={'LANG': 'C'})
            try:
                ssh_result = self._look_for_ssh_start_finish(p, addr, current_binding_port, logger, user=user, keyfile=keyfile, raise_on_unknow_error=is_last_try, ssh_tunnel_timeout=ssh_tunnel_timeout)
                is_connected = ssh_result['connected']
                # If we did connect, we are done
                if is_connected:
                    if not self.mute_mode:
                        logger.info('     - tunnel creation SUCCESS: localhost:%s =(ssh tunnel)=> %s:22 =(mongodb)=> %s:%d (search try:%d, ssh pid=%s)' %
                                    (current_binding_port, addr, addr, dest_port, nb_try, p.pid))
                    
                    # Remember the process so we can clean it if need in the future
                    self._living_sub_processes[uri] = p
                    break
                # Maybe it's just a problem of port already used, if so, increase
                ssh_error = ssh_result['error']
                # Maybe it's just a already in use error, if so just try with another port
                if ssh_error == SSH_ADDRESS_ALREADY_IN_USE_ERROR:
                    local_port_search['current'] = self._get_new_binding_port_from_base_port(local_port_search['base'])
                    if not self.mute_mode:
                        logger.info('     - the binding port %s was not free, trying another port (search try:%d)' % (current_binding_port, nb_try))
                    
                    continue
            except Exception:  # timeout
                if nb_try > max_ssh_retry:
                    raise
    
    
    @staticmethod
    def _no_block_read(output):
        fd = output.fileno()
        fl = fcntl.fcntl(fd, fcntl.F_GETFL)
        fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK)
        try:
            return output.read()
        except:
            return ''
    
    
    @staticmethod
    def _kill_process(process):
        try:
            process.terminate()
            process.kill()
        except:  # already done
            pass
    
    
    def _look_for_ssh_start_finish(self, process, addr, current_binding_port, logger, user='shinken', keyfile='~/.ssh/id_rsa.pub', raise_on_unknow_error=True, ssh_tunnel_timeout=SSH_TUNNEL_TIMEOUT_DEFAULT):
        start = time.time()
        if keyfile is None:
            keyfile = '~/.ssh/id_rsa.pub'
        # Don't worry, we will exit, we just don't know if it will be fast or not
        while True:
            # Global command timeout (even if ssh have one, I prefer add ANOTHER one if process startup hangs for something)
            if time.time() > start + ssh_tunnel_timeout:
                warn = '     - the ssh tunnel to %s timed out to (after %ds) when trying to open it. Please check that your server is reachable and that the SSH daemon allows your connection.' % (addr, ssh_tunnel_timeout)
                if not self.mute_mode:
                    logger.warning(warn)
                self._kill_process(process)
                raise Exception(warn)
            
            # If the process is not dead, try to read what we can as stdout/stderr and look inside if we have clue about finish or not
            process_finish = (process.poll() is not None)
            # print dir(process)
            if not process_finish:
                # Now read stdout/err, but as partial as process is still alive
                stdout = self._no_block_read(process.stdout)
                stderr = self._no_block_read(process.stderr)
            else:
                stdout, stderr = process.communicate()
            lines = stdout.splitlines()
            for line in lines:
                # All we want is this line. It means that the ssh tunnel is done and we did have a shell on the other side, so we are OK
                if 'CONNECTED' in line:
                    sock = None
                    # Try to really connect to it before give back the addr
                    try:
                        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                        sock.settimeout(ssh_tunnel_timeout)
                        sock.connect(('localhost', current_binding_port))
                        sock.close()
                    except Exception as exp:  # Oups, the socket was not so good finally
                        try:
                            sock.close()
                        except:
                            pass
                        warn = '     - the ssh tunnel to %s timed out to (after %ds) when trying to test it (%s)' % (addr, ssh_tunnel_timeout, exp)
                        if not self.mute_mode:
                            logger.warning(warn)
                        self._kill_process(process)
                        return {'connected': False, 'error': warn}
                    return {'connected': True, 'error': None}
            
            # We look at stderr, it's where the truth lie
            lines = stderr.splitlines()
            for line in lines:
                # Maybe there is just another ssh already started, not a problem in fact, the other process will allow us to connect
                if 'bind: Address already in use' in line:
                    return {'connected': False, 'error': SSH_ADDRESS_ALREADY_IN_USE_ERROR}
                # Ok classic one: no ssh key on the other side, definitive error, exit from here
                elif 'Permission denied' in stderr:
                    err = '     - the ssh tunnel to %s failed because your local ssh key (%s) is not authorized on the distant server. You can use the command "ssh-copy-id -i %s %s@%s" from this local server to fix this. If any problem occurs, please refer to the Shinken Documentation. => %s' % \
                          (addr, keyfile, keyfile, user, addr, line)
                    if not self.mute_mode:
                        logger.error(err)
                    raise Exception(err)
            
            # If the process did finish and we didn't catch a stderr we know, get a generic error
            if process_finish:
                err = '     - the ssh tunnel to %s failed to connect with the error: %s' % (addr, stderr)
                if raise_on_unknow_error:
                    if not self.mute_mode:
                        logger.error(err)
                    raise Exception(err)
                else:
                    return {'connected': False, 'error': err}
            
            # Still now finish? don't hammer the CPU
            time.sleep(0.01)
    
    
    @staticmethod
    def _do_connect(uri, replica_set, fsync):
        if replica_set:
            con = ReplicaSetConnection(uri, replicaSet=replica_set, fsync=fsync)
        else:
            con = Connection(uri, fsync=fsync)
        return con
    
    
    def _force_close_previous_ssh_process(self, uri):
        process = self._living_sub_processes.get(uri, None)
        if process:
            try:
                process.terminate()
                time.sleep(0.1)  # let the ssh process time to close
                process.kill()
            except:  # errors here are not a problem
                pass
    
    
    def close_all_tunnels(self):
        if len(self._living_sub_processes) == 0:
            return
        for process in self._living_sub_processes.values():
            try:
                process.terminate()
            except:  # already done
                pass
        # We have some process, let them some ms to finish
        time.sleep(0.1)
        # But after some time, kill them all, no mercy
        for process in self._living_sub_processes.values():
            try:
                process.kill()
            except:  # already done
                pass
        self._living_sub_processes.clear()
    
    
    # This lib can be used in checks, and here we do nto want any logger or such things, Check will do it itself
    def set_mute(self):
        self.mute_mode = True
    
    
    def get_connection(
            self,
            uri,
            logger=None,
            replica_set='',
            fsync=False,
            ssh_user=None,
            use_ssh=False,
            ssh_keyfile=None,
            ssh_retry=1,
            requestor='(unknown requestor)',
            force_ssh_tunnel_recreation=False,
            tunnel_for='mongodb',
            tunnel_end_of_life_delay=30,
            ssh_tunnel_timeout=SSH_TUNNEL_TIMEOUT_DEFAULT
    ):
        if logger is None:
            logger = LoggerFactory.get_logger(requestor)
        logger = logger.get_sub_part('SSH TUNNEL')
        
        # First look in cache, but only if the same process (if not, reset the connection pool)
        cur_pid = os.getpid()
        if cur_pid != self.current_pid:
            self._reset_pool_after_change_process()
            self.current_pid = cur_pid
        
        if force_ssh_tunnel_recreation:
            self._force_close_previous_ssh_process(uri)
        
        connection_result = self.connection_results.get(uri)
        if connection_result is not None:
            # Ok we have a connection, but maybe it's dead?
            try:
                connection_result.get_connection().admin.command('ping')
            except Exception:
                if not self.mute_mode:
                    logger.warning('The connection to the mongodb server %s is no more available, restart a new connection' % uri)
                connection_result = None
        
        # Ok we have a valid connection, give it
        if connection_result is not None:
            return connection_result
        # Ok we must create one, start to be fun
        
        # First try to get where we are going to
        addr, dest_port = self._get_destination_from_uri(uri, logger)
        local_port_base = self._get_tunnel_port('%s:%s' % (addr, dest_port))
        local_port = 0
        ssh_time = 0.0
        
        # Be sure the is a valid ssh waiting for our connection
        if use_ssh:
            before_ssh = time.time()
            local_port_search = {'base': local_port_base, 'current': local_port_base}
            self._spawn_background_ssh(uri, addr, dest_port, local_port_search, ssh_user, ssh_keyfile, ssh_retry, requestor, logger, tunnel_end_of_life_delay, ssh_tunnel_timeout)
            ssh_time = time.time() - before_ssh
            
            # The founded binding port can be different thant the base one
            local_port = local_port_search['current']
            
            # We change the mongo://addr/ => mongo://localhost/
            new_mongo_uri = uri.replace(addr, 'localhost')
            # If there was a specific port set in the uri, switch it
            if ':' in new_mongo_uri.replace('mongodb:', ''):  # avoid the first : in mongodb:
                new_mongo_uri = new_mongo_uri.replace(':%d' % dest_port, ':%d' % local_port)
            # If there was no such port, set it directly
            else:
                new_mongo_uri = new_mongo_uri.replace('localhost', 'localhost:%d' % local_port)
        else:
            new_mongo_uri = uri
        before_mongo = time.time()
        
        try:
            con = self._do_connect(new_mongo_uri, replica_set, fsync)
        except:
            raise Exception('mongo connection failure with the SSH tunnel: localhost:%s =(ssh tunnel)=> %s:22 =(mongodb)=> %s:%d' % (local_port, addr, addr, dest_port))
        
        mongodb_time = time.time() - before_mongo
        if use_ssh and not self.mute_mode:
            logger.info('   - SUCCESS mongo connection is OPENED with the SSH tunnel: localhost:%s =(ssh tunnel)=> %s:22 =(mongodb)=> %s:%d' % (local_port, addr, addr, dest_port))
        
        # Fill the final object
        connection_result = ConnectionResult(con, ssh_time, mongodb_time)
        self.connection_results[uri] = connection_result
        return connection_result
    
    
    @staticmethod
    def check_pymongo_bson_c_extension_installed():
        logger.debug('check_pymongo_c_extension_installed::start')
        pymongo_has_c = {'installed': True, 'status': 'OK', 'message': 'Your pymongo has C extension installed'}
        bson_has_c = {'installed': True, 'status': 'OK', 'message': 'Your bson lib has C extension installed'}
        if not pymongo.has_c():
            pymongo_has_c = {'installed': False, 'status': 'ERROR', 'message': 'Your pymongo lib has not the C extension installed'}
        if not bson.has_c():
            bson_has_c = {'installed': False, 'status': 'ERROR', 'message': 'Your bson lib has not the C extension installed'}
        
        return pymongo_has_c, bson_has_c
    
    
    @staticmethod
    def adapt_cluster_mongo_data_for_json_format(stats):
        election_id = stats.get('$gleStats', {}).get('electionId', None)
        last_op_time = stats.get('$gleStats', {}).get('lastOpTime', None)
        if election_id:
            stats['$gleStats']['electionId'] = str(election_id)
        if last_op_time:
            stats['$gleStats']['lastOpTime'] = last_op_time.time
        
        for cluster in stats.get('raw', {}):
            election_id = stats['raw'][cluster].get('$gleStats', {}).get('electionId', None)
            last_op_time = stats['raw'][cluster].get('$gleStats', {}).get('lastOpTime', None)
            if election_id:
                stats['raw'][cluster]['$gleStats']['electionId'] = str(election_id)
            if last_op_time:
                stats['raw'][cluster]['$gleStats']['lastOpTime'] = last_op_time.time
        return stats
    
    
    def check_connexion_mongodb(self, mongo_uri):
        logger.debug('check_mongodb_connexion::start')
        _result = []
        for url, connection_result in self.connection_results.iteritems():
            if url != mongo_uri:
                continue
            connect_msg = 'Mongodb server is available at: %s' % url
            is_connected = False
            status = 'ERROR'
            stats = {}
            pymongo_has_c, bson_has_c = self.check_pymongo_bson_c_extension_installed()
            
            try:
                mongo_connexion = connection_result.get_connection()
                mongo_all_databases = mongo_connexion.database_names()
                for database in mongo_all_databases:
                    stats[database] = self.adapt_cluster_mongo_data_for_json_format(mongo_connexion[database].command('dbstats'))
                is_connected = True
                status = 'OK'
            except ConnectionFailure as exp:
                connect_msg = 'Cannot connect to mongodb server: %s   (%s)' % (url, exp)
            except ConfigurationError as exp:
                connect_msg = 'Mongo configuration error: %s    (%s)' % (url, exp)
            except Exception as e:
                connect_msg = str(e)
            
            connexion_state = {
                'is_connected' : is_connected,
                'status'       : status,
                'connect_msg'  : connect_msg,
                'url'          : url,
                'pymongo_has_c': pymongo_has_c,
                'bson_has_c'   : bson_has_c,
                'stats'        : stats,
            }
            _result.append(connexion_state)
        
        return _result


if Connection is not None:
    mongo_by_ssh_mgr = SshTunnelMongoManager()
    # When exiting, we need to close all living tunnels
    atexit.register(mongo_by_ssh_mgr.close_all_tunnels)

else:  # missing pymongo lib, sick...
    mongo_by_ssh_mgr = None
