ImageIQ / app.py
OpenRAG128's picture
Upload app.py
2e7f273 verified
from flask import Flask, render_template, request, jsonify
from PIL import Image
from io import BytesIO
import torch
from torchvision import models, transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
import os
app = Flask(__name__)
# Load ImageNet class index
def load_imagenet_class_index():
class_index_path = 'imagenet_classes.txt'
if not os.path.exists(class_index_path):
raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}")
with open(class_index_path) as f:
classes = [line.strip() for line in f.readlines()]
return classes
imagenet_classes = load_imagenet_class_index()
# Load pre-trained models
resnet = models.resnet50(pretrained=True)
resnet.eval()
fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True)
fasterrcnn.eval()
# Image transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# COCO dataset class names
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
# Function for real image analysis
def real_image_analysis(image):
# Prepare image for classification
img_t = transform(image)
batch_t = torch.unsqueeze(img_t, 0)
# Classification
with torch.no_grad():
output = resnet(batch_t)
# Get top 3 predictions
_, indices = torch.sort(output, descending=True)
percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100
objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]]
# Object detection using Faster R-CNN
img_tensor = transforms.ToTensor()(image).unsqueeze(0)
with torch.no_grad():
prediction = fasterrcnn(img_tensor)
# Get detected objects
detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']]
objects.extend(detected_objects)
objects = list(set(objects)) # Remove duplicates
# Get dominant colors
colors = get_dominant_colors(image)
# Determine scene (indoor/outdoor)
scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor"
return {
"objects": objects[:5], # Limit to top 5 objects
"colors": colors,
"scene": scene
}
# Function to get dominant colors from an image
def get_dominant_colors(image, num_colors=3):
# Resize image to speed up processing
img = image.copy()
img.thumbnail((100, 100))
# Get colors from the image
paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors)
palette = paletted.getpalette()
color_counts = sorted(paletted.getcolors(), reverse=True)
colors = []
for i in range(num_colors):
palette_index = color_counts[i][1]
dominant_color = palette[palette_index*3:palette_index*3+3]
colors.append(rgb_to_name(dominant_color))
return colors
# Function to convert RGB to color name (simplified)
def rgb_to_name(rgb):
r, g, b = rgb
if r > g and r > b:
return "red"
elif g > r and g > b:
return "green"
elif b > r and b > g:
return "blue"
else:
return "gray"
# Function to simulate the generation of answers from metadata
def generate_answer_from_metadata(metadata, question, complexity):
prompt = f"""
The image contains the following objects: {', '.join(metadata['objects'])}.
The dominant colors are {', '.join(metadata['colors'])}.
It appears to be an {metadata['scene']} scene.
Based on this, provide a {complexity.lower()} response to the following question: {question}
"""
# Since `client` is not defined, we can simulate a response here
# Replace this section with the actual client code if using an API
return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}."
# Flask routes
@app.route('/')
def index():
return render_template('index.html')
@app.route('/ask', methods=['POST'])
def ask_question():
image = request.files.get('image')
question = request.form.get('question')
complexity = request.form.get('complexity', 'Default')
if not image or not question:
return jsonify({"error": "Missing image or question"}), 400
# Process the image
image = Image.open(image).convert("RGB")
# Perform real image analysis
metadata = real_image_analysis(image)
# Generate the answer
try:
answer = generate_answer_from_metadata(metadata, question, complexity)
return jsonify({"answer": answer})
except Exception as e:
print(f"Error generating answer: {str(e)}")
return jsonify({"error": "Failed to generate answer"}), 500
if __name__ == '__main__':
app.run(debug=True)