#!/usr/bin/python
# -*- coding: utf-8 -*-

# Copyright (C) 2009-2022:
#    Gabes Jean, naparuba@gmail.com
#    Gerhard Lausser, Gerhard.Lausser@consol.de
#    Gregory Starck, g.starck@gmail.com
#    Hartmut Goebel, h.goebel@goebel-consult.de
#
# This file is part of Shinken.
#
# Shinken is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Shinken is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Shinken.  If not, see <http://www.gnu.org/licenses/>.

import socket

from shinken.misc.type_hint import TYPE_CHECKING

if TYPE_CHECKING:
    from shinken.misc.type_hint import Optional, Any, Union, List


class Protocol(object):
    
    def __init__(self, url_protocol, default_port, use_ssl):
        # type: (unicode, int, bool) -> None
        self._url_protocol = url_protocol
        self._default_port = default_port
        self._use_ssl = use_ssl
    
    
    def get_default_port(self):
        # type: () -> int
        return self._default_port
    
    
    def get_url_protocol(self):
        # type: () -> unicode
        return self._url_protocol
    
    
    def get_use_ssl(self):
        # type: () -> bool
        return self._use_ssl


class ShinkenBaseUrlException(Exception):
    pass


class ShinkenBaseUrlExceptionInvalidPort(ShinkenBaseUrlException):
    pass


class ShinkenBaseUrlExceptionUnknownProtocol(ShinkenBaseUrlException):
    pass


class ShinkenBaseUrlExceptionEmptyHost(ShinkenBaseUrlException):
    pass


