import streamlit as st import cv2 import numpy as np from PIL import Image import torch import matplotlib.pyplot as plt from transformers import OwlViTProcessor, OwlViTForObjectDetection from transformers.image_utils import ImageFeatureExtractionMixin st.set_option('deprecation.showfileUploaderEncoding', False) model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32") processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") device = torch.device("cpu") model = model.to(device) model.eval() st.title('Zero-shot Object Detection') # Input image and query image upload col1, col2 = st.columns(2) with col1: uploaded_image = st.file_uploader("Upload input image(image to predict)", type=["jpg", "jpeg", "png"]) if uploaded_image is not None: image = Image.open(uploaded_image) st.image(image, caption='Input Image', use_column_width=True) #else: #st.image('2.png', caption='Input Image', use_column_width=True) with col2: uploaded_query = st.file_uploader("Upload query image(image contains object we wanna predict)", type=["jpg", "jpeg", "png"]) if uploaded_query is not None: query_image = Image.open(uploaded_query) st.image(query_image, caption='Query Image', use_column_width=True) #else: #st.image('1.png', caption='Input Image', use_column_width=True) # Threshold ratio bar and class name input threshold_ratio = st.slider('Select threshold ratio:', min_value=0.0, max_value=1.0, step=0.1, value=0.6) #class_name = st.text_input('Enter class name:', value='agumon') # start_button = st.button('Start prediction') if uploaded_image is not None and uploaded_query is not None and start_button: # Process input and query image target_sizes = torch.Tensor([image.size[::-1]]) inputs = processor(images=image, query_images=query_image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model.image_guided_detection(**inputs) img = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB) outputs.logits = outputs.logits.cpu() outputs.target_pred_boxes = outputs.target_pred_boxes.cpu() results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold_ratio, nms_threshold=0.3, target_sizes=target_sizes) boxes, scores = results[0]["boxes"], results[0]["scores"] # Draw predicted bounding boxes and text for box, score in zip(boxes, scores): box = [int(i) for i in box.tolist()] cx,cy,x,y=box img = cv2.rectangle(img, box[:2], box[2:], (255,0,0), 5) if box[3] + 25 > 768: y = box[3] - 10 else: y = box[3] + 25 plt.imshow(img[:,:,::-1]) output_image = img[:,:,::-1] st.image(output_image, caption='Predicted Image', use_column_width=True) else: st.write('Please upload an image and a query image.')