#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2018:
# This file is part of Shinken Enterprise, all rights reserved.

import base64
import hashlib
import os
import sys

from shinken.log import logger
from shinken.util import make_unicode

from shinkensolutions.crypto import AESCipher
from shinkensolutions.service_override_parser import parse_service_override_property
from .def_items import ITEM_STATE, DEF_ITEMS, ITEM_TYPE, METADATA, SERVICE_OVERRIDE, LINKIFY_MANAGE_STATES, prop_is_linked
from .helpers import split_and_strip_list, get_property_separator, ShinkenDatabaseConsistencyError
from ..business.source.sourceinfoproperty import SourceInfoProperty
from ..business.sync_ui_common import syncuicommon

CRYPTO_NOT_TO_LOOK = set(('_SYNC_KEYS', '_SE_UUID', '_SE_UUID_HASH', '_id', '__SYNC_IDX__'))
DEFAULT_TAG_VALUE = u"protected_login_and_password"
TAG_FROM_CHANGE = u'protected_field_value_from_changes'
PROTECTED_TAG_FROM_CHANGE = base64.b64encode(TAG_FROM_CHANGE)
PROTECTED_DEFAULT_TAG = base64.b64encode(DEFAULT_TAG_VALUE)
FORCE_PROCESS = u'FORCE_PROCESS-25863cbb6e40885f885779ed68c1a488-'


class BaseCrypter(object):
    tag_value = DEFAULT_TAG_VALUE
    
    
    def encrypt(self, value, old_value):
        raise NotImplementedError
    
    
    def decrypt(self, value, old_value):
        raise NotImplementedError


class MockCrypter(BaseCrypter):
    tag_value = DEFAULT_TAG_VALUE
    
    
    def encrypt(self, value, old_value):
        return value
    
    
    def decrypt(self, value, old_value):
        return value


class Base64Crypter(BaseCrypter):
    def encrypt(self, value, old_value):
        if isinstance(value, str) and value != "":
            value = value.decode('utf-8', 'ignore')
        
        if value != "null":
            return base64.b64encode(value)
        else:
            return value
    
    
    def decrypt(self, value, old_value):
        if value == "null":
            # If this is not a base64-coded value, it means the field value is empty, and we don't process it
            return 'null'
        return base64.b64decode(value)


class TagValueCrypter(BaseCrypter):
    
    def __init__(self):
        super(TagValueCrypter, self).__init__()
        self.tag_value = Base64Crypter().encrypt(self.tag_value, None)
    
    
    def encrypt(self, value, old_value):
        if value == "null":
            return "null"
        
        return self.tag_value
    
    
    def decrypt(self, value, old_value):
        if value == self.tag_value and old_value:
            decrypt = old_value
        else:
            decrypt = Base64Crypter().decrypt(value, old_value).decode('utf8', 'ignore')
        
        return decrypt


def _match_protected_property(property_name, protected_property_names, item_type):
    if protected_property_names == "" or property_name.startswith('@'):
        return False
    
    protected_fields = DEF_ITEMS.get(item_type, {'protected_fields': ['password']}).get('protected_fields', ())
    if property_name in protected_fields:
        return True
    elif property_name.startswith("_"):
        property_name = property_name.upper()
        if property_name in CRYPTO_NOT_TO_LOOK:
            return False
        _list = split_and_strip_list(protected_property_names)
        if _list:
            for p in _list:
                if p in property_name:
                    return True
    return False


