Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import re | |
import gradio as gr | |
# ====== DEVICE ====== | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# ====== MODEL ====== | |
class TextPredictor(nn.Module): | |
def __init__(self, vocab_size, embed_dim, hidden_dim, seq_len): | |
super().__init__() | |
self.embedding = nn.Embedding(vocab_size, embed_dim) | |
self.fc1 = nn.Linear(embed_dim * seq_len, hidden_dim) | |
self.relu = nn.ReLU() | |
self.fc2 = nn.Linear(hidden_dim, vocab_size) | |
def forward(self, x): | |
x = self.embedding(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc1(x) | |
x = self.relu(x) | |
return self.fc2(x) | |
# ====== LOAD SAVED MODEL ====== | |
checkpoint = torch.load("text_predictor_model.pth", map_location=device, weights_only=False) | |
vocab = checkpoint['vocab'] | |
word2idx = checkpoint['word2idx'] | |
idx2word = checkpoint['idx2word'] | |
SEQ_LEN = checkpoint['seq_len'] | |
EMBED_DIM = 64 | |
HIDDEN_DIM = 128 | |
model = TextPredictor(len(vocab), EMBED_DIM, HIDDEN_DIM, SEQ_LEN).to(device) | |
model.load_state_dict(checkpoint['model_state_dict']) | |
model.eval() | |
# ====== TEXT GENERATION ====== | |
def generate_text(prompt, length=60, temperature=1.0): | |
prompt = prompt.lower() if prompt else '' | |
prompt = re.sub(r'[^a-z0-9\s<>/]', '', prompt) | |
words = prompt.split() | |
if len(words) < SEQ_LEN: | |
words = [''] * (SEQ_LEN - len(words)) + words | |
current_seq = [word2idx.get(w, 0) for w in words[-SEQ_LEN:]] | |
generated = words.copy() | |
for _ in range(length): | |
x = torch.tensor([current_seq], dtype=torch.long, device=device) | |
with torch.no_grad(): | |
logits = model(x) | |
probs = torch.softmax(logits / temperature, dim=1).squeeze() | |
next_idx = torch.multinomial(probs, num_samples=1).item() | |
next_word = idx2word[next_idx] | |
generated.append(next_word) | |
current_seq = current_seq[1:] + [next_idx] | |
return ' '.join(generated) | |
# ====== GRADIO INTERFACE ====== | |
def gradio_generate(artist, user_prompt, length, temperature): | |
artist = artist.strip().lower() | |
user_prompt = user_prompt.strip().lower() | |
artist_tag = f"<{artist}>" if artist else "<None>" | |
full_prompt = f"{artist_tag} {user_prompt}" if user_prompt else artist_tag | |
output = generate_text(full_prompt, length=length, temperature=temperature) | |
return output | |
with gr.Blocks() as demo: | |
gr.Markdown("## π€ AI Lyrics Generator") | |
with gr.Row(): | |
artist_input = gr.Textbox(label="Artist (leave empty for <None>)") | |
prompt_input = gr.Textbox(label="Prompt (optional)") | |
with gr.Row(): | |
length_slider = gr.Slider(10, 200, value=60, label="Text Length") | |
temp_slider = gr.Slider(0.5, 1.5, value=1.0, step=0.1, label="Temperature (0.5 = more predictable, 1.5 = more creative)") | |
generate_button = gr.Button("ποΈ Generate") | |
output_box = gr.Textbox(lines=20, label="Generated Lyrics") | |
generate_button.click(fn=gradio_generate, | |
inputs=[artist_input, prompt_input, length_slider, temp_slider], | |
outputs=output_box) | |
# ====== LAUNCH ====== | |
if __name__ == "__main__": | |
demo.launch() |