import kornia import os import sys import pathlib import logging import yaml import nvdiffrast.torch as dr from easydict import EasyDict # Apply torchvision compatibility fixes try: import torchvision print(f"torchvision {torchvision.__version__} imported successfully") except (RuntimeError, AttributeError) as e: if "operator torchvision::nms does not exist" in str(e) or "extension" in str(e): print("Applying torchvision compatibility fixes...") # Apply the same fixes as in app.py import types if not hasattr(torch, 'ops'): torch.ops = types.SimpleNamespace() if not hasattr(torch.ops, 'torchvision'): torch.ops.torchvision = types.SimpleNamespace() # Create dummy functions for problematic operators torchvision_ops = ['nms', 'roi_align', 'roi_pool', 'ps_roi_align', 'ps_roi_pool'] for op_name in torchvision_ops: if not hasattr(torch.ops.torchvision, op_name): if op_name == 'nms': setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0, dtype=torch.int64)) else: setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0)) # Try importing again try: import torchvision print("torchvision imported successfully after fixes") except Exception as e2: print(f"torchvision still has issues, but continuing: {e2}") else: print(f"Other torchvision error: {e}") except ImportError: print("torchvision not available, continuing without it") from NeuralJacobianFields import SourceMesh from nvdiffmodeling.src import render from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from utilities.video import Video from utilities.helpers import cosine_avg, create_scene, l1_avg from utilities.camera import CameraBatch, get_camera_params from utilities.clip_spatial import CLIPVisualEncoder from utilities.resize_right import resize, cubic, linear, lanczos2, lanczos3 from packages.fashion_clip.fashion_clip.fashion_clip import FashionCLIP from utils import * from get_embeddings import * from pytorch3d.structures import Meshes from pytorch3d.loss import ( chamfer_distance, mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency, ) from pytorch3d.ops import sample_points_from_meshes def total_triangle_area(vertices): # Calculate the sum of the areas of all triangles in the mesh num_triangles = vertices.shape[0] // 3 triangle_vertices = vertices.view(num_triangles, 3, 3) # Calculate the cross product for each triangle cross_products = torch.cross(triangle_vertices[:, 1] - triangle_vertices[:, 0], triangle_vertices[:, 2] - triangle_vertices[:, 0]) # Calculate the area of each triangle areas = 0.5 * torch.norm(cross_products, dim=1) # Sum the areas of all triangles total_area = torch.sum(areas) return total_area def triangle_size_regularization(vertices): # Penalize small triangles by minimizing the squared sum of triangle areas return total_triangle_area(vertices)**2 def loop(cfg): clip_flag = True output_path = pathlib.Path(cfg['output_path']) os.makedirs(output_path, exist_ok=True) with open(output_path / 'config.yml', 'w') as f: yaml.dump(cfg, f, default_flow_style=False) cfg = EasyDict(cfg) print(f'Output directory {cfg.output_path} created') os.makedirs(output_path / 'tmp', exist_ok=True) device = torch.device(f'cuda:{cfg.gpu}') torch.cuda.set_device(device) # Read mode flags from config if available, otherwise use defaults text_input = cfg.get('text_input', False) image_input = cfg.get('image_input', False) fashion_image = cfg.get('fashion_image', False) fashion_text = cfg.get('fashion_text', True) # Default to fashion text mode use_target_mesh = cfg.get('use_target_mesh', True) CLIP_embeddings = False # Always use FashionCLIP to avoid CLIP issues # Always use FashionCLIP to avoid CLIP loading issues print('Loading FashionCLIP model...') try: fclip = FashionCLIP('fashion-clip') print('FashionCLIP loaded successfully') except Exception as e: print(f'Error loading FashionCLIP: {e}') raise RuntimeError(f"Failed to load FashionCLIP: {e}") # Load CLIPVisualEncoder with error handling print('Loading CLIPVisualEncoder...') try: fe = CLIPVisualEncoder(cfg.consistency_clip_model, cfg.consistency_vit_stride, device) print('CLIPVisualEncoder loaded successfully') except Exception as e: print(f'Error loading CLIPVisualEncoder: {e}') print('Continuing without CLIPVisualEncoder...') fe = None # Use FashionCLIP for all modes to avoid CLIP loading issues if fashion_image: print('Processing with fashion image embeddings') target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True) elif fashion_text: print('Processing with fashion text embeddings') target_direction_embeds, delta_direction_embeds = get_fashion_text_embeddings(fclip, cfg, device) elif text_input or image_input: print('WARNING: Regular CLIP embeddings are disabled, using FashionCLIP instead') if text_input: target_direction_embeds, delta_direction_embeds = get_fashion_text_embeddings(fclip, cfg, device) else: target_direction_embeds, delta_direction_embeds = get_fashion_img_embeddings(fclip, cfg, device, True) clip_mean = torch.tensor([0.48154660, 0.45782750, 0.40821073], device=device) clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device) # output video video = Video(cfg.output_path) # GL Context - with fallback for headless environments print('Initializing nvdiffrast GL context...') try: glctx = dr.RasterizeGLContext() print('nvdiffrast GL context initialized successfully') use_gl_rendering = True except Exception as e: print(f'Error initializing nvdiffrast GL context: {e}') print('This is likely due to missing EGL headers in headless environment.') print('Using fallback rendering approach...') glctx = None use_gl_rendering = False def fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs): """ Fallback rendering function when GL context is not available Returns a simple colored mesh visualization """ try: # Check if return_rast_map is requested return_rast_map = kwargs.get('return_rast_map', False) # Create a simple colored mesh visualization # This is a basic fallback that creates a colored mesh without proper lighting device = mesh.v_pos.device if hasattr(mesh, 'v_pos') and mesh.v_pos is not None else torch.device('cuda') batch_size = 1 if return_rast_map: # Return a dummy rasterization map for consistency rast_map = torch.zeros(batch_size, resolution, resolution, 4, device=device) rast_map[..., 3] = 1.0 # Set alpha to 1 return rast_map else: # Create a simple colored output color = torch.ones(batch_size, resolution, resolution, 3, device=device) * 0.5 # Gray color # Add some basic shading based on vertex positions if hasattr(mesh, 'v_pos') and mesh.v_pos is not None: # Normalize vertex positions for coloring v_pos_norm = (mesh.v_pos - mesh.v_pos.min(dim=0)[0]) / (mesh.v_pos.max(dim=0)[0] - mesh.v_pos.min(dim=0)[0] + 1e-8) # Use vertex positions to create a simple color gradient color = color * 0.3 + v_pos_norm.mean(dim=0).unsqueeze(0).unsqueeze(0).unsqueeze(0) * 0.7 return color except Exception as e: print(f"Fallback rendering failed: {e}") # Return a simple colored square as last resort device = mesh.v_pos.device if hasattr(mesh, 'v_pos') and mesh.v_pos is not None else torch.device('cuda') if kwargs.get('return_rast_map', False): return torch.zeros(1, resolution, resolution, 4, device=device) else: return torch.ones(1, resolution, resolution, 3, device=device) * 0.5 def safe_render_mesh(glctx, mesh, mvp, campos, lightpos, light_power, resolution, **kwargs): """ Safe rendering function that uses GL context if available, otherwise falls back """ if glctx is not None and use_gl_rendering: try: return render.render_mesh(glctx, mesh, mvp, campos, lightpos, light_power, resolution, **kwargs) except Exception as e: print(f"GL rendering failed, using fallback: {e}") return fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs) else: return fallback_render_mesh(mesh, mvp, campos, lightpos, light_power, resolution, **kwargs) load_mesh = get_mesh(cfg.mesh, output_path, cfg.retriangulate, cfg.bsdf) if use_target_mesh: target_mesh = get_mesh(cfg.target_mesh, output_path, cfg.retriangulate, cfg.bsdf, 'mesh_target.obj') # We construct a Meshes structure for the target mesh trg_mesh_p3d = Meshes(verts=[target_mesh.v_pos], faces=[target_mesh.t_pos_idx]) jacobian_source = SourceMesh.SourceMesh(0, str(output_path / 'tmp' / 'mesh.obj'), {}, 1, ttype=torch.float) if len(list((output_path / 'tmp').glob('*.npz'))) > 0: logging.warn(f'Using existing Jacobian .npz files in {str(output_path)}/tmp/ ! Please check if this is intentional.') # Check if the mesh file exists before loading mesh_file_path = output_path / 'tmp' / 'mesh.obj' print(f"Looking for mesh file at: {mesh_file_path}") print(f"Absolute path: {mesh_file_path.absolute()}") if not mesh_file_path.exists(): # List files in the tmp directory to see what's there tmp_dir = output_path / 'tmp' if tmp_dir.exists(): print(f"Files in {tmp_dir}:") for file in tmp_dir.iterdir(): print(f" - {file.name}") else: print(f"Tmp directory {tmp_dir} does not exist") raise FileNotFoundError(f"Mesh file not found: {mesh_file_path}. This indicates an issue with the mesh loading process.") print(f"Mesh file exists at: {mesh_file_path}") print("Loading jacobian source...") jacobian_source.load() jacobian_source.to(device) # Validate that jacobian source loaded properly if not hasattr(jacobian_source, 'jacobians_from_vertices') or jacobian_source.jacobians_from_vertices is None: raise ValueError("Failed to load jacobian source. The jacobians_from_vertices method is not available.") print("Jacobian source loaded successfully") with torch.no_grad(): gt_jacobians = jacobian_source.jacobians_from_vertices(load_mesh.v_pos.unsqueeze(0)) # Validate that gt_jacobians is not empty if gt_jacobians is None or gt_jacobians.shape[0] == 0: raise ValueError("Failed to generate jacobians from vertices. This indicates an issue with the mesh or jacobian source.") print(f"Generated jacobians with shape: {gt_jacobians.shape}") gt_jacobians.requires_grad_(True) optimizer = torch.optim.Adam([gt_jacobians], lr=cfg.lr) cams_data = CameraBatch( cfg.train_res, [cfg.dist_min, cfg.dist_max], [cfg.azim_min, cfg.azim_max], [cfg.elev_alpha, cfg.elev_beta, cfg.elev_max], [cfg.fov_min, cfg.fov_max], cfg.aug_loc, cfg.aug_light, cfg.aug_bkg, cfg.batch_size, rand_solid=True ) cams = torch.utils.data.DataLoader(cams_data, cfg.batch_size, num_workers=0, pin_memory=True) best_losses = {'CLIP': np.inf, 'total': np.inf} for out_type in ['final', 'best_clip', 'best_total', 'target_final']: os.makedirs(output_path / f'mesh_{out_type}', exist_ok=True) os.makedirs(output_path / 'images', exist_ok=True) logger = SummaryWriter(str(output_path / 'logs')) rot_ang = 0.0 t_loop = tqdm(range(cfg.epochs), leave=False) if cfg.resize_method == 'cubic': resize_method = cubic elif cfg.resize_method == 'linear': resize_method = linear elif cfg.resize_method == 'lanczos2': resize_method = lanczos2 elif cfg.resize_method == 'lanczos3': resize_method = lanczos3 for it in t_loop: # updated vertices from jacobians n_vert = jacobian_source.vertices_from_jacobians(gt_jacobians).squeeze() # Validate that n_vert is not empty if n_vert is None or n_vert.shape[0] == 0: raise ValueError("Generated vertices are empty. This indicates an issue with the jacobian source or mesh loading.") print(f"Iteration {it}: Generated {n_vert.shape[0]} vertices") # TODO: More texture code required to make it work ... ready_texture = texture.Texture2D( kornia.filters.gaussian_blur2d( load_mesh.material['kd'].data.permute(0, 3, 1, 2), kernel_size=(7, 7), sigma=(3, 3), ).permute(0, 2, 3, 1).contiguous() ) kd_notex = texture.Texture2D(torch.full_like(ready_texture.data, 0.5)) ready_specular = texture.Texture2D( kornia.filters.gaussian_blur2d( load_mesh.material['ks'].data.permute(0, 3, 1, 2), kernel_size=(7, 7), sigma=(3, 3), ).permute(0, 2, 3, 1).contiguous() ) ready_normal = texture.Texture2D( kornia.filters.gaussian_blur2d( load_mesh.material['normal'].data.permute(0, 3, 1, 2), kernel_size=(7, 7), sigma=(3, 3), ).permute(0, 2, 3, 1).contiguous() ) # Final mesh m = mesh.Mesh( n_vert, load_mesh.t_pos_idx, material={ 'bsdf': cfg.bsdf, 'kd': kd_notex, 'ks': ready_specular, 'normal': ready_normal, }, base=load_mesh # gets uvs etc from here ) deformed_mesh_p3d = Meshes(verts=[m.v_pos], faces=[m.t_pos_idx]) render_mesh = create_scene([m.eval()], sz=512) if it == 0: base_mesh = render_mesh.clone() base_mesh = mesh.auto_normals(base_mesh) base_mesh = mesh.compute_tangents(base_mesh) render_mesh = mesh.auto_normals(render_mesh) render_mesh = mesh.compute_tangents(render_mesh) if use_target_mesh: # Target mesh m_target = mesh.Mesh( target_mesh.v_pos, target_mesh.t_pos_idx, material={ 'bsdf': cfg.bsdf, 'kd': kd_notex, 'ks': ready_specular, 'normal': ready_normal, }, base=target_mesh ) render_target_mesh = create_scene([m_target.eval()], sz=512) if it == 0: base_target_mesh = render_target_mesh.clone() base_target_mesh = mesh.auto_normals(base_target_mesh) base_target_mesh = mesh.compute_tangents(base_target_mesh) render_target_mesh = mesh.auto_normals(render_target_mesh) render_target_mesh = mesh.compute_tangents(render_target_mesh) # Logging mesh if it % cfg.log_interval == 0: with torch.no_grad(): params = get_camera_params( cfg.log_elev, rot_ang, cfg.log_dist, cfg.log_res, cfg.log_fov, ) rot_ang += 5 log_mesh = mesh.unit_size(render_mesh.eval(params)) log_image = safe_render_mesh(glctx, log_mesh, params['mvp'], params['campos'], params['lightpos'], cfg.log_light_power, cfg.log_res) log_image = video.ready_image(log_image) logger.add_mesh('predicted_mesh', vertices=log_mesh.v_pos.unsqueeze(0), faces=log_mesh.t_pos_idx.unsqueeze(0), global_step=it) if cfg.adapt_dist and it > 0: with torch.no_grad(): v_pos = m.v_pos.clone() vmin = v_pos.amin(dim=0) vmax = v_pos.amax(dim=0) v_pos -= (vmin + vmax) / 2 mult = torch.cat([v_pos.amin(dim=0), v_pos.amax(dim=0)]).abs().amax().cpu() cams.dataset.dist_min = cfg.dist_min * mult cams.dataset.dist_max = cfg.dist_max * mult params_camera = next(iter(cams)) for key in params_camera: params_camera[key] = params_camera[key].to(device) final_mesh = render_mesh.eval(params_camera) train_render = safe_render_mesh(glctx, final_mesh, params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res) # Handle permutation for fallback case if train_render.shape[-1] == 3: # If it's already in the right format train_render = train_render.permute(0, 3, 1, 2) train_render = resize(train_render, out_shape=(224, 224), interp_method=resize_method) if use_target_mesh: final_target_mesh = render_target_mesh.eval(params_camera) train_target_render = safe_render_mesh(glctx, final_target_mesh, params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res) # Handle permutation for fallback case if train_target_render.shape[-1] == 3: # If it's already in the right format train_target_render = train_target_render.permute(0, 3, 1, 2) train_target_render = resize(train_target_render, out_shape=(224, 224), interp_method=resize_method) train_rast_map = safe_render_mesh( glctx, final_mesh, params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res, return_rast_map=True ) if it == 0: params_camera = next(iter(cams)) for key in params_camera: params_camera[key] = params_camera[key].to(device) base_render = safe_render_mesh(glctx, base_mesh.eval(params_camera), params_camera['mvp'], params_camera['campos'], params_camera['lightpos'], cfg.light_power, cfg.train_res) # Handle permutation for fallback case if base_render.shape[-1] == 3: # If it's already in the right format base_render = base_render.permute(0, 3, 1, 2) base_render = resize(base_render, out_shape=(224, 224), interp_method=resize_method) if it % cfg.log_interval_im == 0: log_idx = torch.randperm(cfg.batch_size)[:5] s_log = train_render[log_idx, :, :, :] s_log = torchvision.utils.make_grid(s_log) ndarr = s_log.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(str(output_path / 'images' / f'epoch_{it}.png')) if use_target_mesh: s_log_target = train_target_render[log_idx, :, :, :] s_log_target = torchvision.utils.make_grid(s_log_target) ndarr = s_log_target.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(str(output_path / 'images' / f'epoch_{it}_target.png')) obj.write_obj( str(output_path / 'mesh_final'), m.eval() ) optimizer.zero_grad() normalized_clip_render = (train_render - clip_mean[None, :, None, None]) / clip_std[None, :, None, None] deformed_features = fclip.encode_image_tensors(train_render) target_features = fclip.encode_image_tensors(train_target_render) garment_loss = l1_avg(deformed_features, target_features) l1_loss = l1_avg(train_render, train_target_render) # We sample 10k points from the surface of each mesh sample_src = sample_points_from_meshes(deformed_mesh_p3d, 10000) sample_trg = sample_points_from_meshes(trg_mesh_p3d, 10000) # We compare the two sets of pointclouds by computing (a) the chamfer loss loss_chamfer, _ = chamfer_distance(sample_trg, sample_src) loss_chamfer *= 25. # # and (b) the edge length of the predicted mesh loss_edge = mesh_edge_loss(deformed_mesh_p3d) # mesh normal consistency loss_normal = mesh_normal_consistency(deformed_mesh_p3d) # mesh laplacian smoothing loss_laplacian = mesh_laplacian_smoothing(deformed_mesh_p3d, method="uniform") loss_triangles = triangle_size_regularization(deformed_mesh_p3d.verts_list()[0])/100000. logger.add_scalar('l1_loss', l1_loss, global_step=it) logger.add_scalar('garment_loss', garment_loss, global_step=it) # Jacobian regularization r_loss = (((gt_jacobians) - torch.eye(3, 3, device=device)) ** 2).mean() logger.add_scalar('jacobian_regularization', r_loss, global_step=it) if cfg.consistency_loss_weight != 0 and fe is not None and train_rast_map is not None: consistency_loss = compute_mv_cl(final_mesh, fe, normalized_clip_render, params_camera, train_rast_map, cfg, device) else: consistency_loss = r_loss logger.add_scalar('consistency_loss', consistency_loss, global_step=it) logger.add_scalar('chamfer', loss_chamfer, global_step=it) logger.add_scalar('edge', loss_edge, global_step=it) logger.add_scalar('normal', loss_normal, global_step=it) logger.add_scalar('laplacian', loss_laplacian, global_step=it) logger.add_scalar('triangles', loss_triangles, global_step=it) if it > 1000 and clip_flag: cfg.clip_weight = 0 cfg.consistency_loss_weight = 0 cfg.regularize_jacobians_weight = 0.025 clip_flag = False regularizers = loss_chamfer + loss_edge + loss_normal + loss_laplacian + loss_triangles total_loss = (cfg.clip_weight * garment_loss + cfg.delta_clip_weight * l1_loss + cfg.regularize_jacobians_weight * r_loss + cfg.consistency_loss_weight * consistency_loss + regularizers) logger.add_scalar('total_loss', total_loss, global_step=it) total_loss.backward() optimizer.step() t_loop.set_description( f'L1 = {cfg.delta_clip_weight * l1_loss.item()}, ' f'CLIP = {cfg.clip_weight * garment_loss.item()}, ' f'Jacb = {cfg.regularize_jacobians_weight * r_loss.item()}, ' f'MVC = {cfg.consistency_loss_weight * consistency_loss.item()}, ' f'Chamf = {loss_chamfer.item()}, ' f'Edge = {loss_edge.item()}, ' f'Normal = {loss_normal.item()}, ' f'Lapl = {loss_laplacian.item()}, ' f'Triang = {loss_triangles.item()}, ' f'Total = {total_loss.item()}')#_target video.close() obj.write_obj( str(output_path / 'mesh_final'), m.eval() ) return