Leo71288's picture
Update app.py
7c61e0b verified
import torch
import torch.nn as nn
import re
import gradio as gr
# ====== DEVICE ======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ====== MODEL ======
class TextPredictor(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, seq_len):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.fc1 = nn.Linear(embed_dim * seq_len, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
x = self.embedding(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.relu(x)
return self.fc2(x)
# ====== LOAD SAVED MODEL ======
checkpoint = torch.load("text_predictor_model.pth", map_location=device, weights_only=False)
vocab = checkpoint['vocab']
word2idx = checkpoint['word2idx']
idx2word = checkpoint['idx2word']
SEQ_LEN = checkpoint['seq_len']
EMBED_DIM = 64
HIDDEN_DIM = 128
model = TextPredictor(len(vocab), EMBED_DIM, HIDDEN_DIM, SEQ_LEN).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# ====== TEXT GENERATION ======
def generate_text(prompt, length=60, temperature=1.0):
prompt = prompt.lower() if prompt else ''
prompt = re.sub(r'[^a-z0-9\s<>/]', '', prompt)
words = prompt.split()
if len(words) < SEQ_LEN:
words = [''] * (SEQ_LEN - len(words)) + words
current_seq = [word2idx.get(w, 0) for w in words[-SEQ_LEN:]]
generated = words.copy()
for _ in range(length):
x = torch.tensor([current_seq], dtype=torch.long, device=device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits / temperature, dim=1).squeeze()
next_idx = torch.multinomial(probs, num_samples=1).item()
next_word = idx2word[next_idx]
generated.append(next_word)
current_seq = current_seq[1:] + [next_idx]
return ' '.join(generated)
# ====== GRADIO INTERFACE ======
def gradio_generate(artist, user_prompt, length, temperature):
artist = artist.strip().lower()
user_prompt = user_prompt.strip().lower()
artist_tag = f"<{artist}>" if artist else "<None>"
full_prompt = f"{artist_tag} {user_prompt}" if user_prompt else artist_tag
output = generate_text(full_prompt, length=length, temperature=temperature)
return output
with gr.Blocks() as demo:
gr.Markdown("## 🎀 AI Lyrics Generator")
with gr.Row():
artist_input = gr.Textbox(label="Artist (leave empty for <None>)")
prompt_input = gr.Textbox(label="Prompt (optional)")
with gr.Row():
length_slider = gr.Slider(10, 200, value=60, label="Text Length")
temp_slider = gr.Slider(0.5, 1.5, value=1.0, step=0.1, label="Temperature (0.5 = more predictable, 1.5 = more creative)")
generate_button = gr.Button("πŸŽ™οΈ Generate")
output_box = gr.Textbox(lines=20, label="Generated Lyrics")
generate_button.click(fn=gradio_generate,
inputs=[artist_input, prompt_input, length_slider, temp_slider],
outputs=output_box)
# ====== LAUNCH ======
if __name__ == "__main__":
demo.launch()