Spaces:
Paused
Paused
Update web-demos/hugging_face/app.py
Browse files
web-demos/hugging_face/app.py
CHANGED
|
@@ -13,7 +13,6 @@ import torch
|
|
| 13 |
import torchvision
|
| 14 |
import numpy as np
|
| 15 |
import gradio as gr
|
| 16 |
-
from PIL import Image
|
| 17 |
|
| 18 |
from tools.painter import mask_painter
|
| 19 |
from track_anything import TrackingAnything
|
|
@@ -253,14 +252,10 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 253 |
template_mask[0][0]=1
|
| 254 |
operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
|
| 255 |
# return video_output, video_state, interactive_state, operation_error
|
| 256 |
-
masks, logits, painted_images
|
| 257 |
-
|
| 258 |
# clear GPU memory
|
| 259 |
model.cutie.clear_memory()
|
| 260 |
|
| 261 |
-
# сохранить альфа-канальные маски в состояние (для отображения или сохранения видео)
|
| 262 |
-
video_state["alpha_visuals"] = alpha_visuals
|
| 263 |
-
|
| 264 |
if interactive_state["track_end_number"]:
|
| 265 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
| 266 |
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
|
|
@@ -272,10 +267,6 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
|
| 272 |
|
| 273 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=float(fps)) # import video_input to name the output video
|
| 274 |
interactive_state["inference_times"] += 1
|
| 275 |
-
# Дополнительно: альфа-маска-видео
|
| 276 |
-
if "alpha_visuals" in video_state:
|
| 277 |
-
generate_video_from_frames(video_state["alpha_visuals"], output_path="./result/track/{}_alpha.mp4".format(video_state["video_name"].split('.')[0]), fps=float(fps), is_rgba=True)
|
| 278 |
-
|
| 279 |
|
| 280 |
print("Tracking resolution:", following_frames[0].shape)
|
| 281 |
|
|
@@ -334,7 +325,7 @@ def inpaint_video(video_state, resize_ratio_number, dilate_radius_number, raft_i
|
|
| 334 |
|
| 335 |
|
| 336 |
# generate video after vos inference
|
| 337 |
-
def generate_video_from_frames(frames, output_path, fps=30
|
| 338 |
"""
|
| 339 |
Generates a video from a list of frames.
|
| 340 |
|
|
@@ -365,11 +356,6 @@ def generate_video_from_frames(frames, output_path, fps=30,is_rgba=False):
|
|
| 365 |
if not os.path.exists(os.path.dirname(output_path)):
|
| 366 |
os.makedirs(os.path.dirname(output_path))
|
| 367 |
|
| 368 |
-
if is_rgba:
|
| 369 |
-
frames = torch.from_numpy(np.asarray(frames).astype(np.uint8))
|
| 370 |
-
else:
|
| 371 |
-
frames = torch.from_numpy(np.asarray(frames))
|
| 372 |
-
|
| 373 |
# Write the video
|
| 374 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 375 |
return output_path
|
|
|
|
| 13 |
import torchvision
|
| 14 |
import numpy as np
|
| 15 |
import gradio as gr
|
|
|
|
| 16 |
|
| 17 |
from tools.painter import mask_painter
|
| 18 |
from track_anything import TrackingAnything
|
|
|
|
| 252 |
template_mask[0][0]=1
|
| 253 |
operation_log = [("Please add at least one mask to track by clicking the image in step2.","Error"), ("","")]
|
| 254 |
# return video_output, video_state, interactive_state, operation_error
|
| 255 |
+
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
|
|
|
| 256 |
# clear GPU memory
|
| 257 |
model.cutie.clear_memory()
|
| 258 |
|
|
|
|
|
|
|
|
|
|
| 259 |
if interactive_state["track_end_number"]:
|
| 260 |
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
| 261 |
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
|
|
|
|
| 267 |
|
| 268 |
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/track/{}".format(video_state["video_name"]), fps=float(fps)) # import video_input to name the output video
|
| 269 |
interactive_state["inference_times"] += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
print("Tracking resolution:", following_frames[0].shape)
|
| 272 |
|
|
|
|
| 325 |
|
| 326 |
|
| 327 |
# generate video after vos inference
|
| 328 |
+
def generate_video_from_frames(frames, output_path, fps=30):
|
| 329 |
"""
|
| 330 |
Generates a video from a list of frames.
|
| 331 |
|
|
|
|
| 356 |
if not os.path.exists(os.path.dirname(output_path)):
|
| 357 |
os.makedirs(os.path.dirname(output_path))
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
# Write the video
|
| 360 |
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
| 361 |
return output_path
|