#!/usr/bin/env python
"""
simple socket server with json header protocol
"""
import os
import sys
# todo - REMOVE THIS PYTHONPATH hack (including the code placement constrain)
# it's necessary since Python interpreter used in SystemConsole environment
# does not automatically add a path to this module/script to sys.path 
# and does not consider the PYTHONPATH environment variable
_package_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if(sys.path.count(_package_dir) == 0):
    sys.path.insert(0, _package_dir)
    print("added '%s' to sys.path" % _package_dir)
else:
    print("not adding '%s' to sys.path - it is already there" % _package_dir)

# *** please take extra care not to insert ANY code before this line ***


# Avoid putting any PythonSV or SystemConsole or other dependencies in this
# module --Keep to standard CPython

import argparse
import json
import select
import socket
import struct
import threading
import traceback

import svtools
# Add svtools.common to sys.modules
#   so that scripts can determine the baseaccess
#   and we don't need to install svtools.common in external pythonsv
import pysv_fpga.toolext.svtools.common
svtools.common = pysv_fpga.toolext.svtools.common
sys.modules['svtools.common'] = svtools.common
from svtools.common import baseaccess

try:
    from pysv_fpga.logs import get_logger
    logger = get_logger(__name__)
    #print("aquired '%s' logger:" % __name__, logger)
except ImportError:
    print("(logging support not available, using print)")
    import logging
    logger = logging.getLogger(__name__)
    logger.to_stdout = lambda x=True: None  # no-op
    logger.to_file = lambda x, y=True: None  # no-op

    def _print(s, *args):
        if args:
            print(s % tuple(args))
        else:
            print(s)
    logger.info = _print
    logger.debug = _print


_missing = object()  # sentinel
default_port = 65432


class Server(threading.Thread):
    def __init__(self, port=0, port_file=None):
        threading.Thread.__init__(self)

        self.port = port
        self.port_file = port_file
        try:
            from pysv_fpga.commands.general import Commands as GeneralCommands
        except ImportError:
            logger.info("Error importing commands from "
                        "pysv_fpga.commands.general")
            self.commands = None
        else:
            self.commands = GeneralCommands(logger)

        # The shutdown_flag is a threading.Event object that
        # indicates whether the thread should be terminated.
        self.shutdown_flag = threading.Event()

    def run(self):
        if self.port is None:
            self.port=0
        if self.port_file:
            if os.path.exists(self.port_file):
                raise RuntimeError("port file '%s' already exists, will not "
                                   "overwrite." % self.port_file)

        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind(('127.0.0.1', self.port))
            bound_port = s.getsockname()[1]
            logger.info("Starting up server on port %s, waiting for connection"
                  % bound_port)
            s.listen(1)
            s.settimeout(1)
            if self.port_file:
                logger.info("Writing port file '%s'" % self.port_file)
                with open(self.port_file, 'w') as w:
                    w.write(str(bound_port))

            while not self.shutdown_flag.is_set():
                try:
                    conn, addr = s.accept()
                    with conn:
                        logger.info("Connected to %s", addr)
                        server = Agent(conn)
                        while not self.shutdown_flag.is_set():
                            cmd_msg = server.recv_msg()
                            if cmd_msg is None:
                                logger.info("Client closed connection")
                                break

                            if cmd_msg.cmd == 'close_server':
                                logger.info("close_server command received, "
                                            "shutting down")
                                return

                            ret_msg = self.execute_cmd_msg(cmd_msg)
                            ret_package = ret_msg.to_package()
                            logger.debug("Sending:  %s", ret_package)
                            conn.sendall(ret_package)

                except socket.timeout:
                    pass

    def init_logger(self, level=None, to_stdout=False, to_file=None):
        """Initializes the logger for the rpcserver module"""
        # useful for passing along logger settings when the server is running
        # in a subprocesss (required for multiprocessing on Windows systems)
        if to_stdout:
            logger.to_stdout(True)
        if to_file:
            logger.to_file(to_file)
        if level is not None:
            logger.setLevel(level)

    def execute_cmd_msg(self, msg):
        """Executes command in the msg and returns a Message obj"""
        # todo - add meta programming here
        cmd = msg.cmd
        args = msg.args
        logger.debug("Execute:  %s", cmd)
        try:
            if cmd == 'register_device':
                # register_device also used to designate a set of valid commands

                # required args for rpcserver:
                devtype = args['devtype']
                # optional args for rpcserver:
                bogus = args.pop('bogus', False)

                commands_cls = _get_commands_cls(devtype, bogus=bogus)
                self.commands = commands_cls(logger)

                # Now call the appropriate register_device function
                result = self.commands.register_device(**args)
            else:
                try:
                    func = getattr(self.commands, cmd)
                except AttributeError:
                    raise RuntimeError("Cannot find command '%s' in currently "
                                       "loaded commands (%s)"
                                       % (cmd,
                                          self.commands.__class__.__name__))
                else:
                    result = func(**args)
            return ReturnMessage(result)
        except:
            traceback_str = ''.join(traceback.format_exception(*sys.exc_info()))
            logger.error("execute_cmd_msg error: " + traceback_str)
            return ErrorMessage(traceback_str)
            # raise # temporary - raise the exception to help with debug


