#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2009-2012:
#     Gabes Jean, naparuba@gmail.com
#     Gerhard Lausser, Gerhard.Lausser@consol.de
#     Gregory Starck, g.starck@gmail.com
#     Hartmut Goebel, h.goebel@goebel-consult.de
#
# This file is part of Shinken.
#
# Shinken is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Shinken is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with Shinken.  If not, see <http://www.gnu.org/licenses/>.

import copy
import dataclasses
import io
import json
import re
import socket
import threading
import urllib.error
import urllib.parse
import urllib.request

import pycurl

from shinken.compat import cPickle, SHINKEN_PICKLE_PROTOCOL
from shinken.compresser import compresser
from shinken.log import logger
from shinken.misc.type_hint import TYPE_CHECKING
from shinkensolutions.localinstall import VERSION

if TYPE_CHECKING:
    from typing import Self
    from shinken.misc.type_hint import Any, Literal, Mapping, Optional
    
    WaitValue = Literal['short', 'long']

pycurl.global_init(pycurl.GLOBAL_ALL)
PYCURL_VERSION = pycurl.version_info()[1]


class HTTPException(Exception):
    SSL_FAILED = pycurl.E_SSL_CONNECT_ERROR
    TIMED_OUT = pycurl.E_OPERATION_TIMEDOUT
    
    
    def __init__(self, message: str, errno: int = -1, errstr: str = '') -> None:
        super().__init__(message)
        self.errno: int = errno
        self.errstr: str = errstr


HTTPExceptions = HTTPException

_HTTP_CONTENT_CHARSET = re.compile(r'charset=(\S+)')


def _hkey(name: str) -> str:
    return name.title().replace('_', '-')


@dataclasses.dataclass(frozen=True)
class HTTPResponse:
    status: int
    body: bytes = dataclasses.field(repr=False)
    headers: 'dict[str, list[str]]' = dataclasses.field(default_factory=dict)
    
    
    def get_header(self, name: str, default: str = '', index: int = -1) -> str:
        try:
            return self.headers[_hkey(name)][index]
        except LookupError:
            return default
    
    
    def get_all_headers(self, name: str) -> 'list[str]':
        return self.headers.get(_hkey(name)) or []
    
    
    @property
    def content_type(self) -> str:
        return self.get_header('Content-Type', '')
    
    
    def content_charset(self, *, default: str = '') -> str:
        if (content_type := self.content_type) and (match := _HTTP_CONTENT_CHARSET.search(content_type)):
            return match.group(1)
        return default or 'utf-8'
    
    
    def as_text(self, encoding: str = '', errors: str = 'ignore') -> str:
        encoding = self.content_charset(default=encoding)
        return self.body.decode(encoding, errors)
    
    
    def as_json(self) -> 'Any':
        return json.loads(self.as_text())