class Cipher(object):
    def __init__(self, enable, protected_property_names):
        self.enable = enable
        self.encryption_enable = enable
        self.protected_property_names = protected_property_names
        # default crypter that do nothing
        self.crypter = MockCrypter()
    
    
    def match_protected_property(self, item_property, item_type, user=None):
        return self.enable and _match_protected_property(item_property, self.protected_property_names, item_type)
    
    
    def cipher(self, item, item_type, item_state=ITEM_STATE.STAGGING, user=None):
        if not self.enable:
            return item
        if item_type == ITEM_TYPE.ELEMENTS:
            item_type = METADATA.get_metadata(item, METADATA.ITEM_TYPE)
        if METADATA.get_metadata(item, METADATA.CRYPTED, False):
            return item
        if METADATA.get_metadata(item, METADATA.UNCRYPTED, False):
            del item['@metadata']['uncrypted']
            if item['@metadata'] == {}:
                del item['@metadata']
        
        edited = self._find_properties_to_process(item, item_type, item_state, self._cipher_value, user=user)
        
        if edited:
            METADATA.update_metadata(item, METADATA.CRYPTED, True)
        return item
    
    
    def uncipher(self, item, item_type, item_state=ITEM_STATE.STAGGING, old_item=None, user=None):
        if not self.enable:
            return item
        if item_type == ITEM_TYPE.ELEMENTS:
            item_type = METADATA.get_metadata(item, METADATA.ITEM_TYPE)
        if METADATA.get_metadata(item, METADATA.UNCRYPTED, False):
            return item
        if METADATA.get_metadata(item, METADATA.CRYPTED, False):
            del item['@metadata']['crypted']
            if item['@metadata'] == {}:
                del item['@metadata']
        
        edited = self._find_properties_to_process(item, item_type, item_state, self._uncipher_value, old_item, user=user)
        
        if edited:
            METADATA.update_metadata(item, METADATA.UNCRYPTED, True)
        return item
    
    
    def _process(self, item_property, value, process_function, item_type, old_value, user=None, force_process=False):
        if value and (force_process or self.match_protected_property(item_property, item_type, user=user)):
            return make_unicode(process_function(value, old_value))
        return value
    
    
    def _find_properties_to_process(self, item, item_type, item_state, process_function, old_item=None, user=None):
        edited = False
        
        if item_state == ITEM_STATE.CHANGES:
            change = item.get('changes', {}) if item.get('changes', {}) else {}
            for item_property, values in change.iteritems():
                info_ppty = values[2]
                need_reset_info_ppty_from_dict = False
                if isinstance(info_ppty, SourceInfoProperty):
                    info_ppty = info_ppty.as_dict()
                    need_reset_info_ppty_from_dict = True
                
                # for the ORDERED or SET type we don't need to cipher because there is no data in this type (only _DATA and ¨PASSWORD)
                if info_ppty['property_type'] in (SourceInfoProperty.ORDERED_TYPE, SourceInfoProperty.SET_TYPE) and info_ppty['property_key'] != SERVICE_OVERRIDE:
                    continue
                
                prop_value = info_ppty['property_value']
                if info_ppty['property_key'] == SERVICE_OVERRIDE:
                    crypted_value_0, effective_0 = self._read_service_override(values[0], process_function, item_state, None, item_type, old_item, user=user)
                    crypted_value_1, effective_1 = self._read_service_override(values[1], process_function, item_state, None, item_type, old_item, user=user)
                    edited = effective_0 or effective_1
                    
                    if effective_0:
                        values[0] = crypted_value_0
                    if effective_1:
                        separator = get_property_separator(item_type, SERVICE_OVERRIDE)
                        values[1] = crypted_value_1
                        for index, (crypted_value, value) in enumerate(zip(split_and_strip_list(crypted_value_1, separator), split_and_strip_list(values[1], separator))):
                            if crypted_value != value:
                                prop_value[index] = (prop_value[index][0], crypted_value)
                else:
                    old_values = old_item.get('changes', {}).get(item_property, ['', '']) if old_item else ['', '']
                    crypted_value_0 = self._process(item_property, values[0], process_function, item_type, old_values[0], user=user)
                    crypted_value_1 = self._process(item_property, values[1], process_function, item_type, old_values[1], user=user)
                    
                    if crypted_value_0 != values[0]:
                        values[0] = crypted_value_0
                        edited = True
                    if crypted_value_1 != values[1]:
                        values[1] = crypted_value_1
                        edited = True
                    
                    if prop_value:
                        if isinstance(prop_value[0], basestring):
                            prop_value[1] = values[1]
                        else:
                            prop_value[0] = (prop_value[0][0], values[1])
                if need_reset_info_ppty_from_dict:
                    separator = get_property_separator(item_type, item_property)
                    values[2] = SourceInfoProperty.from_dict(info_ppty, separator)
        else:
            for item_property, value in item.iteritems():
                if METADATA.get_metadata(item, METADATA.FROM, {}).get(item_property, '') != u'':
                    continue
                if item_property == 'work_area_info' and 'diff_item' in value:
                    for diff_dict in value['diff_item']:
                        diff_dict['new'] = self._process(diff_dict['prop'], diff_dict['new'], process_function, item_type, None, user=user)
                        diff_dict['stagging'] = self._process(diff_dict['prop'], diff_dict['stagging'], process_function, item_type, None, user=user)
                elif item_property == 'last_modification' and 'change' in value:
                    effective = False
                    for changed_dict in value['change']:
                        if changed_dict['prop'] == SERVICE_OVERRIDE:
                            changed_dict['new'], edited_new = self._read_service_override(changed_dict['new'], process_function, item_state, None, item_type, old_item, user=user)
                            changed_dict['old'], edited_old = self._read_service_override(changed_dict['old'], process_function, item_state, None, item_type, old_item, user=user)
                            edited = edited_new or edited_old
                        else:
                            changed_dict['new'] = self._process(changed_dict['prop'], changed_dict['new'], process_function, item_type, None, user=user)
                            changed_dict['old'] = self._process(changed_dict['prop'], changed_dict['old'], process_function, item_type, None, user=user)
                            edited = True
                        edited = effective or edited
                elif item_property == SERVICE_OVERRIDE:
                    item[item_property], effective = self._read_service_override(value, process_function, item_state, item, item_type, old_item, user=user)
                    edited = effective or edited
                else:
                    if item_state == ITEM_STATE.MERGE_SOURCES:
                        source_info_property = METADATA.get_metadata(item, METADATA.SOURCE_INFO, {}).get('_info', {}).get(item_property)
                        if source_info_property and source_info_property['property_type'] == 'SourceInfoProperty' and source_info_property['property_type'] != SourceInfoProperty.ORDERED_TYPE:
                            sourced_values = source_info_property.get('property_value', [])
                            for index, sourced_value in enumerate(sourced_values):
                                crypted_value = self._process(item_property, sourced_value[1], process_function, item_type, None, user=user)
                                if crypted_value != sourced_value[1]:
                                    new_sourced_value = (sourced_value[0], crypted_value)
                                    sourced_values[index] = new_sourced_value
                                    edited = True
                    
                    old_value = old_item.get(item_property, None) if old_item else None
                    process_value = self._process(item_property, value, process_function, item_type, old_value, user=user)
                    if process_value != value:
                        # logger.debug('[protected fields] before process on [%s]:[%s]' % (item_property, value))
                        # logger.debug('[protected fields] before process on [%s]:[%s]' % (item_property, process_value))
                        item[item_property] = process_value
                        edited = True
        
        return edited
    
    
    def _read_service_override(self, service_override_value, process_function, item_state, item, item_type, old_item, user=None):
        effective = False
        unparsed = []
        
        # We force ITEM_TYPE.HOSTS because the caller sends us ITEM_TYPE.CONTACTS for every item type
        # If we are reading a service_overrides property, it means we are processing a host (or cluster)
        separator = get_property_separator(ITEM_TYPE.HOSTS, SERVICE_OVERRIDE)
        
        old_override = old_item.get(SERVICE_OVERRIDE, None) if old_item else None
        
        if service_override_value is None:
            return service_override_value, effective
        
        if isinstance(service_override_value, basestring):
            try:
                parsed = parse_service_override_property(service_override_value)
                if old_override:
                    if isinstance(old_override, basestring):
                        old_override = parse_service_override_property(old_override)
                    else:
                        old_override = parse_service_override_property(old_item._flatten_prop_service_overrides(old_override, SERVICE_OVERRIDE, LINKIFY_MANAGE_STATES))
            except (SyntaxError, KeyError):
                return service_override_value, effective
            
            for check, so_value in parsed.iteritems():
                for prop_name, value in so_value.iteritems():
                    unprocessed_value_with_space = u'''%s, %s %s''' % (check, prop_name, value)
                    unprocessed_value_without_space = u'''%s,%s %s''' % (check, prop_name, value)
                    unprocessed_values = (unprocessed_value_with_space, unprocessed_value_without_space)
                    
                    # logger.debug('[protected fields] before process service overrides on [%s]' % unprocessed_values)
                    old_value = old_override.get(check, {}).get(prop_name, None) if old_override else None
                    process_value = self._process(prop_name, value, process_function, item_type, old_value, user=user)
                    effective = value != process_value or effective
                    process_value = u'%s, %s %s' % (check, prop_name, process_value)
                    unparsed.append(process_value)
                    # logger.debug('[protected fields] after process service overrides on [%s]' % process_value)
                    
                    if item_state == ITEM_STATE.MERGE_SOURCES and effective:
                        source_info_property = METADATA.get_metadata(item, METADATA.SOURCE_INFO, {}).get('_info', {}).get(SERVICE_OVERRIDE)
                        if source_info_property:
                            sourced_values = source_info_property.get('property_value', [])
                            for index, sourced_value in enumerate(sourced_values):
                                if sourced_value[1] in unprocessed_values:
                                    new_sourced_value = (sourced_value[0], process_value)
                                    sourced_values[index] = new_sourced_value
            
            service_override_value = separator.join(unparsed)
        else:
            for link in service_override_value['links']:
                if not prop_is_linked(item_type, link['key']):
                    # logger.debug('[protected fields] before process service overrides on [%s]:[%s]' % (link['key'], link['value']))
                    old_link_value = self._get_old_link_value(link, old_override, old_item)
                    process_value = self._process(link['key'], link['value'], process_function, item_type, old_link_value, user=user)
                    effective = link['value'] != process_value or effective
                    link['value'] = process_value
                    
                    if 'value_with_link' in link:
                        link['value_with_link'] = self._process(link['key'], link['value_with_link'], process_function, item_type, old_link_value, user=user)
                    # logger.debug('[protected fields] after process  service overrides on [%s]:[%s]' % (link['key'], link['value']))
            
            old_raw_value = old_override['raw_value'] if old_override else None
            # raw_value in service override is always process
            service_override_value['raw_value'] = self._process(FORCE_PROCESS, service_override_value['raw_value'], process_function, item_type, old_raw_value, user=user, force_process=True)
        return service_override_value, effective
    
    
    def _get_old_link_value(self, link, old_override, item):
        if not old_override:
            return None
        
        old_link = next((old_link for old_link in old_override.get('links', []) if self._compare_link(old_link, link, item)), {})
        return old_link.get('value', None)
    
    
    @staticmethod
    def _compare_link(link1, link2, item):
        link1_check_link = link1.get('check_link', {})
        link2_check_link = link2.get('check_link', {})
        same_key = link1.get('key', 'link1') == link2.get('key', 'link2')
        
        same_dfe_key = link1.get('dfe_key', '') == link2.get('dfe_key', '')
        same_id = (link1_check_link.get('_id', 'id_link1') == link2_check_link.get('_id', 'id_link2') and same_dfe_key)
        same_name = link1_check_link.get('name', 'name_link1') == link2_check_link.get('name', 'name_link2')
        same_check_link = same_id or same_name
        if not same_check_link and not (link1_check_link.get('_id', '') and link2_check_link.get('_id', '')):
            n1 = Cipher.resolve_name(link1_check_link, link1.get('dfe_key', ''), item)
            n2 = Cipher.resolve_name(link2_check_link, link2.get('dfe_key', ''), item)
            same_check_link = n1 == n2
        return same_key and same_check_link
    
    
    def _cipher_value(self, value, old_value=None):
        return self.crypter.encrypt(value, old_value)
    
    
    def _uncipher_value(self, value, old_value=None):
        return self.crypter.decrypt(value, old_value)
    
    
    @staticmethod
    def resolve_name(link_check_link, dfe_key, item):
        if link_check_link.get('name', ''):
            return link_check_link.get('name', '')
        
        item_id = link_check_link['_id']
        item_type = link_check_link['item_type']
        datamanager_v2 = syncuicommon.app.datamanagerV2
        if not hasattr(datamanager_v2, 'data_provider'):
            # In Synchronizer daemon process datamanagerV2 is a string
            raise Exception(datamanager_v2)
        # We use internal find because we are in a write lock
        check = datamanager_v2.data_provider._find_item_by_id(item_id, item_type, ITEM_STATE.STAGGING)
        if not check:
            raise ShinkenDatabaseConsistencyError(link_check_link, item['_id'], '', '', LINKIFY_MANAGE_STATES)
        _name = check.get_name().replace('$KEY$', dfe_key) if dfe_key else check.get_name()
        return _name


