from flask import Flask, render_template, request, jsonify from PIL import Image from io import BytesIO import torch from torchvision import models, transforms from torchvision.models.detection import fasterrcnn_resnet50_fpn import os app = Flask(__name__) # Load ImageNet class index def load_imagenet_class_index(): class_index_path = 'imagenet_classes.txt' if not os.path.exists(class_index_path): raise FileNotFoundError(f"ImageNet class index file not found at {class_index_path}") with open(class_index_path) as f: classes = [line.strip() for line in f.readlines()] return classes imagenet_classes = load_imagenet_class_index() # Load pre-trained models resnet = models.resnet50(pretrained=True) resnet.eval() fasterrcnn = fasterrcnn_resnet50_fpn(pretrained=True) fasterrcnn.eval() # Image transformation transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # COCO dataset class names COCO_INSTANCE_CATEGORY_NAMES = [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ] # Function for real image analysis def real_image_analysis(image): # Prepare image for classification img_t = transform(image) batch_t = torch.unsqueeze(img_t, 0) # Classification with torch.no_grad(): output = resnet(batch_t) # Get top 3 predictions _, indices = torch.sort(output, descending=True) percentages = torch.nn.functional.softmax(output, dim=1)[0] * 100 objects = [imagenet_classes[idx.item()] for idx in indices[0][:3]] # Object detection using Faster R-CNN img_tensor = transforms.ToTensor()(image).unsqueeze(0) with torch.no_grad(): prediction = fasterrcnn(img_tensor) # Get detected objects detected_objects = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in prediction[0]['labels']] objects.extend(detected_objects) objects = list(set(objects)) # Remove duplicates # Get dominant colors colors = get_dominant_colors(image) # Determine scene (indoor/outdoor) scene = "outdoor" if any(obj in ['sky', 'tree', 'grass', 'mountain'] for obj in objects) else "indoor" return { "objects": objects[:5], # Limit to top 5 objects "colors": colors, "scene": scene } # Function to get dominant colors from an image def get_dominant_colors(image, num_colors=3): # Resize image to speed up processing img = image.copy() img.thumbnail((100, 100)) # Get colors from the image paletted = img.convert('P', palette=Image.ADAPTIVE, colors=num_colors) palette = paletted.getpalette() color_counts = sorted(paletted.getcolors(), reverse=True) colors = [] for i in range(num_colors): palette_index = color_counts[i][1] dominant_color = palette[palette_index*3:palette_index*3+3] colors.append(rgb_to_name(dominant_color)) return colors # Function to convert RGB to color name (simplified) def rgb_to_name(rgb): r, g, b = rgb if r > g and r > b: return "red" elif g > r and g > b: return "green" elif b > r and b > g: return "blue" else: return "gray" # Function to simulate the generation of answers from metadata def generate_answer_from_metadata(metadata, question, complexity): prompt = f""" The image contains the following objects: {', '.join(metadata['objects'])}. The dominant colors are {', '.join(metadata['colors'])}. It appears to be an {metadata['scene']} scene. Based on this, provide a {complexity.lower()} response to the following question: {question} """ # Since `client` is not defined, we can simulate a response here # Replace this section with the actual client code if using an API return f"Simulated answer based on metadata: {metadata}. Question: {question}, Complexity: {complexity}." # Flask routes @app.route('/') def index(): return render_template('index.html') @app.route('/ask', methods=['POST']) def ask_question(): image = request.files.get('image') question = request.form.get('question') complexity = request.form.get('complexity', 'Default') if not image or not question: return jsonify({"error": "Missing image or question"}), 400 # Process the image image = Image.open(image).convert("RGB") # Perform real image analysis metadata = real_image_analysis(image) # Generate the answer try: answer = generate_answer_from_metadata(metadata, question, complexity) return jsonify({"answer": answer}) except Exception as e: print(f"Error generating answer: {str(e)}") return jsonify({"error": "Failed to generate answer"}), 500 if __name__ == '__main__': app.run(debug=True)