import os os.environ['CUDA_VISIBLE_DEVICES'] = '-1' from flask import Flask, render_template, request, jsonify, redirect, url_for, send_from_directory from flask_pymongo import PyMongo from flask_bcrypt import Bcrypt import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing import image import numpy as np import cv2 import google.generativeai as genai from dotenv import load_dotenv import certifi import uuid import secrets import logging load_dotenv() app = Flask(__name__) app.config["MONGO_URI"] = os.getenv("MONGODB_URI") or os.getenv("MONGO_URI") app.config['SECRET_KEY'] = os.getenv("SECRET_KEY") or secrets.token_hex(16) app.config.setdefault("SESSION_COOKIE_HTTPONLY", True) app.config.setdefault("SESSION_COOKIE_SAMESITE", "Lax") GEMINI_API_KEY = os.getenv("GEMINI_API_KEY") or os.getenv("GOOGLE_API_KEY") logging.basicConfig(level=logging.INFO) logger = logging.getLogger("app") try: if app.config["MONGO_URI"]: mongo = PyMongo(app, tlsCAFile=certifi.where()) else: logger.warning("MONGO_URI not set. MongoDB operations will fail.") mongo = PyMongo(app, tlsCAFile=certifi.where()) except Exception as e: logger.error(f"Mongo initialization error: {e}") mongo = None bcrypt = Bcrypt(app) gemini_model = None if GEMINI_API_KEY: try: genai.configure(api_key=GEMINI_API_KEY) gemini_model = genai.GenerativeModel('gemini-2.0-flash') except Exception as e: logger.error(f"Gemini initialization error: {e}") else: logger.warning("GEMINI_API_KEY/GOOGLE_API_KEY not set. /chat will return a friendly error.") MODEL_CONFIG = { "Pneumonia": { "path": "model/best_pneumonia_model.h5", "labels": ["Normal", "Pneumonia"], "last_conv_layer": "relu", "input_size": (224, 224) }, "Tuberculosis": { "path": "model/best_tuberculosis_model.h5", "labels": ["Normal", "Tuberculosis"], "last_conv_layer": "relu", "input_size": (224, 224) }, "Brain Tumor": { "path": "model/best_braintumor_model.h5", "labels": ["glioma", "meningioma", "notumor", "pituitary"], "last_conv_layer": "relu", "input_size": (224, 224) }, "Skin Cancer": { "path": "model/best_skincancer_model.h5", "labels": ["Actinic keratoses", "Basal cell carcinoma", "Benign keratosis-like lesions", "Dermatofibroma", "Melanoma", "Melanocytic nevi", "Vascular lesions"], "last_conv_layer": "relu", "input_size": (224, 224) }, "Kvasir": { "path": "model/best_kvasir_model.h5", "labels": ["dyed-lifted-polyps", "dyed-resection-margins", "esophagitis", "normal-cecum", "normal-pylorus", "normal-z-line", "polyps", "ulcerative-colitis"], "last_conv_layer": "relu", "input_size": (224, 224) } } # Heuristic filename patterns for mapping examples per model MODEL_EXAMPLE_PATTERNS = { "Pneumonia": ["pneumonia", "normal-"], "Tuberculosis": ["tuberculosis", "tb-"], "Brain Tumor": ["glioma", "meningioma", "notumor", "pituitary", "brain"], "Skin Cancer": ["melanoma", "nev", "keratos", "carcinoma", "vascular", "dermatofibroma", "skin"], "Kvasir": [ "dyedlifted", "dyedresection", "esophagitis", "normalceacum", "normalpylorus", "normalzline", "polypus", "ulcerative" ], } models = {} def load_all_models(): for name, config in MODEL_CONFIG.items(): try: model_path = config["path"] if os.path.exists(model_path): models[name] = load_model(model_path, compile=False) logger.info(f"Successfully loaded {name} model from {model_path}.") else: logger.warning(f"Model file not found at {model_path}") except Exception as e: logger.error(f"Error loading model {name}: {e}") load_all_models() def preprocess_image(img_path, target_size=(224, 224)): img = image.load_img(img_path, target_size=target_size) img_array = image.img_to_array(img) if img_array.ndim == 2: img_array = np.stack([img_array]*3, axis=-1) elif img_array.shape[-1] == 4: img_array = img_array[..., :3] img_array = np.expand_dims(img_array, axis=0) img_array = img_array.astype("float32") / 255.0 return img_array def _safe_get_layer(model, layer_name): try: return model.get_layer(layer_name) except Exception: return None def find_last_conv_layer(model): for layer in reversed(model.layers): if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.DepthwiseConv2D)): try: out_shape = layer.output_shape except Exception: out_shape = None if out_shape and len(out_shape) == 4: return layer.name raise ValueError("Could not automatically find a convolutional layer in the model.") def get_gradcam_heatmap(model, img_array, last_conv_layer_name, pred_index=None): if not _safe_get_layer(model, last_conv_layer_name): last_conv_layer_name = find_last_conv_layer(model) conv_layer = model.get_layer(last_conv_layer_name) grad_model = tf.keras.models.Model([model.inputs], [conv_layer.output, model.output]) with tf.GradientTape() as tape: conv_outputs, preds = grad_model(img_array, training=False) if isinstance(preds, (list, tuple)): preds = preds[0] preds = tf.convert_to_tensor(preds) if preds.shape.rank is not None and preds.shape[-1] == 1: class_channel = preds[:, 0] else: if pred_index is None: pred_index = tf.argmax(preds[0]) class_channel = preds[:, pred_index] grads = tape.gradient(class_channel, conv_outputs) if grads is None: heatmap = tf.zeros(conv_outputs.shape[1:3], dtype=tf.float32) return heatmap.numpy() pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) conv_outputs = conv_outputs[0] heatmap = tf.tensordot(conv_outputs, pooled_grads, axes=(2, 0)) heatmap = tf.maximum(heatmap, 0) denom = tf.math.reduce_max(heatmap) heatmap = heatmap / (denom + 1e-8) return heatmap.numpy() def save_gradcam_image(img_path, heatmap, output_path, threshold=0.6, alpha=0.4): img = cv2.imread(img_path) if img is None: raise ValueError("Failed to read image with OpenCV.") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0])) mask = heatmap > threshold overlay = np.zeros_like(img, dtype=np.uint8) overlay[mask] = [255, 0, 0] superimposed_img = cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0) superimposed_img[~mask] = img[~mask] superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_RGB2BGR) cv2.imwrite(output_path, superimposed_img) return output_path BASE_DIR = os.path.dirname(os.path.abspath(__file__)) TEST_IMAGES_DIR = os.path.join(BASE_DIR, 'testimages') @app.route("/") def home(): return redirect(url_for('index')) @app.route('/tmp/') def serve_tmp_file(filename): return send_from_directory('/tmp', filename) @app.route('/testimages/') def serve_test_image(filename): return send_from_directory(TEST_IMAGES_DIR, filename) @app.route('/example_images') def example_images(): try: files = [] selected_model = (request.args.get('model') or '').strip() patterns = MODEL_EXAMPLE_PATTERNS.get(selected_model, []) if selected_model else [] if os.path.isdir(TEST_IMAGES_DIR): for f in os.listdir(TEST_IMAGES_DIR): lf = f.lower() if lf.endswith(('.png', '.jpg', '.jpeg')): # If a model is selected and patterns exist, filter by them if patterns: if not any(p in lf for p in patterns): continue files.append(url_for('serve_test_image', filename=f)) return jsonify({"images": files}) except Exception as e: logger.error(f"example_images error: {e}") return jsonify({"images": []}) @app.route('/login', methods=['GET', 'POST']) def login(): return redirect(url_for('index')) @app.route('/signup', methods=['GET', 'POST']) def signup(): return redirect(url_for('index')) @app.route('/index') def index(): return render_template('index.html') @app.route('/logout') def logout(): return redirect(url_for('index')) def _postprocess_binary_prediction(raw): arr = np.array(raw, dtype=np.float32) arr = np.squeeze(arr) if arr.ndim == 0: prob = float(arr) if prob < 0.0 or prob > 1.0: prob = float(1.0 / (1.0 + np.exp(-prob))) return min(max(prob, 0.0), 1.0) prob = float(arr[0]) if prob < 0.0 or prob > 1.0: prob = float(1.0 / (1.0 + np.exp(-prob))) return min(max(prob, 0.0), 1.0) @app.route("/predict", methods=["POST"]) def predict(): if "file" not in request.files: return jsonify({"error": "No file part"}), 400 file = request.files["file"] model_name = request.form.get("model") if not file or file.filename == "": return jsonify({"error": "No selected file"}), 400 if model_name not in models: return jsonify({"error": "Invalid model selected"}), 400 try: filename = f"{uuid.uuid4()}_{file.filename}" filepath = os.path.join("/tmp", filename) file.save(filepath) model_config = MODEL_CONFIG[model_name] model = models[model_name] labels = model_config["labels"] input_size = model_config.get("input_size", (224, 224)) img_array = preprocess_image(filepath, target_size=input_size) prediction = model.predict(img_array, verbose=0) prediction = np.array(prediction) if len(labels) == 2 and prediction.ndim >= 1 and prediction.shape[-1] in (1,) and prediction.size >= 1: prob_pos = _postprocess_binary_prediction(prediction) if prob_pos >= 0.5: predicted_index = 1 predicted_label = labels[1] confidence = prob_pos else: predicted_index = 0 predicted_label = labels[0] confidence = 1.0 - prob_pos else: if prediction.ndim == 2: vec = prediction[0] else: vec = prediction.reshape(-1) if np.any(vec < 0) or np.any(vec > 1) or not np.isclose(np.sum(vec), 1.0, atol=1e-3): exps = np.exp(vec - np.max(vec)) probs = exps / (np.sum(exps) + 1e-8) else: probs = vec predicted_index = int(np.argmax(probs)) predicted_label = labels[predicted_index] confidence = float(np.max(probs)) gradcam_url = None try: last_conv_layer_name = MODEL_CONFIG[model_name].get('last_conv_layer') or "" heatmap = get_gradcam_heatmap(model, img_array, last_conv_layer_name, pred_index=predicted_index) gradcam_filename = f"gradcam_{filename}" gradcam_filepath = os.path.join("/tmp", gradcam_filename) save_gradcam_image(filepath, heatmap, gradcam_filepath) gradcam_url = url_for('serve_tmp_file', filename=gradcam_filename) except Exception as e: logger.error(f"Grad-CAM error: {e}") return jsonify({ "original_image": url_for('serve_tmp_file', filename=filename), "gradcam_image": gradcam_url, "prediction": str(predicted_label), "confidence": float(confidence), "model_used": str(model_name) }) except Exception as e: logger.exception("Prediction error") return jsonify({"error": str(e)}), 500 @app.route("/chat", methods=["POST"]) def chat(): data = request.get_json(silent=True) or {} user_message = data.get("message", "") prediction_context = data.get("context") or {} model_used = prediction_context.get('model_used', 'Unknown Model') pred_label = prediction_context.get('prediction', 'Unknown') conf = prediction_context.get('confidence', 0.0) try: conf_pct = float(conf) * 100.0 except Exception: conf_pct = 0.0 prompt = f""" You are a helpful medical assistant chatbot. A medical image was analyzed with the following results: - Model Used: {model_used} - Prediction: {pred_label} - Confidence Score: {conf_pct:.2f}% The user's question is: "{user_message}" Based on this context, provide a helpful and informative response. Do not provide a diagnosis. Advise the user to consult a medical professional. """ try: if gemini_model is None: return jsonify({"error": "Gemini API not configured. Set GEMINI_API_KEY in environment."}), 500 response = gemini_model.generate_content(prompt) text = getattr(response, "text", None) if not text: text = str(response) return jsonify({"response": text}) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(debug=True)