Spaces:
Runtime error
Runtime error
import shutil | |
import cv2 | |
from PIL import Image | |
import streamlit as st | |
from transformers import AutoModelForObjectDetection, AutoFeatureExtractor | |
import torch | |
import matplotlib.pyplot as plt | |
from stqdm import stqdm | |
from pathlib import Path | |
# Load the model | |
best_model_path = "zoheb/yolos-small-balloon" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
feature_extractor = AutoFeatureExtractor.from_pretrained(best_model_path, size=512, max_size=864) | |
model_pt = AutoModelForObjectDetection.from_pretrained(best_model_path).to(device) | |
# colors for visualization | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], | |
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
# Convert Video to Frames | |
def video_to_frames(video, dir): | |
cap = cv2.VideoCapture(str(video)) | |
success, image = cap.read() | |
frame_count = 0 | |
while success: | |
frameId = int(round(cap.get(1))) # current frame number | |
if frameId % 5 == 0: | |
cv2.imwrite(f"{str(dir)}/frame_{frame_count}.jpg", image) | |
frame_count += 1 | |
success, image = cap.read() | |
cap.release() | |
#print (f"No. of frames {frame_count}") | |
# for output bounding box post-processing | |
def box_cxcywh_to_xyxy(x): | |
x_c, y_c, w, h = x.unbind(1) | |
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), | |
(x_c + 0.5 * w), (y_c + 0.5 * h)] | |
return torch.stack(b, dim=1) | |
# rescale bboxes | |
def rescale_bboxes(out_bbox, size): | |
img_w, img_h = size | |
b = box_cxcywh_to_xyxy(out_bbox) | |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
return b | |
# Save predicted frame | |
def save_results(pil_img, prob, boxes, mod_img_path): | |
plt.figure(figsize=(18,10)) | |
plt.imshow(pil_img) | |
id2label = {0: 'balloon'} | |
ax = plt.gca() | |
colors = COLORS * 100 | |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, | |
fill=False, color=c, linewidth=3)) | |
cl = p.argmax() | |
text = f'{id2label[cl.item()]}: {p[cl]:0.2f}' | |
ax.text(xmin, ymin, text, fontsize=15, | |
bbox=dict(facecolor='yellow', alpha=0.5)) | |
plt.axis('off') | |
plt.tight_layout(pad=0) | |
plt.savefig(mod_img_path, transparent=True) | |
plt.close() | |
# Save predictions | |
def save_predictions(image, outputs, mod_img_path, threshold=0.9): | |
# keep only predictions with confidence >= threshold | |
probas = outputs.logits.softmax(-1)[0, :, :-1] | |
keep = probas.max(-1).values > threshold | |
# convert predicted boxes from [0; 1] to image scales | |
bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size) | |
# save results | |
save_results(image, probas[keep], bboxes_scaled, mod_img_path) | |
# Predict on frames | |
def predict_on_frames(dir, mod_dir): | |
files = [f for f in dir.glob('*.jpg') if f.is_file()] | |
#for sorting the file names properly | |
files.sort(key = lambda x: int(x.stem[6:])) | |
for i in stqdm(range(len(files)), desc="Generating... this is a slow task"): | |
filename = Path(dir, files[i]) | |
#print(filename) | |
#reading each files | |
img = Image.open(str(filename)) | |
# extract features | |
img_ftr = feature_extractor(images=img, return_tensors="pt") | |
pixel_values = img_ftr["pixel_values"].to(device) | |
# forward pass to get class logits and bounding boxes | |
outputs = model_pt(pixel_values=pixel_values) | |
mod_img_path = Path(mod_dir, files[i].name) | |
save_predictions(img, outputs, mod_img_path) | |
# Convert frames to video | |
def frames_to_video(dir, path, fps=5): | |
frame_array = [] | |
files = [f for f in dir.glob('*.jpg') if f.is_file()] | |
#for sorting the file names properly | |
files.sort(key = lambda x: int(x.stem[6:])) | |
for file in files: | |
filename = Path(dir, file) | |
#reading each files | |
img = cv2.imread(str(filename)) | |
height, width, _ = img.shape | |
size = (width, height) | |
#print(filename) | |
#inserting the frames into an image array | |
frame_array.append(img) | |
out = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*'DIVX'), fps, size) | |
for item in frame_array: | |
# writing to a image array | |
out.write(item) | |
out.release() | |
# Main | |
if __name__=='__main__': | |
st.title('Detect Balloons using YOLOS') | |
# All dir and Files | |
BASE_DIR = Path(__file__).parent.absolute() | |
FRAMES_DIR = Path(BASE_DIR, "extracted_images") | |
MOD_DIR = Path(BASE_DIR, "modified_images") | |
if FRAMES_DIR.exists() and FRAMES_DIR.is_dir(): | |
shutil.rmtree(FRAMES_DIR) | |
FRAMES_DIR.mkdir(parents=True, exist_ok=True) | |
if MOD_DIR.exists() and MOD_DIR.is_dir(): | |
shutil.rmtree(MOD_DIR) | |
MOD_DIR.mkdir(parents=True, exist_ok=True) | |
generated_video = Path(BASE_DIR, "final_video.mp4") | |
# Upload the video | |
uploaded_file = st.file_uploader("Upload a small video containing Balloons", type=["mp4"]) | |
if uploaded_file is not None: | |
st.video(uploaded_file) | |
vid = uploaded_file.name | |
st.info(f'Uploaded {vid}') | |
with open(vid, mode='wb') as f: | |
f.write(uploaded_file.read()) | |
uploaded_video = Path(BASE_DIR, vid) | |
# Detect balloon in the frames and generate video | |
try: | |
video_to_frames(uploaded_video, FRAMES_DIR) | |
predict_on_frames(FRAMES_DIR, MOD_DIR) | |
frames_to_video(MOD_DIR, generated_video) | |
st.success("Successfully Generated!!") | |
# Video file Generated | |
video_file = open(str(generated_video), 'rb') | |
video_bytes = video_file.read() | |
st.video(video_bytes) | |
st.download_button('Download the Video', video_bytes, file_name=generated_video.name) | |
except Exception as e: | |
st.error(f"Could not convert the file due to {e}") | |
else: | |
st.info('File Not Uploaded Yet!!!') | |