def _get_commands_cls(commands_module_name, bogus=False):
    """Load and return the Commands class from the given module name"""
    # try to get the sibling commands subpackage
    package = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
    if(commands_module_name.startswith("gdr")):
        commands_module_name = "gdr_toolkit"
        full_name = '%s.commands.%s'% (package, commands_module_name)
    else:
        full_name = '%s.commands.%s'% (package, commands_module_name)
    import importlib
    specific_commands_mod = importlib.import_module(full_name)
    # convention: acceptable class names are Commands and BogusCommands
    cls_name = "BogusCommands" if bogus else "Commands"
    try:
        return getattr(specific_commands_mod, cls_name)
    except AttributeError:
        raise AttributeError("Cannot find %s class in module %s"
                             % (cls_name, full_name))


class Agent:
    """Basic socket agent, to be either a client or a server"""

    def __init__(self, sock=None):
        self._sock = sock
        self._recv_buffer = b''
        self._header_len = None
        self._header = None
        self._message = None

    def _clear(self):
        self._header_len = None
        self._header = None
        self._message = None

    def close(self):
        if self._sock:
            self._sock.close()
            self._sock = None

    def send_msg(self, msg):
        if self._sock is None:
            raise RuntimeError("Socket is None.  Call connect or listen first.")
        package = msg.to_package()
        logger.debug("Sending:  %s", package)
        self._sock.sendall(package)

    def recv_msg(self):
        """Receive message in stages: header_len, header, message"""
        if self._sock is None:
            raise RuntimeError("Socket is None.  Call connect or listen first.")

        try:
            while True:
                self._recv()

                if self._header_len is None:
                    self._process_header_len()

                if self._header_len is not None:
                    if self._header is None:
                        self._process_header()

                if self._header:
                    if self._message is None:
                        self._process_message()
                    if self._message is not None:
                        message = self._message
                        self._clear()
                        return message
        except _NothingReceivedError:
            return None

    def _recv(self):
        """Read/Write"""
        socks = [self._sock]
        readables, _, errs = select.select(socks, socks, socks)
        if errs:
            raise RuntimeError("Socket has a error, shutting down")
        if readables:
            received = readables[0].recv(2048)
            if not received:
                raise _NothingReceivedError()
            self._recv_buffer += received
            logger.debug("Received: %s", self._recv_buffer)

    def _process_header_len(self):
        req = 2
        if len(self._recv_buffer) >= req:
            self._header_len = struct.unpack('>H', self._recv_buffer[:req])[0]
            self._recv_buffer = self._recv_buffer[req:]

    def _process_header(self):
        req = self._header_len
        if len(self._recv_buffer) >= req:
            self._header = Header.from_bytes(self._recv_buffer[:req]) #, 'utf-8')
            self._recv_buffer = self._recv_buffer[req:]

    def _process_message(self):
        req = self._header.msg_len
        if len(self._recv_buffer) >= req:
            self._message = Message.from_bytes(self._recv_buffer[:req]) #, 'utf-8')
            self._recv_buffer = self._recv_buffer[req:]


class _NothingReceivedError(Exception):
    pass


class Message:
    """Message Version 0 Request/Response format"""

    def to_bytes(self):
        return json.dumps(self.__dict__).encode('utf-8')

    @staticmethod
    def from_bytes(msg_bytes):
        json_obj = json.loads(msg_bytes)

        missing = _missing  # (lookup global sentinel once)
        err = json_obj.get('err', missing)
        if err is not missing:
            return ErrorMessage(err)

        ret = json_obj.get('ret', missing)
        if ret is not missing:
            return ReturnMessage(ret)

        try:
            cmd = json_obj['cmd']
        except KeyError:
            raise RuntimeError("Could not decode message str %s")
        args = json_obj.get('args', None)  # need a sentinel value here instead?
        return CommandMessage(cmd, args)

    def to_package(self):
        msg_bytes = self.to_bytes()
        msg_bytes += b"\n"  # seems to be necessary to TCL Client
        msg_len = len(msg_bytes)
        header_bytes = Header(msg_len).to_bytes()
        header_len = len(header_bytes)
        package = struct.pack('>H', header_len) + header_bytes + msg_bytes
        return package