class HTTPClient:
    def __init__(self, address: str = '', port: int = 0, use_ssl: bool = False, timeout: int = 3, data_timeout: int = 120, uri: str = '', *, strong_ssl: bool = False, verbose: bool = False, caller_name: str = '') -> None:
        self.address = address
        self.port = port
        self.timeout = timeout
        self.data_timeout = max(data_timeout, timeout)
        
        if uri:
            self.uri = uri.removesuffix('/')
            try:
                uri_split = urllib.parse.urlsplit(uri)
                # uri_split.hostname & url_split.port could raise a ValueError
                self.address = uri_split.hostname or ''
                self.port = uri_split.port or (443 if uri_split.scheme == 'https' else 80)
            except (TypeError, ValueError):
                logger.error(f'Failed to parse given uri [ {uri} ] ')
                logger.print_stack()
        else:
            self.uri = f'{"https" if use_ssl else "http"}://{self.address}:{self.port}'
        
        self.con = pycurl.Curl()
        self._verbose_curl = bool(verbose)
        self._strong_ssl = bool(strong_ssl)
        
        self.lock = threading.RLock()
        
        self._default_headers: 'dict[str, str]' = {}
        
        self._compute_default_headers(caller_name)
        self._set_con_opt()
    
    
    def __enter__(self) -> 'Self':
        return self
    
    
    def __exit__(self, *exc_args: 'Any') -> None:
        self.con.close()
    
    
    def _compute_default_headers(self, caller_name: str) -> None:
        self._default_headers = {
            # Remove the Expect: 100-Continue default behavior of pycurl, because swsgiref do not manage it
            'Expect'       : '',
            'Keep-Alive'   : '300',
            'Connection'   : 'Keep-Alive',
            # Set the daemon name into the header, so the http daemon can display it in debug mode if asked
            'X-Caller-Name': caller_name.strip() or str(logger.name).strip(),
        }
        try:
            host_name = socket.getfqdn() or socket.gethostname()
        except Exception:
            pass
        else:
            self._default_headers['X-Forwarded-For'] = host_name
    
    
    def _set_con_opt(self) -> None:
        with self.lock:
            if self._verbose_curl:
                self.con.setopt(pycurl.VERBOSE, True)
            self.con.setopt(pycurl.USERAGENT, f'shinken:{VERSION} pycurl:{PYCURL_VERSION}')
            self.con.setopt(pycurl.FOLLOWLOCATION, 1)
            self.con.setopt(pycurl.CONNECTTIMEOUT, self.timeout)
            self.con.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_1_1)
            
            # Also set the SSL options to do not look at the certificates too much
            # unless the admin asked for it
            if self._strong_ssl:
                self.con.setopt(pycurl.SSL_VERIFYPEER, 1)
                self.con.setopt(pycurl.SSL_VERIFYHOST, 2)
            else:
                self.con.setopt(pycurl.SSL_VERIFYPEER, 0)
                self.con.setopt(pycurl.SSL_VERIFYHOST, 0)
    
    
    # Used by copy.deepcopy()
    def __getstate__(self) -> 'dict[str, Any]':
        state = self.__dict__.copy()
        state.pop('con', None)
        state.pop('lock', None)
        return state
    
    
    # Used by copy.deepcopy()
    def __setstate__(self, state: 'dict[str, Any]') -> None:
        self.__dict__.update(state)
        self.con = pycurl.Curl()
        self.lock = threading.RLock()
        self._set_con_opt()
    
    
    def clone(self, *, data_timeout: 'int|None' = None) -> 'HTTPClient':
        # Uses __getstate__ and __setstate__ to reconstruct the object.
        new_self = copy.deepcopy(self)
        if data_timeout is not None:
            new_self.data_timeout = max(data_timeout, new_self.timeout)
        return new_self
    
    
    @staticmethod
    def _build_http_headers_from_mapping(headers: 'Mapping[str, str]') -> 'list[str]':
        return [f'{name}: {value}'.strip() for name, value in headers.items()]
    
    
    @staticmethod
    def _normalize_headers(headers: 'Mapping[str, str]') -> 'dict[str, str]':
        return {_hkey(name): value for name, value in headers.items()}
    
    
    def raw_get(self, path: str, *, query: 'Optional[Mapping[str, Any]]' = None, headers: 'Optional[Mapping[str, str]]' = None, wait: 'WaitValue' = 'short') -> 'HTTPResponse':
        with self.lock:
            self.con.setopt(pycurl.POST, 0)
            self.con.setopt(pycurl.HTTPGET, 1)
            return self._perform_any_request(path=path, query=query, headers=headers, wait=wait)
    
    
    def raw_post(self, path: str, data: 'bytes', *, content_type: str = '', query: 'Optional[Mapping[str, Any]]' = None, headers: 'Optional[Mapping[str, str]]' = None, wait: 'WaitValue' = 'short') -> 'HTTPResponse':
        with self.lock:
            self.con.setopt(pycurl.HTTPGET, 0)
            self.con.setopt(pycurl.POST, 1)
            
            if content_type:
                headers = self._normalize_headers(headers) if headers else {}
                headers['Content-Type'] = content_type
            
            with io.BytesIO(data) as data_file:
                self.con.setopt(pycurl.READFUNCTION, data_file.read)
                self.con.setopt(pycurl.POSTFIELDSIZE, len(data))
                try:
                    return self._perform_any_request(path=path, query=query, headers=headers, wait=wait)
                finally:
                    # Make curl release the reference to the BytesIO object.
                    self.con.setopt(pycurl.READFUNCTION, lambda x: b'')
    
    
    def raw_post_fields(self, path: str, form: 'Mapping[str, Any]', *, query: 'Optional[Mapping[str, Any]]' = None, headers: 'Optional[Mapping[str, str]]' = None, wait: 'WaitValue' = 'short') -> 'HTTPResponse':
        with self.lock:
            self.con.setopt(pycurl.HTTPGET, 0)
            self.con.setopt(pycurl.POST, 1)
            self.con.setopt(pycurl.POSTFIELDS, urllib.parse.urlencode(form))
            return self._perform_any_request(path=path, query=query, headers=headers, wait=wait)
    
    
    def raw_post_multipart(self, path: str, form: 'Mapping[str, bytes]', *, query: 'Optional[Mapping[str, Any]]' = None, headers: 'Optional[Mapping[str, str]]' = None, wait: 'WaitValue' = 'short') -> 'HTTPResponse':
        with self.lock:
            self.con.setopt(pycurl.HTTPGET, 0)
            self.con.setopt(pycurl.POST, 1)
            # Pycurl want a list of tuple as args
            self.con.setopt(pycurl.HTTPPOST, list(form.items()))
            return self._perform_any_request(path=path, query=query, headers=headers, wait=wait)
    
    
    # Fast way for JSON data
    def raw_post_json(self, path: str, json_data: 'Any', *, query: 'Optional[Mapping[str, Any]]' = None, headers: 'Optional[Mapping[str, str]]' = None, wait: 'WaitValue' = 'short') -> 'HTTPResponse':
        json_data = json.dumps(json_data).encode('utf-8')
        return self.raw_post(path, json_data, content_type='application/json; charset=utf-8', query=query, headers=headers, wait=wait)
    
    
    # NOTE: The connection object must have been configured before calling this function.
    #       The LOCK must be HELD!
    def _perform_any_request(self, path: str, *, query: 'Optional[Mapping[str, Any]]' = None, headers: 'Optional[Mapping[str, str]]' = None, wait: 'WaitValue' = 'short') -> 'HTTPResponse':
        curl = self.con
        
        if headers:
            headers = {**self._default_headers, **self._normalize_headers(headers)}
        else:
            headers = self._default_headers
        
        curl.setopt(pycurl.HTTPHEADER, self._build_http_headers_from_mapping(headers))
        
        # long:data_timeout, like for huge broks receptions
        # short:timeout, like for just "ok" connection
        match wait:
            case 'short':
                curl.setopt(pycurl.TIMEOUT, self.timeout)
            case 'long':
                curl.setopt(pycurl.TIMEOUT, self.data_timeout)
            case _:
                logger.error(f'Unknown wait argument "{wait}". Assuming value is "long"')
                curl.setopt(pycurl.TIMEOUT, self.data_timeout)
        
        # curl must not raise pycurl.error for 4xx and 5xx status codes.
        curl.setopt(pycurl.FAILONERROR, False)
        
        if not path.startswith('/'):
            path = f'/{path}'
        full_url = f'{self.uri}{path}'
        if query:
            full_url = f'{full_url}?{urllib.parse.urlencode(query)}'
        
        curl.setopt(pycurl.URL, full_url)
        
        response_headers_parser = _IncrementalHeadersParser()
        curl.setopt(pycurl.HEADERFUNCTION, response_headers_parser.add_header)
        
        with io.BytesIO() as response:
            curl.setopt(pycurl.WRITEFUNCTION, response.write)
            try:
                curl.perform()
            except pycurl.error as exc:
                errno, errstr = exc.args
                raise HTTPException(f'Connection error to {self.uri} : [Errno {errno}] {errstr}', errno, errstr) from exc
            finally:
                # Make curl release the reference to the BytesIO object.
                curl.setopt(pycurl.WRITEFUNCTION, lambda x: None)
                # Make curl release the reference to the parser object.
                curl.setopt(pycurl.HEADERFUNCTION, lambda x: None)
            
            status = curl.getinfo(pycurl.HTTP_CODE)
            
            # Do NOT close the connection, we want a keep alive
            return HTTPResponse(status=status, body=response.getvalue(), headers=response_headers_parser.result)
    
    
    def get(self, path: str, args: 'Optional[Mapping[str, Any]]' = None, wait: 'WaitValue' = 'short') -> 'Any':
        response = self.raw_get(path, query=args, wait=wait)
        
        if response.status != 200:
            err = f'Connection error to {self.uri}/{path.removeprefix("/")} : [Status code {response.status}] {response.as_text()}'
            logger.error(err)
            raise HTTPException(err, response.status)
        
        return response.as_json()
    
    
    def post(self, path: str, args: 'Mapping[str, Any]', wait: 'WaitValue' = 'short') -> str:
        # Take args, pickle them and then compress the result
        args = dict(args)
        for (k, v) in args.items():
            _to_compress = cPickle.dumps(v, SHINKEN_PICKLE_PROTOCOL)
            args[k] = compresser.compress(_to_compress)
        
        response = self.raw_post_multipart(path, args, wait=wait)
        if response.status != 200:
            err = f'Connection error to {self.uri}/{path.removeprefix("/")} : [Status code {response.status}] {response.as_text()}'
            logger.error(err)
            raise HTTPException(err, response.status)
        
        return response.as_text('utf-8', 'replace')


class _IncrementalHeadersParser:
    def __init__(self):
        self.result: 'dict[str, list[str]]' = {}
    
    
    def add_header(self, header_line: bytes) -> None:
        # HTTP standard specifies that headers are encoded in iso-8859-1.
        name, have_semicolon, value = header_line.decode('iso-8859-1').partition(':')
        
        if not have_semicolon:
            # Invalid header, bail out.
            return
        
        name = name.strip()
        value = value.strip()
        
        self.result.setdefault(_hkey(name), []).append(value)
