Spaces:
Sleeping
Sleeping
import torch | |
import gradio as gr | |
import plotly.graph_objects as go | |
import trimesh | |
from pathlib import Path | |
device = torch.device("cpu") | |
model = torch.jit.load('model_scripted.pt').to(device) | |
def normalize_vertices(verts): | |
# Center the vertices | |
center = verts.mean(dim=0) | |
verts = verts - center | |
# Find the maximum absolute value for each axis to scale them independently | |
scale = verts.abs().max(dim=0)[0] # This finds the max in each dimension independently | |
# Scale the vertices so that in each dimension, the furthest point is exactly at 1 or -1 | |
# We avoid division by zero by ensuring scale values are at least a very small number | |
scale = torch.where(scale == 0, torch.ones_like(scale), scale) # Prevent division by zero | |
return verts / scale | |
def plot_3d_results(verts, faces, uv_seam_edge_indices): | |
# Convert vertices to NumPy for easier manipulation | |
verts_np = verts.cpu().numpy() | |
faces_np = faces.cpu().numpy() | |
# Prepare the vertex coordinates for the Mesh3d plot | |
x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2] | |
i, j, k = faces_np[:, 0], faces_np[:, 1], faces_np[:, 2] | |
# Create the 3D mesh plot | |
mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightblue', opacity=0.50, name='Mesh') | |
# Prepare lines for the predicted edges | |
edge_x, edge_y, edge_z = [], [], [] | |
for edge in uv_seam_edge_indices: | |
x0, y0, z0 = verts_np[edge[0]] | |
x1, y1, z1 = verts_np[edge[1]] | |
edge_x.extend([x0, x1, None]) | |
edge_y.extend([y0, y1, None]) | |
edge_z.extend([z0, z1, None]) | |
# Create a trace for edges | |
edges_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='red', width=2), | |
name='Predicted Edges') | |
# Create a figure and add the mesh and edges | |
fig = go.Figure(data=[mesh, edges_trace]) | |
fig.update_layout(scene=dict( | |
xaxis=dict(nticks=4, backgroundcolor="rgb(200, 200, 230)", gridcolor="white", showbackground=True, | |
zerolinecolor="white"), | |
yaxis=dict(nticks=4, backgroundcolor="rgb(230, 200,230)", gridcolor="white", showbackground=True, | |
zerolinecolor="white"), | |
zaxis=dict(nticks=4, backgroundcolor="rgb(230, 230,200)", gridcolor="white", showbackground=True, | |
zerolinecolor="white"), camera=dict(up=dict(x=0, y=1, z=0), eye=dict(x=1.25, y=1.25, z=1.25))), | |
title_text='Predicted Edges') | |
# return the figure | |
return fig | |
def generate_prediction(file_input, treshold_value=0.5): | |
if not file_input: | |
return | |
# Load and triangulate the mesh | |
mesh = trimesh.load_mesh(file_input) | |
# For production, we should use a faster method to preprocess the mesh! | |
# Convert vertices to a PyTorch tensor | |
vertices = torch.tensor(mesh.vertices, dtype=torch.float32) | |
vertices = normalize_vertices(vertices) | |
# Initialize containers for unique vertices and mapping | |
unique_vertices = [] | |
vertex_mapping = {} | |
new_faces = [] | |
# Populate unique vertices and create new faces with updated indices | |
for face in mesh.faces: | |
new_face = [] | |
for orig_index in face: | |
vertex = tuple(vertices[orig_index].tolist()) # Convert to tuple (hashable) | |
if vertex not in vertex_mapping: | |
vertex_mapping[vertex] = len(unique_vertices) | |
unique_vertices.append(vertices[orig_index]) | |
new_face.append(vertex_mapping[vertex]) | |
new_faces.append(new_face) | |
# Create edge set to ensure uniqueness | |
edge_set = set() | |
for face in new_faces: | |
# Unpack the vertex indices | |
v1, v2, v3 = face | |
# Create undirected edges (use tuple sorting to ensure uniqueness) | |
edge_set.add(tuple(sorted((v1, v2)))) | |
edge_set.add(tuple(sorted((v2, v3)))) | |
edge_set.add(tuple(sorted((v1, v3)))) | |
# Convert edges back to tensor | |
edges = torch.tensor(list(edge_set), dtype=torch.long) | |
# Convert unique vertices and new faces back to tensors | |
verts = torch.stack(unique_vertices) | |
faces = torch.tensor(new_faces, dtype=torch.long) | |
model.eval() | |
with torch.no_grad(): | |
test_outputs_logits = model(verts, edges).to(device) | |
test_outputs = torch.sigmoid(test_outputs_logits).to(device) | |
test_predictions = (test_outputs > treshold_value).int().cpu() | |
uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1 | |
uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist() | |
# Return the HTML content generated by plot_3d_results | |
return plot_3d_results(verts, faces, uv_seam_edges) | |
def run_gradio(): | |
with gr.Blocks() as demo: | |
gr.Label("Proof of concept demo. Predict UV seams on a 3D sphere meshes.") | |
with gr.Row(): | |
model3d_input = gr.FileExplorer(label="Sphere Prototype Model", | |
file_count='single', | |
value='randomSphere_180.obj', | |
glob='**/*.obj') | |
with gr.Column(): | |
model3d_output = gr.Plot() | |
treshold_value = gr.Slider(minimum=0, maximum=1, value=0.6, label="Threshold") | |
button = gr.Button("Predict") | |
button.click(generate_prediction, inputs=[model3d_input, treshold_value], outputs=model3d_output) | |
demo.launch() | |
run_gradio() | |