#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2019:
# This file is part of Shinken Enterprise, all rights reserved.

import argparse
from datetime import datetime
import fnmatch
import hashlib
import json
import math
import os
import re
import shutil
import socket
import string
import subprocess
import sys
import tarfile
import threading
import traceback

from lib.color_print import *


class HostNotLocked(Exception):
    pass


class HostAlreadyLocked(Exception):
    def __init__(self, lock):
        self.locked_by = lock
    
    
    def __repr__(self):
        return self.locked_by


class ShinkenArchiveChecksumMismatch(Exception):
    pass


class SSHCommandTimeout(Exception):
    def __init__(self, timeout):
        self.timeout = timeout
    
    
    def __repr__(self):
        return self.locked_by


class Configuration(object):
    defaults = {
        "ssh_default_timeout"     : 5,
        "ping_timeout"            : 3,
        "archive_upload_timeout"  : 600,
        "archive_extract_timeout" : 180,
        "patch_upload_timeout"    : 300,
        "patch_extract_timeout"   : 120,
        "preinstall_check_timeout": 5,
        "update_timeout"          : 600,
        "patch_timeout"           : 300,
        "tmp_dir"                 : "/tmp/shinken_auto_update",
        "force_reupload"          : False,
    }
    
    options = {
        # Options defined by command line flags
        "cli"       : {},
        # Options defined as global in configuration file
        "global_cfg": {},
    }
    
    
    def get(self, key):
        cli_option = self.options.get("cli").get(key)
        if cli_option is not None:
            return cli_option
        
        global_cfg_option = self.options.get("global_cfg").get(key)
        if global_cfg_option is not None:
            return global_cfg_option
        
        return self.defaults.get(key)
    
    
    def set(self, key, value, namespace):
        self.options[namespace][key] = value
    
    
    def get_cfg_file_global_options(self):
        return self.options["global_cfg"]
    
    
    @staticmethod
    def save_to_file(config_file_path):
        hosts_to_save = [
            {
                "addr"         : h.get_addr(),
                "names"        : [name for name in h.get_names()],
                "configuration": h.get_configuration(),
            } for h in shinken_host_list.get_hosts()]
        
        configuration_contents = {
            "configuration": configuration.get_cfg_file_global_options(),
            "hosts"        : hosts_to_save,
        }
        try:
            with(open(config_file_path, "w")) as f:
                json.dump(configuration_contents, f)
            
            cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
            print "Configuration file saved (%s)" % config_file_path
        except Exception as e:
            print ""
            print_h1("CONFIGURATION ERROR", line_color="red", title_color="red")
            cprint("Could not save configuration file (%s): '%s'" % (config_file_path, str(e)), color="red")
            sys.exit(1)
    
    
    @staticmethod
    def read_from_file(config_file_path):
        if shinken_host_list.configuration_scan_done:
            return
        
        print ""
        print_h1("Reading hosts from configuration file")
        
        if os.path.exists(config_file_path):
            cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
            print "Configuration file used: %s" % config_file_path
            try:
                configuration_file_contents = json.load(open(config_file_path, 'r'))
            except Exception as e:
                print ""
                print_h1("CONFIGURATION ERROR", line_color="red", title_color="red")
                cprint("Could not read hosts configuration file (%s): '%s'" % (config_file_path, str(e)), color="red")
                sys.exit(1)
            
            hosts_from_cfg = configuration_file_contents.get("hosts")
            for host in hosts_from_cfg:
                shinken_host_list.add_host(host.get("addr"), host.get("configuration"), is_shinken_host=True)
                for alt_name in host.get("names"):
                    shinken_host_list.add_host(alt_name)
            
            cfg_file_configuration_options = configuration_file_contents.get("configuration")
            if cfg_file_configuration_options is not None:
                for key, value in cfg_file_configuration_options.iteritems():
                    configuration.set(key, value, namespace="global_cfg")
        
        else:
            cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
            print "Configuration file (%s) not found" % config_file_path
            
            cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
            print "Host list will be generated from Shinken configuration file and saved into %s" % config_file_path
            get_architecture_hosts()
            configuration.save_to_file(config_file_path)


class Host(object):
    def __init__(self, name):
        self.ip = ""
        self.alt_names = set()
        self.configuration = dict()
        self.is_mongo_cluster_member = False
        self.can_update_other_mongo_cluster_nodes = False
        self.is_shinken_host = False
        
        name = name.strip()
        self.ip = HostList.resolve_addr(name)
        if name != self.ip:
            self.add_alt_name(name)
    
    
    def __repr__(self):
        if self.alt_names:
            return "%s (%s)" % (self.ip, ",".join([name for name in self.alt_names if name != self.ip]))
        else:
            return self.ip
    
    
    def lock(self):
        if self.is_locked():
            raise HostAlreadyLocked(self.get_lock())
        
        # Put lock on remote host to avoid performing multiple updates on the same host
        p, out, err = launch_ssh_command(self, "echo '%s' > %s" % (self.get_display_name(), SHINKEN_REMOTE_LOCK_FILE))
        if err:
            raise Exception(err)
    
    
    def unlock(self):
        if not self.is_locked():
            raise HostNotLocked
        
        # Remove remote lock
        p, out, err = launch_ssh_command(self, "rm %s" % SHINKEN_REMOTE_LOCK_FILE)
        if err:
            raise Exception(err)
    
    
    def get_lock(self):
        rc, out, err = launch_ssh_command(self, "cat %s" % SHINKEN_REMOTE_LOCK_FILE)
        if err:
            raise Exception(err)
        
        return out
    
    
    def is_locked(self):
        try:
            self.get_lock()
            return True
        except Exception as e:
            return False
    
    
    def get_addr(self):
        return self.ip
    
    
    def get_names(self):
        return self.alt_names
    
    
    def get_display_name(self):
        names_without_ip = [name for name in self.alt_names if name != self.ip]
        if len(names_without_ip) > 0:
            return names_without_ip[0]
        
        return self.ip
    
    
    def get_configuration(self):
        return self.configuration
    
    
    def set_configuration(self, configuration):
        self.configuration = configuration
    
    
    def add_alt_name(self, name):
        self.alt_names.add(name.strip())


class HostList(object):
    configuration_scan_done = False
    
    
    def __init__(self):
        self.hosts = set()
    
    
    def __repr__(self):
        return ",".join([h.get_display_name() for h in self.hosts])
    
    
    @staticmethod
    def resolve_addr(address):
        try:
            address_name = socket.gethostbyname(address)
            return address_name
        except IOError as exp:
            print "Unable to get IP address for %s : %s" % (address, exp)
            return address
    
    
    def add_host(self, address, configuration=None, is_mongo_cluster_member=None, can_update_other_mongo_cluster_nodes=None, is_shinken_host=None):
        addr = HostList.resolve_addr(address)
        
        host_exists = False
        for host in self.hosts:
            if host.get_addr() == addr:
                host_exists = True
                host.add_alt_name(address)
                
                if is_mongo_cluster_member is not None:
                    host.is_mongo_cluster_member = True
                if can_update_other_mongo_cluster_nodes is not None:
                    host.can_update_other_mongo_cluster_nodes = True
                if is_shinken_host is not None:
                    host.is_shinken_host = True
                
                if configuration is not None:
                    host.set_configuration(configuration)
        
        if not host_exists:
            h = Host(addr)
            h.add_alt_name(address)
            
            if is_mongo_cluster_member is not None:
                h.is_mongo_cluster_member = True
            if can_update_other_mongo_cluster_nodes is not None:
                h.can_update_other_mongo_cluster_nodes = True
            if is_shinken_host is not None:
                h.is_shinken_host = True
            
            if configuration is not None:
                h.set_configuration(configuration)
            self.hosts.add(h)
    
    
    def get_hosts(self):
        return self.hosts
    
    
    def get_host(self, addr):
        for host in self.hosts:
            if host.get_addr() == addr or addr in host.get_names() or HostList.resolve_addr(addr) == host.get_addr() or HostList.resolve_addr(addr) in host.get_names():
                return host
        
        return None
    
    
    def is_empty(self):
        return len(self.hosts) == 0