class BaseUrl(object):
    PROTOCOLS = [
        Protocol(u'http', 80, False),
        Protocol(u'https', 443, True),
        Protocol(u'ws', 80, False),
        Protocol(u'wss', 443, True),
        Protocol(u'mongodb', 27017, False),
        Protocol(u'sftp', 22, True),
        Protocol(u'ftp', 20, False),
    ]
    
    DEFAULT_PROTOCOL = Protocol(u'http', 80, False)
    DEFAULT_INTERFACE = u'0.0.0.0'
    
    LOCAL_ADDRESSES = (u'localhost', u'127.0.0.1')
    LOCAL_IP = u'127.0.0.1'
    HTTP_PORT_MAX = 65535
    HTTP_PORT_MIN = 0
    
    
    def __init__(self, host, port=None, protocol=None, strict=True):
        # type: (unicode, Optional[int], Optional[unicode], bool) -> None
        self.__cache__ip = None  # type: Optional[unicode]
        self._errors = []  # type: List[ShinkenBaseUrlException]
        
        self._host = self._compute_host(host, strict)
        self._protocol = self._compute_protocol(protocol, strict)  # type: Protocol
        self._port = self._compute_port(port, self._protocol.get_default_port(), strict)
    
    
    @staticmethod
    def from_url(url, strict=True):
        # type: (unicode, bool) -> BaseUrl
        url = url.strip().lower()
        protocol = None
        port = None
        
        if u'://' in url:
            protocol, url = url.split(u'://', 1)
        
        url = url.split(u'/', 1)[0]  # del url path part
        
        if u':' in url:
            url, port = url.split(u':', 2)[:2]
        
        return BaseUrl(url.strip(), port, protocol, strict=strict)
    
    
    def _compute_protocol(self, url_protocol, strict):
        # type: (Optional[unicode], bool) -> Protocol
        if not url_protocol:
            return self.DEFAULT_PROTOCOL
        for inst_proto in self.PROTOCOLS:
            if inst_proto.get_url_protocol() == url_protocol:
                return inst_proto
        exp = ShinkenBaseUrlExceptionUnknownProtocol(u'The protocol [ %s ] is unknown.' % url_protocol)
        if strict:
            raise exp
        self._errors.append(exp)
        return self.DEFAULT_PROTOCOL
    
    
    def _compute_port(self, port, default_port, strict):
        # type: (Optional[Union[int, unicode]], int, bool) -> int
        # We did not check 'if not port' because 0 is a 'valid' port
        if port is None or port == u'':
            return default_port
        try:
            casted_port = int(port)
            if BaseUrl.HTTP_PORT_MIN <= casted_port <= BaseUrl.HTTP_PORT_MAX:
                return casted_port
            raise ValueError
        except (ValueError, TypeError):
            exp = ShinkenBaseUrlExceptionInvalidPort(
                u'The port [ %s ] is not valid. Valid values are integers from %s to %s.' % (
                    port, BaseUrl.HTTP_PORT_MIN, BaseUrl.HTTP_PORT_MAX))
            if strict:
                raise exp
            self._errors.append(exp)
            return default_port
    
    
    def _compute_host(self, host, strict):
        # type: (unicode, bool) -> unicode
        if host:
            return host
        exp = ShinkenBaseUrlExceptionEmptyHost(u'The hostname or IP address is empty or not found.')
        if strict:
            raise exp
        self._errors.append(exp)
        return self.DEFAULT_INTERFACE
    
    
    @staticmethod
    def get_local_ip():
        # type: () -> Optional[unicode]
        try:
            return socket.gethostbyname(socket.gethostname())
        except Exception:
            return None
    
    
    def get_ip(self):
        # type: () -> Optional[unicode]
        if self.__cache__ip is None:
            try:
                self.__cache__ip = socket.gethostbyname(self._host)
            except Exception:
                # We do not want to save anything if the operation fails
                pass
        return self.__cache__ip
    
    
    def is_localhost(self):
        # type: () -> bool
        return self._host in self.LOCAL_ADDRESSES
    
    
    def is_local(self):
        # type: () -> bool
        if self.is_localhost():
            return True
        return self.get_host_identifier() == self.get_local_ip()
    
    
    def create_from(self, host=None, port=None, protocol=None, strict=True):
        # type: (Optional[unicode], Optional[int], Optional[unicode], bool) -> BaseUrl
        return BaseUrl(
            host if host is not None else self._host,
            port if port is not None else self._port,
            protocol if protocol is not None else self._protocol.get_url_protocol(),
            strict=strict
        )
    
    
    def get_host_identifier(self):
        # type: () -> unicode
        ip = self.get_ip()
        return ip if ip else self.get_host()
    
    
    def get_host(self):
        # type: () -> unicode
        return self._host
    
    
    def get_port(self):
        # type: () -> int
        return self._port
    
    
    def get_use_ssl(self):
        # type: () -> bool
        return self._protocol.get_use_ssl()
    
    
    def get_protocol(self):
        # type: () -> unicode
        return self._protocol.get_url_protocol()
    
    
    def get_url(self):
        # type: () -> unicode
        return u'%s://%s:%s' % (self.get_protocol(), self._host, self._port)
    
    
    def get_url_identifier(self):
        # type: () -> unicode
        return u'%s://%s:%s' % (self.get_protocol(), self.get_host_identifier(), self._port)
    
    
    def get_display_host_with_ip(self):
        # type: () -> unicode
        ip = self.get_ip()
        host = self.get_host()
        if ip and ip != host:
            return u'%s (%s)' % (host, ip)
        return host
    
    
    def build_url_with_path(self, path):
        # type: (unicode) -> unicode
        url_template = u'%s%s'
        if not path.startswith(u'/'):
            url_template = u'%s/%s'
        return url_template % (self.get_url(), path)
    
    
    def has_errors(self):
        # type: () -> bool
        return bool(self._errors)
    
    
    def get_exceptions(self):
        # type: () -> List[ShinkenBaseUrlException]
        return self._errors
    
    
    def get_error_messages(self):
        # type: () -> List[unicode]
        return [exp.message for exp in self._errors]
    
    
    def __str__(self):
        # type: () -> unicode
        return self.get_url()
    
    
    def __unicode__(self):
        # type: () -> unicode
        return self.get_url()
    
    
    def __eq__(self, other):
        # type: (Any) -> bool
        if not isinstance(other, BaseUrl):
            return False
        return self.get_url_identifier() != other.get_url_identifier()