class ErrorMessage(Message):
    def __init__(self, err_str):
        self.err = err_str


class CommandMessage(Message):
    def __init__(self, cmd, args):
        self.cmd = cmd
        self.args = args


class ReturnMessage(Message):
    def __init__(self, ret):
        self.ret = ret


class Header:

    def __init__(self, msg_len):
        #self.version = 0  # exclude for now
        self.msg_len = msg_len

    def to_bytes(self):
        return json.dumps(self.__dict__).encode('utf-8')

    @staticmethod
    def from_bytes(header_bytes):
        json_obj = json.loads(header_bytes)
        version = json_obj.get('version', None)
        if version is None :
            version = 0
        if not isinstance(version , int):
            raise ValueError("Bad version in header: %s (expected int)"
                             % version)
        if version == 0:
            msg_len = json_obj["msg_len"]
            # add clearer exception beside Key error
            if not isinstance(msg_len, int):
                raise ValueError("Bad 'len' in header: %s (expected int)"
                                 % msg_len)
            return Header(msg_len)
        else:
            raise ValueError("Unsupported version %d" % version)


class Client(Agent):

    def __init__(self, port=None, sock=None):
        super(Client, self).__init__(sock)
        if sock is not None:
            self._port = None
        else:
            self._port = port or default_port

    def __enter__(self):
        if self._sock is None:
            self.connect()
        return self

    def __exit__(self, t, value, traceback):
        self.close()
        return False

    def connect(self):
        if self._sock:
            raise RuntimeError("Already connected!")
        if not self._port:
            raise RuntimeError("Cannot connect, port not set")
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
        self._sock.connect(('127.0.0.1', self._port))

    def close(self):
        if self._sock:
            logger.info("Client closing")
            self._sock.close()
            self._sock = None

    def execute(self, cmd_name, **kwargs):
        cmd_msg = CommandMessage(cmd_name, kwargs)
        self.send_msg(cmd_msg)
        ret_msg = self.recv_msg()
        if isinstance(ret_msg, ErrorMessage):
            raise Exception(ret_msg.err)
        return None if not ret_msg else ret_msg.ret

    def close_server(self):
        return self.execute("close_server")


def test_client(port=None):
    port = port or default_port
    logger.info("Running test_client with port %s" % port)
    with Client(port) as client:
        five = client.execute("add", a=2, b=3)
        if five != 5:
            raise Exception("Simple add did not work")
        logger.info("Executed 2+3 correctly  --GOOD")
        try:
            client.execute("add", a=2, b="three")
        except Exception as e:
            if "TypeError: unsupported operand type(s) for +: 'int' and 'str'"\
                    not in str(e):
                raise Exception("Got an unexpected exception: %s" % e)
            logger.info("Negative case: caught the expected exception  --GOOD")
        else:
            raise Exception("Did not get expected exception")


def parse_args(override=None):
    parser = argparse.ArgumentParser(description="Start a py socket server")
    parser.add_argument("--log-file",
                        help='output the logs to the given file path')
    parser.add_argument("-p",
                        "--port",
                        help='set connection port number')
    parser.add_argument("--port-file",
                        help='write port number to given file path')
    parser.add_argument("-v",
                        "--verbose",
                        action='count',
                        default=0,
                        help='set verbose level of logger (-v=DEBUG)')
    parser.add_argument("--testclient",
                        action="store_true",
                        help='run a test as a simple client to add 2 numbers'
                             ' and also check error handling.')
    a = parser.parse_args() if override is None else parser.parse_args(override)
    if a.port is not None:
        a.port = int(a.port)
    return a


def _is_interactive():
    return bool(getattr(sys, 'ps1', sys.flags.interactive))


def main(override=None):
    args = parse_args(override)
    logger.to_stdout(True)
    if args.verbose:
        logger.setLevel(10)  # can add more levels if needed
    else:
        logger.setLevel(20)  # can add more levels if needed
    if args.log_file:
        if os.path.exists(args.log_file):
            raise ValueError("destination log file %s already exists"
                             % args.log_file)
        logger.to_file(args.log_file)
    logger.info("sys.path: %s", sys.path)

    if args.testclient:
        if args.port is None and args.port_file is not None:
            with open(args.port_file) as r:
                args.port = int(r.read().strip())
        if _is_interactive():
            # create a Client and add it to the global namespace
            import __main__ as _main
            client = Client(args.port)
            client.connect()
            _main.client = client
            import atexit
            @atexit.register
            def close_client():
                print("(Trying to close client on exit)")
                client.close()
        else:
            test_client(args.port)
    else:
        server = Server(args.port, args.port_file)
        server.run()


if __name__ == "__main__":
    main()
