#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022:
# This file is part of Shinken Enterprise, all rights reserved.

import base64
import json
import os
import time

import rsa
from Crypto import Random
from Crypto.Cipher import AES

from shinken.log import logger

if os.name == 'nt':
    with open(r'C:\etc\shinken\sl.key', 'rb') as private_key_file:
        private_key = rsa.PrivateKey.load_pkcs1(private_key_file.read())
else:
    with open('/etc/shinken/sl.key', 'rb') as private_key_file:
        private_key = rsa.PrivateKey.load_pkcs1(private_key_file.read())

BS = 16


def pad(s):
    return s + (BS - len(s) % BS) * chr(BS - len(s) % BS).encode('utf8')


def unpad(s):
    return s[0:-ord(chr(s[-1]))]


class AESCipher:
    def __init__(self, key: bytes):
        self.key = key
    
    
    def encrypt(self, raw: str | bytes, old_value: str = None) -> str:
        if isinstance(raw, str):
            raw = raw.encode('utf8')
        raw = pad(raw)
        iv = Random.new().read(16)
        
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        ciphered_msg = cipher.encrypt(raw)
        return base64.b64encode(b'%s%s' % (iv, ciphered_msg)).decode('utf8')
    
    
    def decrypt(self, enc: str | bytes, old_value: str = None) -> str:
        enc = base64.b64decode(enc)
        iv = enc[:16]
        cipher = AES.new(self.key, AES.MODE_CBC, iv)
        return unpad(cipher.decrypt(enc[16:])).decode('utf8')


def decrypt_file(p):
    with open('/etc/shinken/enterprise.key', 'rb') as key_file:
        enterprise_key = key_file.read()
    decrypted_enterprise_key = rsa.decrypt(enterprise_key, private_key)
    
    c = AESCipher(decrypted_enterprise_key)
    with open(p, 'rb') as encrypted_file:
        encrypted = encrypted_file.read()
    decrypted = c.decrypt(encrypted)
    return decrypted


def are_keys_valid():
    key_path = ['/etc/shinken/user.key', '/etc/shinken/user.key2']
    are_valid = False
    nodes_limit = int((((4 + 16 - 10) * 2) // 2) * 2)
    is_powerful = False
    customer = 'unknown'
    customer_email = 'root@localhost'
    testing = False
    duration = 0
    creation_time = 0
    is_duplicated_trial = False
    is_format_valid = False
    is_present = False
    for p in key_path:
        if os.path.exists(p):
            is_present = True
        else:  # no file? bail out
            continue
        
        try:
            buf = decrypt_file(p)
        except Exception as exp:
            is_format_valid = False
            logger.warning('The key %s is malformed! %s' % (p, str(exp)))
            logger.print_stack()
            continue
        
        if buf is None:
            continue
        
        try:
            key = json.loads(buf)
            is_format_valid = True
        except Exception as exp:
            logger.warning('The key %s is malformed! %s' % (p, str(exp)))
            continue
        uid = key.get('uuid', None)
        if uid is None:
            continue
        
        is_present = True
        
        # If temporary look, if we had not already got a previous key...
        if key.get('testing', False):
            if os.path.exists('/var/lib/shinken/._'):
                prev_key = open('/var/lib/shinken/._', 'r').read()
                if prev_key != uid:
                    logger.warning('The temporary key %s is a new temporary one but a previous was was present, we are skipping it' % p)
                    is_duplicated_trial = True
                    continue
            else:
                # Was not already one, we save our
                f = open('/var/lib/shinken/._', 'w')
                f.write(uid)
                f.close()
                # Set its write to 777 because root and shinken users are sharing it
                os.chmod('/var/lib/shinken/._', 0o777)
            # If we get there it means it's a temporary key, so we can continue in fact
            is_powerful = True
            testing = True
        
        creation_time = key.get('creation_time', 0)
        key_customer = key.get('customer', '')
        key_customer_email = key.get('customer_email', 'root@localhost')
        duration = key.get('duration', 0)
        key_nodes_limit = key.get('nodes_limit', nodes_limit)
        
        now = int(time.time())
        if now > creation_time + duration:
            logger.error('The key %s is no more valid, skipping it' % p)
            continue
        else:
            # We found a valid key!            
            are_valid = True
            logger.info('Found a valid key at %s' % p)
        
        if now > creation_time + duration - (86400 * 15):
            logger.warning('The key %s will be no more valid soon (less than 15days), please ask for a new one' % p)
        
        # Ok we take the information :)
        if not are_valid:
            nodes_limit = max(nodes_limit, key_nodes_limit)
        else:
            nodes_limit = key_nodes_limit  # real key, keep the value from it
        if key_customer:
            customer = key_customer
        if key_customer_email:
            customer_email = key_customer_email
    
    d = {
        'are_valid'          : are_valid,
        'nodes_limit'        : nodes_limit,
        'is_powerful'        : is_powerful,
        'customer'           : customer,
        'customer_email'     : customer_email,
        'testing'            : testing,
        'duration'           : duration,
        'creation_time'      : creation_time,
        'is_duplicated_trial': is_duplicated_trial,
        'is_format_valid'    : is_format_valid,
        'is_present'         : is_present,
    }
    return d  # (are_valid, nodes_limit, is_powerful, customer, customer_email, testing)


def get_obfuscated_message_from_unicode(clear_message):
    # type: (str) -> str
    
    # WARNING : This method is designed to obfuscate message in code. The clear message MUST NOT BE PRESENT in the source code
    # use it only on command line and put the obfuscated message in the source code
    
    return ''.join([chr(ord(letter) + 1) for letter in clear_message])


def get_clear_message_form_obfuscated(obfuscated_message):
    # type: (str) -> str
    return ''.join(chr(ord(letter) - 1) for letter in obfuscated_message)
