|
|
|
""" |
|
CompI Phase 1.E: LoRA Style Management System |
|
|
|
Manage multiple LoRA styles, switch between them, and organize trained models. |
|
|
|
Usage: |
|
python src/generators/compi_phase1e_style_manager.py --list |
|
python src/generators/compi_phase1e_style_manager.py --info my_style |
|
python src/generators/compi_phase1e_style_manager.py --cleanup |
|
""" |
|
|
|
import os |
|
import argparse |
|
import json |
|
import shutil |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Dict, List, Optional, Tuple |
|
|
|
import pandas as pd |
|
|
|
|
|
|
|
LORA_MODELS_DIR = "lora_models" |
|
STYLES_CONFIG_FILE = "lora_styles_config.json" |
|
|
|
|
|
|
|
class LoRAStyleManager: |
|
"""Manager for LoRA styles and models.""" |
|
|
|
def __init__(self, models_dir: str = LORA_MODELS_DIR): |
|
self.models_dir = Path(models_dir) |
|
self.models_dir.mkdir(exist_ok=True) |
|
self.config_file = self.models_dir / STYLES_CONFIG_FILE |
|
self.config = self.load_config() |
|
|
|
def load_config(self) -> Dict: |
|
"""Load styles configuration.""" |
|
if self.config_file.exists(): |
|
with open(self.config_file) as f: |
|
return json.load(f) |
|
return {"styles": {}, "last_updated": datetime.now().isoformat()} |
|
|
|
def save_config(self): |
|
"""Save styles configuration.""" |
|
self.config["last_updated"] = datetime.now().isoformat() |
|
with open(self.config_file, 'w') as f: |
|
json.dump(self.config, f, indent=2) |
|
|
|
def scan_styles(self) -> Dict[str, Dict]: |
|
"""Scan for available LoRA styles.""" |
|
styles = {} |
|
|
|
for style_dir in self.models_dir.iterdir(): |
|
if not style_dir.is_dir() or style_dir.name.startswith('.'): |
|
continue |
|
|
|
|
|
checkpoints = list(style_dir.glob("checkpoint-*")) |
|
if not checkpoints: |
|
continue |
|
|
|
|
|
latest_checkpoint = max(checkpoints, key=lambda x: int(x.name.split('-')[1])) |
|
|
|
|
|
info_file = style_dir / "training_info.json" |
|
if info_file.exists(): |
|
with open(info_file) as f: |
|
training_info = json.load(f) |
|
else: |
|
training_info = {} |
|
|
|
|
|
dataset_info = {} |
|
for dataset_dir in [style_dir / "dataset", Path("datasets") / style_dir.name]: |
|
dataset_info_file = dataset_dir / "dataset_info.json" |
|
if dataset_info_file.exists(): |
|
with open(dataset_info_file) as f: |
|
dataset_info = json.load(f) |
|
break |
|
|
|
|
|
style_info = { |
|
"name": style_dir.name, |
|
"path": str(style_dir), |
|
"latest_checkpoint": str(latest_checkpoint), |
|
"checkpoints": [str(cp) for cp in checkpoints], |
|
"training_info": training_info, |
|
"dataset_info": dataset_info, |
|
"last_scanned": datetime.now().isoformat() |
|
} |
|
|
|
styles[style_dir.name] = style_info |
|
|
|
return styles |
|
|
|
def refresh_styles(self): |
|
"""Refresh the styles database.""" |
|
print("π Scanning for LoRA styles...") |
|
scanned_styles = self.scan_styles() |
|
|
|
|
|
self.config["styles"] = scanned_styles |
|
self.save_config() |
|
|
|
print(f"β
Found {len(scanned_styles)} LoRA style(s)") |
|
return scanned_styles |
|
|
|
def list_styles(self, detailed: bool = False) -> List[Dict]: |
|
"""List available styles.""" |
|
styles = self.config.get("styles", {}) |
|
|
|
if not styles: |
|
styles = self.refresh_styles() |
|
|
|
if detailed: |
|
return list(styles.values()) |
|
else: |
|
return [{"name": name, "checkpoints": len(info["checkpoints"])} |
|
for name, info in styles.items()] |
|
|
|
def get_style_info(self, style_name: str) -> Optional[Dict]: |
|
"""Get detailed information about a specific style.""" |
|
styles = self.config.get("styles", {}) |
|
return styles.get(style_name) |
|
|
|
def get_best_checkpoint(self, style_name: str) -> Optional[str]: |
|
"""Get the best checkpoint for a style.""" |
|
style_info = self.get_style_info(style_name) |
|
if not style_info: |
|
return None |
|
|
|
|
|
|
|
return style_info.get("latest_checkpoint") |
|
|
|
def delete_style(self, style_name: str, confirm: bool = False) -> bool: |
|
"""Delete a LoRA style.""" |
|
if not confirm: |
|
print("β οΈ Use --confirm to actually delete the style") |
|
return False |
|
|
|
style_dir = self.models_dir / style_name |
|
if not style_dir.exists(): |
|
print(f"β Style not found: {style_name}") |
|
return False |
|
|
|
try: |
|
shutil.rmtree(style_dir) |
|
|
|
|
|
if style_name in self.config.get("styles", {}): |
|
del self.config["styles"][style_name] |
|
self.save_config() |
|
|
|
print(f"β
Deleted style: {style_name}") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Error deleting style: {e}") |
|
return False |
|
|
|
def cleanup_checkpoints(self, style_name: str, keep_last: int = 3) -> int: |
|
"""Clean up old checkpoints, keeping only the most recent ones.""" |
|
style_dir = self.models_dir / style_name |
|
if not style_dir.exists(): |
|
print(f"β Style not found: {style_name}") |
|
return 0 |
|
|
|
checkpoints = list(style_dir.glob("checkpoint-*")) |
|
if len(checkpoints) <= keep_last: |
|
print(f"β
No cleanup needed for {style_name} ({len(checkpoints)} checkpoints)") |
|
return 0 |
|
|
|
|
|
checkpoints.sort(key=lambda x: int(x.name.split('-')[1])) |
|
|
|
|
|
to_remove = checkpoints[:-keep_last] |
|
removed_count = 0 |
|
|
|
for checkpoint in to_remove: |
|
try: |
|
shutil.rmtree(checkpoint) |
|
removed_count += 1 |
|
except Exception as e: |
|
print(f"β οΈ Failed to remove {checkpoint}: {e}") |
|
|
|
print(f"β
Cleaned up {removed_count} old checkpoints for {style_name}") |
|
return removed_count |
|
|
|
def export_style_info(self, output_file: str = None) -> str: |
|
"""Export styles information to CSV.""" |
|
styles = self.list_styles(detailed=True) |
|
|
|
if not styles: |
|
print("β No styles found") |
|
return "" |
|
|
|
|
|
rows = [] |
|
for style in styles: |
|
training_info = style.get("training_info", {}) |
|
dataset_info = style.get("dataset_info", {}) |
|
|
|
row = { |
|
"style_name": style["name"], |
|
"checkpoints": len(style["checkpoints"]), |
|
"latest_checkpoint": Path(style["latest_checkpoint"]).name, |
|
"total_steps": training_info.get("total_steps", "unknown"), |
|
"epochs": training_info.get("epochs", "unknown"), |
|
"learning_rate": training_info.get("learning_rate", "unknown"), |
|
"lora_rank": training_info.get("lora_rank", "unknown"), |
|
"dataset_images": dataset_info.get("total_images", "unknown"), |
|
"trigger_word": dataset_info.get("trigger_word", "unknown"), |
|
"last_scanned": style.get("last_scanned", "unknown") |
|
} |
|
rows.append(row) |
|
|
|
|
|
df = pd.DataFrame(rows) |
|
|
|
if output_file is None: |
|
output_file = f"lora_styles_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" |
|
|
|
df.to_csv(output_file, index=False) |
|
print(f"π Exported styles info to: {output_file}") |
|
return output_file |
|
|
|
|
|
|
|
def setup_args(): |
|
"""Setup command line arguments.""" |
|
parser = argparse.ArgumentParser( |
|
description="CompI Phase 1.E: LoRA Style Management", |
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
epilog=""" |
|
Examples: |
|
# List all available styles |
|
python %(prog)s --list |
|
|
|
# Get detailed info about a specific style |
|
python %(prog)s --info my_style |
|
|
|
# Refresh styles database |
|
python %(prog)s --refresh |
|
|
|
# Clean up old checkpoints |
|
python %(prog)s --cleanup my_style --keep 2 |
|
|
|
# Export styles information |
|
python %(prog)s --export styles_report.csv |
|
""" |
|
) |
|
|
|
parser.add_argument("--list", action="store_true", |
|
help="List all available LoRA styles") |
|
|
|
parser.add_argument("--list-detailed", action="store_true", |
|
help="List styles with detailed information") |
|
|
|
parser.add_argument("--info", metavar="STYLE_NAME", |
|
help="Show detailed information about a specific style") |
|
|
|
parser.add_argument("--refresh", action="store_true", |
|
help="Refresh the styles database") |
|
|
|
parser.add_argument("--cleanup", metavar="STYLE_NAME", |
|
help="Clean up old checkpoints for a style") |
|
|
|
parser.add_argument("--keep", type=int, default=3, |
|
help="Number of recent checkpoints to keep during cleanup") |
|
|
|
parser.add_argument("--delete", metavar="STYLE_NAME", |
|
help="Delete a LoRA style") |
|
|
|
parser.add_argument("--confirm", action="store_true", |
|
help="Confirm destructive operations") |
|
|
|
parser.add_argument("--export", metavar="OUTPUT_FILE", |
|
help="Export styles information to CSV") |
|
|
|
parser.add_argument("--models-dir", default=LORA_MODELS_DIR, |
|
help=f"LoRA models directory (default: {LORA_MODELS_DIR})") |
|
|
|
return parser.parse_args() |
|
|
|
def print_style_info(style_info: Dict): |
|
"""Print detailed style information.""" |
|
print(f"π¨ Style: {style_info['name']}") |
|
print("=" * 40) |
|
|
|
|
|
print(f"π Path: {style_info['path']}") |
|
print(f"π Checkpoints: {len(style_info['checkpoints'])}") |
|
print(f"π Latest: {Path(style_info['latest_checkpoint']).name}") |
|
|
|
|
|
training_info = style_info.get("training_info", {}) |
|
if training_info: |
|
print(f"\nπ Training Information:") |
|
print(f" Steps: {training_info.get('total_steps', 'unknown')}") |
|
print(f" Epochs: {training_info.get('epochs', 'unknown')}") |
|
print(f" Learning Rate: {training_info.get('learning_rate', 'unknown')}") |
|
print(f" LoRA Rank: {training_info.get('lora_rank', 'unknown')}") |
|
print(f" LoRA Alpha: {training_info.get('lora_alpha', 'unknown')}") |
|
|
|
|
|
dataset_info = style_info.get("dataset_info", {}) |
|
if dataset_info: |
|
print(f"\nπ Dataset Information:") |
|
print(f" Total Images: {dataset_info.get('total_images', 'unknown')}") |
|
print(f" Train Images: {dataset_info.get('train_images', 'unknown')}") |
|
print(f" Validation Images: {dataset_info.get('validation_images', 'unknown')}") |
|
print(f" Trigger Word: {dataset_info.get('trigger_word', 'unknown')}") |
|
print(f" Image Size: {dataset_info.get('image_size', 'unknown')}") |
|
|
|
print(f"\nπ Last Scanned: {style_info.get('last_scanned', 'unknown')}") |
|
|
|
def main(): |
|
"""Main function.""" |
|
args = setup_args() |
|
|
|
|
|
manager = LoRAStyleManager(args.models_dir) |
|
|
|
print("π¨ CompI Phase 1.E: LoRA Style Manager") |
|
print("=" * 40) |
|
|
|
|
|
if args.refresh: |
|
manager.refresh_styles() |
|
|
|
elif args.list or args.list_detailed: |
|
styles = manager.list_styles(detailed=args.list_detailed) |
|
|
|
if not styles: |
|
print("β No LoRA styles found") |
|
print("π‘ Train a style first using: python src/generators/compi_phase1e_lora_training.py") |
|
else: |
|
print(f"π Available LoRA Styles ({len(styles)}):") |
|
print("-" * 40) |
|
|
|
if args.list_detailed: |
|
for style in styles: |
|
print_style_info(style) |
|
print() |
|
else: |
|
for style in styles: |
|
print(f"π¨ {style['name']} ({style['checkpoints']} checkpoints)") |
|
|
|
elif args.info: |
|
style_info = manager.get_style_info(args.info) |
|
if style_info: |
|
print_style_info(style_info) |
|
else: |
|
print(f"β Style not found: {args.info}") |
|
print("π‘ Use --list to see available styles") |
|
|
|
elif args.cleanup: |
|
removed = manager.cleanup_checkpoints(args.cleanup, args.keep) |
|
if removed > 0: |
|
manager.refresh_styles() |
|
|
|
elif args.delete: |
|
manager.delete_style(args.delete, args.confirm) |
|
if args.confirm: |
|
manager.refresh_styles() |
|
|
|
elif args.export: |
|
manager.export_style_info(args.export) |
|
|
|
else: |
|
print("β No command specified. Use --help for usage information.") |
|
print("π‘ Common commands:") |
|
print(" --list List available styles") |
|
print(" --info STYLE_NAME Show style details") |
|
print(" --refresh Refresh styles database") |
|
|
|
return 0 |
|
|
|
if __name__ == "__main__": |
|
exit(main()) |
|
|