from asyncio import sleep
from typing import Optional
from fastapi import FastAPI
from fastapi.encoders import jsonable_encoder
from fastapi.websockets import WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse, JSONResponse
from websockets import ConnectionClosed

from accelerator import Accelerator
from answerer import Answerer
from mapper import Mapper

try: mapper = Mapper("sentence-transformers/multi-qa-distilbert-cos-v1")
except Exception as e: print(f"ERROR! cannot load Mapper model!\n{e}")
    
answerer = Answerer(
  model="RWKV-5-World-3B-v2-20231118-ctx16k",
  vocab="rwkv_vocab_v20230424",
  strategy="cpu bf16",
  ctx_limit=16*1024,
)

accelerator = Accelerator()
  
app = FastAPI()

HTML = """
<!DOCTYPE HTML>

<html>

<body>
  <form action="" onsubmit="ask(event)">
    <textarea id="prompt"></textarea>
    <br>
    <input type="submit" value="SEND" />
  </form>

  <p id="output"></p>
  <script>
    const prompt = document.getElementById("prompt");
    const output = document.getElementById("output");

    const ws = new WebSocket("wss://daniilalpha-answerer-api.hf.space/answer"); 
    ws.onmessage = (e) => answer(e.data);

    function ask(event) {
      if(ws.readyState != 1) {
        answer("websocket is not connected!");
        return;
      }

      ws.send(prompt.value);
      event.preventDefault();
    }

    function answer(value) {
      output.innerHTML = value;
    }
  </script>
</body>

</html>
"""

@app.get("/")
def index():
  return HTMLResponse(HTML)

@app.websocket("/accelerate")
async def answer(ws: WebSocket):
  await accelerator.connect(ws)
  while accelerator.connected():
    await sleep(10)

@app.post("/map")
def map(query: Optional[str], items: Optional[list[str]]):
  scores = mapper(query, items)
  return JSONResponse(jsonable_encoder(scores))

async def handle_answerer_local(ws: WebSocket, input: str):
  output = answerer(input, 128)
  el: str
  async for el in output: pass
  await ws.send_text(el)

async def handle_answerer_accelerated(ws: WebSocket, input: str):
  output = await accelerator.accelerate(input)
  if output: await ws.send_text(output)
  else: await handle_answerer_local(ws, input)

@app.websocket("/answer")
async def answer(ws: WebSocket):
  await ws.accept()
      
  try: 
    input = await ws.receive_text()
    if accelerator.connected(): await handle_answerer_accelerated(ws, input)
    else: await handle_answerer_local(ws, input)
  except ConnectionClosed: return
  except WebSocketDisconnect: return
      
  await ws.close()