#!/usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (C) 2013-2018:
# This file is part of Shinken Enterprise, all rights reserved.

"""
This Class is a plugin for a external module for Shinken.
"""

import threading
import time
import traceback
import pymongo

from shinkensolutions.ssh_mongodb.sshtunnelmongomgr import mongo_by_ssh_mgr
import os
import tarfile
import datetime
import shutil
import re
from shinken.basemodule import BaseModule, ModuleState
from shinken.modulesctx import modulesctx
from shinken.modulesmanager import ModulesManager
from shinken.daemon import Daemon
from shinken.log import logger
from shinken.message import Message
from bson import BSON
from shinkensolutions.system_tools import create_tree
from work_hours import WorkHours

properties = {
    'daemons' : ['synchronizer'],
    'type'    : 'synchronizer_module_database_backup',
    'phases'  : ['running'],
    'external': True,
}

LOOP_INTERVAL = 60

DEFAULT_BACKUP_RATE = 60
DEFAULT_RETENTION_DAYS = 21
DEFAULT_BACKUP_DIRECTORY = '/var/shinken-user/backup/synchronizer-module-database-backup'

REGEX_BACKUP_NAME = r'^[0-9]{4}-[0-9]{2}-[0-9]{2}_[0-9]{2}-[0-9]{2}_.*\.tgz$'


# called by the plugin manager to get an instance
def get_instance(plugin):
    logger.debug("Get a SynchronizerModuleDatabaseBackup instance for plugin %s" % plugin.get_name())
    
    return SynchronizerModuleDatabaseBackup(plugin)


