#!/usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (C) 2013-2019:
# This file is part of Shinken Enterprise, all rights reserved.

import itertools
from collections import deque

from shinken.misc.type_hint import TYPE_CHECKING

if TYPE_CHECKING:
    from shinken.misc.type_hint import List, Dict, Tuple, Union, Iterator, Iterable, Optional, Type


class FilteringError(Exception):
    pass


class ParenthesisMismatchError(FilteringError):
    pass


class QuoteMismatchError(FilteringError):
    pass


class MissingOperandError(FilteringError):
    pass


def filter_values(_filter, value):
    # type:(FilterNode, Union[List[str],Tuple[str], str]) -> bool
    # Check if value(s) matches the filter.
    # If value is a list or tuple, all values must match filter.
    if isinstance(value, (str, unicode)):
        value = [value]
    elif not value:
        return False
    return all(_filter.match(v) is not None for v in value)


#
# An "extensible range" is a tuple of 3 ints: start, end_min and end_max.
# Matches are extensible ranges as they match from start to end_min up to end_max (excluded).
# For performance reason, this is not implemented as a class but as a tuple
#
#
# eg :
#           1111
# 01234567890123
# foobarbazquux
# "foo" match is (0, 3, 13)
# "bar" match is (3, 6, 13)
# "baz" match is (6, 9, 13)
#

RANGE_START = 0
RANGE_END_MIN = 1
RANGE_END_MAX = 2


def unify_range(match_range, *matches_range):
    # type: (Tuple[int, int, int], *Tuple[int, int, int]) -> Tuple[int, int, int]
    for r in matches_range:
        match_range = (
            min(r[RANGE_START], match_range[RANGE_START]),
            max(r[RANGE_END_MIN], match_range[RANGE_END_MIN]),
            max(r[RANGE_END_MAX], match_range[RANGE_END_MAX])
        )
    return match_range


class FilterNode(object):
    """
    parent: Parent node in filtering tree
    pos: position in filter text where the definition of this node occurs
    children: operands for this node
    """
    
    
    def __init__(self, parent=None, pos=0):
        # type: (Optional["FilterNode"], int) -> None
        self.parent = parent
        self.pos = pos
        self.children = []  # type: List["FilterNode"]
        if parent:
            parent.add_child(self)
    
    
    def __repr__(self):
        return '<%s at position %s with %s child%s>' % (
            type(self).__name__,
            self.pos if self.pos >= 0 else 'fictive',
            len(self.children),
            '' if len(self.children) < 2 else 'ren',
        )
    
    
    def add_child(self, node):
        # type: ("FilterNode") -> None
        node.parent = self
        self.children.append(node)
    
    
    def match(self, value):
        # type: (str) -> Union[Tuple[int, int, int], None]
        # Return first match, None if none
        return next(self.matches(value), None)
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        # Return all matches for "value"
        raise NotImplementedError
    
    
    def optimize(self):
        do_remove = set()
        for i, c in enumerate(self.children):
            optimized = c.optimize()
            if optimized is None:
                do_remove.add(c)
            else:
                optimized.parent = self
                self.children[i] = optimized
        for c in do_remove:
            c.parent = None
            try:
                self.children.remove(c)
            except ValueError:
                continue
        
        return self
    
    
    def as_dict(self):
        return {'type'    : type(self).__name__,
                'children': [c.as_dict() for c in self.children]}


class OperatorNode(FilterNode):
    """
    Special class used to identify operator nodes
    """
    
    # operand side specification
    LHS = -1
    RHS = 1
    
    # Operator precedence over another
    # Higher value means which must be done first
    # Will be assigned for each operator at the end of the module
    precedence = 0
    
    
    def __init__(self, name, parent=None, pos=0):
        # type: (str, Optional["FilterNode"], int) -> None
        super(OperatorNode, self).__init__(parent=parent, pos=pos)
        self.name = name
    
    
    def add_child(self, node, side=None):
        # type: (FilterNode, Optional[int]) -> None
        node.parent = self
        if side is None or side == self.RHS:
            self.children.append(node)
        elif side == self.LHS:
            self.children.insert(0, node)
        else:
            raise FilteringError('Invalid side specification')
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        raise NotImplementedError
    
    
    def handle_missing_operand(self, side):
        # type: (int) -> None
        raise MissingOperandError('Missing operand for "%s" at position %d' % (self.name, self.pos))


