import sys
from typing import List
import traceback
import os
import base64
import json
import pprint

from huggingface_hub import Repository
from text_generation import Client

from requests.exceptions import ReadTimeout

PORT = 7860

# TODO: implement maximum length (currently, each iteration is limited by the slider-specified max length, but this can be iterated, or long code entered into the editor, to get really long documents
# if os.path.exists('unlock'):
#     # create an 'unlock' file (not checked into Git) locally to get full context lengths
#     MAX_LENGTH = 8192
# else:
#     # set to a shorter value to prevent long contexts and make the demo more efficient
#     MAX_LENGTH = 1024
# TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in the document for efficiency.'
TRUNCATION_MESSAGE = f'TODO'

HF_TOKEN = os.environ.get("HF_TOKEN", None)
API_URL = os.environ.get("API_URL")

with open("./HHH_prompt.txt", "r") as f:
    HHH_PROMPT = f.read() + "\n\n"

# used by the model
FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
END_OF_TEXT = "<|endoftext|>"

# used to mark infill locations in the editor
FIM_INDICATOR = "<infill>"

client = Client(
    API_URL, headers={"Authorization": f"Bearer {HF_TOKEN}"},
)

from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, StreamingResponse
app = FastAPI(docs_url=None, redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")

@app.head("/")
@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="static/index.html", media_type="text/html")

def generate(prefix, suffix=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0):
    # TODO: deduplicate code between this and `infill`
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    fim_mode = suffix is not None
    
    if suffix is not None:
        prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
    else:
        prompt = prefix
    output = client.generate(prompt, **generate_kwargs)
    generated_text = output.generated_text
    # TODO: set this based on stop reason from client.generate
    truncated = False
    while generated_text.endswith(END_OF_TEXT):
        generated_text = generated_text[:-len(END_OF_TEXT)]
    generation = {
        'truncated': truncated,
    }
    if fim_mode:
        generation['type'] = 'infill'
        generation['text'] = prefix + generated_text + suffix
        generation['parts'] = [prefix, suffix]
        generation['infills'] = [generated_text]
    else:
        generation['type'] = 'generate'
        generation['text'] = prompt + generated_text
        generation['parts'] = [prompt]
    return generation

@app.get('/generate')
async def generate_maybe(info: str):
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    form = json.loads(info)
    prompt = form['prompt']
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    try:
        generation = generate(prompt, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
        if generation['truncated']:
            message = TRUNCATION_MESSAGE 
        else:
            message = ''
        return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation['text'], 'message': message}
    except ReadTimeout as e:
        print(e)
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Request timed out.'}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}

@app.get('/infill')
async def infill_maybe(info: str):
    # info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
    # fix padding, following https://stackoverflow.com/a/9956217/1319683
    info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
    form = json.loads(info)
    length_limit = int(form['length'])
    temperature = float(form['temperature'])
    try:
        if len(form['parts']) > 2:
            return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Only a single <infill> token is supported!"}
        elif len(form['parts']) == 1:
            return {'result': 'error', 'text': ''.join(form['parts']), 'type': 'infill', 'message': f"error: Must have an <infill> token present!"}
        prefix, suffix = form['parts']
        generation = generate(prefix, suffix=suffix, temperature=temperature, max_new_tokens=length_limit, top_p=0.95, repetition_penalty=1.0)
        generation['result'] = 'success'
        if generation['truncated']:
            generation['message'] = TRUNCATION_MESSAGE
        else:
            generation['message'] = ''
        return generation
    except ReadTimeout as e:
        print(e)
        return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Request timed out.'}
    except Exception as e:
        traceback.print_exception(*sys.exc_info())
        return {'result': 'error', 'type': 'infill', 'message': f'Error: {e}.'}

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=PORT, threaded=False)