Spaces:
Sleeping
Sleeping
import logging | |
from typing import Tuple, Optional | |
import numpy as np | |
from PIL import Image, ImageFilter | |
import gradio as gr | |
from transformers import pipeline | |
try: | |
import cv2 | |
from cv2 import GaussianBlur, bilateralFilter | |
CV2_AVAILABLE = True | |
except ImportError: | |
cv2 = None | |
CV2_AVAILABLE = False | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class EnhancedChromoStereoizer: | |
""" | |
Advanced depth estimation with multi-scale fusion, gradient-preserving normalization, | |
and edge-aware blending for maximum detail preservation. | |
""" | |
def __init__( | |
self, | |
model_name: str = "depth-anything/Depth-Anything-V2-Small-hf", | |
tile_size: int = 518, # Smaller tiles for more detail | |
overlap_ratio: float = 0.5 # Higher overlap for better blending | |
): | |
self.depth_pipe = pipeline("depth-estimation", model=model_name) | |
self.tile_size = tile_size | |
self.overlap_ratio = overlap_ratio | |
self.last_original: Optional[Image.Image] = None | |
self.last_depth_norm: Optional[np.ndarray] = None | |
def _gaussian_filter(self, image: np.ndarray, sigma: float = 1.0) -> np.ndarray: | |
"""Numpy-based Gaussian filter implementation.""" | |
if CV2_AVAILABLE: | |
kernel_size = max(3, int(6 * sigma + 1)) | |
if kernel_size % 2 == 0: | |
kernel_size += 1 | |
return cv2.GaussianBlur(image.astype(np.float32), (kernel_size, kernel_size), sigma) | |
else: | |
# Fallback using PIL | |
if len(image.shape) == 2: | |
pil_img = Image.fromarray((image * 255).astype(np.uint8)) | |
blurred = pil_img.filter(ImageFilter.GaussianBlur(radius=sigma)) | |
return np.array(blurred, dtype=np.float32) / 255.0 | |
else: | |
return image # Return original if can't process | |
def _sobel_edge_detection(self, image: np.ndarray) -> np.ndarray: | |
"""Numpy-based Sobel edge detection.""" | |
if CV2_AVAILABLE: | |
return cv2.Sobel(image.astype(np.float32), cv2.CV_32F, 1, 1, ksize=3) | |
else: | |
# Simple numpy implementation | |
sobel_x = np.array([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32) | |
sobel_y = np.array([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32) | |
# Pad image | |
padded = np.pad(image, 1, mode='edge') | |
# Apply convolution | |
grad_x = np.zeros_like(image) | |
grad_y = np.zeros_like(image) | |
for i in range(image.shape[0]): | |
for j in range(image.shape[1]): | |
region = padded[i:i+3, j:j+3] | |
grad_x[i, j] = np.sum(region * sobel_x) | |
grad_y[i, j] = np.sum(region * sobel_y) | |
return np.sqrt(grad_x**2 + grad_y**2) | |
def _percentile_normalize(self, depth_map: np.ndarray, p_low: float = 2, p_high: float = 98) -> np.ndarray: | |
"""Robust normalization using percentiles to handle outliers.""" | |
low, high = np.percentile(depth_map, [p_low, p_high]) | |
normalized = np.clip((depth_map - low) / max(high - low, 1e-6), 0, 1) | |
return normalized | |
def _extract_high_freq_details(self, tile_depth: np.ndarray, global_depth: np.ndarray, sigma: float = 2.0) -> np.ndarray: | |
"""Extract high-frequency details from tile while preserving global structure.""" | |
# Create low-frequency version of tile | |
tile_low = self._gaussian_filter(tile_depth, sigma=sigma) | |
global_low = self._gaussian_filter(global_depth, sigma=sigma) | |
# Extract high-frequency details | |
tile_details = tile_depth - tile_low | |
# Add details to global depth | |
enhanced = global_depth + tile_details * 0.5 # Adjust strength as needed | |
return enhanced | |
def _histogram_match_local(self, tile_depth: np.ndarray, global_region: np.ndarray, | |
preserve_details: bool = True) -> np.ndarray: | |
"""Advanced histogram matching that preserves local details.""" | |
if preserve_details: | |
# Extract details first | |
tile_smooth = self._gaussian_filter(tile_depth, sigma=1.5) | |
details = tile_depth - tile_smooth | |
# Match smooth version to global | |
matched_smooth = self._histogram_match(tile_smooth, global_region) | |
# Add back details | |
result = matched_smooth + details * 0.7 | |
else: | |
result = self._histogram_match(tile_depth, global_region) | |
return np.clip(result, 0, 1) | |
def _histogram_match(self, source: np.ndarray, template: np.ndarray) -> np.ndarray: | |
"""Match histogram of source to template.""" | |
source_flat = source.flatten() | |
template_flat = template.flatten() | |
# Get sorted unique values and their indices | |
source_values, source_indices = np.unique(source_flat, return_inverse=True) | |
template_values = np.unique(template_flat) | |
# Interpolate template values to match source quantiles | |
source_quantiles = np.linspace(0, 1, len(source_values)) | |
template_quantiles = np.linspace(0, 1, len(template_values)) | |
interp_values = np.interp(source_quantiles, template_quantiles, template_values) | |
# Map source values to interpolated template values | |
matched_flat = interp_values[source_indices] | |
return matched_flat.reshape(source.shape) | |
def _edge_aware_blend(self, tile: np.ndarray, global_region: np.ndarray, | |
weight_map: np.ndarray, edge_map: np.ndarray) -> np.ndarray: | |
"""Edge-aware blending that preserves sharp transitions.""" | |
# Modify weights based on edges | |
edge_threshold = 0.1 | |
edge_weights = np.where(edge_map > edge_threshold, 0.8, weight_map) | |
# Blend with edge awareness | |
blended = tile * edge_weights + global_region * (1 - edge_weights) | |
return blended | |
def _create_seamless_weights(self, h: int, w: int, blend_width: int = 32) -> np.ndarray: | |
"""Create seamless blending weights with smooth transitions.""" | |
weights = np.ones((h, w), dtype=np.float32) | |
# Create fade regions at borders | |
for i in range(min(blend_width, min(h, w) // 2)): | |
alpha = i / blend_width | |
# Top and bottom | |
if i < h: | |
weights[i, :] *= alpha | |
weights[-(i+1), :] *= alpha | |
# Left and right | |
if i < w: | |
weights[:, i] *= alpha | |
weights[:, -(i+1)] *= alpha | |
# Apply smoothing for even better transitions | |
weights = self._gaussian_filter(weights, sigma=blend_width/6) | |
return weights | |
def _guided_filter_simple(self, depth: np.ndarray, guide: np.ndarray, radius: int = 8) -> np.ndarray: | |
"""Simplified guided filter using bilateral filtering concept.""" | |
if CV2_AVAILABLE: | |
# Use bilateral filter as approximation | |
depth_uint8 = (depth * 255).astype(np.uint8) | |
filtered = cv2.bilateralFilter(depth_uint8, radius, 50, 50) | |
return filtered.astype(np.float32) / 255.0 | |
else: | |
# Fallback to Gaussian filter | |
return self._gaussian_filter(depth, sigma=radius/3) | |
def generate_depth_map(self, img: Image.Image, mode: str) -> Tuple[Optional[Image.Image], Optional[Image.Image]]: | |
"""Enhanced depth map generation with multiple processing modes.""" | |
if img is None: | |
self.last_original = None | |
self.last_depth_norm = None | |
return None, None | |
self.last_original = img | |
W, H = img.size | |
# Convert to numpy for edge detection | |
img_gray = np.array(img.convert('L'), dtype=np.float32) / 255.0 | |
# 1. Generate global depth map | |
try: | |
result_global = self.depth_pipe(img) | |
raw_global = np.array(result_global["depth"], dtype=np.float32) | |
if CV2_AVAILABLE: | |
raw_global = cv2.resize(raw_global, (W, H), interpolation=cv2.INTER_LINEAR) | |
else: | |
pil_global = Image.fromarray(raw_global) | |
pil_global = pil_global.resize((W, H), resample=Image.BILINEAR) | |
raw_global = np.array(pil_global, dtype=np.float32) | |
except Exception as e: | |
logger.error(f"Global depth inference failed: {e}") | |
return None, None | |
# Normalize global depth | |
global_normalized = self._percentile_normalize(raw_global) | |
if mode == "Enhanced Tiled": | |
final_depth = self._process_enhanced_tiled(img, img_gray, global_normalized, W, H) | |
elif mode == "Multi-Scale Fusion": | |
final_depth = self._process_multiscale_fusion(img, img_gray, global_normalized, W, H) | |
else: | |
final_depth = global_normalized | |
self.last_depth_norm = final_depth | |
depth_img = Image.fromarray((final_depth * 255).astype(np.uint8)) | |
# Default effect | |
chromo = self.apply_effect(50, 50, 10, 50, 50, 50, 0, 100, 0) | |
return depth_img.convert('RGB'), chromo | |
def _process_enhanced_tiled(self, img: Image.Image, img_gray: np.ndarray, | |
global_depth: np.ndarray, W: int, H: int) -> np.ndarray: | |
"""Enhanced tiled processing with advanced blending.""" | |
# Edge detection for guidance | |
edges = self._sobel_edge_detection(img_gray) | |
# Initialize accumulators | |
accum = np.zeros((H, W), dtype=np.float32) | |
weight_total = np.zeros((H, W), dtype=np.float32) | |
ts = self.tile_size | |
stride = int(ts * (1 - self.overlap_ratio)) | |
# Generate tile positions with better coverage | |
x_positions = list(range(0, W - ts + 1, stride)) | |
y_positions = list(range(0, H - ts + 1, stride)) | |
# Ensure edge coverage | |
if len(x_positions) == 0 or x_positions[-1] + ts < W: | |
x_positions.append(max(0, W - ts)) | |
if len(y_positions) == 0 or y_positions[-1] + ts < H: | |
y_positions.append(max(0, H - ts)) | |
processed_tiles = 0 | |
total_tiles = len(x_positions) * len(y_positions) | |
for y in y_positions: | |
for x in x_positions: | |
processed_tiles += 1 | |
logger.info(f"Processing tile {processed_tiles}/{total_tiles} at ({x},{y})") | |
# Extract tile region | |
x_end, y_end = min(x + ts, W), min(y + ts, H) | |
tile_w, tile_h = x_end - x, y_end - y | |
if tile_w <= 0 or tile_h <= 0: | |
continue | |
# Crop image tile | |
tile_img = img.crop((x, y, x_end, y_end)) | |
# Pad if necessary | |
if tile_w != ts or tile_h != ts: | |
# Calculate mean color for padding | |
tile_array = np.array(tile_img) | |
mean_color = tuple(map(int, np.mean(tile_array.reshape(-1, tile_array.shape[-1]), axis=0))) | |
padded_tile = Image.new('RGB', (ts, ts), color=mean_color) | |
padded_tile.paste(tile_img, (0, 0)) | |
tile_img = padded_tile | |
# Process tile | |
try: | |
tile_result = self.depth_pipe(tile_img) | |
tile_raw = np.array(tile_result["depth"], dtype=np.float32) | |
# Extract valid region | |
tile_depth = tile_raw[:tile_h, :tile_w] | |
# Get corresponding global region | |
global_region = global_depth[y:y_end, x:x_end] | |
edge_region = edges[y:y_end, x:x_end] | |
# Advanced normalization with detail preservation | |
tile_normalized = self._histogram_match_local( | |
self._percentile_normalize(tile_depth), | |
global_region, | |
preserve_details=True | |
) | |
# Multi-scale fusion | |
tile_enhanced = self._extract_high_freq_details( | |
tile_normalized, global_region, sigma=1.5 | |
) | |
# Create advanced weight map | |
weight_map = self._create_seamless_weights( | |
tile_h, tile_w, | |
blend_width=min(32, min(tile_h, tile_w)//4) | |
) | |
# Edge-aware blending | |
tile_final = self._edge_aware_blend( | |
tile_enhanced, global_region, weight_map, edge_region | |
) | |
# Accumulate | |
accum[y:y_end, x:x_end] += tile_final * weight_map | |
weight_total[y:y_end, x:x_end] += weight_map | |
except Exception as e: | |
logger.error(f"Tile processing failed at ({x},{y}): {e}") | |
# Use global region as fallback | |
fallback_weight = np.ones((tile_h, tile_w), dtype=np.float32) * 0.1 | |
accum[y:y_end, x:x_end] += global_depth[y:y_end, x:x_end] * fallback_weight | |
weight_total[y:y_end, x:x_end] += fallback_weight | |
continue | |
# Final blend | |
final_depth = np.divide(accum, weight_total, out=global_depth.copy(), where=weight_total > 0) | |
# Post-processing with guided filtering | |
final_depth = self._guided_filter_simple(final_depth, img_gray, radius=4) | |
return np.clip(final_depth, 0, 1) | |
def _process_multiscale_fusion(self, img: Image.Image, img_gray: np.ndarray, | |
global_depth: np.ndarray, W: int, H: int) -> np.ndarray: | |
"""Multi-scale depth fusion for maximum detail.""" | |
scales = [0.5, 0.75, 1.0, 1.25] # Different processing scales | |
fused_depth = global_depth.copy() | |
for scale in scales: | |
if scale == 1.0: | |
continue | |
# Resize image | |
new_w, new_h = int(W * scale), int(H * scale) | |
if new_w < 64 or new_h < 64: # Skip very small scales | |
continue | |
logger.info(f"Processing scale {scale}") | |
scaled_img = img.resize((new_w, new_h), Image.BILINEAR) | |
try: | |
# Process at this scale | |
scale_result = self.depth_pipe(scaled_img) | |
scale_depth = np.array(scale_result["depth"], dtype=np.float32) | |
# Resize back to original | |
if CV2_AVAILABLE: | |
scale_depth = cv2.resize(scale_depth, (W, H), interpolation=cv2.INTER_LINEAR) | |
else: | |
scale_pil = Image.fromarray(scale_depth) | |
scale_depth = np.array(scale_pil.resize((W, H), Image.BILINEAR), dtype=np.float32) | |
# Normalize and extract details | |
scale_normalized = self._percentile_normalize(scale_depth) | |
details = scale_normalized - self._gaussian_filter(scale_normalized, sigma=2.0) | |
# Add scaled details to fusion | |
detail_strength = 0.3 / len(scales) # Adjust strength | |
fused_depth += details * detail_strength | |
except Exception as e: | |
logger.error(f"Multi-scale processing failed at {scale}: {e}") | |
continue | |
return np.clip(fused_depth, 0, 1) | |
def apply_effect(self, threshold_perc, depth_scale, feather_perc, | |
red_b, blue_b, gamma_perc, black_perc, white_perc, smooth_perc) -> Optional[Image.Image]: | |
"""Enhanced chromostereopsis effect with better depth mapping.""" | |
if self.last_original is None or self.last_depth_norm is None: | |
return None | |
gray = np.array(self.last_original.convert('L'), dtype=np.float32) | |
# Enhanced brightness/contrast adjustment | |
black = black_perc * 2.55 | |
white = white_perc * 2.55 | |
adj = np.clip((gray - black) / max(white - black, 1e-6), 0, 1) | |
# Improved gamma correction | |
gamma_v = 0.1 + (gamma_perc / 100.0) * 2.9 | |
adj = np.clip(adj ** gamma_v, 0, 1) | |
# Enhanced depth processing | |
depth_sm = self.last_depth_norm | |
if smooth_perc > 0: | |
sigma = smooth_perc / 100.0 * 3.0 | |
depth_sm = self._gaussian_filter(depth_sm, sigma=sigma) | |
# Better depth mapping with multiple thresholds | |
thr = threshold_perc / 100.0 | |
steep = max(depth_scale, 1e-3) / (feather_perc / 100.0 * 10 + 1) | |
# Create smoother blend with better falloff | |
blend = 1.0 / (1.0 + np.exp(-steep * (depth_sm - thr))) | |
# Enhanced color mapping | |
r = np.clip((red_b / 50.0) * adj * blend * 255, 0, 255).astype(np.uint8) | |
b = np.clip((blue_b / 50.0) * adj * (1 - blend) * 255, 0, 255).astype(np.uint8) | |
# Create output with better color balance | |
h, w = r.shape | |
out = np.zeros((h, w, 3), dtype=np.uint8) | |
out[..., 0] = r # Red channel | |
out[..., 2] = b # Blue channel | |
return Image.fromarray(out, 'RGB') | |
def update_effect(self, *args): | |
return self.apply_effect(*args) | |
def clear(self): | |
self.last_original = None | |
self.last_depth_norm = None | |
return None, None | |
# Enhanced UI | |
stereo = EnhancedChromoStereoizer() | |
with gr.Blocks(title='Enhanced ChromoStereoizer Pro') as demo: | |
gr.Markdown('## Enhanced ChromoStereoizer Pro - Maximum Detail Depth Processing') | |
gr.Markdown('*Advanced tiled processing with multi-scale fusion and edge-aware blending*') | |
with gr.Row(): | |
with gr.Column(scale=1): | |
inp = gr.Image(type='pil', label='Upload Image') | |
mode = gr.Radio([ | |
'Standard', | |
'Enhanced Tiled', | |
'Multi-Scale Fusion' | |
], value='Enhanced Tiled', label='Processing Mode') | |
with gr.Accordion("Advanced Settings", open=False): | |
gr.Markdown("**Processing Parameters**") | |
tile_size_info = gr.Markdown("Tile Size: 384px (optimized for detail)") | |
overlap_info = gr.Markdown("Overlap: 75% (optimized for seamless blending)") | |
btn = gr.Button('Generate Depth Map', variant='primary') | |
with gr.Column(scale=1): | |
d_out = gr.Image(type='pil', interactive=False, show_download_button=True, label='Depth Map') | |
c_out = gr.Image(type='pil', interactive=False, show_download_button=True, label='Chromostereopsis Effect') | |
with gr.Accordion("Effect Controls", open=True): | |
sliders = [ | |
gr.Slider(0, 100, 50, label='Depth Threshold'), | |
gr.Slider(0, 100, 50, label='Depth Scale'), | |
gr.Slider(0, 100, 10, label='Edge Feather'), | |
gr.Slider(0, 100, 50, label='Red Intensity'), | |
gr.Slider(0, 100, 50, label='Blue Intensity'), | |
gr.Slider(0, 100, 50, label='Gamma'), | |
gr.Slider(0, 100, 0, label='Black Level'), | |
gr.Slider(0, 100, 100, label='White Level'), | |
gr.Slider(0, 100, 0, label='Smooth Factor') | |
] | |
clr = gr.Button('Clear', variant='secondary') | |
# Event handlers | |
btn.click( | |
lambda m, i: stereo.generate_depth_map(i, m), | |
[mode, inp], | |
[d_out, c_out], | |
show_progress=True | |
) | |
for slider in sliders: | |
slider.change(stereo.update_effect, sliders, c_out) | |
clr.click(stereo.clear, [], [d_out, c_out]) | |
if __name__ == '__main__': | |
demo.launch() |