OpenRAG128 commited on
Commit
2e7f273
·
verified ·
1 Parent(s): 94d96f3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import torch
5
+ from torchvision import models, transforms
6
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
7
+ import os
8
+
9
+ app = Flask(__name__)
10
+
11
+ # Load ImageNet class index
12
+ def load_imagenet_class_index():
13
+ class_index_path = 'imagenet_classes.txt'
14
+ if not os.path.exists(class_index_path):
15
+ raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}")
16
+
17
+ with open(class_index_path) as f:
18
+ classes = [line.strip() for line in f.readlines()]
19
+ return classes
20
+
21
+ imagenet_classes = load_imagenet_class_index()
22
+
23
+ # Load pre-trained models
24
+ resnet = models.resnet50(pretrained=True)
25
+ resnet.eval()
26
+ fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True)
27
+ fasterrcnn.eval()
28
+
29
+ # Image transformation
30
+ transform = transforms.Compose([
31
+ transforms.Resize(256),
32
+ transforms.CenterCrop(224),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
35
+ ])
36
+
37
+ # COCO dataset class names
38
+ COCO_INSTANCE_CATEGORY_NAMES = [
39
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
40
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
41
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
42
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
43
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
44
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
45
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
46
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
47
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
48
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
49
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
50
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
51
+ ]
52
+
53
+ # Function for real image analysis
54
+ def real_image_analysis(image):
55
+ # Prepare image for classification
56
+ img_t = transform(image)
57
+ batch_t = torch.unsqueeze(img_t, 0)
58
+
59
+ # Classification
60
+ with torch.no_grad():
61
+ output = resnet(batch_t)
62
+
63
+ # Get top 3 predictions
64
+ _, indices = torch.sort(output, descending=True)
65
+ percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100
66
+ objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]]
67
+
68
+ # Object detection using Faster R-CNN
69
+ img_tensor = transforms.ToTensor()(image).unsqueeze(0)
70
+ with torch.no_grad():
71
+ prediction = fasterrcnn(img_tensor)
72
+
73
+ # Get detected objects
74
+ detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']]
75
+ objects.extend(detected_objects)
76
+ objects = list(set(objects)) # Remove duplicates
77
+
78
+ # Get dominant colors
79
+ colors = get_dominant_colors(image)
80
+
81
+ # Determine scene (indoor/outdoor)
82
+ scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor"
83
+
84
+ return {
85
+ "objects": objects[:5], # Limit to top 5 objects
86
+ "colors": colors,
87
+ "scene": scene
88
+ }
89
+
90
+ # Function to get dominant colors from an image
91
+ def get_dominant_colors(image, num_colors=3):
92
+ # Resize image to speed up processing
93
+ img = image.copy()
94
+ img.thumbnail((100, 100))
95
+
96
+ # Get colors from the image
97
+ paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors)
98
+ palette = paletted.getpalette()
99
+ color_counts = sorted(paletted.getcolors(), reverse=True)
100
+ colors = []
101
+ for i in range(num_colors):
102
+ palette_index = color_counts[i][1]
103
+ dominant_color = palette[palette_index*3:palette_index*3+3]
104
+ colors.append(rgb_to_name(dominant_color))
105
+ return colors
106
+
107
+ # Function to convert RGB to color name (simplified)
108
+ def rgb_to_name(rgb):
109
+ r, g, b = rgb
110
+ if r > g and r > b:
111
+ return "red"
112
+ elif g > r and g > b:
113
+ return "green"
114
+ elif b > r and b > g:
115
+ return "blue"
116
+ else:
117
+ return "gray"
118
+
119
+ # Function to simulate the generation of answers from metadata
120
+ def generate_answer_from_metadata(metadata, question, complexity):
121
+ prompt = f"""
122
+ The image contains the following objects: {', '.join(metadata['objects'])}.
123
+ The dominant colors are {', '.join(metadata['colors'])}.
124
+ It appears to be an {metadata['scene']} scene.
125
+
126
+ Based on this, provide a {complexity.lower()} response to the following question: {question}
127
+ """
128
+
129
+ # Since `client` is not defined, we can simulate a response here
130
+ # Replace this section with the actual client code if using an API
131
+ return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}."
132
+
133
+ # Flask routes
134
+ @app.route('/')
135
+ def index():
136
+ return render_template('index.html')
137
+
138
+ @app.route('/ask', methods=['POST'])
139
+ def ask_question():
140
+ image = request.files.get('image')
141
+ question = request.form.get('question')
142
+ complexity = request.form.get('complexity', 'Default')
143
+
144
+ if not image or not question:
145
+ return jsonify({"error": "Missing image or question"}), 400
146
+
147
+ # Process the image
148
+ image = Image.open(image).convert("RGB")
149
+
150
+ # Perform real image analysis
151
+ metadata = real_image_analysis(image)
152
+
153
+ # Generate the answer
154
+ try:
155
+ answer = generate_answer_from_metadata(metadata, question, complexity)
156
+ return jsonify({"answer": answer})
157
+ except Exception as e:
158
+ print(f"Error generating answer: {str(e)}")
159
+ return jsonify({"error": "Failed to generate answer"}), 500
160
+
161
+ if __name__ == '__main__':
162
+ app.run(debug=True)