# -*- coding: utf-8 -*-

# Copyright (C) 2009-2014:
#     Gabes Jean, naparuba@gmail.com
#     Gerhard Lausser, Gerhard.Lausser@consol.de
#     Gregory Starck, g.starck@gmail.com
#     Hartmut Goebel, h.goebel@goebel-consult.de
#
# This file is part of Shinken.
#
# Shinken is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Shinken is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Shinken.  If not, see <http://www.gnu.org/licenses/>.

import io
import os
import sys
import threading
import time
from collections import namedtuple

from .compat import cPickle
from .log import LoggerFactory
from .misc.type_hint import TYPE_CHECKING
from .toolbox import _pickledb_dict
from .util import get_day

logger_raw = LoggerFactory.get_logger()
logger_protection = logger_raw.get_sub_part('DESERIALIZATION PROTECTION')

NB_DAYS_TO_KEEP_SECURITY_ERROR = 1
NB_SECONDS_IN_ONE_DAY = 86400
MAX_SECURITY_ERROR_COUNT = 3

SecurityError = namedtuple('SecurityError', ['route', 'date', 'caller'])

if TYPE_CHECKING:
    from .misc.type_hint import Any, Tuple, List, Optional, Dict, Union


# Used by checks to know if the status was good or not
# NOTE: I don't see any WARNING status for this, there is a security error, or not
class SERIALIZATION_SECURITY_STATUS_CODE:  # noqa: yes, no __init__
    OK = 0
    CRITICAL = 2


class SERIALIZATION_SECURITY_EXCEPTION(ValueError):
    pass


class SERIALIZATION_SECURITY_KEYS:
    ERRORS = 'errors'
    COUNT = 'count'


# This class is a singleton to keep all the daemon security error in pickle.loads
# - save them for 1 day, with only the last error of each type
# - export it in a json way for get_raw_stats
# - check the export for checks & healthcheck
class SerializationSecurityContainer:
    def __init__(self):
        self._security_errors_lock = threading.RLock()
        self._security_errors_lock_created_pid = os.getpid()
        self._security_errors = []  # type: List[SecurityError]
        self._error_count = 0
        self._last_error_day = get_day(int(time.time()))
    
    
    # Calls by tests, and ONLY TEST to drop all errors
    def reset_for_tests(self):
        self._reset_errors_and_counter()
    
    
    # IMPORTANT: this object will be created at boot, and will be kept across all forks
    #            so the threading.RLock() can be a problem if we don't recreate it
    #            in the new process
    def _get_security_lock(self):
        # type: () -> threading.RLock
        # if we did change process, at first call drop the old lock,  and create a new one
        # NOTE: we cannot .release() the old one, can be a zombie lock, sorry
        _current_pid = os.getpid()
        if _current_pid != self._security_errors_lock_created_pid:
            self._security_errors_lock = threading.RLock()
            _security_errors_lock_created_pid = _current_pid
        return self._security_errors_lock
    
    
    def _reset_errors_and_counter(self):
        # type: () -> None
        with self._get_security_lock():
            self._security_errors = []
            self._error_count = 0
    
    
    def _check_day(self):
        # type: () -> None
        today = get_day(int(time.time()))
        # The current day has changed, we need to clear error count and errors
        if self._last_error_day != today:
            self._reset_errors_and_counter()
    
    
    def add_error(self, module_name, class_name, caller):
        # type: (str, str, str) -> None
        self._check_day()
        # If we already have the maximum error count, we need to remove the first one
        if len(self._security_errors) == MAX_SECURITY_ERROR_COUNT:
            with self._get_security_lock():
                # No need to keep the result of pop
                self._security_errors.pop(0)  # NOTE: this realloc the list, but it's ok as it's very small
        
        with self._get_security_lock():
            error = SecurityError('%s/%s' % (module_name, class_name), int(time.time()), caller)
            self._security_errors.append(error)
            
            # And finally incrementing counter
            self._error_count += 1
    
    
    def get_errors_export(self):
        # type: () -> Dict[str, Optional[List[SecurityError], int]]
        self._check_day()
        with self._get_security_lock():
            return {SERIALIZATION_SECURITY_KEYS.ERRORS: self._security_errors[:],  # copy here as we will json/read it outside
                    SERIALIZATION_SECURITY_KEYS.COUNT : self._error_count}
    
    
    # Same as the watchdog method: look at export, and give a status code and a text
    # NOTE: we will only display the errors, we don't care about security ok
    @staticmethod
    def check_status_and_output_text(export):
        # type: (Dict[str, Optional[List[SecurityError], int]]) -> Tuple[int, str, List[str]]
        
        errors = export[SERIALIZATION_SECURITY_KEYS.ERRORS]
        error_count = export[SERIALIZATION_SECURITY_KEYS.COUNT]
        
        if error_count == 0:  # no error, we won't display anything in the healthcheck
            return (SERIALIZATION_SECURITY_STATUS_CODE.OK, '', [])  # noqa: pycharm, I want a tuple, I want to see it with ()
        
        title = 'There were [ %s ] security breaches blocked today (last %d):' % (error_count, MAX_SECURITY_ERROR_COUNT)
        _errs = []
        for (mod_and_class, epoch, caller) in errors:  # NOTE: we are ok with tye as we are in the SAME process
            _errs.append('( %s ) by "%s" at %s' % (mod_and_class, caller, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(epoch))))
        return (SERIALIZATION_SECURITY_STATUS_CODE.CRITICAL, title, _errs)  # noqa: pycharm, I want a tuple, I want to see it with ()
    
    
    # Same as the watchdog method: look at export, and give a status code and a text
    # NOTE: we will only display the errors, we don't care about security ok
    @staticmethod
    def check_status_and_output_html(export):
        # type: (Dict[str, Optional[List[SecurityError], int]]) -> Tuple[int, str]
        
        errors = export[SERIALIZATION_SECURITY_KEYS.ERRORS]
        error_count = export[SERIALIZATION_SECURITY_KEYS.COUNT]
        
        # IMPORTANT: Double check for len because if not, mapping will FAIL and crash
        # should be the SAME, but whatever, if not, will crash
        if error_count == 0 or len(errors) == 0:  # no error, we won't display anything in he healthcheck
            return (SERIALIZATION_SECURITY_STATUS_CODE.OK, '')  # noqa: pycharm, I want a tuple, I want to see it with ()
        
        output = 'There were [ %s ] security breaches blocked today (last %d):<ul>' % (error_count, MAX_SECURITY_ERROR_COUNT)
        _errs = []
        
        for (mod_and_class, epoch, caller) in errors:  # crash here if len(errors) == 0
            _errs.append('<li>[ %s ] by [ %s ] at [ %s ]</li>' % (mod_and_class, caller, time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(epoch))))
        return (SERIALIZATION_SECURITY_STATUS_CODE.CRITICAL, output + (''.join(_errs)) + ('</ul>'))  # noqa: pycharm, I want a tuple, I want to see it with ()