class StatusCounter(object):
    def __init__(self):
        self.ok = 0
        self.warning = 0
        self.critical = 0
        self.lock = threading.Lock()
    
    
    def add_ok(self):
        with self.lock:
            self.ok += 1
    
    
    def add_warning(self):
        with self.lock:
            self.warning += 1
    
    
    def add_critical(self):
        with self.lock:
            self.critical += 1
    
    
    def get_ok(self):
        return self.ok
    
    
    def get_warning(self):
        return self.warning
    
    
    def get_critical(self):
        return self.critical


class DisplayUtils(object):
    
    @staticmethod
    def display_seconds(seconds):
        minutes = math.floor(seconds / 60)
        if minutes > 0:
            seconds_left = seconds - (minutes * 60)
            if seconds_left > 0:
                return "%dmn%ds" % (minutes, seconds_left)
            else:
                return "%dmn" % minutes
        
        return "%ds" % seconds
    
    
    @staticmethod
    def task_status(status, message, counter, print_status=True):
        color = "grey"
        if status == "OK":
            color = "green"
            counter.add_ok()
        elif status == "WARNING":
            status = "WARN"
            color = "yellow"
            counter.add_warning()
        elif status == "ERROR":
            color = "red"
            counter.add_critical()
        elif status == "INFO":
            color = "cyan"
        elif status == "LOG":
            color = "grey"
        
        if print_status:
            with globalLock:
                print(" [%s]" % sprintf("%5s" % status, color=color)),
                print message
    
    
    @staticmethod
    def status(message_type="success"):
        msg = ""
        color = ""
        if message_type == "success":
            msg = "OK"
            color = "green"
        elif message_type == "warning":
            msg = "WARNING"
            color = "yellow"
        elif message_type == "error":
            msg = "ERROR"
            color = "red"
        elif message_type == "info":
            msg = "INFO"
            color = "grey"
        
        return "[ %s ]" % sprintf("%-8s" % msg, color=color)
    
    
    @staticmethod
    def recap(ok=0, warning=0, critical=0):
        ok_color = "grey"
        if ok != 0:
            ok_color = "green"
        ok_str = sprintf("OK: %d" % ok, color=ok_color)
        
        warning_color = "grey"
        if warning != 0:
            warning_color = "yellow"
        warning_str = sprintf("WARNING: %d" % warning, color=warning_color)
        
        critical_color = "grey"
        if critical != 0:
            critical_color = "red"
        critical_str = sprintf("CRITICAL: %d" % critical, color=critical_color)
        
        print ""
        print "%27s %27s %26s" % (ok_str, warning_str, critical_str)
    
    
    @staticmethod
    def exit_if_errors(status_counter, title, message):
        if status_counter.get_critical() > 0:
            print ""
            print_h1(title, line_color="red", title_color="red")
            cprint(message, color="red")
            sys.exit(1)


class ThreadWithException(threading.Thread):
    def run(self):
        self.exception = None
        try:
            # Python 2 and python 3 have different methods when gettings arguments to pass to thread function
            # https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread-in-python
            if hasattr(self, '_Thread__target'):
                # Thread uses name mangling prior to Python 3.
                self._Thread__target(*self._Thread__args, **self._Thread__kwargs)
            else:
                self._target(*self._args, **self._kwargs)
        except BaseException:
            self.exception = Exception('Exception detected from thread: %s' % traceback.format_exc())
    
    
    def join_with_exception(self):
        super(ThreadWithException, self).join()
        if self.exception:
            raise self.exception
        return None


SHINKEN_CONFIGURATION_PATH = "/etc/shinken"
DAEMONS = [
    "arbiter",
    "broker",
    "poller",
    "reactionner",
    "receiver",
    "scheduler",
    "synchronizer",
]

DEFAULT_CFG_PATH = "/etc/shinken/shinken_auto_update_config.json"

INSTALL_LOGS_PATH = "shinken_auto_update_logs"

SHINKEN_ARCHIVE_REMOTE_NAME = "shinken_auto_update_archive.tar.gz"
SHINKEN_PATCH_REMOTE_NAME = "shinken_auto_update_patch.tar.gz"
SHINKEN_INSTALL_PATH_EXEC = ""
SHINKEN_REMOTE_LOCK_FILE = "/tmp/shinken_auto_update_lock_%s.lock" % datetime.now().strftime("%Y_%m_%d_%Hh%Mm%S")

DEBUG = False

hosts_addresses = set()
shinken_host_list = HostList()
mongodb_cluster_topography = list()
globalLock = threading.Lock()
configuration = Configuration()


def launch_task_on_all_hosts(task_name, error_recap_title, error_recap_message, status_counter, function_name, args_list=(), custom_host_list=None):
    if custom_host_list:
        hosts = custom_host_list
    else:
        hosts = shinken_host_list.get_hosts()
    
    if task_name is not None:
        print ""
        print_h1(task_name)
    
    thread_pool = list()
    for host in hosts:
        default_args = (host, status_counter)
        t = ThreadWithException(target=function_name, args=default_args + args_list)
        thread_pool.append({
            "thread": t,
            "host"  : host,
        })
        t.start()
    
    for info in thread_pool:
        try:
            info.get("thread").join_with_exception()
        except:
            print ""
            print_h1("IRRECOVERABLE ERROR DETECT WHEN PERFORMING OPERATION ON HOST", line_color="red", title_color="red")
            cprint(
                "An irrecoverable has been raised when processing host '%s'. \nTo prevent broken Shinken updates or patch applications, all operations on hosts are stopped. \n\nPlease contact your Shinken support providing the following error: \n" % info.get(
                    "host").get_display_name(), color="red")
            traceback.print_exc()
            sys.exit(1)
    
    DisplayUtils.recap(ok=status_counter.get_ok(), warning=status_counter.get_warning(), critical=status_counter.get_critical())
    DisplayUtils.exit_if_errors(status_counter,
                                title=error_recap_title,
                                message=error_recap_message)


def get_clean_filename(filename_to_clean):
    valid_chars = "_.%s%s" % (string.ascii_letters, string.digits)
    filename = ''.join(c for c in filename_to_clean if c in valid_chars)
    filename = filename.replace(' ', '_')
    filename = filename.replace('.', '_')
    
    return filename


def launch_ssh_command(host, command, timeout=None):
    if timeout is None:
        timeout = host.get_configuration().get("ssh_default_timeout", configuration.get("ssh_default_timeout"))
    
    cmd_without_timeout = """ssh -o BatchMode=yes -o LogLevel=ERROR -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no root@%s '%s'""" % (host.get_addr(), command)
    cmd = "timeout %d %s" % (timeout, cmd_without_timeout)
    
    if DEBUG:
        print cmd
    
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    out, err = p.communicate()
    out = out.decode('utf8')
    err = err.decode('utf8')
    
    if p.returncode == 124:
        raise SSHCommandTimeout(timeout)
    
    return p.returncode, out, err


def copy_via_ssh(host, src, dest, timeout=None):
    if timeout is None:
        timeout = host.get_configuration().get("ssh_default_timeout", configuration.get("ssh_default_timeout"))
    
    cmd_without_timeout = """scp -o BatchMode=yes -o LogLevel=ERROR -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no '%s' root@%s:%s""" % (src, host.get_addr(), dest)
    cmd = "timeout %d %s" % (timeout, cmd_without_timeout)
    
    if DEBUG:
        print cmd
    
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
    out, err = p.communicate()
    out = out.decode('utf8')
    err = out.decode('utf8')
    
    if p.returncode == 124:
        raise SSHCommandTimeout(timeout)
    
    return p.returncode, out, err


