import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

# checkpoint = "Salesforce/codegen25-7b-instruct"
# checkpoint = "Salesforce/codegen-2B-nl"
checkpoint = "Salesforce/codegen2-1B"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained(checkpoint, cache_dir="models/")
model = AutoModelForCausalLM.from_pretrained(checkpoint)

def code_gen(text):
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    generated_ids = model.generate(input_ids, max_length=128)
    response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(response)
    return response


iface = gr.Interface(fn=code_gen,
                     inputs=gr.inputs.Textbox(
                         label="Input Source Code"),
                     outputs="text",
                     title="Code Generation")

iface.launch()