|
|
|
""" |
|
CompI Phase 1.E: LoRA Fine-tuning for Personal Style |
|
|
|
This script implements LoRA (Low-Rank Adaptation) fine-tuning for Stable Diffusion |
|
to learn your personal artistic style. |
|
|
|
Usage: |
|
python src/generators/compi_phase1e_lora_training.py --dataset-dir datasets/my_style |
|
python src/generators/compi_phase1e_lora_training.py --help |
|
""" |
|
|
|
import os |
|
import argparse |
|
import json |
|
import math |
|
from pathlib import Path |
|
from typing import Dict, List, Optional |
|
import logging |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.utils.data import Dataset, DataLoader |
|
from PIL import Image |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
|
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
UNet2DConditionModel, |
|
DDPMScheduler, |
|
AutoencoderKL |
|
) |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from peft import LoraConfig, get_peft_model, TaskType |
|
|
|
|
|
|
|
DEFAULT_MODEL = "runwayml/stable-diffusion-v1-5" |
|
DEFAULT_RESOLUTION = 512 |
|
DEFAULT_BATCH_SIZE = 1 |
|
DEFAULT_LEARNING_RATE = 1e-4 |
|
DEFAULT_EPOCHS = 100 |
|
DEFAULT_LORA_RANK = 4 |
|
DEFAULT_LORA_ALPHA = 32 |
|
|
|
|
|
|
|
class StyleDataset(Dataset): |
|
"""Dataset class for LoRA fine-tuning.""" |
|
|
|
def __init__(self, dataset_dir: str, split: str = "train", resolution: int = 512): |
|
self.dataset_dir = Path(dataset_dir) |
|
self.split = split |
|
self.resolution = resolution |
|
|
|
|
|
self.images_dir = self.dataset_dir / split |
|
self.captions_file = self.dataset_dir / f"{split}_captions.txt" |
|
|
|
if not self.images_dir.exists(): |
|
raise FileNotFoundError(f"Images directory not found: {self.images_dir}") |
|
|
|
if not self.captions_file.exists(): |
|
raise FileNotFoundError(f"Captions file not found: {self.captions_file}") |
|
|
|
|
|
self.image_captions = {} |
|
with open(self.captions_file, 'r') as f: |
|
for line in f: |
|
if ':' in line: |
|
filename, caption = line.strip().split(':', 1) |
|
self.image_captions[filename.strip()] = caption.strip() |
|
|
|
|
|
self.image_files = [f for f in os.listdir(self.images_dir) |
|
if f.lower().endswith(('.png', '.jpg', '.jpeg'))] |
|
|
|
|
|
self.image_files = [f for f in self.image_files if f in self.image_captions] |
|
|
|
print(f"Loaded {len(self.image_files)} images for {split} split") |
|
|
|
def __len__(self): |
|
return len(self.image_files) |
|
|
|
def __getitem__(self, idx): |
|
filename = self.image_files[idx] |
|
image_path = self.images_dir / filename |
|
caption = self.image_captions[filename] |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
image = image.resize((self.resolution, self.resolution), Image.Resampling.LANCZOS) |
|
|
|
|
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = (image - 0.5) / 0.5 |
|
image = torch.from_numpy(image).permute(2, 0, 1) |
|
|
|
return { |
|
'pixel_values': image, |
|
'caption': caption, |
|
'filename': filename |
|
} |
|
|
|
|
|
|
|
def setup_args(): |
|
"""Setup command line arguments.""" |
|
parser = argparse.ArgumentParser( |
|
description="CompI Phase 1.E: LoRA Fine-tuning for Personal Style", |
|
formatter_class=argparse.RawDescriptionHelpFormatter |
|
) |
|
|
|
parser.add_argument("--dataset-dir", required=True, |
|
help="Directory containing prepared dataset") |
|
|
|
parser.add_argument("--output-dir", |
|
help="Output directory for LoRA weights (default: lora_models/{style_name})") |
|
|
|
parser.add_argument("--model-name", default=DEFAULT_MODEL, |
|
help=f"Base Stable Diffusion model (default: {DEFAULT_MODEL})") |
|
|
|
parser.add_argument("--resolution", type=int, default=DEFAULT_RESOLUTION, |
|
help=f"Training resolution (default: {DEFAULT_RESOLUTION})") |
|
|
|
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE, |
|
help=f"Training batch size (default: {DEFAULT_BATCH_SIZE})") |
|
|
|
parser.add_argument("--learning-rate", type=float, default=DEFAULT_LEARNING_RATE, |
|
help=f"Learning rate (default: {DEFAULT_LEARNING_RATE})") |
|
|
|
parser.add_argument("--epochs", type=int, default=DEFAULT_EPOCHS, |
|
help=f"Number of training epochs (default: {DEFAULT_EPOCHS})") |
|
|
|
parser.add_argument("--lora-rank", type=int, default=DEFAULT_LORA_RANK, |
|
help=f"LoRA rank (default: {DEFAULT_LORA_RANK})") |
|
|
|
parser.add_argument("--lora-alpha", type=int, default=DEFAULT_LORA_ALPHA, |
|
help=f"LoRA alpha (default: {DEFAULT_LORA_ALPHA})") |
|
|
|
parser.add_argument("--save-steps", type=int, default=100, |
|
help="Save checkpoint every N steps") |
|
|
|
parser.add_argument("--validation-steps", type=int, default=50, |
|
help="Run validation every N steps") |
|
|
|
parser.add_argument("--mixed-precision", action="store_true", |
|
help="Use mixed precision training") |
|
|
|
parser.add_argument("--gradient-checkpointing", action="store_true", |
|
help="Use gradient checkpointing to save memory") |
|
|
|
return parser.parse_args() |
|
|
|
def load_models(model_name: str, device: str): |
|
"""Load Stable Diffusion components.""" |
|
print(f"Loading models from {model_name}...") |
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer") |
|
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder") |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(model_name, subfolder="vae") |
|
|
|
|
|
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet") |
|
|
|
|
|
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler") |
|
|
|
|
|
text_encoder.to(device) |
|
vae.to(device) |
|
unet.to(device) |
|
|
|
|
|
text_encoder.eval() |
|
vae.eval() |
|
unet.train() |
|
|
|
return tokenizer, text_encoder, vae, unet, noise_scheduler |
|
|
|
def setup_lora(unet: UNet2DConditionModel, lora_rank: int, lora_alpha: int): |
|
"""Setup LoRA adapters for UNet.""" |
|
print(f"Setting up LoRA with rank={lora_rank}, alpha={lora_alpha}") |
|
|
|
|
|
lora_config = LoraConfig( |
|
r=lora_rank, |
|
lora_alpha=lora_alpha, |
|
target_modules=[ |
|
"to_k", "to_q", "to_v", "to_out.0", |
|
"proj_in", "proj_out", |
|
"ff.net.0.proj", "ff.net.2" |
|
], |
|
lora_dropout=0.1, |
|
) |
|
|
|
|
|
unet = get_peft_model(unet, lora_config) |
|
|
|
|
|
trainable_params = sum(p.numel() for p in unet.parameters() if p.requires_grad) |
|
total_params = sum(p.numel() for p in unet.parameters()) |
|
|
|
print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)") |
|
|
|
return unet |
|
|
|
def encode_text(tokenizer, text_encoder, captions: List[str], device: str): |
|
"""Encode text captions.""" |
|
inputs = tokenizer( |
|
captions, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
|
|
with torch.no_grad(): |
|
text_embeddings = text_encoder(inputs.input_ids.to(device))[0] |
|
|
|
return text_embeddings |
|
|
|
def training_step(batch, unet, vae, text_encoder, tokenizer, noise_scheduler, device): |
|
"""Single training step.""" |
|
pixel_values = batch['pixel_values'].to(device) |
|
captions = batch['caption'] |
|
|
|
|
|
with torch.no_grad(): |
|
latents = vae.encode(pixel_values).latent_dist.sample() |
|
latents = latents * vae.config.scaling_factor |
|
|
|
|
|
noise = torch.randn_like(latents) |
|
batch_size = latents.shape[0] |
|
|
|
|
|
timesteps = torch.randint( |
|
0, noise_scheduler.config.num_train_timesteps, |
|
(batch_size,), device=device |
|
).long() |
|
|
|
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
|
|
|
text_embeddings = encode_text(tokenizer, text_encoder, captions, device) |
|
|
|
|
|
noise_pred = unet(noisy_latents, timesteps, text_embeddings).sample |
|
|
|
|
|
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
|
|
|
return loss |
|
|
|
def validate_model(val_dataloader, unet, vae, text_encoder, tokenizer, noise_scheduler, device): |
|
"""Validation step.""" |
|
unet.eval() |
|
total_loss = 0 |
|
num_batches = 0 |
|
|
|
with torch.no_grad(): |
|
for batch in val_dataloader: |
|
loss = training_step(batch, unet, vae, text_encoder, tokenizer, noise_scheduler, device) |
|
total_loss += loss.item() |
|
num_batches += 1 |
|
|
|
unet.train() |
|
return total_loss / num_batches if num_batches > 0 else 0 |
|
|
|
def save_lora_weights(unet, output_dir: Path, step: int): |
|
"""Save LoRA weights.""" |
|
checkpoint_dir = output_dir / f"checkpoint-{step}" |
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
unet.save_pretrained(checkpoint_dir) |
|
|
|
print(f"πΎ Saved checkpoint to: {checkpoint_dir}") |
|
return checkpoint_dir |
|
|
|
|
|
|
|
def train_lora(args): |
|
"""Main training function.""" |
|
print(f"π¨ CompI Phase 1.E: Starting LoRA Training") |
|
print("=" * 50) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"π₯οΈ Using device: {device}") |
|
|
|
|
|
dataset_dir = Path(args.dataset_dir) |
|
info_file = dataset_dir / "dataset_info.json" |
|
|
|
if info_file.exists(): |
|
with open(info_file) as f: |
|
dataset_info = json.load(f) |
|
style_name = dataset_info.get('style_name', 'custom_style') |
|
print(f"π― Training style: {style_name}") |
|
else: |
|
style_name = dataset_dir.name |
|
print(f"β οΈ No dataset info found, using directory name: {style_name}") |
|
|
|
|
|
if args.output_dir: |
|
output_dir = Path(args.output_dir) |
|
else: |
|
output_dir = Path("lora_models") / style_name |
|
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
print(f"π Output directory: {output_dir}") |
|
|
|
|
|
print(f"π Loading datasets...") |
|
train_dataset = StyleDataset(args.dataset_dir, "train", args.resolution) |
|
|
|
try: |
|
val_dataset = StyleDataset(args.dataset_dir, "validation", args.resolution) |
|
has_validation = True |
|
except FileNotFoundError: |
|
print("β οΈ No validation set found, using train set for validation") |
|
val_dataset = train_dataset |
|
has_validation = False |
|
|
|
|
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=True, |
|
num_workers=2, |
|
pin_memory=True |
|
) |
|
|
|
val_dataloader = DataLoader( |
|
val_dataset, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=2, |
|
pin_memory=True |
|
) |
|
|
|
|
|
tokenizer, text_encoder, vae, unet, noise_scheduler = load_models(args.model_name, device) |
|
|
|
|
|
unet = setup_lora(unet, args.lora_rank, args.lora_alpha) |
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
unet.parameters(), |
|
lr=args.learning_rate, |
|
betas=(0.9, 0.999), |
|
weight_decay=0.01, |
|
eps=1e-08 |
|
) |
|
|
|
|
|
total_steps = len(train_dataloader) * args.epochs |
|
print(f"π Total training steps: {total_steps}") |
|
|
|
|
|
print(f"\nπ Starting training...") |
|
global_step = 0 |
|
best_val_loss = float('inf') |
|
|
|
for epoch in range(args.epochs): |
|
print(f"\nπ
Epoch {epoch + 1}/{args.epochs}") |
|
|
|
epoch_loss = 0 |
|
progress_bar = tqdm(train_dataloader, desc=f"Training") |
|
|
|
for batch in progress_bar: |
|
|
|
loss = training_step(batch, unet, vae, text_encoder, tokenizer, noise_scheduler, device) |
|
|
|
|
|
loss.backward() |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
|
|
epoch_loss += loss.item() |
|
global_step += 1 |
|
|
|
|
|
progress_bar.set_postfix({ |
|
'loss': f"{loss.item():.4f}", |
|
'avg_loss': f"{epoch_loss / (progress_bar.n + 1):.4f}" |
|
}) |
|
|
|
|
|
if global_step % args.validation_steps == 0: |
|
val_loss = validate_model(val_dataloader, unet, vae, text_encoder, tokenizer, noise_scheduler, device) |
|
print(f"\nπ Step {global_step}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss:.4f}") |
|
|
|
|
|
if val_loss < best_val_loss: |
|
best_val_loss = val_loss |
|
save_lora_weights(unet, output_dir, global_step) |
|
|
|
|
|
if global_step % args.save_steps == 0: |
|
save_lora_weights(unet, output_dir, global_step) |
|
|
|
|
|
avg_epoch_loss = epoch_loss / len(train_dataloader) |
|
print(f"π Epoch {epoch + 1} complete. Average loss: {avg_epoch_loss:.4f}") |
|
|
|
|
|
final_checkpoint = save_lora_weights(unet, output_dir, global_step) |
|
|
|
|
|
training_info = { |
|
'style_name': style_name, |
|
'model_name': args.model_name, |
|
'total_steps': global_step, |
|
'epochs': args.epochs, |
|
'learning_rate': args.learning_rate, |
|
'lora_rank': args.lora_rank, |
|
'lora_alpha': args.lora_alpha, |
|
'final_checkpoint': str(final_checkpoint), |
|
'best_val_loss': best_val_loss |
|
} |
|
|
|
with open(output_dir / "training_info.json", 'w') as f: |
|
json.dump(training_info, f, indent=2) |
|
|
|
print(f"\nπ Training complete!") |
|
print(f"π LoRA weights saved to: {output_dir}") |
|
print(f"π‘ Next steps:") |
|
print(f" 1. Test your style: python src/generators/compi_phase1e_style_generation.py --lora-path {final_checkpoint}") |
|
print(f" 2. Integrate with UI: Use the style in your Streamlit interface") |
|
|
|
def main(): |
|
"""Main function.""" |
|
args = setup_args() |
|
|
|
try: |
|
train_lora(args) |
|
except Exception as e: |
|
print(f"β Training failed: {e}") |
|
import traceback |
|
traceback.print_exc() |
|
return 1 |
|
|
|
return 0 |
|
|
|
if __name__ == "__main__": |
|
exit(main()) |
|
|