Spaces:
Running
Running
from flask import Flask, render_template, request, jsonify | |
from PIL import Image | |
from io import BytesIO | |
import torch | |
from torchvision import models, transforms | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
import os | |
app = Flask(__name__) | |
# Load ImageNet class index | |
def load_imagenet_class_index(): | |
class_index_path = 'imagenet_classes.txt' | |
if not os.path.exists(class_index_path): | |
raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}") | |
with open(class_index_path) as f: | |
classes = [line.strip() for line in f.readlines()] | |
return classes | |
imagenet_classes = load_imagenet_class_index() | |
# Load pre-trained models | |
resnet = models.resnet50(pretrained=True) | |
resnet.eval() | |
fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True) | |
fasterrcnn.eval() | |
# Image transformation | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# COCO dataset class names | |
COCO_INSTANCE_CATEGORY_NAMES = [ | |
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | |
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', | |
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', | |
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', | |
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', | |
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', | |
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | |
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', | |
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', | |
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', | |
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', | |
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' | |
] | |
# Function for real image analysis | |
def real_image_analysis(image): | |
# Prepare image for classification | |
img_t = transform(image) | |
batch_t = torch.unsqueeze(img_t, 0) | |
# Classification | |
with torch.no_grad(): | |
output = resnet(batch_t) | |
# Get top 3 predictions | |
_, indices = torch.sort(output, descending=True) | |
percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100 | |
objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]] | |
# Object detection using Faster R-CNN | |
img_tensor = transforms.ToTensor()(image).unsqueeze(0) | |
with torch.no_grad(): | |
prediction = fasterrcnn(img_tensor) | |
# Get detected objects | |
detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']] | |
objects.extend(detected_objects) | |
objects = list(set(objects)) # Remove duplicates | |
# Get dominant colors | |
colors = get_dominant_colors(image) | |
# Determine scene (indoor/outdoor) | |
scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor" | |
return { | |
"objects": objects[:5], # Limit to top 5 objects | |
"colors": colors, | |
"scene": scene | |
} | |
# Function to get dominant colors from an image | |
def get_dominant_colors(image, num_colors=3): | |
# Resize image to speed up processing | |
img = image.copy() | |
img.thumbnail((100, 100)) | |
# Get colors from the image | |
paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors) | |
palette = paletted.getpalette() | |
color_counts = sorted(paletted.getcolors(), reverse=True) | |
colors = [] | |
for i in range(num_colors): | |
palette_index = color_counts[i][1] | |
dominant_color = palette[palette_index*3:palette_index*3+3] | |
colors.append(rgb_to_name(dominant_color)) | |
return colors | |
# Function to convert RGB to color name (simplified) | |
def rgb_to_name(rgb): | |
r, g, b = rgb | |
if r > g and r > b: | |
return "red" | |
elif g > r and g > b: | |
return "green" | |
elif b > r and b > g: | |
return "blue" | |
else: | |
return "gray" | |
# Function to simulate the generation of answers from metadata | |
def generate_answer_from_metadata(metadata, question, complexity): | |
prompt = f""" | |
The image contains the following objects: {', '.join(metadata['objects'])}. | |
The dominant colors are {', '.join(metadata['colors'])}. | |
It appears to be an {metadata['scene']} scene. | |
Based on this, provide a {complexity.lower()} response to the following question: {question} | |
""" | |
# Since `client` is not defined, we can simulate a response here | |
# Replace this section with the actual client code if using an API | |
return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}." | |
# Flask routes | |
def index(): | |
return render_template('index.html') | |
def ask_question(): | |
image = request.files.get('image') | |
question = request.form.get('question') | |
complexity = request.form.get('complexity', 'Default') | |
if not image or not question: | |
return jsonify({"error": "Missing image or question"}), 400 | |
# Process the image | |
image = Image.open(image).convert("RGB") | |
# Perform real image analysis | |
metadata = real_image_analysis(image) | |
# Generate the answer | |
try: | |
answer = generate_answer_from_metadata(metadata, question, complexity) | |
return jsonify({"answer": answer}) | |
except Exception as e: | |
print(f"Error generating answer: {str(e)}") | |
return jsonify({"error": "Failed to generate answer"}), 500 | |
if __name__ == '__main__': | |
app.run(debug=True) | |