#!/usr/bin/env python import os import io import numpy as np import pandas as pd import tensorflow as tf from tensorflow.keras import layers, regularizers from sklearn.preprocessing import MultiLabelBinarizer from sklearn.model_selection import train_test_split from google.cloud import storage from huggingface_hub import hf_hub_download, notebook_login, login from PIL import Image import gradio as gr import collections from dotenv import load_dotenv # Load environment variables from .env file load_dotenv() # Access and validate HF token hf_token = os.getenv('HF_TOKEN') if hf_token: login(token=hf_token) else: # Check if token exists in default location token_path = os.path.expanduser('~/.huggingface/token') if os.path.exists(token_path): with open(token_path) as f: login(token=f.read().strip()) else: print("Please set HF_TOKEN environment variable or store your token in ~/.huggingface/token") exit(1) # ====================== # CONSTANTS & CONFIGURATION # ====================== SCIN_GCP_PROJECT = 'dx-scin-public' SCIN_GCS_BUCKET_NAME = 'dx-scin-public-data' SCIN_GCS_CASES_CSV = 'dataset/scin_cases.csv' SCIN_GCS_LABELS_CSV = 'dataset/scin_labels.csv' SCIN_HF_MODEL_NAME = 'google/derm-foundation' SCIN_HF_EMBEDDING_FILE = 'scin_dataset_precomputed_embeddings.npz' # The 10 conditions we want to predict CONDITIONS_TO_PREDICT = [ 'Eczema', 'Allergic Contact Dermatitis', 'Insect Bite', 'Urticaria', 'Psoriasis', 'Folliculitis', 'Irritant Contact Dermatitis', 'Tinea', 'Herpes Zoster', 'Drug Rash' ] # ====================== # HELPER FUNCTIONS FOR DATA LOADING # ====================== def initialize_df_with_metadata(bucket, csv_path): csv_bytes = bucket.blob(csv_path).download_as_string() df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str}) df['case_id'] = df['case_id'].astype(str) return df def augment_metadata_with_labels(df, bucket, csv_path): csv_bytes = bucket.blob(csv_path).download_as_string() labels_df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str}) labels_df['case_id'] = labels_df['case_id'].astype(str) merged_df = pd.merge(df, labels_df, on='case_id') return merged_df def load_embeddings_from_file(repo_id, object_name): file_path = hf_hub_download(repo_id=repo_id, filename=object_name, local_dir='./') embeddings = {} with open(file_path, 'rb') as f: npz_file = np.load(f, allow_pickle=True) for key, value in npz_file.items(): embeddings[key] = value return embeddings # ====================== # DATA PREPARATION FUNCTION # ====================== def prepare_data(df, embeddings): MINIMUM_CONFIDENCE = 0 # Adjust this if needed. X = [] y = [] poor_image_quality_counter = 0 missing_embedding_counter = 0 not_in_condition_counter = 0 condition_confidence_low_counter = 0 for row in df.itertuples(): # Check if the image is marked as having sufficient quality. if getattr(row, 'dermatologist_gradable_for_skin_condition_1', None) != 'DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT': poor_image_quality_counter += 1 continue # Parse the labels and confidences. try: labels = eval(getattr(row, 'dermatologist_skin_condition_on_label_name', '[]')) confidences = eval(getattr(row, 'dermatologist_skin_condition_confidence', '[]')) except Exception as e: continue row_labels = [] for label, conf in zip(labels, confidences): if label not in CONDITIONS_TO_PREDICT: not_in_condition_counter += 1 continue if conf < MINIMUM_CONFIDENCE: condition_confidence_low_counter += 1 continue row_labels.append(label) # For each image associated with this case, add its embedding and labels. for image_path in [getattr(row, 'image_1_path', None), getattr(row, 'image_2_path', None), getattr(row, 'image_3_path', None)]: if pd.isna(image_path) or image_path is None: continue if image_path not in embeddings: missing_embedding_counter += 1 continue X.append(embeddings[image_path]) y.append(row_labels) print(f'Poor image quality count: {poor_image_quality_counter}') print(f'Missing embedding count: {missing_embedding_counter}') print(f'Condition not in list count: {not_in_condition_counter}') print(f'Excluded due to low confidence count: {condition_confidence_low_counter}') return X, y # ====================== # MODEL BUILDING FUNCTION # ====================== def build_model(input_dim, output_dim, weight_decay=1e-4): inputs = tf.keras.Input(shape=(input_dim,)) hidden = layers.Dense(256, activation="relu", kernel_regularizer=regularizers.l2(weight_decay), bias_regularizer=regularizers.l2(weight_decay))(inputs) hidden = layers.Dropout(0.1)(hidden) hidden = layers.Dense(128, activation="relu", kernel_regularizer=regularizers.l2(weight_decay), bias_regularizer=regularizers.l2(weight_decay))(hidden) hidden = layers.Dropout(0.1)(hidden) output = layers.Dense(output_dim, activation="sigmoid")(hidden) model = tf.keras.Model(inputs, output) model.compile(loss="binary_crossentropy", optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)) return model # ====================== # MAIN FUNCTION & GRADIO INTERFACE # ====================== def main(): # Connect to the Google Cloud Storage bucket. storage_client = storage.Client(SCIN_GCP_PROJECT) bucket = storage_client.bucket(SCIN_GCS_BUCKET_NAME) # Load SCIN dataset CSVs and merge them. df_cases = initialize_df_with_metadata(bucket, SCIN_GCS_CASES_CSV) df_full = augment_metadata_with_labels(df_cases, bucket, SCIN_GCS_LABELS_CSV) df_full.set_index('case_id', inplace=True) # Load precomputed embeddings from Hugging Face. print("Loading embeddings...") embeddings = load_embeddings_from_file(SCIN_HF_MODEL_NAME, SCIN_HF_EMBEDDING_FILE) # Prepare the training data. print("Preparing training data...") X, y = prepare_data(df_full, embeddings) X = np.array(X) # Convert the list of label lists to binary arrays. mlb = MultiLabelBinarizer(classes=CONDITIONS_TO_PREDICT) y_bin = mlb.fit_transform(y) # Split the data into train and test sets. X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.2, random_state=42) # Build the model. model = build_model(input_dim=6144, output_dim=len(CONDITIONS_TO_PREDICT)) # If a saved model exists, load it; otherwise, train and save it. model_file = "model.h5" if os.path.exists(model_file): print("Loading existing model from", model_file) model = tf.keras.models.load_model(model_file) else: print("Training model... This may take a few minutes.") train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32) test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32) model.fit(train_ds, validation_data=test_ds, epochs=15) model.save(model_file) # Extract a list of case IDs for dropdown case_ids = list(df_full.index) def predict_case(case_id: str): """Fetch images and predictions for a given case ID.""" if case_id not in df_full.index: return [], "Case ID not found!", "N/A", "N/A" row = df_full.loc[case_id] image_paths = [row.get('image_1_path'), row.get('image_2_path'), row.get('image_3_path')] images, predictions_text = [], [] # Get Dermatologist's Labels dermatologist_conditions = row.get('dermatologist_skin_condition_on_label_name', "N/A") dermatologist_confidence = row.get('dermatologist_skin_condition_confidence', "N/A") if isinstance(dermatologist_conditions, str): try: dermatologist_conditions = eval(dermatologist_conditions) dermatologist_confidence = eval(dermatologist_confidence) except: pass # Process images & generate predictions for path in image_paths: if isinstance(path, str) and (path in embeddings): try: img_bytes = bucket.blob(path).download_as_string() img = Image.open(io.BytesIO(img_bytes)).convert("RGB") images.append(img) except: continue # Model Prediction emb = np.expand_dims(embeddings[path], axis=0) pred = model.predict(emb)[0] pred_dict = {cond: round(float(prob), 3) for cond, prob in zip(mlb.classes_, pred)} predictions_text.append(str(pred_dict)) # Format the output predictions_text = "\n".join(predictions_text) if predictions_text else "No predictions available." dermatologist_conditions = str(dermatologist_conditions) dermatologist_confidence = str(dermatologist_confidence) return images, predictions_text, dermatologist_conditions, dermatologist_confidence # Create the Gradio Interface with a Dropdown iface = gr.Interface( fn=predict_case, inputs=gr.Dropdown(choices=case_ids, label="Select a Case ID"), outputs=[ gr.Gallery(label="Case Images"), gr.Textbox(label="Model's Predictions"), gr.Textbox(label="Dermatologist's Skin Conditions"), gr.Textbox(label="Dermatologist's Confidence Ratings") ], title="Derm Foundation Skin Conditions Explorer", description="Select a Case ID from the dropdown to view images and predictions." ) iface.launch(share=True) if __name__ == "__main__": main()