def init():
    if os.path.exists(INSTALL_LOGS_PATH):
        try:
            shutil.move(INSTALL_LOGS_PATH, "%s_backup_%s" % (INSTALL_LOGS_PATH, datetime.now().strftime("%Y_%m_%d_%Hh%Mm%S")))
        except OSError as e:
            print "Could not clean logs directory (%s): %s" % (INSTALL_LOGS_PATH, e.str())
    
    try:
        os.mkdir(INSTALL_LOGS_PATH)
    except OSError as e:
        print "Could not create logs directory (%s): %s" % (INSTALL_LOGS_PATH, e.str())


################################################
######      TASK: GENERATE_HOST_LIST      ######
################################################


def read_daemon_cfg_file(file_path):
    conf = {}
    file = open(file_path)
    for line in file:
        if not line.startswith("#"):
            m = re.match(r"\s*(?P<key>\w+)\s*(?P<value>.*)", line)
            if m is not None:
                match = m.groupdict()
                conf[match['key'].strip()] = match['value'].split(";")[0].strip()
    
    return conf


def get_architecture_hosts():
    if shinken_host_list.configuration_scan_done:
        return
    
    print ""
    print_h1("Retrieving hosts information from Shinken cfg files")
    
    for daemon_type in DAEMONS:
        for daemon_cfg_file in os.listdir("%s/%ss/" % (SHINKEN_CONFIGURATION_PATH, daemon_type)):
            if fnmatch.fnmatch(daemon_cfg_file, "*.cfg"):
                conf = read_daemon_cfg_file("%s/%ss/%s" % (SHINKEN_CONFIGURATION_PATH, daemon_type, daemon_cfg_file))
                if conf.get("enabled", '1') == '1':
                    addr = conf.get("address")
                    if addr is not None:
                        shinken_host_list.add_host(addr, is_shinken_host=True)
    
    shinken_host_list.configuration_scan_done = True
    
    cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
    print "Correctly read addresses from Shinken daemons configuration files"


def build_host_list(args):
    get_architecture_hosts()
    
    list_hosts(args)
    
    if os.path.exists(args.config):
        try:
            config_backup_name = "%s.backup_%s" % (args.config, datetime.now().strftime("%Y_%m_%d_%Hh%Mm%S"))
            shutil.move(args.config, config_backup_name)
            cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
            print "Host list is being built but a configuration file already exists"
            print "    To avoid data loss, existing configuration file (%s) has been backed up (%s)" % (args.config, config_backup_name)
        except OSError as e:
            print
            print_h1("CONFIGURATION ERROR", line_color="red", title_color="red")
            cprint("Could not backup configuration file (%s) before overwrite: %s" % (args.config, str(e)), color="red")
            sys.exit(1)
    
    configuration.save_to_file(args.config)


########################################
######      TASK: LIST_HOSTS      ######
########################################


def print_host_list(host_list):
    hosts_display = []
    for host in host_list:
        if host.is_mongo_cluster_member and host.can_update_other_mongo_cluster_nodes:
            hosts_display.append("     %s %s %s" % (sprintf("-", color='grey'), host, sprintf("(mongodb cluster host)", color="cyan")))
        else:
            hosts_display.append("     %s %s" % (sprintf("-", color='grey'), host))
    
    cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
    cprint(u' %s ' % CHARACTERS.check, color='green', end='')
    cprint("%d host(s) found: " % len(host_list), color='green')
    print '\n'.join(hosts_display)


def list_hosts(args):
    configuration.read_from_file(args.config)
    
    print ""
    print_h1("Current host list")
    
    print_host_list(shinken_host_list.get_hosts())


###################################
######      TASK: CHECK      ######
###################################

