Contrastive Learning BERT for Medical Text
This is a custom contrastive learning model trained with triplet loss on medical/clinical text data. It produces high-quality embeddings specifically optimized for medical text similarity.
⚠️ Important: Custom Model Architecture
This model uses a custom ContrastiveModel
class that implements:
- Attention-masked mean pooling (not standard BERT pooling)
- Custom
encode()
method for getting embeddings - Triplet loss training with margin=1.0
The embeddings from this model will be different from standard BERT models.
Model Details
- Base Model: Simonlee711/Clinical_ModernBERT
- Architecture: Custom ContrastiveModel with triplet loss
- Training Method: Triplet loss with medical text triplets
- Pooling: Attention-masked mean pooling
- Normalization: L2 normalization
Quick Start
from transformers import AutoTokenizer, AutoModel
import torch
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("nikhil061307/contrastive-learning-bert-v2")
model = AutoModel.from_pretrained("nikhil061307/contrastive-learning-bert-v2", trust_remote_code=True)
def get_embeddings(texts, model, tokenizer):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=256)
# Use the model's custom encode method
with torch.no_grad():
embeddings = model.encode(**inputs)
return embeddings
# Example usage
texts = [
"patient has swelling in legs",
"edema in lower extremities",
"patient reports headache"
]
embeddings = get_embeddings(texts, model, tokenizer)
print(f"Embeddings shape: {embeddings.shape}")
# Calculate similarity
from torch.nn.functional import cosine_similarity
sim_matrix = cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
print(f"Similarity between text 1 and 2: {sim_matrix[0,1]:.4f}")
Training Data & Method
Training Data: Medical notes with triplet structure:
- Anchor: Clinical sentences/notes
- Positive: Relevant medical entities/symptoms
- Negative: Irrelevant medical entities/symptoms
Loss Function: Triplet loss with margin=1.0
Pooling: Attention-masked mean pooling
Normalization: L2 normalization
Dropout: 0.15 during training
Key Features
✅ Preserves exact training behavior - same pooling and normalization
✅ Custom encode() method - matches your training inference
✅ Medical domain optimized - trained on clinical text
✅ Attention-masked pooling - proper handling of padding tokens
✅ L2 normalized embeddings - ready for cosine similarity
Performance Notes
This model was specifically trained to:
- Distinguish between relevant/irrelevant medical entities
- Produce semantically meaningful embeddings for clinical text
- Handle medical terminology and clinical language patterns
Expected similarity scores:
- Similar medical concepts: 0.8-0.95
- Related medical concepts: 0.6-0.8
- Unrelated concepts: 0.3-0.6
Troubleshooting
If you get different results than expected:
- Make sure to use
trust_remote_code=True
when loading - Use the model's
encode()
method, not standard transformer outputs - This model uses custom mean pooling - don't apply your own pooling
Usage Examples
Single Text Embedding
text = "patient shows signs of fatigue"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
embedding = model.encode(**inputs)
Batch Processing
texts = ["text1", "text2", "text3"]
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
embeddings = model.encode(**inputs) # Shape: [3, 768]
Similarity Calculation
import torch.nn.functional as F
text1 = "patient has swelling"
text2 = "edema present"
emb1 = get_embeddings([text1], model, tokenizer)
emb2 = get_embeddings([text2], model, tokenizer)
similarity = F.cosine_similarity(emb1, emb2)
print(f"Similarity: {similarity.item():.4f}")
Citation
@misc{contrastive-clinical-bert,
title={Contrastive Learning BERT for Medical Text},
author={Your Name},
year={2025},
note={Custom contrastive learning model with triplet loss}
}
- Downloads last month
- 56