class UnaryOperatorNode(OperatorNode):
    """
    Special operator which can have only one child
    """
    
    # By default, unary operators have their operand in the right-hand side
    operand_side = OperatorNode.RHS  # type: int
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        raise NotImplementedError
    
    
    def add_child(self, node, side=None):
        # type: (FilterNode, Optional[int]) -> None
        # Replace child.
        if side is not None and side != self.operand_side:
            raise FilteringError('Invalid side specification')
        node.parent = self
        if self.children:
            self.children[0].parent = None
        self.children = [node]
    
    
    def optimize(self):
        super(UnaryOperatorNode, self).optimize()
        if len(self.children) == 0:
            return None
        return self
    
    
    def handle_missing_operand(self, side):
        # type: (int) -> None
        if side == self.operand_side and len(self.children) == 0:
            return super(UnaryOperatorNode, self).handle_missing_operand(side)  # Top-level call will raise a MissingOperandError
        return None


class AssociativeUnaryOperatorNode(UnaryOperatorNode):
    """
    Special class for associative unary operations

    If the child has the same type, just peek his child instead
    """
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        raise NotImplementedError
    
    
    def optimize(self):
        super(AssociativeUnaryOperatorNode, self).optimize()
        if len(self.children) == 0:
            return None
        child = self.children[0]
        if type(child) is type(self):
            child.parent = None
            self.children[0] = child = child.children[0]  # Cannot be empty
            child.parent = self
        return self


class BinaryOperatorNode(OperatorNode):
    """
    Special class used to identify binary operator nodes
    
    Caveat: Binary operators must have *AT LEAST* 2 children
    3+ children means successive operations and, therefore, is not an error
    """
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        raise NotImplementedError
    
    
    def handle_missing_operand(self, side):
        # type: (int) -> None
        if len(self.children) < 2:
            return super(BinaryOperatorNode, self).handle_missing_operand(side)  # Top-level call will raise a MissingOperandError
        # Already have all its children, go away
        return None
    
    
    def optimize(self):
        super(BinaryOperatorNode, self).optimize()
        for c in self.children[:]:
            if type(c) is type(self):
                # collapse successive similar operators
                c.parent = None
                try:
                    c_idx = self.children.index(c)
                except IndexError:  # Shut down error
                    continue
                for grand_child in c.children:
                    grand_child.parent = self
                self.children[c_idx:c_idx + 1] = c.children  # Preserve order
        return self


class TextNode(FilterNode):
    """A substring match node"""
    
    
    def __init__(self, text, parent=None, pos=0):
        super(TextNode, self).__init__(parent, pos)
        text = text.lower()
        self.text = text
    
    
    def __repr__(self):
        return '<%s %r at position %s>' % (
            type(self).__name__,
            self.text,
            self.pos,
        )
    
    
    def add_child(self, node):
        # type: (FilterNode) -> None
        raise FilteringError(u'TextNode cannot have children')
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        
        # special case: "" matches empty string
        if self.text == "":
            yield 0, 0, 0
            return
        value = value.lower()  # Case-insensitive match
        i = 0
        while True:
            i = value.find(self.text, i)
            if i == -1:
                break
            yield i, i + len(self.text), len(value)
            i += 1
    
    
    def as_dict(self):
        result = super(TextNode, self).as_dict()
        result['content'] = self.text
        return result


class NotNode(UnaryOperatorNode):
    """Negation node"""
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        c = self.children[0]
        m = next(c.matches(value), None)
        if m is None:
            yield 0, len(value), len(value)


class AndNode(BinaryOperatorNode):
    """Intersection node"""
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        # match if all children match
        per_child_matches = deque()  # type: deque[Tuple[int, int, int]]
        # get all children matches
        for c in self.children:
            child_match = next(c.matches(value), None)
            if not child_match:
                # a child doesnt match. bail out early.
                return
            per_child_matches.append(child_match)
        if per_child_matches:
            yield unify_range(*per_child_matches)


class OrNode(BinaryOperatorNode):
    """Union node"""
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        # if all children matches
        # get all child matches
        for c in self.children:
            for m in c.matches(value):
                yield m


