File size: 9,862 Bytes
fe24641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
import torch
import numpy as np
from PIL import Image
import trimesh
import tempfile
from typing import Union, Optional, Dict, Any
from pathlib import Path
import os

class Hunyuan3DGenerator:
    """3D model generation using Hunyuan3D-2.1"""
    
    def __init__(self, device: str = "cuda"):
        self.device = device if torch.cuda.is_available() else "cpu"
        self.model = None
        self.preprocessor = None
        
        # Model configuration
        self.model_id = "tencent/Hunyuan3D-2.1"
        self.lite_model_id = "tencent/Hunyuan3D-2.1-Lite"  # For low VRAM
        
        # Generation parameters
        self.num_inference_steps = 50
        self.guidance_scale = 7.5
        self.resolution = 256  # 3D resolution
        
        # Use lite model for low VRAM
        self.use_lite = self.device == "cpu" or not self._check_vram()
    
    def _check_vram(self) -> bool:
        """Check if we have enough VRAM for full model"""
        if not torch.cuda.is_available():
            return False
        
        try:
            vram = torch.cuda.get_device_properties(0).total_memory
            # Need at least 12GB for full model
            return vram > 12 * 1024 * 1024 * 1024
        except:
            return False
    
    def load_model(self):
        """Lazy load the 3D generation model"""
        if self.model is None:
            try:
                # Import Hunyuan3D components
                from transformers import AutoModel, AutoProcessor
                
                model_id = self.lite_model_id if self.use_lite else self.model_id
                
                # Load preprocessor
                self.preprocessor = AutoProcessor.from_pretrained(model_id)
                
                # Load model with optimizations
                torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
                
                self.model = AutoModel.from_pretrained(
                    model_id,
                    torch_dtype=torch_dtype,
                    low_cpu_mem_usage=True,
                    device_map="auto" if self.device == "cuda" else None
                )
                
                if self.device == "cpu":
                    self.model = self.model.to(self.device)
                
                # Enable optimizations
                if hasattr(self.model, 'enable_attention_slicing'):
                    self.model.enable_attention_slicing()
                
            except Exception as e:
                print(f"Failed to load Hunyuan3D model: {e}")
                # Model loading failed, will use fallback
                self.model = "fallback"
    
    def image_to_3d(self, 
                   image: Union[str, Image.Image, np.ndarray],
                   remove_background: bool = True,
                   texture_resolution: int = 1024) -> Union[str, trimesh.Trimesh]:
        """Convert 2D image to 3D model"""
        
        try:
            # Load model if needed
            if self.model is None:
                self.load_model()
            
            # If model loading failed, use fallback
            if self.model == "fallback":
                return self._generate_fallback_3d(image)
            
            # Prepare image
            if isinstance(image, str):
                image = Image.open(image)
            elif isinstance(image, np.ndarray):
                image = Image.fromarray(image)
            
            # Ensure RGB
            if image.mode != 'RGB':
                image = image.convert('RGB')
            
            # Resize for processing
            image = image.resize((512, 512), Image.Resampling.LANCZOS)
            
            # Remove background if requested
            if remove_background:
                image = self._remove_background(image)
            
            # Process with model
            with torch.no_grad():
                # Preprocess image
                inputs = self.preprocessor(images=image, return_tensors="pt").to(self.device)
                
                # Generate 3D
                outputs = self.model.generate(
                    **inputs,
                    num_inference_steps=self.num_inference_steps,
                    guidance_scale=self.guidance_scale,
                    texture_resolution=texture_resolution
                )
                
                # Extract mesh
                mesh = self._extract_mesh(outputs)
            
            # Save mesh
            mesh_path = self._save_mesh(mesh)
            
            return mesh_path
            
        except Exception as e:
            print(f"3D generation error: {e}")
            return self._generate_fallback_3d(image)
    
    def _remove_background(self, image: Image.Image) -> Image.Image:
        """Remove background from image"""
        try:
            # Try using rembg if available
            from rembg import remove
            return remove(image)
        except:
            # Fallback: simple background removal
            # Convert to RGBA
            image = image.convert("RGBA")
            
            # Simple white background removal
            datas = image.getdata()
            new_data = []
            
            for item in datas:
                # Remove white-ish backgrounds
                if item[0] > 230 and item[1] > 230 and item[2] > 230:
                    new_data.append((255, 255, 255, 0))
                else:
                    new_data.append(item)
            
            image.putdata(new_data)
            return image
    
    def _extract_mesh(self, model_outputs: Dict[str, Any]) -> trimesh.Trimesh:
        """Extract mesh from model outputs"""
        # This would depend on actual Hunyuan3D output format
        # Placeholder implementation
        
        if 'vertices' in model_outputs and 'faces' in model_outputs:
            vertices = model_outputs['vertices'].cpu().numpy()
            faces = model_outputs['faces'].cpu().numpy()
            
            # Create trimesh object
            mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
            
            # Add texture if available
            if 'texture' in model_outputs:
                # Apply texture to mesh
                pass
            
            return mesh
        else:
            # Create a simple mesh if outputs are different
            return self._create_simple_mesh()
    
    def _create_simple_mesh(self) -> trimesh.Trimesh:
        """Create a simple placeholder mesh"""
        # Create a simple sphere as placeholder
        mesh = trimesh.creation.icosphere(subdivisions=3, radius=1.0)
        
        # Add some variation
        mesh.vertices += np.random.normal(0, 0.05, mesh.vertices.shape)
        
        # Smooth the mesh
        mesh = mesh.smoothed()
        
        return mesh
    
    def _generate_fallback_3d(self, image: Union[Image.Image, np.ndarray]) -> str:
        """Generate fallback 3D model when main model fails"""
        
        # Create a simple 3D representation based on image
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        elif isinstance(image, str):
            image = Image.open(image)
        
        # Analyze image for basic shape
        image_array = np.array(image.resize((64, 64)))
        
        # Create height map from image brightness
        gray = np.mean(image_array, axis=2)
        height_map = gray / 255.0
        
        # Create mesh from height map
        mesh = self._heightmap_to_mesh(height_map)
        
        # Save and return path
        return self._save_mesh(mesh)
    
    def _heightmap_to_mesh(self, heightmap: np.ndarray) -> trimesh.Trimesh:
        """Convert heightmap to 3D mesh"""
        h, w = heightmap.shape
        
        # Create vertices
        vertices = []
        faces = []
        
        # Create vertex grid
        for i in range(h):
            for j in range(w):
                x = (j - w/2) / w * 2
                y = (i - h/2) / h * 2
                z = heightmap[i, j] * 0.5
                vertices.append([x, y, z])
        
        # Create faces
        for i in range(h-1):
            for j in range(w-1):
                # Two triangles per grid square
                v1 = i * w + j
                v2 = v1 + 1
                v3 = v1 + w
                v4 = v3 + 1
                
                faces.append([v1, v2, v3])
                faces.append([v2, v4, v3])
        
        vertices = np.array(vertices)
        faces = np.array(faces)
        
        # Create mesh
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
        
        # Apply smoothing
        mesh = mesh.smoothed()
        
        return mesh
    
    def _save_mesh(self, mesh: trimesh.Trimesh) -> str:
        """Save mesh to file"""
        # Create temporary file
        with tempfile.NamedTemporaryFile(suffix='.glb', delete=False) as tmp:
            mesh_path = tmp.name
        
        # Export mesh
        mesh.export(mesh_path)
        
        return mesh_path
    
    def text_to_3d(self, text_prompt: str) -> str:
        """Generate 3D model from text description"""
        # First generate image, then convert to 3D
        # This would require image generator integration
        raise NotImplementedError("Text to 3D requires image generation first")
    
    def to(self, device: str):
        """Move model to specified device"""
        self.device = device
        if self.model and self.model != "fallback":
            self.model.to(device)
    
    def __del__(self):
        """Cleanup when object is destroyed"""
        if self.model and self.model != "fallback":
            del self.model
        if self.preprocessor:
            del self.preprocessor
        torch.cuda.empty_cache()