#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022:
# This file is part of Shinken Enterprise, all rights reserved.

import time
from collections import defaultdict
from contextlib import contextmanager
from threading import RLock

from shinken.misc.type_hint import TYPE_CHECKING, cast
from shinkensolutions.cache.cache_storage import CacheStorage, CacheItemNotFound
from shinkensolutions.data_hub.data_hub_driver.abstract_data_hub_driver import AbstractDataHubDriver, AbstractDataHubDriverConfig
from shinkensolutions.data_hub.data_hub_factory.data_hub_factory import DataHubFactory

if TYPE_CHECKING:
    from shinken.misc.type_hint import Any, Number, Tuple, List, Iterator
    from shinken.log import PartLogger
    from shinkensolutions.data_hub.data_hub import DataHubConfig
    from shinkensolutions.data_hub.data_hub_driver.abstract_data_hub_driver import AbstractDataHubDriverConfig


class DataHubMetaDriverConfigCache(AbstractDataHubDriverConfig):
    def __init__(self, config_main_driver, refresh_interval=0):
        # type: (AbstractDataHubDriverConfig, Number) -> None
        super(DataHubMetaDriverConfigCache, self).__init__('CACHE')
        self.config_main_driver = config_main_driver
        self.refresh_interval = refresh_interval


def data_hub_meta_driver_cache_factory(logger, driver_config, data_hub_config):
    # type: (PartLogger, AbstractDataHubDriverConfig, DataHubConfig) -> DataHubMetaDriverCache
    driver_config = cast(DataHubMetaDriverConfigCache, driver_config)
    main_driver = DataHubFactory.build_driver(logger, driver_config.config_main_driver, data_hub_config)
    return DataHubMetaDriverCache(logger, driver_config, main_driver)


DataHubFactory.register_driver_factory(DataHubMetaDriverConfigCache, data_hub_meta_driver_cache_factory)


class CacheObject:
    __slots__ = ("object", "last_modification_date")
    
    
    def __init__(self, object_to_cache, last_modification_date=0):
        # type: (Any, Number) -> None
        self.object = object_to_cache  # type: Any
        self.last_modification_date = last_modification_date  # type: Number


class DataHubMetaDriverCache(AbstractDataHubDriver):
    def __init__(self, logger, driver_config, main_driver):
        # type: (PartLogger, DataHubMetaDriverConfigCache, AbstractDataHubDriver) -> None
        super(DataHubMetaDriverCache, self).__init__(logger, driver_config)
        
        self.cache_lock = defaultdict(RLock)  # type: defaultdict[str, RLock]
        self.cache = CacheStorage()
        self.main_driver = main_driver
        
        self._refresh_interval = driver_config.refresh_interval
    
    
    def init(self):
        self.main_driver.init()
        self.cache_lock.clear()
        self.cache.flush()
    
    
    def _is_cache_object_dirty(self, data_id, cache_object):
        # type: (str, CacheObject) -> bool
        if cache_object.last_modification_date != self.main_driver.get_last_modification_date(data_id):
            return True
        return self._refresh_interval > 0 and cache_object.last_modification_date + self._refresh_interval <= time.time()
    
    
    def _peek_from_cache(self, data_id, check_if_dirty):
        # type: (str, bool) -> CacheObject
        cache_object = self.cache.get_object(data_id)  # type: CacheObject
        if check_if_dirty and self._is_cache_object_dirty(data_id, cache_object):  # We will act as it never existed :)
            self.cache.delete_object(data_id, raising=False)
            raise CacheItemNotFound(data_id, message='Invalid object')
        return cache_object
    
    
    def _reload_from_driver(self, data_id: str, log_error: bool) -> 'CacheObject':
        main_driver_data, last_modification_date = self.main_driver.read_and_get_last_modification_date(data_id, log_error=log_error)
        cache_object = CacheObject(main_driver_data, last_modification_date=last_modification_date)
        self.cache.store_object(data_id, cache_object)
        return cache_object
    
    
    def get_number_of_stored_data(self):
        return self.main_driver.get_number_of_stored_data()
    
    
    def get_all_data_id(self):
        return self.main_driver.get_all_data_id()
    
    
    def find_data_id(self, filters):
        # type: (Any) -> List[str]
        return self.main_driver.find_data_id(filters)
    
    
    def is_data_correct(self, data_id):
        with self.cache_lock[data_id]:
            try:
                is_valid = self.main_driver.is_data_correct(data_id)
            except Exception:  # noqa: Should not happen but must not be propagated
                is_valid = False
            if is_valid:
                cache_object = self.cache.get_object_default(data_id, None)
                if cache_object is None or self._is_cache_object_dirty(data_id, cache_object):
                    try:
                        self._reload_from_driver(data_id, log_error=False)
                    except Exception:  # noqa: Expected behaviour.
                        self.cache.delete_object(data_id, raising=False)
                        is_valid = False
            else:  # Invalidate cached object if not valid
                self.cache.delete_object(data_id, raising=False)
        return is_valid
    
    
    @contextmanager
    def lock_context(self, data_id):
        # type: (str) -> Iterator[None]
        with self.cache_lock[data_id]:
            with self.main_driver.lock_context(data_id):
                yield
    
    
    def write(self, data_id, data):
        with self.lock_context(data_id):
            self.main_driver.write(data_id, data)
            self.cache.store_object(data_id, CacheObject(data, last_modification_date=self.main_driver.get_last_modification_date(data_id)))
    
    
    def write_raw(self, data_id, raw_data):
        self.main_driver.write_raw(data_id, raw_data)
    
    
    def read(self, data_id, log_error=True):
        data, _ = self.read_and_get_last_modification_date(data_id, log_error=log_error)
        return data
    
    
    def read_raw(self, data_id, log_error=True):
        return self.main_driver.read_raw(data_id, log_error=log_error)
    
    
    def remove(self, data_id):
        try:
            with self.lock_context(data_id):
                try:
                    self.main_driver.remove(data_id)
                finally:
                    self.cache.delete_object(data_id, raising=False)
        finally:
            self.cache_lock.pop(data_id, None)
    
    
    def read_and_get_last_modification_date(self, data_id, log_error=True):
        # type: (str, bool) -> Tuple[Any, Number]
        with self.cache_lock[data_id]:
            try:
                cache_object = self._peek_from_cache(data_id, check_if_dirty=True)
            except CacheItemNotFound:
                cache_object = self._reload_from_driver(data_id, log_error=log_error)
            return cache_object.object, cache_object.last_modification_date
    
    
    def get_last_modification_date(self, data_id):
        return self.main_driver.get_last_modification_date(data_id)
    
    
    def destroy(self):
        try:
            self.main_driver.destroy()
        finally:
            self.cache_lock.clear()
            self.cache.flush()
    
    
    def stop(self):
        try:
            self.main_driver.stop()
        finally:
            self.cache_lock.clear()
            self.cache.flush()
    
    
    def get_total_size(self):
        try:
            return self.main_driver.get_total_size()
        except NotImplementedError:
            raise
        except:
            self.cache.flush()
            raise
    
    
    def get_size_of(self, data_id):
        try:
            return self.main_driver.get_size_of(data_id)
        except NotImplementedError:
            raise
        except:
            self.cache.flush()
            raise
