#!/usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (C) 2013-2019:
# This file is part of Shinken Enterprise, all rights reserved.

from shinken.misc.type_hint import List, Tuple, Union, NoReturn, Type


class FilteringError(Exception):
    pass


class ParenthesisMismatchError(FilteringError):
    pass


class QuoteMismatchError(FilteringError):
    pass


class MissingOperandError(FilteringError):
    pass


def filter_values(_filter, value):
    # type:(FilterNode, Union[List[str],Tuple[str], str]) -> bool
    # Check if value(s) matches the filter.
    # If value is a list or tuple, all values must match filter.
    if _filter is None:
        return True
    if type(value) in (list, tuple):
        for v in value:
            if _filter.match(v.lower()) is None:
                return False
        return True
    else:
        return _filter.match(value.lower()) is not None


#
# An "extensible range" is a tuple of 3 ints: start, end_min and end_max.
# Matches are extensible ranges as they match from start to end_min up to end_max (excluded).
# For performance reason, this is not implemented as a class but as a tuple
#
#
# eg :
#           1111
# 01234567890123
# foobarbazquux
# "foo" match is (0, 3, 13)
# "bar" match is (3, 6, 13)
# "baz" match is (6, 9, 13)
#

RANGE_START = 0
RANGE_END_MIN = 1
RANGE_END_MAX = 2


def reduce_ranges(ranges):
    # type: (List[Tuple[int, int, int]]) -> List[Tuple[int, int, int]]
    # remove ranges that will be already be matched by another range
    res = []
    if len(ranges) == 0:
        return []
    _ranges = ranges[:]
    first = _ranges.pop()
    while _ranges:
        other = _ranges.pop()
        if first[RANGE_START] == other[RANGE_START] and first[RANGE_END_MAX] == other[RANGE_END_MAX]:
            if first[RANGE_END_MIN] <= other[RANGE_END_MIN]:
                # replace 'other' with 'first'
                pass
            else:
                # replace 'first' with 'other'
                res.append(other)
                res.extend(_ranges)
                start = 0
                break
        else:
            res.append(other)
    else:
        res.insert(0, first)
        start = 1  # don't redo it
    
    _res = res[:start]
    _res.extend(reduce_ranges(res[start:]))
    return _res


class FilterNode(object):
    """
    parent: Parent node in filtering tree
    pos: position in filter text where the definition of this node occurs
    children: operands for this node
    """
    
    
    def __init__(self, parent=None, pos=0):
        self.parent = parent
        self.pos = pos
        self.children = []
        if parent:
            parent.add_child(self)
        self.cached_results = {}
    
    
    def add_child(self, node):
        self.children.append(node)
    
    
    def match(self, value):
        # type: (str) -> Union[Tuple[int, int, int], None]
        # Return first match, None if none
        for m in self.matches(value):
            return m
        return None
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        # Return all matches for "value"
        raise NotImplementedError
    
    
    def optimize(self):
        do_remove = []
        for i, c in enumerate(self.children):
            optimized = c.optimize()
            if optimized is None:
                do_remove.insert(0, i)
            else:
                self.children[i] = optimized
        for i in do_remove:
            del self.children[i]
        
        return self
    
    
    def as_dict(self):
        return {'type'    : type(self).__name__,
                'children': [c.as_dict() for c in self.children]}


class TextNode(FilterNode):
    """A substring match node"""
    
    
    def __init__(self, text, parent=None, pos=0):
        super(TextNode, self).__init__(parent, pos)
        text = text.lower()
        self.text = text
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        
        # special case: "" matches empty string
        if self.text == "":
            yield 0, 0, 0
        else:
            i = 0
            while True:
                i = value.find(self.text, i)
                if i == -1:
                    break
                yield i, i + len(self.text), len(value)
                i += 1
    
    
    def as_dict(self):
        result = {'type'    : type(self).__name__,
                  'content' : self.text,
                  'children': [c.as_dict() for c in self.children]}
        return result


class NotNode(FilterNode):
    """Negation node"""
    
    
    def add_child(self, node):
        # NotNode can only have one child. Replace child.
        self.children = [node]
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        c = self.children[0]
        m = next(c.matches(value), None)
        if m is None:
            yield 0, len(value), len(value)


class AndNode(FilterNode):
    """Intersection node"""
    
    
    @staticmethod
    def unions(ranges):
        # type: (List[List[Tuple[int, int, int]]]) -> List[Tuple[int, int, int]]
        # compute unions of ranges
        res = []
        if not ranges:
            return res
        if len(ranges) == 1:
            return ranges[0]
        for r in ranges[0]:
            for inner_range in AndNode.unions(ranges[1:]):
                _s = min(r[RANGE_START], inner_range[RANGE_START])
                _e_min = max(r[RANGE_END_MIN], inner_range[RANGE_END_MIN])
                _e_max = max(r[RANGE_END_MAX], inner_range[RANGE_END_MAX])
                if _s <= _e_min:
                    res.append((_s, _e_min, _e_max))
        return res
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        # match if all children match
        per_child_matches = []
        got_match = True
        # get all child matches
        for c in self.children:
            _matches = c.matches(value)
            m = next(_matches, None)
            if m is None:
                # a child doesnt match. bail out early.
                got_match = False
                break
            child_matches = []
            while m:
                child_matches.append(m)
                m = next(_matches, None)
            per_child_matches.append(child_matches)
        
        if got_match:
            # compute unions
            res = AndNode.unions(per_child_matches)
            for r in reduce_ranges(res):
                yield r
    
    
    def optimize(self):
        super(AndNode, self).optimize()
        for c in self.children[:]:
            if type(c) == AndNode:
                # collapse successive 'And's
                self.children.extend(c.children)
                self.children.remove(c)
        return self


