|
import gradio as gr
|
|
import torch
|
|
import os
|
|
from model_utils import load_model, generate_text, GPTConfig
|
|
|
|
|
|
try:
|
|
model_path = "best_model.pt"
|
|
if not os.path.exists(model_path):
|
|
raise FileNotFoundError(f"Model file {model_path} not found")
|
|
|
|
model = load_model(model_path)
|
|
if model is None:
|
|
raise ValueError("Model failed to load")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
model = model.to(device)
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
model = None
|
|
|
|
def predict(prompt, max_tokens, temperature, top_k):
|
|
"""Wrapper function for Gradio interface"""
|
|
if not prompt:
|
|
return "Error: Please enter a prompt"
|
|
if model is None:
|
|
return "Error: Model failed to load. Please check the logs."
|
|
|
|
try:
|
|
generated_text = generate_text(
|
|
model=model,
|
|
prompt=prompt.strip(),
|
|
max_new_tokens=int(max_tokens),
|
|
temperature=float(temperature),
|
|
top_k=int(top_k)
|
|
)
|
|
return generated_text
|
|
except Exception as e:
|
|
return f"Error during generation: {str(e)}"
|
|
|
|
|
|
demo = gr.Interface(
|
|
fn=predict,
|
|
inputs=[
|
|
gr.Textbox(label="Enter your prompt", lines=3, placeholder="Type your text here..."),
|
|
gr.Slider(minimum=1, maximum=200, value=50, step=1, label="Max Tokens"),
|
|
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, step=0.1, label="Temperature"),
|
|
gr.Slider(minimum=1, maximum=100, value=40, step=1, label="Top-k"),
|
|
],
|
|
outputs=gr.Textbox(label="Generated Text", lines=5),
|
|
title="GPT Text Generation",
|
|
description="Enter a text prompt and the model will generate a continuation.",
|
|
examples=[
|
|
["The quick brown fox", 50, 0.8, 40],
|
|
["Once upon a time", 100, 0.9, 50],
|
|
["In the distant future", 75, 0.7, 30],
|
|
],
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(share=False)
|
|
else:
|
|
app = demo.launch(share=False) |