Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify
|
2 |
+
from PIL import Image
|
3 |
+
from io import BytesIO
|
4 |
+
import torch
|
5 |
+
from torchvision import models, transforms
|
6 |
+
from torchvision.models.detection import fasterrcnn_resnet50_fpn
|
7 |
+
import os
|
8 |
+
|
9 |
+
app = Flask(__name__)
|
10 |
+
|
11 |
+
# Load ImageNet class index
|
12 |
+
def load_imagenet_class_index():
|
13 |
+
class_index_path = 'imagenet_classes.txt'
|
14 |
+
if not os.path.exists(class_index_path):
|
15 |
+
raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}")
|
16 |
+
|
17 |
+
with open(class_index_path) as f:
|
18 |
+
classes = [line.strip() for line in f.readlines()]
|
19 |
+
return classes
|
20 |
+
|
21 |
+
imagenet_classes = load_imagenet_class_index()
|
22 |
+
|
23 |
+
# Load pre-trained models
|
24 |
+
resnet = models.resnet50(pretrained=True)
|
25 |
+
resnet.eval()
|
26 |
+
fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True)
|
27 |
+
fasterrcnn.eval()
|
28 |
+
|
29 |
+
# Image transformation
|
30 |
+
transform = transforms.Compose([
|
31 |
+
transforms.Resize(256),
|
32 |
+
transforms.CenterCrop(224),
|
33 |
+
transforms.ToTensor(),
|
34 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
35 |
+
])
|
36 |
+
|
37 |
+
# COCO dataset class names
|
38 |
+
COCO_INSTANCE_CATEGORY_NAMES = [
|
39 |
+
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
40 |
+
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
|
41 |
+
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
|
42 |
+
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
|
43 |
+
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
44 |
+
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
45 |
+
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
|
46 |
+
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
47 |
+
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
|
48 |
+
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
|
49 |
+
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
|
50 |
+
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
51 |
+
]
|
52 |
+
|
53 |
+
# Function for real image analysis
|
54 |
+
def real_image_analysis(image):
|
55 |
+
# Prepare image for classification
|
56 |
+
img_t = transform(image)
|
57 |
+
batch_t = torch.unsqueeze(img_t, 0)
|
58 |
+
|
59 |
+
# Classification
|
60 |
+
with torch.no_grad():
|
61 |
+
output = resnet(batch_t)
|
62 |
+
|
63 |
+
# Get top 3 predictions
|
64 |
+
_, indices = torch.sort(output, descending=True)
|
65 |
+
percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100
|
66 |
+
objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]]
|
67 |
+
|
68 |
+
# Object detection using Faster R-CNN
|
69 |
+
img_tensor = transforms.ToTensor()(image).unsqueeze(0)
|
70 |
+
with torch.no_grad():
|
71 |
+
prediction = fasterrcnn(img_tensor)
|
72 |
+
|
73 |
+
# Get detected objects
|
74 |
+
detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']]
|
75 |
+
objects.extend(detected_objects)
|
76 |
+
objects = list(set(objects)) # Remove duplicates
|
77 |
+
|
78 |
+
# Get dominant colors
|
79 |
+
colors = get_dominant_colors(image)
|
80 |
+
|
81 |
+
# Determine scene (indoor/outdoor)
|
82 |
+
scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor"
|
83 |
+
|
84 |
+
return {
|
85 |
+
"objects": objects[:5], # Limit to top 5 objects
|
86 |
+
"colors": colors,
|
87 |
+
"scene": scene
|
88 |
+
}
|
89 |
+
|
90 |
+
# Function to get dominant colors from an image
|
91 |
+
def get_dominant_colors(image, num_colors=3):
|
92 |
+
# Resize image to speed up processing
|
93 |
+
img = image.copy()
|
94 |
+
img.thumbnail((100, 100))
|
95 |
+
|
96 |
+
# Get colors from the image
|
97 |
+
paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors)
|
98 |
+
palette = paletted.getpalette()
|
99 |
+
color_counts = sorted(paletted.getcolors(), reverse=True)
|
100 |
+
colors = []
|
101 |
+
for i in range(num_colors):
|
102 |
+
palette_index = color_counts[i][1]
|
103 |
+
dominant_color = palette[palette_index*3:palette_index*3+3]
|
104 |
+
colors.append(rgb_to_name(dominant_color))
|
105 |
+
return colors
|
106 |
+
|
107 |
+
# Function to convert RGB to color name (simplified)
|
108 |
+
def rgb_to_name(rgb):
|
109 |
+
r, g, b = rgb
|
110 |
+
if r > g and r > b:
|
111 |
+
return "red"
|
112 |
+
elif g > r and g > b:
|
113 |
+
return "green"
|
114 |
+
elif b > r and b > g:
|
115 |
+
return "blue"
|
116 |
+
else:
|
117 |
+
return "gray"
|
118 |
+
|
119 |
+
# Function to simulate the generation of answers from metadata
|
120 |
+
def generate_answer_from_metadata(metadata, question, complexity):
|
121 |
+
prompt = f"""
|
122 |
+
The image contains the following objects: {', '.join(metadata['objects'])}.
|
123 |
+
The dominant colors are {', '.join(metadata['colors'])}.
|
124 |
+
It appears to be an {metadata['scene']} scene.
|
125 |
+
|
126 |
+
Based on this, provide a {complexity.lower()} response to the following question: {question}
|
127 |
+
"""
|
128 |
+
|
129 |
+
# Since `client` is not defined, we can simulate a response here
|
130 |
+
# Replace this section with the actual client code if using an API
|
131 |
+
return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}."
|
132 |
+
|
133 |
+
# Flask routes
|
134 |
+
@app.route('/')
|
135 |
+
def index():
|
136 |
+
return render_template('index.html')
|
137 |
+
|
138 |
+
@app.route('/ask', methods=['POST'])
|
139 |
+
def ask_question():
|
140 |
+
image = request.files.get('image')
|
141 |
+
question = request.form.get('question')
|
142 |
+
complexity = request.form.get('complexity', 'Default')
|
143 |
+
|
144 |
+
if not image or not question:
|
145 |
+
return jsonify({"error": "Missing image or question"}), 400
|
146 |
+
|
147 |
+
# Process the image
|
148 |
+
image = Image.open(image).convert("RGB")
|
149 |
+
|
150 |
+
# Perform real image analysis
|
151 |
+
metadata = real_image_analysis(image)
|
152 |
+
|
153 |
+
# Generate the answer
|
154 |
+
try:
|
155 |
+
answer = generate_answer_from_metadata(metadata, question, complexity)
|
156 |
+
return jsonify({"answer": answer})
|
157 |
+
except Exception as e:
|
158 |
+
print(f"Error generating answer: {str(e)}")
|
159 |
+
return jsonify({"error": "Failed to generate answer"}), 500
|
160 |
+
|
161 |
+
if __name__ == '__main__':
|
162 |
+
app.run(debug=True)
|