class OrNode(FilterNode):
    """Union node"""
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        # if all children matches
        # get all child matches
        for c in self.children:
            for m in c.matches(value):
                yield m
    
    
    def optimize(self):
        super(OrNode, self).optimize()
        for c in self.children[:]:
            if type(c) == OrNode:
                # collapse successive 'Or's
                self.children.extend(c.children)
                self.children.remove(c)
        return self


class StarNode(FilterNode):
    """Catch-all node"""
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        c = self.children[0]
        for i in range(len(value)):
            # eat up characters until a match is found
            for m in c.matches(value[i:]):
                yield m[RANGE_START] + i, m[RANGE_END_MIN] + i, m[RANGE_END_MAX] + i
    
    
    def optimize(self):
        super(StarNode, self).optimize()
        if len(self.children) == 0:
            return None
        else:
            return self


class StartsWithNode(FilterNode):
    """Value starts with expression"""
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        # Return True if value starts with child expression
        c = self.children[0]
        for m in c.matches(value):
            if m[RANGE_START] == 0:
                yield m


class EndsWithNode(FilterNode):
    """Value ends with expression"""
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        # Return True if value ends with child expression
        c = self.children[0]
        for m in c.matches(value):
            if m[RANGE_END_MIN] == len(value):
                yield m[RANGE_START], m[RANGE_END_MAX], m[RANGE_END_MAX]


class EqualNode(FilterNode):
    """Value strictly equal with expression"""
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        # Return True if value equals child expression
        c = self.children[0]
        for m in c.matches(value):
            if m[RANGE_START] == 0 and m[RANGE_END_MIN] == len(value):
                yield m


class ParenthesisNode(FilterNode):
    # operand side specification
    LHS = -1
    RHS = 1
    ORHS = 3  # Optional right-hand side, for '*'
    
    # unary operators have equal precedence
    unary_ops = {
        EndsWithNode  : ('<', LHS),
        StartsWithNode: ('>', RHS),
        StarNode      : ('*', ORHS),
        EqualNode     : ('=', RHS),
        NotNode       : ('!', RHS)
    }
    
    
    def match_next(self, children, value, offset):
        # type: (List[FilterNode], str, int) -> List[Tuple[int, int, int]]
        if children:
            c = children.pop(0)
            for m in c.matches(value[offset:]):
                if children:
                    for i in xrange(m[RANGE_END_MIN], m[RANGE_END_MAX] + 1):
                        # get remaining matches starting at m[RANGE_START]
                        sub_matches = self.match_next(children, value, offset + i)
                        for s in sub_matches:
                            yield m[RANGE_START] + offset + i, s[RANGE_END_MIN] + offset + i, s[RANGE_END_MAX] + offset + i
                else:
                    yield m
    
    
    def matches(self, value):
        # type: (str) -> List[Tuple[int, int, int]]
        # Check children in sequence.
        # If a match is found, return list of Tuple[int, int, int]s matching string in value. [] otherwise.
        if not self.children:
            yield 0, len(value), len(value)
        else:
            for m in self.match_next(self.children[:], value, 0):
                yield m
    
    
    def _iterate_on_type(self, not_in, *types):
        # type: (Union[List[FilterNode], None], *Type[FilterNode]) -> List[FilterNode]
        if not_in is None:
            not_in = []
        return [_op for _op in self.children if type(_op) in types and _op not in not_in]
    
    
    def optimize(self):
        super(ParenthesisNode, self).optimize()
        if len(self.children) == 1:
            return self.children[0]
        return self
    
    
    def _pop_operand(self, op, name, side, delete):
        # Pop operand from children list, either right or left-hand side.
        # Leave operand in list if 'delete" is false.
        
        if side == self.ORHS:
            side = self.RHS
            _optional = True
        else:
            _optional = False
        
        if side not in [self.LHS, self.RHS]:
            raise FilteringError('Invalid operand side specification')
        
        try:
            idx = self.children.index(op) + side
            if idx < 0:
                raise IndexError
            if delete:
                result = self.children.pop(idx)
            else:
                result = self.children[idx]
        except IndexError:
            if _optional:
                return None
            raise MissingOperandError('Missing operand for "%s" at position %d' % (name, self.pos))
        return result
    
    
    def _parse_unary(self, op):
        # type: (FilterNode) -> NoReturn
        # Parse an unary operator, moving its operand to its children.
        
        op_name, op_side = self.unary_ops[type(op)]
        
        # get operand
        operand = self._pop_operand(op, op_name, op_side, delete=False)
        
        if type(operand) in (AndNode, OrNode):
            raise FilteringError('Invalid "|" or "&" after "%s"' % op_name)
        
        if type(operand) in self.unary_ops and op_side != self.LHS:
            # handle successive unary operators. LHS operators have already been handled.
            self._parse_unary(operand)
        
        if operand is not None:
            op.add_child(operand)
            self.children.remove(operand)
    
    
    def reorder(self):
        # Reorder children according to types and operator precedences
        
        # Reorder sub-parenthesis first
        for op in self._iterate_on_type(None, ParenthesisNode):
            op.reorder()
        
        # Then unary operators
        op_done = []
        while True:
            nodes = self._iterate_on_type(op_done, *self.unary_ops.keys())
            if not nodes:
                break
            op = nodes[0]
            self._parse_unary(op)
            op_done.append(op)
        
        # And Node
        for op in self._iterate_on_type(None, AndNode):
            # pop left-hand side and right-hand side operands
            rhs = self._pop_operand(op, '&', self.RHS, delete=True)
            lhs = self._pop_operand(op, '&', self.LHS, delete=True)
            op.add_child(lhs)
            op.add_child(rhs)
        
        # Or Node
        for op in self._iterate_on_type(None, OrNode):
            # pop left-hand side and right-hand side operands
            rhs = self._pop_operand(op, '|', self.RHS, delete=True)
            lhs = self._pop_operand(op, '|', self.LHS, delete=True)
            op.add_child(lhs)
            op.add_child(rhs)


