#!/usr/bin/python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2013-2022
# This file is part of Shinken Enterprise, all rights reserved.

import errno
import os
import socket
import struct
import threading
import time
from collections import deque

from select import select

from shinken.misc.type_hint import TYPE_CHECKING

if TYPE_CHECKING:
    from shinken.misc.type_hint import Callable, Number, Optional


# May be useful for Windows system :
# Since Python3.5 (cf. socket.py) this code emulate lacking socket pair functionality:
# Origin: https://gist.github.com/4325783, by Geert Jansen.  Public domain.
def socket_pair(family=socket.AF_INET):
    proto = 0
    sock_type = socket.SOCK_STREAM
    
    if family == socket.AF_INET:
        host = u'127.0.0.1'
    elif family == socket.AF_INET6:
        host = u'::1'
    else:
        raise ValueError(u'Only AF_INET and AF_INET6 socket address families are supported')
    
    # We create a connected TCP socket. Note the trick with
    # set blocking(False) that prevents us from having to create a thread.
    listen_sock = socket.socket(family, sock_type, proto)
    try:
        listen_sock.bind((host, 0))
        listen_sock.listen(0)
        # On IPv6, ignore flow_info and scope_id
        addr, port = listen_sock.getsockname()[:2]
        client_sock = socket.socket(family, sock_type, proto)
        try:
            client_sock.setblocking(False)
            client_sock.connect_ex((addr, port))
            client_sock.setblocking(True)
            server_sock, _ = listen_sock.accept()
        except:
            client_sock.close()
            raise
    finally:
        listen_sock.close()
    return server_sock, client_sock


# Encapsulate socket (endpoint) to have one send matching one receive of this strict data
class ShinkenSocket(object):
    def __init__(self, sock):
        # type: (socket.socket) -> None
        self.socket = sock
        self.header_size = struct.calcsize('>I')
        self.socket.setblocking(True)
        # the wait all flag is not available on Windows system
        # noinspection SpellCheckingInspection
        self.send_recv_flags = socket.MSG_WAITALL if hasattr(socket, u'MSG_WAITALL') else 0
    
    
    def poll(self, timeout=0.0):
        # type: (Optional[Number]) -> bool
        result = select([self.socket.fileno()], [], [], timeout)
        return bool(result[0])
    
    
    def write_poll(self, timeout=0.0):
        # type: (Optional[Number]) -> bool
        result = select([], [self.socket.fileno()], [], timeout)
        return bool(result[1])
    
    
    def _force_recv(self, data_size):
        # type: (int) -> bytes
        data = self.socket.recv(data_size, self.send_recv_flags)
        if data == '':
            return data
        
        # Windows system ...
        remaining_data_size = data_size - len(data)
        while remaining_data_size > 0:
            self.poll(None)
            tmp_data = self.socket.recv(remaining_data_size, self.send_recv_flags)
            if tmp_data == '':
                return data
            data = '%s%s' % (data, tmp_data)
            remaining_data_size = data_size - len(data)
        return data
    
    
    def recv(self):
        # type: () -> bytes
        
        header = self._force_recv(self.header_size)
        if not header:
            return header
        data_size = struct.unpack('>I', header)[0]
        result = self._force_recv(data_size)
        
        return result
    
    
    def _force_send(self, data, data_size):
        # type: (bytes, int) -> int
        
        already_sent_size = self.socket.send(data, self.send_recv_flags)
        if already_sent_size <= 0:
            raise IOError(errno.EAGAIN, os.strerror(errno.EAGAIN))
        
        # Windows system ....
        remaining_size = data_size - already_sent_size
        while remaining_size > 0:
            self.write_poll(None)  # block until we can write, in order to send the whole data for pickle
            already_sent_size = self.socket.send(data[-remaining_size:], self.send_recv_flags)
            if already_sent_size <= 0:
                raise IOError(errno.EAGAIN, os.strerror(errno.EAGAIN))
            remaining_size = remaining_size - already_sent_size
        return data_size - remaining_size
    
    
    def send(self, data):
        # type: (bytes) -> int
        
        data_size = len(data)
        header = struct.pack('>I', data_size)
        self._force_send(header, self.header_size)
        sent_size = self._force_send(data, data_size)
        
        return sent_size
    
    
    def close(self):
        # type: () -> None
        try:
            self.socket.shutdown(socket.SHUT_RDWR)
        except:
            pass
        try:
            self.socket.close()
        except:
            pass
        self.socket = None
        # self.poll = lambda: False
        # self.close = self.recv = self.send = lambda: None


# Utility object gathering internal data for receiver or sender threads
class _ShinkenIOThread(object):
    def __init__(self, target, name=None):
        # type: (Callable,Optional[unicode]) -> None
        self._thread = None  # type: Optional[threading.Thread]
        self.lock = None  # type: Optional[threading.Condition]
        self.queue = deque()
        self.queue_size = 0
        self._target = target
        self.keep_thread_running = True
        self.name = name
    
    
    def init(self, have_forked=False):  # Called on startup or after fork
        # type: (bool) -> None
        
        # Once the queue is init for one process you cannot use it in another process.
        if have_forked and self.lock:
            raise RuntimeError(u'Endpoint used in two distinct processes')
        if not self.lock:
            self.lock = threading.Condition(threading.RLock())
    
    
    def ensure_is_started(self):  # Ensure thread is running
        # type: () -> None
        if self._thread:
            if self._thread.is_alive():
                return
            self._thread.join()
            self._thread = None
        if self.keep_thread_running:
            self._thread = threading.Thread(target=self._target, name=self.name)
            self._thread.daemon = True
            self._thread.start()
    
    
    def is_running(self):
        # type: () -> bool
        return self.keep_thread_running and self._thread and self._thread.is_alive()
    
    
    def quit(self):
        # type: () -> None
        self.keep_thread_running = False
        if self._thread:
            if self._thread.is_alive():
                with self.lock:
                    self.lock.notify_all()
            self._thread.join()
            self._thread = None
        self.queue = deque()
        self.queue_size = 0


