import logging from typing import Tuple, Optional import numpy as np from PIL import Image, ImageFilter import gradio as gr from transformers import pipeline try: import cv2 from cv2 import GaussianBlur, bilateralFilter CV2_AVAILABLE = True except ImportError: cv2 = None CV2_AVAILABLE = False # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EnhancedChromoStereoizer: """ Advanced depth estimation with multi-scale fusion, gradient-preserving normalization, and edge-aware blending for maximum detail preservation. """ def __init__( self, model_name: str = "depth-anything/Depth-Anything-V2-Small-hf", tile_size: int = 518, # Smaller tiles for more detail overlap_ratio: float = 0.5 # Higher overlap for better blending ): self.depth_pipe = pipeline("depth-estimation", model=model_name) self.tile_size = tile_size self.overlap_ratio = overlap_ratio self.last_original: Optional[Image.Image] = None self.last_depth_norm: Optional[np.ndarray] = None def _gaussian_filter(self, image: np.ndarray, sigma: float = 1.0) -> np.ndarray: """Numpy-based Gaussian filter implementation.""" if CV2_AVAILABLE: kernel_size = max(3, int(6 * sigma + 1)) if kernel_size % 2 == 0: kernel_size += 1 return cv2.GaussianBlur(image.astype(np.float32), (kernel_size, kernel_size), sigma) else: # Fallback using PIL if len(image.shape) == 2: pil_img = Image.fromarray((image * 255).astype(np.uint8)) blurred = pil_img.filter(ImageFilter.GaussianBlur(radius=sigma)) return np.array(blurred, dtype=np.float32) / 255.0 else: return image # Return original if can't process def _sobel_edge_detection(self, image: np.ndarray) -> np.ndarray: """Numpy-based Sobel edge detection.""" if CV2_AVAILABLE: return cv2.Sobel(image.astype(np.float32), cv2.CV_32F, 1, 1, ksize=3) else: # Simple numpy implementation sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32) sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32) # Pad image padded = np.pad(image, 1, mode='edge') # Apply convolution grad_x = np.zeros_like(image) grad_y = np.zeros_like(image) for i in range(image.shape[0]): for j in range(image.shape[1]): region = padded[i:i+3, j:j+3] grad_x[i, j] = np.sum(region * sobel_x) grad_y[i, j] = np.sum(region * sobel_y) return np.sqrt(grad_x**2 + grad_y**2) def _percentile_normalize(self, depth_map: np.ndarray, p_low: float = 2, p_high: float = 98) -> np.ndarray: """Robust normalization using percentiles to handle outliers.""" low, high = np.percentile(depth_map, [p_low, p_high]) normalized = np.clip((depth_map - low) / max(high - low, 1e-6), 0, 1) return normalized def _extract_high_freq_details(self, tile_depth: np.ndarray, global_depth: np.ndarray, sigma: float = 2.0) -> np.ndarray: """Extract high-frequency details from tile while preserving global structure.""" # Create low-frequency version of tile tile_low = self._gaussian_filter(tile_depth, sigma=sigma) global_low = self._gaussian_filter(global_depth, sigma=sigma) # Extract high-frequency details tile_details = tile_depth - tile_low # Add details to global depth enhanced = global_depth + tile_details * 0.5 # Adjust strength as needed return enhanced def _histogram_match_local(self, tile_depth: np.ndarray, global_region: np.ndarray, preserve_details: bool = True) -> np.ndarray: """Advanced histogram matching that preserves local details.""" if preserve_details: # Extract details first tile_smooth = self._gaussian_filter(tile_depth, sigma=1.5) details = tile_depth - tile_smooth # Match smooth version to global matched_smooth = self._histogram_match(tile_smooth, global_region) # Add back details result = matched_smooth + details * 0.7 else: result = self._histogram_match(tile_depth, global_region) return np.clip(result, 0, 1) def _histogram_match(self, source: np.ndarray, template: np.ndarray) -> np.ndarray: """Match histogram of source to template.""" source_flat = source.flatten() template_flat = template.flatten() # Get sorted unique values and their indices source_values, source_indices = np.unique(source_flat, return_inverse=True) template_values = np.unique(template_flat) # Interpolate template values to match source quantiles source_quantiles = np.linspace(0, 1, len(source_values)) template_quantiles = np.linspace(0, 1, len(template_values)) interp_values = np.interp(source_quantiles, template_quantiles, template_values) # Map source values to interpolated template values matched_flat = interp_values[source_indices] return matched_flat.reshape(source.shape) def _edge_aware_blend(self, tile: np.ndarray, global_region: np.ndarray, weight_map: np.ndarray, edge_map: np.ndarray) -> np.ndarray: """Edge-aware blending that preserves sharp transitions.""" # Modify weights based on edges edge_threshold = 0.1 edge_weights = np.where(edge_map > edge_threshold, 0.8, weight_map) # Blend with edge awareness blended = tile * edge_weights + global_region * (1 - edge_weights) return blended def _create_seamless_weights(self, h: int, w: int, blend_width: int = 32) -> np.ndarray: """Create seamless blending weights with smooth transitions.""" weights = np.ones((h, w), dtype=np.float32) # Create fade regions at borders for i in range(min(blend_width, min(h, w) // 2)): alpha = i / blend_width # Top and bottom if i < h: weights[i, :] *= alpha weights[-(i+1), :] *= alpha # Left and right if i < w: weights[:, i] *= alpha weights[:, -(i+1)] *= alpha # Apply smoothing for even better transitions weights = self._gaussian_filter(weights, sigma=blend_width/6) return weights def _guided_filter_simple(self, depth: np.ndarray, guide: np.ndarray, radius: int = 8) -> np.ndarray: """Simplified guided filter using bilateral filtering concept.""" if CV2_AVAILABLE: # Use bilateral filter as approximation depth_uint8 = (depth * 255).astype(np.uint8) filtered = cv2.bilateralFilter(depth_uint8, radius, 50, 50) return filtered.astype(np.float32) / 255.0 else: # Fallback to Gaussian filter return self._gaussian_filter(depth, sigma=radius/3) def generate_depth_map(self, img: Image.Image, mode: str) -> Tuple[Optional[Image.Image], Optional[Image.Image]]: """Enhanced depth map generation with multiple processing modes.""" if img is None: self.last_original = None self.last_depth_norm = None return None, None self.last_original = img W, H = img.size # Convert to numpy for edge detection img_gray = np.array(img.convert('L'), dtype=np.float32) / 255.0 # 1. Generate global depth map try: result_global = self.depth_pipe(img) raw_global = np.array(result_global["depth"], dtype=np.float32) if CV2_AVAILABLE: raw_global = cv2.resize(raw_global, (W, H), interpolation=cv2.INTER_LINEAR) else: pil_global = Image.fromarray(raw_global) pil_global = pil_global.resize((W, H), resample=Image.BILINEAR) raw_global = np.array(pil_global, dtype=np.float32) except Exception as e: logger.error(f"Global depth inference failed: {e}") return None, None # Normalize global depth global_normalized = self._percentile_normalize(raw_global) if mode == "Enhanced Tiled": final_depth = self._process_enhanced_tiled(img, img_gray, global_normalized, W, H) elif mode == "Multi-Scale Fusion": final_depth = self._process_multiscale_fusion(img, img_gray, global_normalized, W, H) else: final_depth = global_normalized self.last_depth_norm = final_depth depth_img = Image.fromarray((final_depth * 255).astype(np.uint8)) # Default effect chromo = self.apply_effect(50, 50, 10, 50, 50, 50, 0, 100, 0) return depth_img.convert('RGB'), chromo def _process_enhanced_tiled(self, img: Image.Image, img_gray: np.ndarray, global_depth: np.ndarray, W: int, H: int) -> np.ndarray: """Enhanced tiled processing with advanced blending.""" # Edge detection for guidance edges = self._sobel_edge_detection(img_gray) # Initialize accumulators accum = np.zeros((H, W), dtype=np.float32) weight_total = np.zeros((H, W), dtype=np.float32) ts = self.tile_size stride = int(ts * (1 - self.overlap_ratio)) # Generate tile positions with better coverage x_positions = list(range(0, W - ts + 1, stride)) y_positions = list(range(0, H - ts + 1, stride)) # Ensure edge coverage if len(x_positions) == 0 or x_positions[-1] + ts < W: x_positions.append(max(0, W - ts)) if len(y_positions) == 0 or y_positions[-1] + ts < H: y_positions.append(max(0, H - ts)) processed_tiles = 0 total_tiles = len(x_positions) * len(y_positions) for y in y_positions: for x in x_positions: processed_tiles += 1 logger.info(f"Processing tile {processed_tiles}/{total_tiles} at ({x},{y})") # Extract tile region x_end, y_end = min(x + ts, W), min(y + ts, H) tile_w, tile_h = x_end - x, y_end - y if tile_w <= 0 or tile_h <= 0: continue # Crop image tile tile_img = img.crop((x, y, x_end, y_end)) # Pad if necessary if tile_w != ts or tile_h != ts: # Calculate mean color for padding tile_array = np.array(tile_img) mean_color = tuple(map(int, np.mean(tile_array.reshape(-1, tile_array.shape[-1]), axis=0))) padded_tile = Image.new('RGB', (ts, ts), color=mean_color) padded_tile.paste(tile_img, (0, 0)) tile_img = padded_tile # Process tile try: tile_result = self.depth_pipe(tile_img) tile_raw = np.array(tile_result["depth"], dtype=np.float32) # Extract valid region tile_depth = tile_raw[:tile_h, :tile_w] # Get corresponding global region global_region = global_depth[y:y_end, x:x_end] edge_region = edges[y:y_end, x:x_end] # Advanced normalization with detail preservation tile_normalized = self._histogram_match_local( self._percentile_normalize(tile_depth), global_region, preserve_details=True ) # Multi-scale fusion tile_enhanced = self._extract_high_freq_details( tile_normalized, global_region, sigma=1.5 ) # Create advanced weight map weight_map = self._create_seamless_weights( tile_h, tile_w, blend_width=min(32, min(tile_h, tile_w)//4) ) # Edge-aware blending tile_final = self._edge_aware_blend( tile_enhanced, global_region, weight_map, edge_region ) # Accumulate accum[y:y_end, x:x_end] += tile_final * weight_map weight_total[y:y_end, x:x_end] += weight_map except Exception as e: logger.error(f"Tile processing failed at ({x},{y}): {e}") # Use global region as fallback fallback_weight = np.ones((tile_h, tile_w), dtype=np.float32) * 0.1 accum[y:y_end, x:x_end] += global_depth[y:y_end, x:x_end] * fallback_weight weight_total[y:y_end, x:x_end] += fallback_weight continue # Final blend final_depth = np.divide(accum, weight_total, out=global_depth.copy(), where=weight_total > 0) # Post-processing with guided filtering final_depth = self._guided_filter_simple(final_depth, img_gray, radius=4) return np.clip(final_depth, 0, 1) def _process_multiscale_fusion(self, img: Image.Image, img_gray: np.ndarray, global_depth: np.ndarray, W: int, H: int) -> np.ndarray: """Multi-scale depth fusion for maximum detail.""" scales = [0.5, 0.75, 1.0, 1.25] # Different processing scales fused_depth = global_depth.copy() for scale in scales: if scale == 1.0: continue # Resize image new_w, new_h = int(W * scale), int(H * scale) if new_w < 64 or new_h < 64: # Skip very small scales continue logger.info(f"Processing scale {scale}") scaled_img = img.resize((new_w, new_h), Image.BILINEAR) try: # Process at this scale scale_result = self.depth_pipe(scaled_img) scale_depth = np.array(scale_result["depth"], dtype=np.float32) # Resize back to original if CV2_AVAILABLE: scale_depth = cv2.resize(scale_depth, (W, H), interpolation=cv2.INTER_LINEAR) else: scale_pil = Image.fromarray(scale_depth) scale_depth = np.array(scale_pil.resize((W, H), Image.BILINEAR), dtype=np.float32) # Normalize and extract details scale_normalized = self._percentile_normalize(scale_depth) details = scale_normalized - self._gaussian_filter(scale_normalized, sigma=2.0) # Add scaled details to fusion detail_strength = 0.3 / len(scales) # Adjust strength fused_depth += details * detail_strength except Exception as e: logger.error(f"Multi-scale processing failed at {scale}: {e}") continue return np.clip(fused_depth, 0, 1) def apply_effect(self, threshold_perc, depth_scale, feather_perc, red_b, blue_b, gamma_perc, black_perc, white_perc, smooth_perc) -> Optional[Image.Image]: """Enhanced chromostereopsis effect with better depth mapping.""" if self.last_original is None or self.last_depth_norm is None: return None gray = np.array(self.last_original.convert('L'), dtype=np.float32) # Enhanced brightness/contrast adjustment black = black_perc * 2.55 white = white_perc * 2.55 adj = np.clip((gray - black) / max(white - black, 1e-6), 0, 1) # Improved gamma correction gamma_v = 0.1 + (gamma_perc / 100.0) * 2.9 adj = np.clip(adj ** gamma_v, 0, 1) # Enhanced depth processing depth_sm = self.last_depth_norm if smooth_perc > 0: sigma = smooth_perc / 100.0 * 3.0 depth_sm = self._gaussian_filter(depth_sm, sigma=sigma) # Better depth mapping with multiple thresholds thr = threshold_perc / 100.0 steep = max(depth_scale, 1e-3) / (feather_perc / 100.0 * 10 + 1) # Create smoother blend with better falloff blend = 1.0 / (1.0 + np.exp(-steep * (depth_sm - thr))) # Enhanced color mapping r = np.clip((red_b / 50.0) * adj * blend * 255, 0, 255).astype(np.uint8) b = np.clip((blue_b / 50.0) * adj * (1 - blend) * 255, 0, 255).astype(np.uint8) # Create output with better color balance h, w = r.shape out = np.zeros((h, w, 3), dtype=np.uint8) out[..., 0] = r # Red channel out[..., 2] = b # Blue channel return Image.fromarray(out, 'RGB') def update_effect(self, *args): return self.apply_effect(*args) def clear(self): self.last_original = None self.last_depth_norm = None return None, None # Enhanced UI stereo = EnhancedChromoStereoizer() with gr.Blocks(title='Enhanced ChromoStereoizer Pro') as demo: gr.Markdown('## Enhanced ChromoStereoizer Pro - Maximum Detail Depth Processing') gr.Markdown('*Advanced tiled processing with multi-scale fusion and edge-aware blending*') with gr.Row(): with gr.Column(scale=1): inp = gr.Image(type='pil', label='Upload Image') mode = gr.Radio([ 'Standard', 'Enhanced Tiled', 'Multi-Scale Fusion' ], value='Enhanced Tiled', label='Processing Mode') with gr.Accordion("Advanced Settings", open=False): gr.Markdown("**Processing Parameters**") tile_size_info = gr.Markdown("Tile Size: 384px (optimized for detail)") overlap_info = gr.Markdown("Overlap: 75% (optimized for seamless blending)") btn = gr.Button('Generate Depth Map', variant='primary') with gr.Column(scale=1): d_out = gr.Image(type='pil', interactive=False, show_download_button=True, label='Depth Map') c_out = gr.Image(type='pil', interactive=False, show_download_button=True, label='Chromostereopsis Effect') with gr.Accordion("Effect Controls", open=True): sliders = [ gr.Slider(0, 100, 50, label='Depth Threshold'), gr.Slider(0, 100, 50, label='Depth Scale'), gr.Slider(0, 100, 10, label='Edge Feather'), gr.Slider(0, 100, 50, label='Red Intensity'), gr.Slider(0, 100, 50, label='Blue Intensity'), gr.Slider(0, 100, 50, label='Gamma'), gr.Slider(0, 100, 0, label='Black Level'), gr.Slider(0, 100, 100, label='White Level'), gr.Slider(0, 100, 0, label='Smooth Factor') ] clr = gr.Button('Clear', variant='secondary') # Event handlers btn.click( lambda m, i: stereo.generate_depth_map(i, m), [mode, inp], [d_out, c_out], show_progress=True ) for slider in sliders: slider.change(stereo.update_effect, sliders, c_out) clr.click(stereo.clear, [], [d_out, c_out]) if __name__ == '__main__': demo.launch()