class StarNode(BinaryOperatorNode):
    """Catch-all node"""
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        if not self.children:
            # Ultimate joker, all match :)
            yield 0, len(value), len(value)
            return
        
        all_matches = {}  # type: Dict[FilterNode, Tuple[int, int, int]]
        start = 0
        for c in self.children:
            for offset in range(start, len(value)):
                # eat up characters until a match is found
                m = next(c.matches(value[offset:]), None)
                if m is None:
                    continue
                m = m[RANGE_START] + offset, m[RANGE_END_MIN] + offset, m[RANGE_END_MAX] + offset
                all_matches[c] = m
                start = m[RANGE_END_MIN]
                break
        if len(all_matches) == len(self.children):
            for c in self.children:
                yield all_matches[c]
    
    
    def handle_missing_operand(self, side):
        # type: (int) -> None
        # VERY SPECIAL CASE: Both side of a StarNode is optional !
        # It is a binary operator, but it can be alone :)
        return None
    
    
    def optimize(self):
        super(StarNode, self).optimize()
        if len(self.children) == 1:
            return self.children[0]
        return self


class StartsWithNode(AssociativeUnaryOperatorNode):
    """Value starts with expression"""
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        # Return True if value starts with child expression
        c = self.children[0]
        for m in c.matches(value):
            if m[RANGE_START] == 0:
                yield m


class EndsWithNode(AssociativeUnaryOperatorNode):
    """Value ends with expression"""
    
    # Special case: Operand in the left-hand side
    operand_side = OperatorNode.LHS
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        # Return True if value ends with child expression
        c = self.children[0]
        for m in c.matches(value):
            if m[RANGE_END_MIN] == len(value):
                yield m[RANGE_START], m[RANGE_END_MAX], m[RANGE_END_MAX]


class EqualNode(AssociativeUnaryOperatorNode):
    """Value strictly equal with expression"""
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        # Return True if value equals child expression
        c = self.children[0]
        for m in c.matches(value):
            if m[RANGE_START] == 0 and m[RANGE_END_MIN] == len(value):
                yield m


