RF-DETR / app.py
SkalskiP's picture
update UI to support video inference
1015457
raw
history blame
7.16 kB
from typing import Union
import gradio as gr
import numpy as np
import supervision as sv
from PIL import Image
from rfdetr import RFDETRBase, RFDETRLarge
from rfdetr.detr import RFDETR
from rfdetr.util.coco_classes import COCO_CLASSES
from utils.image import calculate_resolution_wh
from utils.video import create_directory
MARKDOWN = """
# RF-DETR 🔥
<div>
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-finetune-rf-detr-on-detection-dataset.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="colab" style="display:inline-block;">
</a>
<a href="https://blog.roboflow.com/rf-detr">
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="roboflow" style="display:inline-block;">
</a>
<a href="https://github.com/roboflow/rf-detr">
<img src="https://badges.aleen42.com/src/github.svg" alt="roboflow" style="display:inline-block;">
</a>
</div>
RF-DETR is a real-time, transformer-based object detection model architecture developed
by [Roboflow](https://roboflow.com/) and released under the Apache 2.0 license.
"""
IMAGE_EXAMPLES = [
['https://media.roboflow.com/supervision/image-examples/people-walking.png', 0.3, 728, "large"],
['https://media.roboflow.com/supervision/image-examples/vehicles.png', 0.3, 728, "large"],
['https://media.roboflow.com/notebooks/examples/dog-2.jpeg', 0.5, 560, "base"],
]
COLOR = sv.ColorPalette.from_hex([
"#ffff00", "#ff9b00", "#ff8080", "#ff66b2", "#ff66ff", "#b266ff",
"#9999ff", "#3399ff", "#66ffff", "#33ff99", "#66ff66", "#99ff00"
])
VIDEO_SCALE_FACTOR = 0.5
VIDEO_TARGET_DIRECTORY = "tmp"
create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
def detect_and_annotate(model: RFDETR, image: Union[Image.Image, np.ndarray], confidence: float):
detections = model.predict(image, threshold=confidence)
resolution_wh = calculate_resolution_wh(image)
text_scale = sv.calculate_optimal_text_scale(resolution_wh=resolution_wh) - 0.2
thickness = sv.calculate_optimal_line_thickness(resolution_wh=resolution_wh)
bbox_annotator = sv.BoxAnnotator(color=COLOR, thickness=thickness)
label_annotator = sv.LabelAnnotator(
color=COLOR,
text_color=sv.Color.BLACK,
text_scale=text_scale,
smart_position=True
)
labels = [
f"{COCO_CLASSES[class_id]} {confidence:.2f}"
for class_id, confidence
in zip(detections.class_id, detections.confidence)
]
annotated_image = image.copy()
annotated_image = bbox_annotator.annotate(annotated_image, detections)
annotated_image = label_annotator.annotate(annotated_image, detections, labels)
return annotated_image
def image_processing_inference(input_image: Image.Image, confidence: float, resolution: int, checkpoint: str):
model_class = RFDETRBase if checkpoint == "base" else RFDETRLarge
model = model_class(resolution=resolution)
return detect_and_annotate(model=model, image=input_image, confidence=confidence)
def video_processing_inference(input_video: str, confidence: float, resolution: int, checkpoint: str):
model_class = RFDETRBase if checkpoint == "base" else RFDETRLarge
model = model_class(resolution=resolution)
return input_video
with gr.Blocks() as demo:
gr.Markdown(MARKDOWN)
with gr.Tab("Image"):
with gr.Row():
image_processing_input_image = gr.Image(
label="Upload image",
image_mode='RGB',
type='pil',
height=600
)
image_processing_output_image = gr.Image(
label="Output image",
image_mode='RGB',
type='pil',
height=600
)
with gr.Row():
with gr.Column():
image_processing_confidence_slider = gr.Slider(
label="Confidence",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
image_processing_resolution_slider = gr.Slider(
label="Inference resolution",
minimum=560,
maximum=1120,
step=56,
value=728,
)
image_processing_checkpoint_dropdown = gr.Dropdown(
label="Checkpoint",
choices=["base", "large"],
value="base"
)
with gr.Column():
image_processing_submit_button = gr.Button("Submit", value="primary")
gr.Examples(
fn=image_processing_inference,
examples=IMAGE_EXAMPLES,
inputs=[
image_processing_input_image,
image_processing_confidence_slider,
image_processing_resolution_slider,
image_processing_checkpoint_dropdown
],
outputs=image_processing_output_image,
cache_examples=True
)
image_processing_submit_button.click(
image_processing_inference,
inputs=[
image_processing_input_image,
image_processing_confidence_slider,
image_processing_resolution_slider,
image_processing_checkpoint_dropdown
],
outputs=image_processing_output_image
)
with gr.Tab("Video"):
with gr.Row():
video_processing_input_video = gr.Video(
label='Upload video',
height=600
)
video_processing_output_video = gr.Video(
label='Output video',
height=600
)
with gr.Row():
with gr.Column():
video_processing_confidence_slider = gr.Slider(
label="Confidence",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
video_processing_resolution_slider = gr.Slider(
label="Inference resolution",
minimum=560,
maximum=1120,
step=56,
value=728,
)
video_processing_checkpoint_dropdown = gr.Dropdown(
label="Checkpoint",
choices=["base", "large"],
value="base"
)
with gr.Column():
video_processing_submit_button = gr.Button("Submit", value="primary")
video_processing_submit_button.click(
video_processing_inference,
inputs=[
video_processing_input_video,
video_processing_confidence_slider,
video_processing_resolution_slider,
video_processing_checkpoint_dropdown
],
outputs=video_processing_output_video
)
demo.launch(debug=False, show_error=True)