# # imports | |
# import os | |
# import json | |
# import base64 | |
# from io import BytesIO | |
# from dotenv import load_dotenv | |
# from openai import OpenAI | |
# import gradio as gr | |
# import numpy as np | |
# from PIL import Image, ImageDraw | |
# import requests | |
# import torch | |
# from transformers import ( | |
# AutoProcessor, | |
# Owlv2ForObjectDetection, | |
# AutoModelForZeroShotObjectDetection | |
# ) | |
# # from transformers import AutoProcessor, Owlv2ForObjectDetection | |
# from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD | |
# # Initialization | |
# load_dotenv() | |
# os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here') | |
# PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here') | |
# MODEL = "gpt-4o" | |
# openai = OpenAI() | |
# # Initialize models | |
# device = "cuda" if torch.cuda.is_available() else "cpu" | |
# # Owlv2 | |
# owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16") | |
# owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device) | |
# # DINO | |
# dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base") | |
# dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device) | |
# system_message = """You are an expert in object detection. When users mention: | |
# 1. "count [object(s)]" - Use detect_objects with proper format based on model | |
# 2. "detect [object(s)]" - Same as count | |
# 3. "show [object(s)]" - Same as count | |
# For DINO model: Format queries as "a [object]." (e.g., "a frog.") | |
# For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]] | |
# Always use object detection tool when counting/detecting is mentioned.""" | |
# system_message += "Always be accurate. If you don't know the answer, say so." | |
# class State: | |
# def __init__(self): | |
# self.current_image = None | |
# self.last_prediction = None | |
# self.current_model = "owlv2" # Default model | |
# state = State() | |
# def get_preprocessed_image(pixel_values): | |
# pixel_values = pixel_values.squeeze().numpy() | |
# unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None] | |
# unnormalized_image = (unnormalized_image * 255).astype(np.uint8) | |
# unnormalized_image = np.moveaxis(unnormalized_image, 0, -1) | |
# return unnormalized_image | |
# def encode_image_to_base64(image_array): | |
# if image_array is None: | |
# return None | |
# image = Image.fromarray(image_array) | |
# buffered = BytesIO() | |
# image.save(buffered, format="JPEG") | |
# return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
# def format_query_for_model(text_input, model_type="owlv2"): | |
# """Format query based on model requirements""" | |
# # Extract objects (e.g., "detect a lion" -> "lion") | |
# text = text_input.lower() | |
# words = [w.strip('.,?!') for w in text.split() | |
# if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']] | |
# if model_type == "owlv2": | |
# # Return just the list of queries for Owlv2, not nested list | |
# queries = ["a photo of " + obj for obj in words] | |
# print("Owlv2 queries:", queries) | |
# return queries | |
# else: # DINO | |
# # DINO query format | |
# query = f"a {words[:]}." | |
# print("DINO query:", query) | |
# return query | |
# def detect_objects(query_text): | |
# if state.current_image is None: | |
# return {"count": 0, "message": "No image provided"} | |
# image = Image.fromarray(state.current_image) | |
# draw = ImageDraw.Draw(image) | |
# if state.current_model == "owlv2": | |
# # For Owlv2, pass the text queries directly | |
# inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device) | |
# with torch.no_grad(): | |
# outputs = owlv2_model(**inputs) | |
# results = owlv2_processor.post_process_object_detection( | |
# outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]]) | |
# ) | |
# else: # DINO | |
# # For DINO, pass the single text query | |
# inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device) | |
# with torch.no_grad(): | |
# outputs = dino_model(**inputs) | |
# results = dino_processor.post_process_grounded_object_detection( | |
# outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3, | |
# target_sizes=[image.size[::-1]] | |
# ) | |
# # Draw detection boxes | |
# boxes = results[0]["boxes"] | |
# scores = results[0]["scores"] | |
# for box, score in zip(boxes, scores): | |
# box = [round(i) for i in box.tolist()] | |
# draw.rectangle(box, outline="red", width=3) | |
# draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red") | |
# state.last_prediction = np.array(image) | |
# return { | |
# "count": len(boxes), | |
# "confidence": scores.tolist(), | |
# "message": f"Detected {len(boxes)} objects" | |
# } | |
# def identify_plant(): | |
# if state.current_image is None: | |
# return {"error": "No image provided"} | |
# image = Image.fromarray(state.current_image) | |
# img_byte_arr = BytesIO() | |
# image.save(img_byte_arr, format='JPEG') | |
# img_byte_arr = img_byte_arr.getvalue() | |
# api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}" | |
# files = [('images', ('image.jpg', img_byte_arr))] | |
# data = {'organs': ['leaf']} | |
# try: | |
# response = requests.post(api_endpoint, files=files, data=data) | |
# if response.status_code == 200: | |
# result = response.json() | |
# best_match = result['results'][0] | |
# return { | |
# "scientific_name": best_match['species']['scientificName'], | |
# "common_names": best_match['species'].get('commonNames', []), | |
# "family": best_match['species']['family']['scientificName'], | |
# "genus": best_match['species']['genus']['scientificName'], | |
# "confidence": f"{best_match['score']*100:.1f}%" | |
# } | |
# else: | |
# return {"error": f"API Error: {response.status_code}"} | |
# except Exception as e: | |
# return {"error": f"Error: {str(e)}"} | |
# # Tool definitions | |
# object_detection_function = { | |
# "name": "detect_objects", | |
# "description": "Use this function to detect and count objects in images based on text queries.", | |
# "parameters": { | |
# "type": "object", | |
# "properties": { | |
# "query_text": { | |
# "type": "array", | |
# "description": "List of text queries describing objects to detect", | |
# "items": {"type": "string"} | |
# } | |
# } | |
# } | |
# } | |
# plant_identification_function = { | |
# "name": "identify_plant", | |
# "description": "Use this when asked about plant species identification or botanical classification.", | |
# "parameters": { | |
# "type": "object", | |
# "properties": {}, | |
# "required": [] | |
# } | |
# } | |
# tools = [ | |
# {"type": "function", "function": object_detection_function}, | |
# {"type": "function", "function": plant_identification_function} | |
# ] | |
# def format_tool_response(tool_response_content): | |
# data = json.loads(tool_response_content) | |
# if "error" in data: | |
# return f"Error: {data['error']}" | |
# elif "scientific_name" in data: | |
# return f"""📋 Plant Identification Results: | |
# 🌿 Scientific Name: {data['scientific_name']} | |
# 👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'} | |
# 👪 Family: {data['family']} | |
# 🎯 Confidence: {data['confidence']}""" | |
# else: | |
# return f"I detected {data['count']} objects in the image." | |
# def chat(message, image, history): | |
# if image is not None: | |
# state.current_image = image | |
# if state.current_image is None: | |
# return "Please upload an image first.", None | |
# base64_image = encode_image_to_base64(state.current_image) | |
# messages = [{"role": "system", "content": system_message}] | |
# for human, assistant in history: | |
# messages.append({"role": "user", "content": human}) | |
# messages.append({"role": "assistant", "content": assistant}) | |
# # Extract objects to detect from user message | |
# # This could be enhanced with better NLP | |
# objects_to_detect = message.lower() | |
# formatted_query = format_query_for_model(objects_to_detect, state.current_model) | |
# messages.append({ | |
# "role": "user", | |
# "content": [ | |
# {"type": "text", "text": message}, | |
# {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} | |
# ] | |
# }) | |
# response = openai.chat.completions.create( | |
# model=MODEL, | |
# messages=messages, | |
# tools=tools, | |
# max_tokens=300 | |
# ) | |
# if response.choices[0].finish_reason == "tool_calls": | |
# message = response.choices[0].message | |
# messages.append(message) | |
# for tool_call in message.tool_calls: | |
# if tool_call.function.name == "detect_objects": | |
# results = detect_objects(formatted_query) | |
# else: | |
# results = identify_plant() | |
# tool_response = { | |
# "role": "tool", | |
# "content": json.dumps(results), | |
# "tool_call_id": tool_call.id | |
# } | |
# messages.append(tool_response) | |
# response = openai.chat.completions.create( | |
# model=MODEL, | |
# messages=messages, | |
# max_tokens=300 | |
# ) | |
# return response.choices[0].message.content, state.last_prediction | |
# def update_model(choice): | |
# print(f"Model switched to: {choice}") | |
# state.current_model = choice.lower() | |
# return f"Model switched to {choice}" | |
# # Create Gradio interface | |
# with gr.Blocks() as demo: | |
# gr.Markdown("# Object Detection and Plant Analysis System") | |
# with gr.Row(): | |
# with gr.Column(): | |
# model_choice = gr.Radio( | |
# choices=["Owlv2", "DINO"], | |
# value="Owlv2", | |
# label="Select Detection Model", | |
# interactive=True | |
# ) | |
# image_input = gr.Image(type="numpy", label="Upload Image") | |
# text_input = gr.Textbox( | |
# label="Ask about the image", | |
# placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'" | |
# ) | |
# with gr.Row(): | |
# submit_btn = gr.Button("Analyze") | |
# reset_btn = gr.Button("Reset") | |
# with gr.Column(): | |
# chatbot = gr.Chatbot() | |
# # output_image = gr.Image(label="Detected Objects") | |
# output_image = gr.Image(type="numpy", label="Detected Objects") | |
# def process_interaction(message, image, history): | |
# response, pred_image = chat(message, image, history) | |
# history.append((message, response)) | |
# return "", pred_image, history | |
# def reset_interface(): | |
# state.current_image = None | |
# state.last_prediction = None | |
# return None, None, None, [] | |
# model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)]) | |
# submit_btn.click( | |
# fn=process_interaction, | |
# inputs=[text_input, image_input, chatbot], | |
# outputs=[text_input, output_image, chatbot] | |
# ) | |
# reset_btn.click( | |
# fn=reset_interface, | |
# inputs=[], | |
# outputs=[image_input, output_image, text_input, chatbot] | |
# ) | |
# gr.Markdown("""## Instructions | |
# 1. Select the detection model (Owlv2 or DINO) | |
# 2. Upload an image | |
# 3. Ask specific questions about objects or plants | |
# 4. Click Analyze to get results""") | |
# demo.launch(share=True) | |
import os | |
import openai | |
import gradio as gr | |
import vision_agent.tools as T | |
# Set your OpenAI API key (ensure the environment variable is set or replace with your key) | |
openai.api_key = os.getenv("OPENAI_API_KEY", "your-openai-api-key-here") | |
def get_single_prompt(user_input): | |
""" | |
Uses OpenAI to rephrase the user's chatter into a single, concise prompt for object detection. | |
The generated prompt will not include any question marks. | |
""" | |
if not user_input.strip(): | |
user_input = "Detect objects in the image" | |
prompt_instruction = ( | |
f"Based on the following user input, generate a single, concise prompt for object detection. " | |
f"Do not include any question marks in the output. " | |
f"User input: \"{user_input}\"" | |
) | |
response = openai.chat.completions.create( | |
model="gpt-4o", # adjust model name if needed | |
messages=[{"role": "user", "content": prompt_instruction}], | |
temperature=0.3, | |
max_tokens=50, | |
) | |
generated_prompt = response.choices[0].message.content.strip() | |
# Ensure no question marks remain. | |
generated_prompt = generated_prompt.replace("?", "") | |
return generated_prompt | |
def is_count_query(user_input): | |
""" | |
Check if the user's input indicates a counting request. | |
Looks for common keywords such as "count", "how many", "number of", etc. | |
""" | |
keywords = ["count", "how many", "number of", "total", "get me a count"] | |
for kw in keywords: | |
if kw.lower() in user_input.lower(): | |
return True | |
return False | |
def process_question_and_detect(user_input, image): | |
""" | |
1. Uses OpenAI to generate a single, concise prompt (without question marks) from the user's input. | |
2. Feeds that prompt to the VisionAgent detection function. | |
3. Overlays the detection bounding boxes on the image. | |
4. If the user's input implies a counting request, it also returns the count of detected objects. | |
""" | |
if image is None: | |
return None, "Please upload an image." | |
# Generate the concise prompt from the user's input. | |
generated_prompt = get_single_prompt(user_input) | |
# Run object detection using the generated prompt. | |
dets = T.agentic_object_detection(generated_prompt, image) | |
# Overlay bounding boxes on the image. | |
viz = T.overlay_bounding_boxes(image, dets) | |
# If the user's input implies a counting request, include the count. | |
count_text = "" | |
if is_count_query(user_input): | |
count = len(dets) | |
count_text = f"Detected {count} objects." | |
output_text = f"Generated prompt: {generated_prompt}\n{count_text}" | |
print(output_text) | |
return viz, output_text | |
with gr.Blocks() as demo: | |
gr.Markdown("# VisionAgent Object Detection and Counting App") | |
gr.Markdown( | |
""" | |
Enter your input (for example: | |
- "What is the number of fruit in my image?" | |
- "How many bicycles can you see?" | |
- "Get me a count of my bottles") | |
and upload an image. | |
The app uses OpenAI to generate a single, concise prompt for object detection (without question marks), | |
then runs the detection. If your input implies a counting request, it will also display the count of detected objects. | |
""" | |
) | |
with gr.Row(): | |
user_input = gr.Textbox(label="Enter your input", placeholder="Type your input here...") | |
image_input = gr.Image(label="Upload Image", type="numpy") | |
submit_btn = gr.Button("Detect and Count") | |
output_image = gr.Image(label="Detection Result") | |
output_text = gr.Textbox(label="Output Details") | |
submit_btn.click(fn=process_question_and_detect, inputs=[user_input, image_input], outputs=[output_image, output_text]) | |
demo.launch(share=True) | |