class SynchronizerModuleDatabaseBackup(BaseModule, Daemon):
    
    def __init__(self, modconf):
        BaseModule.__init__(self, modconf)
        self.errors = []
        self.warnings = []
        
        # Backup options
        self.backup_name = getattr(self.myconf, 'backup_name', '')
        self.backup_dir = getattr(self.myconf, 'backup_directory', None) or DEFAULT_BACKUP_DIRECTORY
        self.backup_rate = int(getattr(self.myconf, 'backup_rate', DEFAULT_BACKUP_RATE))
        self.last_backup = None
        self.work_hours = None  # Will be set later
        
        # Retention
        self.retention_days = int(getattr(self.myconf, 'retention_days', DEFAULT_RETENTION_DAYS))
        
        # Mongo
        self.uri = getattr(self.myconf, 'uri', 'mongodb://localhost/?safe=true')
        self.replica_set = getattr(self.myconf, 'replica_set', '')
        self.use_ssh_tunnel = getattr(self.myconf, 'use_ssh_tunnel', False) in ['1', True]
        self.use_ssh_retry_failure = int(getattr(self.myconf, 'use_ssh_retry_failure', 1))
        self.ssh_user = getattr(self.myconf, 'ssh_user', os.getenv('USER'))
        self.ssh_keyfile = getattr(self.myconf, 'ssh_keyfile', '~shinken/.ssh/id_rsa')
        self.ssh_tunnel_timeout = int(getattr(self.myconf, 'ssh_tunnel_timeout', '5'))
        self.database_name = getattr(self.myconf, 'database', 'synchronizer')
        
        self.mongo_database = None
    
    
    def init(self):
        
        # create the tree for backup_directory and check the rights.
        create_tree(self.backup_dir)
        self.clean_tmp_file()
        
        # Set the mongo and work_hours parameters
        self.init_mongo_database()
        self.init_work_hours_and_rate()
    
    
    def init_mongo_database(self):
        # Get a connection. If need, we will use a SSH tunnel
        requester = '%s-%s' % (logger.name, self.get_name())
        con_result = mongo_by_ssh_mgr.get_connection(
            self.uri,
            fsync=False,
            replica_set=self.replica_set,
            use_ssh=self.use_ssh_tunnel,
            ssh_keyfile=self.ssh_keyfile,
            ssh_user=self.ssh_user,
            ssh_retry=self.use_ssh_retry_failure,
            ssh_tunnel_timeout=self.ssh_tunnel_timeout,
            requestor=requester,
        )
        mongo_client = con_result.get_connection()
        
        self.mongo_database = getattr(mongo_client, self.database_name)
    
    
    def init_work_hours_and_rate(self):
        use_work_hours = getattr(self.myconf, 'enable_specific_backup_interval_during_working_hours', '0')
        days_worked = getattr(self.myconf, 'days_worked', None)
        rate_workings_hours = getattr(self.myconf, 'backup_interval_during_working_hours', '0')
        start = getattr(self.myconf, 'work_hours_start', None)
        end = getattr(self.myconf, 'work_hours_end', None)
        
        try:
            self.work_hours = WorkHours(use_work_hours, start, end, days_worked, self.backup_rate, rate_workings_hours)
            self.work_hours.is_correct()
        except Exception as e:
            logger.warning('[%s] WorkHours configuration : %s' % (self.name, e.message))
            logger.warning('[%s] There is a problem in the work_hours configuration. Will disable the option use_work_hours.' % self.name)
            self.warnings.append(e.message)
            if self.work_hours:
                self.work_hours.enable = False
    
    
    def main(self):
        logger.set_name(self.name)
        self.debug_output = []
        
        self.modules_dir = modulesctx.get_modulesdir()
        self.modules_manager = ModulesManager('synchronizer-module-database-backup', self.modules_dir, [])
        self.modules_manager.set_modules(self.modules)
        
        self.do_load_modules()
        for inst in self.modules_manager.get_all_alive_instances():
            f = getattr(inst, 'load', None)
            if f and callable(f):
                f(self)
        
        for s in self.debug_output:
            logger.debug(s)
        del self.debug_output
        
        try:
            self.do_main()
        except Exception, exp:
            msg = Message(id=0, type='ICrash', data={'name': self.get_name(), 'exception': exp, 'trace': traceback.format_exc()})
            self.from_module_to_main_daemon_queue.put(msg)
            # wait 2 sec so we know that the synchronizer got our message, and die
            time.sleep(2)
            raise
    
    
    # Real main function
    def do_main(self):
        
        # I register my exit function
        self.set_exit_handler()
        
        self.thread_backup = None
        self.thread_clean_retention = None
        
        while not self.interrupted or not self.errors:
            if self.need_backup():
                self.thread_backup = threading.Thread(None, self.backup_thread, 'backupThread')
                # This thread must stay alive to terminate his job !
                self.thread_backup.daemon = False
                self.thread_backup.start()
            
            if self.need_clean_retention():
                self.thread_clean_retention = threading.Thread(None, self.clean_retention_thread, 'cleanRetentionThread')
                # This thread must stay alive to terminate his job !
                self.thread_clean_retention.daemon = False
                self.thread_clean_retention.start()
            
            time.sleep(LOOP_INTERVAL)
    
    
    def need_backup(self):
        now = time.time()
        if not self.last_backup:
            self.last_backup = now
            return True
        
        # The delta between 2 backup must be minutes
        delta = (now - self.last_backup) / 60
        
        # The self.work_hours can be empty if it encounter an error at init
        _rate = self.work_hours.get_rate() if self.work_hours else self.backup_rate
        
        if delta >= _rate:
            self.last_backup = now
            return True
        else:
            return False
    
    
    def need_clean_retention(self):
        return True
    
    
    def backup_thread(self):
        # To launch the backup, we need to get the time to create the directory
        start_time = time.time()
        logger.debug('[Backup] Start thread')
        # Then we compute the backup directory name with date and name set by user
        backup_date = datetime.datetime.fromtimestamp(start_time).strftime('%Y-%m-%d_%H-%M')
        backup_name = '%s_%s' % (backup_date, self.backup_name)
        archive_name = '%s.tgz' % backup_name
        full_backup_directory_name = os.path.join(self.backup_dir, backup_name)
        
        # ready ? let's go to backup, compress, and clean dump
        self.dump_mongo_database(full_backup_directory_name)
        logger.log_perf(start_time, self, 'dump mongo database : %s' % (self.database_name))
        
        # Then we compress this dump
        compress_start_time = time.time()
        self.compress_directory(archive_name, backup_name)
        logger.log_perf(compress_start_time, self, 'compress dump : %s into %s' % (full_backup_directory_name, archive_name))
        
        # Now need to clean the raw_dump
        clean_start_time = time.time()
        self.clean_raw_dump(full_backup_directory_name)
        logger.log_perf(clean_start_time, self, 'clean raw dump')
        
        full_archive_name = os.path.join(self.backup_dir, archive_name)
        logger.log_perf(start_time, self, 'full backup creation [%s]' % (full_archive_name))
        logger.info('[Backup] Archive created [%s]' % (full_archive_name))
    
    
    def clean_retention_thread(self):
        # To know which archive to delete, we need to get the date
        start_time = time.time()
        
        # Then we compute the older timestamp to keep
        older_to_keep = datetime.datetime.fromtimestamp(start_time) - datetime.timedelta(days=self.retention_days)
        removed_files = []
        
        # Will keep only the file with the format YYYY-MM-DD_HH-MM_All_You-Want Here.tgz
        archives = [arch for arch in os.listdir(self.backup_dir) if re.match(REGEX_BACKUP_NAME, arch)]
        for archive in archives:
            archive_time = archive[:16]
            archive_date = datetime.datetime.strptime(archive_time, '%Y-%m-%d_%H-%M')
            if archive_date < older_to_keep:
                removed_files.append(archive)
                os.remove(os.path.join(self.backup_dir, archive))
        
        if removed_files:
            logger.info('[Clean retention] Files removed : %s' % ', '.join(removed_files))
    
    
    def dump_mongo_database(self, backup_dir):
        dump_directory = os.path.join(backup_dir, self.database_name)
        try:
            create_tree(dump_directory)
        except OSError as e:
            _error = '%s: %s' % (e.strerror, e.filename)
            logger.error(_error)
            if _error not in self.errors:
                self.errors.append(_error)
            raise e
        collections = self.mongo_database.collection_names()
        for i, collection_name in enumerate(collections):
            if collection_name.startswith('tmp-'):
                continue
            col = getattr(self.mongo_database, collections[i])
            raw_col = col.find()
            file_name = os.path.join(dump_directory, '%s.bson' % collection_name)
            with open(file_name, 'wb+') as f:
                for doc in raw_col:
                    f.write(BSON.encode(doc))
    
    
    def compress_directory(self, backup_name, backup_dir):
        cwd = os.getcwd()
        os.chdir(self.backup_dir)
        _name = 'tmp-%s' % backup_name
        try:
            targz = tarfile.open(_name.encode('utf-8'), 'w:gz')
            targz.add(backup_dir)
        except AttributeError:
            pass
        finally:
            targz.close()
        
        os.rename(_name, backup_name)
        os.chdir(cwd)
    
    
    def clean_raw_dump(self, backup_dir):
        try:
            shutil.rmtree(backup_dir)
        except OSError as e:
            print ("Error: %s - %s." % (e.filename, e.strerror))
    
    
    def get_module_info(self):
        module_info = {}
        tmp = BaseModule.get_submodule_states(self)
        if tmp:
            module_info.update(tmp)
        
        module_info['errors'] = self.errors
        module_info['warnings'] = self.warnings
        if self.errors:
            # Be carreful : the module state in "CRITICAL" is not correctly shown in the healthcheck.
            module_info['status'] = ModuleState.CRITICAL
            module_info['output'] = '<br>'.join(self.errors)
        elif self.warnings:
            module_info['status'] = ModuleState.WARNING
            module_info['output'] = '<br>'.join(self.warnings)
        return module_info
    
    
    def clean_tmp_file(self):
        files = os.listdir(self.backup_dir)
        for file in files:
            if file.startswith('tmp-'):
                os.remove(os.path.join(self.backup_dir, file))
