derm-foundation / app.py
taziksh's picture
Upload folder using huggingface_hub
9c5ec50 verified
#!/usr/bin/env python
import os
import io
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, regularizers
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
from google.cloud import storage
from huggingface_hub import hf_hub_download, notebook_login, login
from PIL import Image
import gradio as gr
import collections
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
# Access and validate HF token
hf_token = os.getenv('HF_TOKEN')
if hf_token:
login(token=hf_token)
else:
# Check if token exists in default location
token_path = os.path.expanduser('~/.huggingface/token')
if os.path.exists(token_path):
with open(token_path) as f:
login(token=f.read().strip())
else:
print("Please set HF_TOKEN environment variable or store your token in ~/.huggingface/token")
exit(1)
# ======================
# CONSTANTS & CONFIGURATION
# ======================
SCIN_GCP_PROJECT = 'dx-scin-public'
SCIN_GCS_BUCKET_NAME = 'dx-scin-public-data'
SCIN_GCS_CASES_CSV = 'dataset/scin_cases.csv'
SCIN_GCS_LABELS_CSV = 'dataset/scin_labels.csv'
SCIN_HF_MODEL_NAME = 'google/derm-foundation'
SCIN_HF_EMBEDDING_FILE = 'scin_dataset_precomputed_embeddings.npz'
# The 10 conditions we want to predict
CONDITIONS_TO_PREDICT = [
'Eczema',
'Allergic Contact Dermatitis',
'Insect Bite',
'Urticaria',
'Psoriasis',
'Folliculitis',
'Irritant Contact Dermatitis',
'Tinea',
'Herpes Zoster',
'Drug Rash'
]
# ======================
# HELPER FUNCTIONS FOR DATA LOADING
# ======================
def initialize_df_with_metadata(bucket, csv_path):
csv_bytes = bucket.blob(csv_path).download_as_string()
df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str})
df['case_id'] = df['case_id'].astype(str)
return df
def augment_metadata_with_labels(df, bucket, csv_path):
csv_bytes = bucket.blob(csv_path).download_as_string()
labels_df = pd.read_csv(io.BytesIO(csv_bytes), dtype={'case_id': str})
labels_df['case_id'] = labels_df['case_id'].astype(str)
merged_df = pd.merge(df, labels_df, on='case_id')
return merged_df
def load_embeddings_from_file(repo_id, object_name):
file_path = hf_hub_download(repo_id=repo_id, filename=object_name, local_dir='./')
embeddings = {}
with open(file_path, 'rb') as f:
npz_file = np.load(f, allow_pickle=True)
for key, value in npz_file.items():
embeddings[key] = value
return embeddings
# ======================
# DATA PREPARATION FUNCTION
# ======================
def prepare_data(df, embeddings):
MINIMUM_CONFIDENCE = 0 # Adjust this if needed.
X = []
y = []
poor_image_quality_counter = 0
missing_embedding_counter = 0
not_in_condition_counter = 0
condition_confidence_low_counter = 0
for row in df.itertuples():
# Check if the image is marked as having sufficient quality.
if getattr(row, 'dermatologist_gradable_for_skin_condition_1', None) != 'DEFAULT_YES_IMAGE_QUALITY_SUFFICIENT':
poor_image_quality_counter += 1
continue
# Parse the labels and confidences.
try:
labels = eval(getattr(row, 'dermatologist_skin_condition_on_label_name', '[]'))
confidences = eval(getattr(row, 'dermatologist_skin_condition_confidence', '[]'))
except Exception as e:
continue
row_labels = []
for label, conf in zip(labels, confidences):
if label not in CONDITIONS_TO_PREDICT:
not_in_condition_counter += 1
continue
if conf < MINIMUM_CONFIDENCE:
condition_confidence_low_counter += 1
continue
row_labels.append(label)
# For each image associated with this case, add its embedding and labels.
for image_path in [getattr(row, 'image_1_path', None),
getattr(row, 'image_2_path', None),
getattr(row, 'image_3_path', None)]:
if pd.isna(image_path) or image_path is None:
continue
if image_path not in embeddings:
missing_embedding_counter += 1
continue
X.append(embeddings[image_path])
y.append(row_labels)
print(f'Poor image quality count: {poor_image_quality_counter}')
print(f'Missing embedding count: {missing_embedding_counter}')
print(f'Condition not in list count: {not_in_condition_counter}')
print(f'Excluded due to low confidence count: {condition_confidence_low_counter}')
return X, y
# ======================
# MODEL BUILDING FUNCTION
# ======================
def build_model(input_dim, output_dim, weight_decay=1e-4):
inputs = tf.keras.Input(shape=(input_dim,))
hidden = layers.Dense(256, activation="relu",
kernel_regularizer=regularizers.l2(weight_decay),
bias_regularizer=regularizers.l2(weight_decay))(inputs)
hidden = layers.Dropout(0.1)(hidden)
hidden = layers.Dense(128, activation="relu",
kernel_regularizer=regularizers.l2(weight_decay),
bias_regularizer=regularizers.l2(weight_decay))(hidden)
hidden = layers.Dropout(0.1)(hidden)
output = layers.Dense(output_dim, activation="sigmoid")(hidden)
model = tf.keras.Model(inputs, output)
model.compile(loss="binary_crossentropy",
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4))
return model
# ======================
# MAIN FUNCTION & GRADIO INTERFACE
# ======================
def main():
# Connect to the Google Cloud Storage bucket.
storage_client = storage.Client(SCIN_GCP_PROJECT)
bucket = storage_client.bucket(SCIN_GCS_BUCKET_NAME)
# Load SCIN dataset CSVs and merge them.
df_cases = initialize_df_with_metadata(bucket, SCIN_GCS_CASES_CSV)
df_full = augment_metadata_with_labels(df_cases, bucket, SCIN_GCS_LABELS_CSV)
df_full.set_index('case_id', inplace=True)
# Load precomputed embeddings from Hugging Face.
print("Loading embeddings...")
embeddings = load_embeddings_from_file(SCIN_HF_MODEL_NAME, SCIN_HF_EMBEDDING_FILE)
# Prepare the training data.
print("Preparing training data...")
X, y = prepare_data(df_full, embeddings)
X = np.array(X)
# Convert the list of label lists to binary arrays.
mlb = MultiLabelBinarizer(classes=CONDITIONS_TO_PREDICT)
y_bin = mlb.fit_transform(y)
# Split the data into train and test sets.
X_train, X_test, y_train, y_test = train_test_split(X, y_bin, test_size=0.2, random_state=42)
# Build the model.
model = build_model(input_dim=6144, output_dim=len(CONDITIONS_TO_PREDICT))
# If a saved model exists, load it; otherwise, train and save it.
model_file = "model.h5"
if os.path.exists(model_file):
print("Loading existing model from", model_file)
model = tf.keras.models.load_model(model_file)
else:
print("Training model... This may take a few minutes.")
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32)
model.fit(train_ds, validation_data=test_ds, epochs=15)
model.save(model_file)
# Extract a list of case IDs for dropdown
case_ids = list(df_full.index)
def predict_case(case_id: str):
"""Fetch images and predictions for a given case ID."""
if case_id not in df_full.index:
return [], "Case ID not found!", "N/A", "N/A"
row = df_full.loc[case_id]
image_paths = [row.get('image_1_path'), row.get('image_2_path'), row.get('image_3_path')]
images, predictions_text = [], []
# Get Dermatologist's Labels
dermatologist_conditions = row.get('dermatologist_skin_condition_on_label_name', "N/A")
dermatologist_confidence = row.get('dermatologist_skin_condition_confidence', "N/A")
if isinstance(dermatologist_conditions, str):
try:
dermatologist_conditions = eval(dermatologist_conditions)
dermatologist_confidence = eval(dermatologist_confidence)
except:
pass
# Process images & generate predictions
for path in image_paths:
if isinstance(path, str) and (path in embeddings):
try:
img_bytes = bucket.blob(path).download_as_string()
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
images.append(img)
except:
continue
# Model Prediction
emb = np.expand_dims(embeddings[path], axis=0)
pred = model.predict(emb)[0]
pred_dict = {cond: round(float(prob), 3) for cond, prob in zip(mlb.classes_, pred)}
predictions_text.append(str(pred_dict))
# Format the output
predictions_text = "\n".join(predictions_text) if predictions_text else "No predictions available."
dermatologist_conditions = str(dermatologist_conditions)
dermatologist_confidence = str(dermatologist_confidence)
return images, predictions_text, dermatologist_conditions, dermatologist_confidence
# Create the Gradio Interface with a Dropdown
iface = gr.Interface(
fn=predict_case,
inputs=gr.Dropdown(choices=case_ids, label="Select a Case ID"),
outputs=[
gr.Gallery(label="Case Images"),
gr.Textbox(label="Model's Predictions"),
gr.Textbox(label="Dermatologist's Skin Conditions"),
gr.Textbox(label="Dermatologist's Confidence Ratings")
],
title="Derm Foundation Skin Conditions Explorer",
description="Select a Case ID from the dropdown to view images and predictions."
)
iface.launch(share=True)
if __name__ == "__main__":
main()