import torch from huggingface_hub import hf_hub_download class MedSAM2Model(torch.nn.Module): def __init__(self, model_filename="MedSAM2_latest.pt"): super().__init__() # Download the model checkpoint from the Hub model_path = hf_hub_download( repo_id="wanglab/MedSAM2", filename=model_filename ) # Load the full checkpoint (this assumes the full model is saved) self.model = torch.load(model_path, map_location="cpu") def forward(self, x): return self.model(x)