class ParenthesisNode(FilterNode):
    # operand side specification
    LHS = OperatorNode.LHS
    RHS = OperatorNode.RHS
    
    
    def __init__(self, parent=None, pos=0, children=None):
        # type: (Optional["FilterNode"], int, Optional[Iterable[FilterNode]]) -> None
        super(ParenthesisNode, self).__init__(parent, pos)
        if children:
            self.children[:] = children
            for child in self.children:
                child.parent = self
    
    
    def match_next(self, children, value, offset):
        # type: (List[FilterNode], str, int) -> Iterator[Tuple[int, int, int]]
        if children:
            c = children.pop(0)
            for m in c.matches(value[offset:]):
                if children:
                    for i in xrange(m[RANGE_END_MIN], m[RANGE_END_MAX] + 1):
                        # get remaining matches starting at m[RANGE_START]
                        sub_matches = self.match_next(children, value, offset + i)
                        for s in sub_matches:
                            yield m[RANGE_START] + offset + i, s[RANGE_END_MIN] + offset + i, s[RANGE_END_MAX] + offset + i
                else:
                    yield m
    
    
    def matches(self, value):
        # type: (str) -> Iterator[Tuple[int, int, int]]
        # Check children in sequence.
        # If a match is found, return list of Tuple[int, int, int]s matching string in value. [] otherwise.
        if not self.children:
            yield 0, len(value), len(value)
        else:
            for m in self.match_next(self.children[:], value, 0):
                yield m
    
    
    def optimize(self):
        super(ParenthesisNode, self).optimize()
        if len(self.children) == 1:
            return self.children[0]
        return self
    
    
    def reorder(self):
        # Reorder children according to types and operator precedences
        
        actual_op = None  # type: Optional[OperatorNode]
        
        # Container for leaf operands or operators which will be used as operand
        left_operands_queue_next_op = deque()  # type: deque[FilterNode]
        right_operands_queue_actual_op = deque()  # type: deque[FilterNode]
        
        # We recreate a new graph so there is no need to keep them as children
        all_nodes = self.children
        self.children = []
        
        for node in all_nodes:
            node.parent = None  # Officially detach node from parenthesis as it will be handled to have a new parent
            if not isinstance(node, OperatorNode):
                # Just peek the node
                if isinstance(node, ParenthesisNode) and node is not self:  # Ensure recursive call
                    node.reorder()
                right_operands_queue_actual_op.append(node)
                continue
            
            actual_op = self._handle_operator(actual_op, node, left_operands_queue_next_op, right_operands_queue_actual_op)
        
        # Do it a last time for incomplete loop
        self._handle_operator(actual_op, None, left_operands_queue_next_op, right_operands_queue_actual_op)
        
        # All remaining nodes on both queues return to parenthesis
        self.children[:] = itertools.chain(left_operands_queue_next_op, right_operands_queue_actual_op)
        for child in self.children:
            child.parent = self
    
    
    @staticmethod
    def _operator_accepts_operand_on_side(op, side):
        # type: (OperatorNode, int) -> bool
        if isinstance(op, UnaryOperatorNode):
            return op.operand_side == side
        # Other operators is considered as binary operators
        return True
    
    
    @classmethod
    def _handle_operator(cls, actual_op, next_op, left_operands_queue_next_op, right_operands_queue_actual_op):
        # type: (Optional[OperatorNode], Optional[OperatorNode], deque[FilterNode], deque[FilterNode]) -> Optional[OperatorNode]
        if actual_op is None and next_op is None:
            return None
        
        # Put needed functions to local scope
        operator_accepts_operand_on_side = cls._operator_accepts_operand_on_side
        
        if actual_op is not None:
            # Caveat: Unary node with operand to the left-hand side see its expression finished here
            if operator_accepts_operand_on_side(actual_op, cls.RHS):  # Does operator accept operand on the right-hand side ?
                if next_op is not None and next_op.precedence > actual_op.precedence:
                    # Skip it, will be handled later
                    right_operands_queue_actual_op.append(next_op)
                    return actual_op
                if not right_operands_queue_actual_op:
                    # Successive operators or end of list
                    if isinstance(actual_op, UnaryOperatorNode) and (next_op is not None and actual_op.precedence == next_op.precedence):
                        actual_op.add_child(next_op, side=cls.RHS)
                    else:
                        actual_op.handle_missing_operand(cls.RHS)
                        # No raise ? Nice !
                elif isinstance(right_operands_queue_actual_op[0], ParenthesisNode) and not any(isinstance(n, OperatorNode) for n in right_operands_queue_actual_op):
                    # Only the next parenthesis node is the operand
                    actual_op.add_child(right_operands_queue_actual_op.popleft(), side=cls.RHS)
                else:
                    # All the grabbed nodes is an operand
                    operand = ParenthesisNode(children=right_operands_queue_actual_op, pos=-1)
                    actual_op.add_child(operand, side=cls.RHS)
                    right_operands_queue_actual_op.clear()
                    operand.reorder()  # Force reordering if operators with higher precedence was taken
            if not actual_op.parent:  # If this operator is not already in the new graph
                left_operands_queue_next_op.append(actual_op)
        
        # All unused right-hand operands switch to left-hand side
        if right_operands_queue_actual_op:
            left_operands_queue_next_op.extend(right_operands_queue_actual_op)
            right_operands_queue_actual_op.clear()
        
        # grab all nodes as operator's left operand
        if next_op is not None and operator_accepts_operand_on_side(next_op, cls.LHS):  # Does operator accept operand on the left-hand side ?
            if not left_operands_queue_next_op:
                next_op.handle_missing_operand(cls.LHS)
                # No raise ? Nice again !
            elif isinstance(left_operands_queue_next_op[-1], ParenthesisNode):
                next_op.add_child(left_operands_queue_next_op.pop(), side=cls.LHS)
            else:
                next_op.add_child(ParenthesisNode(children=left_operands_queue_next_op, pos=-1), side=cls.LHS)
                left_operands_queue_next_op.clear()
        return next_op


# BNF filter grammar (sort-of)
# expression := "(" expression ")" | ["] string ["] | lhs_unary_operator expression | expression binary_operator expression | expression rhs_unary_operator
# lhs_unary_operator := "!" | ">" | "="
# rhs_unary_operator := "<"
# binary_operator := "|" | "&" | "*"
# operators precedence : *, =, >, !, <, &, |

# This precedence means for example: ( >"d" | !a & b ) means ( ( > ("d") ) | ( (!a) & (b) ) )
# Caveat 1: Unary operators MUST have higher precedence than binary operators (except for the star operator, which is... special)
# Caveat 2: In our case, all the unary operators have the same precedence