serialization_security_container = SerializationSecurityContainer()

# NOTE: I didn't find a way to avoid the global caller here, if you have an idea, let me know
#       tips: if your idea have locks in it, rethink about it, that will fail ^^
current_caller = ''


def _merge_pickledb_dictionaries(destination: 'dict[str, set[str]]', source: 'dict[str, set[str]]') -> None:
    for module, classes in source.items():
        destination.setdefault(module, set()).update(classes)


class ShinkenSafeUnpickler(cPickle.Unpickler):
    def find_class(self, __module_name, __global_name):
        return SafeUnpickler.find_class(__module_name, __global_name)


# Unpickle but strip and remove all __reduce__ things, so we don't allow external code to be executed
# Original Code from Graphite::carbon project, with lot more things ^^
class SafeUnpickler:
    PICKLE_SAFE = {
        'copyreg'                                                               : {'_reconstructor'},
        'builtins'                                                              : {'object', 'set', 'tuple'},
        '_codecs'                                                               : {'encode'},
        'collections'                                                           : {'defaultdict'},
        're'                                                                    : {'_compile'},  # for result modulation that need its regexp
        # Now what we need to pickle that is NOT in shinken.objects
        'shinken.acknowledge'                                                   : {'Acknowledge'},
        'shinken.arbiterlink'                                                   : {'ArbiterLink', 'ArbiterLinks'},
        'shinken.basesubprocess'                                                : {'FromModuleCommandRespond', 'ToModuleCommandRequest'},
        'shinken.brok'                                                          : {'Brok', 'PersistantBrok'},
        'shinken.brokerlink'                                                    : {'BrokerLink', 'BrokerLinks'},
        'shinken.check'                                                         : {'Check'},
        'shinken.contactdowntime'                                               : {'ContactDowntime'},
        'shinken.commandcall'                                                   : {'CommandCall'},
        'shinken.configuration_incarnation'                                     : {'ConfigurationIncarnation', 'PartConfigurationIncarnation'},
        'shinken.dependencynode'                                                : {'DependencyNode', 'StateRule'},
        'shinken.daterange'                                                     : {'Timerange', 'Daterange', 'CalendarDaterange', 'StandardDaterange', 'MonthWeekDayDaterange', 'MonthDateDaterange', 'WeekDayDaterange', 'MonthDayDaterange'},
        'shinken.downtime'                                                      : {'Downtime'},
        'shinken.external_command'                                              : {'ExternalCommand'},
        'shinken.graph'                                                         : {'Graph'},
        # NOTE for logger: The only code that need it was the realm, and is no more need
        #                  if you need it, ask you if you need to pickle a logger (tips: you don't need it)
        # 'shinken.log'                      : set(['PartLogger', 'LoggersInfo', 'Log']),
        'shinken.message'                                                       : {'Message'},
        'shinken.inter_daemon_message'                                          : {'InterDaemonMessage', 'MessageRef'},
        'shinken.notification'                                                  : {'Notification'},
        'shinken.eventhandler'                                                  : {'EventHandler'},
        'shinken.pollerlink'                                                    : {'PollerLink', 'PollerLinks'},
        'shinken.reactionnerlink'                                               : {'ReactionnerLink', 'ReactionnerLinks'},
        'shinken.receiverlink'                                                  : {'ReceiverLink', 'ReceiverLinks'},
        'shinken.schedulerlink'                                                 : {'SchedulerLink', 'SchedulerLinks'},
        'shinken.synchronizerlink'                                              : {'SynchronizerLink', 'SynchronizerLinks'},
        'shinken.withworkersandinventorymodule'                                 : {'FromDaemonToModuleMessage'},
        
        # [ FCLA ] Only for the DIRISI fix -> auto add dict from sources/python3/shinken/toolbox/_pickledb_dict.py
        'webui_module_service_weather.misc.inter_daemon_message_service_weather': {'MessageData', 'MessageDataWeatherHasBeenCalled'},
    }
    
    _merge_pickledb_dictionaries(PICKLE_SAFE, _pickledb_dict._SAFE_PICKLEABLE_CLASSES_DATABASE)
    
    _BYPASS_PICKLE_SAFE_CHECK = False  # Do not touch this variable outside the tests
    
    
    @classmethod
    def find_class(cls, module, name):
        global current_caller
        
        if module == '__builtin__':
            # Python likes it simple : module "__builtin__" in py2 has been renamed to "builtins" in py3 ....
            module = 'builtins'
        
        if module == 'copy_reg':
            module = 'copyreg'
        
        if not cls._BYPASS_PICKLE_SAFE_CHECK and not module.startswith('shinken.objects.') and (module not in cls.PICKLE_SAFE or name not in cls.PICKLE_SAFE[module]):
            err = 'Loading of object "%s/%s" not allowed for security reason (not in the allowed list)' % (module, name)
            serialization_security_container.add_error(module, name, current_caller)
            raise SERIALIZATION_SECURITY_EXCEPTION(err)  # Note: we don't have access to caller here string as this is a C object with a fix prototype, cannot give caller here
        
        __import__(module)
        mod = sys.modules[module]
        return getattr(mod, name)
    
    
    # NOTE: we will "loads" the pickle string but important:
    # * cPickle.Unpickler is a C object, that cannot be extended, so cannot give it directly the caller
    #   property, need to give in a global way, but without lock because of threads/process issues
    @classmethod
    def loads(cls, pickle_string, caller):
        # type: (Union[bytes,str], str) -> Any
        if isinstance(pickle_string, bytes):
            return cls.load(io.BytesIO(pickle_string), caller)
        elif isinstance(pickle_string, str):
            return cls.load(io.BytesIO(pickle_string.encode('utf8')), caller)
        else:
            raise (SERIALIZATION_SECURITY_EXCEPTION('unexpected data type <%s> in SafeUnpickler.loads()' % type(pickle_string)))
    
    
    @classmethod
    def load(cls, pickle_stringio, caller):
        # type: (Union[bytes,io.BytesIO], str) -> Any
        global current_caller
        current_caller = caller
        if not isinstance(pickle_stringio, io.BytesIO):
            pickle_stringio = io.BytesIO(pickle_stringio)
        pickle_obj = ShinkenSafeUnpickler(pickle_stringio)
        pickle_obj.find_class = cls.find_class
        try:
            r = pickle_obj.load()
            return r
        except SERIALIZATION_SECURITY_EXCEPTION as exp:
            logger_protection.error('[ %s ] %s ' % (caller, exp))
            raise
