Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import os | |
import io | |
import numpy as np | |
import pandas as pd | |
import tensorflow as tf | |
from tensorflow.keras import layers, regularizers | |
from sklearn.preprocessing import MultiLabelBinarizer | |
from sklearn.model_selection import train_test_split | |
from google.cloud import storage | |
from huggingface_hub import hf_hub_download, notebook_login, login | |
from PIL import Image | |
import gradio as gr | |
import collections | |
from dotenv import load_dotenv | |
# Load environment variables from .env file | |
load_dotenv() | |
# Access and validate HF token | |
hf_token = os.getenv('HF_TOKEN') | |
if hf_token: | |
login(token=hf_token) | |
else: | |
# Check if token exists in default location | |
token_path = os.path.expanduser('~/.huggingface/token') | |
if os.path.exists(token_path): | |
with open(token_path) as f: | |
login(token=f.read().strip()) | |
else: | |
print("Please set HF_TOKEN environment variable or store your token in ~/.huggingface/token") | |
exit(1) | |
# ====================== | |
# CONSTANTS & CONFIGURATION | |
# ====================== | |
SCIN_GCP_PROJECT = 'dx-scin-public' | |
SCIN_GCS_BUCKET_NAME = 'dx-scin-public-data' | |
SCIN_GCS_CASES_CSV = 'dataset/scin_cases.csv' | |
SCIN_GCS_LABELS_CSV = 'dataset/scin_labels.csv' | |
SCIN_HF_MODEL_NAME = 'google/derm-foundation' | |
SCIN_HF_EMBEDDING_FILE = 'scin_dataset_precomputed_embeddings.npz' | |
# The 10 conditions we want to predict | |
CONDITIONS_TO_PREDICT = [ | |
'Eczema', | |
'Allergic Contact Dermatitis', | |
'Insect Bite', | |
'Urticaria', | |
'Psoriasis', | |
'Folliculitis', | |
'Irritant Contact Dermatitis', | |
'Tinea', | |
'Herpes Zoster', | |
'Drug Rash' | |
] | |
# ====================== | |
# HELPER FUNCTIONS FOR DATA LOADING | |
# ====================== | |
def initialize_df_with_metadata(bucket, csv_path): | |
csv_bytes = bucket.blob(csv_path).download_as_string() | |
df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str}) | |
df['case_id'] = df['case_id'].astype(str) | |
return df | |
def augment_metadata_with_labels(df, bucket, csv_path): | |
csv_bytes = bucket.blob(csv_path).download_as_string() | |
labels_df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str}) | |
labels_df['case_id'] = labels_df['case_id'].astype(str) | |
merged_df = pd.merge(df, labels_df, on='case_id') | |
return merged_df | |
def load_embeddings_from_file(repo_id, object_name): | |
file_path = hf_hub_download(repo_id=repo_id, filename=object_name, local_dir='./') | |
embeddings = {} | |
with open(file_path, 'rb') as f: | |
npz_file = np.load(f, allow_pickle=True) | |
for key, value in npz_file.items(): | |
embeddings[key] = value | |
return embeddings | |
# ====================== | |
# DATA PREPARATION FUNCTION | |
# ====================== | |
def prepare_data(df, embeddings): | |
MINIMUM_CONFIDENCE = 0 # Adjust this if needed. | |
X = [] | |
y = [] | |
poor_image_quality_counter = 0 | |
missing_embedding_counter = 0 | |
not_in_condition_counter = 0 | |
condition_confidence_low_counter = 0 | |
for row in df.itertuples(): | |
# Check if the image is marked as having sufficient quality. | |
if getattr(row, 'dermatologist_gradable_for_skin_condition_1', None) != 'DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT': | |
poor_image_quality_counter += 1 | |
continue | |
# Parse the labels and confidences. | |
try: | |
labels = eval(getattr(row, 'dermatologist_skin_condition_on_label_name', '[]')) | |
confidences = eval(getattr(row, 'dermatologist_skin_condition_confidence', '[]')) | |
except Exception as e: | |
continue | |
row_labels = [] | |
for label, conf in zip(labels, confidences): | |
if label not in CONDITIONS_TO_PREDICT: | |
not_in_condition_counter += 1 | |
continue | |
if conf < MINIMUM_CONFIDENCE: | |
condition_confidence_low_counter += 1 | |
continue | |
row_labels.append(label) | |
# For each image associated with this case, add its embedding and labels. | |
for image_path in [getattr(row, 'image_1_path', None), | |
getattr(row, 'image_2_path', None), | |
getattr(row, 'image_3_path', None)]: | |
if pd.isna(image_path) or image_path is None: | |
continue | |
if image_path not in embeddings: | |
missing_embedding_counter += 1 | |
continue | |
X.append(embeddings[image_path]) | |
y.append(row_labels) | |
print(f'Poor image quality count: {poor_image_quality_counter}') | |
print(f'Missing embedding count: {missing_embedding_counter}') | |
print(f'Condition not in list count: {not_in_condition_counter}') | |
print(f'Excluded due to low confidence count: {condition_confidence_low_counter}') | |
return X, y | |
# ====================== | |
# MODEL BUILDING FUNCTION | |
# ====================== | |
def build_model(input_dim, output_dim, weight_decay=1e-4): | |
inputs = tf.keras.Input(shape=(input_dim,)) | |
hidden = layers.Dense(256, activation="relu", | |
kernel_regularizer=regularizers.l2(weight_decay), | |
bias_regularizer=regularizers.l2(weight_decay))(inputs) | |
hidden = layers.Dropout(0.1)(hidden) | |
hidden = layers.Dense(128, activation="relu", | |
kernel_regularizer=regularizers.l2(weight_decay), | |
bias_regularizer=regularizers.l2(weight_decay))(hidden) | |
hidden = layers.Dropout(0.1)(hidden) | |
output = layers.Dense(output_dim, activation="sigmoid")(hidden) | |
model = tf.keras.Model(inputs, output) | |
model.compile(loss="binary_crossentropy", | |
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)) | |
return model | |
# ====================== | |
# MAIN FUNCTION & GRADIO INTERFACE | |
# ====================== | |
def main(): | |
# Connect to the Google Cloud Storage bucket. | |
storage_client = storage.Client(SCIN_GCP_PROJECT) | |
bucket = storage_client.bucket(SCIN_GCS_BUCKET_NAME) | |
# Load SCIN dataset CSVs and merge them. | |
df_cases = initialize_df_with_metadata(bucket, SCIN_GCS_CASES_CSV) | |
df_full = augment_metadata_with_labels(df_cases, bucket, SCIN_GCS_LABELS_CSV) | |
df_full.set_index('case_id', inplace=True) | |
# Load precomputed embeddings from Hugging Face. | |
print("Loading embeddings...") | |
embeddings = load_embeddings_from_file(SCIN_HF_MODEL_NAME, SCIN_HF_EMBEDDING_FILE) | |
# Prepare the training data. | |
print("Preparing training data...") | |
X, y = prepare_data(df_full, embeddings) | |
X = np.array(X) | |
# Convert the list of label lists to binary arrays. | |
mlb = MultiLabelBinarizer(classes=CONDITIONS_TO_PREDICT) | |
y_bin = mlb.fit_transform(y) | |
# Split the data into train and test sets. | |
X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.2, random_state=42) | |
# Build the model. | |
model = build_model(input_dim=6144, output_dim=len(CONDITIONS_TO_PREDICT)) | |
# If a saved model exists, load it; otherwise, train and save it. | |
model_file = "model.h5" | |
if os.path.exists(model_file): | |
print("Loading existing model from", model_file) | |
model = tf.keras.models.load_model(model_file) | |
else: | |
print("Training model... This may take a few minutes.") | |
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32) | |
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32) | |
model.fit(train_ds, validation_data=test_ds, epochs=15) | |
model.save(model_file) | |
# Extract a list of case IDs for dropdown | |
case_ids = list(df_full.index) | |
def predict_case(case_id: str): | |
"""Fetch images and predictions for a given case ID.""" | |
if case_id not in df_full.index: | |
return [], "Case ID not found!", "N/A", "N/A" | |
row = df_full.loc[case_id] | |
image_paths = [row.get('image_1_path'), row.get('image_2_path'), row.get('image_3_path')] | |
images, predictions_text = [], [] | |
# Get Dermatologist's Labels | |
dermatologist_conditions = row.get('dermatologist_skin_condition_on_label_name', "N/A") | |
dermatologist_confidence = row.get('dermatologist_skin_condition_confidence', "N/A") | |
if isinstance(dermatologist_conditions, str): | |
try: | |
dermatologist_conditions = eval(dermatologist_conditions) | |
dermatologist_confidence = eval(dermatologist_confidence) | |
except: | |
pass | |
# Process images & generate predictions | |
for path in image_paths: | |
if isinstance(path, str) and (path in embeddings): | |
try: | |
img_bytes = bucket.blob(path).download_as_string() | |
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
images.append(img) | |
except: | |
continue | |
# Model Prediction | |
emb = np.expand_dims(embeddings[path], axis=0) | |
pred = model.predict(emb)[0] | |
pred_dict = {cond: round(float(prob), 3) for cond, prob in zip(mlb.classes_, pred)} | |
predictions_text.append(str(pred_dict)) | |
# Format the output | |
predictions_text = "\n".join(predictions_text) if predictions_text else "No predictions available." | |
dermatologist_conditions = str(dermatologist_conditions) | |
dermatologist_confidence = str(dermatologist_confidence) | |
return images, predictions_text, dermatologist_conditions, dermatologist_confidence | |
# Create the Gradio Interface with a Dropdown | |
iface = gr.Interface( | |
fn=predict_case, | |
inputs=gr.Dropdown(choices=case_ids, label="Select a Case ID"), | |
outputs=[ | |
gr.Gallery(label="Case Images"), | |
gr.Textbox(label="Model's Predictions"), | |
gr.Textbox(label="Dermatologist's Skin Conditions"), | |
gr.Textbox(label="Dermatologist's Confidence Ratings") | |
], | |
title="Derm Foundation Skin Conditions Explorer", | |
description="Select a Case ID from the dropdown to view images and predictions." | |
) | |
iface.launch(share=True) | |
if __name__ == "__main__": | |
main() | |