Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import argparse | |
import random | |
import os | |
os.system('python setup.py develop') | |
import gradio as gr | |
import matplotlib | |
import numpy as np | |
import torch | |
from PIL import ImageDraw, Image | |
from matplotlib import pyplot as plt | |
from mmcv import Config | |
import json | |
# def replace_line(file_name, line_num, text): | |
# lines = open(file_name, 'r').readlines() | |
# lines[line_num] = text | |
# out = open(file_name, 'w') | |
# out.writelines(lines) | |
# out.close() | |
# def read_lines(file_name): | |
# lines = open(file_name, 'r').readlines() | |
# print(lines) | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/parallel/distributed.py", 7, "from mmengine import print_log\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/parallel/distributed.py", 8, "from mmengine.utils.dl_utils import TORCH_VERSION\nfrom mmengine.utils import digit_version\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/parallel/registry.py", 3, 'from mmengine.registry import Registry\n') | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/fileio/io.py", 5, "from mmengine.utils import is_list_of\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/checkpoint.py", 23, "from mmengine.utils import digit_version, mkdir_or_exist\nfrom mmengine.utils.dl_utils import load_url\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/hook.py", 1, "from mmengine.registry import Registry\nfrom mmengine.utils import is_method_overridden\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/evaluation.py",11, "from mmengine.utils import is_seq_of\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/logger/mlflow.py", 3, "from mmengine.utils.dl_utils import TORCH_VERSION\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/logger/tensorboard.py", 4, "from mmengine.utils.dl_utils import TORCH_VERSION\nfrom mmengine.utils import digit_version\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/logger/text.py", 12, "from mmengine.utils import is_tuple_of, scandir\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/logger/wandb.py", 5, "from mmengine.utils import scandir\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/optimizer.py", 11, "from mmengine.utils.dl_utils import TORCH_VERSION\nfrom mmcv.utils import IS_NPU_AVAILABLE\nfrom mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/hooks/optimizer.py", 14, "from mmengine.utils import digit_version\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/fp16_utils.py", 12, "from mmcv.utils import IS_NPU_AVAILABLE\nfrom mmengine.utils.dl_utils import TORCH_VERSION\nfrom mmengine.utils import digit_version\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/builder.py", 4, "from mmengine.registry import Registry\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/optimizer/builder.py", 7, "from mmcv.utils import IS_NPU_AVAILABLE\nfrom mmengine.registry import Registry, build_from_cfg\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/runner/optimizer/default_constructor.py", 8, "from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm, _InstanceNorm\nfrom mmengine.registry import build_from_cfg\nfrom mmengine.utils import is_list_of\n") | |
# def is_ipu_available() -> bool: | |
# try: | |
# import poptorch | |
# return poptorch.ipuHardwareIsAvailable() | |
# except ImportError: | |
# return False | |
# IS_IPU_AVAILABLE = str(is_ipu_available()) | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/device/ipu/__init__.py", 1, f'IS_IPU_AVAILABLE = {IS_IPU_AVAILABLE}\n') | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/device/scatter_gather.py", 4, "from mmengine.utils import deprecated_api_warning\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmcv/device/_functions.py", 5, "from mmengine.utils import deprecated_api_warning\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmpose/__init__.py", 1, "from mmengine.utils import digit_version\nfrom mmcv import parse_version_info\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmpose/__init__.py", 21, "import mmcv\nmmcv_version = digit_version(mmcv.__version__)\n") | |
# replace_line("/usr/local/lib/python3.10/site-packages/mmpose/core/optimizers/builder.py", 3, "from mmengine.registry import Registry, build_from_cfg") | |
from mmcv.runner import load_checkpoint | |
from mmpose.core import wrap_fp16_model | |
from mmpose.models import build_posenet | |
from torchvision import transforms | |
from demo_text import Resize_Pad | |
from models import * | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
import ast | |
import cv2 | |
import matplotlib | |
# matplotlib.use('agg') | |
def edges_prompt_to_list(prompt): | |
if prompt[0] != "[": | |
prompt = "[" + prompt | |
if prompt[-1] != "]": | |
prompt += "]" | |
return ast.literal_eval(prompt) | |
def descriptions_prompt_to_list(prompt): | |
return prompt.split(',') | |
# Function to visualize the graph | |
def visualize_graph(node_descriptions, edges): | |
plt.close('all') | |
node_descriptions = descriptions_prompt_to_list(node_descriptions) | |
edges = edges_prompt_to_list(edges) | |
# Create an empty graph | |
G = nx.Graph() | |
G.clear() | |
# Add nodes with descriptions | |
for i, desc in enumerate(node_descriptions): | |
G.add_node(i, label=f'{i}:{desc}') | |
# Add edges | |
for edge in edges: | |
G.add_edge(edge[0], edge[1]) | |
# Draw the graph | |
pos = nx.spring_layout(G) # Define layout | |
labels = nx.get_node_attributes(G, 'label') # Get labels | |
nx.draw(G, pos, with_labels=True, labels=labels, node_size=1500, node_color='skyblue', font_size=10, font_weight='bold', font_color='black') # Draw nodes with labels | |
nx.draw_networkx_edges(G, pos, width=2, edge_color='gray') # Draw edges | |
plt.title("Graph Visualization") # Set title | |
plt.axis('off') # Turn off axis | |
# plt.show() # Show plot | |
# Image from plot | |
fig = plt.gcf() | |
# fig.tight_layout(pad=0) | |
# To remove the huge white borders | |
# plt.margins(0) | |
fig.canvas.draw() | |
image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plt.clf() | |
return image_from_plot | |
checkpoint_path = '' | |
def plot_query_results(query_img, query_w, skeleton, prediction, radius=6): | |
h, w, c = query_img.shape | |
prediction = prediction[-1].cpu().numpy() * h | |
# prediction = prediction.cpu().numpy() * h | |
query_img = (query_img - np.min(query_img)) / ( | |
np.max(query_img) - np.min(query_img)) | |
for id, (img, w, keypoint) in enumerate(zip([query_img], | |
[query_w], | |
[prediction])): | |
f, axes = plt.subplots() | |
plt.imshow(img) | |
for k in range(keypoint.shape[0]): | |
if w[k] > 0: | |
kp = keypoint[k, :2] | |
c = (1, 0, 0, 0.75) if w[k] == 1 else (0, 0, 1, 0.6) | |
patch = plt.Circle(kp, radius, color=c) | |
axes.add_patch(patch) | |
axes.text(kp[0], kp[1], k) | |
plt.draw() | |
for l, limb in enumerate(skeleton): | |
kp = keypoint[:, :2] | |
if l > len(COLORS) - 1: | |
c = [x / 255 for x in random.sample(range(0, 255), 3)] | |
else: | |
c = [x / 255 for x in COLORS[l]] | |
if w[limb[0]] > 0 and w[limb[1]] > 0: | |
patch = plt.Line2D([kp[limb[0], 0], kp[limb[1], 0]], | |
[kp[limb[0], 1], kp[limb[1], 1]], | |
linewidth=6, color=c, alpha=0.6) | |
axes.add_artist(patch) | |
plt.axis('off') # command for hiding the axis. | |
plt.subplots_adjust(0, 0, 1, 1, 0, 0) | |
plt.margins(0) | |
fig = plt.gcf() | |
fig.tight_layout(pad=0) | |
return plt | |
COLORS = [ | |
[255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], | |
[85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], | |
[0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], [170, 0, 255], | |
[255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0] | |
] | |
model = None | |
# @spaces.GPU(duration=30) | |
# def estimate(model, data): | |
# with torch.no_grad(): | |
# model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# data["img_q"].to(device=model_device) | |
# data['target_weight_s'][0].to(device=model_device) | |
# print(f'img type: {data["img_q"].dtype}, target_weight type: {data["target_weight_s"][0].dtype}') | |
# model.to(model_device) | |
# model.eval() | |
# # return model(**data) | |
# return model(str(data)) | |
def estimate(data): | |
global model | |
model.cuda() | |
with torch.no_grad(): | |
# model_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# data["img_q"].to(device=model_device) | |
# data['target_weight_s'][0].to(device=model_device) | |
return model(data) | |
# Custom JSON encoder to handle non-serializable objects | |
class CustomEncoder(json.JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
return super().default(obj) | |
def process(query_img, node_descriptions, edges, | |
cfg_path='configs/1shot-swin-gte/graph_split1_config.py'): | |
global model | |
node_descriptions = descriptions_prompt_to_list(node_descriptions) | |
edges = edges_prompt_to_list(edges) | |
cfg = Config.fromfile(cfg_path) | |
kp_src_tensor = torch.zeros((len(node_descriptions), 2)) | |
preprocess = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
Resize_Pad(cfg.model.encoder_config.img_size, | |
cfg.model.encoder_config.img_size)]) | |
if len(edges) == 0: | |
edges = [(0, 0)] | |
#model_device = "cuda" if torch.cuda.is_available() else "cpu" | |
np_query = np.array(query_img)[:, :, ::-1].copy() | |
q_img = preprocess(np_query).flip(0)[None] #.to(model_device) | |
# Create heatmap from keypoints | |
genHeatMap = TopDownGenerateTargetFewShot() | |
data_cfg = cfg.data_cfg | |
data_cfg['image_size'] = np.array([cfg.model.encoder_config.img_size, | |
cfg.model.encoder_config.img_size]) | |
data_cfg['joint_weights'] = None | |
data_cfg['use_different_joint_weights'] = False | |
kp_src_3d = torch.cat( | |
(kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
kp_src_3d_weight = torch.cat( | |
(torch.ones_like(kp_src_tensor), | |
torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1) | |
target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg, | |
kp_src_3d, | |
kp_src_3d_weight, | |
sigma=1) | |
target_s = torch.tensor(target_s).float()[None] | |
target_weight_s = torch.ones_like( | |
torch.tensor(target_weight_s).float()[None]) #.to(model_device) | |
data = { | |
'img_s': [0], | |
'img_q': q_img, | |
'target_s': [target_s], | |
'target_weight_s': [target_weight_s], | |
'target_q': None, | |
'target_weight_q': None, | |
'return_loss': False, | |
'img_metas': [{'sample_skeleton': [edges], | |
'query_skeleton': edges, | |
# 'sample_point_descriptions': np.array([node_descriptions]), | |
'sample_point_descriptions': node_descriptions, | |
'sample_joints_3d': [kp_src_3d], | |
'query_joints_3d': kp_src_3d, | |
'sample_center': [kp_src_tensor.mean(dim=0)], | |
'query_center': kp_src_tensor.mean(dim=0), | |
'sample_scale': [ | |
kp_src_tensor.max(dim=0)[0] - | |
kp_src_tensor.min(dim=0)[0]], | |
'query_scale': kp_src_tensor.max(dim=0)[0] - | |
kp_src_tensor.min(dim=0)[0], | |
'sample_rotation': [0], | |
'query_rotation': 0, | |
'sample_bbox_score': [1], | |
'query_bbox_score': 1, | |
'query_image_file': '', | |
'sample_image_file': [''], | |
}] | |
} | |
# Load model | |
model = build_posenet(cfg.model) | |
fp16_cfg = cfg.get('fp16', None) | |
if fp16_cfg is not None: | |
wrap_fp16_model(model) | |
load_checkpoint(model, checkpoint_path, map_location='cpu') | |
#model.to(model_device) | |
#model.eval() | |
# with torch.no_grad(): | |
# outputs = model(**data) | |
data["img_q"] = data["img_q"].cpu().numpy().tolist() | |
data['target_weight_s'][0] = data['target_weight_s'][0].cpu().numpy().tolist() | |
data['target_s'][0] = data['target_s'][0].cpu().numpy().tolist() | |
data['img_metas'][0]['sample_joints_3d'][0] = data['img_metas'][0]['sample_joints_3d'][0].cpu().tolist() | |
data['img_metas'][0]['query_joints_3d'] = data['img_metas'][0]['query_joints_3d'].cpu().tolist() | |
data['img_metas'][0]['sample_center'][0] = data['img_metas'][0]['sample_center'][0].cpu().tolist() | |
data['img_metas'][0]['query_center'] = data['img_metas'][0]['query_center'].cpu().tolist() | |
data['img_metas'][0]['sample_scale'][0] = data['img_metas'][0]['sample_scale'][0].cpu().tolist() | |
data['img_metas'][0]['query_scale'] = data['img_metas'][0]['query_scale'].cpu().tolist() | |
# # data['img_metas'][0]['sample_point_descriptions'] = data['img_metas'][0]['sample_point_descriptions'].tolist() | |
#model.cuda() | |
model.eval() | |
# return model(**data) | |
# with torch.no_grad(): | |
# outputs = model(**data) | |
str_data = json.dumps(data, cls=CustomEncoder) | |
outputs = estimate(str_data) | |
#outputs = estimate(**data) | |
# visualize results | |
vis_q_weight = target_weight_s[0] | |
vis_q_image = q_img[0].detach().cpu().numpy().transpose(1, 2, 0) | |
out = plot_query_results(vis_q_image, vis_q_weight, edges, torch.tensor(outputs['points']).squeeze(0)) | |
return out | |
def update_examples(query_img, node_descriptions, edges): | |
return query_img, node_descriptions, edges | |
with gr.Blocks() as demo: | |
state = gr.State({ | |
'kp_src': [], | |
'skeleton': [], | |
'count': 0, | |
'color_idx': 0, | |
'prev_pt': None, | |
'prev_pt_idx': None, | |
'prev_clicked': None, | |
'point_descriptions': None, | |
}) | |
gr.Markdown(''' | |
# CapeX Demo | |
We present a novel category agnostic pose estimation approach that utilizes support text-graphs | |
(graphs with text on its nodes), instead of the conventional techniques that use support images. | |
By leveraging the abstraction power of text-graphs, CapeX showcases SOTA results on MP100 while dropping the need | |
of providing an annotated support image. | |
Please note that this demo demonstrates CapeX-S, i.e. our approach with SwinV2-S image backbone (and not SwinV2-T). | |
### [Paper](https://arxiv.org/pdf/2406.00384) | [GitHub](https://github.com/matanr/capex) | |
## Instructions | |
1. Explain using text the desired keypoints. Please refer to the example for the correct format. | |
2. Optionally provide a graph representing the connections between the keypoints. Please refer to the example for the right format. | |
3. Upload an image of the object you want to pose to the query image. | |
4. Click **Evaluate** to pose the query image. | |
''') | |
with gr.Row(): | |
# Input block for node descriptions | |
node_descriptions = gr.Textbox(label="Node Descriptions (String separated by commas)", lines=5, type="text", | |
# value="left eye, right eye, nose, neck, root of tail, left shoulder, left elbow, " | |
# "left front paw, right shoulder, right elbow, right front paw, left hip, " | |
# "left knee, left back paw, right hip, right knee, right back paw" | |
value="left eye, nose, right eye" | |
) | |
# Input block for edges | |
edges = gr.Textbox(label="Edges (List of 2-valued lists representing connections)", lines=5, type="text", | |
# value="[[0, 1], [0, 2], [1, 2], [2, 3], [3, 4], [3, 5], [5, 6], [6, 7], [3, 8], " | |
# "[8, 9], [9, 10], [4, 11], [11, 12], [12, 13], [4, 14], [14, 15], [15, 16]]" | |
value="[[0,1], [1,2]]" | |
) | |
def set_initial_text_graph(): | |
text_graph = visualize_graph("left eye, nose, right eye", "[[0,1], [1,2]]") | |
return text_graph | |
text_graph = gr.Image(label="Text-graph visualization", | |
value=set_initial_text_graph, | |
type="pil", height=400, width=400) | |
with gr.Row(): | |
query_img = gr.Image(label="Query Image", | |
type="pil", height=400, width=400) | |
with gr.Row(): | |
eval_btn = gr.Button(value="Evaluate") | |
with gr.Row(): | |
output_img = gr.Plot(label="Output Image") | |
with gr.Row(): | |
gr.Markdown("## Examples") | |
with gr.Row(): | |
gr.Examples( | |
examples=[ | |
['examples/animal.png', | |
"left eye, right eye, nose, neck, root of tail, left shoulder, left elbow, " | |
"left front paw, right shoulder, right elbow, right front paw, left hip, " | |
"left knee, left back paw, right hip, right knee, right back paw", | |
"[[0, 1], [0, 2], [1, 2], [2, 3], [3, 4], [3, 5], [5, 6], [6, 7], [3, 8], [8, 9]," | |
"[9, 10], [4, 11], [11, 12], [12, 13], [4, 14], [14, 15], [15, 16]]" | |
], | |
['examples/person.png', | |
"nose, left eye, right eye, left ear, right ear, left shoulder, right shoulder, left elbow, " | |
"right elbow, left wrist, right wrist, left hip, right hip, left knee, right knee, left ankle, " | |
"right ankle", | |
"[[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7]," | |
"[6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]" | |
], | |
['examples/chair.png', | |
"left and front leg, right and front leg, right and back leg, left and back leg, " | |
"left and front side of the seat, right and front side of the seat, right and back side of the seat, " | |
"left and back side of the seat, top left side of the backseat, top right side of the backseat", | |
"[[0, 4], [3, 7], [1, 5], [2, 6], [4, 5], [5, 6], [6, 7], [7, 4], [6, 7], [7, 8],[8, 9], [9, 6]]", | |
], | |
['examples/car.png', | |
"front and right wheel, front and left wheel, rear and right wheel, rear and left wheel, " | |
"right headlight, left headlight, right taillight, left taillight, " | |
"front and right side of the top, front and left side of the top, rear and right side of the top, " | |
"rear and left side of the top", | |
"[[0, 2], [1, 3], [0, 1], [2, 3], [8, 10], [9, 11], [8, 9], [10, 11], [4, 0], " | |
"[4, 8], [4, 5], [5, 1], [5, 9], [6, 2], [6, 10], [7, 3], [7, 11], [6, 7]]" | |
] | |
], | |
inputs=[query_img, node_descriptions, edges], | |
outputs=[query_img, node_descriptions, edges], | |
fn=update_examples, | |
run_on_click=True, | |
) | |
eval_btn.click(fn=process, | |
inputs=[query_img, node_descriptions, edges], | |
outputs=[output_img]) | |
node_descriptions.change(visualize_graph, inputs=[node_descriptions, edges], outputs=[text_graph]) | |
edges.input(visualize_graph, inputs=[node_descriptions, edges], outputs=[text_graph]) | |
# visualize_button.click(fn=visualize_graph, | |
# inputs=[node_descriptions, edges, state], | |
# outputs=[text_graph, state]) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description='CapeX Demo') | |
parser.add_argument('--checkpoint', | |
help='checkpoint path', | |
default='swin-gte-split1.pth') | |
args = parser.parse_args() | |
checkpoint_path = args.checkpoint | |
demo.launch() | |