class DatabaseCipher(Cipher):
    def __init__(self, enable, protected_property_names, keyfile):
        super(DatabaseCipher, self).__init__(enable, protected_property_names)
        if enable:
            try:
                with open(keyfile) as fd:
                    data = fd.read()
                key = base64.decodestring(data[data.index("|") + 1:])
                self.crypter = AESCipher(key)
            except OSError as e:
                logger.critical("[protected fields] Unable to open key file %s : %s" % (keyfile, str(e)))
                sys.exit(1)


class FrontendCipher(Cipher):
    def __init__(self, encryption_enable, protected_property_names, are_viewable_by_admin_si):
        super(FrontendCipher, self).__init__(True, protected_property_names)
        self.crypter = TagValueCrypter()
        self.encryption_enable = encryption_enable
        self.are_viewable_by_admin_si = are_viewable_by_admin_si
    
    
    def match_protected_property(self, item_property, item_type, user=None):
        if not user:
            raise ValueError('Missing user argument')
        
        user_is_admin = user.get('is_admin', '0') == '1'
        
        if self.encryption_enable:
            return _match_protected_property(item_property, self.protected_property_names, item_type)
        elif not user_is_admin and not self.are_viewable_by_admin_si:
            return _match_protected_property(item_property, self.protected_property_names, item_type)
        elif user_is_admin:
            return False
        else:
            return item_type != ITEM_TYPE.HOSTS and _match_protected_property(item_property, self.protected_property_names, item_type)


def generate_cipher_key():
    return base64.encodestring(hashlib.sha256(os.urandom(1000)).digest())


# Use at install for generate a key.
if __name__ == '__main__':
    print generate_cipher_key()
