import spaces import torch import gc import os from typing import Optional, List, Dict, Any from datetime import datetime from pathlib import Path import numpy as np from PIL import Image import tempfile # Model imports (to be implemented) from models.stt_processor import KyutaiSTTProcessor from models.text_generator import QwenTextGenerator from models.image_generator import OmniGenImageGenerator from models.model_3d_generator import Hunyuan3DGenerator from models.rigging_processor import UniRigProcessor from utils.fallbacks import FallbackManager from utils.caching import ModelCache class MonsterGenerationPipeline: """Main AI pipeline for monster generation""" def __init__(self, device: str = "cuda"): self.device = device if torch.cuda.is_available() else "cpu" self.cache = ModelCache() self.fallback_manager = FallbackManager() self.models = {} self.model_loaded = { 'stt': False, 'text_gen': False, 'image_gen': False, '3d_gen': False, 'rigging': False } # Pipeline configuration self.config = { 'max_retries': 3, 'timeout': 180, 'enable_caching': True, 'low_vram_mode': True } def _cleanup_memory(self): """Clear GPU memory""" if self.device == "cuda": torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() def _lazy_load_model(self, model_type: str): """Lazy loading with memory optimization""" if self.model_loaded[model_type]: return self.models[model_type] # Clear memory before loading new model self._cleanup_memory() try: if model_type == 'stt': self.models['stt'] = KyutaiSTTProcessor(device=self.device) elif model_type == 'text_gen': self.models['text_gen'] = QwenTextGenerator(device=self.device) elif model_type == 'image_gen': self.models['image_gen'] = OmniGenImageGenerator(device=self.device) elif model_type == '3d_gen': self.models['3d_gen'] = Hunyuan3DGenerator(device=self.device) elif model_type == 'rigging': self.models['rigging'] = UniRigProcessor(device=self.device) self.model_loaded[model_type] = True return self.models[model_type] except Exception as e: print(f"Failed to load {model_type}: {e}") return None def _unload_model(self, model_type: str): """Unload model to free memory""" if model_type in self.models and self.model_loaded[model_type]: if hasattr(self.models[model_type], 'to'): self.models[model_type].to('cpu') del self.models[model_type] self.model_loaded[model_type] = False self._cleanup_memory() @spaces.GPU(duration=300) def generate_monster(self, audio_input: Optional[str] = None, text_input: Optional[str] = None, reference_images: Optional[List] = None, user_id: str = None) -> Dict[str, Any]: """Main monster generation pipeline""" generation_log = { 'user_id': user_id, 'timestamp': datetime.now().isoformat(), 'stages_completed': [], 'fallbacks_used': [], 'success': False } try: # Stage 1: Speech to Text (if audio provided) description = "" if audio_input and os.path.exists(audio_input): try: stt_model = self._lazy_load_model('stt') if stt_model: description = stt_model.transcribe(audio_input) generation_log['stages_completed'].append('stt') else: raise Exception("STT model failed to load") except Exception as e: print(f"STT failed: {e}") description = text_input or "Create a friendly digital monster" generation_log['fallbacks_used'].append('stt') finally: # Unload STT to free memory self._unload_model('stt') else: description = text_input or "Create a friendly digital monster" # Stage 2: Generate monster characteristics monster_traits = {} monster_dialogue = "" try: text_gen = self._lazy_load_model('text_gen') if text_gen: monster_traits = text_gen.generate_traits(description) monster_dialogue = text_gen.generate_dialogue(monster_traits) generation_log['stages_completed'].append('text_gen') else: raise Exception("Text generation model failed to load") except Exception as e: print(f"Text generation failed: {e}") monster_traits, monster_dialogue = self.fallback_manager.handle_text_gen_failure(description) generation_log['fallbacks_used'].append('text_gen') finally: self._unload_model('text_gen') # Stage 3: Generate monster image monster_image = None try: image_gen = self._lazy_load_model('image_gen') if image_gen: # Create enhanced prompt from traits image_prompt = self._create_image_prompt(description, monster_traits) monster_image = image_gen.generate( prompt=image_prompt, reference_images=reference_images, width=512, height=512 ) generation_log['stages_completed'].append('image_gen') else: raise Exception("Image generation model failed to load") except Exception as e: print(f"Image generation failed: {e}") monster_image = self.fallback_manager.handle_image_gen_failure(description) generation_log['fallbacks_used'].append('image_gen') finally: self._unload_model('image_gen') # Stage 4: Convert to 3D model model_3d = None model_3d_path = None try: model_3d_gen = self._lazy_load_model('3d_gen') if model_3d_gen and monster_image: model_3d = model_3d_gen.image_to_3d(monster_image) # Save 3D model model_3d_path = self._save_3d_model(model_3d, user_id) generation_log['stages_completed'].append('3d_gen') else: raise Exception("3D generation failed") except Exception as e: print(f"3D generation failed: {e}") model_3d = self.fallback_manager.handle_3d_gen_failure(monster_image) generation_log['fallbacks_used'].append('3d_gen') finally: self._unload_model('3d_gen') # Stage 5: Add rigging (optional, can be skipped for performance) rigged_model = model_3d if model_3d and self.config.get('enable_rigging', False): try: rigging_proc = self._lazy_load_model('rigging') if rigging_proc: rigged_model = rigging_proc.rig_mesh(model_3d) generation_log['stages_completed'].append('rigging') except Exception as e: print(f"Rigging failed: {e}") generation_log['fallbacks_used'].append('rigging') finally: self._unload_model('rigging') # Prepare download files download_files = self._prepare_download_files( rigged_model or model_3d, monster_image, user_id ) generation_log['success'] = True return { 'description': description, 'traits': monster_traits, 'dialogue': monster_dialogue, 'image': monster_image, 'model_3d': model_3d_path, 'download_files': download_files, 'generation_log': generation_log, 'status': 'success' } except Exception as e: generation_log['error'] = str(e) print(f"Pipeline error: {e}") return self.fallback_generation(description or "digital monster", generation_log) def _create_image_prompt(self, base_description: str, traits: Dict) -> str: """Create enhanced prompt for image generation""" prompt_parts = [base_description] if traits: if 'appearance' in traits: prompt_parts.append(traits['appearance']) if 'personality' in traits: prompt_parts.append(f"with {traits['personality']} personality") if 'color_scheme' in traits: prompt_parts.append(f"featuring {traits['color_scheme']} colors") prompt_parts.extend([ "digital monster", "creature design", "game character", "high quality", "detailed" ]) return ", ".join(prompt_parts) def _save_3d_model(self, model_3d, user_id: str) -> str: """Save 3D model to persistent storage""" if not model_3d: return None timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"monster_{user_id}_{timestamp}.glb" # Use HuggingFace Spaces persistent storage if os.path.exists("/data"): filepath = f"/data/models/{filename}" else: filepath = f"./data/models/{filename}" os.makedirs(os.path.dirname(filepath), exist_ok=True) # Save model (implementation depends on model format) # This is a placeholder - actual implementation would depend on model format with open(filepath, 'wb') as f: if hasattr(model_3d, 'export'): model_3d.export(f) else: # Fallback: save as binary data f.write(str(model_3d).encode()) return filepath def _prepare_download_files(self, model_3d, image, user_id: str) -> List[str]: """Prepare downloadable files for user""" files = [] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # Save image if image: if isinstance(image, Image.Image): image_path = f"/tmp/monster_{user_id}_{timestamp}.png" image.save(image_path) files.append(image_path) elif isinstance(image, np.ndarray): image_path = f"/tmp/monster_{user_id}_{timestamp}.png" Image.fromarray(image).save(image_path) files.append(image_path) # Save 3D model in multiple formats if available if model_3d: # GLB format glb_path = f"/tmp/monster_{user_id}_{timestamp}.glb" files.append(glb_path) # OBJ format (optional) obj_path = f"/tmp/monster_{user_id}_{timestamp}.obj" files.append(obj_path) return files def fallback_generation(self, description: str, generation_log: Dict) -> Dict[str, Any]: """Complete fallback generation when pipeline fails""" return self.fallback_manager.complete_fallback_generation(description, generation_log) def cleanup(self): """Clean up all loaded models""" for model_type in list(self.models.keys()): self._unload_model(model_type) self._cleanup_memory()