File size: 20,390 Bytes
7929cac
 
 
 
 
 
 
 
 
 
be8e2f4
7929cac
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
8586183
083ae55
7929cac
 
 
 
 
 
 
be8e2f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7929cac
 
 
 
 
 
 
 
 
be8e2f4
 
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
7929cac
 
be8e2f4
 
 
7929cac
be8e2f4
 
 
7929cac
 
be8e2f4
7929cac
 
be8e2f4
 
 
 
7929cac
be8e2f4
7929cac
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
 
 
 
7929cac
be8e2f4
 
 
7929cac
 
be8e2f4
 
 
7929cac
 
 
 
be8e2f4
 
 
7929cac
 
 
 
 
be8e2f4
 
 
 
 
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
 
 
 
7929cac
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
 
 
 
7929cac
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8e2f4
 
7929cac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
import logging
from typing import Tuple, Optional
import numpy as np
from PIL import Image, ImageFilter
import gradio as gr
from transformers import pipeline

try:
    import cv2
    from cv2 import GaussianBlur, bilateralFilter
    CV2_AVAILABLE = True
except ImportError:
    cv2 = None
    CV2_AVAILABLE = False

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class EnhancedChromoStereoizer:
    """
    Advanced depth estimation with multi-scale fusion, gradient-preserving normalization,
    and edge-aware blending for maximum detail preservation.
    """
    def __init__(
        self,
        model_name: str = "depth-anything/Depth-Anything-V2-Small-hf",
        tile_size: int = 518,  # Smaller tiles for more detail
        overlap_ratio: float = 0.5  # Higher overlap for better blending
    ):
        self.depth_pipe = pipeline("depth-estimation", model=model_name)
        self.tile_size = tile_size
        self.overlap_ratio = overlap_ratio
        self.last_original: Optional[Image.Image] = None
        self.last_depth_norm: Optional[np.ndarray] = None

    def _gaussian_filter(self, image: np.ndarray, sigma: float = 1.0) -> np.ndarray:
        """Numpy-based Gaussian filter implementation."""
        if CV2_AVAILABLE:
            kernel_size = max(3, int(6 * sigma + 1))
            if kernel_size % 2 == 0:
                kernel_size += 1
            return cv2.GaussianBlur(image.astype(np.float32), (kernel_size, kernel_size), sigma)
        else:
            # Fallback using PIL
            if len(image.shape) == 2:
                pil_img = Image.fromarray((image * 255).astype(np.uint8))
                blurred = pil_img.filter(ImageFilter.GaussianBlur(radius=sigma))
                return np.array(blurred, dtype=np.float32) / 255.0
            else:
                return image  # Return original if can't process

    def _sobel_edge_detection(self, image: np.ndarray) -> np.ndarray:
        """Numpy-based Sobel edge detection."""
        if CV2_AVAILABLE:
            return cv2.Sobel(image.astype(np.float32), cv2.CV_32F, 1, 1, ksize=3)
        else:
            # Simple numpy implementation
            sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32)
            sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32)
            
            # Pad image
            padded = np.pad(image, 1, mode='edge')
            
            # Apply convolution
            grad_x = np.zeros_like(image)
            grad_y = np.zeros_like(image)
            
            for i in range(image.shape[0]):
                for j in range(image.shape[1]):
                    region = padded[i:i+3, j:j+3]
                    grad_x[i, j] = np.sum(region * sobel_x)
                    grad_y[i, j] = np.sum(region * sobel_y)
            
            return np.sqrt(grad_x**2 + grad_y**2)

    def _percentile_normalize(self, depth_map: np.ndarray, p_low: float = 2, p_high: float = 98) -> np.ndarray:
        """Robust normalization using percentiles to handle outliers."""
        low, high = np.percentile(depth_map, [p_low, p_high])
        normalized = np.clip((depth_map - low) / max(high - low, 1e-6), 0, 1)
        return normalized

    def _extract_high_freq_details(self, tile_depth: np.ndarray, global_depth: np.ndarray, sigma: float = 2.0) -> np.ndarray:
        """Extract high-frequency details from tile while preserving global structure."""
        # Create low-frequency version of tile
        tile_low = self._gaussian_filter(tile_depth, sigma=sigma)
        global_low = self._gaussian_filter(global_depth, sigma=sigma) 
        
        # Extract high-frequency details
        tile_details = tile_depth - tile_low
        
        # Add details to global depth
        enhanced = global_depth + tile_details * 0.5  # Adjust strength as needed
        return enhanced

    def _histogram_match_local(self, tile_depth: np.ndarray, global_region: np.ndarray, 
                              preserve_details: bool = True) -> np.ndarray:
        """Advanced histogram matching that preserves local details."""
        if preserve_details:
            # Extract details first
            tile_smooth = self._gaussian_filter(tile_depth, sigma=1.5)
            details = tile_depth - tile_smooth
            
            # Match smooth version to global
            matched_smooth = self._histogram_match(tile_smooth, global_region)
            
            # Add back details
            result = matched_smooth + details * 0.7
        else:
            result = self._histogram_match(tile_depth, global_region)
        
        return np.clip(result, 0, 1)

    def _histogram_match(self, source: np.ndarray, template: np.ndarray) -> np.ndarray:
        """Match histogram of source to template."""
        source_flat = source.flatten()
        template_flat = template.flatten()
        
        # Get sorted unique values and their indices
        source_values, source_indices = np.unique(source_flat, return_inverse=True)
        template_values = np.unique(template_flat)
        
        # Interpolate template values to match source quantiles
        source_quantiles = np.linspace(0, 1, len(source_values))
        template_quantiles = np.linspace(0, 1, len(template_values))
        
        interp_values = np.interp(source_quantiles, template_quantiles, template_values)
        
        # Map source values to interpolated template values
        matched_flat = interp_values[source_indices]
        return matched_flat.reshape(source.shape)

    def _edge_aware_blend(self, tile: np.ndarray, global_region: np.ndarray, 
                         weight_map: np.ndarray, edge_map: np.ndarray) -> np.ndarray:
        """Edge-aware blending that preserves sharp transitions."""
        # Modify weights based on edges
        edge_threshold = 0.1
        edge_weights = np.where(edge_map > edge_threshold, 0.8, weight_map)
        
        # Blend with edge awareness
        blended = tile * edge_weights + global_region * (1 - edge_weights)
        return blended

    def _create_seamless_weights(self, h: int, w: int, blend_width: int = 32) -> np.ndarray:
        """Create seamless blending weights with smooth transitions."""
        weights = np.ones((h, w), dtype=np.float32)
        
        # Create fade regions at borders
        for i in range(min(blend_width, min(h, w) // 2)):
            alpha = i / blend_width
            # Top and bottom
            if i < h:
                weights[i, :] *= alpha
                weights[-(i+1), :] *= alpha
            # Left and right  
            if i < w:
                weights[:, i] *= alpha
                weights[:, -(i+1)] *= alpha
        
        # Apply smoothing for even better transitions
        weights = self._gaussian_filter(weights, sigma=blend_width/6)
        return weights

    def _guided_filter_simple(self, depth: np.ndarray, guide: np.ndarray, radius: int = 8) -> np.ndarray:
        """Simplified guided filter using bilateral filtering concept."""
        if CV2_AVAILABLE:
            # Use bilateral filter as approximation
            depth_uint8 = (depth * 255).astype(np.uint8)
            filtered = cv2.bilateralFilter(depth_uint8, radius, 50, 50)
            return filtered.astype(np.float32) / 255.0
        else:
            # Fallback to Gaussian filter
            return self._gaussian_filter(depth, sigma=radius/3)

    def generate_depth_map(self, img: Image.Image, mode: str) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
        """Enhanced depth map generation with multiple processing modes."""
        if img is None:
            self.last_original = None
            self.last_depth_norm = None
            return None, None

        self.last_original = img
        W, H = img.size
        
        # Convert to numpy for edge detection
        img_gray = np.array(img.convert('L'), dtype=np.float32) / 255.0

        # 1. Generate global depth map
        try:
            result_global = self.depth_pipe(img)
            raw_global = np.array(result_global["depth"], dtype=np.float32)
            if CV2_AVAILABLE:
                raw_global = cv2.resize(raw_global, (W, H), interpolation=cv2.INTER_LINEAR)
            else:
                pil_global = Image.fromarray(raw_global)
                pil_global = pil_global.resize((W, H), resample=Image.BILINEAR)
                raw_global = np.array(pil_global, dtype=np.float32)
        except Exception as e:
            logger.error(f"Global depth inference failed: {e}")
            return None, None

        # Normalize global depth
        global_normalized = self._percentile_normalize(raw_global)

        if mode == "Enhanced Tiled":
            final_depth = self._process_enhanced_tiled(img, img_gray, global_normalized, W, H)
        elif mode == "Multi-Scale Fusion":
            final_depth = self._process_multiscale_fusion(img, img_gray, global_normalized, W, H)
        else:
            final_depth = global_normalized

        self.last_depth_norm = final_depth
        depth_img = Image.fromarray((final_depth * 255).astype(np.uint8))

        # Default effect
        chromo = self.apply_effect(50, 50, 10, 50, 50, 50, 0, 100, 0)
        return depth_img.convert('RGB'), chromo

    def _process_enhanced_tiled(self, img: Image.Image, img_gray: np.ndarray, 
                               global_depth: np.ndarray, W: int, H: int) -> np.ndarray:
        """Enhanced tiled processing with advanced blending."""
        # Edge detection for guidance
        edges = self._sobel_edge_detection(img_gray)
        
        # Initialize accumulators
        accum = np.zeros((H, W), dtype=np.float32)
        weight_total = np.zeros((H, W), dtype=np.float32)
        
        ts = self.tile_size
        stride = int(ts * (1 - self.overlap_ratio))
        
        # Generate tile positions with better coverage
        x_positions = list(range(0, W - ts + 1, stride))
        y_positions = list(range(0, H - ts + 1, stride))
        
        # Ensure edge coverage
        if len(x_positions) == 0 or x_positions[-1] + ts < W:
            x_positions.append(max(0, W - ts))
        if len(y_positions) == 0 or y_positions[-1] + ts < H:
            y_positions.append(max(0, H - ts))

        processed_tiles = 0
        total_tiles = len(x_positions) * len(y_positions)
        
        for y in y_positions:
            for x in x_positions:
                processed_tiles += 1
                logger.info(f"Processing tile {processed_tiles}/{total_tiles} at ({x},{y})")
                
                # Extract tile region
                x_end, y_end = min(x + ts, W), min(y + ts, H)
                tile_w, tile_h = x_end - x, y_end - y
                
                if tile_w <= 0 or tile_h <= 0:
                    continue
                
                # Crop image tile
                tile_img = img.crop((x, y, x_end, y_end))
                
                # Pad if necessary
                if tile_w != ts or tile_h != ts:
                    # Calculate mean color for padding
                    tile_array = np.array(tile_img)
                    mean_color = tuple(map(int, np.mean(tile_array.reshape(-1, tile_array.shape[-1]), axis=0)))
                    
                    padded_tile = Image.new('RGB', (ts, ts), color=mean_color)
                    padded_tile.paste(tile_img, (0, 0))
                    tile_img = padded_tile

                # Process tile
                try:
                    tile_result = self.depth_pipe(tile_img)
                    tile_raw = np.array(tile_result["depth"], dtype=np.float32)
                    
                    # Extract valid region
                    tile_depth = tile_raw[:tile_h, :tile_w]
                    
                    # Get corresponding global region
                    global_region = global_depth[y:y_end, x:x_end]
                    edge_region = edges[y:y_end, x:x_end]
                    
                    # Advanced normalization with detail preservation
                    tile_normalized = self._histogram_match_local(
                        self._percentile_normalize(tile_depth), 
                        global_region, 
                        preserve_details=True
                    )
                    
                    # Multi-scale fusion
                    tile_enhanced = self._extract_high_freq_details(
                        tile_normalized, global_region, sigma=1.5
                    )
                    
                    # Create advanced weight map
                    weight_map = self._create_seamless_weights(
                        tile_h, tile_w, 
                        blend_width=min(32, min(tile_h, tile_w)//4)
                    )
                    
                    # Edge-aware blending
                    tile_final = self._edge_aware_blend(
                        tile_enhanced, global_region, weight_map, edge_region
                    )
                    
                    # Accumulate
                    accum[y:y_end, x:x_end] += tile_final * weight_map
                    weight_total[y:y_end, x:x_end] += weight_map

                except Exception as e:
                    logger.error(f"Tile processing failed at ({x},{y}): {e}")
                    # Use global region as fallback
                    fallback_weight = np.ones((tile_h, tile_w), dtype=np.float32) * 0.1
                    accum[y:y_end, x:x_end] += global_depth[y:y_end, x:x_end] * fallback_weight
                    weight_total[y:y_end, x:x_end] += fallback_weight
                    continue

        # Final blend
        final_depth = np.divide(accum, weight_total, out=global_depth.copy(), where=weight_total > 0)
        
        # Post-processing with guided filtering
        final_depth = self._guided_filter_simple(final_depth, img_gray, radius=4)
        
        return np.clip(final_depth, 0, 1)

    def _process_multiscale_fusion(self, img: Image.Image, img_gray: np.ndarray, 
                                  global_depth: np.ndarray, W: int, H: int) -> np.ndarray:
        """Multi-scale depth fusion for maximum detail."""
        scales = [0.5, 0.75, 1.0, 1.25]  # Different processing scales
        fused_depth = global_depth.copy()
        
        for scale in scales:
            if scale == 1.0:
                continue
                
            # Resize image
            new_w, new_h = int(W * scale), int(H * scale)
            if new_w < 64 or new_h < 64:  # Skip very small scales
                continue
                
            logger.info(f"Processing scale {scale}")
            scaled_img = img.resize((new_w, new_h), Image.BILINEAR)
            
            try:
                # Process at this scale
                scale_result = self.depth_pipe(scaled_img)
                scale_depth = np.array(scale_result["depth"], dtype=np.float32)
                
                # Resize back to original
                if CV2_AVAILABLE:
                    scale_depth = cv2.resize(scale_depth, (W, H), interpolation=cv2.INTER_LINEAR)
                else:
                    scale_pil = Image.fromarray(scale_depth)
                    scale_depth = np.array(scale_pil.resize((W, H), Image.BILINEAR), dtype=np.float32)
                
                # Normalize and extract details
                scale_normalized = self._percentile_normalize(scale_depth)
                details = scale_normalized - self._gaussian_filter(scale_normalized, sigma=2.0)
                
                # Add scaled details to fusion
                detail_strength = 0.3 / len(scales)  # Adjust strength
                fused_depth += details * detail_strength
                
            except Exception as e:
                logger.error(f"Multi-scale processing failed at {scale}: {e}")
                continue
        
        return np.clip(fused_depth, 0, 1)

    def apply_effect(self, threshold_perc, depth_scale, feather_perc,
                    red_b, blue_b, gamma_perc, black_perc, white_perc, smooth_perc) -> Optional[Image.Image]:
        """Enhanced chromostereopsis effect with better depth mapping."""
        if self.last_original is None or self.last_depth_norm is None:
            return None
            
        gray = np.array(self.last_original.convert('L'), dtype=np.float32)
        
        # Enhanced brightness/contrast adjustment
        black = black_perc * 2.55
        white = white_perc * 2.55
        adj = np.clip((gray - black) / max(white - black, 1e-6), 0, 1)
        
        # Improved gamma correction
        gamma_v = 0.1 + (gamma_perc / 100.0) * 2.9
        adj = np.clip(adj ** gamma_v, 0, 1)
        
        # Enhanced depth processing
        depth_sm = self.last_depth_norm
        if smooth_perc > 0:
            sigma = smooth_perc / 100.0 * 3.0
            depth_sm = self._gaussian_filter(depth_sm, sigma=sigma)
        
        # Better depth mapping with multiple thresholds
        thr = threshold_perc / 100.0
        steep = max(depth_scale, 1e-3) / (feather_perc / 100.0 * 10 + 1)
        
        # Create smoother blend with better falloff
        blend = 1.0 / (1.0 + np.exp(-steep * (depth_sm - thr)))
        
        # Enhanced color mapping
        r = np.clip((red_b / 50.0) * adj * blend * 255, 0, 255).astype(np.uint8)
        b = np.clip((blue_b / 50.0) * adj * (1 - blend) * 255, 0, 255).astype(np.uint8)
        
        # Create output with better color balance
        h, w = r.shape
        out = np.zeros((h, w, 3), dtype=np.uint8)
        out[..., 0] = r  # Red channel
        out[..., 2] = b  # Blue channel
        
        return Image.fromarray(out, 'RGB')

    def update_effect(self, *args):
        return self.apply_effect(*args)

    def clear(self):
        self.last_original = None
        self.last_depth_norm = None
        return None, None

# Enhanced UI
stereo = EnhancedChromoStereoizer()

with gr.Blocks(title='Enhanced ChromoStereoizer Pro') as demo:
    gr.Markdown('## Enhanced ChromoStereoizer Pro - Maximum Detail Depth Processing')
    gr.Markdown('*Advanced tiled processing with multi-scale fusion and edge-aware blending*')
    
    with gr.Row():
        with gr.Column(scale=1):
            inp = gr.Image(type='pil', label='Upload Image')
            mode = gr.Radio([
                'Standard', 
                'Enhanced Tiled', 
                'Multi-Scale Fusion'
            ], value='Enhanced Tiled', label='Processing Mode')
            
            with gr.Accordion("Advanced Settings", open=False):
                gr.Markdown("**Processing Parameters**")
                tile_size_info = gr.Markdown("Tile Size: 384px (optimized for detail)")
                overlap_info = gr.Markdown("Overlap: 75% (optimized for seamless blending)")
            
            btn = gr.Button('Generate Depth Map', variant='primary')
            
        with gr.Column(scale=1):
            d_out = gr.Image(type='pil', interactive=False, show_download_button=True, label='Depth Map')
            c_out = gr.Image(type='pil', interactive=False, show_download_button=True, label='Chromostereopsis Effect')
            
            with gr.Accordion("Effect Controls", open=True):
                sliders = [
                    gr.Slider(0, 100, 50, label='Depth Threshold'),
                    gr.Slider(0, 100, 50, label='Depth Scale'),
                    gr.Slider(0, 100, 10, label='Edge Feather'),
                    gr.Slider(0, 100, 50, label='Red Intensity'),
                    gr.Slider(0, 100, 50, label='Blue Intensity'),
                    gr.Slider(0, 100, 50, label='Gamma'),
                    gr.Slider(0, 100, 0, label='Black Level'),
                    gr.Slider(0, 100, 100, label='White Level'),
                    gr.Slider(0, 100, 0, label='Smooth Factor')
                ]
            
            clr = gr.Button('Clear', variant='secondary')

    # Event handlers
    btn.click(
        lambda m, i: stereo.generate_depth_map(i, m),
        [mode, inp],
        [d_out, c_out],
        show_progress=True
    )
    
    for slider in sliders:
        slider.change(stereo.update_effect, sliders, c_out)
    
    clr.click(stereo.clear, [], [d_out, c_out])

if __name__ == '__main__':
    demo.launch()