SpherePredictor / app.py
masterblaster22's picture
multiple examples added
f05270d verified
import torch
import gradio as gr
import plotly.graph_objects as go
import trimesh
from pathlib import Path
device = torch.device("cpu")
model = torch.jit.load('model_scripted.pt').to(device)
def normalize_vertices(verts):
# Center the vertices
center = verts.mean(dim=0)
verts = verts - center
# Find the maximum absolute value for each axis to scale them independently
scale = verts.abs().max(dim=0)[0] # This finds the max in each dimension independently
# Scale the vertices so that in each dimension, the furthest point is exactly at 1 or -1
# We avoid division by zero by ensuring scale values are at least a very small number
scale = torch.where(scale == 0, torch.ones_like(scale), scale) # Prevent division by zero
return verts / scale
def plot_3d_results(verts, faces, uv_seam_edge_indices):
# Convert vertices to NumPy for easier manipulation
verts_np = verts.cpu().numpy()
faces_np = faces.cpu().numpy()
# Prepare the vertex coordinates for the Mesh3d plot
x, y, z = verts_np[:, 0], verts_np[:, 1], verts_np[:, 2]
i, j, k = faces_np[:, 0], faces_np[:, 1], faces_np[:, 2]
# Create the 3D mesh plot
mesh = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightblue', opacity=0.50, name='Mesh')
# Prepare lines for the predicted edges
edge_x, edge_y, edge_z = [], [], []
for edge in uv_seam_edge_indices:
x0, y0, z0 = verts_np[edge[0]]
x1, y1, z1 = verts_np[edge[1]]
edge_x.extend([x0, x1, None])
edge_y.extend([y0, y1, None])
edge_z.extend([z0, z1, None])
# Create a trace for edges
edges_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z, mode='lines', line=dict(color='red', width=2),
name='Predicted Edges')
# Create a figure and add the mesh and edges
fig = go.Figure(data=[mesh, edges_trace])
fig.update_layout(scene=dict(
xaxis=dict(nticks=4, backgroundcolor="rgb(200, 200, 230)", gridcolor="white", showbackground=True,
zerolinecolor="white"),
yaxis=dict(nticks=4, backgroundcolor="rgb(230, 200,230)", gridcolor="white", showbackground=True,
zerolinecolor="white"),
zaxis=dict(nticks=4, backgroundcolor="rgb(230, 230,200)", gridcolor="white", showbackground=True,
zerolinecolor="white"), camera=dict(up=dict(x=0, y=1, z=0), eye=dict(x=1.25, y=1.25, z=1.25))),
title_text='Predicted Edges')
# return the figure
return fig
def generate_prediction(file_input, treshold_value=0.5):
if not file_input:
return
# Load and triangulate the mesh
mesh = trimesh.load_mesh(file_input)
# For production, we should use a faster method to preprocess the mesh!
# Convert vertices to a PyTorch tensor
vertices = torch.tensor(mesh.vertices, dtype=torch.float32)
vertices = normalize_vertices(vertices)
# Initialize containers for unique vertices and mapping
unique_vertices = []
vertex_mapping = {}
new_faces = []
# Populate unique vertices and create new faces with updated indices
for face in mesh.faces:
new_face = []
for orig_index in face:
vertex = tuple(vertices[orig_index].tolist()) # Convert to tuple (hashable)
if vertex not in vertex_mapping:
vertex_mapping[vertex] = len(unique_vertices)
unique_vertices.append(vertices[orig_index])
new_face.append(vertex_mapping[vertex])
new_faces.append(new_face)
# Create edge set to ensure uniqueness
edge_set = set()
for face in new_faces:
# Unpack the vertex indices
v1, v2, v3 = face
# Create undirected edges (use tuple sorting to ensure uniqueness)
edge_set.add(tuple(sorted((v1, v2))))
edge_set.add(tuple(sorted((v2, v3))))
edge_set.add(tuple(sorted((v1, v3))))
# Convert edges back to tensor
edges = torch.tensor(list(edge_set), dtype=torch.long)
# Convert unique vertices and new faces back to tensors
verts = torch.stack(unique_vertices)
faces = torch.tensor(new_faces, dtype=torch.long)
model.eval()
with torch.no_grad():
test_outputs_logits = model(verts, edges).to(device)
test_outputs = torch.sigmoid(test_outputs_logits).to(device)
test_predictions = (test_outputs > treshold_value).int().cpu()
uv_seam_edges_mask = test_predictions.cpu().squeeze() == 1
uv_seam_edges = edges[uv_seam_edges_mask].cpu().tolist()
# Return the HTML content generated by plot_3d_results
return plot_3d_results(verts, faces, uv_seam_edges)
def run_gradio():
with gr.Blocks() as demo:
gr.Label("Proof of concept demo. Predict UV seams on a 3D sphere meshes.")
with gr.Row():
model3d_input = gr.FileExplorer(label="Sphere Prototype Model",
file_count='single',
value='randomSphere_180.obj',
glob='**/*.obj')
with gr.Column():
model3d_output = gr.Plot()
treshold_value = gr.Slider(minimum=0, maximum=1, value=0.6, label="Threshold")
button = gr.Button("Predict")
button.click(generate_prediction, inputs=[model3d_input, treshold_value], outputs=model3d_output)
demo.launch()
run_gradio()