# BNF filter grammar (sort-of)
# expression := "(" expression ")" | ["] string ["] | unary_operator expression | expression binary_operator expression | expression ">"
# unary_operator := "!" | ">" | "=" | "*"
# binary_operator := "|" | "&"
# operators priority : ", (, =, >, <, !, &, |


simple_operators = {
    '>': StartsWithNode,
    '<': EndsWithNode,
    '=': EqualNode,
    '|': OrNode,
    '&': AndNode,
    '*': StarNode,
    '!': NotNode
}


class TokenizerState:
    
    def __init__(self):
        self._reset()
        self.current_node = ParenthesisNode()
    
    
    def append_to_content(self, chars):
        if self.content is None:
            self.content = chars
        else:
            self.content += self.trailing_spaces + chars
        self.trailing_spaces = ''
    
    
    def stack_parenthesis(self, pos):
        self.flush_expression(pos)
        new_node = ParenthesisNode(self.current_node)
        self.current_node = new_node
    
    
    def unstack_parenthesis(self, pos):
        self.flush_expression(pos)
        if self.current_node.parent is None:
            raise ParenthesisMismatchError('invalid closing parenthesis at position %s' % pos)
        self.current_node = self.current_node.parent
    
    
    def add_child(self, child, pos):
        self.flush_expression(pos)
        self.current_node.add_child(child)
    
    
    def add_trailing_space(self, char):
        self.trailing_spaces += char
    
    
    def _reset(self):
        self.content = None
        self.trailing_spaces = ''
    
    
    def flush_expression(self, pos):
        # flush text content if any.
        # note that trailing whitespaces are discarded.
        if self.content is not None:
            # Create a text node and link it as a child of current_node
            TextNode(self.content, self.current_node, pos)
        
        self._reset()


def parse_filter(expression):
    # type: (str) -> FilterNode
    # Parse filter expression. Return filtering tree.
    
    state = TokenizerState()
    
    stop_chars = '"() \t' + ''.join(simple_operators.keys())
    pos = 0
    
    while pos < len(expression):
        current_char = expression[pos]
        
        if current_char == '"':
            # eat characters until next quotes
            next_ = expression.find('"', pos + 1)
            if next_ == -1:
                raise QuoteMismatchError('missing closing quotes at position %d' % pos)
            state.trailing_spaces = ''
            state.append_to_content(expression[pos + 1:next_])
            pos = next_ + 1
        elif current_char == '(':
            state.stack_parenthesis(pos)
            pos += 1
        elif current_char == ')':
            state.unstack_parenthesis(pos)
            pos += 1
        elif current_char in simple_operators:
            # any other operator simply flush content and create a new node
            state.add_child(simple_operators[current_char](), pos)
            pos += 1
        elif current_char.isspace():
            # maybe this space will be kept.
            state.add_trailing_space(current_char)
            pos += 1
        else:
            # append characters to content
            next_ = pos + 1
            while next_ < len(expression) and expression[next_] not in stop_chars:
                next_ += 1
            state.append_to_content(expression[pos:next_])
            pos = next_
    
    # flush pending characters
    state.flush_expression(len(expression))
    if state.current_node.parent is not None:
        # raise an exception if current_node is not root
        raise ParenthesisMismatchError('missing closing parenthesis')
    
    # re-order nodes for evaluation
    state.current_node.reorder()
    
    current_node = state.current_node.optimize()
    
    return current_node
