#!/usr/bin/env python3
# coding=utf-8
#
# Copyright 2020 Institute of Formal and Applied Linguistics, Faculty of
# Mathematics and Physics, Charles University, Czech Republic.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

"""Word embeddings server class."""

import email.parser
import http.server
import json
import random
import socketserver
import sys
import threading
import time
import urllib.error
import urllib.parse
import urllib.request


class FrontendRESTServer(socketserver.TCPServer):
    class Backend():
        def __init__(self, server):
            self._server = server

            with self.request("/models") as response:
                data = json.loads(response.read())

            assert "models" in data and isinstance(data["models"], dict)
            self.models = data["models"]

            assert "default_model" in data and isinstance(data["default_model"], str)
            self.default_model = data["default_model"]

        def request(self, url, data=None, data_content_type=None):
            return urllib.request.urlopen(urllib.request.Request(
                url="http://{}{}".format(self._server, url),
                data=data,
                headers={} if data is None else {"Content-Type": data_content_type},
            ))


    class FrontendRESTServer(http.server.BaseHTTPRequestHandler):
        protocol_version = "HTTP/1.1"

        format_for_log_table = str.maketrans("\n", "\r", "\r")
        def format_for_log(request, data, limit=None):
            if limit is not None:
                if limit <= 0:
                    data = "[{}B]".format(len(data))
                elif len(data) > limit:
                    data = data[:limit // 2] + " ... " + data[min(-1, -limit // 2):]
            return data.translate(request.format_for_log_table)

        def respond(request, content_type, code=200, additional_headers={}):
            request.close_connection = True
            request.send_response(code)
            request.send_header("Connection", "close")
            request.send_header("Content-Type", content_type)
            request.send_header("Access-Control-Allow-Origin", "*")
            for key, value in additional_headers.items():
                request.send_header(key, value)
            request.end_headers()

        def respond_error(request, message, code=400):
            request.respond("text/plain", code)
            request.wfile.write(message.encode("utf-8"))

        def handle_expect_100(request):
            try:
                request_too_long = int(request.headers["Content-Length"]) > request.server._args.max_request_size
            except:
                request_too_long = False

            if request_too_long:
                request.respond_error("The payload size is too large.")
                return False
            return super().handle_expect_100()

        def do_GET(request):
            # Parse the model from URL/body
            params, body, body_content_type = {}, None, None
            try:
                encoded_path = request.path.encode("iso-8859-1").decode("utf-8")
                url = urllib.parse.urlparse(encoded_path)
                for name, value in urllib.parse.parse_qsl(url.query, encoding="utf-8", keep_blank_values=True, errors="strict"):
                    params[name] = value
            except:
                return request.respond_error("Cannot parse request URL.")

            # Parse the body of a POST request
            if request.command == "POST":
                if request.headers.get("Transfer-Encoding", "identity").lower() != "identity":
                    return request.respond_error("Only 'identity' Transfer-Encoding of payload is supported for now.")

                try:
                    content_length = int(request.headers["Content-Length"])
                except:
                    return request.respond_error("The Content-Length of payload is required.")

                if content_length > request.server._args.max_request_size:
                    while content_length:
                        read = request.rfile.read(min(content_length, 65536))
                        content_length -= len(read) if read else content_length
                    return request.respond_error("The payload size is too large.")

                body = request.rfile.read(content_length)
                body_content_type = request.headers.get("Content-Type", "")

                # multipart/form-data
                if request.headers.get("Content-Type", "").startswith("multipart/form-data"):
                    try:
                        parser = email.parser.BytesFeedParser()
                        parser.feed(b"Content-Type: " + request.headers["Content-Type"].encode("ascii") + b"\r\n\r\n")
                        parser.feed(body)
                        for part in parser.close().get_payload():
                            name = part.get_param("name", header="Content-Disposition")
                            if name:
                                params[name] = part.get_payload(decode=True).decode("utf-8")
                    except:
                        return request.respond_error("Cannot parse the multipart/form-data payload.")
                # x-www-form-urlencoded
                elif request.headers.get("Content-Type", "").startswith("application/x-www-form-urlencoded"):
                    try:
                        for name, value in urllib.parse.parse_qsl(
                                body.decode("utf-8"), encoding="utf-8", keep_blank_values=True, errors="strict"):
                            params[name] = value
                    except:
                        return request.respond_error("Cannot parse the application/x-www-form-urlencoded payload.")

            # Log if required
            if request.server._args.log_data:
                print(url.path, " ".join(request.headers.get_all("X-Forwarded-For", [])),
                      *["{}:{}".format(key, request.format_for_log(value)) for key, value in params.items() if key != "data"],
                      "data:" + request.format_for_log(params.get("data", ""), request.server._args.log_data),
                      sep="\t", file=sys.stderr, flush=True)

            # Handle /models
            if url.path == "/models":
                response = {
                    "models": {name: value for backend in request.server.backends for name, value in backend.models.items()},
                    "default_model": request.server.backends[0].default_model,
                }
                request.respond("application/json")
                request.wfile.write(json.dumps(response, indent=1).encode("utf-8"))
            # Handle everything else
            else:
                # Start by finding appropriate backends
                backends = request.server.backends.copy()
                model = params.get("model", request.server.backends[0].default_model)
                if model in request.server.aliases:
                    resolved_model = request.server.aliases[model]
                    backends = [backend for backend in request.server.backends if resolved_model in backend.models] or backends

                # Forward the request to the backend
                started_responding = False
                try:
                    assert backends, "No backends found!"
                    while backends:
                        backend = random.choice(backends) if len(backends) > 1 else backends[0]
                        backends.remove(backend)
                        try:
                            with backend.request(request.path, body, body_content_type) as response:
                                while True:
                                    data = response.read(32768)
                                    if not started_responding:
                                        started_responding = True
                                        billing_infclen = response.getheader("X-Billing-Input-NFC-Len", None)
                                        headers = {"X-Billing-Input-NFC-Len": billing_infclen} if billing_infclen is not None else {}
                                        request.respond(response.getheader("Content-Type", "application/json"), code=response.code,
                                                        additional_headers=headers)
                                    if len(data) == 0: break
                                    request.wfile.write(data)
                        except urllib.error.HTTPError as error:
                            if not started_responding:
                                started_responding = True
                                request.respond(error.headers.get("Content-Type", "text/plain"), code=error.code)
                                request.wfile.write(error.file.read())
                                break
                            raise
                        except:
                            if backends and not started_responding:
                                import traceback
                                traceback.print_exc(file=sys.stderr)
                                print("The above error occurred during request processing on '{}',".format(backend._server),
                                      "but more backends are available, retrying.", file=sys.stderr, flush=True)
                                continue
                            raise
                        break
                except:
                    import traceback
                    traceback.print_exc(file=sys.stderr)
                    sys.stderr.flush()

                    if not started_responding:
                        request.respond_error("An internal error occurred during processing.")
                    else:
                        request.wfile.write(b'",\n"An internal error occurred during processing, producing incorrect JSON!"')

        def do_POST(request):
            return request.do_GET()

    def __init__(self, args):
        self._args = args

        # Initialize all backends
        self.backends = [self.Backend(backend) for backend in args.backends]

        # Initialize the aliases
        self.aliases = {}
        if args.aliases is not None:
            with open(args.aliases, "r", encoding="utf-8") as aliases_file:
                for line in aliases_file:
                    line = line.rstrip("\r\n")
                    if not line or line.startswith("#"):
                        continue
                    parts = line.split()
                    assert len(parts) in [3, 4], "Expected 3-4 columns in the aliases file: line '{}'".format(line)
                    names = parts[0].split(":")
                    for name in names:
                        parts = name.split("-")
                        for prefix in ("-".join(parts[:None if not i else -i]) for i in range(len(parts))):
                            self.aliases.setdefault(prefix, names[0])

        # Initialize the server
        self._threads = []
        super().__init__(("", self._args.port), self.FrontendRESTServer)

    def server_bind(self):
        import socket
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
        super().server_bind()

    def server_activate(self):
        self.socket.listen(256)

    def process_request_thread(self, request, client_address):
        try:
            self.finish_request(request, client_address)
        except Exception:
            self.handle_error(request, client_address)
        finally:
            self.shutdown_request(request)

    def process_request(self, request, client_address):
        thread = threading.Thread(target=self.process_request_thread, args=(request, client_address), daemon=False)
        self._threads.append(thread)
        thread.start()

    def service_actions(self):
        if len(self._threads) >= self._args.max_concurrency:
            self._threads = [thread for thread in self._threads if thread.is_alive()]

        while len(self._threads) >= self._args.max_concurrency:
            time.sleep(0.1)
            self._threads = [thread for thread in self._threads if thread.is_alive()]

    def server_close(self):
        super().server_close()
        for thread in self._threads:
            thread.join()


if __name__ == "__main__":
    import argparse
    import signal

    # Parse server arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("port", type=int, help="Port to use")
    parser.add_argument("backends", type=str, nargs="+", help="Backends to use")
    parser.add_argument("--aliases", default=None, type=str, help="Path to model aliases")
    parser.add_argument("--logfile", default=None, type=str, help="Log path")
    parser.add_argument("--log_data", default=None, type=int, help="Log that much bytes of every request data")
    parser.add_argument("--max_concurrency", default=256, type=int, help="Maximum concurrency")
    parser.add_argument("--max_request_size", default=4096*1024, type=int, help="Maximum request size")
    args = parser.parse_args()

    # Log stderr to logfile if given
    if args.logfile is not None:
        sys.stderr = open(args.logfile, "a", encoding="utf-8")

    # Create the server
    server = FrontendRESTServer(args)
    server_thread = threading.Thread(target=server.serve_forever, daemon=True)
    server_thread.start()

    print("Started Frontend REST server on port {}.".format(args.port), file=sys.stderr)
    print("To stop it gracefully, either send SIGINT (Ctrl+C) or SIGUSR1.", file=sys.stderr, flush=True)

    # Wait until the server should be closed
    signal.pthread_sigmask(signal.SIG_BLOCK, [signal.SIGINT, signal.SIGUSR1])
    signal.sigwait([signal.SIGINT, signal.SIGUSR1])
    print("Initiating shutdown of the Frontend REST server.", file=sys.stderr, flush=True)
    server.shutdown()
    print("Stopped handling new requests, processing all current ones.", file=sys.stderr, flush=True)
    server.server_close()
    print("Finished shutdown of the Frontend REST server.", file=sys.stderr, flush=True)