""" # Melody Generation Model Development # Project: Opentunes.ai This notebook implements a Transformer-based melody generation model. The model takes text prompts and generates musical melodies in MIDI format. Key Features: - Text-to-melody generation - MIDI file handling - Transformer architecture - Training pipeline integration with HuggingFace Note: This is a starting point and might need adjustments based on: - Specific musical requirements - Available training data - Computational resources - Desired output format """ import torch import torch.nn as nn from transformers import ( AutoModelForAudio, AutoTokenizer, Trainer, TrainingArguments ) import librosa import numpy as np import pandas as pd import music21 from pathlib import Path import json import wandb # for experiment tracking # ===================================== # 1. Data Loading and Preprocessing # ===================================== class MelodyDataset(torch.utils.data.Dataset): """ Custom Dataset class for handling melody data. This class: - Loads MIDI files from a directory - Converts MIDI files to sequences of notes and durations - Provides data in format suitable for model training Args: data_dir (str): Directory containing MIDI files max_length (int): Maximum sequence length (default: 512) Features: - Handles variable-length MIDI files - Converts complex MIDI structures to simple note sequences - Implements efficient data loading and preprocessing """ def __init__(self, data_dir, max_length=512): self.data_dir = Path(data_dir) self.max_length = max_length self.midi_files = list(self.data_dir.glob("*.mid")) # Initialize tokenizer for text prompts self.tokenizer = AutoTokenizer.from_pretrained("t5-small") print(f"Found {len(self.midi_files)} MIDI files in {data_dir}") def midi_to_sequence(self, midi_path): """ Convert MIDI file to sequence of notes. Args: midi_path (Path): Path to MIDI file Returns: list: List of dictionaries containing note information Each dict has 'pitch', 'duration', and 'offset' Example output: [ {'pitch': 60, 'duration': 1.0, 'offset': 0.0}, # Middle C, quarter note {'pitch': 64, 'duration': 0.5, 'offset': 1.0}, # E, eighth note ... ] """ score = music21.converter.parse(str(midi_path)) notes = [] # Extract notes and their properties for n in score.flat.notesAndRests: if isinstance(n, music21.note.Note): notes.append({ 'pitch': n.pitch.midi, # MIDI pitch number (0-127) 'duration': n.duration.quarterLength, # Duration in quarter notes 'offset': n.offset # Start time in quarter notes }) return notes def __getitem__(self, idx): """ Get a single item from the dataset. Args: idx (int): Index of the item Returns: dict: Dictionary containing: - 'notes': Tensor of note pitches - 'durations': Tensor of note durations Note: Both tensors are padded/truncated to max_length """ midi_file = self.midi_files[idx] melody_sequence = self.midi_to_sequence(midi_file) # Convert to tensors with padding/truncation notes = torch.tensor([n['pitch'] for n in melody_sequence]) durations = torch.tensor([n['duration'] for n in melody_sequence]) # Pad or truncate sequences if len(notes) < self.max_length: # Pad with rest values pad_length = self.max_length - len(notes) notes = torch.cat([notes, torch.zeros(pad_length)]) durations = torch.cat([durations, torch.zeros(pad_length)]) else: # Truncate to max_length notes = notes[:self.max_length] durations = durations[:self.max_length] return { 'notes': notes, 'durations': durations, } def __len__(self): return len(self.midi_files) # ===================================== # 2. Model Architecture Development # ===================================== class MelodyTransformer(nn.Module): """ Transformer-based model for melody generation. Architecture Overview: 1. Embedding layers for notes, durations, and positions 2. Transformer encoder for sequence processing 3. Separate prediction heads for notes and durations Args: num_notes (int): Size of note vocabulary (default: 128 for MIDI range) max_duration (int): Number of possible duration values (default: 32) d_model (int): Dimension of the model (default: 512) nhead (int): Number of attention heads (default: 8) num_layers (int): Number of transformer layers (default: 6) Forward Pass: - Input: note sequence, duration sequence, position indices - Output: predictions for next note and duration """ def __init__(self, num_notes=128, # MIDI note range (0-127) max_duration=32, # Quantized duration values d_model=512, # Model dimension (as in original Transformer) nhead=8, # Multi-head attention num_layers=6): # Number of Transformer layers super().__init__() # Embedding layers self.note_embedding = nn.Embedding( num_embeddings=num_notes, embedding_dim=d_model, padding_idx=0 # Use 0 for padding ) self.duration_embedding = nn.Embedding( num_embeddings=max_duration, embedding_dim=d_model, padding_idx=0 ) self.position_embedding = nn.Embedding( num_embeddings=1024, # Maximum sequence length embedding_dim=d_model ) # Transformer architecture encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=4*d_model, # As per original Transformer paper dropout=0.1, activation='gelu' # Modern activation function ) self.transformer = nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=num_layers, norm=nn.LayerNorm(d_model) ) # Output heads self.note_head = nn.Sequential( nn.Linear(d_model, d_model), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_model, num_notes) ) self.duration_head = nn.Sequential( nn.Linear(d_model, d_model), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_model, max_duration) ) def forward(self, notes, durations, positions): """ Forward pass through the model. Args: notes (torch.Tensor): Shape [batch_size, seq_length] Contains MIDI note numbers durations (torch.Tensor): Shape [batch_size, seq_length] Contains quantized duration values positions (torch.Tensor): Shape [batch_size, seq_length] Contains position indices Returns: tuple: (note_logits, duration_logits) - note_logits: Shape [batch_size, seq_length, num_notes] - duration_logits: Shape [batch_size, seq_length, max_duration] Note: The model predicts both the next note and its duration simultaneously, allowing for coherent melody generation. """ # Get embeddings for each component note_emb = self.note_embedding(notes) # [B, S, D] duration_emb = self.duration_embedding(durations) # [B, S, D] pos_emb = self.position_embedding(positions) # [B, S, D] # Combine embeddings # Sum embeddings as in original Transformer paper x = note_emb + duration_emb + pos_emb # [B, S, D] # Apply Transformer # Note: Need to reshape for Transformer which expects [S, B, D] x = x.transpose(0, 1) x = self.transformer(x) x = x.transpose(0, 1) # Back to [B, S, D] # Generate predictions note_logits = self.note_head(x) # [B, S, num_notes] duration_logits = self.duration_head(x) # [B, S, max_duration] return note_logits, duration_logits def generate(self, prompt, max_length=512, temperature=1.0): """ Generate a melody from a starting prompt. Args: prompt (dict): Initial notes and durations max_length (int): Maximum sequence length to generate temperature (float): Sampling temperature (higher = more random) Returns: tuple: (generated_notes, generated_durations) Example: >>> model = MelodyTransformer() >>> prompt = {'notes': [60, 64, 67], 'durations': [1.0, 1.0, 1.0]} >>> notes, durations = model.generate(prompt) """ self.eval() # Set to evaluation mode with torch.no_grad(): # Initialize with prompt current_notes = torch.tensor(prompt['notes']).unsqueeze(0) current_durations = torch.tensor(prompt['durations']).unsqueeze(0) generated_notes = list(prompt['notes']) generated_durations = list(prompt['durations']) # Generate one note at a time for i in range(len(prompt['notes']), max_length): # Create position tensor positions = torch.arange(len(generated_notes)).unsqueeze(0) # Get predictions note_logits, duration_logits = self( current_notes, current_durations, positions ) # Sample from logits using temperature note_probs = F.softmax(note_logits[:, -1] / temperature, dim=-1) duration_probs = F.softmax(duration_logits[:, -1] / temperature, dim=-1) next_note = torch.multinomial(note_probs, 1) next_duration = torch.multinomial(duration_probs, 1) # Append to generated sequence generated_notes.append(next_note.item()) generated_durations.append(next_duration.item()) # Update current sequence current_notes = torch.tensor(generated_notes).unsqueeze(0) current_durations = torch.tensor(generated_durations).unsqueeze(0) return generated_notes, generated_durations # ===================================== # 3. Training Pipeline # ===================================== class MelodyTrainer: """ Custom training pipeline for the melody generation model. Features: - Automated training loop - Validation monitoring - Checkpoint saving - Logging and metrics tracking Args: model (MelodyTransformer): The model to train config (dict): Training configuration device (str): Device to train on ('cuda' or 'cpu') """ def __init__(self, model, config, device='cuda'): self.model = model.to(device) self.config = config self.device = device # Initialize training components self.criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=config['learning_rate'], weight_decay=config.get('weight_decay', 0.01) ) # Learning rate scheduler self.scheduler = torch.optim.lr_scheduler.OneCycleLR( self.optimizer, max_lr=config['learning_rate'], epochs=config['epochs'], steps_per_epoch=config['steps_per_epoch'] ) # Initialize wandb for experiment tracking if config.get('use_wandb', False): wandb.init( project="opentunes-melody", config=config, name=f"melody_training_{datetime.now().strftime('%Y%m%d_%H%M')}" ) def train_epoch(self, train_loader): """ Train for one epoch. Args: train_loader (DataLoader): Training data loader Returns: dict: Training metrics for this epoch """ self.model.train() epoch_loss = 0 epoch_note_acc = 0 epoch_dur_acc = 0 num_batches = 0 for batch in tqdm(train_loader, desc="Training"): # Move batch to device notes = batch['notes'].to(self.device) durations = batch['durations'].to(self.device) positions = torch.arange(notes.size(1)).unsqueeze(0).expand( notes.size(0), -1).to(self.device) # Forward pass note_logits, duration_logits = self.model(notes, durations, positions) # Calculate loss # Shift sequences for next-token prediction note_loss = self.criterion( note_logits[:, :-1].reshape(-1, note_logits.size(-1)), notes[:, 1:].reshape(-1) ) duration_loss = self.criterion( duration_logits[:, :-1].reshape(-1, duration_logits.size(-1)), durations[:, 1:].reshape(-1) ) loss = note_loss + duration_loss # Backward pass self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() self.scheduler.step() # Calculate metrics with torch.no_grad(): note_preds = note_logits.argmax(dim=-1) dur_preds = duration_logits.argmax(dim=-1) note_acc = (note_preds[:, :-1] == notes[:, 1:]).float().mean() dur_acc = (dur_preds[:, :-1] == durations[:, 1:]).float().mean() # Update running metrics epoch_loss += loss.item() epoch_note_acc += note_acc.item() epoch_dur_acc += dur_acc.item() num_batches += 1 # Log batch metrics if self.config.get('use_wandb', False): wandb.log({ 'batch_loss': loss.item(), 'note_accuracy': note_acc.item(), 'duration_accuracy': dur_acc.item(), 'learning_rate': self.scheduler.get_last_lr()[0] }) # Calculate epoch metrics metrics = { 'loss': epoch_loss / num_batches, 'note_accuracy': epoch_note_acc / num_batches, 'duration_accuracy': epoch_dur_acc / num_batches } return metrics def validate(self, val_loader): """ Validate the model. Args: val_loader (DataLoader): Validation data loader Returns: dict: Validation metrics """ self.model.eval() val_loss = 0 val_note_acc = 0 val_dur_acc = 0 num_batches = 0 with torch.no_grad(): for batch in tqdm(val_loader, desc="Validation"): notes = batch['notes'].to(self.device) durations = batch['durations'].to(self.device) positions = torch.arange(notes.size(1)).unsqueeze(0).expand( notes.size(0), -1).to(self.device) # Forward pass note_logits, duration_logits = self.model(notes, durations, positions) # Calculate metrics (similar to training) note_loss = self.criterion( note_logits[:, :-1].reshape(-1, note_logits.size(-1)), notes[:, 1:].reshape(-1) ) duration_loss = self.criterion( duration_logits[:, :-1].reshape(-1, duration_logits.size(-1)), durations[:, 1:].reshape(-1) ) loss = note_loss + duration_loss note_preds = note_logits.argmax(dim=-1) dur_preds = duration_logits.argmax(dim=-1) note_acc = (note_preds[:, :-1] == notes[:, 1:]).float().mean() dur_acc = (dur_preds[:, :-1] == durations[:, 1:]).float().mean() val_loss += loss.item() val_note_acc += note_acc.item() val_dur_acc += dur_acc.item() num_batches += 1 metrics = { 'val_loss': val_loss / num_batches, 'val_note_accuracy': val_note_acc / num_batches, 'val_duration_accuracy': val_dur_acc / num_batches } return metrics def train(self, train_loader, val_loader): """ Full training loop. Args: train_loader (DataLoader): Training data loader val_loader (DataLoader): Validation data loader """ best_val_loss = float('inf') for epoch in range(self.config['epochs']): print(f"\nEpoch {epoch+1}/{self.config['epochs']}") # Training phase train_metrics = self.train_epoch(train_loader) print(f"Training metrics: {train_metrics}") # Validation phase val_metrics = self.validate(val_loader) print(f"Validation metrics: {val_metrics}") # Save checkpoint if best so far if val_metrics['val_loss'] < best_val_loss: best_val_loss = val_metrics['val_loss'] self.save_checkpoint( f"models/melody-gen/weights/v0.1.0/best_model.pth", epoch, train_metrics, val_metrics ) # Log epoch metrics if self.config.get('use_wandb', False): wandb.log({ 'epoch': epoch, **train_metrics, **val_metrics }) def save_checkpoint(self, path, epoch, train_metrics, val_metrics): """ Save model checkpoint. Args: path (str): Path to save checkpoint epoch (int): Current epoch train_metrics (dict): Training metrics val_metrics (dict): Validation metrics """ checkpoint = { 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'train_metrics': train_metrics, 'val_metrics': val_metrics, 'config': self.config } torch.save(checkpoint, path) print(f"Checkpoint saved to {path}") # ===================================== # 4. Evaluation Functions # ===================================== class MelodyEvaluator: """ Comprehensive evaluation suite for melody generation models. Features: - Note accuracy metrics - Musical quality assessment - Style consistency checking - Sample generation and analysis Args: model (MelodyTransformer): Trained model to evaluate device (str): Device to run evaluation on """ def __init__(self, model, device='cuda'): self.model = model.to(device) self.device = device self.model.eval() # Set model to evaluation mode def evaluate_metrics(self, test_loader): """ Compute quantitative metrics on test set. Args: test_loader (DataLoader): Test data loader Returns: dict: Dictionary of evaluation metrics """ metrics = { 'note_accuracy': 0, 'rhythm_accuracy': 0, 'sequence_coherence': 0, 'scale_consistency': 0 } num_batches = 0 with torch.no_grad(): for batch in tqdm(test_loader, desc="Evaluating"): notes = batch['notes'].to(self.device) durations = batch['durations'].to(self.device) positions = torch.arange(notes.size(1)).unsqueeze(0).expand( notes.size(0), -1).to(self.device) # Get model predictions note_logits, duration_logits = self.model(notes, durations, positions) # Calculate basic accuracy note_preds = note_logits.argmax(dim=-1) dur_preds = duration_logits.argmax(dim=-1) metrics['note_accuracy'] += (note_preds[:, :-1] == notes[:, 1:]).float().mean().item() metrics['rhythm_accuracy'] += (dur_preds[:, :-1] == durations[:, 1:]).float().mean().item() # Calculate musical coherence metrics metrics['sequence_coherence'] += self._calculate_coherence(note_preds) metrics['scale_consistency'] += self._check_scale_consistency(note_preds) num_batches += 1 # Average metrics for key in metrics: metrics[key] /= num_batches return metrics def _calculate_coherence(self, note_sequence): """ Calculate musical coherence score. Checks for: - Melodic intervals (steps vs leaps) - Phrase structure - Repetition patterns Args: note_sequence (torch.Tensor): Predicted note sequence Returns: float: Coherence score between 0 and 1 """ # Convert to numpy for music21 processing notes = note_sequence.cpu().numpy() # Calculate interval distribution intervals = np.diff(notes, axis=1) step_ratio = np.mean(np.abs(intervals) <= 2) # Proportion of stepwise motion # Check for phrase repetition phrase_score = self._analyze_phrases(notes) # Combine metrics coherence_score = 0.6 * step_ratio + 0.4 * phrase_score return coherence_score def _check_scale_consistency(self, note_sequence): """ Check if generated notes follow consistent scale patterns. Args: note_sequence (torch.Tensor): Predicted note sequence Returns: float: Scale consistency score between 0 and 1 """ notes = note_sequence.cpu().numpy() # Create pitch class histogram pitch_classes = notes % 12 histogram = np.bincount(pitch_classes.flatten(), minlength=12) # Check against common scales major_scale = np.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1]) minor_scale = np.array([1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0]) # Calculate consistency scores major_score = np.sum((histogram > 0) == major_scale) / 12 minor_score = np.sum((histogram > 0) == minor_scale) / 12 return max(major_score, minor_score) def generate_and_evaluate_samples(self, num_samples=10, max_length=512): """ Generate and evaluate multiple melody samples. Args: num_samples (int): Number of samples to generate max_length (int): Maximum length of each sample Returns: tuple: (generated_samples, evaluation_results) """ samples = [] results = [] for i in range(num_samples): # Generate sample prompt = { 'notes': [60], # Start with middle C 'durations': [1.0] # Quarter note } notes, durations = self.model.generate( prompt, max_length=max_length, temperature=0.8 ) # Evaluate sample sample_metrics = { 'melodic_range': self._calculate_melodic_range(notes), 'rhythm_variety': self._calculate_rhythm_variety(durations), 'musical_coherence': self._evaluate_musical_qualities(notes, durations) } samples.append({'notes': notes, 'durations': durations}) results.append(sample_metrics) # Save generated sample self._save_sample( notes, durations, f"models/melody-gen/examples/generated_samples/sample_{i+1}.mid" ) return samples, results def _calculate_melodic_range(self, notes): """ Calculate the melodic range and distribution. Args: notes (list): List of MIDI note numbers Returns: dict: Melodic range statistics """ return { 'range': max(notes) - min(notes), 'mean': np.mean(notes), 'std': np.std(notes) } def _calculate_rhythm_variety(self, durations): """ Analyze rhythm patterns and variety. Args: durations (list): List of note durations Returns: dict: Rhythm statistics """ return { 'unique_values': len(set(durations)), 'variance': np.var(durations), 'pattern_complexity': len(set(zip(durations[:-1], durations[1:]))) } def _evaluate_musical_qualities(self, notes, durations): """ Evaluate musical qualities of the generated melody. Checks for: - Phrase structure - Melodic contour - Rhythmic patterns - Musical tension and resolution Args: notes (list): List of MIDI note numbers durations (list): List of note durations Returns: dict: Musical quality metrics """ # Convert to music21 stream for analysis stream = self._create_music21_stream(notes, durations) return { 'phrase_structure': self._analyze_phrases(stream), 'melodic_contour': self._analyze_contour(notes), 'rhythmic_complexity': self._analyze_rhythm(durations), 'tension_resolution': self._analyze_tension(notes) } def _save_sample(self, notes, durations, filepath): """ Save generated sample as MIDI file. Args: notes (list): List of MIDI note numbers durations (list): List of note durations filepath (str): Path to save MIDI file """ stream = music21.stream.Stream() for note, duration in zip(notes, durations): n = music21.note.Note(note) n.duration = music21.duration.Duration(duration) stream.append(n) stream.write('midi', fp=filepath) def generate_evaluation_report(self, test_loader): """ Generate comprehensive evaluation report. Args: test_loader (DataLoader): Test data loader Returns: dict: Complete evaluation report """ # Basic metrics metrics = self.evaluate_metrics(test_loader) # Generate and evaluate samples samples, sample_results = self.generate_and_evaluate_samples() # Compile complete report report = { 'quantitative_metrics': metrics, 'sample_evaluations': sample_results, 'generation_timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'model_version': '0.1.0' } # Save report with open('models/melody-gen/examples/evaluation_report.json', 'w') as f: json.dump(report, f, indent=2) return report # ===================================== # 5. Generation and Inference # ===================================== class MelodyGenerator: """ High-level interface for generating melodies using trained model. Features: - Text-to-melody generation - Style conditioning - Batch generation - Format conversion and export Args: model (MelodyTransformer): Trained model device (str): Device to run generation on config (dict): Generation parameters """ def __init__(self, model, device='cuda', config=None): self.model = model.to(device) self.device = device self.model.eval() # Default generation config self.config = { 'temperature': 0.8, 'max_length': 512, 'top_k': 50, 'top_p': 0.95, 'repetition_penalty': 1.2 } if config: self.config.update(config) def generate_from_prompt(self, prompt, style=None): """ Generate melody from text prompt. Args: prompt (str): Text description of desired melody style (dict, optional): Style parameters { 'genre': 'pop/jazz/classical', 'tempo': beats per minute, 'mood': 'happy/sad/energetic' } Returns: dict: Generated melody information { 'notes': List of MIDI notes, 'durations': List of note durations, 'midi_path': Path to saved MIDI file, 'metadata': Generation metadata } """ # Process prompt and style generation_params = self._prepare_generation_params(prompt, style) with torch.no_grad(): # Initialize sequence with start token current_notes = torch.tensor([[60]]).to(self.device) # Middle C current_durations = torch.tensor([[1.0]]).to(self.device) # Quarter note generated_notes = [] generated_durations = [] # Generate sequence for i in range(self.config['max_length']): # Get position encoding position = torch.arange(current_notes.size(1)).unsqueeze(0).to(self.device) # Get predictions note_logits, duration_logits = self.model( current_notes, current_durations, position ) # Apply temperature and sampling strategies next_note = self._sample_from_logits( note_logits[:, -1], temperature=generation_params['temperature'], top_k=generation_params['top_k'], top_p=generation_params['top_p'] ) next_duration = self._sample_from_logits( duration_logits[:, -1], temperature=generation_params['temperature'] ) # Apply repetition penalty if len(generated_notes) > 0: next_note = self._apply_repetition_penalty( next_note, generated_notes, generation_params['repetition_penalty'] ) # Append to sequences generated_notes.append(next_note.item()) generated_durations.append(next_duration.item()) # Update input sequences current_notes = torch.tensor([generated_notes]).to(self.device) current_durations = torch.tensor([generated_durations]).to(self.device) # Check for end condition if self._check_end_condition(generated_notes, generated_durations): break # Post-process and save return self._post_process_and_save( generated_notes, generated_durations, prompt, style ) def batch_generate(self, prompts, styles=None): """ Generate multiple melodies in batch. Args: prompts (list): List of text prompts styles (list, optional): List of style parameters Returns: list: List of generated melodies """ results = [] for i, prompt in enumerate(prompts): style = styles[i] if styles else None result = self.generate_from_prompt(prompt, style) results.append(result) return results def _prepare_generation_params(self, prompt, style): """ Prepare generation parameters based on prompt and style. Args: prompt (str): Text prompt style (dict): Style parameters Returns: dict: Generation parameters """ params = self.config.copy() # Adjust parameters based on style if style: if style.get('genre') == 'classical': params['temperature'] *= 0.9 # More conservative params['repetition_penalty'] *= 1.1 elif style.get('genre') == 'jazz': params['temperature'] *= 1.1 # More experimental params['top_k'] *= 1.2 if style.get('mood') == 'energetic': params['temperature'] *= 1.1 elif style.get('mood') == 'calm': params['temperature'] *= 0.9 return params def _sample_from_logits(self, logits, temperature=1.0, top_k=None, top_p=None): """ Sample from logits with temperature and optional top-k/top-p filtering. Args: logits (torch.Tensor): Raw logits temperature (float): Sampling temperature top_k (int, optional): Top-k filtering parameter top_p (float, optional): Nucleus sampling parameter Returns: torch.Tensor: Sampled token """ logits = logits / temperature # Top-k filtering if top_k is not None: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = float('-inf') # Top-p filtering (nucleus sampling) if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter( dim=-1, index=sorted_indices, src=sorted_indices_to_remove ) logits[indices_to_remove] = float('-inf') # Sample probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, 1) def _post_process_and_save(self, notes, durations, prompt, style): """ Post-process and save generated melody. Args: notes (list): Generated notes durations (list): Generated durations prompt (str): Original prompt style (dict): Style parameters Returns: dict: Generation results and metadata """ # Create timestamp timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') # Create MIDI file midi_path = f"models/melody-gen/examples/generated_samples/melody_{timestamp}.mid" self._save_to_midi(notes, durations, midi_path) # Prepare metadata metadata = { 'timestamp': timestamp, 'prompt': prompt, 'style': style, 'generation_params': self.config, 'stats': { 'length': len(notes), 'pitch_range': max(notes) - min(notes), 'unique_pitches': len(set(notes)), 'unique_durations': len(set(durations)) } } # Save metadata metadata_path = f"models/melody-gen/examples/generated_samples/melody_{timestamp}.json" with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) return { 'notes': notes, 'durations': durations, 'midi_path': midi_path, 'metadata': metadata } # ===================================== # 6. Utility Functions and Helpers # ===================================== class MelodyUtils: """ Utility functions for melody processing and manipulation. """ @staticmethod def save_to_midi(notes, durations, path): """ Save melody to MIDI file with enhanced musical properties. Args: notes (list): MIDI note numbers durations (list): Note durations path (str): Output path """ stream = music21.stream.Stream() # Add time signature and tempo stream.append(music21.meter.TimeSignature('4/4')) stream.append(music21.tempo.MetronomeMark(number=120)) # Add notes with velocity for dynamics for note, duration in zip(notes, durations): n = music21.note.Note(note) n.duration = music21.duration.Duration(duration) # Add velocity (dynamics) based on position in phrase n.volume.velocity = MelodyUtils._calculate_velocity(note, notes) stream.append(n) stream.write('midi', fp=path) @staticmethod def _calculate_velocity(note, notes_sequence): """Calculate appropriate velocity for musical expression.""" base_velocity = 64 # Emphasize phrase beginnings and high points if note == max(notes_sequence): return min(base_velocity + 32, 127) return base_velocity # ===================================== # 7. Enhanced Generation Features # ===================================== class EnhancedMelodyGenerator(MelodyGenerator): """ Extended melody generator with additional features. """ def generate_with_structure(self, prompt, form="AABA"): """ Generate melody with specific musical form. Args: prompt (str): Text prompt form (str): Musical form (e.g., "AABA", "ABAC") Returns: dict: Generated melody with structural sections """ sections = {} full_melody = [] for section in form: if section not in sections: # Generate new section result = self.generate_from_prompt( prompt + f" for section {section}", {'section': section} ) sections[section] = (result['notes'], result['durations']) # Add section to full melody notes, durations = sections[section] full_melody.extend(zip(notes, durations)) return self._post_process_structured_melody(full_melody, form) def generate_with_harmony(self, prompt, chord_progression=None): """ Generate melody with harmonic constraints. Args: prompt (str): Text prompt chord_progression (list): Optional chord progression Returns: dict: Generated melody with harmonic context """ if chord_progression is None: chord_progression = self._generate_chord_progression() # Generate melody considering harmony generation_params = self._prepare_generation_params(prompt, { 'harmony': chord_progression }) return self.generate_from_prompt(prompt, generation_params) # ===================================== # 8. Example Usage Scenarios # ===================================== def example_usage(): """Example usage of the melody generation system.""" # 1. Basic melody generation generator = MelodyGenerator(model) result = generator.generate_from_prompt( "Create an upbeat pop melody in C major" ) # 2. Style-conditional generation styled_result = generator.generate_from_prompt( "Create a jazz melody", style={ 'genre': 'jazz', 'tempo': 120, 'mood': 'energetic' } ) # 3. Structured generation enhanced_generator = EnhancedMelodyGenerator(model) structured_result = enhanced_generator.generate_with_structure( "Create a memorable melody", form="AABA" ) # 4. Batch generation prompts = [ "Happy birthday song style", "Sad emotional melody", "Energetic dance tune" ] batch_results = generator.batch_generate(prompts) # 5. Generation with harmony harmonic_result = enhanced_generator.generate_with_harmony( "Create a melody", chord_progression=["C", "Am", "F", "G"] ) return { 'basic': result, 'styled': styled_result, 'structured': structured_result, 'batch': batch_results, 'harmonic': harmonic_result } # ===================================== # 9. Integration Example # ===================================== def run_complete_pipeline(): """ Complete pipeline from training to generation. """ # 1. Load configuration with open('models/melody-gen/config/model_config.json') as f: model_config = json.load(f) # 2. Initialize model model = MelodyTransformer(**model_config) # 3. Load dataset train_dataset = MelodyDataset('datasets/train') val_dataset = MelodyDataset('datasets/val') test_dataset = MelodyDataset('datasets/test') # 4. Training trainer = MelodyTrainer(model, model_config) trainer.train(train_dataset, val_dataset) # 5. Evaluation evaluator = MelodyEvaluator(model) eval_results = evaluator.generate_evaluation_report(test_dataset) # 6. Generation generator = MelodyGenerator(model) samples = generator.generate_from_prompt( "Create an original melody", style={'genre': 'pop', 'mood': 'happy'} ) return { 'evaluation': eval_results, 'samples': samples } if __name__ == "__main__": # Run example usage results = example_usage() # Run complete pipeline pipeline_results = run_complete_pipeline()