# Encapsulate socket (endpoint) to have 2 threads, one managing data sending, the other managing data receiving with an internal dequeue storing pending data
# The main goal is avoiding saturation of socket's system buffer, and to have a fast data transfer (without latency)
class ShinkenThreadedSocket(ShinkenSocket):
    def __init__(self, sock, name=None):
        # type: (socket.socket,Optional[unicode]) -> None
        super(ShinkenThreadedSocket, self).__init__(sock)
        self._my_pid = os.getpid()
        self._sender_thread = _ShinkenIOThread(self._sender_thread_loop, (u'%s-snd' % name) if name else None)
        self._receiver_thread = _ShinkenIOThread(self._receiver_thread_loop, (u'%s-rcv' % name) if name else None)
    
    
    def _sender_thread_loop(self):
        # type: () -> None
        sender = self._sender_thread
        poll = self.write_poll
        send = super(ShinkenThreadedSocket, self).send
        
        with sender.lock:
            while sender.keep_thread_running:
                if len(sender.queue) == 0:
                    sender.lock.wait()
                    continue
                data = sender.queue.popleft()
                if data == '':  # Empty data will be interpreted as "Connection Closed" by receiver part
                    sender.lock.notify_all()
                    continue
                sender.queue_size = sender.queue_size - len(data)
                sender.lock.release()
                try:
                    poll(None)
                    send(data)
                except:
                    sender.lock.acquire()  # Overkill, but we need the lock to exit the context ...
                    sender.keep_thread_running = False
                    sender.lock.notify_all()
                    break
                sender.lock.acquire()
                sender.lock.notify_all()
    
    
    def _receiver_thread_loop(self):
        # type: () -> None
        receiver = self._receiver_thread
        recv = super(ShinkenThreadedSocket, self).recv
        
        with receiver.lock:
            while receiver.keep_thread_running:
                receiver.lock.release()
                try:
                    data = recv()
                except:
                    data = ''
                receiver.lock.acquire()
                
                if data == '':  # No data means connection has been closed
                    receiver.lock.notify_all()
                    receiver.keep_thread_running = False
                    break
                receiver.queue.append(data)
                receiver.queue_size = receiver.queue_size + len(data)
                receiver.lock.notify_all()
    
    
    def _init_threads(self, current_thread):
        # type: (_ShinkenIOThread) -> bool
        have_forked = False
        if os.getpid() != self._my_pid:
            self._my_pid = os.getpid()
            have_forked = True
        
        self._sender_thread.init(have_forked)
        self._receiver_thread.init(have_forked)
        
        current_thread.ensure_is_started()
        return current_thread.is_running()
    
    
    def close(self):
        # type: () -> None
        super(ShinkenThreadedSocket, self).close()  # Will unlock I/O waiting in threads
        self._sender_thread.quit()
        self._receiver_thread.quit()
    
    
    def send(self, data):
        # type: (bytes) -> int
        io_thread = self._sender_thread
        if not self._init_threads(io_thread):
            return 0
        
        sent_size = 0
        with io_thread.lock:
            if io_thread.is_running():
                io_thread.queue.append(data)
                io_thread.queue_size = io_thread.queue_size + len(data)
                io_thread.lock.notify_all()
                sent_size = len(data)
        return sent_size
    
    
    def recv(self):
        # type: () -> bytes
        self.poll(None)
        io_thread = self._receiver_thread
        
        data = ''
        if self._init_threads(io_thread):
            io_thread.lock.acquire()
            has_lock = True
        else:
            # Thread is not running, we still allow to fetch already received data
            has_lock = False
        
        if len(io_thread.queue) > 0:
            data = io_thread.queue.popleft()
            io_thread.queue_size = io_thread.queue_size - len(data)
        if has_lock:
            io_thread.lock.release()
        return data
    
    
    def poll(self, timeout=0.0):
        # type: (Optional[Number]) -> bool
        io_thread = self._receiver_thread
        if not self._init_threads(io_thread):
            return len(io_thread.queue) > 0
        
        start = time.time()
        with io_thread.lock:
            wait_time = 1 if timeout is None else (timeout - (time.time() - start))
            while len(io_thread.queue) <= 0 < wait_time and io_thread.is_running():
                if timeout is None:
                    io_thread.lock.wait(None)
                else:
                    io_thread.lock.wait(wait_time)
                    wait_time = timeout - (time.time() - start)
            return len(io_thread.queue) > 0
    
    
    def flush(self, timeout=0.0):
        # type: (Optional[Number]) -> bool
        io_thread = self._sender_thread
        if not self._init_threads(io_thread):
            return False
        
        start = time.time()
        with io_thread.lock:
            io_thread.lock.notify_all()
            wait_time = 1 if timeout is None else (timeout - (time.time() - start))
            while len(io_thread.queue) > 0 and wait_time > 0 and io_thread.is_running():
                if timeout is None:
                    io_thread.lock.wait(None)
                else:
                    io_thread.lock.wait(wait_time)
                    wait_time = timeout - (time.time() - start)
        return len(io_thread.queue) == 0
    
    
    def get_queues_size(self):
        return len(self._receiver_thread.queue), self._receiver_thread.queue_size, len(self._sender_thread.queue), self._sender_thread.queue_size
