#!/usr/bin/python

# -*- coding: utf-8 -*-

# Copyright (C) 2014:
#    Gabes Jean, j.gabes@shinken-solutions.com
#
# This file is part of Shinken Enterprise


import time
import socket
from Queue import Queue, Empty
from threading import Thread

from IPy import IP
from shinken.basemodule import BaseModule
from shinken.log import logger
from shinkensolutions.api.synchronizer import SourceInfo, SourceInfoProperty, ITEM_TYPE, METADATA

properties = {
    'daemons': ['synchronizer'],
    'type'   : 'sync-ip-tag',
}


# called by the plugin manager to get a module
def get_instance(plugin):
    logger.info("[Sync IPTag] Get a SyncIPTag module for plugin %s" % plugin.get_name())
    
    # Catch errors
    ip_range = plugin.ip_range
    prop = plugin.property
    value = plugin.value
    method = getattr(plugin, 'method', 'replace')
    ignore_hosts = getattr(plugin, 'ignore_hosts', None)
    
    instance = Sync_Ip_Tag_Arbiter(plugin, ip_range, prop, value, method, ignore_hosts)
    return instance


# Thread executing tasks from a given tasks queue
class Worker(Thread):
    def __init__(self, tasks):
        Thread.__init__(self)
        self.must_run = True
        self.tasks = tasks
        self.daemon = True
        self.start()
    
    
    def stop(self):
        self.must_run = False
    
    
    def run(self):
        while self.must_run:
            if not self.tasks.empty():
                try:
                    func, args, kargs = self.tasks.get(False)
                    func(*args, **kargs)
                except Empty:
                    pass
                except Exception as exp:
                    str_exp = str(exp)
                    if str_exp:
                        logger.error('[Sync IPTag] exp [%s]' % str_exp)
                finally:
                    try:
                        self.tasks.task_done()
                    except ValueError:  # task done called too many time, not a big deal
                        pass
            time.sleep(0.1)


# Pool of threads consuming tasks from a queue"""
class ThreadPool:
    def __init__(self, num_threads):
        self.tasks = Queue(num_threads)
        self.threads = []
        for _ in range(num_threads):
            w = Worker(self.tasks)
            self.threads.append(w)
    
    
    def add_task(self, func, *args, **kargs):
        """Add a task to the queue"""
        self.tasks.put((func, args, kargs))
    
    
    def wait_completion(self):
        """Wait for completion of all the tasks in the queue"""
        self.tasks.join()
        for w in self.threads:
            w.stop()
            w.join()