# Operators MUST be declared IN ORDER according to the precedence declared above
# Precedence evaluation will be made using this List
# Operators with equal precedence must be in the same set
# Do not forget to declare it, else it will be the default (0)
OPERATOR_PRECEDENCE = [
    [StarNode],
    [EqualNode, StartsWithNode, EndsWithNode, NotNode],
    [AndNode],
    [OrNode],
]  # type: List[List[Type[OperatorNode]]]

SIMPLE_OPERATORS = {
    '=': EqualNode,
    '*': StarNode,
    '>': StartsWithNode,
    '<': EndsWithNode,
    '!': NotNode,
    '&': AndNode,
    '|': OrNode,
}  # type: Dict[str, Type[OperatorNode]]


def _assign_precedence_to_operators():
    for precedence, operator_cls_set in enumerate(reversed(OPERATOR_PRECEDENCE), start=1):
        for operator_cls in operator_cls_set:
            operator_cls.precedence = precedence


_assign_precedence_to_operators()


class TokenizerState:
    
    def __init__(self):
        self._reset()
        self.current_node = ParenthesisNode(pos=-1)
    
    
    def append_to_content(self, chars):
        if self.content is None:
            self.content = chars
        else:
            self.content += self.trailing_spaces + chars
        self.trailing_spaces = ''
    
    
    def stack_parenthesis(self, pos):
        self.flush_expression(pos)
        new_node = ParenthesisNode(self.current_node, pos)
        self.current_node = new_node
    
    
    def unstack_parenthesis(self, pos):
        self.flush_expression(pos)
        if self.current_node.parent is None:
            raise ParenthesisMismatchError('invalid closing parenthesis at position %s' % pos)
        self.current_node = self.current_node.parent
    
    
    def add_child(self, child, pos):
        self.flush_expression(pos)
        self.current_node.add_child(child)
        child.pos = pos
    
    
    def add_trailing_space(self, char):
        self.trailing_spaces += char
    
    
    def _reset(self):
        self.content = None
        self.trailing_spaces = ''
    
    
    def flush_expression(self, pos):
        # flush text content if any.
        # note that trailing whitespaces are discarded.
        if self.content is not None:
            # Create a text node and link it as a child of current_node
            TextNode(self.content, self.current_node, pos - len(self.content))
        
        self._reset()


def tokenize_filter_expression(expression):
    # type: (str) -> ParenthesisNode
    # Filter expression lexer. Return tokenized filtering tree.
    
    state = TokenizerState()
    
    stop_chars = '"()' + ''.join(SIMPLE_OPERATORS.keys())
    pos = 0
    
    while pos < len(expression):
        current_char = expression[pos]
        
        if current_char == '"':
            # eat characters until next quotes
            next_ = expression.find('"', pos + 1)
            if next_ == -1:
                raise QuoteMismatchError('missing closing quotes at position %d' % pos)
            state.trailing_spaces = ''
            state.append_to_content(expression[pos + 1:next_])
            pos = next_ + 1
        elif current_char == '(':
            state.stack_parenthesis(pos)
            pos += 1
        elif current_char == ')':
            state.unstack_parenthesis(pos)
            pos += 1
        elif current_char in SIMPLE_OPERATORS:
            # any other operator simply flush content and create a new node
            state.add_child(SIMPLE_OPERATORS[current_char](current_char), pos)
            pos += 1
        elif current_char.isspace():
            # maybe this space will be kept.
            state.add_trailing_space(current_char)
            pos += 1
        else:
            # append characters to content
            next_ = pos + 1
            while next_ < len(expression) and expression[next_] not in stop_chars and not expression[next_].isspace():
                next_ += 1
            state.append_to_content(expression[pos:next_])
            pos = next_
    
    # flush pending characters
    state.flush_expression(len(expression))
    if state.current_node.parent is not None:
        # raise an exception if current_node is not root
        raise ParenthesisMismatchError('missing closing parenthesis')
    return state.current_node


def parse_filter(expression):
    # type: (str) -> FilterNode
    # First tokenize expression
    current_node = tokenize_filter_expression(expression)
    
    # re-order and optimizes nodes for evaluation
    current_node.reorder()
    current_node = current_node.optimize()
    if current_node is None:
        # A useless filter so...
        current_node = TextNode('')
    return current_node
