Contrastive Learning BERT for Medical Text

This is a custom contrastive learning model trained with triplet loss on medical/clinical text data. It produces high-quality embeddings specifically optimized for medical text similarity.

⚠️ Important: Custom Model Architecture

This model uses a custom ContrastiveModel class that implements:

  • Attention-masked mean pooling (not standard BERT pooling)
  • Custom encode() method for getting embeddings
  • Triplet loss training with margin=1.0

The embeddings from this model will be different from standard BERT models.

Model Details

  • Base Model: Simonlee711/Clinical_ModernBERT
  • Architecture: Custom ContrastiveModel with triplet loss
  • Training Method: Triplet loss with medical text triplets
  • Pooling: Attention-masked mean pooling
  • Normalization: L2 normalization

Quick Start

from transformers import AutoTokenizer, AutoModel
import torch

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("nikhil061307/contrastive-learning-bert-v2")
model = AutoModel.from_pretrained("nikhil061307/contrastive-learning-bert-v2", trust_remote_code=True)

def get_embeddings(texts, model, tokenizer):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt", max_length=256)
    
    # Use the model's custom encode method
    with torch.no_grad():
        embeddings = model.encode(**inputs)
    return embeddings

# Example usage
texts = [
    "patient has swelling in legs",
    "edema in lower extremities", 
    "patient reports headache"
]

embeddings = get_embeddings(texts, model, tokenizer)
print(f"Embeddings shape: {embeddings.shape}")

# Calculate similarity
from torch.nn.functional import cosine_similarity
sim_matrix = cosine_similarity(embeddings.unsqueeze(1), embeddings.unsqueeze(0), dim=2)
print(f"Similarity between text 1 and 2: {sim_matrix[0,1]:.4f}")

Training Data & Method

  • Training Data: Medical notes with triplet structure:

    • Anchor: Clinical sentences/notes
    • Positive: Relevant medical entities/symptoms
    • Negative: Irrelevant medical entities/symptoms
  • Loss Function: Triplet loss with margin=1.0

  • Pooling: Attention-masked mean pooling

  • Normalization: L2 normalization

  • Dropout: 0.15 during training

Key Features

Preserves exact training behavior - same pooling and normalization ✅ Custom encode() method - matches your training inference
Medical domain optimized - trained on clinical text ✅ Attention-masked pooling - proper handling of padding tokens ✅ L2 normalized embeddings - ready for cosine similarity

Performance Notes

This model was specifically trained to:

  • Distinguish between relevant/irrelevant medical entities
  • Produce semantically meaningful embeddings for clinical text
  • Handle medical terminology and clinical language patterns

Expected similarity scores:

  • Similar medical concepts: 0.8-0.95
  • Related medical concepts: 0.6-0.8
  • Unrelated concepts: 0.3-0.6

Troubleshooting

If you get different results than expected:

  1. Make sure to use trust_remote_code=True when loading
  2. Use the model's encode() method, not standard transformer outputs
  3. This model uses custom mean pooling - don't apply your own pooling

Usage Examples

Single Text Embedding

text = "patient shows signs of fatigue"
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
embedding = model.encode(**inputs)

Batch Processing

texts = ["text1", "text2", "text3"]
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=256)
embeddings = model.encode(**inputs)  # Shape: [3, 768]

Similarity Calculation

import torch.nn.functional as F

text1 = "patient has swelling"
text2 = "edema present"

emb1 = get_embeddings([text1], model, tokenizer)
emb2 = get_embeddings([text2], model, tokenizer)

similarity = F.cosine_similarity(emb1, emb2)
print(f"Similarity: {similarity.item():.4f}")

Citation

@misc{contrastive-clinical-bert,
  title={Contrastive Learning BERT for Medical Text},
  author={Your Name},
  year={2025},
  note={Custom contrastive learning model with triplet loss}
}
Downloads last month
56
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support