def check_host_connection(host, status_counter):
    print " %s Checking connection to '%s'" % (sprintf("[.....]", color="grey"), host.get_display_name())
    try:
        p, out, err = launch_ssh_command(host, "echo 'OK'", timeout=host.get_configuration().get("ping_timeout", configuration.get("ping_timeout")))
    except SSHCommandTimeout as e:
        DisplayUtils.task_status(status="ERROR",
                                 message="Connection to '%s' failed: could not connect to server in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                 counter=status_counter)
        return
    
    if err:
        conn_status = "ERROR"
        conn_message = "Connection to '%s' failed: '%s'" % (host.get_display_name(), err.strip())
    else:
        conn_status = "OK"
        conn_message = "Connection to '%s' successful" % host.get_display_name()
    
    DisplayUtils.task_status(conn_status, conn_message, status_counter)


def check(args):
    list_hosts(args)
    
    launch_task_on_all_hosts(task_name="Checking connection to architecture hosts",
                             error_recap_title="HOSTS CONNECTION ERROR",
                             error_recap_message="Could not connect to every Shinken installation hosts correctly. The update process is aborted to avoid having a partial update",
                             status_counter=StatusCounter(),
                             function_name=check_host_connection)


####################################
######      TASK: UPLOAD      ######
####################################


def get_file_md5(filepath):
    hash_md5 = hashlib.md5()
    with open(filepath, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    
    return hash_md5.hexdigest()


def find_remote_extracted_archive_path(host, args):
    shinken_archive_remote_path = host.get_configuration().get("tmp_dir", configuration.get("tmp_dir"))
    
    if getattr(args, "patch", False):
        grep_str = "^shinken-enterprise-patch_.*"
    else:
        grep_str = "^shinken-enterprise_.*"
    
    return launch_ssh_command(host, "ls %s/ | grep -E '%s' | head -n 1" % (shinken_archive_remote_path, grep_str))


def check_uploaded_archive_checksum(host, args):
    shinken_archive_remote_path = host.get_configuration().get("tmp_dir", configuration.get("tmp_dir"))
    if getattr(args, "patch", False):
        shinken_archive_remote_fullpath = "%s/%s" % (shinken_archive_remote_path, SHINKEN_PATCH_REMOTE_NAME)
    else:
        shinken_archive_remote_fullpath = "%s/%s" % (shinken_archive_remote_path, SHINKEN_ARCHIVE_REMOTE_NAME)
    
    try:
        local_archive_checksum = get_file_md5(args.archive)
    except Exception as e:
        raise Exception("Could not get local Shinken archive checksum: %s" % e)
    
    try:
        rc, out, err = launch_ssh_command(host, "md5sum %s | cut -d\" \" -f 1" % shinken_archive_remote_fullpath)
    except SSHCommandTimeout as e:
        raise Exception("Could not get remote Shinken archive checksum on '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)))
    
    if err:
        raise Exception("Could not get remote Shinken archive checksum: %s" % err.split())
    
    remote_archive_checksum = out.strip()
    if local_archive_checksum != remote_archive_checksum:
        raise ShinkenArchiveChecksumMismatch(
            "An archive has been found on the remote server but local and uploaded archive checksums do not match: local archive checksum is '%s' while remote archive checksum is '%s'" % (local_archive_checksum, remote_archive_checksum))


def upload_archive(host, status_counter, args):
    print " %s Uploading Shinken archive to '%s'" % (sprintf("[.....]", color="grey"), host.get_display_name())
    configuration.set("force_reupload", args.force_reupload, namespace="cli")
    
    if not getattr(args, "patch", False) and "shinken-enterprise-patch_" in args.archive:
        DisplayUtils.task_status(status="ERROR",
                                 message="The provided archive seems to be a Shinken patch but the --is-patch option has not been set. If you really want to upload a patch, try adding the --patch option",
                                 counter=status_counter)
        return
    
    lock_message = ""
    lock_status = ""
    try:
        with globalLock:
            host.lock()
    except HostAlreadyLocked as e:
        lock_message = "Archive upload on host '%s' has been skipped because it is the same machine as '%s'" % (host.get_display_name(), e.locked_by.strip())
        lock_status = "WARNING"
    except SSHCommandTimeout as e:
        lock_message = "Could not place remote lock for '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout))
        lock_status = "ERROR"
    except Exception as e:
        lock_message = "Could not place remote lock for '%s': %s" % (host.get_display_name(), e)
        lock_status = "ERROR"
    finally:
        if lock_status != "":
            DisplayUtils.task_status(lock_status, lock_message, status_counter)
            return
    
    shinken_archive_remote_path = host.get_configuration().get("tmp_dir", configuration.get("tmp_dir"))
    if getattr(args, "patch", False):
        shinken_archive_remote_fullpath = "%s/%s" % (shinken_archive_remote_path, SHINKEN_PATCH_REMOTE_NAME)
    else:
        shinken_archive_remote_fullpath = "%s/%s" % (shinken_archive_remote_path, SHINKEN_ARCHIVE_REMOTE_NAME)
    
    remote_shinken_archive_found = True
    try:
        rc, out, err = launch_ssh_command(host, "stat %s" % shinken_archive_remote_fullpath)
    except SSHCommandTimeout as e:
        remote_shinken_archive_found = False
    
    if rc != 0 or (err and "No such file or directory" in err):
        remote_shinken_archive_found = False
    
    force_archive_reupload = False
    if remote_shinken_archive_found:
        try:
            check_uploaded_archive_checksum(host, args)
        except ShinkenArchiveChecksumMismatch as e:
            DisplayUtils.task_status(status="INFO",
                                     message="%s. Shinken archive will be reuploaded" % str(e),
                                     counter=status_counter)
            force_archive_reupload = True
        except Exception as e:
            DisplayUtils.task_status(status="WARNING",
                                     message="Could not compare local and remote (host %s) Shinken archive checksum: '%s'. Shinken archive will be reuploaded" % (host.get_display_name(), e),
                                     counter=status_counter)
            force_archive_reupload = True
    
    force_archive_reupload = force_archive_reupload or host.get_configuration().get("force_reupload", configuration.get("force_reupload"))
    if (not remote_shinken_archive_found) or force_archive_reupload:
        try:
            rc, out, err = launch_ssh_command(host, "rm -rf %s && mkdir -p %s" % (shinken_archive_remote_path, shinken_archive_remote_path))
        except SSHCommandTimeout as e:
            DisplayUtils.task_status(status="ERROR",
                                     message="Could not clean working folder on '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                     counter=status_counter)
            return
        if err:
            DisplayUtils.task_status(status="ERROR",
                                     message="Could not clean working folder on '%s' (%s): %s" % (host.get_display_name(), shinken_archive_remote_path, err.strip()),
                                     counter=status_counter)
            return
        
        try:
            timeout = host.get_configuration().get("archive_upload_timeout", configuration.get("archive_upload_timeout"))
            if getattr(args, "patch", False):
                timeout = host.get_configuration().get("patch_upload_timeout", configuration.get("patch_upload_timeout"))
            
            rc, out, err = copy_via_ssh(host, args.archive, shinken_archive_remote_fullpath, timeout=timeout)
        except SSHCommandTimeout as e:
            DisplayUtils.task_status(status="ERROR",
                                     message="Could not upload Shinken archive to '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                     counter=status_counter)
            return
        if err:
            DisplayUtils.task_status(status="ERROR",
                                     message="Shinken archive upload to '%s' failed: '%s'" % (host.get_display_name(), err.strip()),
                                     counter=status_counter)
            return
        
        try:
            timeout = host.get_configuration().get("archive_extract_timeout", configuration.get("archive_extract_timeout"))
            if getattr(args, "patch", False):
                timeout = host.get_configuration().get("patch_extract_timeout", configuration.get("patch_extract_timeout"))
            
            rc, out, err = launch_ssh_command(host, "cd %s/ && tar -xzvf %s" % (shinken_archive_remote_path, shinken_archive_remote_fullpath),
                                              timeout=timeout)
        except SSHCommandTimeout as e:
            DisplayUtils.task_status(status="ERROR",
                                     message="Could not extract Shinken archive to '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                     counter=status_counter)
            return
        
        if err:
            DisplayUtils.task_status(status="ERROR",
                                     message="Could not extract archive on '%s': %s" % (host.get_display_name(), err.strip()),
                                     counter=status_counter)
            return
        
        try:
            rc, out, err = find_remote_extracted_archive_path(host, args)
        except SSHCommandTimeout as e:
            DisplayUtils.task_status(status="ERROR",
                                     message="Could not find extracted Shinken archive files on '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                     counter=status_counter)
            return
        
        if err or out.strip() == "":
            upload_status = "ERROR"
            msg = "Could not locate extracted archive path on '%s'" % host.get_display_name()
            if err:
                upload_message = "%s: %s" % (msg, err.strip())
            else:
                upload_message = "%s: %s" % (msg, "Archive path not found in '%s' on remote host" % shinken_archive_remote_path)
            
            DisplayUtils.task_status(upload_status, upload_message, status_counter)
            return
        
        global SHINKEN_INSTALL_PATH_EXEC
        SHINKEN_INSTALL_PATH_EXEC = "%s/%s" % (shinken_archive_remote_path, out.strip())
        
        upload_status = "OK"
        upload_message = "Shinken archive uploaded to '%s'" % host.get_display_name()
    else:
        upload_status = "OK"
        upload_message = "Shinken archive already present on '%s'. No upload needed" % host.get_display_name()
    
    DisplayUtils.task_status(upload_status, upload_message, status_counter)
    
    try:
        with globalLock:
            host.unlock()
    except HostNotLocked:
        lock_message = "No lock has been found on '%s' when trying to remove it." % host.get_display_name()
        lock_status = "ERROR"
    except SSHCommandTimeout as e:
        lock_message = "Could not remove remote lock for '%s' in less than %s (timeout). Future operations performed by this script might fail" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout))
        lock_status = "ERROR"
    except Exception as e:
        lock_message = "Could not remove remote lock on '%s': %s. Future operations performed by this script might fail" % (host.get_display_name(), e)
        lock_status = "ERROR"
    finally:
        if lock_status != "":
            DisplayUtils.task_status(lock_status, lock_message, status_counter)
            return


def upload(args):
    check(args)
    
    if getattr(args, "patch", False):
        task_name = "Uploading Shinken patch to architecture hosts"
        error_recap_title = "SHINKEN PATCH UPLOAD ERROR"
        error_recap_message = "Could not copy the Shinken patch correctly on every installation hosts. The patch process is aborted to avoid having a partial patch application"
    else:
        task_name = "Uploading Shinken archive to architecture hosts"
        error_recap_title = "SHINKEN ARCHIVE UPLOAD ERROR"
        error_recap_message = "Could not copy the Shinken archive correctly on every installation hosts. The update process is aborted to avoid having a partial update"
    
    launch_task_on_all_hosts(task_name=task_name,
                             error_recap_title=error_recap_title,
                             error_recap_message=error_recap_message,
                             status_counter=StatusCounter(),
                             function_name=upload_archive,
                             args_list=(args,))


#######################################################
######      TASK: PREINSTALL/PREPATCH_CHECK      ######
#######################################################

def perform_check(host, status_counter, args, check):
    shinken_archive_remote_path = host.get_configuration().get("tmp_dir", configuration.get("tmp_dir"))
    
    if check.get("type") == "commonsh":
        ssh_command = "bash -c \"cd %s/shinken-enterprise*/ && source %s/shinken-enterprise*/lib/common.sh && %s\"" % (shinken_archive_remote_path, shinken_archive_remote_path, check.get("name"))
    else:
        ssh_command = "bash -c \"cd %s/shinken-enterprise*/ && %s\"" % (shinken_archive_remote_path, check.get("path"))
    
    try:
        rc, out, err = launch_ssh_command(host, ssh_command, timeout=host.get_configuration().get("preinstall_check_timeout", configuration.get("preinstall_check_timeout")))
    except SSHCommandTimeout as e:
        DisplayUtils.task_status(status="ERROR",
                                 message="Could not execute preinstall check '%s' on '%s' in less than %s (timeout)" % (check.get("display_name"), host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                 counter=status_counter)
        return
    
    if rc != 0:
        if out:
            error_message = "Check failure on '%s' (%s): '%s'" % (host.get_display_name(), check.get("display_name"), out.strip())
        elif err:
            error_message = "Check failure on '%s' (%s): '%s'" % (host.get_display_name(), check.get("display_name"), err.strip())
        else:
            error_message = "Check failure on '%s' (%s): no error message" % (host.get_display_name(), check.get("display_name"))
        
        DisplayUtils.task_status("ERROR", error_message, status_counter)
        return False
    
    DisplayUtils.task_status("OK", "", status_counter, print_status=False)
    return True


def extract_mongo_cluster_members_from_raw_string(raw_str):
    # Scan results from a mongodb with sharding looks like this: replicasetname/node1:27017,node2:27017,node3:27017
    # Scan results from a mongodb with only a replicaset (no sharding) looks like this: "node1:27017,node2:27017,node3:27017". We transform this string to be able to parse them with the same code
    if raw_str.startswith('"') and raw_str.endswith('"'):
        raw_str = raw_str[1:-1]
        raw_str = "replicaset/%s" % raw_str
        
    raw_members = raw_str.split("/")
    if len(raw_members) < 2:
        raise Exception("Could not retrieve mongodb cluster members from the following string: '%s'" % raw_str)
    
    member_list_with_ports = raw_members[1].split(",")
    members = set()
    for member_with_port in member_list_with_ports:
        members.add(member_with_port.split(":")[0])
    
    return members


def get_mongodb_cluster_hosts(host):
    # 2 different cluster configurations can be used
    scan_commands = [
        # Complex mongo cluster with sharding (mongos and mongo-configsrv daemons)
        "mongo --quiet --eval \"JSON.stringify(db.adminCommand( { listShards: 1 } ).shards[0].host);\"",
        # Simple mongo cluster with onyl mongod nodes and a replicaset
        "mongo --quiet --eval \"JSON.stringify(rs.config().members.map(function(m) { return m.host }).join());\""
    ]
    
    cluster_nodes_list = HostList()
    for command in scan_commands:
        rc, out, err = launch_ssh_command(host, command)
        
        # 252 is the code returned by mongo when the command failed. On classic mongodb installations or simple mongodb cluster installations, at least one of this commands will fail because there is no shard or replicaset configured
        if rc == 252:
            continue
            
        if rc != 0 or err:
            print rc
            print err
            raise Exception(err.strip())
        
        raw_cluster_member_list = extract_mongo_cluster_members_from_raw_string(out.strip())
        for node in raw_cluster_member_list:
            cluster_nodes_list.add_host(node, is_mongo_cluster_member=True)
    
    return cluster_nodes_list


def mongodb_cluster_preinstall_scan(host, status_counter, args, mongodb_cluster_topography):
    print " %s Checking Mongodb configuration on '%s' to handle possible Mongodb cluster setup during update" % (sprintf("[.....]", color="grey"), host.get_display_name())
    
    try:
        members = get_mongodb_cluster_hosts(host)
    except SSHCommandTimeout as e:
        DisplayUtils.task_status(status="ERROR",
                                 message="Could not check Mongodb configuration on '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                 counter=status_counter)
        return
    except Exception as e:
        DisplayUtils.task_status(status="ERROR",
                                 message="Could not check Mongodb configuration on '%s': %s" % (host.get_display_name(), e),
                                 counter=status_counter)
        return
    
    if members.is_empty():
        DisplayUtils.task_status(status="OK",
                                 message="The host '%s' is not part of a Mongodb cluster" % host.get_display_name(),
                                 counter=status_counter)
        return
    
    # Check if member is part of an already scanned Mongodb cluster
    must_save_cluster_in_topography_set = False
    for scanned_cluster in mongodb_cluster_topography:
        # Cluster already saved but some elements differ
        if len(members.get_hosts().intersection(scanned_cluster.get_hosts())) > 0:
            diff_list = members.get_hosts().intersection(scanned_cluster.get_hosts())
            for host in diff_list:
                members.add_host(host.get_addr())
            must_save_cluster_in_topography_set = True
    
    if must_save_cluster_in_topography_set or len(mongodb_cluster_topography) == 0:
        with globalLock:
            mongodb_cluster_topography.append(members)
    
    DisplayUtils.task_status(status="OK",
                             message="The host '%s' is part of a Mongodb cluster containing the following hosts: '%s'" % (host.get_display_name(), members),
                             counter=status_counter)


def check_mongo_node_can_connect_to_all_other_nodes(host, status_counter, cluster):
    can_connect_to_other_nodes = True
    
    for other_node in cluster.get_hosts():
        print " %s Checking SSH connection from '%s' to '%s'" % (sprintf("[.....]", color="grey"), host.get_display_name(), other_node.get_display_name())
        try:
            rc, out, err = launch_ssh_command(host,
                                              "ssh -o BatchMode=yes -o LogLevel=ERROR -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no root@%s 'echo \"OK\"'" % other_node.get_addr(),
                                              timeout=host.get_configuration().get("ping_timeout", configuration.get("ping_timeout")))
        except SSHCommandTimeout as e:
            can_connect_to_other_nodes = False
            DisplayUtils.task_status(status="LOG",
                                     message="SSH connection cannot be established from '%s' to '%s' for Mongodb cluster update in less than %d (timeout) (not a critical error)" % (
                                         host.get_display_name(), other_node.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                     counter=status_counter)
            break
        
        if err or rc != 0:
            can_connect_to_other_nodes = False
            DisplayUtils.task_status(status="LOG",
                                     message="SSH connection cannot be established from '%s' to '%s' for Mongodb cluster update (not a critical error): %s" % (host.get_display_name(), other_node.get_display_name(), err.strip()),
                                     counter=status_counter)
            break
    
    host.can_update_other_mongo_cluster_nodes = can_connect_to_other_nodes
    if can_connect_to_other_nodes:
        DisplayUtils.task_status(status="OK",
                                 message="SSH connection can be established from '%s' to '%s' for Mongodb cluster update" % (host.get_display_name(), cluster),
                                 counter=status_counter)


def find_mongo_cluster_update_candidates(mongo_cluster):
    print ""
    print_h1("Looking for Mongodb cluster update candidates")
    launch_task_on_all_hosts(task_name=None,
                             error_recap_title="MONGODB CLUSTER HOSTS SCAN ERROR",
                             error_recap_message="Errors have been detected when scanning Mongodb cluster members to find candidates for update. The update process is stopped on all hosts to avoid incomplete/broken updates due to an unhandled Mongodb configuration. \nNothing has been done on remote hosts yet except uploading and extracting the Shinken archive in a temporary folder",
                             status_counter=StatusCounter(),
                             function_name=check_mongo_node_can_connect_to_all_other_nodes,
                             args_list=(mongo_cluster,),
                             custom_host_list=mongo_cluster.get_hosts())
    
    return [h for h in mongo_cluster.get_hosts() if h.can_update_other_mongo_cluster_nodes]


def host_preinstall_check(host, status_counter, args, is_patch=False):
    if is_patch:
        setattr(args, "patch", True)
        print " %s Performing pre patch checks on '%s'" % (sprintf("[.....]", color="grey"), host.get_display_name())
    else:
        print " %s Performing pre update checks on '%s'" % (sprintf("[.....]", color="grey"), host.get_display_name())
    
    shinken_archive_remote_path = host.get_configuration().get("tmp_dir", configuration.get("tmp_dir"))
    
    try:
        rc, out, err = find_remote_extracted_archive_path(host, args)
    except SSHCommandTimeout as e:
        DisplayUtils.task_status(status="ERROR",
                                 message="Could not find extracted Shinken archive files on '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                 counter=status_counter)
        return
    if err or out.strip() == "":
        upload_status = "ERROR"
        msg = "Could not locate extracted archive path on '%s'" % host.get_display_name()
        if err:
            upload_message = "%s: %s" % (msg, err.strip())
        else:
            upload_message = "%s: %s" % (msg, "Archive path not found in %s on remote host" % shinken_archive_remote_path)
        
        DisplayUtils.task_status(upload_status, upload_message, status_counter)
        return
    
    global SHINKEN_INSTALL_PATH_EXEC
    SHINKEN_INSTALL_PATH_EXEC = "%s/%s" % (shinken_archive_remote_path, out.strip())
    
    check_list = [
        {
            "name"        : "assert_is_installed",
            "display_name": "Shinken is installed",
            "type"        : "commonsh",
        },
        {
            "name"        : "assert_redhat_or_centos",
            "display_name": "Check CentOS/RHEL version",
            "type"        : "commonsh",
        },
        {
            "name"        : "assert_current_version_not_too_old",
            "display_name": "Current Shinken version check",
            "type"        : "commonsh",
        },
        {
            "display_name": "No duplicates in Synchronizer DB",
            "type"        : "script",
            "path"        : "lib/preupdate-checks/check_db_for_duplicates_wrapper.sh",
        },
        {
            "name"        : "assert_only_valid_mongo",
            "display_name": "Mongodb version",
            "type"        : "commonsh",
        },
        {
            "display_name": "At least 3 Gb available on disk",
            "type"        : "script",
            "path"        : "/usr/lib64/nagios/plugins/check_disk -c 3072 --units MB -p \"/\"",
        },
        {
            "name"        : "assert_protected_fields_conf_not_corrupted",
            "display_name": "Protected fields configuration not corrupted",
            "type"        : "commonsh",
        },
    ]
    
    checks_status = "OK"
    for check in check_list:
        if not perform_check(host, status_counter, args, check):
            checks_status = "ERROR"
    
    if checks_status != "OK":
        return
    
    checks_message = "Checks passed on %s" % host.get_display_name()
    
    DisplayUtils.task_status(checks_status, checks_message, status_counter)


def preinstall_check(args):
    hosts = shinken_host_list.get_hosts()
    if len(hosts) == 0:
        list_hosts(args)
    
    is_patch = False
    if args.subcommand == "prepatch-check" or args.subcommand == "patch" or args.subcommand == "launch-patch-application":
        is_patch = True
        task_name = "Performing sanity checks before patch application"
        error_recap_title = "SANITY CHECKS ERROR"
        error_recap_message = "Errors have been detected on hosts. The patch application process is stopped on all hosts to avoid incomplete/broken patch application. \nNothing has been done on remote hosts yet except uploading and extracting the Shinken patch in a temporary folder"
    else:
        task_name = "Performing sanity checks before update"
        error_recap_title = "SANITY CHECKS ERROR"
        error_recap_message = "Errors have been detected on hosts. The update process is stopped on all hosts to avoid incomplete/broken updates. \nNothing has been done on remote hosts yet except uploading and extracting the Shinken archive in a temporary folder"
    
    launch_task_on_all_hosts(task_name=task_name,
                             error_recap_title=error_recap_title,
                             error_recap_message=error_recap_message,
                             status_counter=StatusCounter(),
                             function_name=host_preinstall_check,
                             args_list=(args, is_patch,))
    
    launch_task_on_all_hosts(task_name="Mongodb configuration scan",
                             error_recap_title="MONGODB CONFIGURATION SCAN ERROR",
                             error_recap_message="Errors have been detected when scanning Mongodb configuration on hosts. The update process is stopped on all hosts to avoid incomplete/broken updates due to an unhandled Mongodb configuration. \nNothing has been done on remote hosts yet except uploading and extracting the Shinken archive in a temporary folder",
                             status_counter=StatusCounter(),
                             function_name=mongodb_cluster_preinstall_scan,
                             args_list=(args, mongodb_cluster_topography,))
    
    print ""
    print_h1("Mongodb configuration scan results")
    if len(mongodb_cluster_topography) == 0:
        print " [%s] Scan successful. No Mongodb clusters have been found on your Shinken architecture host. Listed hosts will be updated as usual" % (sprintf(" INFO", color="cyan"))
        return
    
    cprint(u' %s%s ' % (CHARACTERS.corner_bottom_left, CHARACTERS.arrow_left), color="grey", end='')
    cprint(u' %s ' % CHARACTERS.check, color='green', end='')
    cprint(" %s Mongodb cluster(s) found" % (len(mongodb_cluster_topography)), color="green")
    
    for cluster in mongodb_cluster_topography:
        cprint(u'     %s ' % (sprintf("-", color='grey')), end='')
        print ', '.join(str(host) for host in cluster.get_hosts())
    
    mongo_cluster_update_candidates_host_list = HostList()
    for mongo_cluster in mongodb_cluster_topography:
        update_candidates = find_mongo_cluster_update_candidates(mongo_cluster)
        # If no updates candidates have been found, we must check if the hosts of the mongodb cluster correspond to a Shinken host.
        # If the mongodb cluster is not hosted on Shinken machines, we do not have to udate it ourselves.
        if len(update_candidates) == 0:
            for mongo_host in mongo_cluster.get_hosts():
                if shinken_host_list.get_host(mongo_host.get_addr()) is not None:
                    message = "None of the hosts in the cluster '%s' can be used for updating Mongodb. The Mongodb update done by the Shinken update script will fail. \nThe update process is stopped on all hosts to avoid incomplete/broken updates due to an unhandled Mongodb configuration. \nNothing has been done on remote hosts yet except uploading and extracting the Shinken archive in a temporary folder" % mongo_cluster
                    
                    print ""
                    print_h1("MONGODB UPDATE CANDIDATES ERROR", line_color="red", title_color="red")
                    cprint(message, color="red")
                    sys.exit(1)
        # If we have candidates, we must check that these candidates are Shinken hosts for the update.sh script to be executed from
        else:
            shinken_host_candidate_found = False
            for candidate in update_candidates:
                shinken_host = shinken_host_list.get_host(candidate.get_addr())
                if shinken_host is not None:
                    candidate.is_shinken_host = True
                    shinken_host_candidate_found = True
            
            if not shinken_host_candidate_found:
                message = "The hosts '%s' can be used for updating Mongodb but none of them are part of the Shinken architecture. The Mongodb update done by the Shinken update script will fail. \nThe update process is stopped on all hosts to avoid incomplete/broken updates due to an unhandled Mongodb configuration. \nNothing has been done on remote hosts yet except uploading and extracting the Shinken archive in a temporary folder" % ",".join(
                    [h.get_display_name() for h in update_candidates if h.is_mongo_cluster_member and h.can_update_other_mongo_cluster_nodes])
                
                print ""
                print_h1("MONGODB UPDATE CANDIDATES ERROR", line_color="red", title_color="red")
                cprint(message, color="red")
                sys.exit(1)
            
            for host in [h for h in update_candidates if h.is_shinken_host and h.is_mongo_cluster_member and h.can_update_other_mongo_cluster_nodes]:
                mongo_cluster_update_candidates_host_list.add_host(host.get_addr())
    
    # Update shinken_host_list to specify which hosts are to considered as Mongodb hosts
    for shinken_host in shinken_host_list.get_hosts():
        # If shinken host exists in final mongodb update candidates list, mark it as such in the Shinken hosts list
        if mongo_cluster_update_candidates_host_list.get_host(shinken_host.get_addr()):
            shinken_host.is_mongo_cluster_member = True
            shinken_host.can_update_other_mongo_cluster_nodes = True
    
    print ""
    print_h1("Host list after Mongodb scan")
    print_host_list(shinken_host_list.get_hosts())


##################################################
######      TASK: LAUNCH-UPDATE-SCRIPT      ######
##################################################

def detect_shinken_update_errors(updater_output):
    error_tags = [
        "ERROR",
        "FAIL",
        "bad_start",
    ]
    
    for line in updater_output.splitlines():
        line = line.strip()
        for error_tag in error_tags:
            if error_tag in line:
                return True
    
    return False


def update_host(host, status_counter):
    print " %s Launching Shinken update on '%s'" % (sprintf("[.....]", color="grey"), host.get_display_name())
    
    lock_message = ""
    lock_status = ""
    try:
        with globalLock:
            host.lock()
    except HostAlreadyLocked as e:
        lock_message = "Shinken update on host '%s' has been skipped because it is the same machine as '%s'" % (host.get_display_name(), e.locked_by.strip())
        lock_status = "WARNING"
    except SSHCommandTimeout as e:
        lock_message = "Could not place remote lock for '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout))
        lock_status = "ERROR"
    except Exception as e:
        lock_message = "Could not place remote lock for '%s': %s" % (host.get_display_name(), e)
        lock_status = "ERROR"
    finally:
        if lock_status != "":
            DisplayUtils.task_status(lock_status, lock_message, status_counter)
            return
    
    try:
        rc, out, err = launch_ssh_command(host, "cd %s && ./update.sh --disable-important-notices-user-input" % SHINKEN_INSTALL_PATH_EXEC, timeout=host.get_configuration().get("update_timeout", configuration.get("update_timeout")))
    except SSHCommandTimeout as e:
        DisplayUtils.task_status(status="ERROR",
                                 message="Could not execute Shinken update script on '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout)),
                                 counter=status_counter)
        return
    
    log_file_name = get_clean_filename("auto_update_%s_output" % host.get_display_name())
    with open("%s/%s.log" % (INSTALL_LOGS_PATH, log_file_name), 'w') as log_file:
        log_file.write("\nOUTPUT\n======\n")
        log_file.write(out.encode('utf8'))
        if err:
            log_file.write("\nERRORS\n======\n")
            log_file.write(err.encode('utf8'))
    
    if err or rc != 0 or detect_shinken_update_errors(out):
        DisplayUtils.task_status(status="ERROR",
                                 message="Could not perform Shinken update correctly on host %s (see %s/%s.log for more information)" % (host.get_display_name(), INSTALL_LOGS_PATH, log_file_name),
                                 counter=status_counter)
        return
    
    DisplayUtils.task_status(status="OK",
                             message="Shinken update done on %s" % host.get_display_name(),
                             counter=status_counter)
    
    try:
        with globalLock:
            host.unlock()
    except HostNotLocked:
        lock_message = "No lock has been found on '%s' when trying to remove it." % host.get_display_name()
        lock_status = "ERROR"
    except SSHCommandTimeout as e:
        lock_message = "Could not remove remote lock for '%s' in less than %s (timeout). Future operations performed by this script might fail" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout))
        lock_status = "ERROR"
    except Exception as e:
        lock_message = "Could not remove remote lock on '%s': %s. Future operations performed by this script might fail" % (host.get_display_name(), e)
        lock_status = "ERROR"
    finally:
        if lock_status != "":
            DisplayUtils.task_status(lock_status, lock_message, status_counter)
            return


def launch_update_script(args):
    preinstall_check(args)
    
    mongo_cluster_update_hosts = [h for h in shinken_host_list.get_hosts() if h.can_update_other_mongo_cluster_nodes]
    
    if len(mongo_cluster_update_hosts) != 0:
        shinken_hosts_to_update = [h for h in shinken_host_list.get_hosts() if not h.can_update_other_mongo_cluster_nodes]
        
        # Updating Mongodb cluster hosts first
        launch_task_on_all_hosts(task_name="Performing Shinken update on Mongodb cluster hosts",
                                 error_recap_title="SHINKEN UPDATE ERROR",
                                 error_recap_message="Some problems were detected during the Shinken update of hosts that are part of a Mongodb cluster. The Shinken installation may be unstable/broken until update errors are identified and resolved. The update has not yet been launched on Shinken hosts that are not part of a Mongodb cluster. \nUpdate logs can be found in the '%s' folder for more information." % INSTALL_LOGS_PATH,
                                 status_counter=StatusCounter(),
                                 function_name=update_host,
                                 custom_host_list=mongo_cluster_update_hosts)
    else:
        shinken_hosts_to_update = shinken_host_list.get_hosts()
    
    launch_task_on_all_hosts(task_name="Performing Shinken update on architecture hosts",
                             error_recap_title="SHINKEN UPDATE ERROR",
                             error_recap_message="Some problems were detected during hosts Shinken update. The Shinken installation may be unstable/broken until update errors are identified and resolved. \nUpdate logs can be found in the '%s' folder for more information." % INSTALL_LOGS_PATH,
                             status_counter=StatusCounter(),
                             function_name=update_host,
                             custom_host_list=shinken_hosts_to_update)


####################################
######      TASK: UPDATE      ######
####################################

def update(args):
    if not os.path.exists(args.config):
        print ""
        print_h1("CONFIGURATION ERROR", line_color="red", title_color="red")
        cprint("No configuration file has been found. \nPlease generate one using the build-host-list subcommand or specify the configuration file path using the --config flag.", color="red")
        sys.exit(1)
    
    upload(args)
    launch_update_script(args)


######################################################
######      TASK: LAUNCH_PATCH_APPLICATION      ######
######################################################

def apply_patch(host, status_counter, args):
    print " %s Applying Shinken patch on '%s'" % (sprintf("[.....]", color="grey"), host.get_display_name())
    
    lock_message = ""
    lock_status = ""
    try:
        with globalLock:
            host.lock()
    except HostAlreadyLocked as e:
        lock_message = "Shinken patch application on host '%s' has been skipped because it is the same machine as '%s'" % (host.get_display_name(), e.locked_by.strip())
        lock_status = "WARNING"
    except SSHCommandTimeout as e:
        lock_message = "Could not place remote lock for '%s' in less than %s (timeout)" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout))
        lock_status = "ERROR"
    except Exception as e:
        lock_message = "Could not place remote lock for '%s': %s" % (host.get_display_name(), e)
        lock_status = "ERROR"
    finally:
        if lock_status != "":
            DisplayUtils.task_status(lock_status, lock_message, status_counter)
            return
    
    timeout = host.get_configuration().get("patch_timeout", configuration.get("patch_timeout"))
    
    # TODO: Implement patch application
    DisplayUtils.task_status(status="ERROR",
                             message="The patch application feature is not available yet. To complete patch application, connect to '%s' and execute the install-patch.sh script present in the '%s' folder" % (
                                 host.get_display_name(), configuration.get("tmp_dir")),
                             counter=status_counter)
    
    try:
        with globalLock:
            host.unlock()
    except HostNotLocked:
        lock_message = "No lock has been found on '%s' when trying to remove it." % host.get_display_name()
        lock_status = "ERROR"
    except SSHCommandTimeout as e:
        lock_message = "Could not remove remote lock for '%s' in less than %s (timeout). Future operations performed by this script might fail" % (host.get_display_name(), DisplayUtils.display_seconds(e.timeout))
        lock_status = "ERROR"
    except Exception as e:
        lock_message = "Could not remove remote lock on '%s': %s. Future operations performed by this script might fail" % (host.get_display_name(), e)
        lock_status = "ERROR"
    finally:
        if lock_status != "":
            DisplayUtils.task_status(lock_status, lock_message, status_counter)
            return


def launch_patch_application(args):
    list_hosts(args)
    preinstall_check(args)
    
    launch_task_on_all_hosts(task_name="Applying Shinken patch on architecture hosts",
                             error_recap_title="SHINKEN PATCH APPLICATION ERROR",
                             error_recap_message="Some problems were detected during patch application. The Shinken installation may be unstable/broken until patch application errors are identified and resolved. \nPatch application logs can be found in the '%s' folder for more information." % INSTALL_LOGS_PATH,
                             status_counter=StatusCounter(),
                             function_name=apply_patch,
                             args_list=(args,))


###################################
######      TASK: PATCH      ######
###################################

def patch(args):
    if not os.path.exists(args.config):
        print ""
        print_h1("CONFIGURATION ERROR", line_color="red", title_color="red")
        cprint("No configuration file has been found. \nPlease generate one using the build-host-list subcommand or specify the configuration file path using the --config flag.", color="red")
        sys.exit(1)
    
    setattr(args, "patch", True)
    upload(args)
    launch_patch_application(args)


def main():
    help = """
Performs Shinken update of the specified archive on all machines of the Shinken architecture.

Updating a Shinken installation with this module is done in multiple steps:
- Create the hosts list to upgrade by scanning Shinken configuration files (only needs the be executed the first time):
    %(prog_name)s build-host-list
- List hosts to upgrade:
    %(prog_name)s list-hosts
- Check SSH connexion to all hosts:
    %(prog_name)s check
- Upload Shinken archive:
    %(prog_name)s upload /path/to/shinken_archive.tar.gz
- Launch pre update checks:
    %(prog_name)s preinstall-check
- Launch update:
    %(prog_name)s launch-update-script

All the previous steps can be performed at once with the 'update' subcommand:
- Update hosts: %(prog_name)s update /path/to/shinken_archive.tar.gz


Patch can be applied with the following process
- Create the hosts list to upgrade by scanning Shinken configuration files (only needs the be executed the first time):
    %(prog_name)s build-host-list
- List hosts to upgrade:
    %(prog_name)s list-hosts
- Check SSH connexion to all hosts:
    %(prog_name)s check
- Upload Shinken patch archive:
    %(prog_name)s upload --is-patch /path/to/shinken_patch.tar.gz
- Launch pre patch checks:
    %(prog_name)s prepatch-check
- Launch patch application:
    %(prog_name)s launch-patch-application

All the previous steps can be performed at once with the 'patch' subcommand:
- Update hosts: %(prog_name)s patch /path/to/shinken_patch.tar.gz


    """ % {
        "prog_name": "./shinken-architecture-update"
    }
    parser = argparse.ArgumentParser(description=help,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    subparsers = parser.add_subparsers(help="subcommands help", dest="subcommand")
    
    parser.add_argument("--tmp-dir", type=str, help="Remote upload directory")
    parser.add_argument("-c", "--config", type=str, help="Configuration file to use", default=DEFAULT_CFG_PATH)
    
    # Timeouts configuration
    parser.add_argument("--ssh-default-timeout", type=int, help="Default timeout for other commands done on remote hosts over SSH")
    parser.add_argument("--ping-timeout", type=int, help="Remote host ping timeout")
    parser.add_argument("--patch-upload-timeout", type=int, help="Shinken patch archive upload timeout")
    parser.add_argument("--patch-extract-timeout", type=int, help="Shinken patch archive extraction timeout")
    parser.add_argument("--upload-timeout", type=int, help="Shinken update archive upload timeout")
    parser.add_argument("--extract-timeout", type=int, help="Shinken update archive extraction timeout")
    parser.add_argument("--preinstall-check-timeout", type=int, help="Shinken preinstall check timeout")
    parser.add_argument("--update-timeout", type=int, help="Shinken update timeout")
    parser.add_argument("--patch-timeout", type=int, help="Shinken patch application timeout")
    
    parser.add_argument("-d", "--debug", action="store_true", default=False, help="Print debug information for Shinken Enterprise support")
    
    # Subcommands
    #############
    parser_build_host_list = subparsers.add_parser("build-host-list", help="build-host-list subcommand help")
    parser_build_host_list.set_defaults(handler=build_host_list)
    
    parser_list_hosts = subparsers.add_parser("list-hosts", help="list-hosts subcommand help")
    parser_list_hosts.set_defaults(handler=list_hosts)
    
    parser_check = subparsers.add_parser("check", help="check subcommand help")
    parser_check.set_defaults(handler=check)
    
    parser_upload = subparsers.add_parser("upload", help="upload subcommand help")
    parser_upload.add_argument("archive", type=str, help="Shinken archive (tar.gz) to install")
    parser_upload.add_argument("--force-reupload", action="store_true", default=None, help="Force Shinken archive or patch reupload")
    parser_upload.add_argument("--is-patch", action="store_true", dest="patch", default=None, help="Consider the provided archive as a Shinken patch and not as a Shinken archive")
    parser_upload.set_defaults(handler=upload)
    
    parser_preinstall_check = subparsers.add_parser("preinstall-check", help="preinstall-check subcommand help")
    parser_preinstall_check.set_defaults(handler=preinstall_check)
    
    parser_preinstall_check = subparsers.add_parser("prepatch-check", help="prepatch-check subcommand help")
    parser_preinstall_check.set_defaults(handler=preinstall_check)
    
    parser_launch_update = subparsers.add_parser("launch-update-script", help="launch-update-script subcommand help")
    parser_launch_update.set_defaults(handler=launch_update_script)
    
    parser_update = subparsers.add_parser("update", help="update subcommand help")
    parser_update.add_argument("archive", type=str, help="Shinken archive (tar.gz) to install")
    parser_update.add_argument("--force-reupload", action="store_true", default=None, help="Force Shinken archive reupload")
    parser_update.set_defaults(handler=update)
    
    parser_launch_path_application = subparsers.add_parser("launch-patch-application", help="apply_patch subcommand help")
    # TODO: Handle --revert parameter
    parser_launch_path_application.add_argument("--revert", action="store_true", default=None, help="Revert patch installation")
    parser_launch_path_application.set_defaults(handler=launch_patch_application)
    
    parser_apply_patch = subparsers.add_parser("patch", help="patch subcommand help")
    parser_apply_patch.add_argument("archive", type=str, help="Patch archive (tar.gz) to install")
    parser_apply_patch.add_argument("--force-reupload", action="store_true", default=None, help="Force Shinken patch reupload")
    # TODO: Handle --revert parameter
    parser_apply_patch.add_argument("--revert", action="store_true", default=None, help="Revert patch installation")
    parser_apply_patch.set_defaults(handler=patch)
    
    args = parser.parse_args()
    
    # Configuration setup
    #####################
    configuration.set("tmp_dir", args.tmp_dir, namespace="cli")
    
    configuration.set("ssh_default_timeout", args.ssh_default_timeout, namespace="cli")
    configuration.set("ping_timeout", args.ping_timeout, namespace="cli")
    configuration.set("patch_upload_timeout", args.patch_upload_timeout, namespace="cli")
    configuration.set("patch_extract_timeout", args.patch_extract_timeout, namespace="cli")
    configuration.set("archive_upload_timeout", args.upload_timeout, namespace="cli")
    configuration.set("archive_extract_timeout", args.extract_timeout, namespace="cli")
    configuration.set("preinstall_check_timeout", args.preinstall_check_timeout, namespace="cli")
    configuration.set("update_timeout", args.update_timeout, namespace="cli")
    configuration.set("patch_timeout", args.patch_timeout, namespace="cli")
    
    global DEBUG
    DEBUG = args.debug
    
    need_archive_subcommands = ["upload", "update", "patch"]
    if args.subcommand in need_archive_subcommands:
        if not args.archive:
            print "%s Path to Shinken archive must be present when using the following commands: %s" % (DisplayUtils.status("error"), ', '.join(need_archive_subcommands))
            sys.exit(1)
        else:
            if not os.path.isfile(args.archive):
                print "%s The file '%s' has not been found" % (DisplayUtils.status("error"), args.archive)
                sys.exit(1)
            if not tarfile.is_tarfile(args.archive):
                print "%s The Shinken archive provided ('%s') does not seem to be a .tar.gz archive" % (DisplayUtils.status("error"), args.archive)
                sys.exit(1)
    
    init()
    args.handler(args)


if __name__ == '__main__':
    main()
