import torch import torch.nn as nn import re import gradio as gr # ====== DEVICE ====== device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ====== MODEL ====== class TextPredictor(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, seq_len): super().__init__() self.embedding = nn.Embedding(vocab_size, embed_dim) self.fc1 = nn.Linear(embed_dim * seq_len, hidden_dim) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_dim, vocab_size) def forward(self, x): x = self.embedding(x) x = x.view(x.size(0), -1) x = self.fc1(x) x = self.relu(x) return self.fc2(x) # ====== LOAD SAVED MODEL ====== checkpoint = torch.load("text_predictor_model.pth", map_location=device, weights_only=False) vocab = checkpoint['vocab'] word2idx = checkpoint['word2idx'] idx2word = checkpoint['idx2word'] SEQ_LEN = checkpoint['seq_len'] EMBED_DIM = 64 HIDDEN_DIM = 128 model = TextPredictor(len(vocab), EMBED_DIM, HIDDEN_DIM, SEQ_LEN).to(device) model.load_state_dict(checkpoint['model_state_dict']) model.eval() # ====== TEXT GENERATION ====== def generate_text(prompt, length=60, temperature=1.0): prompt = prompt.lower() if prompt else '' prompt = re.sub(r'[^a-z0-9\s<>/]', '', prompt) words = prompt.split() if len(words) < SEQ_LEN: words = [''] * (SEQ_LEN - len(words)) + words current_seq = [word2idx.get(w, 0) for w in words[-SEQ_LEN:]] generated = words.copy() for _ in range(length): x = torch.tensor([current_seq], dtype=torch.long, device=device) with torch.no_grad(): logits = model(x) probs = torch.softmax(logits / temperature, dim=1).squeeze() next_idx = torch.multinomial(probs, num_samples=1).item() next_word = idx2word[next_idx] generated.append(next_word) current_seq = current_seq[1:] + [next_idx] return ' '.join(generated) # ====== GRADIO INTERFACE ====== def gradio_generate(artist, user_prompt, length, temperature): artist = artist.strip().lower() user_prompt = user_prompt.strip().lower() artist_tag = f"<{artist}>" if artist else "" full_prompt = f"{artist_tag} {user_prompt}" if user_prompt else artist_tag output = generate_text(full_prompt, length=length, temperature=temperature) return output with gr.Blocks() as demo: gr.Markdown("## 🎤 AI Lyrics Generator") with gr.Row(): artist_input = gr.Textbox(label="Artist (leave empty for )") prompt_input = gr.Textbox(label="Prompt (optional)") with gr.Row(): length_slider = gr.Slider(10, 200, value=60, label="Text Length") temp_slider = gr.Slider(0.5, 1.5, value=1.0, step=0.1, label="Temperature (0.5 = more predictable, 1.5 = more creative)") generate_button = gr.Button("🎙️ Generate") output_box = gr.Textbox(lines=20, label="Generated Lyrics") generate_button.click(fn=gradio_generate, inputs=[artist_input, prompt_input, length_slider, temp_slider], outputs=output_box) # ====== LAUNCH ====== if __name__ == "__main__": demo.launch()