import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, logging


checkpoint = "Salesforce/codet5p-770m"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, cache_dir="models/")


def code_gen(text):
    logging.set_verbosity(logging.CRITICAL)

    print("*** Pipeline:")
    pipe = pipeline(
        model=checkpoint,
        # tokenizer=tokenizer,
        max_new_tokens=124,
        temperature=0.7,
        top_p=0.95,
        device= "cuda" if torch.cuda.is_available() else "cpu",
        repetition_penalty=1.15
    )

    response = pipe(text)
    print(response)

    return response[0]['generated_text']


iface = gr.Interface(fn=code_gen,
                     inputs=gr.inputs.Textbox(
                         label="Input Source Code"),
                     outputs="text",
                     title="Code Generation")

iface.launch()