TharunSivamani's picture
final code
1c954f5 verified
import gradio as gr
import subprocess
import os
import torch
from model import GPTLanguageModel, decode, context
# Clone the repository if not already cloned
REPO_URL = "https://huggingface.co/TharunSivamani/tiny-shakespeare"
REPO_NAME = "tiny-shakespeare"
if not os.path.exists(REPO_NAME):
subprocess.run(["git", "clone", REPO_URL])
# Set the device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model
model = GPTLanguageModel().to(DEVICE)
model.load_state_dict(torch.load(f"{REPO_NAME}/model.pth", map_location=DEVICE), strict=False)
model.eval()
# Define the display function
def display(text, number):
answer = decode(model.generate(context, max_new_tokens=number)[0].tolist())
return text + " \n" + answer
# Gradio app interface
input_box = gr.Textbox(label="Story Lines", value="Once Upon a Time")
input_slider = gr.Slider(
minimum=200, maximum=500, label="Select the maximum number of tokens/words:", step=100
)
output_text = gr.Textbox()
gr.Interface(
fn=display,
inputs=[input_box, input_slider],
outputs=output_text,
examples=[["Shakespeare Once Said", 500], ["A Long Time Ago", 300]]
).launch()