|
import gradio as gr
|
|
import subprocess
|
|
import os
|
|
import torch
|
|
from model import GPTLanguageModel, decode, context
|
|
|
|
|
|
REPO_URL = "https://huggingface.co/TharunSivamani/tiny-shakespeare"
|
|
REPO_NAME = "tiny-shakespeare"
|
|
|
|
if not os.path.exists(REPO_NAME):
|
|
subprocess.run(["git", "clone", REPO_URL])
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
model = GPTLanguageModel().to(DEVICE)
|
|
model.load_state_dict(torch.load(f"{REPO_NAME}/model.pth", map_location=DEVICE), strict=False)
|
|
model.eval()
|
|
|
|
|
|
def display(text, number):
|
|
answer = decode(model.generate(context, max_new_tokens=number)[0].tolist())
|
|
return text + " \n" + answer
|
|
|
|
|
|
input_box = gr.Textbox(label="Story Lines", value="Once Upon a Time")
|
|
input_slider = gr.Slider(
|
|
minimum=200, maximum=500, label="Select the maximum number of tokens/words:", step=100
|
|
)
|
|
output_text = gr.Textbox()
|
|
|
|
gr.Interface(
|
|
fn=display,
|
|
inputs=[input_box, input_slider],
|
|
outputs=output_text,
|
|
examples=[["Shakespeare Once Said", 500], ["A Long Time Ago", 300]]
|
|
).launch() |