import gradio as gr import subprocess import os import torch from model import GPTLanguageModel, decode, context # Clone the repository if not already cloned REPO_URL = "https://huggingface.co/TharunSivamani/tiny-shakespeare" REPO_NAME = "tiny-shakespeare" if not os.path.exists(REPO_NAME): subprocess.run(["git", "clone", REPO_URL]) # Set the device DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Load the model model = GPTLanguageModel().to(DEVICE) model.load_state_dict(torch.load(f"{REPO_NAME}/model.pth", map_location=DEVICE), strict=False) model.eval() # Define the display function def display(text, number): answer = decode(model.generate(context, max_new_tokens=number)[0].tolist()) return text + " \n" + answer # Gradio app interface input_box = gr.Textbox(label="Story Lines", value="Once Upon a Time") input_slider = gr.Slider( minimum=200, maximum=500, label="Select the maximum number of tokens/words:", step=100 ) output_text = gr.Textbox() gr.Interface( fn=display, inputs=[input_box, input_slider], outputs=output_text, examples=[["Shakespeare Once Said", 500], ["A Long Time Ago", 300]] ).launch()