Spaces:
Paused
Paused
import torch | |
import numpy as np | |
from PIL import Image | |
import trimesh | |
import tempfile | |
from typing import Union, Optional, Dict, Any | |
from pathlib import Path | |
import os | |
class Hunyuan3DGenerator: | |
"""3D model generation using Hunyuan3D-2.1""" | |
def __init__(self, device: str = "cuda"): | |
self.device = device if torch.cuda.is_available() else "cpu" | |
self.model = None | |
self.preprocessor = None | |
# Model configuration | |
self.model_id = "tencent/Hunyuan3D-2.1" | |
self.lite_model_id = "tencent/Hunyuan3D-2.1-Lite" # For low VRAM | |
# Generation parameters | |
self.num_inference_steps = 50 | |
self.guidance_scale = 7.5 | |
self.resolution = 256 # 3D resolution | |
# Use lite model for low VRAM | |
self.use_lite = self.device == "cpu" or not self._check_vram() | |
def _check_vram(self) -> bool: | |
"""Check if we have enough VRAM for full model""" | |
if not torch.cuda.is_available(): | |
return False | |
try: | |
vram = torch.cuda.get_device_properties(0).total_memory | |
# Need at least 12GB for full model | |
return vram > 12 * 1024 * 1024 * 1024 | |
except: | |
return False | |
def load_model(self): | |
"""Lazy load the 3D generation model""" | |
if self.model is None: | |
try: | |
# Import Hunyuan3D components | |
from transformers import AutoModel, AutoProcessor | |
model_id = self.lite_model_id if self.use_lite else self.model_id | |
# Load preprocessor | |
self.preprocessor = AutoProcessor.from_pretrained(model_id) | |
# Load model with optimizations | |
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 | |
self.model = AutoModel.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
device_map="auto" if self.device == "cuda" else None | |
) | |
if self.device == "cpu": | |
self.model = self.model.to(self.device) | |
# Enable optimizations | |
if hasattr(self.model, 'enable_attention_slicing'): | |
self.model.enable_attention_slicing() | |
except Exception as e: | |
print(f"Failed to load Hunyuan3D model: {e}") | |
# Model loading failed, will use fallback | |
self.model = "fallback" | |
def image_to_3d(self, | |
image: Union[str, Image.Image, np.ndarray], | |
remove_background: bool = True, | |
texture_resolution: int = 1024) -> Union[str, trimesh.Trimesh]: | |
"""Convert 2D image to 3D model""" | |
try: | |
# Load model if needed | |
if self.model is None: | |
self.load_model() | |
# If model loading failed, use fallback | |
if self.model == "fallback": | |
return self._generate_fallback_3d(image) | |
# Prepare image | |
if isinstance(image, str): | |
image = Image.open(image) | |
elif isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
# Ensure RGB | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Resize for processing | |
image = image.resize((512, 512), Image.Resampling.LANCZOS) | |
# Remove background if requested | |
if remove_background: | |
image = self._remove_background(image) | |
# Process with model | |
with torch.no_grad(): | |
# Preprocess image | |
inputs = self.preprocessor(images=image, return_tensors="pt").to(self.device) | |
# Generate 3D | |
outputs = self.model.generate( | |
**inputs, | |
num_inference_steps=self.num_inference_steps, | |
guidance_scale=self.guidance_scale, | |
texture_resolution=texture_resolution | |
) | |
# Extract mesh | |
mesh = self._extract_mesh(outputs) | |
# Save mesh | |
mesh_path = self._save_mesh(mesh) | |
return mesh_path | |
except Exception as e: | |
print(f"3D generation error: {e}") | |
return self._generate_fallback_3d(image) | |
def _remove_background(self, image: Image.Image) -> Image.Image: | |
"""Remove background from image""" | |
try: | |
# Try using rembg if available | |
from rembg import remove | |
return remove(image) | |
except: | |
# Fallback: simple background removal | |
# Convert to RGBA | |
image = image.convert("RGBA") | |
# Simple white background removal | |
datas = image.getdata() | |
new_data = [] | |
for item in datas: | |
# Remove white-ish backgrounds | |
if item[0] > 230 and item[1] > 230 and item[2] > 230: | |
new_data.append((255, 255, 255, 0)) | |
else: | |
new_data.append(item) | |
image.putdata(new_data) | |
return image | |
def _extract_mesh(self, model_outputs: Dict[str, Any]) -> trimesh.Trimesh: | |
"""Extract mesh from model outputs""" | |
# This would depend on actual Hunyuan3D output format | |
# Placeholder implementation | |
if 'vertices' in model_outputs and 'faces' in model_outputs: | |
vertices = model_outputs['vertices'].cpu().numpy() | |
faces = model_outputs['faces'].cpu().numpy() | |
# Create trimesh object | |
mesh = trimesh.Trimesh(vertices=vertices, faces=faces) | |
# Add texture if available | |
if 'texture' in model_outputs: | |
# Apply texture to mesh | |
pass | |
return mesh | |
else: | |
# Create a simple mesh if outputs are different | |
return self._create_simple_mesh() | |
def _create_simple_mesh(self) -> trimesh.Trimesh: | |
"""Create a simple placeholder mesh""" | |
# Create a simple sphere as placeholder | |
mesh = trimesh.creation.icosphere(subdivisions=3, radius=1.0) | |
# Add some variation | |
mesh.vertices += np.random.normal(0, 0.05, mesh.vertices.shape) | |
# Smooth the mesh | |
mesh = mesh.smoothed() | |
return mesh | |
def _generate_fallback_3d(self, image: Union[Image.Image, np.ndarray]) -> str: | |
"""Generate fallback 3D model when main model fails""" | |
# Create a simple 3D representation based on image | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
elif isinstance(image, str): | |
image = Image.open(image) | |
# Analyze image for basic shape | |
image_array = np.array(image.resize((64, 64))) | |
# Create height map from image brightness | |
gray = np.mean(image_array, axis=2) | |
height_map = gray / 255.0 | |
# Create mesh from height map | |
mesh = self._heightmap_to_mesh(height_map) | |
# Save and return path | |
return self._save_mesh(mesh) | |
def _heightmap_to_mesh(self, heightmap: np.ndarray) -> trimesh.Trimesh: | |
"""Convert heightmap to 3D mesh""" | |
h, w = heightmap.shape | |
# Create vertices | |
vertices = [] | |
faces = [] | |
# Create vertex grid | |
for i in range(h): | |
for j in range(w): | |
x = (j - w/2) / w * 2 | |
y = (i - h/2) / h * 2 | |
z = heightmap[i, j] * 0.5 | |
vertices.append([x, y, z]) | |
# Create faces | |
for i in range(h-1): | |
for j in range(w-1): | |
# Two triangles per grid square | |
v1 = i * w + j | |
v2 = v1 + 1 | |
v3 = v1 + w | |
v4 = v3 + 1 | |
faces.append([v1, v2, v3]) | |
faces.append([v2, v4, v3]) | |
vertices = np.array(vertices) | |
faces = np.array(faces) | |
# Create mesh | |
mesh = trimesh.Trimesh(vertices=vertices, faces=faces) | |
# Apply smoothing | |
mesh = mesh.smoothed() | |
return mesh | |
def _save_mesh(self, mesh: trimesh.Trimesh) -> str: | |
"""Save mesh to file""" | |
# Create temporary file | |
with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as tmp: | |
mesh_path = tmp.name | |
# Export mesh | |
mesh.export(mesh_path) | |
return mesh_path | |
def text_to_3d(self, text_prompt: str) -> str: | |
"""Generate 3D model from text description""" | |
# First generate image, then convert to 3D | |
# This would require image generator integration | |
raise NotImplementedError("Text to 3D requires image generation first") | |
def to(self, device: str): | |
"""Move model to specified device""" | |
self.device = device | |
if self.model and self.model != "fallback": | |
self.model.to(device) | |
def __del__(self): | |
"""Cleanup when object is destroyed""" | |
if self.model and self.model != "fallback": | |
del self.model | |
if self.preprocessor: | |
del self.preprocessor | |
torch.cuda.empty_cache() |