# Just print some stuff
class Sync_Ip_Tag_Arbiter(BaseModule):
    def __init__(self, mod_conf, ip_range, prop, value, method, ignore_hosts=None):
        BaseModule.__init__(self, mod_conf)
        self.ip_range = IP(ip_range, make_net=True)
        self.property = prop
        self.value = value
        self.method = method
        if ignore_hosts:
            self.ignore_hosts = ignore_hosts.split(', ')
            logger.debug("[Sync IPTag] Ignoring hosts : %s" % self.ignore_hosts)
        else:
            self.ignore_hosts = []
        self.pool_size = int(getattr(mod_conf, 'pool_size', '1'))
    
    
    # Called by Arbiter to say 'let's prepare yourself guy'
    def init(self):
        logger.info("[Sync IPTag] Initialization of the ip range tagger module")
    
    
    def hook_post_merge(self, all_hosts):
        logger.info("[Sync IpTag] in hook post merge")
        
        # Get a pool for jobs
        pool = ThreadPool(32)
        
        for host in all_hosts:
            if 'address' not in host and 'host_name' not in host:
                continue
            
            if 'host_name' not in host:
                continue
            
            source_info = METADATA.get_metadata(host, METADATA.SOURCE_INFO)
            
            hname = host.get('host_name', None)
            if hname in self.ignore_hosts:
                logger.debug("[Sync IPTag] Ignoring host %s" % hname)
                continue
            
            # The address to resolve
            addr = None
            # By default take the address, if not, take host_name
            if 'address' in host:
                addr = host['address']
            else:
                addr = host['host_name']
            
            # logger.debug("[Sync IPTag] Looking for name:%s   and   addr:%s  %s" % (hname, addr, host))
            # logger.debug("[Sync IPTag] Address is %s" % str(addr))
            h_ip = None
            try:
                IP(addr)
                # If we reach here, it's it was a real IP :)
                h_ip = addr
            except:
                pass
            
            pool.add_task(self.job, host, h_ip, addr, source_info)
        
        logger.debug("[Sync IPTag] all task in poll")
        # Now wait for all jobs to finish if need
        if pool:
            t1 = time.time()
            pool.wait_completion()
            logger.debug("[Sync IPTag] poll.wait_completion %.3f" % (time.time() - t1))
    
    
    # Main job, will manage eachhost in asyncronous mode thanks to gevent
    def job(self, h, h_ip, addr, source_info):
        # Ok, try again with name resolution
        if not h_ip:
            try:
                h_ip = socket.gethostbyname(addr)
            except Exception, exp:
                # logger.debug('[Sync IPTag] skiping host: cannot get addr from DNS %s : %s' % (addr, exp))
                pass
        
        if isinstance(source_info, dict):
            source_info = SourceInfo.from_dict(source_info, ITEM_TYPE.HOSTS)
        
        # Ok, maybe we succeed :)
        # logger.debug("[Sync IPTag] Host ip is: %s" % str(h_ip))
        # If we got an ip that match and the object do not already got
        # the property, tag it!
        if h_ip and h_ip in self.ip_range:
            logger.debug("[Sync IPTag] Is in the range %s", h_ip)
            # 4 cases: append, prepend, replace and set
            # append will join with the value (on the END)
            # prepend will join with the value (on the BEGINING)
            # replace will replace it if NOT existing
            # set put the value even if the property exists
            if self.method == 'append':
                orig_v = [e.strip() for e in h.get(self.property, '').split(',') if e.strip()]
                suffixlst = [e.strip() for e in self.value.split(',') if e.strip()]
                for e in suffixlst:
                    if e not in orig_v:
                        orig_v.append(e)
                new_v = ','.join(orig_v)
                h[self.property] = new_v
                # Let the source info know that this value came from us
                # Here : we did not overrite
                source_info.update_field(self.property, self.value, self.get_name(), field_type=SourceInfoProperty.ORDERED_TYPE)
            
            # Same but we put before
            if self.method == 'prepend':
                orig_v = [e.strip() for e in h.get(self.property, '').split(',') if e.strip()]
                suffixlst = [e.strip() for e in self.value.split(',') if e.strip()]
                for e in suffixlst:
                    if e not in orig_v:
                        orig_v.insert(0, e)
                new_v = ','.join(orig_v)
                h[self.property] = new_v
                # Let the source info know that this value came from us
                # Here : we did not overrite
                source_info.update_field(self.property, self.value, self.get_name(), field_type=SourceInfoProperty.ORDERED_TYPE)
            
            if self.method == 'replace':
                if not self.property in h:
                    # Ok, set the value!
                    h[self.property] = self.value
                    # Let the source info know that this value came from us
                    # Here : we did overwrite it
                    source_info.update_field(self.property, self.value, self.get_name(), field_type=SourceInfoProperty.ORDERED_TYPE, overwrite=True)
            
            if self.method == 'set':
                h[self.property] = self.value
                # Let the source info know that this value came from us
                # Here : we did overwrite it
                source_info.update_field(self.property, self.value, self.get_name(), field_type=SourceInfoProperty.ORDERED_TYPE, overwrite=True)
            
            _source_info = source_info.as_dict()
            METADATA.update_metadata(h, METADATA.SOURCE_INFO, _source_info)
            
            logger.debug('[Sync IPTag] Generated property: host:%s  ip:%s %s => %s' % (h.get('host_name', ''), h_ip, self.property, h[self.property]))
