diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a717d468647f839ec11347da18ca2611b2ea889
--- /dev/null
+++ b/app.py
@@ -0,0 +1,124 @@
+import gradio as gr
+from gradio_image_prompter import ImagePrompter
+from detectron2.config import LazyConfig, instantiate
+from detectron2.checkpoint import DetectionCheckpointer
+import cv2
+import numpy as np
+import torch
+from huggingface_hub import hf_hub_download
+
+DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+model_choice = {
+    'SAM': None, 
+    'HQ-SAM': None, 
+    'SAM2': None
+}
+
+for model_type in model_choice.keys():
+    model_choice[model_type] = hf_hub_download(repo_id="XiaRho/SEMat", filename=f"SEMat_{model_type}.pth", repo_type="model")
+
+def load_model(model_type='SAM2'):
+    assert model_type in model_choice.keys()
+    config_path = './configs/SEMat_{}.py'.format(model_type)
+    cfg = LazyConfig.load(config_path)
+
+    if hasattr(cfg.model.sam_model, 'ckpt_path'):
+        cfg.model.sam_model.ckpt_path = None
+    else:
+        cfg.model.sam_model.checkpoint = None
+    model = instantiate(cfg.model)
+    if model.lora_rank is not None:
+        model.init_lora()
+    model.to(DEVICE)
+    DetectionCheckpointer(model).load(model_choice[model_type])
+    model.eval()
+    return model, model_type
+
+def transform_image_bbox(prompts):
+    if len(prompts["points"]) != 1:
+        raise gr.Error("Please input only one BBox.", duration=5)
+    [[x1, y1, idx_3, x2, y2, idx_6]] = prompts["points"]
+    if idx_3 != 2 or idx_6 != 3:
+        raise gr.Error("Please input BBox instead of point.", duration=5)
+    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+
+    img = prompts["image"]
+    ori_H, ori_W, _ = img.shape
+
+    scale = 1024 * 1.0 / max(ori_H, ori_W)
+    new_H, new_W = ori_H * scale, ori_W * scale
+    new_W = int(new_W + 0.5)
+    new_H = int(new_H + 0.5)
+
+    img = cv2.resize(img, (new_W, new_H), interpolation=cv2.INTER_LINEAR)
+    padding = np.zeros([1024, 1024, 3], dtype=img.dtype)
+    padding[: new_H, : new_W, :] = img
+    img = padding
+    # img = img[:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) / 255.0
+    img = img.transpose((2, 0, 1)).astype(np.float32) / 255.0
+
+    [[x1, y1, _, x2, y2, _]] = prompts["points"]
+    x1, y1, x2, y2 = int(x1 * scale + 0.5), int(y1 * scale + 0.5), int(x2 * scale + 0.5), int(y2 * scale + 0.5)
+    bbox = np.clip(np.array([[x1, y1, x2, y2]]) * 1.0, 0, 1023.0)
+
+    return img, bbox, (ori_H, ori_W), (new_H, new_W)
+
+if __name__ == '__main__':
+
+    model, model_type = load_model()
+
+    def inference_image(prompts, input_model_type):
+
+        global model_type
+        global model
+
+        if input_model_type != model_type:
+            gr.Info('Loading SEMat of {} version.'.format(input_model_type), duration=5)
+            _model, _ = load_model(input_model_type)
+            model_type = input_model_type
+            model = _model
+
+        image, bbox, ori_H_W, pad_H_W = transform_image_bbox(prompts)
+        input_data = {
+            'image': torch.from_numpy(image)[None].to(model.device),
+            'bbox': torch.from_numpy(bbox)[None].to(model.device),
+        }
+
+        with torch.no_grad():
+            inputs = model.preprocess_inputs(input_data) 
+            images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition']
+
+            if model.backbone_condition:
+                condition_proj = model.condition_embedding(condition) 
+            elif model.backbone_bbox_prompt is not None or model.bbox_prompt_all_block is not None:
+                condition_proj = bbox
+            else:
+                condition_proj = None
+
+            low_res_masks, pred_alphas, pred_trimap, sam_hq_matting_token = model.forward_samhq_and_matting_decoder(images, bbox, condition_proj)
+
+
+        output_alpha = np.uint8(pred_alphas[0, 0][:pad_H_W[0], :pad_H_W[1], None].repeat(1, 1, 3).cpu().numpy() * 255)
+
+        return output_alpha
+
+    with gr.Blocks() as demo:
+
+        with gr.Row():
+            with gr.Column(scale=45):
+                img_in = ImagePrompter(type='numpy', show_label=False, label="query image")
+                
+            with gr.Column(scale=45):
+                img_out = gr.Image(type='pil', label="output")
+
+        with gr.Row():
+            with gr.Column(scale=45):
+                input_model_type = gr.Dropdown(list(model_choice.keys()), value='SAM2', label="Trained SEMat Version")
+
+            with gr.Column(scale=45):
+                bt = gr.Button()
+
+        bt.click(inference_image, inputs=[img_in, input_model_type], outputs=[img_out]) 
+
+demo.launch()
+
diff --git a/configs/SEMat_HQ-SAM.py b/configs/SEMat_HQ-SAM.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a4a90d280dedac697f958cad1985dc14c4e73e7
--- /dev/null
+++ b/configs/SEMat_HQ-SAM.py
@@ -0,0 +1,48 @@
+from .common.train import train
+from .semantic_enhanced_matting.model import model
+from .common.optimizer import optimizer
+from .common.scheduler import lr_multiplier
+from .semantic_enhanced_matting.dataloader import dataloader
+from modeling.decoder.unet_detail_capture import MattingDetailDecoder
+from detectron2.config import LazyCall as L
+
+model.sam_model.model_type = 'vit_l'
+model.sam_model.checkpoint = None
+model.vis_period = 200
+model.output_dir = '?'
+
+train.max_iter = 60000
+train.eval_period = int(train.max_iter * 1 / 10)
+train.checkpointer.period = int(train.max_iter * 1 / 10)
+train.checkpointer.max_to_keep = 1
+
+optimizer.lr = 5e-5
+
+lr_multiplier.scheduler.values = [1.0, 0.5, 0.2]
+lr_multiplier.scheduler.milestones = [0.5, 0.75]
+lr_multiplier.scheduler.num_updates = train.max_iter
+lr_multiplier.warmup_length = 250 / train.max_iter
+
+train.output_dir = './work_dirs/SEMat_HQ-SAM'
+
+model.lora_rank = 16
+model.lora_alpha = 16
+model.matting_decoder = L(MattingDetailDecoder)(
+    vit_intern_feat_in = 1024,
+    vit_intern_feat_index = [0, 1, 2, 3],
+    norm_type = 'SyncBN',
+    block_num = 2,
+    img_feat_in = 6,
+    norm_mask_logits = 6.5
+)
+model.backbone_bbox_prompt = 'bbox'
+model.backbone_bbox_prompt_loc = [2, 3]
+model.backbone_bbox_prompt_loss_weight = 1.0
+model.matting_token = True
+model.sam_model.matting_token = 3
+model.sam_model.frozen_decoder = True
+model.sam_hq_token_reg = 0.2
+model.reg_w_bce_loss = True
+model.matting_token_sup = 'trimap'
+model.matting_token_sup_loss_weight = 0.05
+model.trimap_loss_type = 'NGHM'
diff --git a/configs/SEMat_SAM.py b/configs/SEMat_SAM.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32de0be7ef2dee2b3c7f549039d1612c55309a5
--- /dev/null
+++ b/configs/SEMat_SAM.py
@@ -0,0 +1,51 @@
+from .common.train import train
+from .semantic_enhanced_matting.model import model
+from .common.optimizer import optimizer
+from .common.scheduler import lr_multiplier
+from .semantic_enhanced_matting.dataloader import dataloader
+from modeling.decoder.unet_detail_capture import MattingDetailDecoder
+from detectron2.config import LazyCall as L
+
+model.sam_model.model_type = 'vit_l'
+model.sam_model.checkpoint = None
+model.vis_period = 200
+model.output_dir = '?'
+
+train.max_iter = 60000
+train.eval_period = int(train.max_iter * 1 / 10)
+train.checkpointer.period = int(train.max_iter * 1 / 10)
+train.checkpointer.max_to_keep = 1
+
+optimizer.lr = 5e-5
+
+lr_multiplier.scheduler.values = [1.0, 0.5, 0.2]
+lr_multiplier.scheduler.milestones = [0.5, 0.75]
+lr_multiplier.scheduler.num_updates = train.max_iter
+lr_multiplier.warmup_length = 250 / train.max_iter
+
+train.output_dir = './work_dirs/SEMat_SAM'
+
+model.lora_rank = 16
+model.lora_alpha = 16
+model.matting_decoder = L(MattingDetailDecoder)(
+    vit_intern_feat_in = 1024,
+    vit_intern_feat_index = [0, 1, 2, 3],
+    norm_type = 'SyncBN',
+    block_num = 2,
+    img_feat_in = 6,
+    norm_mask_logits = 6.5
+)
+model.backbone_bbox_prompt = 'bbox'
+model.backbone_bbox_prompt_loc = [2, 3]
+model.backbone_bbox_prompt_loss_weight = 1.0
+model.matting_token = True
+model.sam_model.matting_token = 3
+model.sam_model.frozen_decoder = True
+model.sam_hq_token_reg = 0.2
+model.reg_on_sam_logits = True
+model.reg_w_bce_loss = True
+model.matting_token_sup = 'trimap'
+model.matting_token_sup_loss_weight = 0.05
+model.trimap_loss_type = 'NGHM'
+model.sam_model.wo_hq = True
+model.sam_model.mask_matting_res_add = False
diff --git a/configs/SEMat_SAM2.py b/configs/SEMat_SAM2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6103a578349990e70cffce818464b9d6889dbea
--- /dev/null
+++ b/configs/SEMat_SAM2.py
@@ -0,0 +1,57 @@
+from .common.train import train
+from .semantic_enhanced_matting.model import model
+from .common.optimizer import optimizer
+from .common.scheduler import lr_multiplier
+from .semantic_enhanced_matting.dataloader import dataloader
+from modeling.decoder.unet_detail_capture import MattingDetailDecoder
+from detectron2.config import LazyCall as L
+from sam2.build_sam import build_sam2
+
+model.sam_model.model_type = 'vit_l'
+model.sam_model.checkpoint = None
+model.vis_period = 200
+model.output_dir = '?'
+
+train.max_iter = 60000
+train.eval_period = int(train.max_iter * 1 / 10)
+train.checkpointer.period = int(train.max_iter * 1 / 10)
+train.checkpointer.max_to_keep = 1
+
+optimizer.lr = 5e-5
+
+lr_multiplier.scheduler.values = [1.0, 0.5, 0.2]
+lr_multiplier.scheduler.milestones = [0.5, 0.75]
+lr_multiplier.scheduler.num_updates = train.max_iter
+lr_multiplier.warmup_length = 250 / train.max_iter
+
+train.output_dir = './work_dirs/SEMat_SAM2'
+
+model.sam2 = True
+model.sam_model = L(build_sam2)(
+    config_file = 'sam2_hiera_l.yaml',
+    ckpt_path = None,
+    device = "cuda",
+    bbox_mask_matting_token = True,
+    mode="train",
+    upscaled_embedding_res_add = False
+)
+model.lora_rank = 16
+model.lora_alpha = 16
+model.matting_decoder = L(MattingDetailDecoder)(
+    vit_intern_feat_in = 1024,
+    vit_intern_feat_index = [0, 1, 2, 3],
+    norm_type = 'SyncBN',
+    block_num = 2,
+    img_feat_in = 6,
+    norm_mask_logits = 6.5,
+    sam2_multi_scale_feates = True
+)
+model.backbone_bbox_prompt = 'bbox'
+model.backbone_bbox_prompt_loc = [2, 3]
+model.backbone_bbox_prompt_loss_weight = 1.0
+model.matting_token = True
+model.sam_hq_token_reg = 0.2
+model.reg_w_bce_loss = True
+model.matting_token_sup = 'trimap'
+model.matting_token_sup_loss_weight = 0.05
+model.trimap_loss_type = 'NGHM'
diff --git a/configs/common/optimizer.py b/configs/common/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c4d5a5ca093fce2b7fc2578cac76ce2c7944ba
--- /dev/null
+++ b/configs/common/optimizer.py
@@ -0,0 +1,26 @@
+from detectron2 import model_zoo
+from functools import partial
+
+def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12):
+    """
+    Calculate lr decay rate for different ViT blocks.
+    Args:
+        name (string): parameter name.
+        lr_decay_rate (float): base lr decay rate.
+        num_layers (int): number of ViT blocks.
+
+    Returns:
+        lr decay rate for the given parameter.
+    """
+    layer_id = num_layers + 1
+    if name.startswith("backbone"):
+        if ".pos_embed" in name or ".patch_embed" in name:
+            layer_id = 0
+        elif ".blocks." in name and ".residual." not in name:
+            layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
+    return lr_decay_rate ** (num_layers + 1 - layer_id)
+
+# Optimizer
+optimizer = model_zoo.get_config("common/optim.py").AdamW
+optimizer.params.lr_factor_func = partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.65)
+optimizer.params.overrides = {"pos_embed": {"weight_decay": 0.0}}
\ No newline at end of file
diff --git a/configs/common/scheduler.py b/configs/common/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ceaf49576628f63b62c76fa46c3ce46c5ced308
--- /dev/null
+++ b/configs/common/scheduler.py
@@ -0,0 +1,13 @@
+from detectron2.config import LazyCall as L
+from detectron2.solver import WarmupParamScheduler
+from fvcore.common.param_scheduler import MultiStepParamScheduler
+
+lr_multiplier = L(WarmupParamScheduler)(
+    scheduler=L(MultiStepParamScheduler)(
+        values=[1.0, 0.1, 0.01],
+        milestones=[96778, 103579],
+        num_updates=100,
+    ),
+    warmup_length=250 / 100,
+    warmup_factor=0.001,
+)
\ No newline at end of file
diff --git a/configs/common/train.py b/configs/common/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..441786e4f0923e05bcc6be1ee992497836bb96d6
--- /dev/null
+++ b/configs/common/train.py
@@ -0,0 +1,17 @@
+train = dict(
+    output_dir="./output",
+    init_checkpoint="",
+    max_iter=90000,
+    amp=dict(enabled=False),  # options for Automatic Mixed Precision
+    ddp=dict(  # options for DistributedDataParallel
+        broadcast_buffers=True,
+        find_unused_parameters=False,
+        fp16_compression=True,
+    ),
+    checkpointer=dict(period=5000, max_to_keep=100),  # options for PeriodicCheckpointer
+    eval_period=5000,
+    log_period=20,
+    device="cuda",
+    seed=42
+    # ...
+)
\ No newline at end of file
diff --git a/configs/semantic_enhanced_matting/dataloader.py b/configs/semantic_enhanced_matting/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba4247445876995001c7589c0616b79f63cb843f
--- /dev/null
+++ b/configs/semantic_enhanced_matting/dataloader.py
@@ -0,0 +1,62 @@
+from omegaconf import OmegaConf
+from torch.utils.data import ConcatDataset
+from detectron2.config import LazyCall as L
+
+from data.dim_dataset import build_d2_test_dataloader, AdobeCompositionEvaluator, adobe_composition_collate_fn, RW100Test, AIM500Test, AM2KTest, P3M500Test, RWP636Test, SIMTest
+
+AIM500_PATH = '/path/to/datasets/AIM-500'
+RW100_PATH = '/path/to/datasets/RefMatte_RW_100'
+AM2K_PATH = '/path/to/datasets/AM-2K'
+P3M500_PATH = '/path/to/datasets/P3M-10k/validation/P3M-500-NP'
+RWP636_PATH = '/path/to/datasets/RealWorldPortrait-636'
+SIM_PATH = '/path/to/datasets/SIMD/generated_testset'
+
+dataloader = OmegaConf.create()
+test_dataset = L(ConcatDataset)(
+    datasets = [
+        L(AIM500Test)(
+            data_dir = AIM500_PATH,
+            target_size = 1024,
+            multi_fg = True,
+        ),
+        L(RW100Test)(
+            data_dir = RW100_PATH,
+            target_size = 1024,
+            multi_fg = True,
+        ),
+        L(AM2KTest)(
+            data_dir = AM2K_PATH,
+            target_size = 1024,
+            multi_fg = True,
+        ),
+        L(P3M500Test)(
+            data_dir = P3M500_PATH,
+            target_size = 1024,
+            multi_fg = True,
+        ),
+        L(RWP636Test)(
+            data_dir = RWP636_PATH,
+            target_size = 1024,
+            multi_fg = True
+        ),
+        L(SIMTest)(
+            data_dir = SIM_PATH,
+            target_size = 1024,
+            multi_fg = True
+        )
+    ]
+)
+
+dataloader.test = L(build_d2_test_dataloader)(
+    dataset = test_dataset,
+    local_batch_size = 1,
+    num_workers = 4,
+    collate_fn = adobe_composition_collate_fn
+)
+
+dataloader.evaluator = L(AdobeCompositionEvaluator)(
+    save_eval_results_step = 10, 
+    output_dir = None,  # modify in EvalHook (do_test)
+    eval_dataset_type = ['RW100', 'AIM500', 'AM2K', 'P3M500', 'RWP636', 'SIM'],
+    distributed = True,
+),
diff --git a/configs/semantic_enhanced_matting/model.py b/configs/semantic_enhanced_matting/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ab3f69d3570585f46756ed1f213374e3073d1c
--- /dev/null
+++ b/configs/semantic_enhanced_matting/model.py
@@ -0,0 +1,35 @@
+from detectron2.config import LazyCall as L
+
+from modeling import Detail_Capture, MattingCriterion
+from modeling.meta_arch import SamHqMatte
+from modeling.semantic_enhanced_matting.build_sam import sam_model_registry_def
+# from modeling.sam_hq_matting.predictor import SamPredictor
+from modeling.semantic_enhanced_matting import MaskDecoderMatting
+
+mask_token_only = False
+
+model = L(SamHqMatte)(
+
+    # original sam_hq
+    sam_model = L(sam_model_registry_def)(
+        model_type = 'vit_b',
+        checkpoint = None,
+    ),
+    hq_token_only = True,
+    hq_features_type = 'Final',
+    multimask_output = True,
+
+    # loss function
+    criterion=L(MattingCriterion)(
+        losses = ['unknown_l1_loss', 'known_l1_loss', 'loss_pha_laplacian', 'loss_gradient_penalty']
+    ),
+    
+    # other params.
+    pixel_mean = [123.675 / 255., 116.280 / 255., 103.530 / 255.],
+    pixel_std = [58.395 / 255., 57.120 / 255., 57.375 / 255.],
+
+    lora_rank = None,
+    lora_alpha = None,
+    w_dora = False,
+    w_rslora = False,
+)
\ No newline at end of file
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a6d8b9c2967e4e7660868a4b5f2d1bf4e2aaa19
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1 @@
+from .dim_dataset import *
\ No newline at end of file
diff --git a/data/__pycache__/__init__.cpython-38.pyc b/data/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03cb664bb8b36d0b6dd6bc608eab66e303d7b450
Binary files /dev/null and b/data/__pycache__/__init__.cpython-38.pyc differ
diff --git a/data/__pycache__/dim_dataset.cpython-38.pyc b/data/__pycache__/dim_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d7ebc3178ffa57ef88cc28c3beb8fa108bc51bb
Binary files /dev/null and b/data/__pycache__/dim_dataset.cpython-38.pyc differ
diff --git a/data/__pycache__/evaluate.cpython-38.pyc b/data/__pycache__/evaluate.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..290aebf059a454b8e79c21878c16ab7ed26bdba4
Binary files /dev/null and b/data/__pycache__/evaluate.cpython-38.pyc differ
diff --git a/data/__pycache__/rand_augment.cpython-38.pyc b/data/__pycache__/rand_augment.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..254ab37896c285f4cedfa083a2b823ce6e726761
Binary files /dev/null and b/data/__pycache__/rand_augment.cpython-38.pyc differ
diff --git a/data/coconut_dataset.py b/data/coconut_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..494f70266577fc1dbe063aa126af22c108c29292
--- /dev/null
+++ b/data/coconut_dataset.py
@@ -0,0 +1,377 @@
+import os
+import time
+import json
+import torch
+import numpy as np
+import cv2
+from torch.utils.data import Dataset, DistributedSampler, Sampler
+from torchvision import transforms
+from detectron2.utils.logger import setup_logger
+from typing import Optional
+from operator import itemgetter
+from collections import defaultdict
+
+from data.dim_dataset import GenBBox
+
+
+def random_interp():
+    return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])
+
+
+class SplitConcatImage(object):
+
+    def __init__(self, concat_num=4, wo_mask_to_mattes=False):
+        self.concat_num = concat_num
+        self.wo_mask_to_mattes = wo_mask_to_mattes
+        if self.wo_mask_to_mattes:
+            assert self.concat_num == 5
+
+    def __call__(self, concat_image):
+        if isinstance(concat_image, list):
+            concat_image, image_path = concat_image[0], concat_image[1]
+        else:
+            image_path = None
+        H, W, _ = concat_image.shape
+
+        concat_num = self.concat_num
+        if image_path is not None:
+            if '06-14' in image_path:
+                concat_num = 4
+            elif 'ori_mask' in image_path or 'SEMat' in image_path:
+                concat_num = 3
+            else:
+                concat_num = 5
+        
+        assert W % concat_num == 0
+        W = W // concat_num
+
+        image = concat_image[:H, :W]
+        if self.concat_num != 3:
+            trimap = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
+            if self.wo_mask_to_mattes:
+                alpha = concat_image[:H, 2 * W: 3 * W]
+            else:
+                alpha = concat_image[:H, (concat_num - 1) * W: concat_num * W]
+        else:
+            trimap = concat_image[:H, (concat_num - 1) * W: concat_num * W]
+            alpha = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
+
+        return {'image': image, 'trimap': trimap, 'alpha': alpha}
+
+
+class RandomHorizontalFlip(object):
+
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, sample):
+        if np.random.uniform(0, 1) < self.prob:
+            for key in sample.keys():
+                sample[key] = cv2.flip(sample[key], 1)
+        return sample
+
+class EmptyAug(object):
+    def __call__(self, sample):
+        return sample
+
+class RandomReszieCrop(object):
+
+    def __init__(self, output_size=1024, aug_scale_min=0.5, aug_scale_max=1.5):
+        self.desired_size = output_size
+        self.aug_scale_min = aug_scale_min
+        self.aug_scale_max = aug_scale_max
+
+    def __call__(self, sample):
+        H, W, _ = sample['image'].shape
+
+        if self.aug_scale_min == 1.0 and self.aug_scale_max == 1.0:
+            crop_H, crop_W = H, W
+            crop_y1, crop_y2 = 0, crop_H
+            crop_x1, crop_x2 = 0, crop_W
+            scale_W, scaled_H = W, H
+        elif self.aug_scale_min == -1.0 and self.aug_scale_max == -1.0:
+            scale = min(self.desired_size / H, self.desired_size / W)
+            scaled_H, scale_W = round(H * scale), round(W * scale)
+            crop_H, crop_W = scaled_H, scale_W
+            crop_y1, crop_y2 = 0, crop_H
+            crop_x1, crop_x2 = 0, crop_W
+        else:
+            # random size
+            random_scale = np.random.uniform(0, 1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min  # random_val: 0.5 ~ 1.5
+            scaled_size = round(random_scale * self.desired_size)
+
+            scale = min(scaled_size / H, scaled_size / W)
+            scaled_H, scale_W = round(H * scale), round(W * scale)
+
+            # random crop
+            crop_H, crop_W = min(self.desired_size, scaled_H), min(self.desired_size, scale_W)  # crop_size
+            margin_H, margin_W = max(scaled_H - crop_H, 0), max(scale_W - crop_W, 0)
+            offset_H, offset_W = np.random.randint(0, margin_H + 1), np.random.randint(0, margin_W + 1)
+            crop_y1, crop_y2 = offset_H, offset_H + crop_H
+            crop_x1, crop_x2 = offset_W, offset_W + crop_W
+
+        for key in sample.keys():
+            sample[key] = cv2.resize(sample[key], (scale_W, scaled_H), interpolation=random_interp())[crop_y1: crop_y2, crop_x1: crop_x2, :]  # resize and crop
+            padding = np.zeros(shape=(self.desired_size, self.desired_size, 3), dtype=sample[key].dtype)  # pad to desired_size
+            padding[: crop_H, : crop_W, :] = sample[key]
+            sample[key] = padding
+
+        return sample
+
+
+class RandomJitter(object):
+    """
+    Random change the hue of the image
+    """
+
+    def __call__(self, sample):
+
+        image = sample['image']
+
+        # convert to HSV space, convert to float32 image to keep precision during space conversion.
+        image = cv2.cvtColor(image.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
+        # Hue noise
+        hue_jitter = np.random.randint(-40, 40)
+        image[:, :, 0] = np.remainder(image[:, :, 0].astype(np.float32) + hue_jitter, 360)
+        # Saturation noise
+        sat_bar = image[:, :, 1].mean()
+
+        sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
+        sat = image[:, :, 1]
+        sat = np.abs(sat + sat_jitter)
+        sat[sat>1] = 2 - sat[sat>1]
+        image[:, :, 1] = sat
+        # Value noise
+        val_bar = image[:, :, 2].mean()
+
+        val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
+        val = image[:, :, 2]
+        val = np.abs(val + val_jitter)
+        val[val>1] = 2 - val[val>1]
+        image[:, :, 2] = val
+        # convert back to BGR space
+        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
+        sample['image'] = image * 255
+
+        return sample
+
+
+class ToTensor(object):
+
+    def __call__(self, sample):
+        image, alpha, trimap = sample['image'][:, :, ::-1], sample['alpha'], sample['trimap']
+
+        # image
+        image = image.transpose((2, 0, 1)) / 255.
+        sample['image'] = torch.from_numpy(image).float()
+
+        # alpha
+        alpha = alpha.transpose((2, 0, 1))[0: 1] / 255.
+        alpha[alpha < 0 ] = 0
+        alpha[alpha > 1] = 1
+        sample['alpha'] = torch.from_numpy(alpha).float()
+
+        # trimap
+        trimap = trimap.transpose((2, 0, 1))[0: 1] / 1.
+        sample['trimap'] = torch.from_numpy(trimap).float()
+        sample['trimap'][sample['trimap'] < 85] = 0
+        sample['trimap'][sample['trimap'] >= 170] = 1
+        sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        return sample
+    
+
+class COCONutData(Dataset):
+    def __init__(
+        self, 
+        json_path, 
+        data_root_path, 
+        output_size = 512, 
+        aug_scale_min = 0.5, 
+        aug_scale_max = 1.5,
+        with_bbox = False, 
+        bbox_offset_factor = None,
+        phase = "train",
+        min_miou = 95,
+        miou_json = '',
+        remove_coco_transparent = False,
+        coconut_num_ratio = None,
+        return_multi_fg_info = False,
+        wo_accessory_fusion = False,
+        wo_mask_to_mattes = False,
+        return_image_name = False,
+    ):
+        
+        self.data_root_path = data_root_path
+        self.output_size = output_size
+        self.aug_scale_min = aug_scale_min
+        self.aug_scale_max = aug_scale_max
+        self.with_bbox = with_bbox
+        self.bbox_offset_factor = bbox_offset_factor
+        self.phase = phase
+        self.min_miou = min_miou
+        self.miou_json = miou_json
+        self.remove_coco_transparent = remove_coco_transparent
+        self.coconut_num_ratio = coconut_num_ratio
+        self.return_multi_fg_info = return_multi_fg_info
+        self.wo_accessory_fusion = wo_accessory_fusion # TODO
+        self.wo_mask_to_mattes = wo_mask_to_mattes
+        self.return_image_name = return_image_name
+        assert self.wo_accessory_fusion + self.wo_mask_to_mattes <= 1
+        assert self.phase == 'train'
+
+        self.data_path = []
+        with open(json_path, "r") as file:
+            coconut_matting_info = json.load(file)
+        
+        if self.miou_json != '':
+            name_2_miou_dict = defaultdict(int)
+            with open(self.miou_json, "r") as file:
+                coconut_matting_miou = json.load(file)
+            for miou, name in coconut_matting_miou:
+                name_2_miou_dict[name] = miou
+            for i in coconut_matting_info:
+                if 'accessory' in i['save_path']:
+                    self.data_path.append(i['save_path'])
+                elif name_2_miou_dict[i['save_path'].split('/')[-1]] >= self.min_miou:
+                    if not (self.remove_coco_transparent and 'glass' in i['save_path']):
+                        self.data_path.append(i['save_path'])
+        else:
+            for i in coconut_matting_info:
+                self.data_path.append(i['save_path'])
+
+        if 'accessory' in json_path:
+            concat_num = 5
+        elif 'ori_mask' in json_path:
+            concat_num = 3
+        else:
+            concat_num = 4
+
+        train_trans = [
+            SplitConcatImage(concat_num, wo_mask_to_mattes = self.wo_mask_to_mattes),
+            RandomHorizontalFlip(prob=0 if hasattr(self, 'return_image_name') and self.return_image_name else 0.5),
+            RandomReszieCrop(self.output_size, self.aug_scale_min, self.aug_scale_max),
+            EmptyAug() if hasattr(self, 'return_image_name') and self.return_image_name else RandomJitter(),
+            ToTensor(),
+            GenBBox(bbox_offset_factor=self.bbox_offset_factor)
+        ]
+        self.transform = transforms.Compose(train_trans)
+        print('coconut num: ', len(self.data_path) * self.coconut_num_ratio if self.coconut_num_ratio is not None else len(self.data_path))
+
+    def __getitem__(self, idx):
+        if self.coconut_num_ratio is not None:
+            if self.coconut_num_ratio < 1.0 or idx >= len(self.data_path):
+                idx = np.random.randint(0, len(self.data_path))
+        concat_image = cv2.imread(os.path.join(self.data_root_path, self.data_path[idx]))
+        sample = self.transform([concat_image, self.data_path[idx]])
+        sample['dataset_name'] = 'COCONut'
+        if self.return_multi_fg_info:
+            sample['multi_fg'] = False
+        if hasattr(self, 'return_image_name') and self.return_image_name:
+            sample['image_name'] = self.data_path[idx]
+        return sample
+
+    def __len__(self):
+        if self.coconut_num_ratio is not None:
+            return int(len(self.data_path) * self.coconut_num_ratio)
+        else:
+            return len(self.data_path)
+
+
+class DatasetFromSampler(Dataset):
+    """Dataset to create indexes from `Sampler`.
+
+    Args:
+        sampler: PyTorch sampler
+    """
+
+    def __init__(self, sampler: Sampler):
+        """Initialisation for DatasetFromSampler."""
+        self.sampler = sampler
+        self.sampler_list = None
+
+    def __getitem__(self, index: int):
+        """Gets element of the dataset.
+
+        Args:
+            index: index of the element in the dataset
+
+        Returns:
+            Single element by index
+        """
+        if self.sampler_list is None:
+            self.sampler_list = list(self.sampler)
+        return self.sampler_list[index]
+
+    def __len__(self) -> int:
+        """
+        Returns:
+            int: length of the dataset
+        """
+        return len(self.sampler)
+    
+
+class DistributedSamplerWrapper(DistributedSampler):
+    """
+    Wrapper over `Sampler` for distributed training.
+    Allows you to use any sampler in distributed mode.
+    It is especially useful in conjunction with
+    `torch.nn.parallel.DistributedDataParallel`. In such case, each
+    process can pass a DistributedSamplerWrapper instance as a DataLoader
+    sampler, and load a subset of subsampled data of the original dataset
+    that is exclusive to it.
+    .. note::
+        Sampler is assumed to be of constant size.
+    """
+
+    def __init__(
+        self,
+        sampler,
+        num_replicas: Optional[int] = None,
+        rank: Optional[int] = None,
+        shuffle: bool = True,
+    ):
+        """
+        Args:
+            sampler: Sampler used for subsampling
+            num_replicas (int, optional): Number of processes participating in
+              distributed training
+            rank (int, optional): Rank of the current process
+              within ``num_replicas``
+            shuffle (bool, optional): If true (default),
+              sampler will shuffle the indices
+        """
+        super(DistributedSamplerWrapper, self).__init__(
+            DatasetFromSampler(sampler),
+            num_replicas=num_replicas,
+            rank=rank,
+            shuffle=shuffle,
+        )
+        self.sampler = sampler
+
+    def __iter__(self):
+        """@TODO: Docs. Contribution is welcome."""
+        self.dataset = DatasetFromSampler(self.sampler)
+        indexes_of_indexes = super().__iter__()
+        subsampler_indexes = self.dataset
+        return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
+    
+
+if __name__ == '__main__':
+
+    
+
+    dataset = COCONutData(
+        json_path = '/root/data/my_path/Matting/DiffMatte-main/24-06-14_coco-nut_matting.json', 
+        data_root_path = '/root/data/my_path/Matting/DiffMatte-main', 
+        output_size = 1024, 
+        aug_scale_min = 0.5, 
+        aug_scale_max = 1.5,
+        with_bbox = True, 
+        bbox_offset_factor = 0.1,
+        phase = "train"
+    )
+    data = dataset[0]
+
+    for key, val in data.items():
+        print(key, val.shape, torch.min(val), torch.max(val))
\ No newline at end of file
diff --git a/data/dim_dataset.py b/data/dim_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f47549bb6e253201616934c158c6037d5893a01
--- /dev/null
+++ b/data/dim_dataset.py
@@ -0,0 +1,1476 @@
+'''
+Dataloader to process Adobe Image Matting Dataset.
+
+From GCA_Matting(https://github.com/Yaoyi-Li/GCA-Matting/tree/master/dataloader)
+'''
+import os
+import glob
+import logging
+import os.path as osp
+import functools
+import numpy as np
+import torch
+import cv2
+import math
+import numbers
+import random
+import pickle
+from   torch.utils.data import Dataset, DataLoader
+from   torch.nn import functional as F
+from   torchvision import transforms
+from easydict import EasyDict
+from detectron2.utils.logger import setup_logger
+from detectron2.utils import comm
+from detectron2.data import build_detection_test_loader
+import torchvision.transforms.functional
+
+import json
+from PIL import Image
+from detectron2.evaluation.evaluator import DatasetEvaluator
+from collections import defaultdict
+
+from data.evaluate import compute_sad_loss, compute_mse_loss, compute_mad_loss, compute_gradient_loss, compute_connectivity_error
+
+# Base default config
+CONFIG = EasyDict({})
+
+# Model config
+CONFIG.model = EasyDict({})
+# one-hot or class, choice: [3, 1]
+CONFIG.model.trimap_channel = 1
+
+# Dataloader config
+CONFIG.data = EasyDict({})
+# feed forward image size (untested)
+CONFIG.data.crop_size = 512
+# composition of two foregrounds, affine transform, crop and HSV jitter
+CONFIG.data.cutmask_prob = 0.25
+CONFIG.data.augmentation = True
+CONFIG.data.random_interp = True
+
+class Prefetcher():
+    """
+    Modified from the data_prefetcher in https://github.com/NVIDIA/apex/blob/master/examples/imagenet/main_amp.py
+    """
+    def __init__(self, loader):
+        self.orig_loader = loader
+        self.stream = torch.cuda.Stream()
+        self.next_sample = None
+
+    def preload(self):
+        try:
+            self.next_sample = next(self.loader)
+        except StopIteration:
+            self.next_sample = None
+            return
+
+        with torch.cuda.stream(self.stream):
+            for key, value in self.next_sample.items():
+                if isinstance(value, torch.Tensor):
+                    self.next_sample[key] = value.cuda(non_blocking=True)
+
+    def __next__(self):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        sample = self.next_sample
+        if sample is not None:
+            for key, value in sample.items():
+                if isinstance(value, torch.Tensor):
+                    sample[key].record_stream(torch.cuda.current_stream())
+            self.preload()
+        else:
+            # throw stop exception if there is no more data to perform as a default dataloader
+            raise StopIteration("No samples in loader. example: `iterator = iter(Prefetcher(loader)); "
+                                "data = next(iterator)`")
+        return sample
+
+    def __iter__(self):
+        self.loader = iter(self.orig_loader)
+        self.preload()
+        return self
+
+
+class ImageFile(object):
+    def __init__(self, phase='train'):
+        self.phase = phase
+        self.rng = np.random.RandomState(0)
+
+    def _get_valid_names(self, *dirs, shuffle=True):
+        name_sets = [self._get_name_set(d) for d in dirs]
+
+        def _join_and(a, b):
+            return a & b
+
+        valid_names = list(functools.reduce(_join_and, name_sets))
+        if shuffle:
+            self.rng.shuffle(valid_names)
+
+        return valid_names
+
+    @staticmethod
+    def _get_name_set(dir_name):
+        path_list = glob.glob(os.path.join(dir_name, '*'))
+        name_set = set()
+        for path in path_list:
+            name = os.path.basename(path)
+            name = os.path.splitext(name)[0]
+            name_set.add(name)
+        return name_set
+
+    @staticmethod
+    def _list_abspath(data_dir, ext, data_list):
+        return [os.path.join(data_dir, name + ext)
+                for name in data_list]
+
+class ImageFileTrain(ImageFile):
+    def __init__(
+        self,
+        alpha_dir="train_alpha",
+        fg_dir="train_fg",
+        bg_dir="train_bg",
+        alpha_ext=".jpg",
+        fg_ext=".jpg",
+        bg_ext=".jpg",
+        fg_have_bg_num=None,
+        alpha_ratio_json = None,
+        alpha_min_ratio = None,
+        key_sample_ratio = None,
+    ):
+        super(ImageFileTrain, self).__init__(phase="train")
+
+        self.alpha_dir  = alpha_dir
+        self.fg_dir     = fg_dir
+        self.bg_dir     = bg_dir
+        self.alpha_ext  = alpha_ext
+        self.fg_ext     = fg_ext
+        self.bg_ext     = bg_ext
+        logger = setup_logger(name=__name__)
+
+        if not isinstance(self.alpha_dir, str):
+            assert len(self.alpha_dir) == len(self.fg_dir) == len(alpha_ext) == len(fg_ext)
+            self.valid_fg_list = []
+            self.alpha = []
+            self.fg = []
+            self.key_alpha = []
+            self.key_fg = []
+            for i in range(len(self.alpha_dir)):
+                valid_fg_list = self._get_valid_names(self.fg_dir[i], self.alpha_dir[i])
+                valid_fg_list.sort()
+                alpha = self._list_abspath(self.alpha_dir[i], self.alpha_ext[i], valid_fg_list)
+                fg = self._list_abspath(self.fg_dir[i], self.fg_ext[i], valid_fg_list)
+                self.valid_fg_list += valid_fg_list
+
+                self.alpha += alpha * fg_have_bg_num[i]
+                self.fg += fg * fg_have_bg_num[i]
+
+                if alpha_ratio_json[i] is not None:
+                    tmp_key_alpha = []
+                    tmp_key_fg = []
+                    name_to_alpha_path = dict()
+                    for name in alpha:
+                        name_to_alpha_path[name.split('/')[-1].split('.')[0]] = name
+                    name_to_fg_path = dict()
+                    for name in fg:
+                        name_to_fg_path[name.split('/')[-1].split('.')[0]] = name
+
+                    with open(alpha_ratio_json[i], 'r') as file:
+                        alpha_ratio_list = json.load(file)
+                    for ratio, name in alpha_ratio_list:
+                        if ratio < alpha_min_ratio[i]:
+                            break
+                        tmp_key_alpha.append(name_to_alpha_path[name.split('.')[0]])
+                        tmp_key_fg.append(name_to_fg_path[name.split('.')[0]])
+
+                    self.key_alpha.extend(tmp_key_alpha * fg_have_bg_num[i])
+                    self.key_fg.extend(tmp_key_fg * fg_have_bg_num[i])
+
+            if len(self.key_alpha) != 0 and key_sample_ratio > 0:
+                repeat_num = key_sample_ratio * (len(self.alpha) - len(self.key_alpha)) / len(self.key_alpha) / (1 - key_sample_ratio) - 1
+                print('key sample num:', len(self.key_alpha), ', repeat num: ', repeat_num)
+                for i in range(math.ceil(repeat_num)):
+                    self.alpha += self.key_alpha
+                    self.fg += self.key_fg
+
+        else:
+            self.valid_fg_list = self._get_valid_names(self.fg_dir, self.alpha_dir)
+            self.valid_fg_list.sort()
+            self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_fg_list)
+            self.fg = self._list_abspath(self.fg_dir, self.fg_ext, self.valid_fg_list)
+            
+        self.valid_bg_list = [os.path.splitext(name)[0] for name in os.listdir(self.bg_dir)]
+        self.valid_bg_list.sort()
+
+        if fg_have_bg_num is not None:
+            # assert fg_have_bg_num * len(self.valid_fg_list) <= len(self.valid_bg_list)
+            # self.valid_bg_list = self.valid_bg_list[: fg_have_bg_num * len(self.valid_fg_list)]
+            assert len(self.alpha) <= len(self.valid_bg_list)
+            self.valid_bg_list = self.valid_bg_list[: len(self.alpha)]
+
+        self.bg = self._list_abspath(self.bg_dir, self.bg_ext, self.valid_bg_list)
+
+    def __len__(self):  
+        return len(self.alpha)
+
+class ImageFileTest(ImageFile):
+    def __init__(self,
+                 alpha_dir="test_alpha",
+                 merged_dir="test_merged",
+                 trimap_dir="test_trimap",
+                 alpha_ext=".png",
+                 merged_ext=".png",
+                 trimap_ext=".png"):
+        super(ImageFileTest, self).__init__(phase="test")
+
+        self.alpha_dir  = alpha_dir
+        self.merged_dir = merged_dir
+        self.trimap_dir = trimap_dir
+        self.alpha_ext  = alpha_ext
+        self.merged_ext = merged_ext
+        self.trimap_ext = trimap_ext
+
+        self.valid_image_list = self._get_valid_names(self.alpha_dir, self.merged_dir, self.trimap_dir, shuffle=False)
+
+        self.alpha = self._list_abspath(self.alpha_dir, self.alpha_ext, self.valid_image_list)
+        self.merged = self._list_abspath(self.merged_dir, self.merged_ext, self.valid_image_list)
+        self.trimap = self._list_abspath(self.trimap_dir, self.trimap_ext, self.valid_image_list)
+
+    def __len__(self):
+        return len(self.alpha)
+
+interp_list = [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]
+
+
+def maybe_random_interp(cv2_interp):
+    if CONFIG.data.random_interp:
+        return np.random.choice(interp_list)
+    else:
+        return cv2_interp
+
+
+class ToTensor(object):
+    """
+    Convert ndarrays in sample to Tensors with normalization.
+    """
+    def __init__(self, phase="test"):
+        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
+        self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
+        self.phase = phase
+
+    def __call__(self, sample):
+        image, alpha, trimap, mask = sample['image'][:,:,::-1], sample['alpha'], sample['trimap'], sample['mask']
+        
+        alpha[alpha < 0 ] = 0
+        alpha[alpha > 1] = 1
+     
+        image = image.transpose((2, 0, 1)).astype(np.float32)
+        alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
+        
+        mask = np.expand_dims(mask.astype(np.float32), axis=0)
+
+        image /= 255.
+
+        if self.phase == "train":
+            fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
+            sample['fg'] = torch.from_numpy(fg)
+            bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
+            sample['bg'] = torch.from_numpy(bg)
+
+        sample['image'], sample['alpha'], sample['trimap'] = \
+            torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long)
+        sample['image'] = sample['image']
+
+        if CONFIG.model.trimap_channel == 3:
+            sample['trimap'] = F.one_hot(sample['trimap'], num_classes=3).permute(2,0,1).float()
+        elif CONFIG.model.trimap_channel == 1:
+            sample['trimap'] = sample['trimap'][None,...].float()
+        else:
+            raise NotImplementedError("CONFIG.model.trimap_channel can only be 3 or 1")
+        sample['trimap'][sample['trimap'] < 85] = 0
+        sample['trimap'][sample['trimap'] >= 170] = 1
+        sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        sample['mask'] = torch.from_numpy(mask).float()
+
+        return sample
+
+
+class RandomAffine(object):
+    """
+    Random affine translation
+    """
+    def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
+        if isinstance(degrees, numbers.Number):
+            if degrees < 0:
+                raise ValueError("If degrees is a single number, it must be positive.")
+            self.degrees = (-degrees, degrees)
+        else:
+            assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
+                "degrees should be a list or tuple and it must be of length 2."
+            self.degrees = degrees
+
+        if translate is not None:
+            assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
+                "translate should be a list or tuple and it must be of length 2."
+            for t in translate:
+                if not (0.0 <= t <= 1.0):
+                    raise ValueError("translation values should be between 0 and 1")
+        self.translate = translate
+
+        if scale is not None:
+            assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
+                "scale should be a list or tuple and it must be of length 2."
+            for s in scale:
+                if s <= 0:
+                    raise ValueError("scale values should be positive")
+        self.scale = scale
+
+        if shear is not None:
+            if isinstance(shear, numbers.Number):
+                if shear < 0:
+                    raise ValueError("If shear is a single number, it must be positive.")
+                self.shear = (-shear, shear)
+            else:
+                assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
+                    "shear should be a list or tuple and it must be of length 2."
+                self.shear = shear
+        else:
+            self.shear = shear
+
+        self.resample = resample
+        self.fillcolor = fillcolor
+        self.flip = flip
+
+    @staticmethod
+    def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
+        """Get parameters for affine transformation
+
+        Returns:
+            sequence: params to be passed to the affine transformation
+        """
+        angle = random.uniform(degrees[0], degrees[1])
+        if translate is not None:
+            max_dx = translate[0] * img_size[0]
+            max_dy = translate[1] * img_size[1]
+            translations = (np.round(random.uniform(-max_dx, max_dx)),
+                            np.round(random.uniform(-max_dy, max_dy)))
+        else:
+            translations = (0, 0)
+
+        if scale_ranges is not None:
+            scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
+                     random.uniform(scale_ranges[0], scale_ranges[1]))
+        else:
+            scale = (1.0, 1.0)
+
+        if shears is not None:
+            shear = random.uniform(shears[0], shears[1])
+        else:
+            shear = 0.0
+
+        if flip is not None:
+            flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1
+
+        return angle, translations, scale, shear, flip
+
+    def __call__(self, sample):
+        fg, alpha = sample['fg'], sample['alpha']
+        rows, cols, ch = fg.shape
+        if np.maximum(rows, cols) < 1024:
+            params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
+        else:
+            params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
+
+        center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
+        M = self._get_inverse_affine_matrix(center, *params)
+        M = np.array(M).reshape((2, 3))
+
+        fg = cv2.warpAffine(fg, M, (cols, rows),
+                            flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
+        alpha = cv2.warpAffine(alpha, M, (cols, rows),
+                               flags=maybe_random_interp(cv2.INTER_NEAREST) + cv2.WARP_INVERSE_MAP)
+
+        sample['fg'], sample['alpha'] = fg, alpha
+
+        return sample
+
+
+    @ staticmethod
+    def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
+
+        angle = math.radians(angle)
+        shear = math.radians(shear)
+        scale_x = 1.0 / scale[0] * flip[0]
+        scale_y = 1.0 / scale[1] * flip[1]
+
+        # Inverted rotation matrix with scale and shear
+        d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
+        matrix = [
+            math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
+            -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
+        ]
+        matrix = [m / d for m in matrix]
+
+        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
+        matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
+        matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
+
+        # Apply center translation: C * RSS^-1 * C^-1 * T^-1
+        matrix[2] += center[0]
+        matrix[5] += center[1]
+
+        return matrix
+
+
+class RandomJitter(object):
+    """
+    Random change the hue of the image
+    """
+
+    def __call__(self, sample):
+        sample_ori = sample.copy()
+        fg, alpha = sample['fg'], sample['alpha']
+        # if alpha is all 0 skip
+        if np.all(alpha==0):
+            return sample_ori
+        # convert to HSV space, convert to float32 image to keep precision during space conversion.
+        fg = cv2.cvtColor(fg.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
+        # Hue noise
+        hue_jitter = np.random.randint(-40, 40)
+        fg[:, :, 0] = np.remainder(fg[:, :, 0].astype(np.float32) + hue_jitter, 360)
+        # Saturation noise
+        sat_bar = fg[:, :, 1][alpha > 0].mean()
+        if np.isnan(sat_bar):
+            return sample_ori
+        sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
+        sat = fg[:, :, 1]
+        sat = np.abs(sat + sat_jitter)
+        sat[sat>1] = 2 - sat[sat>1]
+        fg[:, :, 1] = sat
+        # Value noise
+        val_bar = fg[:, :, 2][alpha > 0].mean()
+        if np.isnan(val_bar):
+            return sample_ori
+        val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
+        val = fg[:, :, 2]
+        val = np.abs(val + val_jitter)
+        val[val>1] = 2 - val[val>1]
+        fg[:, :, 2] = val
+        # convert back to BGR space
+        fg = cv2.cvtColor(fg, cv2.COLOR_HSV2BGR)
+        sample['fg'] = fg*255
+
+        return sample
+
+
+class RandomHorizontalFlip(object):
+    """
+    Random flip image and label horizontally
+    """
+    def __init__(self, prob=0.5):
+        self.prob = prob
+    def __call__(self, sample):
+        fg, alpha = sample['fg'], sample['alpha']
+        if np.random.uniform(0, 1) < self.prob:
+            fg = cv2.flip(fg, 1)
+            alpha = cv2.flip(alpha, 1)
+        sample['fg'], sample['alpha'] = fg, alpha
+
+        return sample
+
+
+class RandomCrop(object):
+    """
+    Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
+
+    :param output_size (tuple or int): Desired output size. If int, square crop
+            is made.
+    """
+
+    def __init__(self, output_size=( CONFIG.data.crop_size, CONFIG.data.crop_size)):
+        assert isinstance(output_size, (int, tuple))
+        if isinstance(output_size, int):
+            self.output_size = (output_size, output_size)
+        else:
+            assert len(output_size) == 2
+            self.output_size = output_size
+        self.margin = output_size[0] // 2
+        self.logger = logging.getLogger("Logger")
+
+    def __call__(self, sample):
+        fg, alpha, trimap, mask, name = sample['fg'],  sample['alpha'], sample['trimap'], sample['mask'], sample['image_name']
+        bg = sample['bg']
+        h, w = trimap.shape
+        bg = cv2.resize(bg, (w, h), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
+        if w < self.output_size[0]+1 or h < self.output_size[1]+1:
+            ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w
+            # self.logger.warning("Size of {} is {}.".format(name, (h, w)))
+            while h < self.output_size[0]+1 or w < self.output_size[1]+1:
+                fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+                alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
+                                   interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+                trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
+                bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=maybe_random_interp(cv2.INTER_CUBIC))
+                mask = cv2.resize(mask, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
+                h, w = trimap.shape
+        small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
+        unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
+                                                       self.margin//4:(w-self.margin)//4] == 128)))
+        unknown_num = len(unknown_list)
+        if len(unknown_list) < 10:
+            left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
+        else:
+            idx = np.random.randint(unknown_num)
+            left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
+
+        fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
+        alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
+        bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
+        trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
+        mask_crop = mask[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
+
+        if len(np.where(trimap==128)[0]) == 0:
+            self.logger.error("{} does not have enough unknown area for crop. Resized to target size."
+                                "left_top: {}".format(name, left_top))
+            fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+            alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+            trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
+            bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=maybe_random_interp(cv2.INTER_CUBIC))
+            mask_crop = cv2.resize(mask, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
+        
+        sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'mask': mask_crop, 'bg': bg_crop})
+        return sample
+
+
+class OriginScale(object):
+    def __call__(self, sample):
+        h, w = sample["alpha_shape"]
+
+        if h % 32 == 0 and w % 32 == 0:
+            return sample
+
+        target_h = 32 * ((h - 1) // 32 + 1)
+        target_w = 32 * ((w - 1) // 32 + 1)
+        pad_h = target_h - h
+        pad_w = target_w - w
+
+        padded_image = np.pad(sample['image'], ((0,pad_h), (0, pad_w), (0,0)), mode="reflect")
+        padded_trimap = np.pad(sample['trimap'], ((0,pad_h), (0, pad_w)), mode="reflect")
+        padded_mask = np.pad(sample['mask'], ((0,pad_h), (0, pad_w)), mode="reflect")
+
+        sample['image'] = padded_image
+        sample['trimap'] = padded_trimap
+        sample['mask'] = padded_mask
+
+        return sample
+
+
+class GenMask(object):
+    def __init__(self):
+        self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,30)]
+
+    def __call__(self, sample):
+        alpha_ori = sample['alpha']
+        h, w = alpha_ori.shape
+
+        max_kernel_size = 30
+        alpha = cv2.resize(alpha_ori, (640,640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+
+        ### generate trimap
+        fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
+        bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
+        fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+        bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+
+        fg_width = np.random.randint(1, 30)
+        bg_width = np.random.randint(1, 30)
+        fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
+        bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
+        fg_mask = cv2.erode(fg_mask, self.erosion_kernels[fg_width])
+        bg_mask = cv2.erode(bg_mask, self.erosion_kernels[bg_width])
+
+        trimap = np.ones_like(alpha) * 128
+        trimap[fg_mask == 1] = 255
+        trimap[bg_mask == 1] = 0
+
+        trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
+        sample['trimap'] = trimap
+
+        ### generate mask
+        low = 0.01
+        high = 1.0
+        thres = random.random() * (high - low) + low
+        seg_mask = (alpha >= thres).astype(np.int32).astype(np.uint8)
+        random_num = random.randint(0,3)
+        if random_num == 0:
+            seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+        elif random_num == 1:
+            seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+        elif random_num == 2:
+            seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+            seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+        elif random_num == 3:
+            seg_mask = cv2.dilate(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+            seg_mask = cv2.erode(seg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+        
+        seg_mask = cv2.resize(seg_mask, (w,h), interpolation=cv2.INTER_NEAREST)
+        sample['mask'] = seg_mask
+
+        return sample
+
+
+class Composite(object):
+    def __call__(self, sample):
+        fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
+        alpha[alpha < 0 ] = 0
+        alpha[alpha > 1] = 1
+        fg[fg < 0 ] = 0
+        fg[fg > 255] = 255
+        bg[bg < 0 ] = 0
+        bg[bg > 255] = 255
+
+        image = fg * alpha[:, :, None] + bg * (1 - alpha[:, :, None])
+        sample['image'] = image
+        return sample
+
+
+class CutMask(object):
+    def __init__(self, perturb_prob = 0):
+        self.perturb_prob = perturb_prob
+
+    def __call__(self, sample):
+        if np.random.rand() < self.perturb_prob:
+            return sample
+
+        mask = sample['mask'] # H x W, trimap 0--255, segmask 0--1, alpha 0--1
+        h, w = mask.shape
+        perturb_size_h, perturb_size_w = random.randint(h // 4, h // 2), random.randint(w // 4, w // 2)
+        x = random.randint(0, h - perturb_size_h)
+        y = random.randint(0, w - perturb_size_w)
+        x1 = random.randint(0, h - perturb_size_h)
+        y1 = random.randint(0, w - perturb_size_w)
+        
+        mask[x:x+perturb_size_h, y:y+perturb_size_w] = mask[x1:x1+perturb_size_h, y1:y1+perturb_size_w].copy()
+        
+        sample['mask'] = mask
+        return sample
+
+
+class ScaleFg(object):
+    def __init__(self, min_scale_fg_scale=0.5, max_scale_fg_scale=1.0):
+        self.min_scale_fg_scale = min_scale_fg_scale
+        self.max_scale_fg_scale = max_scale_fg_scale
+
+    def __call__(self, sample):
+        scale_factor = np.random.uniform(low=self.min_scale_fg_scale, high=self.max_scale_fg_scale)
+
+        fg, alpha = sample['fg'], sample['alpha']  # np.array(): [H, W, 3] 0 ~ 255 , [H, W] 0.0 ~ 1.0
+        h, w = alpha.shape
+        scale_h, scale_w = int(h * scale_factor), int(w * scale_factor)
+
+        new_fg, new_alpha = np.zeros_like(fg), np.zeros_like(alpha)
+        fg = cv2.resize(fg, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR) 
+        alpha = cv2.resize(alpha, (scale_w, scale_h), interpolation=cv2.INTER_LINEAR) 
+
+        if scale_factor <= 1:
+            offset_h, offset_w = np.random.randint(h - scale_h + 1), np.random.randint(w - scale_w + 1)
+            new_fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :] = fg
+            new_alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w] = alpha
+        else:
+            offset_h, offset_w = np.random.randint(scale_h - h + 1), np.random.randint(scale_w - w + 1)
+            new_fg = fg[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w, :]
+            new_alpha = alpha[offset_h: offset_h + scale_h, offset_w: offset_w + scale_w]
+
+        sample['fg'], sample['alpha'] = new_fg, new_alpha
+        return sample
+
+class GenBBox(object):
+    def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None):
+        self.bbox_offset_factor = bbox_offset_factor
+        self.random_crop_bbox = random_crop_bbox
+        self.train_or_test = train_or_test
+        self.dataset_type = dataset_type
+        self.random_auto_matting = random_auto_matting
+
+    def __call__(self, sample):
+
+        alpha = sample['alpha']  # [1, H, W] 0.0 ~ 1.0
+        indices = torch.nonzero(alpha[0], as_tuple=True)
+
+        if len(indices[0]) > 0:
+
+            min_x, min_y = torch.min(indices[1]), torch.min(indices[0])
+            max_x, max_y = torch.max(indices[1]), torch.max(indices[0])
+
+            if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox:
+                ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1])
+                sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
+                sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
+                sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0]
+                bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]])
+
+            elif self.bbox_offset_factor != 0:
+                bbox_w = max(1, max_x - min_x)
+                bbox_h = max(1, max_y - min_y)
+                offset_w = math.ceil(self.bbox_offset_factor * bbox_w)
+                offset_h = math.ceil(self.bbox_offset_factor * bbox_h)
+
+                min_x = max(0, min_x + np.random.randint(-offset_w, offset_w))
+                max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w))
+                min_y = max(0, min_y + np.random.randint(-offset_h, offset_h))
+                max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h))
+                bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
+            else:
+                bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
+            
+            if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting:
+                bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]])
+
+        else:
+            bbox = torch.zeros(1, 4)
+
+        sample['bbox'] = bbox.float()
+        return sample
+
+class DataGenerator(Dataset):
+    def __init__(
+            self, 
+            data, 
+            phase="train", 
+            crop_size=512, 
+            remove_multi_fg=False, 
+            min_scale_fg_scale=None, 
+            max_scale_fg_scale=None, 
+            with_bbox = False, 
+            bbox_offset_factor = None,
+            return_keys = None,
+            random_crop_bbox = None,
+            dataset_name = None,
+            random_auto_matting = None,
+        ):
+        self.phase = phase
+        # self.crop_size = CONFIG.data.crop_size
+        self.crop_size = crop_size
+        self.remove_multi_fg = remove_multi_fg
+        self.with_bbox = with_bbox
+        self.bbox_offset_factor = bbox_offset_factor
+        self.alpha = data.alpha
+        self.return_keys = return_keys
+        self.random_crop_bbox = random_crop_bbox
+        self.dataset_name = dataset_name
+        self.random_auto_matting = random_auto_matting
+
+        if self.phase == "train":
+            self.fg = data.fg
+            self.bg = data.bg
+            self.merged = []
+            self.trimap = []
+        else:
+            self.fg = []
+            self.bg = []
+            self.merged = data.merged
+            self.trimap = data.trimap
+
+        train_trans = [
+            RandomAffine(degrees=30, scale=[0.8, 1.25], shear=10, flip=0.5),
+            GenMask(),
+            CutMask(perturb_prob=CONFIG.data.cutmask_prob),
+            RandomCrop((self.crop_size, self.crop_size)),
+            RandomJitter(),
+            Composite(),
+            ToTensor(phase="train")
+        ]
+        if min_scale_fg_scale is not None:
+            train_trans.insert(0, ScaleFg(min_scale_fg_scale, max_scale_fg_scale))
+        if self.with_bbox:
+            train_trans.append(GenBBox(bbox_offset_factor=self.bbox_offset_factor, random_crop_bbox=self.random_crop_bbox, random_auto_matting=self.random_auto_matting))
+
+        test_trans = [ OriginScale(), ToTensor() ]
+
+        self.transform = {
+            'train':
+                transforms.Compose(train_trans),
+            'val':
+                transforms.Compose([
+                    OriginScale(),
+                    ToTensor()
+                ]),
+            'test':
+                transforms.Compose(test_trans)
+        }[phase]
+
+        self.fg_num = len(self.fg)
+
+    def select_keys(self, sample):
+        new_sample = {}
+        for key, val in sample.items():
+            if key in self.return_keys:
+                new_sample[key] = val
+        return new_sample
+
+    def __getitem__(self, idx):
+        if self.phase == "train":
+            fg = cv2.imread(self.fg[idx % self.fg_num])
+            alpha = cv2.imread(self.alpha[idx % self.fg_num], 0).astype(np.float32)/255
+            bg = cv2.imread(self.bg[idx], 1)
+
+            if not self.remove_multi_fg:
+                fg, alpha, multi_fg = self._composite_fg(fg, alpha, idx)
+            else:
+                multi_fg = False
+            image_name = os.path.split(self.fg[idx % self.fg_num])[-1]
+            sample = {'fg': fg, 'alpha': alpha, 'bg': bg, 'image_name': image_name, 'multi_fg': multi_fg}
+
+        else:
+            image = cv2.imread(self.merged[idx])
+            alpha = cv2.imread(self.alpha[idx], 0)/255.
+            trimap = cv2.imread(self.trimap[idx], 0)
+            mask = (trimap >= 170).astype(np.float32)
+            image_name = os.path.split(self.merged[idx])[-1]
+
+            sample = {'image': image, 'alpha': alpha, 'trimap': trimap, 'mask': mask, 'image_name': image_name, 'alpha_shape': alpha.shape}
+
+        sample = self.transform(sample)
+
+        if self.return_keys is not None:
+            sample = self.select_keys(sample)
+        if self.dataset_name is not None:
+            sample['dataset_name'] = self.dataset_name
+        return sample
+
+    def _composite_fg(self, fg, alpha, idx):
+        
+        multi_fg = False
+        if np.random.rand() < 0.5:
+            idx2 = np.random.randint(self.fg_num) + idx
+            fg2 = cv2.imread(self.fg[idx2 % self.fg_num])
+            alpha2 = cv2.imread(self.alpha[idx2 % self.fg_num], 0).astype(np.float32)/255.
+            h, w = alpha.shape
+            fg2 = cv2.resize(fg2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+            alpha2 = cv2.resize(alpha2, (w, h), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+
+            alpha_tmp = 1 - (1 - alpha) * (1 - alpha2)
+            if  np.any(alpha_tmp < 1):
+                fg = fg.astype(np.float32) * alpha[:,:,None] + fg2.astype(np.float32) * (1 - alpha[:,:,None])
+                # The overlap of two 50% transparency should be 25%
+                alpha = alpha_tmp
+                fg = fg.astype(np.uint8)
+            multi_fg = True
+
+        if np.random.rand() < 0.25:
+            # fg = cv2.resize(fg, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+            # alpha = cv2.resize(alpha, (640, 640), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+            fg = cv2.resize(fg, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+            alpha = cv2.resize(alpha, (1280, 1280), interpolation=maybe_random_interp(cv2.INTER_NEAREST))
+
+        return fg, alpha, multi_fg
+
+    def __len__(self):
+        if self.phase == "train":
+            return len(self.bg)
+        else:
+            return len(self.alpha)
+
+
+class ResziePad(object):
+
+    def __init__(self, target_size=1024):
+        self.target_size = target_size
+
+    def __call__(self, sample):
+        _, H, W = sample['image'].shape
+
+        scale = self.target_size * 1.0 / max(H, W)
+        new_H, new_W = H * scale, W * scale
+        new_W = int(new_W + 0.5)
+        new_H = int(new_H + 0.5)
+
+        choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() else {'image', 'alpha'}
+        for key in choice:
+            if key in {'image', 'trimap'}:
+                sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0]
+            else:
+                # sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='nearest')[0]
+                sample[key] = F.interpolate(sample[key][None], size=(new_H, new_W), mode='bilinear', align_corners=False)[0]
+            padding = torch.zeros([sample[key].shape[0], self.target_size, self.target_size], dtype=sample[key].dtype, device=sample[key].device)
+            padding[:, : new_H, : new_W] = sample[key]
+            sample[key] = padding
+
+        return sample
+    
+
+class Cv2ResziePad(object):
+
+    def __init__(self, target_size=1024):
+        self.target_size = target_size
+
+    def __call__(self, sample):
+        H, W, _ = sample['image'].shape
+
+        scale = self.target_size * 1.0 / max(H, W)
+        new_H, new_W = H * scale, W * scale
+        new_W = int(new_W + 0.5)
+        new_H = int(new_H + 0.5)
+
+        choice = {'image', 'trimap', 'alpha'} if 'trimap' in sample.keys() and sample['trimap'] is not None else {'image', 'alpha'}
+        for key in choice:
+            sample[key] = cv2.resize(sample[key], (new_W, new_H), interpolation=cv2.INTER_LINEAR)  # cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC
+
+            if key == 'image':
+                padding = np.zeros([self.target_size, self.target_size, sample[key].shape[-1]], dtype=sample[key].dtype)
+                padding[: new_H, : new_W, :] = sample[key]
+                sample[key] = padding
+                sample[key] = sample[key][:, :, ::-1].transpose((2, 0, 1)).astype(np.float32) #/ 255.0
+            else:
+                padding = np.zeros([self.target_size, self.target_size], dtype=sample[key].dtype)
+                padding[: new_H, : new_W] = sample[key]
+                sample[key] = padding
+                sample[key] = sample[key][None].astype(np.float32)
+            sample[key] = torch.from_numpy(sample[key])
+
+        return sample
+    
+
+class AdobeCompositionTest(Dataset):
+    def __init__(self, data_dir, target_size=1024, multi_fg=None):
+        self.data_dir = data_dir
+        self.file_names = sorted(os.listdir(os.path.join(self.data_dir, 'merged')))
+        
+        test_trans = [
+            ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0)
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.file_names)
+
+    def __getitem__(self, idx):
+        phas = Image.open(os.path.join(self.data_dir, 'alpha_copy', self.file_names[idx])).convert('L')
+        tris = Image.open(os.path.join(self.data_dir, 'trimaps', self.file_names[idx]))
+        imgs = Image.open(os.path.join(self.data_dir, 'merged', self.file_names[idx]))
+        sample = {
+            'ori_h_w': (imgs.size[1], imgs.size[0]),
+            'data_type': 'Adobe'
+        }
+
+        sample['alpha'] = torchvision.transforms.functional.to_tensor(phas)  # [1, H, W] 0.0 ~ 1.0
+        sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
+        sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
+        sample['image_name'] = 'Adobe_' + self.file_names[idx]
+
+        sample = self.transform(sample)
+        sample['trimap'][sample['trimap'] < 85] = 0
+        sample['trimap'][sample['trimap'] >= 170] = 1
+        sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+
+
+class SIMTest(Dataset):
+    def __init__(self, data_dir, target_size=1024, multi_fg=None):
+        self.data_dir = data_dir
+        self.file_names = sorted(glob.glob(os.path.join(*[data_dir, '*', 'alpha', '*'])))  # [: 10]
+        test_trans = [
+            ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0)
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.file_names)
+
+    def __getitem__(self, idx):
+        phas = Image.open(self.file_names[idx]).convert('L')
+        # tris = Image.open(self.file_names[idx].replace('alpha', 'trimap'))
+        imgs = Image.open(self.file_names[idx].replace('alpha', 'merged'))
+        sample = {
+            'ori_h_w': (imgs.size[1], imgs.size[0]),
+            'data_type': 'SIM'
+        }
+
+        sample['alpha'] = torchvision.transforms.functional.to_tensor(phas)  # [1, H, W] 0.0 ~ 1.0
+        # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
+        sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
+        sample['image_name'] = 'SIM_{}_{}'.format(self.file_names[idx].split('/')[-3], self.file_names[idx].split('/')[-1])
+
+        sample = self.transform(sample)
+        # sample['trimap'][sample['trimap'] < 85] = 0
+        # sample['trimap'][sample['trimap'] >= 170] = 1
+        # sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+    
+
+class RW100Test(Dataset):
+    def __init__(self, data_dir, target_size=1024, multi_fg=None):
+        self.data_dir = data_dir
+        self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'mask', '*'])))
+
+        self.name_to_idx = dict()
+        for idx, file_name in enumerate(self.file_names):
+            self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
+            
+        test_trans = [
+            ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0, train_or_test='test', dataset_type='RW100')
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.file_names)
+
+    def __getitem__(self, idx):
+        phas = Image.open(self.file_names[idx]).convert('L')
+        imgs = Image.open(self.file_names[idx].replace('mask', 'image')[:-6] + '.jpg')
+        sample = {
+            'ori_h_w': (imgs.size[1], imgs.size[0]),
+            'data_type': 'RW100'
+        }
+
+        sample['alpha'] = torchvision.transforms.functional.to_tensor(phas)  # [1, H, W] 0.0 ~ 1.0
+        sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
+        sample['image_name'] = 'RW100_' + self.file_names[idx].split('/')[-1]
+        
+        sample = self.transform(sample)
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+    
+    
+class AIM500Test(Dataset):
+    def __init__(self, data_dir, target_size=1024, multi_fg=None):
+        self.data_dir = data_dir
+        self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original', '*'])))
+
+        self.name_to_idx = dict()
+        for idx, file_name in enumerate(self.file_names):
+            self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
+
+        test_trans = [
+            ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0)
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.file_names)
+
+    def __getitem__(self, idx):
+        phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L')
+        # tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L')
+        imgs = Image.open(self.file_names[idx])
+        sample = {
+            'ori_h_w': (imgs.size[1], imgs.size[0]),
+            'data_type': 'AIM500'
+        }
+
+        sample['alpha'] = torchvision.transforms.functional.to_tensor(phas)  # [1, H, W] 0.0 ~ 1.0
+        # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
+        sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
+        sample['image_name'] = 'AIM500_' + self.file_names[idx].split('/')[-1]
+
+        sample = self.transform(sample)
+        # sample['trimap'][sample['trimap'] < 85] = 0
+        # sample['trimap'][sample['trimap'] >= 170] = 1
+        # sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+
+
+class RWP636Test(Dataset):
+    def __init__(self, data_dir, target_size=1024, multi_fg=None):
+        self.data_dir = data_dir
+        self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'image', '*'])))
+
+        self.name_to_idx = dict()
+        for idx, file_name in enumerate(self.file_names):
+            self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
+
+        test_trans = [
+            ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0)
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.file_names)
+
+    def __getitem__(self, idx):
+        phas = Image.open(self.file_names[idx].replace('image', 'alpha').replace('jpg', 'png')).convert('L')
+        imgs = Image.open(self.file_names[idx])
+        sample = {
+            'ori_h_w': (imgs.size[1], imgs.size[0]),
+            'data_type': 'RWP636'
+        }
+
+        sample['alpha'] = torchvision.transforms.functional.to_tensor(phas)  # [1, H, W] 0.0 ~ 1.0
+        sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
+        sample['image_name'] = 'RWP636_' + self.file_names[idx].split('/')[-1]
+
+        sample = self.transform(sample)
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+
+
+class AM2KTest(Dataset):
+    def __init__(self, data_dir, target_size=1024, multi_fg=None):
+        self.data_dir = data_dir
+        self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'validation/original', '*'])))
+        test_trans = [
+            ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0)
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.file_names)
+
+    def __getitem__(self, idx):
+        phas = Image.open(self.file_names[idx].replace('original', 'mask').replace('jpg', 'png')).convert('L')
+        # tris = Image.open(self.file_names[idx].replace('original', 'trimap').replace('jpg', 'png')).convert('L')
+        imgs = Image.open(self.file_names[idx])
+        sample = {
+            'ori_h_w': (imgs.size[1], imgs.size[0]),
+            'data_type': 'AM2K'
+        }
+
+        sample['alpha'] = torchvision.transforms.functional.to_tensor(phas)  # [1, H, W] 0.0 ~ 1.0
+        # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
+        sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
+        sample['image_name'] = 'AM2K_' + self.file_names[idx].split('/')[-1]
+
+        sample = self.transform(sample)
+        # sample['trimap'][sample['trimap'] < 85] = 0
+        # sample['trimap'][sample['trimap'] >= 170] = 1
+        # sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+
+
+class P3M500Test(Dataset):
+    def __init__(self, data_dir, target_size=1024, multi_fg=None):
+        self.data_dir = data_dir
+        self.file_names = sorted(glob.glob(os.path.join(*[data_dir, 'original_image', '*'])))
+
+        self.name_to_idx = dict()
+        for idx, file_name in enumerate(self.file_names):
+            self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
+
+        test_trans = [
+            ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0)
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.file_names)
+
+    def __getitem__(self, idx):
+        phas = Image.open(self.file_names[idx].replace('original_image', 'mask').replace('jpg', 'png')).convert('L')
+        # tris = Image.open(self.file_names[idx].replace('original_image', 'trimap').replace('jpg', 'png')).convert('L')
+        imgs = Image.open(self.file_names[idx])
+        sample = {
+            'ori_h_w': (imgs.size[1], imgs.size[0]),
+            'data_type': 'P3M500'
+        }
+
+        sample['alpha'] = torchvision.transforms.functional.to_tensor(phas)  # [1, H, W] 0.0 ~ 1.0
+        # sample['trimap'] = torchvision.transforms.functional.to_tensor(tris) * 255.0
+        sample['image'] = torchvision.transforms.functional.to_tensor(imgs)
+        sample['image_name'] = 'P3M500_' + self.file_names[idx].split('/')[-1]
+
+        sample = self.transform(sample)
+        # sample['trimap'][sample['trimap'] < 85] = 0
+        # sample['trimap'][sample['trimap'] >= 170] = 1
+        # sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+
+
+class MattingTest(Dataset):
+    def __init__(
+        self, 
+        data_type,
+        data_dir,
+        image_sub_path,
+        alpha_sub_path,
+        trimpa_sub_path=None,
+        target_size=1024, 
+        multi_fg=None,
+    ):
+        self.data_type = data_type
+        self.data_dir = data_dir
+
+        self.image_paths = sorted(glob.glob(os.path.join(*[data_dir, image_sub_path])))
+        self.alpha_paths = sorted(glob.glob(os.path.join(*[data_dir, alpha_sub_path])))
+        self.trimpa_paths = sorted(glob.glob(os.path.join(*[data_dir, trimpa_sub_path]))) if trimpa_sub_path is not None else None
+
+        self.name_to_idx = dict()
+        for idx, file_name in enumerate(self.image_paths):
+            self.name_to_idx[file_name.split('/')[-1].split('.')[0]] = idx
+
+        test_trans = [
+            Cv2ResziePad(target_size=target_size),
+            GenBBox(bbox_offset_factor=0)
+        ]
+        self.transform = transforms.Compose(test_trans)
+        self.multi_fg = multi_fg
+
+    def __len__(self):  # 1000
+        return len(self.image_paths)
+
+    def __getitem__(self, idx):
+
+        img = cv2.imread(self.image_paths[idx])
+        sample = {
+            'image': img.astype(np.float32) / 255,
+            'alpha': cv2.imread(self.alpha_paths[idx], 0).astype(np.float32) / 255,
+            'trimap': cv2.imread(self.trimpa_paths[idx], 0) if self.trimpa_paths is not None else None,
+            'ori_h_w': (img.shape[0], img.shape[1]),
+            'data_type': self.data_type,
+            'image_name': self.data_type + '_' + self.image_paths[idx].split('/')[-1]
+        }
+
+        sample = self.transform(sample)
+        if self.trimpa_paths is not None:
+            sample['trimap'][sample['trimap'] < 85] = 0
+            sample['trimap'][sample['trimap'] >= 170] = 1
+            sample['trimap'][sample['trimap'] >= 85] = 0.5
+        else:
+            del sample['trimap']
+
+        if self.multi_fg is not None:
+            sample['multi_fg'] = torch.tensor(self.multi_fg)
+
+        return sample
+
+
+def adobe_composition_collate_fn(batch):
+    new_batch = defaultdict(list)
+    for sub_batch in batch:
+        for key in sub_batch.keys():
+            new_batch[key].append(sub_batch[key])
+    for key in new_batch: 
+        if isinstance(new_batch[key][0], torch.Tensor):
+            new_batch[key] = torch.stack(new_batch[key])
+    return dict(new_batch)
+
+
+def build_d2_test_dataloader(
+    dataset,
+    mapper=None,
+    total_batch_size=None,
+    local_batch_size=None,
+    num_workers=0,
+    collate_fn=None
+):
+
+    assert (total_batch_size is None) != (
+        local_batch_size is None
+    ), "Either total_batch_size or local_batch_size must be specified"
+
+    world_size = comm.get_world_size()
+
+    if total_batch_size is not None:
+        assert (
+            total_batch_size > 0 and total_batch_size % world_size == 0
+        ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
+            total_batch_size, world_size
+        )
+        batch_size = total_batch_size // world_size
+
+    if local_batch_size is not None:
+        batch_size = local_batch_size
+
+    logger = logging.getLogger(__name__)
+    if batch_size != 1:
+        logger.warning(
+            "When testing, batch size is set to 1. "
+            "This is the only mode that is supported for d2."
+        )
+
+    return build_detection_test_loader(
+        dataset=dataset,
+        mapper=mapper,
+        sampler=None,
+        num_workers=num_workers,
+        collate_fn=collate_fn,
+    )
+
+
+class AdobeCompositionEvaluator(DatasetEvaluator):
+
+    def __init__(
+        self, 
+        save_eval_results_step=-1, 
+        output_dir=None, 
+        eval_dataset_type=['Adobe'],
+        distributed=True,
+        eval_w_sam_hq_mask = False,
+    ):  
+
+        self.save_eval_results_step = save_eval_results_step
+        self.output_dir = output_dir
+        self.eval_index = 0
+        self.eval_dataset_type = eval_dataset_type
+        self.eval_w_sam_hq_mask = eval_w_sam_hq_mask
+
+        self._distributed = distributed
+        self._logger = logging.getLogger(__name__)
+
+    def reset(self):
+        self.eval_metric = dict()
+        for i in self.eval_dataset_type:
+            self.eval_metric[i + '_MSE'] = []
+            self.eval_metric[i + '_SAD'] = []
+            self.eval_metric[i + '_MAD'] = []
+            self.eval_metric[i + '_Grad'] = []
+            self.eval_metric[i + '_Conn'] = []
+
+        os.makedirs(self.output_dir, exist_ok=True) if self.output_dir is not None else None
+
+    def process(self, inputs, outputs):
+        """
+        Args:
+            inputs: {'alpha', 'trimap', 'image', 'bbox', 'image_name'}
+            outputs: [1, 1, H, W] 0. ~ 1.
+        """
+
+        # crop the black pad area
+        assert inputs['image'].shape[-1] == inputs['image'].shape[-2] == 1024 and len(inputs['ori_h_w']) == 1
+        inputs['ori_h_w'] = inputs['ori_h_w'][0]
+        before_pad_h, before_pad_w = int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][0] + 0.5), int(1024 / max(inputs['ori_h_w']) * inputs['ori_h_w'][1] + 0.5)
+        inputs['image'] = inputs['image'][:, :, :before_pad_h, :before_pad_w]
+        inputs['alpha'] = inputs['alpha'][:, :, :before_pad_h, :before_pad_w]
+
+        if self.eval_w_sam_hq_mask:
+            outputs, samhq_low_res_masks = outputs[0][:, :, :before_pad_h, :before_pad_w], outputs[1][:, :, :before_pad_h, :before_pad_w]
+            pred_alpha, label_alpha, samhq_low_res_masks = outputs.cpu().numpy(), inputs['alpha'].numpy(), (samhq_low_res_masks > 0).float().cpu()
+        else:
+            outputs = outputs[:, :, :before_pad_h, :before_pad_w]
+            pred_alpha, label_alpha = outputs.cpu().numpy(), inputs['alpha'].numpy()
+
+        # if 'trimap' in inputs.keys():
+        #     inputs['trimap'] = inputs['trimap'][:, :, :before_pad_h, :before_pad_w]
+        #     trimap = inputs['trimap'].numpy()
+        #     assert np.max(trimap) <= 1 and len(np.unique(trimap)) <= 3
+        #     sad_loss_unknown = compute_sad_loss(pred_alpha, label_alpha, trimap, area='unknown')
+        #     mse_loss_unknown = compute_mse_loss(pred_alpha, label_alpha, trimap, area='unknown')
+
+        #     self.eval_metric[inputs['data_type'][0] + '_unknown_mse (1e-3)'].append(mse_loss_unknown)
+        #     self.eval_metric[inputs['data_type'][0] + '_unknown_sad (1e3)'].append(sad_loss_unknown)
+
+        # calculate loss
+        assert np.max(pred_alpha) <= 1 and np.max(label_alpha) <= 1
+        eval_pred = np.uint8(pred_alpha[0, 0] * 255.0 + 0.5) * 1.0
+        eval_gt = label_alpha[0, 0] * 255.0
+
+        detailmap = np.zeros_like(eval_gt) + 128
+        mse_loss_ = compute_mse_loss(eval_pred, eval_gt, detailmap)
+        sad_loss_ = compute_sad_loss(eval_pred, eval_gt, detailmap)[0]
+        mad_loss_ = compute_mad_loss(eval_pred, eval_gt, detailmap)
+        grad_loss_ = compute_gradient_loss(eval_pred, eval_gt, detailmap)
+        conn_loss_ = compute_connectivity_error(eval_pred, eval_gt, detailmap)
+
+        self.eval_metric[inputs['data_type'][0] + '_MSE'].append(mse_loss_)
+        self.eval_metric[inputs['data_type'][0] + '_SAD'].append(sad_loss_)
+        self.eval_metric[inputs['data_type'][0] + '_MAD'].append(mad_loss_)
+        self.eval_metric[inputs['data_type'][0] + '_Grad'].append(grad_loss_)
+        self.eval_metric[inputs['data_type'][0] + '_Conn'].append(conn_loss_)
+
+        # vis results
+        if self.save_eval_results_step != -1 and self.eval_index % self.save_eval_results_step == 0:
+            if self.eval_w_sam_hq_mask:
+                self.save_vis_results(inputs, pred_alpha, samhq_low_res_masks)
+            else:
+                self.save_vis_results(inputs, pred_alpha)
+        self.eval_index += 1
+
+    def save_vis_results(self, inputs, pred_alpha, samhq_low_res_masks=None):
+
+        # image
+        image = inputs['image'][0].permute(1, 2, 0) * 255.0
+        l, u, r, d = int(inputs['bbox'][0, 0, 0].item()), int(inputs['bbox'][0, 0, 1].item()), int(inputs['bbox'][0, 0, 2].item()), int(inputs['bbox'][0, 0, 3].item())
+        red_line = torch.tensor([[255., 0., 0.]], device=image.device, dtype=image.dtype)
+        image[u: d, l, :] = red_line
+        image[u: d, r, :] = red_line
+        image[u, l: r, :] = red_line
+        image[d, l: r, :] = red_line
+        image = np.uint8(image.numpy())
+
+        # trimap, pred_alpha, label_alpha
+        save_results = [image]
+
+        choice = [inputs['trimap'], torch.from_numpy(pred_alpha), inputs['alpha']] if 'trimap' in inputs.keys() else [torch.from_numpy(pred_alpha), inputs['alpha']]
+        for val in choice:
+            val = val[0].permute(1, 2, 0).repeat(1, 1, 3) * 255.0 + 0.5  # +0.5 and int() = round()
+            val = np.uint8(val.numpy())
+            save_results.append(val)
+
+        if samhq_low_res_masks is not None:
+            save_results.append(np.uint8(samhq_low_res_masks[0].permute(1, 2, 0).repeat(1, 1, 3).numpy() * 255.0))
+
+        save_results = np.concatenate(save_results, axis=1)
+        save_name = os.path.join(self.output_dir, inputs['image_name'][0])
+        Image.fromarray(save_results).save(save_name.replace('.jpg', '.png'))
+
+    def evaluate(self):
+        
+        if self._distributed:
+            comm.synchronize()
+            eval_metric = comm.gather(self.eval_metric, dst=0)
+
+            if not comm.is_main_process():
+                return {}
+            
+            merges_eval_metric = defaultdict(list)
+            for sub_eval_metric in eval_metric:
+                for key, val in sub_eval_metric.items():
+                    merges_eval_metric[key] += val
+            eval_metric = merges_eval_metric
+
+        else:
+            eval_metric = self.eval_metric
+
+        eval_results = {}
+
+        for key, val in eval_metric.items():
+            if len(val) != 0:
+                # if 'mse' in key:
+                #     eval_results[key] = np.array(val).mean() * 1e3
+                # else:
+                #     assert 'sad' in key
+                #     eval_results[key] = np.array(val).mean() / 1e3
+                eval_results[key] = np.array(val).mean()
+
+        return eval_results
+
+
+if __name__ == '__main__':
+    pass
\ No newline at end of file
diff --git a/data/evaluate.py b/data/evaluate.py
new file mode 100644
index 0000000000000000000000000000000000000000..53147cd3b3d2e6a14b2f4f90ed768d3bb9e5c0fd
--- /dev/null
+++ b/data/evaluate.py
@@ -0,0 +1,102 @@
+import scipy.ndimage
+import numpy as np
+from skimage.measure import label
+import scipy.ndimage.morphology
+
+
+def gauss(x, sigma):
+    y = np.exp(-x ** 2 / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))
+    return y
+
+
+def dgauss(x, sigma):
+    y = -x * gauss(x, sigma) / (sigma ** 2)
+    return y
+
+
+def gaussgradient(im, sigma):
+    epsilon = 1e-2
+    halfsize = np.ceil(sigma * np.sqrt(-2 * np.log(np.sqrt(2 * np.pi) * sigma * epsilon))).astype(np.int32)
+    size = 2 * halfsize + 1
+    hx = np.zeros((size, size))
+    for i in range(0, size):
+        for j in range(0, size):
+            u = [i - halfsize, j - halfsize]
+            hx[i, j] = gauss(u[0], sigma) * dgauss(u[1], sigma)
+
+    hx = hx / np.sqrt(np.sum(np.abs(hx) * np.abs(hx)))
+    hy = hx.transpose()
+
+    gx = scipy.ndimage.convolve(im, hx, mode='nearest')
+    gy = scipy.ndimage.convolve(im, hy, mode='nearest')
+
+    return gx, gy
+
+
+def compute_gradient_loss(pred, target, trimap):
+
+    pred = pred / 255.0
+    target = target / 255.0
+
+    pred_x, pred_y = gaussgradient(pred, 1.4)
+    target_x, target_y = gaussgradient(target, 1.4)
+
+    pred_amp = np.sqrt(pred_x ** 2 + pred_y ** 2)
+    target_amp = np.sqrt(target_x ** 2 + target_y ** 2)
+
+    error_map = (pred_amp - target_amp) ** 2
+    loss = np.sum(error_map[trimap == 128])
+
+    return loss / 1000.
+
+
+def getLargestCC(segmentation):
+    labels = label(segmentation, connectivity=1)
+    largestCC = labels == np.argmax(np.bincount(labels.flat))
+    return largestCC
+
+
+def compute_connectivity_error(pred, target, trimap, step=0.1):
+    pred = pred / 255.0
+    target = target / 255.0
+    h, w = pred.shape
+
+    thresh_steps = list(np.arange(0, 1 + step, step))
+    l_map = np.ones_like(pred, dtype=np.float32) * -1
+    for i in range(1, len(thresh_steps)):
+        pred_alpha_thresh = (pred >= thresh_steps[i]).astype(np.int32)
+        target_alpha_thresh = (target >= thresh_steps[i]).astype(np.int32)
+
+        omega = getLargestCC(pred_alpha_thresh * target_alpha_thresh).astype(np.int32)
+        flag = ((l_map == -1) & (omega == 0)).astype(np.int32)
+        l_map[flag == 1] = thresh_steps[i - 1]
+
+    l_map[l_map == -1] = 1
+
+    pred_d = pred - l_map
+    target_d = target - l_map
+    pred_phi = 1 - pred_d * (pred_d >= 0.15).astype(np.int32)
+    target_phi = 1 - target_d * (target_d >= 0.15).astype(np.int32)
+    loss = np.sum(np.abs(pred_phi - target_phi)[trimap == 128])
+
+    return loss / 1000.
+
+
+def compute_mse_loss(pred, target, trimap):
+    error_map = (pred - target) / 255.0
+    loss = np.sum((error_map ** 2) * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
+
+    return loss
+
+
+def compute_sad_loss(pred, target, trimap):
+    error_map = np.abs((pred - target) / 255.0)
+    loss = np.sum(error_map * (trimap == 128))
+
+    return loss / 1000, np.sum(trimap == 128) / 1000
+
+def compute_mad_loss(pred, target, trimap):
+    error_map = np.abs((pred - target) / 255.0)
+    loss = np.sum(error_map * (trimap == 128)) / (np.sum(trimap == 128) + 1e-8)
+
+    return loss
diff --git a/data/p3m10k_dataset.py b/data/p3m10k_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..70806ac92a87c1dd30092f5f5eb9e92e5a95a1e5
--- /dev/null
+++ b/data/p3m10k_dataset.py
@@ -0,0 +1,325 @@
+import os
+import torch
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+from torchvision import transforms
+import math
+import torch.nn.functional as F
+
+
+class GenBBox(object):
+    def __init__(self, bbox_offset_factor = 0.1, random_crop_bbox = None, train_or_test = 'train', dataset_type = None, random_auto_matting=None):
+        self.bbox_offset_factor = bbox_offset_factor
+        self.random_crop_bbox = random_crop_bbox
+        self.train_or_test = train_or_test
+        self.dataset_type = dataset_type
+        self.random_auto_matting = random_auto_matting
+
+    def __call__(self, sample):
+
+        alpha = sample['alpha']  # [1, H, W] 0.0 ~ 1.0
+        indices = torch.nonzero(alpha[0], as_tuple=True)
+
+        if len(indices[0]) > 0:
+
+            min_x, min_y = torch.min(indices[1]), torch.min(indices[0])
+            max_x, max_y = torch.max(indices[1]), torch.max(indices[0])
+
+            if self.random_crop_bbox is not None and np.random.uniform(0, 1) < self.random_crop_bbox:
+                ori_h_w = (sample['alpha'].shape[-2], sample['alpha'].shape[-1])
+                sample['alpha'] = F.interpolate(sample['alpha'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
+                sample['image'] = F.interpolate(sample['image'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='bilinear', align_corners=False)[0]
+                sample['trimap'] = F.interpolate(sample['trimap'][None, :, min_y: max_y + 1, min_x: max_x + 1], size=ori_h_w, mode='nearest')[0]
+                bbox = torch.tensor([[0, 0, ori_h_w[1] - 1, ori_h_w[0] - 1]])
+
+            elif self.bbox_offset_factor != 0:
+                bbox_w = max(1, max_x - min_x)
+                bbox_h = max(1, max_y - min_y)
+                offset_w = math.ceil(self.bbox_offset_factor * bbox_w)
+                offset_h = math.ceil(self.bbox_offset_factor * bbox_h)
+
+                min_x = max(0, min_x + np.random.randint(-offset_w, offset_w))
+                max_x = min(alpha.shape[2] - 1, max_x + np.random.randint(-offset_w, offset_w))
+                min_y = max(0, min_y + np.random.randint(-offset_h, offset_h))
+                max_y = min(alpha.shape[1] - 1, max_y + np.random.randint(-offset_h, offset_h))
+                bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
+            else:
+                bbox = torch.tensor([[min_x, min_y, max_x, max_y]])
+            
+            if self.random_auto_matting is not None and np.random.uniform(0, 1) < self.random_auto_matting:
+                bbox = torch.tensor([[0, 0, alpha.shape[2] - 1, alpha.shape[1] - 1]])
+
+        else:
+            bbox = torch.zeros(1, 4)
+
+        sample['bbox'] = bbox.float()
+        return sample
+
+def random_interp():
+    return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])
+
+
+class SplitConcatImage(object):
+
+    def __init__(self, concat_num=4, wo_mask_to_mattes=False):
+        self.concat_num = concat_num
+        self.wo_mask_to_mattes = wo_mask_to_mattes
+        if self.wo_mask_to_mattes:
+            assert self.concat_num == 5
+
+    def __call__(self, concat_image):
+        if isinstance(concat_image, list):
+            concat_image, image_path = concat_image[0], concat_image[1]
+        else:
+            image_path = None
+        H, W, _ = concat_image.shape
+
+        concat_num = self.concat_num
+        if image_path is not None:
+            if '06-14' in image_path:
+                concat_num = 4
+            elif 'ori_mask' in image_path or 'SEMat' in image_path:
+                concat_num = 3
+            else:
+                concat_num = 5
+        
+        assert W % concat_num == 0
+        W = W // concat_num
+
+        image = concat_image[:H, :W]
+        if self.concat_num != 3:
+            trimap = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
+            if self.wo_mask_to_mattes:
+                alpha = concat_image[:H, 2 * W: 3 * W]
+            else:
+                alpha = concat_image[:H, (concat_num - 1) * W: concat_num * W]
+        else:
+            trimap = concat_image[:H, (concat_num - 1) * W: concat_num * W]
+            alpha = concat_image[:H, (concat_num - 2) * W: (concat_num - 1) * W]
+
+        return {'image': image, 'trimap': trimap, 'alpha': alpha}
+
+
+class RandomHorizontalFlip(object):
+
+    def __init__(self, prob=0.5):
+        self.prob = prob
+
+    def __call__(self, sample):
+        if np.random.uniform(0, 1) < self.prob:
+            for key in sample.keys():
+                sample[key] = cv2.flip(sample[key], 1)
+        return sample
+
+class EmptyAug(object):
+    def __call__(self, sample):
+        return sample
+
+class RandomReszieCrop(object):
+
+    def __init__(self, output_size=1024, aug_scale_min=0.5, aug_scale_max=1.5):
+        self.desired_size = output_size
+        self.aug_scale_min = aug_scale_min
+        self.aug_scale_max = aug_scale_max
+
+    def __call__(self, sample):
+        H, W, _ = sample['image'].shape
+        sample['trimap'] = sample['trimap'][:, :, None].repeat(3, axis=-1)
+        sample['alpha'] = sample['alpha'][:, :, None].repeat(3, axis=-1)
+
+        if self.aug_scale_min == 1.0 and self.aug_scale_max == 1.0:
+            crop_H, crop_W = H, W
+            crop_y1, crop_y2 = 0, crop_H
+            crop_x1, crop_x2 = 0, crop_W
+            scale_W, scaled_H = W, H
+        elif self.aug_scale_min == -1.0 and self.aug_scale_max == -1.0:
+            scale = min(self.desired_size / H, self.desired_size / W)
+            scaled_H, scale_W = round(H * scale), round(W * scale)
+            crop_H, crop_W = scaled_H, scale_W
+            crop_y1, crop_y2 = 0, crop_H
+            crop_x1, crop_x2 = 0, crop_W
+        else:
+            # random size
+            random_scale = np.random.uniform(0, 1) * (self.aug_scale_max - self.aug_scale_min) + self.aug_scale_min  # random_val: 0.5 ~ 1.5
+            scaled_size = round(random_scale * self.desired_size)
+
+            scale = min(scaled_size / H, scaled_size / W)
+            scaled_H, scale_W = round(H * scale), round(W * scale)
+
+            # random crop
+            crop_H, crop_W = min(self.desired_size, scaled_H), min(self.desired_size, scale_W)  # crop_size
+            margin_H, margin_W = max(scaled_H - crop_H, 0), max(scale_W - crop_W, 0)
+            offset_H, offset_W = np.random.randint(0, margin_H + 1), np.random.randint(0, margin_W + 1)
+            crop_y1, crop_y2 = offset_H, offset_H + crop_H
+            crop_x1, crop_x2 = offset_W, offset_W + crop_W
+
+        for key in sample.keys():
+            sample[key] = cv2.resize(sample[key], (scale_W, scaled_H), interpolation=random_interp())[crop_y1: crop_y2, crop_x1: crop_x2, :]  # resize and crop
+            padding = np.zeros(shape=(self.desired_size, self.desired_size, 3), dtype=sample[key].dtype)  # pad to desired_size
+            padding[: crop_H, : crop_W, :] = sample[key]
+            sample[key] = padding
+
+        return sample
+
+
+class RandomJitter(object):
+    """
+    Random change the hue of the image
+    """
+
+    def __call__(self, sample):
+
+        image = sample['image']
+
+        # convert to HSV space, convert to float32 image to keep precision during space conversion.
+        image = cv2.cvtColor(image.astype(np.float32)/255.0, cv2.COLOR_BGR2HSV)
+        # Hue noise
+        hue_jitter = np.random.randint(-40, 40)
+        image[:, :, 0] = np.remainder(image[:, :, 0].astype(np.float32) + hue_jitter, 360)
+        # Saturation noise
+        sat_bar = image[:, :, 1].mean()
+
+        sat_jitter = np.random.rand()*(1.1 - sat_bar)/5 - (1.1 - sat_bar) / 10
+        sat = image[:, :, 1]
+        sat = np.abs(sat + sat_jitter)
+        sat[sat>1] = 2 - sat[sat>1]
+        image[:, :, 1] = sat
+        # Value noise
+        val_bar = image[:, :, 2].mean()
+
+        val_jitter = np.random.rand()*(1.1 - val_bar)/5-(1.1 - val_bar) / 10
+        val = image[:, :, 2]
+        val = np.abs(val + val_jitter)
+        val[val>1] = 2 - val[val>1]
+        image[:, :, 2] = val
+        # convert back to BGR space
+        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
+        sample['image'] = image * 255
+
+        return sample
+
+
+class ToTensor(object):
+
+    def __call__(self, sample):
+        image, alpha, trimap = sample['image'][:, :, ::-1], sample['alpha'], sample['trimap']
+
+        # image
+        image = image.transpose((2, 0, 1)) / 255.
+        sample['image'] = torch.from_numpy(image).float()
+
+        # alpha
+        alpha = alpha.transpose((2, 0, 1))[0: 1] / 255.
+        alpha[alpha < 0 ] = 0
+        alpha[alpha > 1] = 1
+        sample['alpha'] = torch.from_numpy(alpha).float()
+
+        # trimap
+        trimap = trimap.transpose((2, 0, 1))[0: 1] / 1.
+        sample['trimap'] = torch.from_numpy(trimap).float()
+        sample['trimap'][sample['trimap'] < 85] = 0
+        sample['trimap'][sample['trimap'] >= 170] = 1
+        sample['trimap'][sample['trimap'] >= 85] = 0.5
+
+        return sample
+    
+
+class GenTrimap(object):
+    def __init__(self):
+        self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,100)]
+
+    def __call__(self, sample):
+        alpha = sample['alpha']
+        h, w = alpha.shape
+
+        max_kernel_size = max(30, int((min(h,w) / 2048) * 30))
+
+        ### generate trimap
+        fg_mask = (alpha / 255.0 + 1e-5).astype(np.int32).astype(np.uint8)
+        bg_mask = (1 - alpha / 255.0 + 1e-5).astype(np.int32).astype(np.uint8)
+        fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+        bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+
+        trimap = np.ones_like(alpha) * 128
+        trimap[fg_mask == 1] = 255
+        trimap[bg_mask == 1] = 0
+
+        trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
+        sample['trimap'] = trimap
+
+        return sample
+    
+
+class P3MData(Dataset):
+    def __init__(
+        self, 
+        data_root_path = '/root/data/my_path_b/public_data/data/matting/P3M-10k/train/blurred_image/', 
+        output_size = 1024, 
+        aug_scale_min = 0.8, 
+        aug_scale_max = 1.5,
+        with_bbox = True, 
+        bbox_offset_factor = 0.05,
+        num_ratio = 4.06,  # 9421 * 4.06 = 38249.26 (38251)
+    ):
+        
+        self.data_root_path = data_root_path
+        self.output_size = output_size
+        self.aug_scale_min = aug_scale_min
+        self.aug_scale_max = aug_scale_max
+        self.with_bbox = with_bbox
+        self.bbox_offset_factor = bbox_offset_factor
+        self.num_ratio = num_ratio
+
+        self.image_names = os.listdir(self.data_root_path)
+        self.image_names = [i for i in self.image_names if 'jpg' in i]
+        self.image_names.sort()
+
+        train_trans = [
+            RandomHorizontalFlip(prob=0 if hasattr(self, 'return_image_name') and self.return_image_name else 0.5),
+            GenTrimap(),
+            RandomReszieCrop(self.output_size, self.aug_scale_min, self.aug_scale_max),
+            RandomJitter(),
+            ToTensor(),
+            GenBBox(bbox_offset_factor=self.bbox_offset_factor)
+        ]
+        self.transform = transforms.Compose(train_trans)
+
+    def __getitem__(self, idx):
+
+        if self.num_ratio is not None:
+            if self.num_ratio < 1.0:
+                idx = np.random.randint(0, len(self.image_names))
+            else:
+                idx = idx % len(self.image_names)
+
+        image_path = os.path.join(self.data_root_path, self.image_names[idx])
+        alpha_path = image_path.replace('jpg', 'png').replace('blurred_image', 'mask')
+
+        sample = self.transform({
+            'image': cv2.imread(image_path),
+            'alpha': cv2.imread(alpha_path, 0),
+        })
+
+        sample['dataset_name'] = 'P3M'
+        sample['multi_fg'] = False
+
+        return sample
+
+    def __len__(self):
+        if self.num_ratio is not None:
+            return int(len(self.image_names) * self.num_ratio)
+        else:
+            return len(self.image_names)
+
+
+if __name__ == '__main__':
+
+    dataset = P3MData()
+    data = dataset[0]
+    print(len(dataset))
+    for key, val in data.items():
+        if isinstance(val, torch.Tensor):
+            print(key, val.shape, torch.min(val), torch.max(val), torch.unique(val))
+        else:
+            print(key, val)
\ No newline at end of file
diff --git a/data/rand_augment.py b/data/rand_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b10cf2c348e9cd39b89c272d38ff5479135f67
--- /dev/null
+++ b/data/rand_augment.py
@@ -0,0 +1,196 @@
+# copyright: https://github.com/ildoonet/pytorch-randaugment
+# code in this file is adpated from rpmcruz/autoaugment
+# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
+# This code is modified version of one of ildoonet, for randaugmentation of fixmatch.
+
+import random
+
+import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+
+def AutoContrast(img, _):
+    return PIL.ImageOps.autocontrast(img)
+
+
+def Brightness(img, v):
+    assert v >= 0.0
+    return PIL.ImageEnhance.Brightness(img).enhance(v)
+
+
+def Color(img, v):
+    assert v >= 0.0
+    return PIL.ImageEnhance.Color(img).enhance(v)
+
+
+def Contrast(img, v):
+    assert v >= 0.0
+    return PIL.ImageEnhance.Contrast(img).enhance(v)
+
+
+def Equalize(img, _):
+    return PIL.ImageOps.equalize(img)
+
+
+def Invert(img, _):
+    return PIL.ImageOps.invert(img)
+
+
+def Identity(img, v):
+    return img
+
+
+def Posterize(img, v):  # [4, 8]
+    v = int(v)
+    v = max(1, v)
+    return PIL.ImageOps.posterize(img, v)
+
+
+def Rotate(img, v):  # [-30, 30]
+    #assert -30 <= v <= 30
+    #if random.random() > 0.5:
+    #    v = -v
+    return img.rotate(v)
+
+
+
+def Sharpness(img, v):  # [0.1,1.9]
+    assert v >= 0.0
+    return PIL.ImageEnhance.Sharpness(img).enhance(v)
+
+
+def ShearX(img, v):  # [-0.3, 0.3]
+    #assert -0.3 <= v <= 0.3
+    #if random.random() > 0.5:
+    #    v = -v
+    return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
+
+
+def ShearY(img, v):  # [-0.3, 0.3]
+    #assert -0.3 <= v <= 0.3
+    #if random.random() > 0.5:
+    #    v = -v
+    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
+
+
+def TranslateX(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    #assert -0.3 <= v <= 0.3
+    #if random.random() > 0.5:
+    #    v = -v
+    v = v * img.size[0]
+    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
+
+
+def TranslateXabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    #assert v >= 0.0
+    #if random.random() > 0.5:
+    #    v = -v
+    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
+
+
+def TranslateY(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    #assert -0.3 <= v <= 0.3
+    #if random.random() > 0.5:
+    #    v = -v
+    v = v * img.size[1]
+    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
+
+
+def TranslateYabs(img, v):  # [-150, 150] => percentage: [-0.45, 0.45]
+    #assert 0 <= v
+    #if random.random() > 0.5:
+    #    v = -v
+    return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
+
+
+def Solarize(img, v):  # [0, 256]
+    assert 0 <= v <= 256
+    return PIL.ImageOps.solarize(img, v)
+
+
+def Cutout(img, v):  #[0, 60] => percentage: [0, 0.2] => change to [0, 0.5]
+    assert 0.0 <= v <= 0.5
+    if v <= 0.:
+        return img
+
+    v = v * img.size[0]
+    return CutoutAbs(img, v)
+
+
+def CutoutAbs(img, v):  # [0, 60] => percentage: [0, 0.2]
+    # assert 0 <= v <= 20
+    if v < 0:
+        return img
+    w, h = img.size
+    x0 = np.random.uniform(w)
+    y0 = np.random.uniform(h)
+
+    x0 = int(max(0, x0 - v / 2.))
+    y0 = int(max(0, y0 - v / 2.))
+    x1 = min(w, x0 + v)
+    y1 = min(h, y0 + v)
+
+    xy = (x0, y0, x1, y1)
+    color = (125, 123, 114)
+    # color = (0, 0, 0)
+    img = img.copy()
+    PIL.ImageDraw.Draw(img).rectangle(xy, color)
+    return img
+
+    
+def augment_list():  
+    l = [
+        (AutoContrast, 0, 1),
+        (Brightness, 0.05, 0.95),
+        (Color, 0.05, 0.95),
+        (Contrast, 0.05, 0.95),
+        (Equalize, 0, 1),
+        (Identity, 0, 1),
+        (Posterize, 4, 8),
+        # (Rotate, -30, 30),
+        (Sharpness, 0.05, 0.95),
+        # (ShearX, -0.3, 0.3),
+        # (ShearY, -0.3, 0.3),
+        (Solarize, 0, 256),
+        # (TranslateX, -0.3, 0.3),
+        # (TranslateY, -0.3, 0.3)
+    ]
+    return l
+
+    
+class RandAugment:
+    def __init__(self, n, m):
+        self.n = n
+        self.m = m      # [0, 30] in fixmatch, deprecated.
+        self.augment_list = augment_list()
+
+        
+    def __call__(self, img, cutout=True):
+        ops = random.choices(self.augment_list, k=self.n)
+        for op, min_val, max_val in ops:
+            val = min_val + float(max_val - min_val)*random.random()
+            img = op(img, val) 
+        if cutout:
+            cutout_val = random.random() * 0.5 
+            img = Cutout(img, cutout_val) #for fixmatch
+        return img
+
+    
+if __name__ == '__main__':
+    # randaug = RandAugment(3,5)
+    # print(randaug)
+    # for item in randaug.augment_list:
+    #     print(item)
+    import os
+
+    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
+    img = PIL.Image.open('./u.jpg')
+    randaug = RandAugment(3,6)
+    img = randaug(img)
+    import matplotlib
+    from matplotlib import pyplot as plt 
+    plt.imshow(img)
+    plt.show()
\ No newline at end of file
diff --git a/data/refmatte_dataset.py b/data/refmatte_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2ea66a6ed590ff49611ed277839bd4d38620917
--- /dev/null
+++ b/data/refmatte_dataset.py
@@ -0,0 +1,418 @@
+import os
+import torch
+import numpy as np
+import cv2
+from torch.utils.data import Dataset
+from torchvision import transforms
+import random
+import imgaug.augmenters as iaa
+import numbers
+import math
+
+
+def random_interp():
+    return np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])
+
+class RandomAffine(object):
+    """
+    Random affine translation
+    """
+    def __init__(self, degrees, translate=None, scale=None, shear=None, flip=None, resample=False, fillcolor=0):
+        if isinstance(degrees, numbers.Number):
+            if degrees < 0:
+                raise ValueError("If degrees is a single number, it must be positive.")
+            self.degrees = (-degrees, degrees)
+        else:
+            assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \
+                "degrees should be a list or tuple and it must be of length 2."
+            self.degrees = degrees
+
+        if translate is not None:
+            assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
+                "translate should be a list or tuple and it must be of length 2."
+            for t in translate:
+                if not (0.0 <= t <= 1.0):
+                    raise ValueError("translation values should be between 0 and 1")
+        self.translate = translate
+
+        if scale is not None:
+            assert isinstance(scale, (tuple, list)) and len(scale) == 2, \
+                "scale should be a list or tuple and it must be of length 2."
+            for s in scale:
+                if s <= 0:
+                    raise ValueError("scale values should be positive")
+        self.scale = scale
+
+        if shear is not None:
+            if isinstance(shear, numbers.Number):
+                if shear < 0:
+                    raise ValueError("If shear is a single number, it must be positive.")
+                self.shear = (-shear, shear)
+            else:
+                assert isinstance(shear, (tuple, list)) and len(shear) == 2, \
+                    "shear should be a list or tuple and it must be of length 2."
+                self.shear = shear
+        else:
+            self.shear = shear
+
+        self.resample = resample
+        self.fillcolor = fillcolor
+        self.flip = flip
+
+    @staticmethod
+    def get_params(degrees, translate, scale_ranges, shears, flip, img_size):
+        """Get parameters for affine transformation
+
+        Returns:
+            sequence: params to be passed to the affine transformation
+        """
+        angle = random.uniform(degrees[0], degrees[1])
+        if translate is not None:
+            max_dx = translate[0] * img_size[0]
+            max_dy = translate[1] * img_size[1]
+            translations = (np.round(random.uniform(-max_dx, max_dx)),
+                            np.round(random.uniform(-max_dy, max_dy)))
+        else:
+            translations = (0, 0)
+
+        if scale_ranges is not None:
+            scale = (random.uniform(scale_ranges[0], scale_ranges[1]),
+                     random.uniform(scale_ranges[0], scale_ranges[1]))
+        else:
+            scale = (1.0, 1.0)
+
+        if shears is not None:
+            shear = random.uniform(shears[0], shears[1])
+        else:
+            shear = 0.0
+
+        if flip is not None:
+            flip = (np.random.rand(2) < flip).astype(np.int32) * 2 - 1
+
+        return angle, translations, scale, shear, flip
+
+    def __call__(self, sample):
+        fg, alpha = sample['fg'], sample['alpha']
+        rows, cols, ch = fg.shape
+        if np.maximum(rows, cols) < 1024:
+            params = self.get_params((0, 0), self.translate, self.scale, self.shear, self.flip, fg.size)
+        else:
+            params = self.get_params(self.degrees, self.translate, self.scale, self.shear, self.flip, fg.size)
+
+        center = (cols * 0.5 + 0.5, rows * 0.5 + 0.5)
+        M = self._get_inverse_affine_matrix(center, *params)
+        M = np.array(M).reshape((2, 3))
+
+        fg = cv2.warpAffine(fg, M, (cols, rows), flags=random_interp() + cv2.WARP_INVERSE_MAP)
+        alpha = cv2.warpAffine(alpha, M, (cols, rows), flags=random_interp() + cv2.WARP_INVERSE_MAP)
+
+        sample['fg'], sample['alpha'] = fg, alpha
+
+        return sample
+
+    @ staticmethod
+    def _get_inverse_affine_matrix(center, angle, translate, scale, shear, flip):
+        # Helper method to compute inverse matrix for affine transformation
+
+        # As it is explained in PIL.Image.rotate
+        # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
+        # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
+        # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
+        # RSS is rotation with scale and shear matrix
+        # It is different from the original function in torchvision
+        # The order are changed to flip -> scale -> rotation -> shear
+        # x and y have different scale factors
+        # RSS(shear, a, scale, f) = [ cos(a + shear)*scale_x*f -sin(a + shear)*scale_y     0]
+        # [ sin(a)*scale_x*f          cos(a)*scale_y             0]
+        # [     0                       0                      1]
+        # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
+
+        angle = math.radians(angle)
+        shear = math.radians(shear)
+        scale_x = 1.0 / scale[0] * flip[0]
+        scale_y = 1.0 / scale[1] * flip[1]
+
+        # Inverted rotation matrix with scale and shear
+        d = math.cos(angle + shear) * math.cos(angle) + math.sin(angle + shear) * math.sin(angle)
+        matrix = [
+            math.cos(angle) * scale_x, math.sin(angle + shear) * scale_x, 0,
+            -math.sin(angle) * scale_y, math.cos(angle + shear) * scale_y, 0
+        ]
+        matrix = [m / d for m in matrix]
+
+        # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
+        matrix[2] += matrix[0] * (-center[0] - translate[0]) + matrix[1] * (-center[1] - translate[1])
+        matrix[5] += matrix[3] * (-center[0] - translate[0]) + matrix[4] * (-center[1] - translate[1])
+
+        # Apply center translation: C * RSS^-1 * C^-1 * T^-1
+        matrix[2] += center[0]
+        matrix[5] += center[1]
+
+        return matrix
+    
+
+class GenTrimap(object):
+    def __init__(self):
+        self.erosion_kernels = [None] + [cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size)) for size in range(1,100)]
+
+    def __call__(self, sample):
+        alpha = sample['alpha']
+        h, w = alpha.shape
+
+        max_kernel_size = max(30, int((min(h,w) / 2048) * 30))
+
+        ### generate trimap
+        fg_mask = (alpha + 1e-5).astype(np.int32).astype(np.uint8)
+        bg_mask = (1 - alpha + 1e-5).astype(np.int32).astype(np.uint8)
+        fg_mask = cv2.erode(fg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+        bg_mask = cv2.erode(bg_mask, self.erosion_kernels[np.random.randint(1, max_kernel_size)])
+
+        trimap = np.ones_like(alpha) * 128
+        trimap[fg_mask == 1] = 255
+        trimap[bg_mask == 1] = 0
+
+        trimap = cv2.resize(trimap, (w,h), interpolation=cv2.INTER_NEAREST)
+        sample['trimap'] = trimap
+
+        return sample
+    
+
+class RandomCrop(object):
+    """
+    Crop randomly the image in a sample, retain the center 1/4 images, and resize to 'output_size'
+
+    :param output_size (tuple or int): Desired output size. If int, square crop
+            is made.
+    """
+
+    def __init__(self, output_size=(1024, 1024)):
+        assert isinstance(output_size, (int, tuple))
+        if isinstance(output_size, int):
+            self.output_size = (output_size, output_size)
+        else:
+            assert len(output_size) == 2
+            self.output_size = output_size
+        self.margin = output_size[0] // 2
+
+    def __call__(self, sample):
+        fg, alpha, trimap, name = sample['fg'],  sample['alpha'], sample['trimap'], sample['image_name']
+        bg = sample['bg']
+        h, w = trimap.shape
+        bg = cv2.resize(bg, (w, h), interpolation=random_interp())
+        if w < self.output_size[0]+1 or h < self.output_size[1]+1:
+            ratio = 1.1*self.output_size[0]/h if h < w else 1.1*self.output_size[1]/w
+            # self.logger.warning("Size of {} is {}.".format(name, (h, w)))
+            while h < self.output_size[0]+1 or w < self.output_size[1]+1:
+                fg = cv2.resize(fg, (int(w*ratio), int(h*ratio)), interpolation=random_interp())
+                alpha = cv2.resize(alpha, (int(w*ratio), int(h*ratio)),
+                                   interpolation=random_interp())
+                trimap = cv2.resize(trimap, (int(w*ratio), int(h*ratio)), interpolation=cv2.INTER_NEAREST)
+                bg = cv2.resize(bg, (int(w*ratio), int(h*ratio)), interpolation=random_interp())
+                h, w = trimap.shape
+        small_trimap = cv2.resize(trimap, (w//4, h//4), interpolation=cv2.INTER_NEAREST)
+        unknown_list = list(zip(*np.where(small_trimap[self.margin//4:(h-self.margin)//4,
+                                                       self.margin//4:(w-self.margin)//4] == 128)))
+        unknown_num = len(unknown_list)
+        if len(unknown_list) < 10:
+            left_top = (np.random.randint(0, h-self.output_size[0]+1), np.random.randint(0, w-self.output_size[1]+1))
+        else:
+            idx = np.random.randint(unknown_num)
+            left_top = (unknown_list[idx][0]*4, unknown_list[idx][1]*4)
+
+        fg_crop = fg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
+        alpha_crop = alpha[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
+        bg_crop = bg[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1],:]
+        trimap_crop = trimap[left_top[0]:left_top[0]+self.output_size[0], left_top[1]:left_top[1]+self.output_size[1]]
+
+        if len(np.where(trimap==128)[0]) == 0:
+            fg_crop = cv2.resize(fg, self.output_size[::-1], interpolation=random_interp())
+            alpha_crop = cv2.resize(alpha, self.output_size[::-1], interpolation=random_interp())
+            trimap_crop = cv2.resize(trimap, self.output_size[::-1], interpolation=cv2.INTER_NEAREST)
+            bg_crop = cv2.resize(bg, self.output_size[::-1], interpolation=random_interp())
+        
+        sample.update({'fg': fg_crop, 'alpha': alpha_crop, 'trimap': trimap_crop, 'bg': bg_crop})
+        return sample
+    
+
+class Composite_Seg(object):
+    def __call__(self, sample):
+        fg, bg, alpha = sample['fg'], sample['bg'], sample['alpha']
+        fg[fg < 0 ] = 0
+        fg[fg > 255] = 255
+        image = fg
+        sample['image'] = image
+        return sample
+    
+
+class ToTensor(object):
+    """
+    Convert ndarrays in sample to Tensors with normalization.
+    """
+    def __init__(self, phase="test", real_world_aug = False):
+        # self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
+        # self.std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
+        self.mean = torch.tensor([0.0, 0.0, 0.0]).view(3,1,1)
+        self.std = torch.tensor([1.0, 1.0, 1.0]).view(3,1,1)
+        self.phase = phase
+        if real_world_aug:
+            self.RWA = iaa.SomeOf((1, None), [
+                iaa.LinearContrast((0.6, 1.4)),
+                iaa.JpegCompression(compression=(0, 60)),
+                iaa.GaussianBlur(sigma=(0.0, 3.0)),
+                iaa.AdditiveGaussianNoise(scale=(0, 0.1*255))
+            ], random_order=True)
+        else:
+            self.RWA = None
+    
+    def get_box_from_alpha(self, alpha_final):
+        bi_mask = np.zeros_like(alpha_final)
+        bi_mask[alpha_final>0.5] = 1
+        #bi_mask[alpha_final<=0.5] = 0
+        fg_set = np.where(bi_mask != 0)
+        if len(fg_set[1]) == 0 or len(fg_set[0]) == 0:
+            x_min = random.randint(1, 511)
+            x_max = random.randint(1, 511) + x_min
+            y_min = random.randint(1, 511)
+            y_max = random.randint(1, 511) + y_min
+        else:
+            x_min = np.min(fg_set[1])
+            x_max = np.max(fg_set[1])
+            y_min = np.min(fg_set[0])
+            y_max = np.max(fg_set[0])
+        bbox = np.array([x_min, y_min, x_max, y_max])
+        #cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0,255,0), 2)
+        #cv2.imwrite('../outputs/test.jpg', image)
+        #cv2.imwrite('../outputs/test_gt.jpg', alpha_single)
+        return bbox
+
+    def __call__(self, sample):
+        # convert GBR images to RGB
+        image, alpha, trimap = sample['image'][:,:,::-1], sample['alpha'], sample['trimap']
+        
+        alpha[alpha < 0 ] = 0
+        alpha[alpha > 1] = 1
+        
+        bbox = self.get_box_from_alpha(alpha)
+
+        if self.phase == 'train' and self.RWA is not None and np.random.rand() < 0.5:
+            image[image > 255] = 255
+            image[image < 0] = 0
+            image = np.round(image).astype(np.uint8)
+            image = np.expand_dims(image, axis=0)
+            image = self.RWA(images=image)
+            image = image[0, ...]
+
+        # swap color axis because
+        # numpy image: H x W x C
+        # torch image: C X H X W
+        image = image.transpose((2, 0, 1)).astype(np.float32)
+        alpha = np.expand_dims(alpha.astype(np.float32), axis=0)
+        trimap[trimap < 85] = 0
+        trimap[trimap >= 170] = 2
+        trimap[trimap >= 85] = 1
+        #image = cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255,0,0), 3)
+        #cv2.imwrite(os.path.join('outputs', 'img_bbox.png'), image.astype('uint8'))
+        # normalize image
+        image /= 255.
+
+        if self.phase == "train":
+            # convert GBR images to RGB
+            fg = sample['fg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
+            sample['fg'] = torch.from_numpy(fg).sub_(self.mean).div_(self.std)
+            bg = sample['bg'][:,:,::-1].transpose((2, 0, 1)).astype(np.float32) / 255.
+            sample['bg'] = torch.from_numpy(bg).sub_(self.mean).div_(self.std)
+            del sample['image_name']
+        
+        sample['boxes'] = torch.from_numpy(bbox).to(torch.float)[None,...]
+
+        sample['image'], sample['alpha'], sample['trimap'] = \
+            torch.from_numpy(image), torch.from_numpy(alpha), torch.from_numpy(trimap).to(torch.long)
+        sample['image'] = sample['image'].sub_(self.mean).div_(self.std)
+        sample['trimap'] = sample['trimap'][None,...].float()
+
+        return sample
+
+
+class RefMatteData(Dataset):
+    def __init__(
+        self, 
+        data_root_path,
+        num_ratio = 0.34,
+    ):
+        self.data_root_path = data_root_path
+        self.num_ratio = num_ratio
+
+        self.rim_img = [os.path.join(data_root_path, name) for name in sorted(os.listdir(data_root_path))]
+        self.rim_pha = [os.path.join(data_root_path.replace('img', 'mask'), name) for name in sorted(os.listdir(data_root_path.replace('img', 'mask')))]
+        self.rim_num = len(self.rim_pha)
+
+        self.transform_spd = transforms.Compose([
+            RandomAffine(degrees=30, scale=[0.8, 1.5], shear=10, flip=0.5),
+            GenTrimap(),
+            RandomCrop((1024, 1024)),
+            Composite_Seg(),
+            ToTensor(phase="train", real_world_aug=False)
+        ])
+
+    def __getitem__(self, idx):
+        if self.num_ratio is not None:
+            if self.num_ratio < 1.0 or idx >= self.rim_num:
+                idx = np.random.randint(0, self.rim_num)
+        alpha = cv2.imread(self.rim_pha[idx % self.rim_num], 0).astype(np.float32)/255
+        alpha_img_name = self.rim_pha[idx % self.rim_num].split('/')[-1]
+        fg_img_name = alpha_img_name[:-6] + '.jpg'
+
+        fg = cv2.imread(os.path.join(self.data_root_path, fg_img_name))
+
+        if np.random.rand() < 0.25:
+            fg = cv2.resize(fg, (1280, 1280), interpolation=random_interp())
+            alpha = cv2.resize(alpha, (1280, 1280), interpolation=random_interp())
+
+        image_name = alpha_img_name  # os.path.split(self.rim_img[idx % self.rim_num])[-1]
+        sample = {'fg': fg, 'alpha': alpha, 'bg': fg, 'image_name': image_name}
+        sample = self.transform_spd(sample)
+
+        converted_sample = {
+            'image': sample['image'],
+            'trimap': sample['trimap'] / 2.0,
+            'alpha': sample['alpha'],
+            'bbox': sample['boxes'],
+            'dataset_name': 'RefMatte',
+            'multi_fg': False,
+        }
+        return converted_sample
+
+    def __len__(self):
+        if self.num_ratio is not None:
+            return int(self.rim_num * self.num_ratio)  # 112506 * 0.34 = 38252 (COCONut_num-38251 + 1)
+        else:
+            return self.rim_num  # 112506
+
+
+    
+if __name__ == '__main__':
+    dataset = RefMatteData(
+        data_root_path = '/data/my_path_b/public_data/data/matting/RefMatte/RefMatte/train/img', 
+        num_ratio=0.34,
+    )
+    data = dataset[0]
+    '''
+    fg torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400)
+    alpha torch.Size([1, 1024, 1024]) tensor(0.) tensor(1.)
+    bg torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400)
+    trimap torch.Size([1, 1024, 1024]) 0.0 or 1.0 or 2.0
+    image torch.Size([3, 1024, 1024]) tensor(-2.1179) tensor(2.6400)
+    boxes torch.Size([1, 4]) tensor(72.) tensor(676.)  0.0~1024.0
+
+    COCONut:
+        image torch.Size([3, 1024, 1024]) tensor(0.0006) tensor(0.9991)
+        trimap torch.Size([1, 1024, 1024]) 0.0 or 0.5 or 1.0
+        alpha torch.Size([1, 1024, 1024]) tensor(0.) tensor(1.)
+        bbox torch.Size([1, 4]) tensor(0.) tensor(590.)
+        dataset_name: 'COCONut'
+    '''
+    for key, val in data.items():
+        if isinstance(val, torch.Tensor):
+            print(key, val.shape, torch.min(val), torch.max(val))
+        else:
+            print(key, val.shape)
\ No newline at end of file
diff --git a/engine/__init__.py b/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8a11d96bd453ed20783a091ec1110a09900cbac
--- /dev/null
+++ b/engine/__init__.py
@@ -0,0 +1 @@
+from .mattingtrainer import MattingTrainer
diff --git a/engine/hooks.py b/engine/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb9c486b488800c63e1c5a5f6453eccfc8fa156f
--- /dev/null
+++ b/engine/hooks.py
@@ -0,0 +1,52 @@
+import inspect
+import detectron2.utils.comm as comm
+from detectron2.engine import EvalHook as _EvalHook
+from detectron2.evaluation.testing import flatten_results_dict
+
+
+class EvalHook(_EvalHook):
+    def __init__(self, eval_period, eval_function):
+        super().__init__(eval_period, eval_function)
+        func_args = inspect.getfullargspec(eval_function).args
+        assert {"final_iter", "next_iter"}.issubset(set(func_args)), (
+            f"Eval function must have either 'final_iter' or 'next_iter' as an argument."
+            f"Got {func_args} instead."
+        )
+
+    def _do_eval(self, final_iter=False, next_iter=0):
+        results = self._func(final_iter=final_iter, next_iter=next_iter)
+
+        if results:
+            assert isinstance(
+                results, dict
+            ), "Eval function must return a dict. Got {} instead.".format(results)
+
+            flattened_results = flatten_results_dict(results)
+            for k, v in flattened_results.items():
+                try:
+                    v = float(v)
+                except Exception as e:
+                    raise ValueError(
+                        "[EvalHook] eval_function should return a nested dict of float. "
+                        "Got '{}: {}' instead.".format(k, v)
+                    ) from e
+            self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
+
+        # Evaluation may take different time among workers.
+        # A barrier make them start the next iteration together.
+        comm.synchronize()
+
+    def after_step(self):
+        next_iter = self.trainer.iter + 1
+        if self._period > 0 and next_iter % self._period == 0:
+            # do the last eval in after_train
+            if next_iter != self.trainer.max_iter:
+                self._do_eval(next_iter=next_iter)
+
+    def after_train(self):
+        # This condition is to prevent the eval from running after a failed training
+        if self.trainer.iter + 1 >= self.trainer.max_iter:
+            self._do_eval(final_iter=True)
+        # func is likely a closure that holds reference to the trainer
+        # therefore we clean it to avoid circular reference in the end
+        del self._func
\ No newline at end of file
diff --git a/engine/mattingtrainer.py b/engine/mattingtrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..13849c706b9e9644625743803a6dd1c4d83f86dd
--- /dev/null
+++ b/engine/mattingtrainer.py
@@ -0,0 +1,171 @@
+from detectron2.engine import AMPTrainer
+import torch
+import time
+import logging
+
+logger = logging.getLogger("detectron2")
+
+import typing
+from collections import defaultdict
+import tabulate
+from torch import nn
+
+
+def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]:
+    """
+    Count parameters of a model and its submodules.
+
+    Args:
+        model: a torch module
+
+    Returns:
+        dict (str-> int): the key is either a parameter name or a module name.
+        The value is the number of elements in the parameter, or in all
+        parameters of the module. The key "" corresponds to the total
+        number of parameters of the model.
+    """
+    r = defaultdict(int)
+    for name, prm in model.named_parameters():
+        if trainable_only:
+            if not prm.requires_grad:
+                continue
+        size = prm.numel()
+        name = name.split(".")
+        for k in range(0, len(name) + 1):
+            prefix = ".".join(name[:k])
+            r[prefix] += size
+    return r
+
+
+def parameter_count_table(
+    model: nn.Module, max_depth: int = 3, trainable_only: bool = False
+) -> str:
+    """
+    Format the parameter count of the model (and its submodules or parameters)
+    in a nice table. It looks like this:
+
+    ::
+
+        | name                            | #elements or shape   |
+        |:--------------------------------|:---------------------|
+        | model                           | 37.9M                |
+        |  backbone                       |  31.5M               |
+        |   backbone.fpn_lateral3         |   0.1M               |
+        |    backbone.fpn_lateral3.weight |    (256, 512, 1, 1)  |
+        |    backbone.fpn_lateral3.bias   |    (256,)            |
+        |   backbone.fpn_output3          |   0.6M               |
+        |    backbone.fpn_output3.weight  |    (256, 256, 3, 3)  |
+        |    backbone.fpn_output3.bias    |    (256,)            |
+        |   backbone.fpn_lateral4         |   0.3M               |
+        |    backbone.fpn_lateral4.weight |    (256, 1024, 1, 1) |
+        |    backbone.fpn_lateral4.bias   |    (256,)            |
+        |   backbone.fpn_output4          |   0.6M               |
+        |    backbone.fpn_output4.weight  |    (256, 256, 3, 3)  |
+        |    backbone.fpn_output4.bias    |    (256,)            |
+        |   backbone.fpn_lateral5         |   0.5M               |
+        |    backbone.fpn_lateral5.weight |    (256, 2048, 1, 1) |
+        |    backbone.fpn_lateral5.bias   |    (256,)            |
+        |   backbone.fpn_output5          |   0.6M               |
+        |    backbone.fpn_output5.weight  |    (256, 256, 3, 3)  |
+        |    backbone.fpn_output5.bias    |    (256,)            |
+        |   backbone.top_block            |   5.3M               |
+        |    backbone.top_block.p6        |    4.7M              |
+        |    backbone.top_block.p7        |    0.6M              |
+        |   backbone.bottom_up            |   23.5M              |
+        |    backbone.bottom_up.stem      |    9.4K              |
+        |    backbone.bottom_up.res2      |    0.2M              |
+        |    backbone.bottom_up.res3      |    1.2M              |
+        |    backbone.bottom_up.res4      |    7.1M              |
+        |    backbone.bottom_up.res5      |    14.9M             |
+        |    ......                       |    .....             |
+
+    Args:
+        model: a torch module
+        max_depth (int): maximum depth to recursively print submodules or
+            parameters
+
+    Returns:
+        str: the table to be printed
+    """
+    count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only)
+    # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
+    param_shape: typing.Dict[str, typing.Tuple] = {
+        k: tuple(v.shape) for k, v in model.named_parameters()
+    }
+
+    # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
+    table: typing.List[typing.Tuple] = []
+
+    def format_size(x: int) -> str:
+        if x > 1e8:
+            return "{:.1f}G".format(x / 1e9)
+        if x > 1e5:
+            return "{:.1f}M".format(x / 1e6)
+        if x > 1e2:
+            return "{:.1f}K".format(x / 1e3)
+        return str(x)
+
+    def fill(lvl: int, prefix: str) -> None:
+        if lvl >= max_depth:
+            return
+        for name, v in count.items():
+            if name.count(".") == lvl and name.startswith(prefix):
+                indent = " " * (lvl + 1)
+                if name in param_shape:
+                    table.append((indent + name, indent + str(param_shape[name])))
+                else:
+                    table.append((indent + name, indent + format_size(v)))
+                    fill(lvl + 1, name + ".")
+
+    table.append(("model", format_size(count.pop(""))))
+    fill(0, "")
+
+    old_ws = tabulate.PRESERVE_WHITESPACE
+    tabulate.PRESERVE_WHITESPACE = True
+    tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe")
+    tabulate.PRESERVE_WHITESPACE = old_ws
+    return tab
+
+
+def cycle(iterable):
+    while True:
+        for x in iterable:
+            yield x
+
+class MattingTrainer(AMPTrainer):
+    def __init__(self, model, data_loader, optimizer, grad_scaler=None):
+        super().__init__(model, data_loader, optimizer, grad_scaler=None)
+        self.data_loader_iter = iter(cycle(self.data_loader))
+
+        # print model parameters
+        logger.info("All parameters: \n" + parameter_count_table(model))
+        logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8))
+
+    def run_step(self):
+        """
+        Implement the AMP training logic.
+        """
+        assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
+        assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
+        from torch.cuda.amp import autocast
+
+        #matting pass
+        start = time.perf_counter()        
+        data = next(self.data_loader_iter)
+        data_time = time.perf_counter() - start
+
+        with autocast():
+            loss_dict = self.model(data)
+            if isinstance(loss_dict, torch.Tensor):
+                losses = loss_dict
+                loss_dict = {"total_loss": loss_dict}
+            else:
+                losses = sum(loss_dict.values())
+
+        self.optimizer.zero_grad()
+        self.grad_scaler.scale(losses).backward()
+
+        self._write_metrics(loss_dict, data_time)
+
+        self.grad_scaler.step(self.optimizer)
+        self.grad_scaler.update()
\ No newline at end of file
diff --git a/modeling/__init__.py b/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b89f6584f82a9ab1cf77d1b4032a8e5829dcda7
--- /dev/null
+++ b/modeling/__init__.py
@@ -0,0 +1,5 @@
+from .backbone import *
+from .criterion import *
+from .decoder import *
+from .meta_arch import *
+from .semantic_enhanced_matting import *
\ No newline at end of file
diff --git a/modeling/__pycache__/__init__.cpython-38.pyc b/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dc4f753be9b7189152be5139a7df3a861069154a
Binary files /dev/null and b/modeling/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/backbone/__init__.py b/modeling/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9420c05c3a5ea747b4a6e884ae079b82bae9fa39
--- /dev/null
+++ b/modeling/backbone/__init__.py
@@ -0,0 +1,2 @@
+from .backbone import *
+from .vit import *
\ No newline at end of file
diff --git a/modeling/backbone/__pycache__/__init__.cpython-38.pyc b/modeling/backbone/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa50dd88a5d5a61448f0148ebdec757140d4dbeb
Binary files /dev/null and b/modeling/backbone/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/backbone/__pycache__/backbone.cpython-38.pyc b/modeling/backbone/__pycache__/backbone.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..035d55d7a631c2437b53e88920227da03d3bdc04
Binary files /dev/null and b/modeling/backbone/__pycache__/backbone.cpython-38.pyc differ
diff --git a/modeling/backbone/__pycache__/utils.cpython-38.pyc b/modeling/backbone/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa110dc5bbdcdc38091d97cccdabff98f3257465
Binary files /dev/null and b/modeling/backbone/__pycache__/utils.cpython-38.pyc differ
diff --git a/modeling/backbone/__pycache__/vit.cpython-38.pyc b/modeling/backbone/__pycache__/vit.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fa38b7d5676eb701f478264a628e900434c9b590
Binary files /dev/null and b/modeling/backbone/__pycache__/vit.cpython-38.pyc differ
diff --git a/modeling/backbone/backbone.py b/modeling/backbone/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c765a6b38542f66cae55216bba697a6626d128
--- /dev/null
+++ b/modeling/backbone/backbone.py
@@ -0,0 +1,74 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from abc import ABCMeta, abstractmethod
+from typing import Dict
+import torch.nn as nn
+
+from detectron2.layers import ShapeSpec
+
+__all__ = ["Backbone"]
+
+
+class Backbone(nn.Module, metaclass=ABCMeta):
+    """
+    Abstract base class for network backbones.
+    """
+
+    def __init__(self):
+        """
+        The `__init__` method of any subclass can specify its own set of arguments.
+        """
+        super().__init__()
+
+    @abstractmethod
+    def forward(self):
+        """
+        Subclasses must override this method, but adhere to the same return type.
+
+        Returns:
+            dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor
+        """
+        pass
+
+    @property
+    def size_divisibility(self) -> int:
+        """
+        Some backbones require the input height and width to be divisible by a
+        specific integer. This is typically true for encoder / decoder type networks
+        with lateral connection (e.g., FPN) for which feature maps need to match
+        dimension in the "bottom up" and "top down" paths. Set to 0 if no specific
+        input size divisibility is required.
+        """
+        return 0
+
+    @property
+    def padding_constraints(self) -> Dict[str, int]:
+        """
+        This property is a generalization of size_divisibility. Some backbones and training
+        recipes require specific padding constraints, such as enforcing divisibility by a specific
+        integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter
+        in :paper:vitdet). `padding_constraints` contains these optional items like:
+        {
+            "size_divisibility": int,
+            "square_size": int,
+            # Future options are possible
+        }
+        `size_divisibility` will read from here if presented and `square_size` indicates the
+        square padding size if `square_size` > 0.
+
+        TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints
+        could be generalized as TypedDict (Python 3.8+) to support more types in the future.
+        """
+        return {}
+
+    def output_shape(self):
+        """
+        Returns:
+            dict[str->ShapeSpec]
+        """
+        # this is a backward-compatible default
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
+            )
+            for name in self._out_features
+        }
diff --git a/modeling/backbone/utils.py b/modeling/backbone/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b89a4c3fbe079a77fd0cef947cf9ada787fc55d
--- /dev/null
+++ b/modeling/backbone/utils.py
@@ -0,0 +1,186 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = [
+    "window_partition",
+    "window_unpartition",
+    "add_decomposed_rel_pos",
+    "get_abs_pos",
+    "PatchEmbed",
+]
+
+
+def window_partition(x, window_size):
+    """
+    Partition into non-overlapping windows with padding if needed.
+    Args:
+        x (tensor): input tokens with [B, H, W, C].
+        window_size (int): window size.
+
+    Returns:
+        windows: windows after partition with [B * num_windows, window_size, window_size, C].
+        (Hp, Wp): padded height and width before partition
+    """
+    B, H, W, C = x.shape
+
+    pad_h = (window_size - H % window_size) % window_size
+    pad_w = (window_size - W % window_size) % window_size
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+    Hp, Wp = H + pad_h, W + pad_w
+
+    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows, (Hp, Wp)
+
+
+def window_unpartition(windows, window_size, pad_hw, hw):
+    """
+    Window unpartition into original sequences and removing padding.
+    Args:
+        x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+        window_size (int): window size.
+        pad_hw (Tuple): padded height and width (Hp, Wp).
+        hw (Tuple): original height and width (H, W) before padding.
+
+    Returns:
+        x: unpartitioned sequences with [B, H, W, C].
+    """
+    Hp, Wp = pad_hw
+    H, W = hw
+    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+    if Hp > H or Wp > W:
+        x = x[:, :H, :W, :].contiguous()
+    return x
+
+
+def get_rel_pos(q_size, k_size, rel_pos):
+    """
+    Get relative positional embeddings according to the relative positions of
+        query and key sizes.
+    Args:
+        q_size (int): size of query q.
+        k_size (int): size of key k.
+        rel_pos (Tensor): relative position embeddings (L, C).
+
+    Returns:
+        Extracted positional embeddings according to relative positions.
+    """
+    max_rel_dist = int(2 * max(q_size, k_size) - 1)
+    # Interpolate rel pos if needed.
+    if rel_pos.shape[0] != max_rel_dist:
+        # Interpolate rel pos.
+        rel_pos_resized = F.interpolate(
+            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+            size=max_rel_dist,
+            mode="linear",
+        )
+        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+    else:
+        rel_pos_resized = rel_pos
+
+    # Scale the coords with short length if shapes for q and k are different.
+    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+    return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
+    """
+    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
+    Args:
+        attn (Tensor): attention map.
+        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+    Returns:
+        attn (Tensor): attention map with added relative positional embeddings.
+    """
+    q_h, q_w = q_size
+    k_h, k_w = k_size
+    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+    Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+    B, _, dim = q.shape
+    r_q = q.reshape(B, q_h, q_w, dim)
+    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+    attn = (
+        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+    ).view(B, q_h * q_w, k_h * k_w)
+
+    return attn
+
+
+def get_abs_pos(abs_pos, has_cls_token, hw):
+    """
+    Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token
+        dimension for the original embeddings.
+    Args:
+        abs_pos (Tensor): absolute positional embeddings with (1, num_position, C).
+        has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token.
+        hw (Tuple): size of input image tokens.
+
+    Returns:
+        Absolute positional embeddings after processing with shape (1, H, W, C)
+    """
+    h, w = hw
+    if has_cls_token:
+        abs_pos = abs_pos[:, 1:]
+    xy_num = abs_pos.shape[1]
+    size = int(math.sqrt(xy_num))
+    assert size * size == xy_num
+
+    if size != h or size != w:
+        new_abs_pos = F.interpolate(
+            abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2),
+            size=(h, w),
+            mode="bicubic",
+            align_corners=False,
+        )
+
+        return new_abs_pos.permute(0, 2, 3, 1)
+    else:
+        return abs_pos.reshape(1, h, w, -1)
+
+
+class PatchEmbed(nn.Module):
+    """
+    Image to Patch Embedding.
+    """
+
+    def __init__(
+        self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
+    ):
+        """
+        Args:
+            kernel_size (Tuple): kernel size of the projection layer.
+            stride (Tuple): stride of the projection layer.
+            padding (Tuple): padding size of the projection layer.
+            in_chans (int): Number of input image channels.
+            embed_dim (int):  embed_dim (int): Patch embedding dimension.
+        """
+        super().__init__()
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+        )
+
+    def forward(self, x):
+        x = self.proj(x)
+        # B C H W -> B H W C
+        x = x.permute(0, 2, 3, 1)
+        return x
diff --git a/modeling/backbone/vit.py b/modeling/backbone/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..f98efd3dbf386919afe652984d9a2b9f89a84ab5
--- /dev/null
+++ b/modeling/backbone/vit.py
@@ -0,0 +1,404 @@
+import logging
+import math
+import fvcore.nn.weight_init as weight_init
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from detectron2.layers import CNNBlockBase, Conv2d, get_norm
+from detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous
+from fairscale.nn.checkpoint import checkpoint_wrapper
+from timm.models.layers import DropPath, Mlp, trunc_normal_
+from .backbone import Backbone
+from .utils import (
+    PatchEmbed,
+    add_decomposed_rel_pos,
+    get_abs_pos,
+    window_partition,
+    window_unpartition,
+)
+
+logger = logging.getLogger(__name__)
+
+
+__all__ = ["ViT"]
+
+
+class Attention(nn.Module):
+    """Multi-head Attention block with relative position embeddings."""
+
+    def __init__(
+        self,
+        dim,
+        num_heads=8,
+        qkv_bias=True,
+        use_rel_pos=False,
+        rel_pos_zero_init=True,
+        input_size=None,
+    ):
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads.
+            qkv_bias (bool:  If True, add a learnable bias to query, key, value.
+            rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            input_size (int or None): Input resolution for calculating the relative positional
+                parameter size.
+        """
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.proj = nn.Linear(dim, dim)
+
+        self.use_rel_pos = use_rel_pos
+        if self.use_rel_pos:
+            # initialize relative positional embeddings
+            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+            if not rel_pos_zero_init:
+                trunc_normal_(self.rel_pos_h, std=0.02)
+                trunc_normal_(self.rel_pos_w, std=0.02)
+
+    def forward(self, x):
+        B, H, W, _ = x.shape
+        # qkv with shape (3, B, nHead, H * W, C)
+        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        # q, k, v with shape (B * nHead, H * W, C)
+        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+        attn = (q * self.scale) @ k.transpose(-2, -1)
+
+        if self.use_rel_pos:
+            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+        attn = attn.softmax(dim=-1)
+        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+        x = self.proj(x)
+
+        return x
+
+class LayerNorm(nn.Module):
+    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
+    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
+    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
+    with shape (batch_size, channels, height, width).
+    """
+    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(normalized_shape))
+        self.bias = nn.Parameter(torch.zeros(normalized_shape))
+        self.eps = eps
+        self.data_format = data_format
+        if self.data_format not in ["channels_last", "channels_first"]:
+            raise NotImplementedError 
+        self.normalized_shape = (normalized_shape, )
+    
+    def forward(self, x):
+        if self.data_format == "channels_last":
+            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
+        elif self.data_format == "channels_first":
+            u = x.mean(1, keepdim=True)
+            s = (x - u).pow(2).mean(1, keepdim=True)
+            x = (x - u) / torch.sqrt(s + self.eps)
+            x = self.weight[:, None, None] * x + self.bias[:, None, None]
+            return x
+
+class ResBottleneckBlock(CNNBlockBase):
+    """
+    The standard bottleneck residual block without the last activation layer.
+    It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        bottleneck_channels,
+        norm="LN",
+        act_layer=nn.GELU,
+        conv_kernels=3,
+        conv_paddings=1,
+    ):
+        """
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            bottleneck_channels (int): number of output channels for the 3x3
+                "bottleneck" conv layers.
+            norm (str or callable): normalization for all conv layers.
+                See :func:`layers.get_norm` for supported format.
+            act_layer (callable): activation for all conv layers.
+        """
+        super().__init__(in_channels, out_channels, 1)
+
+        self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
+        self.norm1 = get_norm(norm, bottleneck_channels)
+        self.act1 = act_layer()
+
+        self.conv2 = Conv2d(
+            bottleneck_channels,
+            bottleneck_channels,
+            conv_kernels,
+            padding=conv_paddings,
+            bias=False,
+        )
+        self.norm2 = get_norm(norm, bottleneck_channels)
+        self.act2 = act_layer()
+
+        self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
+        self.norm3 = get_norm(norm, out_channels)
+
+        for layer in [self.conv1, self.conv2, self.conv3]:
+            weight_init.c2_msra_fill(layer)
+        for layer in [self.norm1, self.norm2]:
+            layer.weight.data.fill_(1.0)
+            layer.bias.data.zero_()
+        # zero init last norm layer.
+        self.norm3.weight.data.zero_()
+        self.norm3.bias.data.zero_()
+
+    def forward(self, x):
+        out = x
+        for layer in self.children():
+            out = layer(out)
+
+        out = x + out
+        return out
+
+
+class Block(nn.Module):
+    """Transformer blocks with support of window attention and residual propagation blocks"""
+
+    def __init__(
+        self,
+        dim,
+        num_heads,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        drop_path=0.0,
+        norm_layer=nn.LayerNorm,
+        act_layer=nn.GELU,
+        use_rel_pos=False,
+        rel_pos_zero_init=True,
+        window_size=0,
+        use_cc_attn = False,
+        use_residual_block=False,
+        use_convnext_block=False,
+        input_size=None,
+        res_conv_kernel_size=3,
+        res_conv_padding=1,
+    ):
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            drop_path (float): Stochastic depth rate.
+            norm_layer (nn.Module): Normalization layer.
+            act_layer (nn.Module): Activation layer.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks. If it equals 0, then not
+                use window attention.
+            use_residual_block (bool): If True, use a residual block after the MLP block.
+            input_size (int or None): Input resolution for calculating the relative positional
+                parameter size.
+        """
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            use_rel_pos=use_rel_pos,
+            rel_pos_zero_init=rel_pos_zero_init,
+            input_size=input_size if window_size == 0 else (window_size, window_size),
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)
+
+        self.window_size = window_size
+
+        self.use_residual_block = use_residual_block
+        if use_residual_block:
+            # Use a residual block with bottleneck channel as dim // 2
+            self.residual = ResBottleneckBlock(
+                in_channels=dim,
+                out_channels=dim,
+                bottleneck_channels=dim // 2,
+                norm="LN",
+                act_layer=act_layer,
+                conv_kernels=res_conv_kernel_size,
+                conv_paddings=res_conv_padding,
+            )
+        self.use_convnext_block = use_convnext_block
+        if use_convnext_block:
+            self.convnext = ConvNextBlock(dim = dim)
+
+        if use_cc_attn:
+            self.attn = CrissCrossAttention(dim)
+
+
+    def forward(self, x):
+        shortcut = x
+        x = self.norm1(x)
+        # Window partition
+        if self.window_size > 0:
+            H, W = x.shape[1], x.shape[2]
+            x, pad_hw = window_partition(x, self.window_size)
+
+        x = self.attn(x)
+
+        # Reverse window partition
+        if self.window_size > 0:
+            x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        if self.use_residual_block:
+            x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+        if self.use_convnext_block:
+            x = self.convnext(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+
+        return x
+
+
+class ViT(Backbone):
+    """
+    This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`.
+    "Exploring Plain Vision Transformer Backbones for Object Detection",
+    https://arxiv.org/abs/2203.16527
+    """
+
+    def __init__(
+        self,
+        img_size=1024,
+        patch_size=16,
+        in_chans=3,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        drop_path_rate=0.0,
+        norm_layer=nn.LayerNorm,
+        act_layer=nn.GELU,
+        use_abs_pos=True,
+        use_rel_pos=False,
+        rel_pos_zero_init=True,
+        window_size=0,
+        window_block_indexes=(),
+        residual_block_indexes=(),
+        use_act_checkpoint=False,
+        pretrain_img_size=224,
+        pretrain_use_cls_token=True,
+        out_feature="last_feat",
+        res_conv_kernel_size=3, 
+        res_conv_padding=1,
+    ):
+        """
+        Args:
+            img_size (int): Input image size.
+            patch_size (int): Patch size.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Patch embedding dimension.
+            depth (int): Depth of ViT.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            drop_path_rate (float): Stochastic depth rate.
+            norm_layer (nn.Module): Normalization layer.
+            act_layer (nn.Module): Activation layer.
+            use_abs_pos (bool): If True, use absolute positional embeddings.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks.
+            window_block_indexes (list): Indexes for blocks using window attention.
+            residual_block_indexes (list): Indexes for blocks using conv propagation.
+            use_act_checkpoint (bool): If True, use activation checkpointing.
+            pretrain_img_size (int): input image size for pretraining models.
+            pretrain_use_cls_token (bool): If True, pretrainig models use class token.
+            out_feature (str): name of the feature from the last block.
+        """
+        super().__init__()
+        self.pretrain_use_cls_token = pretrain_use_cls_token
+
+        self.patch_embed = PatchEmbed(
+            kernel_size=(patch_size, patch_size),
+            stride=(patch_size, patch_size),
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+
+        if use_abs_pos:
+            # Initialize absolute positional embedding with pretrain image size.
+            num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size)
+            num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches
+            self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim))
+        else:
+            self.pos_embed = None
+
+        # stochastic depth decay rule
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+
+        self.blocks = nn.ModuleList()
+        for i in range(depth):
+            block = Block(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                drop_path=dpr[i],
+                norm_layer=norm_layer,
+                act_layer=act_layer,
+                use_rel_pos=use_rel_pos,
+                rel_pos_zero_init=rel_pos_zero_init,
+                window_size=window_size if i in window_block_indexes else 0,
+                use_residual_block=i in residual_block_indexes,
+                input_size=(img_size // patch_size, img_size // patch_size),
+                res_conv_kernel_size=res_conv_kernel_size,
+                res_conv_padding=res_conv_padding,
+            )
+            if use_act_checkpoint:
+                block = checkpoint_wrapper(block)
+            self.blocks.append(block)
+
+        self._out_feature_channels = {out_feature: embed_dim}
+        self._out_feature_strides = {out_feature: patch_size}
+        self._out_features = [out_feature]
+
+        if self.pos_embed is not None:
+            trunc_normal_(self.pos_embed, std=0.02)
+
+        self.apply(self._init_weights)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=0.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    def forward(self, x):
+        x = self.patch_embed(x)
+        if self.pos_embed is not None:
+            x = x + get_abs_pos(
+                self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2])
+            )
+
+        for blk in self.blocks:
+            x = blk(x)
+
+        outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)}
+
+        return outputs['last_feat']
\ No newline at end of file
diff --git a/modeling/criterion/__init__.py b/modeling/criterion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f09984281bdc1a8338f5a403e78f91bb8b3b172e
--- /dev/null
+++ b/modeling/criterion/__init__.py
@@ -0,0 +1 @@
+from .matting_criterion import MattingCriterion
\ No newline at end of file
diff --git a/modeling/criterion/__pycache__/__init__.cpython-38.pyc b/modeling/criterion/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef71b1ae2393b3dc9f6e03e7a8467b8854fd7078
Binary files /dev/null and b/modeling/criterion/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc b/modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a624eda9f841b05af01fd3b91242718b5bcc6f0
Binary files /dev/null and b/modeling/criterion/__pycache__/matting_criterion.cpython-38.pyc differ
diff --git a/modeling/criterion/matting_criterion.py b/modeling/criterion/matting_criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f47a72e82257132b5f2100134bfc0b7696c2dbf
--- /dev/null
+++ b/modeling/criterion/matting_criterion.py
@@ -0,0 +1,271 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from collections import defaultdict
+
+
+class MattingCriterion(nn.Module):
+    def __init__(
+        self,
+        *,
+        losses,
+        image_size = 1024,
+    ):
+        super(MattingCriterion, self).__init__()
+        self.losses = losses
+        self.image_size = image_size
+
+    def loss_gradient_penalty(self, sample_map, preds, targets):
+
+        #sample_map for unknown area
+        if torch.sum(sample_map) == 0:
+            scale = 0
+        else:
+            scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map)
+
+        #gradient in x
+        sobel_x_kernel = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]]).type(dtype=preds.type())
+        delta_pred_x = F.conv2d(preds, weight=sobel_x_kernel, padding=1)
+        delta_gt_x = F.conv2d(targets, weight=sobel_x_kernel, padding=1)
+
+        #gradient in y 
+        sobel_y_kernel = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]]).type(dtype=preds.type())
+        delta_pred_y = F.conv2d(preds, weight=sobel_y_kernel, padding=1)
+        delta_gt_y = F.conv2d(targets, weight=sobel_y_kernel, padding=1)
+
+        #loss
+        loss = (F.l1_loss(delta_pred_x * sample_map, delta_gt_x * sample_map) * scale + \
+            F.l1_loss(delta_pred_y * sample_map, delta_gt_y * sample_map) * scale + \
+            0.01 * torch.mean(torch.abs(delta_pred_x * sample_map)) * scale +  \
+            0.01 * torch.mean(torch.abs(delta_pred_y * sample_map)) * scale)
+
+        return dict(loss_gradient_penalty=loss)
+
+    def loss_pha_laplacian(self, preds, targets):
+        loss = laplacian_loss(preds, targets)
+        return dict(loss_pha_laplacian=loss)
+
+    def unknown_l1_loss(self, sample_map, preds, targets):
+        
+        if torch.sum(sample_map) == 0:
+            scale = 0
+        else:
+            scale = sample_map.shape[0] * (self.image_size ** 2) / torch.sum(sample_map)
+        # scale = 1
+
+        loss = F.l1_loss(preds * sample_map, targets * sample_map) * scale
+
+        return dict(unknown_l1_loss=loss)
+
+    def known_l1_loss(self, sample_map, preds, targets):
+        new_sample_map = torch.zeros_like(sample_map)
+        new_sample_map[sample_map==0] = 1
+        
+        if torch.sum(new_sample_map) == 0:
+            scale = 0
+        else:
+            scale = new_sample_map.shape[0] * (self.image_size ** 2) / torch.sum(new_sample_map)
+        # scale = 1
+        
+        loss = F.l1_loss(preds * new_sample_map, targets * new_sample_map) * scale
+
+        return dict(known_l1_loss=loss)
+
+    def get_loss(self, k, sample_map, preds, targets):
+        if k=='unknown_l1_loss' or k=='known_l1_loss' or k=='loss_gradient_penalty':
+            losses = getattr(self, k)(sample_map, preds, targets)
+        else:
+            losses = getattr(self, k)(preds, targets)
+        assert len(list(losses.keys())) == 1
+        return losses[list(losses.keys())[0]]
+
+    def forward(self, sample_map, preds, targets, batch_weight=None):
+        losses = {i: torch.tensor(0.0, device=sample_map.device) for i in self.losses}
+        for k in self.losses:
+            if batch_weight is None:
+                losses[k] += self.get_loss(k, sample_map, preds, targets)
+            else:
+                for i, loss_weight in enumerate(batch_weight):
+                    if loss_weight == -1.0 and k != 'known_l1_loss':
+                        continue
+                    else:
+                        losses[k] += self.get_loss(k, sample_map[i: i + 1], preds[i: i + 1], targets[i: i + 1]) * abs(loss_weight)
+        return losses
+
+
+#-----------------Laplacian Loss-------------------------#
+def laplacian_loss(pred, true, max_levels=5):
+    kernel = gauss_kernel(device=pred.device, dtype=pred.dtype)
+    pred_pyramid = laplacian_pyramid(pred, kernel, max_levels)
+    true_pyramid = laplacian_pyramid(true, kernel, max_levels)
+    loss = 0
+    for level in range(max_levels):
+        loss += (2 ** level) * F.l1_loss(pred_pyramid[level], true_pyramid[level])
+    return loss / max_levels
+
+def laplacian_pyramid(img, kernel, max_levels):
+    current = img
+    pyramid = []
+    for _ in range(max_levels):
+        current = crop_to_even_size(current)
+        down = downsample(current, kernel)
+        up = upsample(down, kernel)
+        diff = current - up
+        pyramid.append(diff)
+        current = down
+    return pyramid
+
+def gauss_kernel(device='cpu', dtype=torch.float32):
+    kernel = torch.tensor([[1,  4,  6,  4, 1],
+                        [4, 16, 24, 16, 4],
+                        [6, 24, 36, 24, 6],
+                        [4, 16, 24, 16, 4],
+                        [1,  4,  6,  4, 1]], device=device, dtype=dtype)
+    kernel /= 256
+    kernel = kernel[None, None, :, :]
+    return kernel
+
+def gauss_convolution(img, kernel):
+    B, C, H, W = img.shape
+    img = img.reshape(B * C, 1, H, W)
+    img = F.pad(img, (2, 2, 2, 2), mode='reflect')
+    img = F.conv2d(img, kernel)
+    img = img.reshape(B, C, H, W)
+    return img
+
+def downsample(img, kernel):
+    img = gauss_convolution(img, kernel)
+    img = img[:, :, ::2, ::2]
+    return img
+
+def upsample(img, kernel):
+    B, C, H, W = img.shape
+    out = torch.zeros((B, C, H * 2, W * 2), device=img.device, dtype=img.dtype)
+    out[:, :, ::2, ::2] = img * 4
+    out = gauss_convolution(out, kernel)
+    return out
+
+def crop_to_even_size(img):
+    H, W = img.shape[2:]
+    H = H - H % 2
+    W = W - W % 2
+    return img[:, :, :H, :W]
+
+def normalized_focal_loss(pred, gt, gamma=2, class_num=3, norm=True, beta_detach=False, beta_sum_detach=False):
+    pred_logits = F.softmax(pred, dim=1)  # [B, 3, H, W]
+    gt_one_hot = F.one_hot(gt, class_num).permute(0, 3, 1, 2)  # [B, 3, H, W]
+    p = (pred_logits * gt_one_hot).sum(dim=1)  # [B, H, W]
+    beta = (1 - p) ** gamma  # [B, H, W]
+    beta_sum = torch.sum(beta, dim=(-2, -1), keepdim=True) / (pred.shape[-1] * pred.shape[-2])  # [B, 1, 1]
+
+    if beta_detach:
+        beta = beta.detach()
+    if beta_sum_detach:
+        beta_sum = beta_sum.detach()
+
+    if norm:
+        loss = 1 / beta_sum * beta * (-torch.log(p))
+        return torch.mean(loss)
+    else:
+        loss = beta * (-torch.log(p))
+        return torch.mean(loss)
+
+class GHMC(nn.Module):
+    def __init__(self, bins=10, momentum=0.75, loss_weight=1.0, device='cuda', norm=False):
+        super(GHMC, self).__init__()
+        self.bins = bins
+        self.momentum = momentum
+        self.edges = torch.arange(bins + 1).float().cuda() / bins
+        self.edges[-1] += 1e-6
+        if momentum > 0:
+            self.acc_sum = torch.zeros(bins).cuda()
+        self.loss_weight = loss_weight
+        self.device = device
+        self.norm = norm
+
+    def forward(self, pred, target, *args, **kwargs):
+        """Calculate the GHM-C loss.
+        Args:
+            pred (float tensor of size [batch_num, class_num]):
+                The direct prediction of classification fc layer.
+            target (float tensor of size [batch_num, class_num]):
+                Binary class target for each sample.
+            label_weight (float tensor of size [batch_num, class_num]):
+                the value is 1 if the sample is valid and 0 if ignored.
+        Returns:
+            The gradient harmonized loss.
+        """
+
+        # the target should be binary class label
+        # if pred.dim() != target.dim():
+        #     target, label_weight = _expand_binary_labels(
+        #                             target, label_weight, pred.size(-1))
+        # target, label_weight = target.float(), label_weight.float()
+        # pdb.set_trace()
+
+        # pred: [B, C, H, W], target: [B, H, W]
+        pred = pred.permute(0, 2, 3, 1).reshape(-1, 3)  # [B x H x W, C]
+        target = target.reshape(-1)  # [B x H x W]
+        # self.acc_sum = self.acc_sum.type(pred.dtype)
+
+        edges = self.edges
+        mmt = self.momentum
+        weights = torch.zeros((target.shape),dtype=pred.dtype).to(self.device)
+
+        # gradient length
+        #g = 1 - torch.index_select(F.softmax(pred,dim=1).detach(), dim=0, index=target)
+        g = 1 - torch.gather(F.softmax(pred,dim=1).detach(),dim=1,index=target.unsqueeze(1))
+        #g = torch.abs(pred.softmax(2).detach() - target)
+
+        tot = 1.0
+        n = 0  # n valid bins
+        for i in range(self.bins):
+            inds = (g >= edges[i]) & (g < edges[i+1])
+            num_in_bin = inds.sum().item()
+            if num_in_bin > 0:
+                idx = torch.nonzero(inds)[:, 0]
+                if mmt > 0:
+                    self.acc_sum[i] = mmt * self.acc_sum[i] \
+                        + (1 - mmt) * num_in_bin
+                    # pdb.set_trace()#scatter_ index_put_
+                    #BB=torch.nonzero(inds)
+                    _weight_idx = tot / self.acc_sum[i]
+                    weights = weights.to(dtype=_weight_idx.dtype)
+                    weights[idx] = _weight_idx
+                    # weights.scatter_(0, torch.nonzero(inds)[:,0], tot / self.acc_sum[i])
+                    # # weights.index_put_(inds, tot / self.acc_sum[i])
+                    # weights[inds] = tot / self.acc_sum[i] # * torch.ones((len(inds)))
+                else:
+                    weights[idx] = tot / num_in_bin
+                n += 1
+        if n > 0:
+            weights = weights / n
+
+            # pdb.set_trace()
+            # loss = (weights * F.cross_entropy(pred, target, reduction='none')).sum() / tot / pred.shape[0]
+        if self.norm:
+            weights = weights / torch.sum(weights).detach()
+
+        loss = - ((weights.unsqueeze(1) * torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() )  # / pred.shape[0]
+
+        # loss3= F.cross_entropy(pred, target, reduction='mean')
+        # loss4 = - ((torch.gather(F.log_softmax(pred, dim=1), dim=1, index=target.unsqueeze(1))).sum() / pred.shape[0])
+
+        # pro = F.softmax(logits, dim=1)
+        #
+        # label_onehot = torch.zeros_like(logits).scatter_(1, labels.unsqueeze(1), 1)
+        # with torch.no_grad():
+        #     weight_matrix = (1 - pro) ** self.gamma
+        # # pdb.set_trace()
+        # fl = - (weight_matrix * (label_onehot * (pro + self.eps).log())).sum() / pro.shape[0]
+
+        return loss
+
+if __name__ == '__main__':
+    pred = torch.randn(2, 3, 1024, 1024)
+    gt =torch.argmax(torch.randn(2, 3, 1024, 1024), dim=1)
+    loss = normalized_focal_loss(pred, gt)
+    print(loss)
+    
+
+
diff --git a/modeling/decoder/__init__.py b/modeling/decoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cd2a43e5205320ded836374f6c823a90714c30c
--- /dev/null
+++ b/modeling/decoder/__init__.py
@@ -0,0 +1 @@
+from .detail_capture import Detail_Capture, Ori_Detail_Capture
\ No newline at end of file
diff --git a/modeling/decoder/__pycache__/__init__.cpython-38.pyc b/modeling/decoder/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3769ef0cda17fb5f15cef2f9a9d3deb0799ac15
Binary files /dev/null and b/modeling/decoder/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/decoder/__pycache__/detail_capture.cpython-38.pyc b/modeling/decoder/__pycache__/detail_capture.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dd1eb6bfc6209e706f0a4a9a24c627ba9ad8ea70
Binary files /dev/null and b/modeling/decoder/__pycache__/detail_capture.cpython-38.pyc differ
diff --git a/modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc b/modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..691371d22ddc50efe3f802294ec7373379b4ad3c
Binary files /dev/null and b/modeling/decoder/__pycache__/unet_detail_capture.cpython-38.pyc differ
diff --git a/modeling/decoder/detail_capture.py b/modeling/decoder/detail_capture.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd8b1d68433473a5d787de3ba53efa35ff9bfbcc
--- /dev/null
+++ b/modeling/decoder/detail_capture.py
@@ -0,0 +1,185 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Basic_Conv3x3(nn.Module):
+    """
+    Basic convolution layers including: Conv3x3, BatchNorm2d, ReLU layers.
+    """
+    def __init__(
+        self,
+        in_chans,
+        out_chans,
+        stride=2,
+        padding=1,
+    ):
+        super().__init__()
+        self.conv = nn.Conv2d(in_chans, out_chans, 3, stride, padding, bias=False)
+        self.bn = nn.BatchNorm2d(out_chans)
+        self.relu = nn.ReLU(True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = self.bn(x)
+        x = self.relu(x)
+
+        return x
+
+class ConvStream(nn.Module):
+    """
+    Simple ConvStream containing a series of basic conv3x3 layers to extract detail features.
+    """
+    def __init__(
+        self,
+        in_chans = 4,
+        out_chans = [48, 96, 192],
+    ):
+        super().__init__()
+        self.convs = nn.ModuleList()
+        
+        self.conv_chans = out_chans.copy()
+        self.conv_chans.insert(0, in_chans)
+        
+        for i in range(len(self.conv_chans)-1):
+            in_chan_ = self.conv_chans[i]
+            out_chan_ = self.conv_chans[i+1]
+            self.convs.append(
+                Basic_Conv3x3(in_chan_, out_chan_)
+            )
+    
+    def forward(self, x):
+        out_dict = {'D0': x}
+        for i in range(len(self.convs)):
+            x = self.convs[i](x)
+            name_ = 'D'+str(i+1)
+            out_dict[name_] = x
+        
+        return out_dict
+
+class Fusion_Block(nn.Module):
+    """
+    Simple fusion block to fuse feature from ConvStream and Plain Vision Transformer.
+    """
+    def __init__(
+        self,
+        in_chans,
+        out_chans,
+    ):
+        super().__init__()
+        self.conv = Basic_Conv3x3(in_chans, out_chans, stride=1, padding=1)
+
+    def forward(self, x, D):
+        F_up = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
+        out = torch.cat([D, F_up], dim=1)
+        out = self.conv(out)
+
+        return out    
+
+class Matting_Head(nn.Module):
+    """
+    Simple Matting Head, containing only conv3x3 and conv1x1 layers.
+    """
+    def __init__(
+        self,
+        in_chans = 32,
+        mid_chans = 16,
+    ):
+        super().__init__()
+        self.matting_convs = nn.Sequential(
+            nn.Conv2d(in_chans, mid_chans, 3, 1, 1),
+            nn.BatchNorm2d(mid_chans),
+            nn.ReLU(True),
+            nn.Conv2d(mid_chans, 1, 1, 1, 0)
+            )
+
+    def forward(self, x):
+        x = self.matting_convs(x)
+
+        return x
+
+class Detail_Capture(nn.Module):
+    """
+    Simple and Lightweight Detail Capture Module for ViT Matting.
+    """
+    def __init__(
+        self,
+        in_chans = [384, 1],
+        img_chans=4,
+        convstream_out = [48, 96, 192],
+        fusion_out = [256, 128, 64, 32],
+    ):
+        super().__init__()
+        assert len(fusion_out) == len(convstream_out) + 1
+
+        self.convstream = ConvStream(in_chans=img_chans, out_chans=convstream_out)
+        self.conv_chans = self.convstream.conv_chans  # [4, 48, 96, 192]
+
+        self.fusion_blks = nn.ModuleList()
+        self.fus_channs = fusion_out.copy()
+        self.fus_channs.insert(0, in_chans[0])  # [384, 256, 128, 64, 32]
+        for i in range(len(self.fus_channs)-1):
+            in_channels = self.fus_channs[i] + self.conv_chans[-(i+1)] if i != 2 else in_chans[1] + self.conv_chans[-(i+1)]  # [256 + 192 = 448, 256 + 96 = 352, 128 + 48 = 176, 64 + 4 = 68]
+            out_channels = self.fus_channs[i+1]  # [256, 128, 64, 32]
+            self.fusion_blks.append(
+                Fusion_Block(
+                    in_chans = in_channels,
+                    out_chans = out_channels,
+                )
+            )
+
+        self.matting_head = Matting_Head(  # 32 --> 1
+            in_chans = fusion_out[-1],  
+        )
+
+    def forward(self, features, images):
+        detail_features = self.convstream(images)  # [1, 4, 672, 992] --> D0: [1, 4, 672, 992], D1: [1, 48, 336, 496], D2: [1, 96, 168, 248], D3: [1, 192, 84, 124]
+        for i in range(len(self.fusion_blks)):  # D3 
+            d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
+            features = self.fusion_blks[i](features, detail_features[d_name_])
+        
+        phas = torch.sigmoid(self.matting_head(features))
+
+        return {'phas': phas}
+
+
+class Ori_Detail_Capture(nn.Module):
+    """
+    Simple and Lightweight Detail Capture Module for ViT Matting.
+    """
+    def __init__(
+        self,
+        in_chans = 384,
+        img_chans=4,
+        convstream_out = [48, 96, 192],
+        fusion_out = [256, 128, 64, 32],
+    ):
+        super().__init__()
+        assert len(fusion_out) == len(convstream_out) + 1
+
+        self.convstream = ConvStream(in_chans = img_chans)
+        self.conv_chans = self.convstream.conv_chans
+
+        self.fusion_blks = nn.ModuleList()
+        self.fus_channs = fusion_out.copy()
+        self.fus_channs.insert(0, in_chans)
+        for i in range(len(self.fus_channs)-1):
+            self.fusion_blks.append(
+                Fusion_Block(
+                    in_chans = self.fus_channs[i] + self.conv_chans[-(i+1)],
+                    out_chans = self.fus_channs[i+1],
+                )
+            )
+
+        self.matting_head = Matting_Head(
+            in_chans = fusion_out[-1],
+        )
+
+    def forward(self, features, images):
+        detail_features = self.convstream(images)
+        for i in range(len(self.fusion_blks)):
+            d_name_ = 'D'+str(len(self.fusion_blks)-i-1)
+            features = self.fusion_blks[i](features, detail_features[d_name_])
+        
+        phas = torch.sigmoid(self.matting_head(features))
+
+        return {'phas': phas}
diff --git a/modeling/decoder/unet_detail_capture.py b/modeling/decoder/unet_detail_capture.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d66ea3f8cd074881c74e8e5452da1e67de64d0
--- /dev/null
+++ b/modeling/decoder/unet_detail_capture.py
@@ -0,0 +1,429 @@
+import cv2
+import torch
+from torch import nn
+from torch.nn import functional as F
+# from nnMorpho.binary_operators import erosion
+from detectron2.layers.batch_norm import NaiveSyncBatchNorm
+
+
+class GenTrimapTorch(object):
+    def __init__(self, max_kernal=200):
+        self.max_kernal = max_kernal
+        self.erosion_kernels = [None] + [torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))).float().cuda() for size in range(1, self.max_kernal)]
+
+    def __call__(self, mask, kernel_size):
+        
+        fg_width = kernel_size
+        bg_width = kernel_size
+
+        fg_mask = mask
+        bg_mask = 1 - mask
+
+        fg_mask = erosion(fg_mask, self.erosion_kernels[fg_width], border='a')
+        bg_mask = erosion(bg_mask, self.erosion_kernels[bg_width], border='a')
+
+        trimap = torch.ones_like(mask) * 0.5
+        trimap[fg_mask == 1] = 1.0
+        trimap[bg_mask == 1] = 0.0
+
+        return trimap
+
+
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
+
+
+class BasicDownBlock(nn.Module):
+    def __init__(self, in_channel, out_channel, res = True, norm=LayerNorm2d, block_num=1, kernel_size=3):
+        super().__init__()
+
+        self.res = res
+        self.basic_layer = nn.ModuleList()
+        for i in range(block_num):
+            if i == 0:
+                basic_layer_in_ch = in_channel
+                stride = 2
+            else:
+                basic_layer_in_ch = out_channel
+                stride = 1
+                self.basic_layer.append(nn.GELU())
+            self.basic_layer.append(nn.Sequential(
+                nn.Conv2d(basic_layer_in_ch, out_channel, kernel_size, stride, kernel_size // 2), 
+                norm(out_channel),
+                nn.GELU(),
+                nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2), 
+                norm(out_channel),
+            ))
+        self.act = nn.GELU()
+
+        if self.res:
+            self.res_layer = nn.Conv2d(in_channel, out_channel, kernel_size, 2, kernel_size // 2)
+
+    def forward(self, x):
+
+        if self.res:
+            identity = self.res_layer(x)
+        else:
+            identity = F.interpolate(x, size=(out.shape[-2], out.shape[-1]), mode='bilinear', align_corners=False)
+
+        out = x
+        for layer in self.basic_layer:
+            out = layer(out)
+        
+        out = out + identity
+        out = self.act(out)
+
+        return out
+
+
+class BasicUpBlock(nn.Module):
+
+    def __init__( self, in_channel, out_channel, res = True, skip_connect = 'concat', norm=LayerNorm2d, block_num=1, kernel_size=3):
+        super().__init__()
+        assert skip_connect in {'sum', 'concat'}
+
+        self.res = res
+        self.skip_connect = skip_connect
+        self.basic_layer = nn.ModuleList()
+        for i in range(block_num):
+            if i == 0:
+                basic_layer_in_ch = in_channel
+                first_conv = nn.ConvTranspose2d(basic_layer_in_ch, out_channel, 2, 2)
+            else:
+                basic_layer_in_ch = out_channel
+                first_conv = nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2)
+                self.basic_layer.append(nn.GELU())
+            self.basic_layer.append(nn.Sequential(
+                first_conv, 
+                norm(out_channel),
+                nn.GELU(),
+                nn.Conv2d(out_channel, out_channel, kernel_size, 1, kernel_size // 2), 
+                norm(out_channel),
+            ))
+        self.act = nn.GELU()
+
+        if self.res:
+            self.res_layer = nn.Conv2d(in_channel, out_channel, kernel_size, 1, kernel_size // 2)
+
+
+    def forward(self, x, skip_feat, concat_feat=None):
+
+        if self.skip_connect == 'sum':
+            x = x + skip_feat
+        else:
+            x = torch.concat((x, skip_feat), dim=1)
+
+        if concat_feat is not None:
+            x = torch.concat((x, concat_feat), dim=1)
+
+        out = x
+        for layer in self.basic_layer:
+            out = layer(out)
+        # out = self.basic_layer(x)
+        
+        identity = F.interpolate(x, size=(out.shape[-2], out.shape[-1]), mode='bilinear', align_corners=False)
+        if self.res:
+            identity = self.res_layer(identity)
+
+        out = out + identity
+        out = self.act(out)
+
+        return out
+    
+
+
+class DetailUNet(nn.Module):
+    def __init__(
+        self,
+        img_feat_in = 4,
+        vit_early_feat_in = 768,
+        matting_feat_in = 5,
+        downsample_in_out = [(4, 32), (32, 64), (64, 128), (128, 256)],
+        upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)],
+        matting_head_in = 16,
+        skip_connect = 'sum',
+        norm_type = 'LN',
+    ):
+        super().__init__()
+
+        assert len(downsample_in_out) == len(upsample_in_out)
+        downsample_in_out[0] = (img_feat_in, downsample_in_out[0][1])
+
+        assert norm_type in {'BN', 'LN', 'SyncBN'}
+        if norm_type == 'BN':
+            self.norm = torch.nn.BatchNorm2d
+        elif norm_type == 'SyncBN':
+            self.norm = NaiveSyncBatchNorm
+        else:
+            self.norm = LayerNorm2d
+
+        self.down_blks = nn.ModuleList()
+        for in_ch, out_ch in downsample_in_out:
+            self.down_blks.append(
+                BasicDownBlock(in_ch, out_ch, norm=self.norm)
+            )
+        
+        self.mid_layer = nn.Sequential(
+            nn.Conv2d(vit_early_feat_in, downsample_in_out[-1][1], 1, 1), 
+            self.norm(downsample_in_out[-1][1]),
+            nn.GELU(),
+        )
+
+        self.up_blks = nn.ModuleList()
+        for i, (in_ch, out_ch) in enumerate(upsample_in_out):
+            if i == 2:
+                in_ch += matting_feat_in
+            self.up_blks.append(
+                BasicUpBlock(in_ch, out_ch, skip_connect=skip_connect, norm=self.norm)
+            )
+
+        self.matting_head = nn.Conv2d(matting_head_in, 1, 3, 1, 1)
+
+
+    def forward(self, x, vit_early_feat, matting_feat, return_alpha_logits=False):
+        details = []
+        dfeatures = x
+
+        for i in range(len(self.down_blks)):
+            dfeatures = self.down_blks[i](dfeatures)
+            details.append(dfeatures)
+
+        out = self.mid_layer(vit_early_feat)
+        for i in range(len(self.up_blks)):
+            if i == 2:
+                out = self.up_blks[i](out, details[-i - 1], matting_feat)
+            else:
+                out = self.up_blks[i](out, details[-i - 1])
+        alpha = self.matting_head(out)
+        if return_alpha_logits:
+            return alpha, out
+        else:
+            return alpha
+    
+
+class MattingDetailDecoder(nn.Module):
+    def __init__(
+        self,
+        img_feat_in = 4,
+        vit_intern_feat_in = 1024,
+        vit_intern_feat_index = [0, 1, 2, 3],
+        downsample_in_out = [(4, 32), (32, 64), (64, 128), (128, 256)],
+        upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)],
+        matting_head_in = 16,
+        skip_connect = 'sum',
+        norm_type = 'BN',
+        norm_mask_logits = 6.5,
+        with_trimap = False,
+        min_kernel_size = 20,
+        kernel_div = 10,
+        concat_gen_trimap = False,
+        wo_hq_features = False,
+        block_num = 1,
+        wo_big_kernel = False,
+        sam2_multi_scale_feates = False,
+    ):
+        super().__init__()
+
+        assert len(downsample_in_out) == len(upsample_in_out)
+        assert skip_connect in {'sum', 'concat'}
+        downsample_in_out[0] = (img_feat_in, downsample_in_out[0][1])
+        
+        self.vit_intern_feat_in = vit_intern_feat_in
+        self.vit_intern_feat_index = vit_intern_feat_index
+        self.norm_mask_logits = norm_mask_logits
+        self.with_trimap = with_trimap
+        self.min_kernel_size = min_kernel_size
+        self.kernel_div = kernel_div
+        self.concat_gen_trimap = concat_gen_trimap
+        self.wo_hq_features = wo_hq_features
+        self.block_num = block_num
+        self.wo_big_kernel = wo_big_kernel
+        self.sam2_multi_scale_feates = sam2_multi_scale_feates
+        if self.sam2_multi_scale_feates:
+            assert downsample_in_out[0][0] == 6
+            downsample_in_out = [(4, 32), (32, 64), (64 + 32, 128), (128 + 64, 256)]
+            upsample_in_out = [(256, 128), (128, 64), (64, 32), (32, 16)]
+
+        if self.with_trimap and not self.concat_gen_trimap:
+            self.gen_trimap = GenTrimapTorch()
+        assert norm_type in {'BN', 'LN', 'SyncBN'}
+        if norm_type == 'BN':
+            self.norm = torch.nn.BatchNorm2d
+        elif norm_type == 'SyncBN':
+            self.norm = NaiveSyncBatchNorm
+        else:
+            self.norm = LayerNorm2d
+
+        if self.block_num >= 2 and not self.wo_big_kernel:
+            self.big_kernel_process = nn.Sequential(
+                nn.Conv2d(img_feat_in, 16, kernel_size=13, stride=1, padding=6), 
+                self.norm(16),
+                nn.GELU(),
+                nn.Conv2d(16, 32, kernel_size=13, stride=1, padding=6), 
+                self.norm(32),
+                nn.GELU(),
+            )
+            downsample_in_out[0] = (32, downsample_in_out[0][1])
+
+        if not self.sam2_multi_scale_feates:
+            self.vit_feat_proj = nn.ModuleDict()
+            for idx in self.vit_intern_feat_index:
+                self.vit_feat_proj[str(idx)] = nn.Conv2d(self.vit_intern_feat_in, self.vit_intern_feat_in // len(self.vit_intern_feat_index), 1, 1)
+        self.vit_feat_aggregation = nn.Sequential(
+            nn.Conv2d(self.vit_intern_feat_in // len(self.vit_intern_feat_index) * len(self.vit_intern_feat_index), downsample_in_out[-1][1], 3, 1, 1), 
+            self.norm(downsample_in_out[-1][1]),
+            nn.GELU(),
+        )
+
+        self.down_blks = nn.ModuleList()
+        for in_ch, out_ch in downsample_in_out:
+            self.down_blks.append(
+                BasicDownBlock(in_ch, out_ch, norm=self.norm, block_num=self.block_num, kernel_size=5 if self.block_num >= 2 else 3)
+            )
+        
+        if self.sam2_multi_scale_feates:
+            self.mid_layer = nn.ModuleList([
+                nn.Sequential(
+                    nn.Conv2d(32, 32, 1, 1), 
+                    self.norm(32),
+                    nn.GELU(),
+                ),
+                nn.Sequential(
+                    nn.Conv2d(64, 64, 1, 1), 
+                    self.norm(64),
+                    nn.GELU(),
+                ),
+                nn.Sequential(
+                    nn.Conv2d(256, 256, 1, 1), 
+                    self.norm(256),
+                    nn.GELU(),
+                ),
+                nn.Sequential(
+                    nn.Conv2d(512, 256, 3, 1, 1), 
+                    self.norm(256),
+                    nn.GELU(),
+                ),
+            ])
+        else:
+            self.mid_layer = nn.Sequential(
+                nn.Conv2d(downsample_in_out[-1][1] * 2, downsample_in_out[-1][1], 1, 1), 
+                self.norm(downsample_in_out[-1][1]),
+                nn.GELU(),
+            )
+
+        self.up_blks = nn.ModuleList()
+        for _, (in_ch, out_ch) in enumerate(upsample_in_out):
+            if skip_connect == 'concat':
+                self.up_blks.append(BasicUpBlock(in_ch * 2, out_ch, skip_connect=skip_connect, norm=self.norm, block_num=self.block_num))
+            else:
+                self.up_blks.append(BasicUpBlock(in_ch, out_ch, skip_connect=skip_connect, norm=self.norm, block_num=self.block_num))
+
+        self.matting_head = nn.Conv2d(matting_head_in, 1, 3, 1, 1)
+
+        if self.norm_mask_logits == 'BN':
+            self.logits_norm = self.norm(1)
+
+
+    def preprocess_inputs(self, images, hq_features, pred_trimap):
+
+        if self.wo_hq_features:
+            return images
+
+        if isinstance(self.norm_mask_logits, float):
+            norm_hq_features = hq_features / self.norm_mask_logits
+        elif self.norm_mask_logits == 'BN':
+            norm_hq_features = self.logits_norm(hq_features)
+        elif self.norm_mask_logits == 'Sigmoid':
+            if hq_features.shape[1] == 1:
+                norm_hq_features = torch.sigmoid(hq_features)
+            else:
+                norm_hq_features = torch.softmax(hq_features, dim=1)
+        elif self.norm_mask_logits:
+            norm_hq_features = hq_features / torch.std(hq_features, dim=(1, 2, 3), keepdim=True)
+        else:
+            norm_hq_features = hq_features
+
+        if self.concat_gen_trimap:
+            pred_trimap = F.interpolate(pred_trimap, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False)
+            pred_trimap = torch.argmax(pred_trimap, dim=1, keepdim=True).float() / 2.0
+            norm_hq_features = torch.concat((norm_hq_features, pred_trimap.detach()), dim=1)
+        elif self.with_trimap:
+            mask = (norm_hq_features > 0).float()
+            for i_batch in range(images.shape[0]):
+                mask_area = torch.sum(mask[i_batch])
+                kernel_size = max(self.min_kernel_size, int((mask_area ** 0.5) / self.kernel_div))
+                kernel_size = min(kernel_size, self.gen_trimap.max_kernal - 1)
+                mask[i_batch, 0] = self.gen_trimap(mask[i_batch, 0], kernel_size=kernel_size)
+            trimaps = mask
+            norm_hq_features = torch.concat((norm_hq_features, trimaps), dim=1)
+
+        conditional_images = torch.concatenate((images, norm_hq_features), dim=1)
+        return conditional_images
+
+    def forward(self, images, hq_features, vit_intern_feat, return_alpha_logits=False, pred_trimap=None):
+        
+        condition_input = self.preprocess_inputs(images, hq_features, pred_trimap)
+
+        if not self.sam2_multi_scale_feates:
+            # aggregate 4 vit_intern_feat
+            # assert len(vit_intern_feat) == self.vit_intern_feat_num
+            vit_feats = []
+            for idx in self.vit_intern_feat_index:
+                vit_feats.append(self.vit_feat_proj[str(idx)](vit_intern_feat[idx].permute(0, 3, 1, 2)))
+            vit_feats = torch.concat(vit_feats, dim=1)
+            vit_aggregation_feats = self.vit_feat_aggregation(vit_feats)
+
+        details = []
+        dfeatures = condition_input
+
+        if hasattr(self, 'big_kernel_process'):
+            dfeatures = self.big_kernel_process(dfeatures)
+
+        for i in range(len(self.down_blks)):
+            if self.sam2_multi_scale_feates:
+                if i == 2:
+                    dfeatures = torch.concat((dfeatures, self.mid_layer[0](vit_intern_feat['high_res_feats'][0])), dim=1)
+                elif i == 3:
+                    dfeatures = torch.concat((dfeatures, self.mid_layer[1](vit_intern_feat['high_res_feats'][1])), dim=1)
+            dfeatures = self.down_blks[i](dfeatures)
+            details.append(dfeatures)
+
+        if self.sam2_multi_scale_feates:
+            out = torch.concat((details[-1], self.mid_layer[2](vit_intern_feat['image_embed'])), dim=1)
+            out = self.mid_layer[3](out)
+        else:
+            out = self.mid_layer(torch.concat((details[-1], vit_aggregation_feats), dim=1))
+        for i in range(len(self.up_blks)):
+            out = self.up_blks[i](out, details[-i - 1])
+        alpha = torch.sigmoid(self.matting_head(out))
+        if return_alpha_logits:
+            return alpha, out
+        else:
+            return alpha
+
+
+
+if __name__ == '__main__':
+
+    from engine.mattingtrainer import parameter_count_table
+
+    model = MattingDetailDecoder(img_feat_in = 5, vit_intern_feat_index=[0])
+    x = torch.randn((2, 5, 1024, 1024))
+    hq_features = torch.randn((2, 1, 1024, 1024))
+    vit_feat = [torch.randn((2, 64, 64, 1024)) for _ in range(4)]
+
+    out = model(x, hq_features, vit_feat)
+    print(out.shape)
+
+    print("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=5))
diff --git a/modeling/meta_arch/__init__.py b/modeling/meta_arch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..87aa29eb1547d347cd6e41fd1d831c318908be47
--- /dev/null
+++ b/modeling/meta_arch/__init__.py
@@ -0,0 +1 @@
+from .sam_hq_matting import SamHqMatte
\ No newline at end of file
diff --git a/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bc7fd1eb6507a9c9f894d49888d8353b22ca0c5a
Binary files /dev/null and b/modeling/meta_arch/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc b/modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..155e8015bd75aa6a8740e168b880463db3b9fb5b
Binary files /dev/null and b/modeling/meta_arch/__pycache__/sam_hq_matting.cpython-38.pyc differ
diff --git a/modeling/meta_arch/sam_hq_matting.py b/modeling/meta_arch/sam_hq_matting.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4a7bca52d5c3f5ab16201dd9e1453f8650b96f0
--- /dev/null
+++ b/modeling/meta_arch/sam_hq_matting.py
@@ -0,0 +1,671 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+import os
+import numpy as np
+from PIL import Image
+from copy import deepcopy
+from collections import defaultdict
+
+from detectron2.structures import ImageList
+from detectron2.utils.comm import get_local_rank
+from modeling.semantic_enhanced_matting.predictor import SamPredictor
+from modeling.semantic_enhanced_matting.condition_conv import ConditionConv, ConditionEmbedding, ConditionAdd, BBoxEmbedInteract, BBoxInteract, BBoxInteractInOut
+from modeling.semantic_enhanced_matting.modeling.image_encoder import PatchEmbed
+from modeling.semantic_enhanced_matting.modeling.common import LayerNorm2d
+from modeling.decoder.unet_detail_capture import MattingDetailDecoder
+from modeling.semantic_enhanced_matting.feature_fusion import FeatureFusion
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+
+from modeling.semantic_enhanced_matting.modeling.mask_decoder_hq_matting import MaskDecoderHQMatting
+from modeling.semantic_enhanced_matting.modeling import TwoWayTransformer
+
+from peft import LoraConfig, get_peft_model
+from peft.tuners.lora.layer import LoraLayer
+from peft.tuners.tuners_utils import BaseTunerLayer
+
+from data.rand_augment import RandAugment
+import random
+import kornia.filters as kf
+
+
+class SamHqMatte(nn.Module):
+
+    target_length = 1024
+
+    def __init__(
+        self,
+        *,
+        sam_model,
+        hq_token_only,
+        hq_features_type,
+        matting_decoder,
+        criterion,
+        pixel_mean,
+        pixel_std,
+        multimask_output=False,
+        vis_period=None,
+        output_dir=None,
+        lora_rank = None,
+        lora_alpha = None,
+        lora_target_modules = ["qkv", "proj"],
+        lora_dropout = 0.1,
+        w_dora = False,
+        w_rslora = False,
+        lora_on_mask_decoder = False,
+        frozen_sam_hq_reg = None,
+        reg_margin = 0.85,
+        w_attention_mask = False,
+        alpha_reg_range = None,
+        alpha_reg_weight = 1.0,
+        coconut_pl = False,
+        coconut_pl_alpha = 1.0,
+        coconut_self_training = False,
+        eval_w_sam_hq_mask = False,
+        backbone_condition = False,
+        condition_wo_conv = False,
+        w_only_bbox_cond = False,
+        coconut_only_known_l1 = False,
+        backbone_bbox_prompt = None,
+        backbone_bbox_prompt_loc = [2, 3], 
+        backbone_bbox_prompt_loss_weight = 1.0,
+        concat_gen_trimap = False,
+        multi_matting_decoder = None,
+        w_all_logits = False,
+        bbox_prompt_all_block = None,
+        matting_token = False,
+        test_w_hq_token = False,
+        sam_hq_token_reg = None,
+        feat_cross_attn_fusion = False,
+        trimap_loss_type = None,
+        reg_on_sam_logits = False,
+        reg_w_bce_loss = False,
+        complex_trimap_pred_layer = False,
+        matting_token_sup = None,
+        matting_token_sup_loss_weight = None,
+        sam2 = False,
+    ):
+        super(SamHqMatte, self).__init__()
+
+        self.sam_model = sam_model
+        self.sam_predictor = SamPredictor(self.sam_model) if not sam2 else SAM2ImagePredictor(self.sam_model)  # already in eval mode and no_grad
+        self.hq_token_only = hq_token_only
+        self.multimask_output = multimask_output
+        self.hq_features_type = hq_features_type
+
+        self.matting_decoder = matting_decoder
+
+        self.criterion = criterion
+
+        self.register_buffer(
+            "pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False
+        )
+        self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False)
+        assert (
+            self.pixel_mean.shape == self.pixel_std.shape
+        ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!"
+
+        self.vis_period = vis_period
+        if output_dir is not None and output_dir != '?':
+            self.output_dir = os.path.join(output_dir, 'vis_results')
+            os.makedirs(self.output_dir, exist_ok=True)
+        self.train_iter_index = 0
+
+        self.lora_rank = lora_rank
+        self.lora_alpha = lora_alpha
+        self.lora_target_modules = lora_target_modules
+        self.lora_dropout = lora_dropout
+        self.w_dora = w_dora
+        self.w_rslora = w_rslora
+        self.lora_on_mask_decoder = lora_on_mask_decoder
+        self.frozen_sam_hq_reg = frozen_sam_hq_reg
+        self.reg_margin = reg_margin
+        self.w_attention_mask = w_attention_mask
+        self.alpha_reg_range = alpha_reg_range
+        self.alpha_reg_weight = alpha_reg_weight
+        self.coconut_pl = coconut_pl
+        self.coconut_pl_alpha = coconut_pl_alpha
+        self.coconut_self_training = coconut_self_training
+        self.eval_w_sam_hq_mask = eval_w_sam_hq_mask
+        self.backbone_condition = backbone_condition
+        self.condition_wo_conv = condition_wo_conv
+        self.w_only_bbox_cond = w_only_bbox_cond
+        self.coconut_only_known_l1 = coconut_only_known_l1
+        self.backbone_bbox_prompt = backbone_bbox_prompt
+        self.backbone_bbox_prompt_loc = backbone_bbox_prompt_loc
+        self.backbone_bbox_prompt_loss_weight = backbone_bbox_prompt_loss_weight
+        self.concat_gen_trimap = concat_gen_trimap
+        self.multi_matting_decoder = multi_matting_decoder
+        self.w_all_logits = w_all_logits
+        self.bbox_prompt_all_block = bbox_prompt_all_block
+        self.matting_token = matting_token
+        self.test_w_hq_token = test_w_hq_token
+        self.sam_hq_token_reg = sam_hq_token_reg
+        self.feat_cross_attn_fusion = feat_cross_attn_fusion
+        self.trimap_loss_type = trimap_loss_type
+        self.reg_on_sam_logits = reg_on_sam_logits
+        self.reg_w_bce_loss = reg_w_bce_loss
+        self.complex_trimap_pred_layer = complex_trimap_pred_layer
+        self.matting_token_sup = matting_token_sup
+        self.sam2 = sam2
+        assert self.matting_token_sup in {'alpha', 'trimap', None}
+        self.matting_token_sup_loss_weight = matting_token_sup_loss_weight
+        if self.matting_token_sup is not None:
+            assert self.backbone_bbox_prompt in {'bbox', None}
+        if self.frozen_sam_hq_reg is not None:
+            assert self.lora_rank is not None
+        if self.w_attention_mask:
+            self.attention_head = deepcopy(self.matting_decoder)
+        if self.coconut_self_training:
+            self.rand_aug = RandAugment(3,6)
+            self.warm_iter_coconut_self_training = 5000
+        if self.backbone_condition:
+            assert self.lora_rank is not None
+        if self.backbone_bbox_prompt is not None:
+            assert self.lora_rank is not None
+        if self.w_all_logits:
+            self.sam_predictor.model.mask_decoder.w_all_logits = True
+        if self.bbox_prompt_all_block:
+            assert self.lora_rank is not None
+        if self.matting_token and not self.sam2:
+            self.sam_predictor.model.mask_decoder.hq_token_only = self.hq_token_only
+
+    @property
+    def device(self):
+        return self.pixel_mean.device
+
+    def init_lora(self, model=None):
+        if model is not None and self.lora_rank >= 1:
+            if self.lora_on_mask_decoder:
+                self.lora_target_modules += ["q_proj", "k_proj", "v_proj", "out_proj"]
+                modules_to_save = None
+            else:
+                modules_to_save = ['matting_decoder']
+
+            lora_config = LoraConfig(
+                r=self.lora_rank,
+                lora_alpha=self.lora_alpha,
+                use_rslora=self.w_rslora,
+                use_dora=self.w_dora,
+                init_lora_weights="gaussian",
+                target_modules=self.lora_target_modules,
+                lora_dropout=self.lora_dropout,
+                modules_to_save=modules_to_save
+            )
+            model = get_peft_model(model, lora_config)
+            if self.lora_on_mask_decoder:
+                for n, p in model.matting_decoder.named_parameters():
+                    if n.split('modules_to_save.default.')[-1] in model.matting_decoder.trainable_params_str:
+                        p.requires_grad = True
+            else:
+                for n, p in model.matting_decoder.named_parameters():
+                    if n.split('modules_to_save.default.')[-1] in model.matting_decoder.frozen_params_str:
+                        p.requires_grad = False
+            return model
+        elif self.lora_rank >= 1:
+            lora_config = LoraConfig(
+                r=self.lora_rank,
+                lora_alpha=self.lora_alpha,
+                use_rslora=self.w_rslora,
+                use_dora=self.w_dora,
+                init_lora_weights="gaussian",
+                target_modules=self.lora_target_modules,
+                lora_dropout=self.lora_dropout,
+            )
+            self.sam_predictor.model.image_encoder = get_peft_model(self.sam_predictor.model.image_encoder, lora_config)
+
+            if self.sam2:
+                for n, p in self.sam_predictor.model.image_encoder.named_parameters():
+                    if 'bbox_mask' in n:
+                        p.requires_grad = True
+
+        if self.backbone_condition:
+            if self.w_only_bbox_cond:
+                self.condition_embedding = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 160)
+            else:
+                self.condition_embedding = ConditionEmbedding(condition_num = 5, pos_embedding_dim = 128)
+
+            if self.condition_wo_conv:
+                self.condition_conv = nn.ModuleList([ConditionAdd() for _ in range(4)])
+            else:
+                self.condition_conv = nn.ModuleList([ConditionConv(
+                    in_channels = self.sam_predictor.model.image_encoder.embed_dim, 
+                    out_channels = self.sam_predictor.model.image_encoder.embed_dim,
+                    bottleneck_channels = 512
+                ) for _ in range(4)])
+        
+        if self.backbone_bbox_prompt is not None and not self.sam2:
+            self.condition_layer = nn.ModuleDict()
+            self.condition_layer['patch_embed'] =  PatchEmbed(
+                kernel_size=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size),
+                stride=(self.sam_predictor.model.image_encoder.patch_size, self.sam_predictor.model.image_encoder.patch_size),
+                in_chans=4,
+                embed_dim=self.sam_predictor.model.image_encoder.embed_dim,
+            )
+            if self.multi_matting_decoder is None:
+                if self.backbone_bbox_prompt in {'trimap', 'alpha_trimap'}:
+                    transformer_dim = self.sam_predictor.model.image_encoder.embed_dim
+                    for i in self.backbone_bbox_prompt_loc:
+                        if self.complex_trimap_pred_layer:
+                            self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
+                                nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2),
+                                LayerNorm2d(transformer_dim // 2),  # 512
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1),
+                                LayerNorm2d(transformer_dim // 4),  # 256
+                                nn.GELU(),
+                                nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+                                LayerNorm2d(transformer_dim // 8),  # 128
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1),
+                                LayerNorm2d(transformer_dim // 16),  # 64
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 16, 3, kernel_size=3, stride=1, padding=1),
+                            )
+                        else:
+                            self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
+                                nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+                                LayerNorm2d(transformer_dim // 4),
+                                nn.GELU(),
+                                nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 8, 3, kernel_size=1, stride=1),
+                            )
+                elif self.backbone_bbox_prompt == 'alpha':
+                    transformer_dim = self.sam_predictor.model.image_encoder.embed_dim
+                    for i in self.backbone_bbox_prompt_loc:
+                        if self.complex_trimap_pred_layer:
+                            self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
+                                nn.ConvTranspose2d(transformer_dim, transformer_dim // 2, kernel_size=2, stride=2),
+                                LayerNorm2d(transformer_dim // 2),  # 512
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 2, transformer_dim // 4, kernel_size=3, stride=1, padding=1),
+                                LayerNorm2d(transformer_dim // 4),  # 256
+                                nn.GELU(),
+                                nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+                                LayerNorm2d(transformer_dim // 8),  # 128
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 8, transformer_dim // 16, kernel_size=3, stride=1, padding=1),
+                                LayerNorm2d(transformer_dim // 16),  # 64
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 16, 1, kernel_size=3, stride=1, padding=1),
+                                nn.Sigmoid()
+                            )
+                        else:
+                            self.condition_layer['{}_pred_layer'.format(i)] = nn.Sequential(
+                                nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+                                LayerNorm2d(transformer_dim // 4),
+                                nn.GELU(),
+                                nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+                                nn.GELU(),
+                                nn.Conv2d(transformer_dim // 8, 1, kernel_size=1, stride=1),
+                                nn.Sigmoid()
+                            )
+            if self.bbox_prompt_all_block is not None:
+                if self.bbox_prompt_all_block == 'reuse_cross-self-attn':
+                    self.condition_layer['prompt_layer'] = BBoxInteract(
+                        position_point_embedding = deepcopy(self.sam_predictor.model.prompt_encoder.pe_layer), 
+                        point_weight = deepcopy(self.sam_predictor.model.prompt_encoder.point_embeddings)
+                    )
+                elif self.bbox_prompt_all_block == 'in-out-bbox_cross-self-attn':
+                    self.condition_layer['prompt_layer'] = BBoxInteractInOut(downsample_rate = 2)
+                else:
+                    embed_type, interact_type = self.bbox_prompt_all_block.split('_')
+                    self.condition_layer['prompt_layer'] = BBoxEmbedInteract(embed_type, interact_type)
+
+        if self.feat_cross_attn_fusion:
+            self.condition_layer['feature_fusion'] = FeatureFusion(in_channels=self.sam_predictor.model.image_encoder.embed_dim, attn_compression_ratio=8)
+
+    def condition_bbox_and_instance_num(self):
+        self.sam_predictor.model.image_encoder.conv_necks = None
+
+    def forward_samhq_and_matting_decoder(self, images, bbox, condition_proj=None, return_hq_token=False):
+        # get features from SAM image encoder
+        if self.sam2:
+            interm_features, sam2_logits, matting_logits, pred_trimap = self.forward_samhq(images, bbox, condition_proj)
+            sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
+            matting_logits = F.interpolate(matting_logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
+            sam_hq_matting_token = {
+                'masks_hq': sam2_logits,
+                'masks_matting': matting_logits
+            }
+            hq_features = matting_logits
+            low_res_masks = matting_logits
+        else:
+            if self.matting_token:
+                features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token, pred_trimap = self.forward_samhq(images, bbox, condition_proj)
+                if return_hq_token:
+                    return sam_hq_matting_token['masks_hq']
+                else:
+                    if not self.training and self.test_w_hq_token:
+                        low_res_masks, hq_features = sam_hq_matting_token['masks_hq'], sam_hq_matting_token['masks_hq']
+                    else:
+                        low_res_masks, hq_features = sam_hq_matting_token['masks_matting'], sam_hq_matting_token['masks_matting']
+            else:
+                features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks, pred_trimap = self.forward_samhq(images, bbox, condition_proj)
+                if return_hq_token:
+                    return hq_features
+                sam_hq_matting_token = {'masks_hq': hq_features, 'masks_sam': sam_logits}
+
+        # get alpha from our proposed matting_decoder
+        if isinstance(self.matting_decoder, MattingDetailDecoder):
+            pred_alpha = self.matting_decoder(
+                images = images,
+                hq_features = hq_features,
+                vit_intern_feat = interm_features,
+                return_alpha_logits = (self.alpha_reg_range is not None),
+                pred_trimap = pred_trimap
+            )
+        else:
+            pred_alpha = self.matting_decoder(
+                image_embeddings = features,  # [B, 256, 64, 64]
+                image_pe = image_pe,
+                sparse_prompt_embeddings = sparse_embeddings,
+                dense_prompt_embeddings = dense_embeddings,
+                multimask_output = False,
+                interm_embeddings = interm_features,  # [B, 256, 64, 64]
+                hq_features = hq_features,
+                images = images,
+                return_alpha_logits = (self.alpha_reg_range is not None),
+                pred_trimap = pred_trimap
+            )
+        return low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token
+
+    def forward(self, batched_inputs):  # image: [1, 3, 643, 960]: 0.0~1.0, trimap: [1, 1, 643, 960]: 0.0~1.0
+
+        inputs = self.preprocess_inputs(batched_inputs) 
+        images, bbox, gt_alpha, trimap, condition = inputs['images'], inputs['bbox'], inputs['alpha'], inputs['trimap'], inputs['condition']
+
+        if self.backbone_condition:
+            condition_proj = self.condition_embedding(condition) 
+        elif self.backbone_bbox_prompt is not None or self.bbox_prompt_all_block is not None:
+            condition_proj = bbox
+        else:
+            condition_proj = None
+
+        low_res_masks, pred_alpha, pred_trimap, sam_hq_matting_token = self.forward_samhq_and_matting_decoder(images, bbox, condition_proj)
+        
+        assert not self.training
+        if self.eval_w_sam_hq_mask:
+            self.sam_predictor.model.image_encoder.disable_adapter_layers()
+            with torch.no_grad():
+                ori_features, ori_interm_features = self.sam_predictor.model.image_encoder(images)
+                samhq_low_res_masks = self.forward_samhq_others(images, bbox, ori_features, ori_interm_features)[-1]
+                samhq_low_res_masks = F.interpolate(samhq_low_res_masks, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False)
+            self.sam_predictor.model.image_encoder.enable_adapter_layers()
+
+            return pred_alpha, samhq_low_res_masks
+        else:
+            return pred_alpha
+        
+    def forward_samhq_image_encoder(self, images, condition_proj=None):
+        if self.sam2:
+            backbone_out = self.sam_predictor.model.forward_image([images, condition_proj])
+            _, vision_feats, _, _ = self.sam_predictor.model._prepare_backbone_features(backbone_out)
+            # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+            if self.sam_predictor.model.directly_add_no_mem_embed:
+                vision_feats[-1] = vision_feats[-1] + self.sam_predictor.model.no_mem_embed
+            feats = [
+                feat.permute(1, 2, 0).view(feat.shape[1], -1, *feat_size)
+                for feat, feat_size in zip(vision_feats[::-1], self.sam_predictor._bb_feat_sizes[::-1])
+            ][::-1]
+            return {"image_embed": feats[-1], "high_res_feats": feats[:-1]}, None, None
+        else:
+            if self.backbone_condition:
+                condition_layer = self.condition_conv 
+            elif self.backbone_bbox_prompt:
+                condition_layer = self.condition_layer
+            else:
+                condition_layer = None
+            # [B, 3, 1024, 1024]: -2. ~ 2. --> [B, 256, 64, 64], 4 x [B, 64, 64, 768]
+            features, interm_features, pred_trimap = self.sam_predictor.model.image_encoder(images, condition_proj, condition_layer)
+            return features, interm_features, pred_trimap
+    
+    # @torch.no_grad()
+    def forward_samhq_others(self, images, bbox, features, interm_features):
+        if self.sam2:
+            sam2_logits, matting_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features)
+            return features, sam2_logits, matting_logits
+        
+        image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe()
+
+        cat_sparse_embeddings = []
+        cat_dense_prompt_embeddings = []
+        cat_hq_features = []
+        cat_sam_logits = []
+        cat_low_res_masks = []
+        cat_sam_hq_matting_token = defaultdict(list)
+
+        for idx in range(images.shape[0]):
+            # get hq_features from SAM_HQ mask decoder
+
+                # Embed prompts
+            sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder(
+                points=None,
+                # boxes=bbox[idx: idx + 1],
+                boxes=bbox[idx],  # [N, 4]
+                masks=None,
+            )  # [B, 2, 256], [B, 256, 64, 64]
+
+                # Predict masks
+            if isinstance(self.sam_predictor.model.mask_decoder, MaskDecoderHQMatting):
+                sam_hq_matting_token = self.sam_predictor.model.mask_decoder(
+                    image_embeddings = features[idx: idx + 1],
+                    image_pe = image_pe,
+                    sparse_prompt_embeddings = sparse_embeddings,
+                    dense_prompt_embeddings = dense_embeddings,
+                    multimask_output = self.multimask_output,
+                    interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features],
+                )
+                for key in sam_hq_matting_token.keys():
+                    cat_sam_hq_matting_token[key].append(sam_hq_matting_token[key])
+            else:
+                low_res_masks, masks_sam, hq_features = self.sam_predictor.model.mask_decoder(
+                    image_embeddings = features[idx: idx + 1],
+                    image_pe = image_pe,
+                    sparse_prompt_embeddings = sparse_embeddings,
+                    dense_prompt_embeddings = dense_embeddings,
+                    multimask_output = self.multimask_output,
+                    hq_token_only = self.hq_token_only,
+                    interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features],
+                    return_hq_features_type = self.hq_features_type
+                )
+                cat_hq_features.append(hq_features)
+                cat_sam_logits.append(masks_sam)
+                cat_low_res_masks.append(low_res_masks)
+
+            cat_sparse_embeddings.append(sparse_embeddings)
+            cat_dense_prompt_embeddings.append(dense_embeddings)
+            
+        sparse_embeddings = torch.stack(cat_sparse_embeddings, dim=0)  # [B, 1, 2, 256]
+        dense_embeddings = torch.stack(cat_dense_prompt_embeddings, dim=0)  # [B, 1, 256, 64, 64]
+        
+        if self.matting_token:
+            for key in cat_sam_hq_matting_token.keys():
+                cat_sam_hq_matting_token[key] = torch.cat(cat_sam_hq_matting_token[key], dim=0)
+                cat_sam_hq_matting_token[key] = F.interpolate(cat_sam_hq_matting_token[key], size=images.shape[-2:], mode='bilinear', align_corners=False)
+            sam_hq_matting_token = cat_sam_hq_matting_token
+            return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, sam_hq_matting_token
+        else:
+            hq_features = torch.cat(cat_hq_features, dim=0)  # [B, 1, 256, 256]
+            low_res_masks = torch.cat(cat_low_res_masks, dim=0)  # [B, 1, 256, 256]
+            hq_features = F.interpolate(hq_features, size=images.shape[-2:], mode='bilinear', align_corners=False)  # [B, 1, 256, 256] --> [B, 1, 1024, 1024]
+            sam_logits = torch.cat(cat_sam_logits, dim=0)
+            sam_logits = F.interpolate(sam_logits, size=images.shape[-2:], mode='bilinear', align_corners=False)  # [B, 1, 256, 256] --> [B, 1, 1024, 1024]
+            return features, image_pe, sparse_embeddings, dense_embeddings, interm_features, hq_features, sam_logits, low_res_masks
+
+    def forward_samhq(self, images, bbox, condition_proj=None):
+        if self.lora_rank is None:
+            with torch.no_grad():
+                features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj)
+        else:
+            features, interm_features, pred_trimap = self.forward_samhq_image_encoder(images, condition_proj)
+
+        return self.forward_samhq_others(images, bbox, features, interm_features) + (pred_trimap, )
+
+    def get_frozen_sam_logits(self, images, bbox, mask_type='hq'):
+        
+        if self.sam2:
+            features, _, _ = self.forward_samhq_image_encoder(images)
+            sam2_logits = self.sam_predictor.predict_batch_boxes_and_features(bbox, features, wo_matting_token=True)
+            sam2_logits = F.interpolate(sam2_logits, size=images.shape[-2:], mode='bilinear', align_corners=False)
+            return sam2_logits
+
+        assert mask_type in {'hq', 'sam'} 
+        features, interm_features, _ = self.forward_samhq_image_encoder(images)
+        image_pe = self.sam_predictor.model.prompt_encoder.get_dense_pe()
+
+        cat_logits = []
+        for idx in range(images.shape[0]):
+            sparse_embeddings, dense_embeddings = self.sam_predictor.model.prompt_encoder(points=None, boxes=bbox[idx], masks=None)
+
+            low_res_masks, masks_sam, hq_features = self.sam_predictor.model.frozen_mask_decoder(
+                image_embeddings = features[idx: idx + 1],
+                image_pe = image_pe,
+                sparse_prompt_embeddings = sparse_embeddings,
+                dense_prompt_embeddings = dense_embeddings,
+                multimask_output = self.multimask_output,
+                hq_token_only = self.hq_token_only,
+                interm_embeddings = [interm_feature[idx: idx + 1] for interm_feature in interm_features],
+                return_hq_features_type = self.hq_features_type
+            )
+            if mask_type == 'hq':
+                cat_logits.append(hq_features)  
+            else:
+                cat_logits.append(masks_sam)  
+        
+        logits = torch.cat(cat_logits, dim=0)  # [B, 1, 256, 256]
+        logits = F.interpolate(logits, size=images.shape[-2:], mode='bilinear', align_corners=False)  # [B, 1, 256, 256] --> [B, 1, 1024, 1024]
+        return logits
+
+    def vis_training_results(self, **kwargs):
+        # images, bbox, trimap, low_res_masks, pred_alpha, alpha
+        self.train_iter_index += 1
+        if self.train_iter_index % self.vis_period == 0:
+            batch_save_results = []
+            save_path = os.path.join(self.output_dir, '{:06d}_rank{}.jpg'.format(self.train_iter_index, get_local_rank()))
+            
+            # [('images', (4, 3, 1024, 1024), -2.117904, 2.64), ('bbox', (4, 1, 4), 0.0, 1023.0), ('trimap', (4, 1, 1024, 1024), 0.0, 1.0), ('low_res_masks', (4, 1, 256, 256), -20.38, 10.15), ('pred_alpha', (4, 1, 1024, 1024), 0.1547, 0.791), ('alpha', (4, 1, 1024, 1024), 0.0, 1.0)]
+            for key in kwargs.keys():
+                if key == 'bbox':
+                    continue
+                # turn all tensor to [B, H, W, 3]: 0~255 np.int8
+                if key == 'images':
+                    kwargs[key] = kwargs[key] * self.pixel_std + self.pixel_mean
+                    kwargs[key] = kwargs[key].permute(0, 2, 3, 1) * 255.0
+                    for i in range(kwargs['images'].shape[0]):
+                        l, u, r, d = int(kwargs['bbox'][i, 0, 0].item()), int(kwargs['bbox'][i, 0, 1].item()), int(kwargs['bbox'][i, 0, 2].item()), int(kwargs['bbox'][i, 0, 3].item())
+                        red_line = torch.tensor([[255., 0., 0.]], device=kwargs[key].device, dtype=kwargs[key].dtype)
+                        kwargs[key][i, u: d, l, :] = red_line
+                        kwargs[key][i, u: d, r, :] = red_line
+                        kwargs[key][i, u, l: r, :] = red_line
+                        kwargs[key][i, d, l: r, :] = red_line
+                elif key in {'low_res_masks', 'frozen_hq_token'}:
+                    if torch.max(kwargs[key]) <= 1:  # coconut ori alpha
+                        kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0
+                    else:
+                        kwargs[key] = F.interpolate(kwargs[key], size=(kwargs['images'].shape[-3], kwargs['images'].shape[-2]), mode='bilinear', align_corners=False)
+                        kwargs[key] = (kwargs[key] > self.sam_predictor.model.mask_threshold).float().permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0
+                else:
+                    kwargs[key] = kwargs[key].permute(0, 2, 3, 1).repeat(1, 1, 1, 3) * 255.0
+
+                kwargs[key] = np.uint8(kwargs[key].detach().cpu().numpy())
+
+            for i in range(kwargs['images'].shape[0]):
+                save_results = []
+                for key in kwargs.keys():
+                    if key != 'bbox':
+                        save_results.append(kwargs[key][i])
+                batch_save_results.append(np.concatenate(save_results, axis=1))
+            
+            Image.fromarray(np.concatenate(batch_save_results, axis=0)).save(save_path)
+
+    def preprocess_inputs(self, batched_inputs):
+        """
+        Normalize, pad and batch the input images.
+        """
+        output = dict()
+
+        if "alpha" in batched_inputs:
+            alpha = batched_inputs["alpha"].to(self.device)
+        else:
+            alpha = None
+
+        bbox = batched_inputs["bbox"].to(self.device)
+
+        if self.training and self.coconut_self_training and sum([i == 'COCONut' for i in batched_inputs['dataset_name']]) >= 1:
+            output['coconut_ori_img'] = []
+            output['coconut_trimap'] = []
+            output['coconut_bbox'] = []
+            output['coconut_idx'] = []
+            for i, dataset_name in enumerate(batched_inputs['dataset_name']):
+                if dataset_name == 'COCONut':
+                    # generate coconut_aug_img
+                    img_np = np.uint8(batched_inputs["image"][i].permute(1, 2, 0).cpu().numpy() * 255.)
+                    strong_aug_img = self.rand_aug(Image.fromarray(img_np), cutout = False)
+                    strong_aug_img_tensor = torch.from_numpy(np.array(strong_aug_img)).to(self.device).permute(2, 0, 1)[None] / 255.
+                    blur_kernel_sigma = 1.0 + random.random()  # random from 1.0 ~ 2.0
+                    blur_filter = kf.GaussianBlur2d((101, 101), (blur_kernel_sigma, blur_kernel_sigma))
+                    blur_strong_aug_img_tensor = blur_filter(strong_aug_img_tensor)[0]
+
+                    output['coconut_ori_img'].append(batched_inputs["image"][i])
+                    batched_inputs["image"][i] = blur_strong_aug_img_tensor
+
+                    # generate coconut_trimap
+                    coconut_mask = (alpha[i] != 0).float()
+                    mask_area = torch.sum(coconut_mask)
+                    kernel_size = max(self.matting_decoder.min_kernel_size, int((mask_area ** 0.5) / 7))  # self.matting_decoder.kernel_div
+                    kernel_size = min(kernel_size, self.matting_decoder.gen_trimap.max_kernal - 1)
+                    output['coconut_trimap'].append(self.matting_decoder.gen_trimap(coconut_mask[0], kernel_size=kernel_size)[None])
+
+                    output['coconut_bbox'].append(bbox[i])
+                    output['coconut_idx'].append(i)
+
+            output['coconut_ori_img'] = torch.stack(output['coconut_ori_img']).to(self.device)
+            output['coconut_ori_img'] = (output['coconut_ori_img'] - self.pixel_mean) / self.pixel_std
+            output['coconut_trimap'] = torch.stack(output['coconut_trimap']).to(self.device)
+            output['coconut_bbox'] = torch.stack(output['coconut_bbox']).to(self.device)
+
+        images = batched_inputs["image"].to(self.device)
+        images = (images - self.pixel_mean) / self.pixel_std
+        assert images.shape[-2] == images.shape[-1] == 1024
+
+        if 'trimap' in batched_inputs.keys():
+            trimap = batched_inputs["trimap"].to(self.device)
+            assert len(torch.unique(trimap)) <= 3
+        else:
+            trimap = None
+
+        output['images'] = images
+        output['bbox'] = bbox
+        output['alpha'] = alpha
+        output['trimap'] = trimap
+
+        if 'hr_images' in batched_inputs.keys():
+            hr_images = batched_inputs["hr_images"].to(self.device)
+            hr_images = (hr_images - self.pixel_mean) / self.pixel_std
+            _, _, H, W = hr_images.shape
+            if hr_images.shape[-1] % 16 != 0 or hr_images.shape[-2] % 16 != 0:
+                new_H = (16 - hr_images.shape[-2] % 16) + H if hr_images.shape[-2] % 16 != 0 else H
+                new_W = (16 - hr_images.shape[-1] % 16) + W if hr_images.shape[-1] % 16 != 0 else W
+                new_hr_images = torch.zeros((hr_images.shape[0], hr_images.shape[1], new_H, new_W)).to(self.device)
+                new_hr_images[:,:,:H,:W] = hr_images[:,:,:,:]
+                del hr_images
+                hr_images = new_hr_images
+            output['hr_images'] = hr_images
+            output['hr_images_ori_h_w'] = (H, W)
+
+        if 'dataset_name' in batched_inputs.keys():
+            output['dataset_name'] = batched_inputs["dataset_name"]
+
+        if self.backbone_condition:
+            if self.w_only_bbox_cond:
+                output['condition'] = output['bbox'][:, 0, :]
+            else:
+                multi_fg_float = batched_inputs["multi_fg"].to(bbox.device).float()[:, None] * 512
+                output['condition'] = torch.concat((output['bbox'][:, 0, :], multi_fg_float), dim=-1)
+        else:
+            output['condition'] = None
+
+        return output
diff --git a/modeling/semantic_enhanced_matting/__init__.py b/modeling/semantic_enhanced_matting/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..34433aa701e25ef6cb385b2de9ba7a82037822b4
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .build_sam import (
+    build_sam,
+    build_sam_vit_h,
+    build_sam_vit_l,
+    build_sam_vit_b,
+    sam_model_registry,
+)
+from .build_sam_baseline import sam_model_registry_baseline
+from .predictor import SamPredictor
+from .automatic_mask_generator import SamAutomaticMaskGenerator
+from .mask_decoder_matting import MaskDecoderMatting
\ No newline at end of file
diff --git a/modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a8ced693e3c6b0b6fe71f8d222761b3204bf42d
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..397755a740bdc7ef3a58cf23d9a276a633e6d78e
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/automatic_mask_generator.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/__pycache__/build_sam.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/build_sam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6859e02f1a93531858c6d7deb5fc6bf9f510bbb7
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/build_sam.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/__pycache__/build_sam_baseline.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/build_sam_baseline.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f254182bc71c3cf0d1751c5863b0531e79ab5f9b
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/build_sam_baseline.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/__pycache__/condition_conv.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/condition_conv.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..761af4e81af4ad07402a1a3d701c113aa8fdbd41
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/condition_conv.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/__pycache__/feature_fusion.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/feature_fusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..855fbbee0eb2a895a07ca1b5e676945349017c19
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/feature_fusion.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/__pycache__/mask_decoder_matting.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/mask_decoder_matting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d07377cf37c8e22364401350ac5a8e47b65ffce4
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/mask_decoder_matting.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/__pycache__/predictor.cpython-38.pyc b/modeling/semantic_enhanced_matting/__pycache__/predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7935cc5accf94b9cbd8604132d43ff8534f80de
Binary files /dev/null and b/modeling/semantic_enhanced_matting/__pycache__/predictor.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/automatic_mask_generator.py b/modeling/semantic_enhanced_matting/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..427ebebd831f848dfff219f695c45302228e449a
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/automatic_mask_generator.py
@@ -0,0 +1,374 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area  # type: ignore
+
+from typing import Any, Dict, List, Optional, Tuple
+
+from .modeling import Sam
+from .predictor import SamPredictor
+from .utils.amg import (
+    MaskData,
+    area_from_rle,
+    batch_iterator,
+    batched_mask_to_box,
+    box_xyxy_to_xywh,
+    build_all_layer_point_grids,
+    calculate_stability_score,
+    coco_encode_rle,
+    generate_crop_boxes,
+    is_box_near_crop_edge,
+    mask_to_rle_pytorch,
+    remove_small_regions,
+    rle_to_mask,
+    uncrop_boxes_xyxy,
+    uncrop_masks,
+    uncrop_points,
+)
+
+
+class SamAutomaticMaskGenerator:
+    def __init__(
+        self,
+        model: Sam,
+        points_per_side: Optional[int] = 32,
+        points_per_batch: int = 64,
+        pred_iou_thresh: float = 0.88,
+        stability_score_thresh: float = 0.95,
+        stability_score_offset: float = 1.0,
+        box_nms_thresh: float = 0.7,
+        crop_n_layers: int = 0,
+        crop_nms_thresh: float = 0.7,
+        crop_overlap_ratio: float = 512 / 1500,
+        crop_n_points_downscale_factor: int = 1,
+        point_grids: Optional[List[np.ndarray]] = None,
+        min_mask_region_area: int = 0,
+        output_mode: str = "binary_mask",
+    ) -> None:
+        """
+        Using a SAM model, generates masks for the entire image.
+        Generates a grid of point prompts over the image, then filters
+        low quality and duplicate masks. The default settings are chosen
+        for SAM with a ViT-H backbone.
+
+        Arguments:
+          model (Sam): The SAM model to use for mask prediction.
+          points_per_side (int or None): The number of points to be sampled
+            along one side of the image. The total number of points is
+            points_per_side**2. If None, 'point_grids' must provide explicit
+            point sampling.
+          points_per_batch (int): Sets the number of points run simultaneously
+            by the model. Higher numbers may be faster but use more GPU memory.
+          pred_iou_thresh (float): A filtering threshold in [0,1], using the
+            model's predicted mask quality.
+          stability_score_thresh (float): A filtering threshold in [0,1], using
+            the stability of the mask under changes to the cutoff used to binarize
+            the model's mask predictions.
+          stability_score_offset (float): The amount to shift the cutoff when
+            calculated the stability score.
+          box_nms_thresh (float): The box IoU cutoff used by non-maximal
+            suppression to filter duplicate masks.
+          crop_n_layers (int): If >0, mask prediction will be run again on
+            crops of the image. Sets the number of layers to run, where each
+            layer has 2**i_layer number of image crops.
+          crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+            suppression to filter duplicate masks between different crops.
+          crop_overlap_ratio (float): Sets the degree to which crops overlap.
+            In the first crop layer, crops will overlap by this fraction of
+            the image length. Later layers with more crops scale down this overlap.
+          crop_n_points_downscale_factor (int): The number of points-per-side
+            sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+          point_grids (list(np.ndarray) or None): A list over explicit grids
+            of points used for sampling, normalized to [0,1]. The nth grid in the
+            list is used in the nth crop layer. Exclusive with points_per_side.
+          min_mask_region_area (int): If >0, postprocessing will be applied
+            to remove disconnected regions and holes in masks with area smaller
+            than min_mask_region_area. Requires opencv.
+          output_mode (str): The form masks are returned in. Can be 'binary_mask',
+            'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+            For large resolutions, 'binary_mask' may consume large amounts of
+            memory.
+        """
+
+        assert (points_per_side is None) != (
+            point_grids is None
+        ), "Exactly one of points_per_side or point_grid must be provided."
+        if points_per_side is not None:
+            self.point_grids = build_all_layer_point_grids(
+                points_per_side,
+                crop_n_layers,
+                crop_n_points_downscale_factor,
+            )
+        elif point_grids is not None:
+            self.point_grids = point_grids
+        else:
+            raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+        assert output_mode in [
+            "binary_mask",
+            "uncompressed_rle",
+            "coco_rle",
+        ], f"Unknown output_mode {output_mode}."
+        if output_mode == "coco_rle":
+            from pycocotools import mask as mask_utils  # type: ignore # noqa: F401
+
+        if min_mask_region_area > 0:
+            import cv2  # type: ignore # noqa: F401
+
+        self.predictor = SamPredictor(model)
+        self.points_per_batch = points_per_batch
+        self.pred_iou_thresh = pred_iou_thresh
+        self.stability_score_thresh = stability_score_thresh
+        self.stability_score_offset = stability_score_offset
+        self.box_nms_thresh = box_nms_thresh
+        self.crop_n_layers = crop_n_layers
+        self.crop_nms_thresh = crop_nms_thresh
+        self.crop_overlap_ratio = crop_overlap_ratio
+        self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+        self.min_mask_region_area = min_mask_region_area
+        self.output_mode = output_mode
+
+    @torch.no_grad()
+    def generate(self, image: np.ndarray, multimask_output: bool = True) -> List[Dict[str, Any]]:
+        """
+        Generates masks for the given image.
+
+        Arguments:
+          image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+        Returns:
+           list(dict(str, any)): A list over records for masks. Each record is
+             a dict containing the following keys:
+               segmentation (dict(str, any) or np.ndarray): The mask. If
+                 output_mode='binary_mask', is an array of shape HW. Otherwise,
+                 is a dictionary containing the RLE.
+               bbox (list(float)): The box around the mask, in XYWH format.
+               area (int): The area in pixels of the mask.
+               predicted_iou (float): The model's own prediction of the mask's
+                 quality. This is filtered by the pred_iou_thresh parameter.
+               point_coords (list(list(float))): The point coordinates input
+                 to the model to generate this mask.
+               stability_score (float): A measure of the mask's quality. This
+                 is filtered on using the stability_score_thresh parameter.
+               crop_box (list(float)): The crop of the image used to generate
+                 the mask, given in XYWH format.
+        """
+
+        # Generate masks
+        mask_data = self._generate_masks(image, multimask_output)
+
+        # Filter small disconnected regions and holes in masks
+        if self.min_mask_region_area > 0:
+            mask_data = self.postprocess_small_regions(
+                mask_data,
+                self.min_mask_region_area,
+                max(self.box_nms_thresh, self.crop_nms_thresh),
+            )
+
+        # Encode masks
+        if self.output_mode == "coco_rle":
+            mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
+        elif self.output_mode == "binary_mask":
+            mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+        else:
+            mask_data["segmentations"] = mask_data["rles"]
+
+        # Write mask records
+        curr_anns = []
+        for idx in range(len(mask_data["segmentations"])):
+            ann = {
+                "segmentation": mask_data["segmentations"][idx],
+                "area": area_from_rle(mask_data["rles"][idx]),
+                "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+                "predicted_iou": mask_data["iou_preds"][idx].item(),
+                "point_coords": [mask_data["points"][idx].tolist()],
+                "stability_score": mask_data["stability_score"][idx].item(),
+                "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+            }
+            curr_anns.append(ann)
+
+        return curr_anns
+
+    def _generate_masks(self, image: np.ndarray, multimask_output: bool = True) -> MaskData:
+        orig_size = image.shape[:2]
+        crop_boxes, layer_idxs = generate_crop_boxes(
+            orig_size, self.crop_n_layers, self.crop_overlap_ratio
+        )
+
+        # Iterate over image crops
+        data = MaskData()
+        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+            crop_data = self._process_crop(image, crop_box, layer_idx, orig_size, multimask_output)
+            data.cat(crop_data)
+
+        # Remove duplicate masks between crops
+        if len(crop_boxes) > 1:
+            # Prefer masks from smaller crops
+            scores = 1 / box_area(data["crop_boxes"])
+            scores = scores.to(data["boxes"].device)
+            keep_by_nms = batched_nms(
+                data["boxes"].float(),
+                scores,
+                torch.zeros_like(data["boxes"][:, 0]),  # categories
+                iou_threshold=self.crop_nms_thresh,
+            )
+            data.filter(keep_by_nms)
+
+        data.to_numpy()
+        return data
+
+    def _process_crop(
+        self,
+        image: np.ndarray,
+        crop_box: List[int],
+        crop_layer_idx: int,
+        orig_size: Tuple[int, ...],
+        multimask_output: bool = True,
+    ) -> MaskData:
+        # Crop the image and calculate embeddings
+        x0, y0, x1, y1 = crop_box
+        cropped_im = image[y0:y1, x0:x1, :]
+        cropped_im_size = cropped_im.shape[:2]
+        self.predictor.set_image(cropped_im)
+
+        # Get points for this crop
+        points_scale = np.array(cropped_im_size)[None, ::-1]
+        points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+        # Generate masks for this crop in batches
+        data = MaskData()
+        for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+            batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size, multimask_output)
+            data.cat(batch_data)
+            del batch_data
+        self.predictor.reset_image()
+
+        # Remove duplicates within this crop.
+        keep_by_nms = batched_nms(
+            data["boxes"].float(),
+            data["iou_preds"],
+            torch.zeros_like(data["boxes"][:, 0]),  # categories
+            iou_threshold=self.box_nms_thresh,
+        )
+        data.filter(keep_by_nms)
+
+        # Return to the original image frame
+        data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+        data["points"] = uncrop_points(data["points"], crop_box)
+        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+        return data
+
+    def _process_batch(
+        self,
+        points: np.ndarray,
+        im_size: Tuple[int, ...],
+        crop_box: List[int],
+        orig_size: Tuple[int, ...],
+        multimask_output: bool = True,
+    ) -> MaskData:
+        orig_h, orig_w = orig_size
+
+        # Run model on this batch
+        transformed_points = self.predictor.transform.apply_coords(points, im_size)
+        in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
+        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+        masks, iou_preds, _ = self.predictor.predict_torch(
+            in_points[:, None, :],
+            in_labels[:, None],
+            multimask_output=multimask_output,
+            return_logits=True,
+        )
+
+        # Serialize predictions and store in MaskData
+        data = MaskData(
+            masks=masks.flatten(0, 1),
+            iou_preds=iou_preds.flatten(0, 1),
+            points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
+        )
+        del masks
+
+        # Filter by predicted IoU
+        if self.pred_iou_thresh > 0.0:
+            keep_mask = data["iou_preds"] > self.pred_iou_thresh
+            data.filter(keep_mask)
+
+        # Calculate stability score
+        data["stability_score"] = calculate_stability_score(
+            data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
+        )
+        if self.stability_score_thresh > 0.0:
+            keep_mask = data["stability_score"] >= self.stability_score_thresh
+            data.filter(keep_mask)
+
+        # Threshold masks and calculate boxes
+        data["masks"] = data["masks"] > self.predictor.model.mask_threshold
+        data["boxes"] = batched_mask_to_box(data["masks"])
+
+        # Filter boxes that touch crop boundaries
+        keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
+        if not torch.all(keep_mask):
+            data.filter(keep_mask)
+
+        # Compress to RLE
+        data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+        data["rles"] = mask_to_rle_pytorch(data["masks"])
+        del data["masks"]
+
+        return data
+
+    @staticmethod
+    def postprocess_small_regions(
+        mask_data: MaskData, min_area: int, nms_thresh: float
+    ) -> MaskData:
+        """
+        Removes small disconnected regions and holes in masks, then reruns
+        box NMS to remove any new duplicates.
+
+        Edits mask_data in place.
+
+        Requires open-cv as a dependency.
+        """
+        if len(mask_data["rles"]) == 0:
+            return mask_data
+
+        # Filter small disconnected regions and holes
+        new_masks = []
+        scores = []
+        for rle in mask_data["rles"]:
+            mask = rle_to_mask(rle)
+
+            mask, changed = remove_small_regions(mask, min_area, mode="holes")
+            unchanged = not changed
+            mask, changed = remove_small_regions(mask, min_area, mode="islands")
+            unchanged = unchanged and not changed
+
+            new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+            # Give score=0 to changed masks and score=1 to unchanged masks
+            # so NMS will prefer ones that didn't need postprocessing
+            scores.append(float(unchanged))
+
+        # Recalculate boxes and remove any new duplicates
+        masks = torch.cat(new_masks, dim=0)
+        boxes = batched_mask_to_box(masks)
+        keep_by_nms = batched_nms(
+            boxes.float(),
+            torch.as_tensor(scores),
+            torch.zeros_like(boxes[:, 0]),  # categories
+            iou_threshold=nms_thresh,
+        )
+
+        # Only recalculate RLEs for masks that have changed
+        for i_mask in keep_by_nms:
+            if scores[i_mask] == 0.0:
+                mask_torch = masks[i_mask].unsqueeze(0)
+                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
+        mask_data.filter(keep_by_nms)
+
+        return mask_data
diff --git a/modeling/semantic_enhanced_matting/build_sam.py b/modeling/semantic_enhanced_matting/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..242944a15470ae6975c0e893e7e498461844db58
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/build_sam.py
@@ -0,0 +1,234 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from functools import partial
+
+from .modeling import ImageEncoderViT, MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer, TinyViT
+from .modeling.mask_decoder_hq_matting import MaskDecoderHQMatting
+
+
+def build_sam_vit_h(checkpoint=None):
+    return _build_sam(
+        encoder_embed_dim=1280,
+        encoder_depth=32,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[7, 15, 23, 31],
+        checkpoint=checkpoint,
+    )
+
+
+build_sam = build_sam_vit_h
+
+
+def build_sam_vit_l(checkpoint=None, matting_token=0, wo_hq=False, frozen_decoder=False, mask_matting_res_add=True):
+    return _build_sam(
+        encoder_embed_dim=1024,
+        encoder_depth=24,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[5, 11, 17, 23],
+        checkpoint=checkpoint,
+        matting_token=matting_token,
+        wo_hq=wo_hq,
+        frozen_decoder=frozen_decoder,
+        mask_matting_res_add=mask_matting_res_add
+    )
+
+
+def build_sam_vit_b(checkpoint=None, matting_token=False, wo_hq=False, frozen_decoder=False):
+    return _build_sam(
+        encoder_embed_dim=768,
+        encoder_depth=12,
+        encoder_num_heads=12,
+        encoder_global_attn_indexes=[2, 5, 8, 11],
+        checkpoint=checkpoint,
+        matting_token=matting_token,
+        wo_hq=wo_hq,
+        frozen_decoder=frozen_decoder
+    )
+
+
+def build_sam_vit_t(checkpoint=None):
+    prompt_embed_dim = 256
+    image_size = 1024
+    vit_patch_size = 16
+    image_embedding_size = image_size // vit_patch_size
+    mobile_sam = Sam(
+            image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
+                embed_dims=[64, 128, 160, 320],
+                depths=[2, 2, 6, 2],
+                num_heads=[2, 4, 5, 10],
+                window_sizes=[7, 7, 14, 7],
+                mlp_ratio=4.,
+                drop_rate=0.,
+                drop_path_rate=0.0,
+                use_checkpoint=False,
+                mbconv_expand_ratio=4.0,
+                local_conv_size=3,
+                layer_lr_decay=0.8
+            ),
+            prompt_encoder=PromptEncoder(
+            embed_dim=prompt_embed_dim,
+            image_embedding_size=(image_embedding_size, image_embedding_size),
+            input_image_size=(image_size, image_size),
+            mask_in_chans=16,
+            ),
+            mask_decoder=MaskDecoderHQ(
+                    num_multimask_outputs=3,
+                    transformer=TwoWayTransformer(
+                    depth=2,
+                    embedding_dim=prompt_embed_dim,
+                    mlp_dim=2048,
+                    num_heads=8,
+                ),
+                transformer_dim=prompt_embed_dim,
+                iou_head_depth=3,
+                iou_head_hidden_dim=256,
+                vit_dim=160,
+            ),
+            pixel_mean=[123.675, 116.28, 103.53],
+            pixel_std=[58.395, 57.12, 57.375],
+        )
+
+    mobile_sam.eval()
+    if checkpoint is not None:
+        with open(checkpoint, "rb") as f:
+            device = "cuda" if torch.cuda.is_available() else "cpu"
+            state_dict = torch.load(f, map_location=device)
+        info = mobile_sam.load_state_dict(state_dict, strict=False)
+        print(info)
+    for n, p in mobile_sam.named_parameters():
+        if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
+            p.requires_grad = False
+    return mobile_sam
+
+sam_model_registry = {
+    "default": build_sam_vit_h,
+    "vit_h": build_sam_vit_h,
+    "vit_l": build_sam_vit_l,
+    "vit_b": build_sam_vit_b,
+    "vit_tiny": build_sam_vit_t
+}
+
+def sam_model_registry_def(model_type, checkpoint, matting_token = 0, wo_hq = False, frozen_decoder = False, mask_matting_res_add=True):
+    assert model_type in {"default", "vit_h", "vit_l", "vit_b", "vit_tiny"}
+    return sam_model_registry[model_type](checkpoint=checkpoint, matting_token=matting_token, wo_hq=wo_hq, frozen_decoder=frozen_decoder, mask_matting_res_add=mask_matting_res_add)
+
+def _build_sam(
+    encoder_embed_dim,
+    encoder_depth,
+    encoder_num_heads,
+    encoder_global_attn_indexes,
+    checkpoint=None,
+    matting_token=0,
+    wo_hq=False,
+    frozen_decoder=False,
+    mask_matting_res_add=True
+):
+    # no_res_add only work when wo_hq and have mat ting token
+    if not mask_matting_res_add:
+        assert matting_token > 0
+
+    prompt_embed_dim = 256
+    image_size = 1024
+    vit_patch_size = 16
+    image_embedding_size = image_size // vit_patch_size
+
+    if matting_token > 0:
+        mask_decoder = MaskDecoderHQMatting(
+            num_multimask_outputs=3,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=prompt_embed_dim,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=prompt_embed_dim,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+            vit_dim=encoder_embed_dim,
+            wo_hq=wo_hq,
+            matting_token_num=matting_token,
+            mask_matting_res_add=mask_matting_res_add
+        )
+    else:
+        mask_decoder = MaskDecoderHQ(
+            num_multimask_outputs=3,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=prompt_embed_dim,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=prompt_embed_dim,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+            vit_dim=encoder_embed_dim,
+            wo_hq=wo_hq
+        )
+    sam = Sam(
+        image_encoder=ImageEncoderViT(
+            depth=encoder_depth,
+            embed_dim=encoder_embed_dim,
+            img_size=image_size,
+            mlp_ratio=4,
+            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+            num_heads=encoder_num_heads,
+            patch_size=vit_patch_size,
+            qkv_bias=True,
+            use_rel_pos=True,
+            global_attn_indexes=encoder_global_attn_indexes,
+            window_size=14,
+            out_chans=prompt_embed_dim,
+        ),
+        prompt_encoder=PromptEncoder(
+            embed_dim=prompt_embed_dim,
+            image_embedding_size=(image_embedding_size, image_embedding_size),
+            input_image_size=(image_size, image_size),
+            mask_in_chans=16,
+        ),
+        mask_decoder=mask_decoder,
+        pixel_mean=[123.675, 116.28, 103.53],
+        pixel_std=[58.395, 57.12, 57.375],
+    )
+    sam.eval()
+    if checkpoint is not None:
+        with open(checkpoint, "rb") as f:
+            device = "cuda" if torch.cuda.is_available() else "cpu"
+            state_dict = torch.load(f, map_location=device)
+        info = sam.load_state_dict(state_dict, strict=False)
+        print(info)
+
+    if frozen_decoder and checkpoint is not None:
+        sam.frozen_mask_decoder = MaskDecoderHQ(
+            num_multimask_outputs=3,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=prompt_embed_dim,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=prompt_embed_dim,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+            vit_dim=encoder_embed_dim,
+            wo_hq=wo_hq
+        )
+        sam.frozen_mask_decoder.eval()
+        info = sam.frozen_mask_decoder.load_state_dict({key.split('mask_decoder.')[1]: val for key, val in state_dict.items() if 'mask_decoder.' in key}, strict=False)
+        print('load frozen_mask_decoder', info)
+        # for n, p in sam.frozen_mask_decoder.named_parameters():
+        #     p = state_dict['mask_decoder.' + n]
+
+    for n, p in sam.named_parameters():
+        # if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
+        #     p.requires_grad = False
+        if 'matting' not in n:
+            p.requires_grad = False
+        # p.requires_grad = False
+
+    return sam
diff --git a/modeling/semantic_enhanced_matting/build_sam_baseline.py b/modeling/semantic_enhanced_matting/build_sam_baseline.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1d34d702821ef49dd451daa20bb3897e76357f2
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/build_sam_baseline.py
@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from functools import partial
+
+from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT
+
+
+def build_sam_vit_h(checkpoint=None):
+    return _build_sam(
+        encoder_embed_dim=1280,
+        encoder_depth=32,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[7, 15, 23, 31],
+        checkpoint=checkpoint,
+    )
+
+
+build_sam = build_sam_vit_h
+
+
+def build_sam_vit_l(checkpoint=None):
+    return _build_sam(
+        encoder_embed_dim=1024,
+        encoder_depth=24,
+        encoder_num_heads=16,
+        encoder_global_attn_indexes=[5, 11, 17, 23],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam_vit_b(checkpoint=None):
+    return _build_sam(
+        encoder_embed_dim=768,
+        encoder_depth=12,
+        encoder_num_heads=12,
+        encoder_global_attn_indexes=[2, 5, 8, 11],
+        checkpoint=checkpoint,
+    )
+
+
+def build_sam_vit_t(checkpoint=None):
+    prompt_embed_dim = 256
+    image_size = 1024
+    vit_patch_size = 16
+    image_embedding_size = image_size // vit_patch_size
+    mobile_sam = Sam(
+            image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
+                embed_dims=[64, 128, 160, 320],
+                depths=[2, 2, 6, 2],
+                num_heads=[2, 4, 5, 10],
+                window_sizes=[7, 7, 14, 7],
+                mlp_ratio=4.,
+                drop_rate=0.,
+                drop_path_rate=0.0,
+                use_checkpoint=False,
+                mbconv_expand_ratio=4.0,
+                local_conv_size=3,
+                layer_lr_decay=0.8
+            ),
+            prompt_encoder=PromptEncoder(
+            embed_dim=prompt_embed_dim,
+            image_embedding_size=(image_embedding_size, image_embedding_size),
+            input_image_size=(image_size, image_size),
+            mask_in_chans=16,
+            ),
+            mask_decoder=MaskDecoder(
+                    num_multimask_outputs=3,
+                    transformer=TwoWayTransformer(
+                    depth=2,
+                    embedding_dim=prompt_embed_dim,
+                    mlp_dim=2048,
+                    num_heads=8,
+                ),
+                transformer_dim=prompt_embed_dim,
+                iou_head_depth=3,
+                iou_head_hidden_dim=256,
+            ),
+            pixel_mean=[123.675, 116.28, 103.53],
+            pixel_std=[58.395, 57.12, 57.375],
+        )
+
+    mobile_sam.eval()
+    if checkpoint is not None:
+        with open(checkpoint, "rb") as f:
+            state_dict = torch.load(f)
+        mobile_sam.load_state_dict(state_dict)
+    return mobile_sam
+
+sam_model_registry_baseline = {
+    "default": build_sam_vit_h,
+    "vit_h": build_sam_vit_h,
+    "vit_l": build_sam_vit_l,
+    "vit_b": build_sam_vit_b,
+    "vit_tiny": build_sam_vit_t
+}
+
+
+def _build_sam(
+    encoder_embed_dim,
+    encoder_depth,
+    encoder_num_heads,
+    encoder_global_attn_indexes,
+    checkpoint=None,
+):
+    prompt_embed_dim = 256
+    image_size = 1024
+    vit_patch_size = 16
+    image_embedding_size = image_size // vit_patch_size
+    sam = Sam(
+        image_encoder=ImageEncoderViT(
+            depth=encoder_depth,
+            embed_dim=encoder_embed_dim,
+            img_size=image_size,
+            mlp_ratio=4,
+            norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
+            num_heads=encoder_num_heads,
+            patch_size=vit_patch_size,
+            qkv_bias=True,
+            use_rel_pos=True,
+            global_attn_indexes=encoder_global_attn_indexes,
+            window_size=14,
+            out_chans=prompt_embed_dim,
+        ),
+        prompt_encoder=PromptEncoder(
+            embed_dim=prompt_embed_dim,
+            image_embedding_size=(image_embedding_size, image_embedding_size),
+            input_image_size=(image_size, image_size),
+            mask_in_chans=16,
+        ),
+        mask_decoder=MaskDecoder(
+            num_multimask_outputs=3,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=prompt_embed_dim,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=prompt_embed_dim,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+        ),
+        pixel_mean=[123.675, 116.28, 103.53],
+        pixel_std=[58.395, 57.12, 57.375],
+    )
+    sam.eval()
+    if checkpoint is not None:
+        with open(checkpoint, "rb") as f:
+            state_dict = torch.load(f)
+        sam.load_state_dict(state_dict)
+    return sam
\ No newline at end of file
diff --git a/modeling/semantic_enhanced_matting/condition_conv.py b/modeling/semantic_enhanced_matting/condition_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..c03b8e48ffa98b4cea4b57f6a95492c9cd6b9c33
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/condition_conv.py
@@ -0,0 +1,504 @@
+import math
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from detectron2.layers import Conv2d
+import fvcore.nn.weight_init as weight_init
+from typing import Any, Optional, Tuple, Type
+
+from modeling.semantic_enhanced_matting.modeling.image_encoder import Attention
+from modeling.semantic_enhanced_matting.modeling.transformer import Attention as DownAttention
+from modeling.semantic_enhanced_matting.feature_fusion import PositionEmbeddingRandom as ImagePositionEmbedding
+from modeling.semantic_enhanced_matting.modeling.common import MLPBlock
+
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
+    
+
+class ConditionConv(nn.Module):
+    """
+    The standard bottleneck residual block without the last activation layer.
+    It contains 3 conv layers with kernels 1x1, 3x3, 1x1.
+    """
+
+    def __init__(
+        self,
+        in_channels,
+        out_channels,
+        bottleneck_channels,
+        norm=LayerNorm2d,
+        act_layer=nn.GELU,
+        conv_kernels=3,
+        conv_paddings=1,
+        condtition_channels = 1024
+    ):
+        """
+        Args:
+            in_channels (int): Number of input channels.
+            out_channels (int): Number of output channels.
+            bottleneck_channels (int): number of output channels for the 3x3
+                "bottleneck" conv layers.
+            norm (str or callable): normalization for all conv layers.
+                See :func:`layers.get_norm` for supported format.
+            act_layer (callable): activation for all conv layers.
+        """
+        super().__init__()
+
+        self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False)
+        self.norm1 = norm(bottleneck_channels)
+        self.act1 = act_layer()
+
+        self.conv2 = Conv2d(
+            bottleneck_channels,
+            bottleneck_channels,
+            conv_kernels,
+            padding=conv_paddings,
+            bias=False,
+        )
+        self.norm2 = norm(bottleneck_channels)
+        self.act2 = act_layer()
+
+        self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False)
+        self.norm3 = norm(out_channels)
+
+        self.init_weight()
+
+        self.condition_embedding = nn.Sequential(
+            act_layer(),
+            nn.Linear(condtition_channels, bottleneck_channels, bias=True)
+        )
+
+    def init_weight(self):
+        for layer in [self.conv1, self.conv2, self.conv3]:
+            weight_init.c2_msra_fill(layer)
+        for layer in [self.norm1, self.norm2]:
+            layer.weight.data.fill_(1.0)
+            layer.bias.data.zero_()
+        # zero init last norm layer.
+        self.norm3.weight.data.zero_()
+        self.norm3.bias.data.zero_()
+
+    # def embed_bbox_and_instance(self, bbox, instance):
+    #     assert isinstance(instance, bool)
+
+    def forward(self, x, condition):
+        # [B, 64, 64, 1024]
+        out = x.permute(0, 3, 1, 2)
+
+        out = self.act1(self.norm1(self.conv1(out)))
+        out = self.conv2(out) + self.condition_embedding(condition)[:, :, None, None]
+        out = self.act2(self.norm2(out))
+        out = self.norm3(self.conv3(out))
+
+        out = x + out.permute(0, 2, 3, 1)
+        return out
+
+
+class ConditionAdd(nn.Module):
+    def __init__(
+        self,
+        act_layer=nn.GELU,
+        condtition_channels = 1024
+    ):
+        super().__init__()
+
+        self.condition_embedding = nn.Sequential(
+            act_layer(),
+            nn.Linear(condtition_channels, condtition_channels, bias=True)
+        )
+
+    def forward(self, x, condition):
+        # [B, 64, 64, 1024]
+        condition = self.condition_embedding(condition)[:, None, None, :]
+        return x + condition
+
+class ConditionEmbedding(nn.Module):
+    def __init__(
+        self,
+        condition_num = 5,
+        pos_embedding_dim = 128,
+        embedding_scale = 1.0,
+        embedding_max_period = 10000,
+        embedding_flip_sin_to_cos = True,
+        embedding_downscale_freq_shift = 1.0,
+        time_embed_dim = 1024,
+        split_embed = False
+    ):
+        super().__init__()
+        self.condition_num = condition_num
+        self.pos_embedding_dim = pos_embedding_dim
+        self.embedding_scale = embedding_scale
+        self.embedding_max_period = embedding_max_period
+        self.embedding_flip_sin_to_cos = embedding_flip_sin_to_cos
+        self.embedding_downscale_freq_shift = embedding_downscale_freq_shift
+        self.split_embed = split_embed
+
+        if self.split_embed:
+            self.linear_1 = nn.Linear(pos_embedding_dim, time_embed_dim, True)
+        else:
+            self.linear_1 = nn.Linear(condition_num * pos_embedding_dim, time_embed_dim, True)
+        self.act = nn.GELU()
+        self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, True)
+
+    def proj_embedding(self, condition):
+        sample = self.linear_1(condition)
+        sample = self.act(sample)
+        sample = self.linear_2(sample)
+        return sample
+    
+    def position_embedding(self, condition):
+        # [B, 5] --> [B, 5, 128] --> [B, 5 * 128]
+
+        assert condition.shape[-1] == self.condition_num
+
+        half_dim = self.pos_embedding_dim // 2
+        exponent = -math.log(self.embedding_max_period) * torch.arange(
+            start=0, end=half_dim, dtype=torch.float32, device=condition.device
+        )
+        exponent = exponent / (half_dim - self.embedding_downscale_freq_shift)
+
+        emb = torch.exp(exponent)
+        emb = condition[:, :, None].float() * emb[None, None, :]  # [B, 5, 1] * [1, 1, 64] --> [B, 5, 64]
+
+        # scale embeddings
+        emb = self.embedding_scale * emb
+
+        # concat sine and cosine embeddings
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)  # [B, 5, 64] --> [B, 5, 128]
+
+        # flip sine and cosine embeddings
+        if self.embedding_flip_sin_to_cos:
+            emb = torch.cat([emb[:, :, half_dim:], emb[:, :, :half_dim]], dim=-1)
+
+        # zero pad
+        # if self.pos_embedding_dim % 2 == 1:
+        #     emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+        if self.split_embed:
+            emb = emb.reshape(-1, emb.shape[-1])
+        else:
+            emb = emb.reshape(emb.shape[0], -1)
+
+        return emb
+
+    def forward(self, condition):
+        condition = self.position_embedding(condition)
+        condition = self.proj_embedding(condition)
+        return condition.float()
+
+
+
+class PositionEmbeddingRandom(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+    """
+
+    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+        super().__init__()
+        if scale is None or scale <= 0.0:
+            scale = 1.0
+        
+        self.positional_encoding_gaussian_matrix = nn.Parameter(scale * torch.randn((2, num_pos_feats // 2)))
+        # self.register_buffer(
+        #     "positional_encoding_gaussian_matrix",
+        #     scale * torch.randn((2, num_pos_feats)),
+        # )
+        point_embeddings = [nn.Embedding(1, num_pos_feats) for i in range(2)]
+        self.point_embeddings = nn.ModuleList(point_embeddings)
+
+    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+        """Positionally encode points that are normalized to [0,1]."""
+        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+        coords = 2 * coords - 1
+        coords = coords @ self.positional_encoding_gaussian_matrix
+        coords = 2 * np.pi * coords
+        # outputs d_1 x ... x d_n x C shape
+        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+    def forward(
+        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+    ) -> torch.Tensor:
+        """Positionally encode points that are not normalized to [0,1]."""
+        coords = coords_input.clone()
+        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+        coords =  self._pe_encoding(coords.to(torch.float))  # B x N x C
+
+        coords[:, 0, :] += self.point_embeddings[0].weight
+        coords[:, 1, :] += self.point_embeddings[1].weight
+
+        return coords
+
+
+class CrossSelfAttn(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+    """
+
+    def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None:
+        super().__init__()
+
+        self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
+        self.norm1 = nn.LayerNorm(embedding_dim)
+        self.mlp = MLPBlock(embedding_dim, mlp_dim=512)
+        self.norm2 = nn.LayerNorm(embedding_dim)
+        self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
+        self.norm3 = nn.LayerNorm(embedding_dim)
+
+    def forward(self, block_feat, bbox_token, feat_pe, bbox_pe):
+        B, H, W, C = block_feat.shape
+        block_feat = block_feat.reshape(B, H * W, C)
+
+        block_feat = block_feat + self.cross_attn(q=block_feat + feat_pe, k=bbox_token + bbox_pe, v=bbox_token)
+        block_feat = self.norm1(block_feat)
+
+        block_feat = block_feat + self.mlp(block_feat)
+        block_feat = self.norm2(block_feat)
+
+        concat_token = torch.concat((block_feat + feat_pe, bbox_token + bbox_pe), dim=1) 
+        block_feat = block_feat + self.self_attn(q=concat_token, k=concat_token, v=concat_token)[:, :-bbox_token.shape[1]]
+        block_feat = self.norm3(block_feat)
+        output = block_feat.reshape(B, H, W, C)
+
+        return output
+
+
+class BBoxEmbedInteract(nn.Module):
+    def __init__(
+        self,
+        embed_type = 'fourier',
+        interact_type = 'attn',
+        layer_num = 3
+    ):
+        super().__init__()
+        assert embed_type in {'fourier', 'position', 'conv'}
+        assert interact_type in {'add', 'attn', 'cross-self-attn'}
+        self.embed_type = embed_type
+        self.interact_type = interact_type
+        self.layer_num = layer_num
+
+        if self.embed_type == 'fourier' and self.interact_type == 'add':
+            self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256)
+        elif self.embed_type == 'fourier':
+            self.embed_layer = ConditionEmbedding(condition_num = 4, pos_embedding_dim = 256, split_embed = True)
+        elif self.embed_type == 'conv':
+            mask_in_chans = 16
+            activation = nn.GELU
+            self.embed_layer = nn.Sequential(
+                nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+                LayerNorm2d(mask_in_chans // 4),
+                activation(),
+                nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+                LayerNorm2d(mask_in_chans),
+                activation(),
+                nn.Conv2d(mask_in_chans, 1024, kernel_size=1),
+            )
+        else:
+            if self.interact_type == 'add':
+                self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 512)
+            else:
+                self.embed_layer = PositionEmbeddingRandom(num_pos_feats = 1024)
+
+        self.interact_layer = nn.ModuleList()
+        for _ in range(self.layer_num):
+            if self.interact_type == 'attn':
+                self.interact_layer.append(Attention(dim = 1024))
+            elif self.interact_type == 'add' and self.embed_type != 'conv':
+                self.interact_layer.append(nn.Sequential(
+                    nn.GELU(),
+                    nn.Linear(1024, 1024, bias=True)
+                ))
+            elif self.interact_type == 'cross-self-attn':
+                self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4))
+
+            self.position_layer = ImagePositionEmbedding(num_pos_feats=1024 // 2)
+
+    def forward(self, block_feat, bbox, layer_index):
+        # input: [B, 1, 4], [B, 64, 64, 1024]
+        if layer_index == self.layer_num:
+            return block_feat
+        interact_layer = self.interact_layer[layer_index]
+
+        bbox = bbox + 0.5  # Shift to center of pixel
+        if self.embed_type == 'fourier' and self.interact_type == 'add':
+            embedding = self.embed_layer(bbox[:, 0])  # [B, 1, 4] --> reshape [B, 4] --> [B, 1024 * 1] --> reshape [B, 1, 1024]
+            embedding = embedding.reshape(embedding.shape[0], 1, -1)
+        elif self.embed_type == 'fourier':
+            embedding = self.embed_layer(bbox[:, 0])  # [B, 1, 4] --> reshape [B, 4] --> [B, 1024 * 4] --> reshape [B, 4, 1024]
+            embedding = embedding.reshape(-1, 4, embedding.shape[-1])
+        elif self.embed_type == 'conv':
+            # concat mask and img as condition
+            bbox_mask = torch.zeros(size=(block_feat.shape[0], 1, 256, 256), device=block_feat.device, dtype=block_feat.dtype)  # [B, 1, 512, 512]
+            for i in range(bbox.shape[0]):
+                l, u, r, d = bbox[i, 0, :] / 4
+                bbox_mask[i, :, int(u + 0.5): int(d + 0.5), int(l + 0.5): int(r + 0.5)] = 1.0  # int(x + 0.5) = round(x)
+            embedding = self.embed_layer(bbox_mask)  # [B, 1024, 64, 64]
+        elif self.embed_type == 'position':
+            embedding = self.embed_layer(bbox.reshape(-1, 2, 2), (1024, 1024))  # [B, 1, 4] --> reshape [B, 2, 2] --> [B, 2, 1024/512]
+            if self.interact_type == 'add':
+                embedding = embedding.reshape(embedding.shape[0], 1, -1)
+
+        # add position embedding to block_feat
+        pe = self.position_layer(size=(64, 64)).reshape(1, 64, 64, 1024)
+        block_feat = block_feat + pe
+
+        if self.interact_type == 'attn':
+            add_token_num = embedding.shape[1]
+            B, H, W, C = block_feat.shape
+            block_feat = block_feat.reshape(B, H * W, C)
+            concat_token = torch.concat((block_feat, embedding), dim=1)  # [B, 64 * 64 + 2, 1024]
+            output_token = interact_layer.forward_token(concat_token)[:, :-add_token_num]
+            output = output_token.reshape(B, H, W, C)
+        elif self.embed_type == 'conv':
+            output = block_feat + embedding.permute(0, 2, 3, 1)
+        elif self.interact_type == 'add':
+            output = interact_layer(embedding[:, None]) + block_feat
+        elif self.interact_type == 'cross-self-attn':
+            output = interact_layer(block_feat, embedding)
+
+        return output
+        
+
+# reuse the position_point_embedding in prompt_encoder
+class BBoxInteract(nn.Module):
+    def __init__(
+        self,
+        position_point_embedding,
+        point_weight,
+        layer_num = 3,
+    ):
+        super().__init__()
+
+        self.position_point_embedding = position_point_embedding
+        self.point_weight = point_weight
+        for _, p in self.named_parameters():
+            p.requires_grad = False
+
+        self.layer_num = layer_num
+        self.input_image_size = (1024, 1024)
+
+        self.interact_layer = nn.ModuleList()
+        for _ in range(self.layer_num):
+            self.interact_layer.append(CrossSelfAttn(embedding_dim=1024, num_heads=4, downsample_rate=4))
+    
+    @torch.no_grad()
+    def get_bbox_token(self, boxes):
+        boxes = boxes + 0.5  # Shift to center of pixel
+        coords = boxes.reshape(-1, 2, 2)
+        corner_embedding = self.position_point_embedding.forward_with_coords(coords, self.input_image_size)
+        corner_embedding[:, 0, :] += self.point_weight[2].weight
+        corner_embedding[:, 1, :] += self.point_weight[3].weight
+        corner_embedding = F.interpolate(corner_embedding[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0]
+        return corner_embedding  # [B, 2, 1024]
+    
+    @torch.no_grad()
+    def get_position_embedding(self, size=(64, 64)):
+        pe = self.position_point_embedding(size=size)
+        pe = F.interpolate(pe.permute(1, 2, 0)[..., None], size=(1024, 1), mode='bilinear', align_corners=False)[..., 0][None]
+        pe = pe.reshape(1, -1, 1024)
+        return pe  # [1, 64 * 64, 1024]
+
+    def forward(self, block_feat, bbox, layer_index):
+        # input: [B, 1, 4], [B, 64, 64, 1024]
+        if layer_index == self.layer_num:
+            return block_feat
+        interact_layer = self.interact_layer[layer_index]
+
+        pe = self.get_position_embedding()
+        bbox_token = self.get_bbox_token(bbox)
+
+        output = interact_layer(block_feat, bbox_token, feat_pe=pe, bbox_pe=bbox_token)
+
+        return output
+        
+class InOutBBoxCrossSelfAttn(nn.Module):
+
+    def __init__(self, embedding_dim=1024, num_heads=4, downsample_rate=4) -> None:
+        super().__init__()
+
+        self.self_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
+        self.norm1 = nn.LayerNorm(embedding_dim)
+        self.mlp = MLPBlock(embedding_dim, mlp_dim=embedding_dim // 2)
+        self.norm2 = nn.LayerNorm(embedding_dim)
+        self.cross_attn = DownAttention(embedding_dim=embedding_dim, num_heads=num_heads, downsample_rate=downsample_rate)
+        self.norm3 = nn.LayerNorm(embedding_dim)
+
+    def forward(self, in_box_token, out_box_token):
+
+        # self-attn
+        short_cut = in_box_token
+        in_box_token = self.norm1(in_box_token)
+        in_box_token = self.self_attn(q=in_box_token, k=in_box_token, v=in_box_token)
+        in_box_token = short_cut + in_box_token
+
+        # mlp
+        in_box_token = in_box_token + self.mlp(self.norm2(in_box_token))
+
+        # cross-attn
+        short_cut = in_box_token
+        in_box_token = self.norm3(in_box_token)
+        in_box_token = self.cross_attn(q=in_box_token, k=out_box_token, v=out_box_token)
+        in_box_token = short_cut + in_box_token
+
+        return in_box_token
+
+
+class BBoxInteractInOut(nn.Module):
+    def __init__(
+        self,
+        num_heads = 4, 
+        downsample_rate = 4,
+        layer_num = 3,
+    ):
+        super().__init__()
+
+        self.layer_num = layer_num
+        self.input_image_size = (1024, 1024)
+
+        self.interact_layer = nn.ModuleList()
+        for _ in range(self.layer_num):
+            self.interact_layer.append(InOutBBoxCrossSelfAttn(embedding_dim=1024, num_heads=num_heads, downsample_rate=downsample_rate))
+
+    def forward(self, block_feat, bbox, layer_index):
+
+        # input: [B, 1, 4], [B, 64, 64, 1024]
+        if layer_index == self.layer_num:
+            return block_feat
+        interact_layer = self.interact_layer[layer_index]
+
+        # split_in_out_bbox_token
+        bbox = torch.round(bbox / self.input_image_size[0] * (block_feat.shape[1] - 1)).int()
+        for i in range(block_feat.shape[0]):
+            in_bbox_mask = torch.zeros((block_feat.shape[1], block_feat.shape[2]), dtype=bool, device=bbox.device)
+            in_bbox_mask[bbox[i, 0, 1]: bbox[i, 0, 3], bbox[i, 0, 0]: bbox[i, 0, 2]] = True
+            in_bbox_token = block_feat[i: i + 1, in_bbox_mask, :]
+            out_bbox_token = block_feat[i: i + 1, ~in_bbox_mask, :]
+            block_feat[i, in_bbox_mask, :] = interact_layer(in_bbox_token, out_bbox_token)
+
+        return block_feat
+
+
+if __name__ == '__main__':
+    # emded = ConditionEmbedding()
+    # input = torch.tensor([[100, 200, 300, 400, 512], [100, 200, 300, 400, 1024]])
+    # print(input.shape)
+    # output = emded(input)  # [B, 5] --> [B, 5 * 128] --> [B, 1024]
+
+    emded = BBoxEmbedInteract(
+        embed_type = 'position',
+        interact_type = 'cross-self-attn'
+    )
+    input = torch.tensor([[[100, 200, 300, 400]], [[100, 200, 300, 400]]])  # [B, 1, 4]
+    print(input.shape)
+    output = emded(torch.randn((2, 64, 64, 1024)), input)  # [B, 5] --> [B, 5 * 128] --> [B, 1024]
\ No newline at end of file
diff --git a/modeling/semantic_enhanced_matting/feature_fusion.py b/modeling/semantic_enhanced_matting/feature_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..27970af201f9bb3769027f9674e6be955a333080
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/feature_fusion.py
@@ -0,0 +1,283 @@
+import torch
+import torch.nn as nn
+from typing import Type, Optional, Tuple
+import numpy as np
+
+from .modeling.transformer import Attention
+from .modeling.common import MLPBlock
+# from modeling.transformer import Attention
+# from modeling.common import MLPBlock
+
+
+class MutualCrossAttention(nn.Module):
+    def __init__(
+        self,
+        embedding_dim: int = 1024,
+        num_heads: int = 8,
+        mlp_dim: int = 1024,
+        activation: Type[nn.Module] = nn.GELU,
+        attention_downsample_rate: int = 4,
+    ) -> None:
+        super().__init__()
+
+        self.cross_attn_token_to_image = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+        self.norm1 = nn.LayerNorm(embedding_dim)
+
+        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+        self.norm2 = nn.LayerNorm(embedding_dim)
+
+        self.norm3 = nn.LayerNorm(embedding_dim)
+        self.cross_attn_image_to_token = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+
+    def forward(self, queries, keys, query_pe=None, key_pe=None):
+
+        # Cross attention block, tokens attending to image embedding
+        q = queries + query_pe if query_pe is not None else queries
+        k = keys + key_pe if key_pe is not None else keys
+        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm1(queries)
+
+        # MLP block
+        mlp_out = self.mlp(queries)
+        queries = queries + mlp_out
+        queries = self.norm2(queries)
+
+        # Cross attention block, image embedding attending to tokens
+        q = queries + query_pe if query_pe is not None else queries
+        k = keys + key_pe if key_pe is not None else keys
+        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+        keys = keys + attn_out
+        keys = self.norm3(keys)
+
+        return queries, keys
+
+
+class PositionEmbeddingRandom(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+    """
+
+    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+        super().__init__()
+        if scale is None or scale <= 0.0:
+            scale = 1.0
+        self.register_buffer(
+            "positional_encoding_gaussian_matrix",
+            scale * torch.randn((2, num_pos_feats)),
+        )
+
+    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+        """Positionally encode points that are normalized to [0,1]."""
+        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+        coords = 2 * coords - 1
+        coords = coords @ self.positional_encoding_gaussian_matrix
+        coords = 2 * np.pi * coords
+        # outputs d_1 x ... x d_n x C shape
+        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+        """Generate positional encoding for a grid of the specified size."""
+        h, w = size
+        device = self.positional_encoding_gaussian_matrix.device
+        grid = torch.ones((h, w), device=device, dtype=torch.float32)
+        y_embed = grid.cumsum(dim=0) - 0.5
+        x_embed = grid.cumsum(dim=1) - 0.5
+        y_embed = y_embed / h
+        x_embed = x_embed / w
+
+        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+        # return pe.permute(2, 0, 1)  # C x H x W
+        return pe.reshape(h * w, -1)[None]  # 1 x (H x W) x C
+    
+
+class FeatureFusion(nn.Module):
+    def __init__(
+        self,
+        in_channels=1024,
+        input_compression_ratio=1,
+        attn_compression_ratio=4,
+        features_num=4,
+        w_pe=True,
+    ):
+        super().__init__()
+
+        self.input_compression_ratio = input_compression_ratio
+        if self.input_compression_ratio != 1:
+            self.mlp_in = nn.ModuleList([nn.Sequential(
+                nn.Linear(in_channels, in_channels // input_compression_ratio),
+                # activation(),
+                # nn.Linear(embedding_dim // compression_ratio, embedding_dim // compression_ratio)
+            ) for _ in range(features_num)])
+
+            self.mlp_out = nn.ModuleList([nn.Sequential(
+                nn.Linear(in_channels // input_compression_ratio, in_channels),
+                # activation(),
+                # nn.Linear(embedding_dim, embedding_dim)
+            ) for _ in range(features_num)])
+
+        in_channels = in_channels // input_compression_ratio
+        self.mutual_cross_attn = nn.ModuleList([
+            MutualCrossAttention(embedding_dim=in_channels, mlp_dim=in_channels // attn_compression_ratio, attention_downsample_rate=attn_compression_ratio) for _ in range(features_num - 1)
+        ])
+        self.w_pe = w_pe
+        if self.w_pe:
+            # no grad
+            self.get_pe = PositionEmbeddingRandom(in_channels // 2)
+            with torch.no_grad():
+                self.pe = self.get_pe(size=(64, 64))
+
+    def forward(self, features):
+        # [B, 64, 64, 1024] x 4
+        
+        b, h, w, _ = features[0].shape
+        for i in range(len(features)):
+            features[i] = features[i].reshape(b, h * w, -1)
+            if self.input_compression_ratio != 1:
+                features[i] = self.mlp_in[i](features[i])
+
+        for i in range(len(features) - 1):   
+            features[i], features[i + 1] = self.mutual_cross_attn[i](features[i], features[i + 1], self.pe, self.pe)
+
+        for i in range(len(features)):
+            features[i] = features[i].reshape(b, h, w, -1)
+            if self.input_compression_ratio != 1:
+                features[i] = self.mlp_out[i](features[i])
+
+        return features
+
+
+if __name__ == '__main__':
+
+    import typing
+    from collections import defaultdict
+    import tabulate
+    from torch import nn
+
+
+    def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]:
+        """
+        Count parameters of a model and its submodules.
+
+        Args:
+            model: a torch module
+
+        Returns:
+            dict (str-> int): the key is either a parameter name or a module name.
+            The value is the number of elements in the parameter, or in all
+            parameters of the module. The key "" corresponds to the total
+            number of parameters of the model.
+        """
+        r = defaultdict(int)
+        for name, prm in model.named_parameters():
+            if trainable_only:
+                if not prm.requires_grad:
+                    continue
+            size = prm.numel()
+            name = name.split(".")
+            for k in range(0, len(name) + 1):
+                prefix = ".".join(name[:k])
+                r[prefix] += size
+        return r
+
+
+    def parameter_count_table(
+        model: nn.Module, max_depth: int = 3, trainable_only: bool = False
+    ) -> str:
+        """
+        Format the parameter count of the model (and its submodules or parameters)
+        in a nice table. It looks like this:
+
+        ::
+
+            | name                            | #elements or shape   |
+            |:--------------------------------|:---------------------|
+            | model                           | 37.9M                |
+            |  backbone                       |  31.5M               |
+            |   backbone.fpn_lateral3         |   0.1M               |
+            |    backbone.fpn_lateral3.weight |    (256, 512, 1, 1)  |
+            |    backbone.fpn_lateral3.bias   |    (256,)            |
+            |   backbone.fpn_output3          |   0.6M               |
+            |    backbone.fpn_output3.weight  |    (256, 256, 3, 3)  |
+            |    backbone.fpn_output3.bias    |    (256,)            |
+            |   backbone.fpn_lateral4         |   0.3M               |
+            |    backbone.fpn_lateral4.weight |    (256, 1024, 1, 1) |
+            |    backbone.fpn_lateral4.bias   |    (256,)            |
+            |   backbone.fpn_output4          |   0.6M               |
+            |    backbone.fpn_output4.weight  |    (256, 256, 3, 3)  |
+            |    backbone.fpn_output4.bias    |    (256,)            |
+            |   backbone.fpn_lateral5         |   0.5M               |
+            |    backbone.fpn_lateral5.weight |    (256, 2048, 1, 1) |
+            |    backbone.fpn_lateral5.bias   |    (256,)            |
+            |   backbone.fpn_output5          |   0.6M               |
+            |    backbone.fpn_output5.weight  |    (256, 256, 3, 3)  |
+            |    backbone.fpn_output5.bias    |    (256,)            |
+            |   backbone.top_block            |   5.3M               |
+            |    backbone.top_block.p6        |    4.7M              |
+            |    backbone.top_block.p7        |    0.6M              |
+            |   backbone.bottom_up            |   23.5M              |
+            |    backbone.bottom_up.stem      |    9.4K              |
+            |    backbone.bottom_up.res2      |    0.2M              |
+            |    backbone.bottom_up.res3      |    1.2M              |
+            |    backbone.bottom_up.res4      |    7.1M              |
+            |    backbone.bottom_up.res5      |    14.9M             |
+            |    ......                       |    .....             |
+
+        Args:
+            model: a torch module
+            max_depth (int): maximum depth to recursively print submodules or
+                parameters
+
+        Returns:
+            str: the table to be printed
+        """
+        count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only)
+        # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
+        param_shape: typing.Dict[str, typing.Tuple] = {
+            k: tuple(v.shape) for k, v in model.named_parameters()
+        }
+
+        # pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
+        table: typing.List[typing.Tuple] = []
+
+        def format_size(x: int) -> str:
+            if x > 1e8:
+                return "{:.1f}G".format(x / 1e9)
+            if x > 1e5:
+                return "{:.1f}M".format(x / 1e6)
+            if x > 1e2:
+                return "{:.1f}K".format(x / 1e3)
+            return str(x)
+
+        def fill(lvl: int, prefix: str) -> None:
+            if lvl >= max_depth:
+                return
+            for name, v in count.items():
+                if name.count(".") == lvl and name.startswith(prefix):
+                    indent = " " * (lvl + 1)
+                    if name in param_shape:
+                        table.append((indent + name, indent + str(param_shape[name])))
+                    else:
+                        table.append((indent + name, indent + format_size(v)))
+                        fill(lvl + 1, name + ".")
+
+        table.append(("model", format_size(count.pop(""))))
+        fill(0, "")
+
+        old_ws = tabulate.PRESERVE_WHITESPACE
+        tabulate.PRESERVE_WHITESPACE = True
+        tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe")
+        tabulate.PRESERVE_WHITESPACE = old_ws
+        return tab
+
+    feature_fusion = FeatureFusion(in_channels=1024, attn_compression_ratio=8)
+    print("All parameters: \n" + parameter_count_table(feature_fusion, max_depth=8))
+    features = [torch.randn(2, 64, 64, 1024) for _ in range(4)]
+    out = feature_fusion(features)
+    for i in out:
+        print(i.shape)
+    print('done')
diff --git a/modeling/semantic_enhanced_matting/mask_decoder_matting.py b/modeling/semantic_enhanced_matting/mask_decoder_matting.py
new file mode 100644
index 0000000000000000000000000000000000000000..1378f5445f2e25aab117042a5337ca41af4b0cdd
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/mask_decoder_matting.py
@@ -0,0 +1,356 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List, Tuple
+import numpy as np
+import cv2
+from detectron2.layers.batch_norm import NaiveSyncBatchNorm
+
+from modeling.semantic_enhanced_matting.modeling import TwoWayTransformer, MaskDecoder
+from modeling.decoder.detail_capture import Detail_Capture
+from modeling.decoder.unet_detail_capture import DetailUNet
+# from nnMorpho.binary_operators import erosion
+
+
+# class GenTrimapTorch(object):
+#     def __init__(self, max_kernal=200):
+#         self.max_kernal = max_kernal
+#         self.erosion_kernels = [None] + [torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))).float().cuda() for size in range(1, self.max_kernal)]
+
+#     def __call__(self, mask, kernel_size):
+        
+#         fg_width = kernel_size
+#         bg_width = kernel_size
+
+#         fg_mask = mask
+#         bg_mask = 1 - mask
+
+#         fg_mask = erosion(fg_mask, self.erosion_kernels[fg_width], border='a')
+#         bg_mask = erosion(bg_mask, self.erosion_kernels[bg_width], border='a')
+
+#         trimap = torch.ones_like(mask) * 0.5
+#         trimap[fg_mask == 1] = 1.0
+#         trimap[bg_mask == 1] = 0.0
+
+#         return trimap
+
+
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
+
+class MLP(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        hidden_dim: int,
+        output_dim: int,
+        num_layers: int,
+        sigmoid_output: bool = False,
+    ) -> None:
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+        )
+        self.sigmoid_output = sigmoid_output
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        if self.sigmoid_output:
+            x = F.sigmoid(x)
+        return x
+
+class MaskDecoderMatting(MaskDecoder):
+    def __init__(
+        self, 
+        model_type, 
+        checkpoint_path, 
+        detail_capture, 
+        mask_token_only, 
+        norm_type = 'LN', 
+        norm_mask_logits = False,
+        with_trimap = False,
+        min_kernel_size = 20,
+        kernel_div = 10,
+        concat_gen_trimap = False,
+    ):
+        super().__init__(
+            transformer_dim=256,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=256,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            num_multimask_outputs=3,
+            activation=nn.GELU,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+        )
+        assert model_type in ["vit_b","vit_l","vit_h"]
+        
+        assert norm_type in {'BN', 'LN', 'SyncBN'}
+        if norm_type == 'BN':
+            self.norm = torch.nn.BatchNorm2d
+        elif norm_type == 'SyncBN':
+            self.norm = NaiveSyncBatchNorm
+        else:
+            self.norm = LayerNorm2d
+
+        # checkpoint_dict = {"vit_b":"pretrained_checkpoint/sam_vit_b_maskdecoder.pth",
+        #                    "vit_l":"pretrained_checkpoint/sam_vit_l_maskdecoder.pth",
+        #                    'vit_h':"pretrained_checkpoint/sam_vit_h_maskdecoder.pth"}
+        # checkpoint_path = checkpoint_dict[model_type]
+
+        self.load_state_dict(torch.load(checkpoint_path))
+        print("Matting Decoder init from SAM MaskDecoder")
+
+        self.frozen_params_str = set()
+        for n, p in self.named_parameters():
+            p.requires_grad = False
+            self.frozen_params_str.add(n)
+
+        self.detail_capture = detail_capture
+        self.mask_token_only = mask_token_only
+        self.norm_mask_logits = norm_mask_logits
+
+        transformer_dim = 256
+        vit_dim_dict = {"vit_b":768,"vit_l":1024,"vit_h":1280}
+        vit_dim = vit_dim_dict[model_type]
+
+        self.hf_token = nn.Embedding(1, transformer_dim)
+        self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+        self.num_mask_tokens = self.num_mask_tokens + 1
+        self.concat_gen_trimap = concat_gen_trimap
+
+        self.compress_vit_feat = nn.Sequential(
+            nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
+            self.norm(transformer_dim),
+            nn.GELU(), 
+            nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2)
+        )
+        
+        self.embedding_encoder = nn.Sequential(
+            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+            self.norm(transformer_dim // 4),
+            nn.GELU(),
+            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+        )
+
+        self.embedding_maskfeature = nn.Sequential(
+            nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), 
+            self.norm(transformer_dim // 4),
+            nn.GELU(),
+            nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1)
+        )
+
+        if isinstance(self.detail_capture, Detail_Capture):
+            self.glue_layer_0 = nn.Conv2d(self.detail_capture.fus_channs[2], transformer_dim // 8, 3, 1, 1)
+        else:
+            assert isinstance(self.detail_capture, DetailUNet)
+
+        self.trainable_params_str = set()
+        for n, p in self.named_parameters():
+            if p.requires_grad:
+                self.trainable_params_str.add(n)
+
+        self.with_trimap = with_trimap
+        self.min_kernel_size = min_kernel_size
+        self.kernel_div = kernel_div
+        if self.with_trimap and not self.concat_gen_trimap:
+            # self.gen_trimap = GenTrimapTorch()
+            raise ValueError('Discard GenTrimapTorch')
+
+        # self.trainable_params_str = {'detail_capture', 'hf_token', 'hf_mlp', 'compress_vit_feat', 'embedding_encoder', 'embedding_maskfeature', 'glue_layer_0'}
+        # for n, p in self.named_parameters():
+        #     if p.requires_grad:
+        #         assert n.split('.')[0] in self.trainable_params_str
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        # hq_token_only: bool,
+        interm_embeddings: torch.Tensor,
+        hq_features: torch.Tensor,
+        images: torch.Tensor,
+        hr_images_ori_h_w = None,
+        return_alpha_logits = False,
+        pred_trimap=None
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predict masks given image and prompt embeddings.
+
+        Arguments:
+          image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
+          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+          multimask_output (bool): Whether to return multiple masks or a single
+            mask.
+
+        Returns:
+          torch.Tensor: batched predicted hq masks
+        """
+        vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT [B, 64, 64, 768]
+
+        # upsample image_embeddings x4.0 with detail_capture & embedding_encoder & compress_vit_feat
+            # regard hq_features as condition
+        if isinstance(self.norm_mask_logits, float):
+            norm_hq_features = hq_features / self.norm_mask_logits
+        elif self.norm_mask_logits:
+            norm_hq_features = hq_features / torch.std(hq_features, dim=(1, 2, 3), keepdim=True)
+        else:
+            norm_hq_features = hq_features
+
+        if hr_images_ori_h_w is not None:
+            assert not isinstance(self.detail_capture, Detail_Capture) and hq_features.shape[-2] == hq_features.shape[-1] == 1024
+            lr_images_before_pad_h_w = (1024 / max(hr_images_ori_h_w) * hr_images_ori_h_w[0], 1024 / max(hr_images_ori_h_w) * hr_images_ori_h_w[1])
+            lr_images_before_pad_h_w = (int(lr_images_before_pad_h_w[0] + 0.5), int(lr_images_before_pad_h_w[1] + 0.5))
+            norm_hq_features = F.interpolate(
+                norm_hq_features[:, :, :lr_images_before_pad_h_w[0], :lr_images_before_pad_h_w[1]], 
+                size = (images.shape[-2], images.shape[-1]), 
+                mode = 'bilinear', 
+                align_corners = False
+            )
+        
+        if self.concat_gen_trimap:
+            pred_trimap = F.interpolate(pred_trimap, size=(images.shape[-2], images.shape[-1]), mode='bilinear', align_corners=False)
+            pred_trimap = torch.argmax(pred_trimap, dim=1, keepdim=True).float() / 2.0
+            norm_hq_features = torch.concat((norm_hq_features, pred_trimap), dim=1)
+        elif self.with_trimap:
+            mask = (norm_hq_features > 0).float()
+            for i_batch in range(image_embeddings.shape[0]):
+                mask_area = torch.sum(mask[i_batch])
+                kernel_size = max(self.min_kernel_size, int((mask_area ** 0.5) / self.kernel_div))
+                kernel_size = min(kernel_size, self.gen_trimap.max_kernal - 1)
+                mask[i_batch, 0] = self.gen_trimap(mask[i_batch, 0], kernel_size=kernel_size)
+            trimaps = mask
+            norm_hq_features = torch.concat((norm_hq_features, trimaps), dim=1)
+
+        conditional_images = torch.concatenate((images, norm_hq_features), dim=1)  # [B, 4, 1024, 1024]
+
+        if isinstance(self.detail_capture, Detail_Capture):
+            detail_features = self.detail_capture.convstream(conditional_images)  # [B, 4, 1024, 1024] --> D0: [B, 4, 1024, 1024], D1: [B, 48, 512, 512], D2: [B, 96, 256, 256], D3: [B, 192, 128, 128]
+            matting_features = self.detail_capture.fusion_blks[0](image_embeddings, detail_features['D3'])  # [B, 256, 64, 64] & [B, 192, 128, 128] --> [B, 256, 128, 128]
+            matting_features = self.detail_capture.fusion_blks[1](matting_features, detail_features['D2'])  # [B, 256, 128, 128] & [B, 96, 256, 256] --> [B, 128, 256, 256]
+            matting_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features) + self.glue_layer_0(matting_features)  # [B, 32, 256, 256]
+        else:
+            matting_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
+
+        batch_len = len(image_embeddings)
+        masks = []
+        iou_preds = []
+        for i_batch in range(batch_len):
+            mask, iou_pred = self.predict_masks(
+                image_embeddings=image_embeddings[i_batch].unsqueeze(0),
+                image_pe=image_pe,
+                sparse_prompt_embeddings=sparse_prompt_embeddings[i_batch],
+                dense_prompt_embeddings=dense_prompt_embeddings[i_batch],
+                matting_feature = matting_features[i_batch].unsqueeze(0)
+            )
+            masks.append(mask)
+            iou_preds.append(iou_pred)
+        masks = torch.cat(masks, 0)  # [B, 5, 256, 256]
+        iou_preds = torch.cat(iou_preds, 0)  # [4, 4]
+
+        if self.mask_token_only:
+            masks_matting = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens), :, :]  # [B, 1, 256, 256]
+        else:
+            masks_matting = masks  # [B, 5, 256, 256]
+
+        if hr_images_ori_h_w is not None:
+            vit_features = F.interpolate(
+                vit_features[:, :, :math.ceil(lr_images_before_pad_h_w[0] / 16), :math.ceil(lr_images_before_pad_h_w[1] / 16)], 
+                size = (images.shape[-2] // 16, images.shape[-1] // 16), 
+                mode = 'bilinear', 
+                align_corners = False
+            )
+            masks_matting = F.interpolate(
+                masks_matting[:, :, :math.ceil(lr_images_before_pad_h_w[0] / 4), :math.ceil(lr_images_before_pad_h_w[1] / 4)], 
+                size = (images.shape[-2] // 4, images.shape[-1] // 4), 
+                mode = 'bilinear', 
+                align_corners = False
+            )
+
+        if isinstance(self.detail_capture, Detail_Capture):
+            matting_features = self.detail_capture.fusion_blks[2](masks_matting, detail_features['D1'])
+            matting_features = self.detail_capture.fusion_blks[3](matting_features, detail_features['D0'])
+            alpha = torch.sigmoid(self.detail_capture.matting_head(matting_features))
+        else:
+            if return_alpha_logits:
+                output = self.detail_capture(conditional_images, vit_features, masks_matting, return_alpha_logits = True)
+                alpha = torch.sigmoid(output[0]), output[1]
+            else:
+                alpha = torch.sigmoid(self.detail_capture(conditional_images, vit_features, masks_matting, return_alpha_logits = False))
+
+        if hr_images_ori_h_w is not None:
+            alpha = alpha[:, :, :hr_images_ori_h_w[0], :hr_images_ori_h_w[1]]
+        
+        return alpha
+    
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        matting_feature: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+
+        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)  # [6, 256]
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)  # [1, 6, 256]
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)  # [1, 8, 256]
+
+        # Expand per-image data in batch direction to be per-mask
+        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)   # [1, 256, 64, 64]
+        src = src + dense_prompt_embeddings  # [1, 256, 64, 64] + [1, 256, 64, 64]
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)  # [1, 256, 64, 64]
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)  # [1, 8, 256], [1, 4096, 256]
+        iou_token_out = hs[:, 0, :]  # [1, 256]
+        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]  # [1, 5, 256]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)  # [1, 256, 64, 64]
+
+        upscaled_embedding_sam = self.output_upscaling(src)  # [1, 32, 256, 256]
+        upscaled_embedding_ours = self.embedding_maskfeature(upscaled_embedding_sam) + matting_feature  # [1, 32, 256, 256]
+        
+        hyper_in_list: List[torch.Tensor] = []
+        for i in range(self.num_mask_tokens):
+            if i < 4:
+                hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+            else:
+                hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
+
+        hyper_in = torch.stack(hyper_in_list, dim=1)  # 5 * [1, 32] --> [1, 5, 32]
+        b, c, h, w = upscaled_embedding_sam.shape
+
+        masks_sam = (hyper_in[:,:4] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)  # [1, 4, 32] @ [1, 32, 65536] --> [1, 4, 256, 256]
+        masks_ours = (hyper_in[:,4:] @ upscaled_embedding_ours.view(b, c, h * w)).view(b, -1, h, w)  # [1, 1, 32] @ [1, 32, 65536] --> [1, 1, 256, 256]
+        masks = torch.cat([masks_sam,masks_ours], dim=1)  # [1, 5, 256, 256]
+        
+        iou_pred = self.iou_prediction_head(iou_token_out)
+
+        return masks, iou_pred
\ No newline at end of file
diff --git a/modeling/semantic_enhanced_matting/modeling/__init__.py b/modeling/semantic_enhanced_matting/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa0a07c7f8b75a3d1882bd4e7a4a3bc83e9da51c
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .sam import Sam
+from .image_encoder import ImageEncoderViT
+from .mask_decoder_hq import MaskDecoderHQ
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+from .transformer import TwoWayTransformer
+from .tiny_vit_sam import TinyViT
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/__init__.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fab22e68038c55b95e55130303275bf62aa1d8c0
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/common.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/common.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..49b83405778545969ce45fae770bf5d0b6c4f7cc
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/common.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/image_encoder.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/image_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0984055070520495df8c875e099f0a77ad3e84e5
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/image_encoder.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ec75a3ace2bf1978f084adfdf05be12e0c0ebdbb
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ede602c19484a67e6ea264872a2a43bb849ca332
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq_matting.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq_matting.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8819dd5b191c831acad58b371c383a06a51dc559
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/mask_decoder_hq_matting.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/prompt_encoder.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/prompt_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..829c46dbb26b33310f8538b01e1bf503df803928
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/prompt_encoder.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/sam.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/sam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..181fefe93e565ea8dfd5344e05444dcfd414ef80
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/sam.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/tiny_vit_sam.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/tiny_vit_sam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f02aa71dfd72b5235064230fd965b9a21704820b
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/tiny_vit_sam.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/__pycache__/transformer.cpython-38.pyc b/modeling/semantic_enhanced_matting/modeling/__pycache__/transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3809a541f8d339f52f0d6c3f037fa2212e5ec10d
Binary files /dev/null and b/modeling/semantic_enhanced_matting/modeling/__pycache__/transformer.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/modeling/common.py b/modeling/semantic_enhanced_matting/modeling/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/common.py
@@ -0,0 +1,43 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+
+from typing import Type
+
+
+class MLPBlock(nn.Module):
+    def __init__(
+        self,
+        embedding_dim: int,
+        mlp_dim: int,
+        act: Type[nn.Module] = nn.GELU,
+    ) -> None:
+        super().__init__()
+        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
+        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
+        self.act = act()
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.lin2(self.act(self.lin1(x)))
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
diff --git a/modeling/semantic_enhanced_matting/modeling/image_encoder.py b/modeling/semantic_enhanced_matting/modeling/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9353c9c0c63cf7447ff5fd14fca7a0e9ee74ce17
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/image_encoder.py
@@ -0,0 +1,447 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from typing import Optional, Tuple, Type
+
+from .common import LayerNorm2d, MLPBlock
+
+
+# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
+class ImageEncoderViT(nn.Module):
+    def __init__(
+        self,
+        img_size: int = 1024,
+        patch_size: int = 16,
+        in_chans: int = 3,
+        embed_dim: int = 768,
+        depth: int = 12,
+        num_heads: int = 12,
+        mlp_ratio: float = 4.0,
+        out_chans: int = 256,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Module] = nn.LayerNorm,
+        act_layer: Type[nn.Module] = nn.GELU,
+        use_abs_pos: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        global_attn_indexes: Tuple[int, ...] = (),
+    ) -> None:
+        """
+        Args:
+            img_size (int): Input image size.
+            patch_size (int): Patch size.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Patch embedding dimension.
+            depth (int): Depth of ViT.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            norm_layer (nn.Module): Normalization layer.
+            act_layer (nn.Module): Activation layer.
+            use_abs_pos (bool): If True, use absolute positional embeddings.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks.
+            global_attn_indexes (list): Indexes for blocks using global attention.
+        """
+        super().__init__()
+        self.img_size = img_size
+
+        self.patch_size = patch_size
+        self.embed_dim = embed_dim
+
+        self.patch_embed = PatchEmbed(
+            kernel_size=(patch_size, patch_size),
+            stride=(patch_size, patch_size),
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+        )
+
+        self.pos_embed: Optional[nn.Parameter] = None
+        if use_abs_pos:
+            # Initialize absolute positional embedding with pretrain image size.
+            self.pos_embed = nn.Parameter(
+                torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+            )
+
+        self.blocks = nn.ModuleList()
+        for i in range(depth):
+            block = Block(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                norm_layer=norm_layer,
+                act_layer=act_layer,
+                use_rel_pos=use_rel_pos,
+                rel_pos_zero_init=rel_pos_zero_init,
+                window_size=window_size if i not in global_attn_indexes else 0,
+                input_size=(img_size // patch_size, img_size // patch_size),
+            )
+            self.blocks.append(block)
+
+        self.neck = nn.Sequential(
+            nn.Conv2d(
+                embed_dim,
+                out_chans,
+                kernel_size=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+            nn.Conv2d(
+                out_chans,
+                out_chans,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+            LayerNorm2d(out_chans),
+        )
+
+    def forward(self, img: torch.Tensor, condition: torch.Tensor = None, condition_layer = None) -> torch.Tensor:
+        
+        x = self.patch_embed(img)
+        if isinstance(condition_layer, nn.ModuleDict) and condition is not None:
+            # concat mask and img as condition
+            bbox_mask = torch.zeros_like(img)[:, 0:1]
+            for i in range(condition.shape[0]):
+                l, u, r, d = condition[i, 0, :]
+                bbox_mask[i, :, int(u): int(d), int(l): int(r)] = 1.0
+            condition_input = torch.concat((img, bbox_mask), dim=1)
+
+            x = x + condition_layer['patch_embed'](condition_input)
+        
+        if self.pos_embed is not None:
+            x = x + self.pos_embed
+
+        index = 0
+        interm_embeddings = []
+        pred_trimap = []
+        for blk in self.blocks:
+            x = blk(x)
+            if blk.window_size == 0:
+
+                interm_embeddings.append(x)
+
+                # pred intern triamp
+                if isinstance(condition_layer, nn.ModuleDict) and '{}_pred_layer'.format(index) in condition_layer.keys() and condition is not None:
+                    pred_trimap.append(condition_layer['{}_pred_layer'.format(index)](x.permute(0, 3, 1, 2)))
+
+                # add intern prompt
+                if isinstance(condition_layer, nn.ModuleList):
+                    x = condition_layer[index](x, condition)
+                elif isinstance(condition_layer, nn.ModuleDict) and 'prompt_layer' in condition_layer.keys() and condition is not None:
+                    x = x + condition_layer['prompt_layer'](x, condition, index)
+                    
+                index += 1
+                
+        x = self.neck(x.permute(0, 3, 1, 2))
+
+        if isinstance(condition_layer, nn.ModuleDict) and len(pred_trimap) != 0 and condition is not None:
+            pred_trimap = sum(pred_trimap) / len(pred_trimap)
+            pred_trimap = F.interpolate(pred_trimap, size=(img.shape[-2], img.shape[-1]), mode='bilinear', align_corners=False)
+        else:
+            pred_trimap = None
+        
+        if isinstance(condition_layer, nn.ModuleDict) and 'feature_fusion' in condition_layer.keys() and condition is not None:
+            interm_embeddings = condition_layer['feature_fusion'](interm_embeddings)
+
+        return x, interm_embeddings, pred_trimap
+
+
+class Block(nn.Module):
+    """Transformer blocks with support of window attention and residual propagation blocks"""
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        qkv_bias: bool = True,
+        norm_layer: Type[nn.Module] = nn.LayerNorm,
+        act_layer: Type[nn.Module] = nn.GELU,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        window_size: int = 0,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads in each ViT block.
+            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+            qkv_bias (bool): If True, add a learnable bias to query, key, value.
+            norm_layer (nn.Module): Normalization layer.
+            act_layer (nn.Module): Activation layer.
+            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            window_size (int): Window size for window attention blocks. If it equals 0, then
+                use global attention.
+            input_size (tuple(int, int) or None): Input resolution for calculating the relative
+                positional parameter size.
+        """
+        super().__init__()
+        self.norm1 = norm_layer(dim)
+        self.attn = Attention(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            use_rel_pos=use_rel_pos,
+            rel_pos_zero_init=rel_pos_zero_init,
+            input_size=input_size if window_size == 0 else (window_size, window_size),
+        )
+
+        self.norm2 = norm_layer(dim)
+        self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+
+        self.window_size = window_size
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        shortcut = x
+        x = self.norm1(x)
+        # Window partition
+        if self.window_size > 0:
+            H, W = x.shape[1], x.shape[2]
+            x, pad_hw = window_partition(x, self.window_size)
+
+        x = self.attn(x)
+        # Reverse window partition
+        if self.window_size > 0:
+            x = window_unpartition(x, self.window_size, pad_hw, (H, W))
+
+        x = shortcut + x
+        x = x + self.mlp(self.norm2(x))
+
+        return x
+
+
+class Attention(nn.Module):
+    """Multi-head Attention block with relative position embeddings."""
+
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 8,
+        qkv_bias: bool = True,
+        use_rel_pos: bool = False,
+        rel_pos_zero_init: bool = True,
+        input_size: Optional[Tuple[int, int]] = None,
+    ) -> None:
+        """
+        Args:
+            dim (int): Number of input channels.
+            num_heads (int): Number of attention heads.
+            qkv_bias (bool):  If True, add a learnable bias to query, key, value.
+            rel_pos (bool): If True, add relative positional embeddings to the attention map.
+            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
+            input_size (tuple(int, int) or None): Input resolution for calculating the relative
+                positional parameter size.
+        """
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.proj = nn.Linear(dim, dim)
+
+        self.use_rel_pos = use_rel_pos
+        if self.use_rel_pos:
+            assert (
+                input_size is not None
+            ), "Input size must be provided if using relative positional encoding."
+            # initialize relative positional embeddings
+            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
+            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
+
+    def forward_token(self, x: torch.Tensor) -> torch.Tensor:
+        B, N, _ = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        q, k, v = qkv.reshape(3, B * self.num_heads, N, -1).unbind(0)
+        attn = (q * self.scale) @ k.transpose(-2, -1)
+        assert not self.use_rel_pos
+        attn = attn.softmax(dim=-1)
+        x = (attn @ v).view(B, self.num_heads, N, -1).permute(0, 2, 1, 3).reshape(B, N, -1)
+        x = self.proj(x)
+        return x
+    
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        B, H, W, _ = x.shape
+        # qkv with shape (3, B, nHead, H * W, C)
+        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+        # q, k, v with shape (B * nHead, H * W, C)
+        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
+
+        attn = (q * self.scale) @ k.transpose(-2, -1)
+
+        if self.use_rel_pos:
+            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+
+        attn = attn.softmax(dim=-1)
+        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+        x = self.proj(x)
+
+        return x
+
+
+def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+    """
+    Partition into non-overlapping windows with padding if needed.
+    Args:
+        x (tensor): input tokens with [B, H, W, C].
+        window_size (int): window size.
+
+    Returns:
+        windows: windows after partition with [B * num_windows, window_size, window_size, C].
+        (Hp, Wp): padded height and width before partition
+    """
+    B, H, W, C = x.shape
+
+    pad_h = (window_size - H % window_size) % window_size
+    pad_w = (window_size - W % window_size) % window_size
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+    Hp, Wp = H + pad_h, W + pad_w
+
+    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    return windows, (Hp, Wp)
+
+
+def window_unpartition(
+    windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+) -> torch.Tensor:
+    """
+    Window unpartition into original sequences and removing padding.
+    Args:
+        windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+        window_size (int): window size.
+        pad_hw (Tuple): padded height and width (Hp, Wp).
+        hw (Tuple): original height and width (H, W) before padding.
+
+    Returns:
+        x: unpartitioned sequences with [B, H, W, C].
+    """
+    Hp, Wp = pad_hw
+    H, W = hw
+    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+    if Hp > H or Wp > W:
+        x = x[:, :H, :W, :].contiguous()
+    return x
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+    """
+    Get relative positional embeddings according to the relative positions of
+        query and key sizes.
+    Args:
+        q_size (int): size of query q.
+        k_size (int): size of key k.
+        rel_pos (Tensor): relative position embeddings (L, C).
+
+    Returns:
+        Extracted positional embeddings according to relative positions.
+    """
+    max_rel_dist = int(2 * max(q_size, k_size) - 1)
+    # Interpolate rel pos if needed.
+    if rel_pos.shape[0] != max_rel_dist:
+        # Interpolate rel pos.
+        rel_pos_resized = F.interpolate(
+            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
+            size=max_rel_dist,
+            mode="linear",
+        )
+        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+    else:
+        rel_pos_resized = rel_pos
+
+    # Scale the coords with short length if shapes for q and k are different.
+    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+    return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+    attn: torch.Tensor,
+    q: torch.Tensor,
+    rel_pos_h: torch.Tensor,
+    rel_pos_w: torch.Tensor,
+    q_size: Tuple[int, int],
+    k_size: Tuple[int, int],
+) -> torch.Tensor:
+    """
+    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
+    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
+    Args:
+        attn (Tensor): attention map.
+        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
+        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
+        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
+        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
+        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
+
+    Returns:
+        attn (Tensor): attention map with added relative positional embeddings.
+    """
+    q_h, q_w = q_size
+    k_h, k_w = k_size
+    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
+    Rw = get_rel_pos(q_w, k_w, rel_pos_w)
+
+    B, _, dim = q.shape
+    r_q = q.reshape(B, q_h, q_w, dim)
+    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
+    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
+
+    attn = (
+        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+    ).view(B, q_h * q_w, k_h * k_w)
+
+    return attn
+
+
+class PatchEmbed(nn.Module):
+    """
+    Image to Patch Embedding.
+    """
+
+    def __init__(
+        self,
+        kernel_size: Tuple[int, int] = (16, 16),
+        stride: Tuple[int, int] = (16, 16),
+        padding: Tuple[int, int] = (0, 0),
+        in_chans: int = 3,
+        embed_dim: int = 768,
+    ) -> None:
+        """
+        Args:
+            kernel_size (Tuple): kernel size of the projection layer.
+            stride (Tuple): stride of the projection layer.
+            padding (Tuple): padding size of the projection layer.
+            in_chans (int): Number of input image channels.
+            embed_dim (int): Patch embedding dimension.
+        """
+        super().__init__()
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.proj(x)
+        # B C H W -> B H W C
+        x = x.permute(0, 2, 3, 1)
+        return x
diff --git a/modeling/semantic_enhanced_matting/modeling/mask_decoder.py b/modeling/semantic_enhanced_matting/modeling/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..caddba072321a2234b9185e70de1abc6799c6f9c
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/mask_decoder.py
@@ -0,0 +1,178 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        transformer_dim: int,
+        transformer: nn.Module,
+        num_multimask_outputs: int = 3,
+        activation: Type[nn.Module] = nn.GELU,
+        iou_head_depth: int = 3,
+        iou_head_hidden_dim: int = 256,
+    ) -> None:
+        """
+        Predicts masks given an image and prompt embeddings, using a
+        transformer architecture.
+
+        Arguments:
+          transformer_dim (int): the channel dimension of the transformer
+          transformer (nn.Module): the transformer used to predict masks
+          num_multimask_outputs (int): the number of masks to predict
+            when disambiguating masks
+          activation (nn.Module): the type of activation to use when
+            upscaling masks
+          iou_head_depth (int): the depth of the MLP used to predict
+            mask quality
+          iou_head_hidden_dim (int): the hidden dimension of the MLP
+            used to predict mask quality
+        """
+        super().__init__()
+        self.transformer_dim = transformer_dim
+        self.transformer = transformer
+
+        self.num_multimask_outputs = num_multimask_outputs
+
+        self.iou_token = nn.Embedding(1, transformer_dim)
+        self.num_mask_tokens = num_multimask_outputs + 1
+        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+        self.output_upscaling = nn.Sequential(
+            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+            LayerNorm2d(transformer_dim // 4),
+            activation(),
+            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+            activation(),
+        )
+        self.output_hypernetworks_mlps = nn.ModuleList(
+            [
+                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+                for i in range(self.num_mask_tokens)
+            ]
+        )
+
+        self.iou_prediction_head = MLP(
+            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+        )
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        hq_token_only: bool,
+        interm_embeddings: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predict masks given image and prompt embeddings.
+
+        Arguments:
+          image_embeddings (torch.Tensor): the embeddings from the image encoder
+          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+          multimask_output (bool): Whether to return multiple masks or a single
+            mask.
+
+        Returns:
+          torch.Tensor: batched predicted masks
+          torch.Tensor: batched predictions of mask quality
+        """
+        masks, iou_pred = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+        )
+
+        # Select the correct mask or masks for output
+        if multimask_output:
+            mask_slice = slice(1, None)
+        else:
+            mask_slice = slice(0, 1)
+        masks = masks[:, mask_slice, :, :]
+        iou_pred = iou_pred[:, mask_slice]
+
+        # Prepare output
+        return masks, iou_pred
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+        # Concatenate output tokens
+        output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        src = src + dense_prompt_embeddings
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, 0, :]
+        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w) 
+        upscaled_embedding = self.output_upscaling(src)
+        hyper_in_list: List[torch.Tensor] = []
+        for i in range(self.num_mask_tokens):
+            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding.shape
+        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+
+        return masks, iou_pred
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        hidden_dim: int,
+        output_dim: int,
+        num_layers: int,
+        sigmoid_output: bool = False,
+    ) -> None:
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+        )
+        self.sigmoid_output = sigmoid_output
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        if self.sigmoid_output:
+            x = F.sigmoid(x)
+        return x
diff --git a/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq.py b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4bc9d8096e3d140422d7f60a1b5860184b53c15
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq.py
@@ -0,0 +1,276 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Modified by HQ-SAM team
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class MaskDecoderHQ(nn.Module):
+    def __init__(
+        self,
+        *,
+        transformer_dim: int,
+        transformer: nn.Module,
+        num_multimask_outputs: int = 3,
+        activation: Type[nn.Module] = nn.GELU,
+        iou_head_depth: int = 3,
+        iou_head_hidden_dim: int = 256,
+        vit_dim: int = 1024,
+        w_all_logits: bool = False,
+        wo_hq: bool = False,
+    ) -> None:
+        """
+        Predicts masks given an image and prompt embeddings, using a
+        transformer architecture.
+
+        Arguments:
+          transformer_dim (int): the channel dimension of the transformer
+          transformer (nn.Module): the transformer used to predict masks
+          num_multimask_outputs (int): the number of masks to predict
+            when disambiguating masks
+          activation (nn.Module): the type of activation to use when
+            upscaling masks
+          iou_head_depth (int): the depth of the MLP used to predict
+            mask quality
+          iou_head_hidden_dim (int): the hidden dimension of the MLP
+            used to predict mask quality
+        """
+        super().__init__()
+        self.transformer_dim = transformer_dim
+        self.transformer = transformer
+        self.vit_dim = vit_dim
+        
+        self.num_multimask_outputs = num_multimask_outputs
+
+        self.iou_token = nn.Embedding(1, transformer_dim)
+        self.num_mask_tokens = num_multimask_outputs + 1
+        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+        self.output_upscaling = nn.Sequential(
+            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+            LayerNorm2d(transformer_dim // 4),
+            activation(),
+            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+            activation(),
+        )
+        self.output_hypernetworks_mlps = nn.ModuleList(
+            [
+                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+                for i in range(self.num_mask_tokens)
+            ]
+        )
+
+        self.iou_prediction_head = MLP(
+            transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
+        )
+
+        self.wo_hq = wo_hq
+        if not self.wo_hq:
+            # HQ-SAM parameters
+            self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
+            self.hf_mlp = MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) # corresponding new MLP layer for HQ-Ouptput-Token
+            self.num_mask_tokens = self.num_mask_tokens + 1
+            
+            # three conv fusion layers for obtaining HQ-Feature
+            self.compress_vit_feat = nn.Sequential(
+                                            nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
+                                            LayerNorm2d(transformer_dim),
+                                            nn.GELU(), 
+                                            nn.ConvTranspose2d(transformer_dim, transformer_dim // 8, kernel_size=2, stride=2))
+            
+            self.embedding_encoder = nn.Sequential(
+                                            nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+                                            LayerNorm2d(transformer_dim // 4),
+                                            nn.GELU(),
+                                            nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+                                        )
+            self.embedding_maskfeature = nn.Sequential(
+                                            nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1), 
+                                            LayerNorm2d(transformer_dim // 4),
+                                            nn.GELU(),
+                                            nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1))
+
+            self.w_all_logits = w_all_logits
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        hq_token_only: bool,
+        interm_embeddings: torch.Tensor,
+        return_hq_features_type: str,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predict masks given image and prompt embeddings.
+
+        Arguments:
+          image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
+          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+          multimask_output (bool): Whether to return multiple masks or a single
+            mask.
+
+        Returns:
+          torch.Tensor: batched predicted masks
+          torch.Tensor: batched predictions of mask quality
+        """
+        if not self.wo_hq:
+            vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
+            hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
+            return_hq_features = None
+            if return_hq_features_type == 'Early':
+                return_hq_features = hq_features
+        else:
+            hq_features = None
+
+        masks, iou_pred, mid_fin_hq_features = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+            hq_features=hq_features,
+            return_hq_features_type=return_hq_features_type
+        )
+        if return_hq_features_type in {'Middle', 'Final'}:
+            return_hq_features = mid_fin_hq_features
+
+        # Select the correct mask or masks for output
+        if multimask_output:
+            # mask with highest score
+            if not self.wo_hq:
+                mask_slice = slice(1, self.num_mask_tokens-1)
+            else:
+                mask_slice = slice(1, self.num_mask_tokens)
+            iou_pred = iou_pred[:, mask_slice]
+            iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)
+            iou_pred = iou_pred.unsqueeze(1)
+            masks_multi = masks[:, mask_slice, :, :]
+            masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
+        else:
+            # singale mask output, default
+            mask_slice = slice(0, 1)
+            iou_pred = iou_pred[:,mask_slice]
+            masks_sam = masks[:,mask_slice]
+
+        if not self.wo_hq:
+            masks_hq = masks[:,slice(self.num_mask_tokens-1, self.num_mask_tokens)]
+
+            if hq_token_only:
+                low_res_masks = masks_hq
+            else:
+                low_res_masks = masks_sam + masks_hq
+            
+            if return_hq_features_type == 'Final':
+                return_hq_features = low_res_masks
+
+            if self.w_all_logits:
+                return_hq_features = masks
+        else:
+            low_res_masks = masks_sam
+            return_hq_features = masks_sam
+        # Prepare output
+        return low_res_masks, masks_sam, return_hq_features
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        hq_features: torch.Tensor,
+        return_hq_features_type: str
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+        # Concatenate output tokens
+        if not self.wo_hq:
+            output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0)
+        else:
+            output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        src = src + dense_prompt_embeddings
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, 0, :]
+        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)
+
+        upscaled_embedding_sam = self.output_upscaling(src)
+        if not self.wo_hq:
+            upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)
+            return_hq_features = None
+            if return_hq_features_type  == 'Middle':
+                return_hq_features = upscaled_embedding_hq
+
+        hyper_in_list: List[torch.Tensor] = []
+        for i in range(self.num_mask_tokens):
+            if i < self.num_mask_tokens - 1 or self.wo_hq:
+                hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+            else:
+                hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
+
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding_sam.shape
+        if not self.wo_hq:
+            masks_sam = (hyper_in[:,:self.num_mask_tokens-1] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
+            masks_sam_hq = (hyper_in[:,self.num_mask_tokens-1:] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
+            if return_hq_features_type == 'Final':
+                return_hq_features = masks_sam_hq
+            masks = torch.cat([masks_sam, masks_sam_hq],dim=1)
+        else:
+            masks_sam = (hyper_in @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
+            return_hq_features = masks_sam
+            masks = masks_sam
+
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+
+        return masks, iou_pred, return_hq_features
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        hidden_dim: int,
+        output_dim: int,
+        num_layers: int,
+        sigmoid_output: bool = False,
+    ) -> None:
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+        )
+        self.sigmoid_output = sigmoid_output
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        if self.sigmoid_output:
+            x = F.sigmoid(x)
+        return x
diff --git a/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq_matting.py b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq_matting.py
new file mode 100644
index 0000000000000000000000000000000000000000..59ff8f44cb1419ebcbcd957a5b8119afba74e1d7
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/mask_decoder_hq_matting.py
@@ -0,0 +1,200 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Modified by HQ-SAM team
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import List, Tuple, Type
+
+from .common import LayerNorm2d
+from .mask_decoder_hq import MaskDecoderHQ, MLP
+
+
+class MaskDecoderHQMatting(MaskDecoderHQ):
+    def __init__(
+        self,
+        hq_token_only=False,
+        matting_token_num=1,
+        mask_matting_res_add=True,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        self.hq_token_only = hq_token_only
+        self.matting_token_num = matting_token_num
+        self.mask_matting_res_add = mask_matting_res_add
+        if not self.mask_matting_res_add:
+            assert self.wo_hq
+
+        # Matting token parameters
+        self.matting_hf_token = nn.Embedding(self.matting_token_num, self.transformer_dim) # Matting-Ouptput-Token
+        self.matting_hf_mlp = MLP(self.transformer_dim, self.transformer_dim, self.transformer_dim // 8, 3) # corresponding new MLP layer for Matting-Ouptput-Token
+        self.num_mask_tokens = self.num_mask_tokens + self.matting_token_num
+        
+        # three conv fusion layers for obtaining Matting-Feature
+        self.matting_compress_vit_feat = nn.Sequential(
+                                        nn.ConvTranspose2d(self.vit_dim, self.transformer_dim, kernel_size=2, stride=2),
+                                        LayerNorm2d(self.transformer_dim),
+                                        nn.GELU(), 
+                                        nn.ConvTranspose2d(self.transformer_dim, self.transformer_dim // 8, kernel_size=2, stride=2))
+        
+        self.matting_embedding_encoder = nn.Sequential(
+                                        nn.ConvTranspose2d(self.transformer_dim, self.transformer_dim // 4, kernel_size=2, stride=2),
+                                        LayerNorm2d(self.transformer_dim // 4),
+                                        nn.GELU(),
+                                        nn.ConvTranspose2d(self.transformer_dim // 4, self.transformer_dim // 8, kernel_size=2, stride=2),
+                                    )
+        self.matting_embedding_maskfeature = nn.Sequential(
+                                        nn.Conv2d(self.transformer_dim // 8, self.transformer_dim // 4, 3, 1, 1), 
+                                        LayerNorm2d(self.transformer_dim // 4),
+                                        nn.GELU(),
+                                        nn.Conv2d(self.transformer_dim // 4, self.transformer_dim // 8, 3, 1, 1))
+
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        interm_embeddings: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predict masks given image and prompt embeddings.
+
+        Arguments:
+          image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
+          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+          multimask_output (bool): Whether to return multiple masks or a single
+            mask.
+
+        Returns:
+          torch.Tensor: batched predicted masks
+          torch.Tensor: batched predictions of mask quality
+        """
+        vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
+        if not self.wo_hq:
+            hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_features)
+        else:
+            hq_features = None
+        matting_hq_features = self.matting_embedding_encoder(image_embeddings) + self.matting_compress_vit_feat(vit_features)
+
+        masks, iou_pred = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+            hq_features=hq_features,
+            matting_hq_features=matting_hq_features
+        )
+
+        # Select the correct mask or masks for output
+        if multimask_output:
+            # mask with highest score
+            if not self.wo_hq:
+                mask_slice = slice(1,self.num_mask_tokens - (self.matting_token_num + 1))  # matting_token_num + hq_token_num
+            else:
+                mask_slice = slice(1,self.num_mask_tokens - self.matting_token_num)  # matting_token_num
+            iou_pred = iou_pred[:, mask_slice]
+            iou_pred, max_iou_idx = torch.max(iou_pred,dim=1)
+            iou_pred = iou_pred.unsqueeze(1)
+            masks_multi = masks[:, mask_slice, :, :]
+            masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
+        else:
+            # singale mask output, default
+            mask_slice = slice(0, 1)
+            iou_pred = iou_pred[:,mask_slice]
+            masks_sam = masks[:,mask_slice]
+
+        if not self.wo_hq:
+            masks_hq = masks[:,slice(self.num_mask_tokens - (self.matting_token_num + 1), self.num_mask_tokens - self.matting_token_num)]
+        masks_matting = masks[:,slice(self.num_mask_tokens - self.matting_token_num, self.num_mask_tokens)]
+
+        if not self.wo_hq:
+            if self.hq_token_only:
+                # masks_hq += masks_sam
+                masks_matting += masks_hq
+            else:
+                masks_hq += masks_sam
+                masks_matting += masks_hq 
+        else:
+            masks_hq = masks_sam
+            if self.mask_matting_res_add:
+                masks_matting = masks_sam + masks_matting
+            else:
+                masks_matting = masks_matting
+        # Prepare output
+        return {'masks_sam': masks_sam, 'masks_hq': masks_hq, 'masks_matting': masks_matting}
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        hq_features: torch.Tensor,
+        matting_hq_features: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+        # Concatenate output tokens
+        if not self.wo_hq:
+            output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight, self.matting_hf_token.weight], dim=0)
+        else:
+            output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight, self.matting_hf_token.weight], dim=0)
+        output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        src = src + dense_prompt_embeddings
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, 0, :]
+        mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)
+
+        upscaled_embedding_sam = self.output_upscaling(src)
+        if not self.wo_hq:
+            upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(b,1,1,1)
+        upscaled_embedding_matting_hq = self.matting_embedding_maskfeature(upscaled_embedding_sam) + matting_hq_features.repeat(b,1,1,1)
+
+        hyper_in_list: List[torch.Tensor] = []
+        for i in range(self.num_mask_tokens):
+            if i < self.num_mask_tokens - (self.matting_token_num + 1) or (self.wo_hq and i < self.num_mask_tokens - self.matting_token_num):
+                hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+            elif i == self.num_mask_tokens - (self.matting_token_num + 1):
+                hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
+            else:
+                hyper_in_list.append(self.matting_hf_mlp(mask_tokens_out[:, i, :]))
+
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding_sam.shape
+
+        if not self.wo_hq:
+            masks_sam = (hyper_in[:,:self.num_mask_tokens - (self.matting_token_num + 1)] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
+            masks_sam_hq = (hyper_in[:,self.num_mask_tokens - (self.matting_token_num + 1) : self.num_mask_tokens - self.matting_token_num] @ upscaled_embedding_hq.view(b, c, h * w)).view(b, -1, h, w)
+        else:
+            masks_sam = (hyper_in[:,:self.num_mask_tokens - self.matting_token_num] @ upscaled_embedding_sam.view(b, c, h * w)).view(b, -1, h, w)
+        masks_sam_matting_hq = (hyper_in[:, self.num_mask_tokens - self.matting_token_num:] @ upscaled_embedding_matting_hq.view(b, c, h * w)).view(b, -1, h, w)
+
+        if not self.wo_hq:
+            masks = torch.cat([masks_sam, masks_sam_hq, masks_sam_matting_hq],dim=1)
+        else:
+            masks = torch.cat([masks_sam, masks_sam_matting_hq],dim=1)
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+
+        return masks, iou_pred
diff --git a/modeling/semantic_enhanced_matting/modeling/prompt_encoder.py b/modeling/semantic_enhanced_matting/modeling/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/prompt_encoder.py
@@ -0,0 +1,214 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch import nn
+
+from typing import Any, Optional, Tuple, Type
+
+from .common import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+    def __init__(
+        self,
+        embed_dim: int,
+        image_embedding_size: Tuple[int, int],
+        input_image_size: Tuple[int, int],
+        mask_in_chans: int,
+        activation: Type[nn.Module] = nn.GELU,
+    ) -> None:
+        """
+        Encodes prompts for input to SAM's mask decoder.
+
+        Arguments:
+          embed_dim (int): The prompts' embedding dimension
+          image_embedding_size (tuple(int, int)): The spatial size of the
+            image embedding, as (H, W).
+          input_image_size (int): The padded size of the image as input
+            to the image encoder, as (H, W).
+          mask_in_chans (int): The number of hidden channels used for
+            encoding input masks.
+          activation (nn.Module): The activation to use when encoding
+            input masks.
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.input_image_size = input_image_size
+        self.image_embedding_size = image_embedding_size
+        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+        self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners
+        point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
+        self.point_embeddings = nn.ModuleList(point_embeddings)
+        self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+        self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+        self.mask_downscaling = nn.Sequential(
+            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans // 4),
+            activation(),
+            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans),
+            activation(),
+            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+        )
+        self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+    def get_dense_pe(self) -> torch.Tensor:
+        """
+        Returns the positional encoding used to encode point prompts,
+        applied to a dense set of points the shape of the image encoding.
+
+        Returns:
+          torch.Tensor: Positional encoding with shape
+            1x(embed_dim)x(embedding_h)x(embedding_w)
+        """
+        return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+    def _embed_points(
+        self,
+        points: torch.Tensor,
+        labels: torch.Tensor,
+        pad: bool,
+    ) -> torch.Tensor:
+        """Embeds point prompts."""
+        points = points + 0.5  # Shift to center of pixel
+        if pad:
+            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+            points = torch.cat([points, padding_point], dim=1)
+            labels = torch.cat([labels, padding_label], dim=1)
+        point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+        point_embedding[labels == -1] = 0.0
+        point_embedding[labels == -1] += self.not_a_point_embed.weight
+        point_embedding[labels == 0] += self.point_embeddings[0].weight
+        point_embedding[labels == 1] += self.point_embeddings[1].weight
+        return point_embedding
+
+    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+        """Embeds box prompts."""
+        boxes = boxes + 0.5  # Shift to center of pixel
+        coords = boxes.reshape(-1, 2, 2)
+        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+        return corner_embedding
+
+    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+        """Embeds mask inputs."""
+        mask_embedding = self.mask_downscaling(masks)
+        return mask_embedding
+
+    def _get_batch_size(
+        self,
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> int:
+        """
+        Gets the batch size of the output given the batch size of the input prompts.
+        """
+        if points is not None:
+            return points[0].shape[0]
+        elif boxes is not None:
+            return boxes.shape[0]
+        elif masks is not None:
+            return masks.shape[0]
+        else:
+            return 1
+
+    def _get_device(self) -> torch.device:
+        return self.point_embeddings[0].weight.device
+
+    def forward(
+        self,
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Embeds different types of prompts, returning both sparse and dense
+        embeddings.
+
+        Arguments:
+          points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+            and labels to embed.
+          boxes (torch.Tensor or none): boxes to embed
+          masks (torch.Tensor or none): masks to embed
+
+        Returns:
+          torch.Tensor: sparse embeddings for the points and boxes, with shape
+            BxNx(embed_dim), where N is determined by the number of input points
+            and boxes.
+          torch.Tensor: dense embeddings for the masks, in the shape
+            Bx(embed_dim)x(embed_H)x(embed_W)
+        """
+        bs = self._get_batch_size(points, boxes, masks)
+        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+        if points is not None:
+            coords, labels = points
+            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+        if boxes is not None:
+            box_embeddings = self._embed_boxes(boxes)
+            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+        if masks is not None:
+            dense_embeddings = self._embed_masks(masks)
+        else:
+            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+            )
+
+        return sparse_embeddings, dense_embeddings
+
+
+class PositionEmbeddingRandom(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+    """
+
+    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+        super().__init__()
+        if scale is None or scale <= 0.0:
+            scale = 1.0
+        self.register_buffer(
+            "positional_encoding_gaussian_matrix",
+            scale * torch.randn((2, num_pos_feats)),
+        )
+
+    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+        """Positionally encode points that are normalized to [0,1]."""
+        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+        coords = 2 * coords - 1
+        coords = coords @ self.positional_encoding_gaussian_matrix
+        coords = 2 * np.pi * coords
+        # outputs d_1 x ... x d_n x C shape
+        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+        """Generate positional encoding for a grid of the specified size."""
+        h, w = size
+        device: Any = self.positional_encoding_gaussian_matrix.device
+        grid = torch.ones((h, w), device=device, dtype=torch.float32)
+        y_embed = grid.cumsum(dim=0) - 0.5
+        x_embed = grid.cumsum(dim=1) - 0.5
+        y_embed = y_embed / h
+        x_embed = x_embed / w
+
+        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+        return pe.permute(2, 0, 1)  # C x H x W
+
+    def forward_with_coords(
+        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+    ) -> torch.Tensor:
+        """Positionally encode points that are not normalized to [0,1]."""
+        coords = coords_input.clone()
+        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+        return self._pe_encoding(coords.to(torch.float))  # B x N x C
diff --git a/modeling/semantic_enhanced_matting/modeling/sam.py b/modeling/semantic_enhanced_matting/modeling/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfca5c7b72253ab390036d565aeed7a19405a40d
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/sam.py
@@ -0,0 +1,177 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from typing import Any, Dict, List, Tuple
+
+from .image_encoder import ImageEncoderViT
+from .mask_decoder import MaskDecoder
+from .prompt_encoder import PromptEncoder
+
+
+class Sam(nn.Module):
+    mask_threshold: float = 0.0
+    image_format: str = "RGB"
+
+    def __init__(
+        self,
+        image_encoder: ImageEncoderViT,
+        prompt_encoder: PromptEncoder,
+        mask_decoder: MaskDecoder,
+        pixel_mean: List[float] = [123.675, 116.28, 103.53],
+        pixel_std: List[float] = [58.395, 57.12, 57.375],
+    ) -> None:
+        """
+        SAM predicts object masks from an image and input prompts.
+
+        Arguments:
+          image_encoder (ImageEncoderViT): The backbone used to encode the
+            image into image embeddings that allow for efficient mask prediction.
+          prompt_encoder (PromptEncoder): Encodes various types of input prompts.
+          mask_decoder (MaskDecoder): Predicts masks from the image embeddings
+            and encoded prompts.
+          pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
+          pixel_std (list(float)): Std values for normalizing pixels in the input image.
+        """
+        super().__init__()
+        self.image_encoder = image_encoder
+        self.prompt_encoder = prompt_encoder
+        self.mask_decoder = mask_decoder
+        self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+    @property
+    def device(self) -> Any:
+        return self.pixel_mean.device
+
+    def forward(
+        self,
+        batched_input: List[Dict[str, Any]],
+        multimask_output: bool,
+        hq_token_only: bool =False,
+    ) -> List[Dict[str, torch.Tensor]]:
+        """
+        Predicts masks end-to-end from provided images and prompts.
+        If prompts are not known in advance, using SamPredictor is
+        recommended over calling the model directly.
+
+        Arguments:
+          batched_input (list(dict)): A list over input images, each a
+            dictionary with the following keys. A prompt key can be
+            excluded if it is not present.
+              'image': The image as a torch tensor in 3xHxW format,
+                already transformed for input to the model.
+              'original_size': (tuple(int, int)) The original size of
+                the image before transformation, as (H, W).
+              'point_coords': (torch.Tensor) Batched point prompts for
+                this image, with shape BxNx2. Already transformed to the
+                input frame of the model.
+              'point_labels': (torch.Tensor) Batched labels for point prompts,
+                with shape BxN.
+              'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
+                Already transformed to the input frame of the model.
+              'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
+                in the form Bx1xHxW.
+          multimask_output (bool): Whether the model should predict multiple
+            disambiguating masks, or return a single mask.
+
+        Returns:
+          (list(dict)): A list over input images, where each element is
+            as dictionary with the following keys.
+              'masks': (torch.Tensor) Batched binary mask predictions,
+                with shape BxCxHxW, where B is the number of input prompts,
+                C is determined by multimask_output, and (H, W) is the
+                original size of the image.
+              'iou_predictions': (torch.Tensor) The model's predictions
+                of mask quality, in shape BxC.
+              'low_res_logits': (torch.Tensor) Low resolution logits with
+                shape BxCxHxW, where H=W=256. Can be passed as mask input
+                to subsequent iterations of prediction.
+        """
+        input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+        image_embeddings, interm_embeddings = self.image_encoder(input_images)
+        interm_embeddings = interm_embeddings[0] # early layer
+
+        outputs = []
+        for image_record, curr_embedding, curr_interm in zip(batched_input, image_embeddings, interm_embeddings):
+            if "point_coords" in image_record:
+                points = (image_record["point_coords"], image_record["point_labels"])
+            else:
+                points = None
+            sparse_embeddings, dense_embeddings = self.prompt_encoder(
+                points=points,
+                boxes=image_record.get("boxes", None),
+                masks=image_record.get("mask_inputs", None),
+            )
+            low_res_masks, iou_predictions = self.mask_decoder(
+                image_embeddings=curr_embedding.unsqueeze(0),
+                image_pe=self.prompt_encoder.get_dense_pe(),
+                sparse_prompt_embeddings=sparse_embeddings,
+                dense_prompt_embeddings=dense_embeddings,
+                multimask_output=multimask_output,
+                hq_token_only=hq_token_only,
+                interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
+            )
+            masks = self.postprocess_masks(
+                low_res_masks,
+                input_size=image_record["image"].shape[-2:],
+                original_size=image_record["original_size"],
+            )
+            masks = masks > self.mask_threshold
+            outputs.append(
+                {
+                    "masks": masks,
+                    "iou_predictions": iou_predictions,
+                    "low_res_logits": low_res_masks,
+                }
+            )
+        return outputs, 
+
+    def postprocess_masks(
+        self,
+        masks: torch.Tensor,
+        input_size: Tuple[int, ...],
+        original_size: Tuple[int, ...],
+    ) -> torch.Tensor:
+        """
+        Remove padding and upscale masks to the original image size.
+
+        Arguments:
+          masks (torch.Tensor): Batched masks from the mask_decoder,
+            in BxCxHxW format.
+          input_size (tuple(int, int)): The size of the image input to the
+            model, in (H, W) format. Used to remove padding.
+          original_size (tuple(int, int)): The original size of the image
+            before resizing for input to the model, in (H, W) format.
+
+        Returns:
+          (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
+            is given by original_size.
+        """
+        masks = F.interpolate(
+            masks,
+            (self.image_encoder.img_size, self.image_encoder.img_size),
+            mode="bilinear",
+            align_corners=False,
+        )
+        masks = masks[..., : input_size[0], : input_size[1]]
+        masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
+        return masks
+
+    def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+        """Normalize pixel values and pad to a square input."""
+        # Normalize colors
+        x = (x - self.pixel_mean) / self.pixel_std
+
+        # Pad
+        h, w = x.shape[-2:]
+        padh = self.image_encoder.img_size - h
+        padw = self.image_encoder.img_size - w
+        x = F.pad(x, (0, padw, 0, padh))
+        return x
diff --git a/modeling/semantic_enhanced_matting/modeling/tiny_vit_sam.py b/modeling/semantic_enhanced_matting/modeling/tiny_vit_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..65f04aa374599f6bb70fe69c81660df9d4e786e1
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/tiny_vit_sam.py
@@ -0,0 +1,724 @@
+# --------------------------------------------------------
+# TinyViT Model Architecture
+# Copyright (c) 2022 Microsoft
+# Adapted from LeViT and Swin Transformer
+#   LeViT: (https://github.com/facebookresearch/levit)
+#   Swin: (https://github.com/microsoft/swin-transformer)
+# Build the TinyViT Model
+# --------------------------------------------------------
+
+import itertools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath as TimmDropPath,\
+    to_2tuple, trunc_normal_
+from timm.models.registry import register_model
+from typing import Tuple
+
+
+class Conv2d_BN(torch.nn.Sequential):
+    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
+                 groups=1, bn_weight_init=1):
+        super().__init__()
+        self.add_module('c', torch.nn.Conv2d(
+            a, b, ks, stride, pad, dilation, groups, bias=False))
+        bn = torch.nn.BatchNorm2d(b)
+        torch.nn.init.constant_(bn.weight, bn_weight_init)
+        torch.nn.init.constant_(bn.bias, 0)
+        self.add_module('bn', bn)
+
+    @torch.no_grad()
+    def fuse(self):
+        c, bn = self._modules.values()
+        w = bn.weight / (bn.running_var + bn.eps)**0.5
+        w = c.weight * w[:, None, None, None]
+        b = bn.bias - bn.running_mean * bn.weight / \
+            (bn.running_var + bn.eps)**0.5
+        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
+            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
+        m.weight.data.copy_(w)
+        m.bias.data.copy_(b)
+        return m
+
+
+class DropPath(TimmDropPath):
+    def __init__(self, drop_prob=None):
+        super().__init__(drop_prob=drop_prob)
+        self.drop_prob = drop_prob
+
+    def __repr__(self):
+        msg = super().__repr__()
+        msg += f'(drop_prob={self.drop_prob})'
+        return msg
+
+
+class PatchEmbed(nn.Module):
+    def __init__(self, in_chans, embed_dim, resolution, activation):
+        super().__init__()
+        img_size: Tuple[int, int] = to_2tuple(resolution)
+        self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
+        self.num_patches = self.patches_resolution[0] * \
+            self.patches_resolution[1]
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+        n = embed_dim
+        self.seq = nn.Sequential(
+            Conv2d_BN(in_chans, n // 2, 3, 2, 1),
+            activation(),
+            Conv2d_BN(n // 2, n, 3, 2, 1),
+        )
+
+    def forward(self, x):
+        return self.seq(x)
+
+
+class MBConv(nn.Module):
+    def __init__(self, in_chans, out_chans, expand_ratio,
+                 activation, drop_path):
+        super().__init__()
+        self.in_chans = in_chans
+        self.hidden_chans = int(in_chans * expand_ratio)
+        self.out_chans = out_chans
+
+        self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
+        self.act1 = activation()
+
+        self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans,
+                               ks=3, stride=1, pad=1, groups=self.hidden_chans)
+        self.act2 = activation()
+
+        self.conv3 = Conv2d_BN(
+            self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
+        self.act3 = activation()
+
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else nn.Identity()
+
+    def forward(self, x):
+        shortcut = x
+
+        x = self.conv1(x)
+        x = self.act1(x)
+
+        x = self.conv2(x)
+        x = self.act2(x)
+
+        x = self.conv3(x)
+
+        x = self.drop_path(x)
+
+        x += shortcut
+        x = self.act3(x)
+
+        return x
+
+
+class PatchMerging(nn.Module):
+    def __init__(self, input_resolution, dim, out_dim, activation):
+        super().__init__()
+
+        self.input_resolution = input_resolution
+        self.dim = dim
+        self.out_dim = out_dim
+        self.act = activation()
+        self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
+        stride_c=2
+        if(out_dim==320 or out_dim==448 or out_dim==576):
+            stride_c=1
+        self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
+        self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
+
+    def forward(self, x):
+        if x.ndim == 3:
+            H, W = self.input_resolution
+            B = len(x)
+            # (B, C, H, W)
+            x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
+
+        x = self.conv1(x)
+        x = self.act(x)
+
+        x = self.conv2(x)
+        x = self.act(x)
+        x = self.conv3(x)
+        x = x.flatten(2).transpose(1, 2)
+        return x
+
+
+class ConvLayer(nn.Module):
+    def __init__(self, dim, input_resolution, depth,
+                 activation,
+                 drop_path=0., downsample=None, use_checkpoint=False,
+                 out_dim=None,
+                 conv_expand_ratio=4.,
+                 ):
+
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            MBConv(dim, dim, conv_expand_ratio, activation,
+                   drop_path[i] if isinstance(drop_path, list) else drop_path,
+                   )
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(
+                input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.downsample is not None:
+            x = self.downsample(x)
+        return x
+
+
+class Mlp(nn.Module):
+    def __init__(self, in_features, hidden_features=None,
+                 out_features=None, act_layer=nn.GELU, drop=0.):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.norm = nn.LayerNorm(in_features)
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.act = act_layer()
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.norm(x)
+
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+class Attention(torch.nn.Module):
+    def __init__(self, dim, key_dim, num_heads=8,
+                 attn_ratio=4,
+                 resolution=(14, 14),
+                 ):
+        super().__init__()
+        # (h, w)
+        assert isinstance(resolution, tuple) and len(resolution) == 2
+        self.num_heads = num_heads
+        self.scale = key_dim ** -0.5
+        self.key_dim = key_dim
+        self.nh_kd = nh_kd = key_dim * num_heads
+        self.d = int(attn_ratio * key_dim)
+        self.dh = int(attn_ratio * key_dim) * num_heads
+        self.attn_ratio = attn_ratio
+        h = self.dh + nh_kd * 2
+
+        self.norm = nn.LayerNorm(dim)
+        self.qkv = nn.Linear(dim, h)
+        self.proj = nn.Linear(self.dh, dim)
+
+        points = list(itertools.product(
+            range(resolution[0]), range(resolution[1])))
+        N = len(points)
+        attention_offsets = {}
+        idxs = []
+        for p1 in points:
+            for p2 in points:
+                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+                if offset not in attention_offsets:
+                    attention_offsets[offset] = len(attention_offsets)
+                idxs.append(attention_offsets[offset])
+        self.attention_biases = torch.nn.Parameter(
+            torch.zeros(num_heads, len(attention_offsets)))
+        self.register_buffer('attention_bias_idxs',
+                             torch.LongTensor(idxs).view(N, N),
+                             persistent=False)
+
+    @torch.no_grad()
+    def train(self, mode=True):
+        super().train(mode)
+        if mode and hasattr(self, 'ab'):
+            del self.ab
+        else:
+            self.register_buffer('ab',
+                                 self.attention_biases[:, self.attention_bias_idxs],
+                                 persistent=False)
+
+    def forward(self, x):  # x (B,N,C)
+        B, N, _ = x.shape
+
+        # Normalization
+        x = self.norm(x)
+
+        qkv = self.qkv(x)
+        # (B, N, num_heads, d)
+        q, k, v = qkv.view(B, N, self.num_heads, -
+                           1).split([self.key_dim, self.key_dim, self.d], dim=3)
+        # (B, num_heads, N, d)
+        q = q.permute(0, 2, 1, 3)
+        k = k.permute(0, 2, 1, 3)
+        v = v.permute(0, 2, 1, 3)
+
+        attn = (
+            (q @ k.transpose(-2, -1)) * self.scale
+            +
+            (self.attention_biases[:, self.attention_bias_idxs]
+             if self.training else self.ab)
+        )
+        attn = attn.softmax(dim=-1)
+        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
+        x = self.proj(x)
+        return x
+
+
+class TinyViTBlock(nn.Module):
+    r""" TinyViT Block.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int, int]): Input resolution.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        drop (float, optional): Dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        local_conv_size (int): the kernel size of the convolution between
+                               Attention and MLP. Default: 3
+        activation: the activation function. Default: nn.GELU
+    """
+
+    def __init__(self, dim, input_resolution, num_heads, window_size=7,
+                 mlp_ratio=4., drop=0., drop_path=0.,
+                 local_conv_size=3,
+                 activation=nn.GELU,
+                 ):
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.num_heads = num_heads
+        assert window_size > 0, 'window_size must be greater than 0'
+        self.window_size = window_size
+        self.mlp_ratio = mlp_ratio
+
+        self.drop_path = DropPath(
+            drop_path) if drop_path > 0. else nn.Identity()
+
+        assert dim % num_heads == 0, 'dim must be divisible by num_heads'
+        head_dim = dim // num_heads
+
+        window_resolution = (window_size, window_size)
+        self.attn = Attention(dim, head_dim, num_heads,
+                              attn_ratio=1, resolution=window_resolution)
+
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        mlp_activation = activation
+        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
+                       act_layer=mlp_activation, drop=drop)
+
+        pad = local_conv_size // 2
+        self.local_conv = Conv2d_BN(
+            dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
+
+    def forward(self, x):
+        H, W = self.input_resolution
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+        res_x = x
+        if H == self.window_size and W == self.window_size:
+            x = self.attn(x)
+        else:
+            x = x.view(B, H, W, C)
+            pad_b = (self.window_size - H %
+                     self.window_size) % self.window_size
+            pad_r = (self.window_size - W %
+                     self.window_size) % self.window_size
+            padding = pad_b > 0 or pad_r > 0
+
+            if padding:
+                x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
+
+            pH, pW = H + pad_b, W + pad_r
+            nH = pH // self.window_size
+            nW = pW // self.window_size
+            # window partition
+            x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
+                B * nH * nW, self.window_size * self.window_size, C)
+            x = self.attn(x)
+            # window reverse
+            x = x.view(B, nH, nW, self.window_size, self.window_size,
+                       C).transpose(2, 3).reshape(B, pH, pW, C)
+
+            if padding:
+                x = x[:, :H, :W].contiguous()
+
+            x = x.view(B, L, C)
+
+        x = res_x + self.drop_path(x)
+
+        x = x.transpose(1, 2).reshape(B, C, H, W)
+        x = self.local_conv(x)
+        x = x.view(B, C, L).transpose(1, 2)
+
+        x = x + self.drop_path(self.mlp(x))
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+               f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
+
+
+class BasicLayer(nn.Module):
+    """ A basic TinyViT layer for one stage.
+
+    Args:
+        dim (int): Number of input channels.
+        input_resolution (tuple[int]): Input resolution.
+        depth (int): Number of blocks.
+        num_heads (int): Number of attention heads.
+        window_size (int): Local window size.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        drop (float, optional): Dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+        local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
+        activation: the activation function. Default: nn.GELU
+        out_dim: the output dimension of the layer. Default: dim
+    """
+
+    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
+                 mlp_ratio=4., drop=0.,
+                 drop_path=0., downsample=None, use_checkpoint=False,
+                 local_conv_size=3,
+                 activation=nn.GELU,
+                 out_dim=None,
+                 ):
+
+        super().__init__()
+        self.dim = dim
+        self.input_resolution = input_resolution
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList([
+            TinyViTBlock(dim=dim, input_resolution=input_resolution,
+                         num_heads=num_heads, window_size=window_size,
+                         mlp_ratio=mlp_ratio,
+                         drop=drop,
+                         drop_path=drop_path[i] if isinstance(
+                             drop_path, list) else drop_path,
+                         local_conv_size=local_conv_size,
+                         activation=activation,
+                         )
+            for i in range(depth)])
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(
+                input_resolution, dim=dim, out_dim=out_dim, activation=activation)
+        else:
+            self.downsample = None
+
+    def forward(self, x):
+        for blk in self.blocks:
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x)
+            else:
+                x = blk(x)
+        if self.downsample is not None:
+            x = self.downsample(x)
+        return x
+
+    def extra_repr(self) -> str:
+        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
+class TinyViT(nn.Module):
+    def __init__(self, img_size=224, in_chans=3, num_classes=1000,
+                 embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2],
+                 num_heads=[3, 6, 12, 24],
+                 window_sizes=[7, 7, 14, 7],
+                 mlp_ratio=4.,
+                 drop_rate=0.,
+                 drop_path_rate=0.1,
+                 use_checkpoint=False,
+                 mbconv_expand_ratio=4.0,
+                 local_conv_size=3,
+                 layer_lr_decay=1.0,
+                 ):
+        super().__init__()
+        self.img_size=img_size
+        self.num_classes = num_classes
+        self.depths = depths
+        self.num_layers = len(depths)
+        self.mlp_ratio = mlp_ratio
+
+        activation = nn.GELU
+
+        self.patch_embed = PatchEmbed(in_chans=in_chans,
+                                      embed_dim=embed_dims[0],
+                                      resolution=img_size,
+                                      activation=activation)
+
+        patches_resolution = self.patch_embed.patches_resolution
+        self.patches_resolution = patches_resolution
+
+        # stochastic depth
+        dpr = [x.item() for x in torch.linspace(0, drop_path_rate,
+                                                sum(depths))]  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            kwargs = dict(dim=embed_dims[i_layer],
+                        input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)),
+                                patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))),
+                        #   input_resolution=(patches_resolution[0] // (2 ** i_layer),
+                        #                     patches_resolution[1] // (2 ** i_layer)),
+                          depth=depths[i_layer],
+                          drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+                          downsample=PatchMerging if (
+                              i_layer < self.num_layers - 1) else None,
+                          use_checkpoint=use_checkpoint,
+                          out_dim=embed_dims[min(
+                              i_layer + 1, len(embed_dims) - 1)],
+                          activation=activation,
+                          )
+            if i_layer == 0:
+                layer = ConvLayer(
+                    conv_expand_ratio=mbconv_expand_ratio,
+                    **kwargs,
+                )
+            else:
+                layer = BasicLayer(
+                    num_heads=num_heads[i_layer],
+                    window_size=window_sizes[i_layer],
+                    mlp_ratio=self.mlp_ratio,
+                    drop=drop_rate,
+                    local_conv_size=local_conv_size,
+                    **kwargs)
+            self.layers.append(layer)
+
+        # Classifier head
+        self.norm_head = nn.LayerNorm(embed_dims[-1])
+        self.head = nn.Linear(
+            embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
+
+        # init weights
+        self.apply(self._init_weights)
+        self.set_layer_lr_decay(layer_lr_decay)
+        self.neck = nn.Sequential(
+            nn.Conv2d(
+                embed_dims[-1],
+                256,
+                kernel_size=1,
+                bias=False,
+            ),
+            LayerNorm2d(256),
+            nn.Conv2d(
+                256,
+                256,
+                kernel_size=3,
+                padding=1,
+                bias=False,
+            ),
+            LayerNorm2d(256),
+        )
+    def set_layer_lr_decay(self, layer_lr_decay):
+        decay_rate = layer_lr_decay
+
+        # layers -> blocks (depth)
+        depth = sum(self.depths)
+        lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
+        #print("LR SCALES:", lr_scales)
+
+        def _set_lr_scale(m, scale):
+            for p in m.parameters():
+                p.lr_scale = scale
+
+        self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
+        i = 0
+        for layer in self.layers:
+            for block in layer.blocks:
+                block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
+                i += 1
+            if layer.downsample is not None:
+                layer.downsample.apply(
+                    lambda x: _set_lr_scale(x, lr_scales[i - 1]))
+        assert i == depth
+        for m in [self.norm_head, self.head]:
+            m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
+
+        for k, p in self.named_parameters():
+            p.param_name = k
+
+        def _check_lr_scale(m):
+            for p in m.parameters():
+                assert hasattr(p, 'lr_scale'), p.param_name
+
+        self.apply(_check_lr_scale)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+        elif isinstance(m, nn.LayerNorm):
+            nn.init.constant_(m.bias, 0)
+            nn.init.constant_(m.weight, 1.0)
+
+    @torch.jit.ignore
+    def no_weight_decay_keywords(self):
+        return {'attention_biases'}
+
+    def forward_features(self, x):
+        # x: (N, C, H, W)
+        x = self.patch_embed(x)
+
+        x = self.layers[0](x)
+        start_i = 1
+
+        interm_embeddings=[]
+        for i in range(start_i, len(self.layers)):
+            layer = self.layers[i]
+            x = layer(x)
+            # print('x shape:', x.shape, '---i:', i)
+            if i == 1:
+                interm_embeddings.append(x.view(x.shape[0], 64, 64, -1))
+
+        B,_,C=x.size()
+        x = x.view(B, 64, 64, C)
+        x=x.permute(0, 3, 1, 2)
+        x=self.neck(x)
+        return x, interm_embeddings
+
+    def forward(self, x):
+        x, interm_embeddings = self.forward_features(x)
+        #x = self.norm_head(x)
+        #x = self.head(x)
+        # print('come to here is correct'* 3)
+        return x, interm_embeddings
+
+
+_checkpoint_url_format = \
+    'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth'
+_provided_checkpoints = {
+    'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill',
+    'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill',
+    'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill',
+    'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill',
+    'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill',
+}
+
+
+def register_tiny_vit_model(fn):
+    '''Register a TinyViT model
+    It is a wrapper of `register_model` with loading the pretrained checkpoint.
+    '''
+    def fn_wrapper(pretrained=False, **kwargs):
+        model = fn()
+        if pretrained:
+            model_name = fn.__name__
+            assert model_name in _provided_checkpoints, \
+                f'Sorry that the checkpoint `{model_name}` is not provided yet.'
+            url = _checkpoint_url_format.format(
+                _provided_checkpoints[model_name])
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url,
+                map_location='cpu', check_hash=False,
+            )
+            model.load_state_dict(checkpoint['model'])
+
+        return model
+
+    # rename the name of fn_wrapper
+    fn_wrapper.__name__ = fn.__name__
+    return register_model(fn_wrapper)
+
+
+@register_tiny_vit_model
+def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0):
+    return TinyViT(
+        num_classes=num_classes,
+        embed_dims=[64, 128, 160, 320],
+        depths=[2, 2, 6, 2],
+        num_heads=[2, 4, 5, 10],
+        window_sizes=[7, 7, 14, 7],
+        drop_path_rate=drop_path_rate,
+    )
+
+
+@register_tiny_vit_model
+def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1):
+    return TinyViT(
+        num_classes=num_classes,
+        embed_dims=[64, 128, 256, 448],
+        depths=[2, 2, 6, 2],
+        num_heads=[2, 4, 8, 14],
+        window_sizes=[7, 7, 14, 7],
+        drop_path_rate=drop_path_rate,
+    )
+
+
+@register_tiny_vit_model
+def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2):
+    return TinyViT(
+        num_classes=num_classes,
+        embed_dims=[96, 192, 384, 576],
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 18],
+        window_sizes=[7, 7, 14, 7],
+        drop_path_rate=drop_path_rate,
+    )
+
+
+@register_tiny_vit_model
+def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1):
+    return TinyViT(
+        img_size=384,
+        num_classes=num_classes,
+        embed_dims=[96, 192, 384, 576],
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 18],
+        window_sizes=[12, 12, 24, 12],
+        drop_path_rate=drop_path_rate,
+    )
+
+
+@register_tiny_vit_model
+def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1):
+    return TinyViT(
+        img_size=512,
+        num_classes=num_classes,
+        embed_dims=[96, 192, 384, 576],
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 18],
+        window_sizes=[16, 16, 32, 16],
+        drop_path_rate=drop_path_rate,
+    )
diff --git a/modeling/semantic_enhanced_matting/modeling/transformer.py b/modeling/semantic_enhanced_matting/modeling/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/modeling/transformer.py
@@ -0,0 +1,240 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch import Tensor, nn
+
+import math
+from typing import Tuple, Type
+
+from .common import MLPBlock
+
+
+class TwoWayTransformer(nn.Module):
+    def __init__(
+        self,
+        depth: int,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+    ) -> None:
+        """
+        A transformer decoder that attends to an input image using
+        queries whose positional embedding is supplied.
+
+        Args:
+          depth (int): number of layers in the transformer
+          embedding_dim (int): the channel dimension for the input embeddings
+          num_heads (int): the number of heads for multihead attention. Must
+            divide embedding_dim
+          mlp_dim (int): the channel dimension internal to the MLP block
+          activation (nn.Module): the activation to use in the MLP block
+        """
+        super().__init__()
+        self.depth = depth
+        self.embedding_dim = embedding_dim
+        self.num_heads = num_heads
+        self.mlp_dim = mlp_dim
+        self.layers = nn.ModuleList()
+
+        for i in range(depth):
+            self.layers.append(
+                TwoWayAttentionBlock(
+                    embedding_dim=embedding_dim,
+                    num_heads=num_heads,
+                    mlp_dim=mlp_dim,
+                    activation=activation,
+                    attention_downsample_rate=attention_downsample_rate,
+                    skip_first_layer_pe=(i == 0),
+                )
+            )
+
+        self.final_attn_token_to_image = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+        self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+    def forward(
+        self,
+        image_embedding: Tensor,
+        image_pe: Tensor,
+        point_embedding: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+          image_embedding (torch.Tensor): image to attend to. Should be shape
+            B x embedding_dim x h x w for any h and w.
+          image_pe (torch.Tensor): the positional encoding to add to the image. Must
+            have the same shape as image_embedding.
+          point_embedding (torch.Tensor): the embedding to add to the query points.
+            Must have shape B x N_points x embedding_dim for any N_points.
+
+        Returns:
+          torch.Tensor: the processed point_embedding
+          torch.Tensor: the processed image_embedding
+        """
+        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+        bs, c, h, w = image_embedding.shape
+        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+        image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+        # Prepare queries
+        queries = point_embedding
+        keys = image_embedding
+
+        # Apply transformer blocks and final layernorm
+        for layer in self.layers:
+            queries, keys = layer(
+                queries=queries,
+                keys=keys,
+                query_pe=point_embedding,
+                key_pe=image_pe,
+            )
+
+        # Apply the final attention layer from the points to the image
+        q = queries + point_embedding
+        k = keys + image_pe
+        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm_final_attn(queries)
+
+        return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int = 2048,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+        skip_first_layer_pe: bool = False,
+    ) -> None:
+        """
+        A transformer block with four layers: (1) self-attention of sparse
+        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+        block on sparse inputs, and (4) cross attention of dense inputs to sparse
+        inputs.
+
+        Arguments:
+          embedding_dim (int): the channel dimension of the embeddings
+          num_heads (int): the number of heads in the attention layers
+          mlp_dim (int): the hidden dimension of the mlp block
+          activation (nn.Module): the activation of the mlp block
+          skip_first_layer_pe (bool): skip the PE on the first layer
+        """
+        super().__init__()
+        self.self_attn = Attention(embedding_dim, num_heads)
+        self.norm1 = nn.LayerNorm(embedding_dim)
+
+        self.cross_attn_token_to_image = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+        self.norm2 = nn.LayerNorm(embedding_dim)
+
+        self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
+        self.norm3 = nn.LayerNorm(embedding_dim)
+
+        self.norm4 = nn.LayerNorm(embedding_dim)
+        self.cross_attn_image_to_token = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+
+        self.skip_first_layer_pe = skip_first_layer_pe
+
+    def forward(
+        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        # Self attention block
+        if self.skip_first_layer_pe:
+            queries = self.self_attn(q=queries, k=queries, v=queries)
+        else:
+            q = queries + query_pe
+            attn_out = self.self_attn(q=q, k=q, v=queries)
+            queries = queries + attn_out
+        queries = self.norm1(queries)
+
+        # Cross attention block, tokens attending to image embedding
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm2(queries)
+
+        # MLP block
+        mlp_out = self.mlp(queries)
+        queries = queries + mlp_out
+        queries = self.norm3(queries)
+
+        # Cross attention block, image embedding attending to tokens
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+        keys = keys + attn_out
+        keys = self.norm4(keys)
+
+        return queries, keys
+
+
+class Attention(nn.Module):
+    """
+    An attention layer that allows for downscaling the size of the embedding
+    after projection to queries, keys, and values.
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        downsample_rate: int = 1,
+    ) -> None:
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.internal_dim = embedding_dim // downsample_rate
+        self.num_heads = num_heads
+        assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+
+        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+        b, n, c = x.shape
+        x = x.reshape(b, n, num_heads, c // num_heads)
+        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head
+
+    def _recombine_heads(self, x: Tensor) -> Tensor:
+        b, n_heads, n_tokens, c_per_head = x.shape
+        x = x.transpose(1, 2)
+        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
+
+    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+        # Input projections
+        q = self.q_proj(q)
+        k = self.k_proj(k)
+        v = self.v_proj(v)
+
+        # Separate into heads
+        q = self._separate_heads(q, self.num_heads)
+        k = self._separate_heads(k, self.num_heads)
+        v = self._separate_heads(v, self.num_heads)
+
+        # Attention
+        _, _, _, c_per_head = q.shape
+        attn = q @ k.permute(0, 1, 3, 2)  # B x N_heads x N_tokens x N_tokens
+        attn = attn / math.sqrt(c_per_head)
+        attn = torch.softmax(attn, dim=-1)
+
+        # Get output
+        out = attn @ v
+        out = self._recombine_heads(out)
+        out = self.out_proj(out)
+
+        return out
diff --git a/modeling/semantic_enhanced_matting/predictor.py b/modeling/semantic_enhanced_matting/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..73d078d03b6190985d8fa611231efa6844a6d247
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/predictor.py
@@ -0,0 +1,276 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+from .modeling import Sam
+
+from typing import Optional, Tuple
+
+from .utils.transforms import ResizeLongestSide
+
+
+class SamPredictor:
+    def __init__(
+        self,
+        sam_model: Sam,
+    ) -> None:
+        """
+        Uses SAM to calculate the image embedding for an image, and then
+        allow repeated, efficient mask prediction given prompts.
+
+        Arguments:
+          sam_model (Sam): The model to use for mask prediction.
+        """
+        super().__init__()
+        self.model = sam_model
+        self.transform = ResizeLongestSide(sam_model.image_encoder.img_size if hasattr(sam_model.image_encoder, 'img_size') else sam_model.image_size)
+        self.reset_image()
+
+    def set_image(
+        self,
+        image: np.ndarray,
+        image_format: str = "RGB",
+    ) -> None:
+        """
+        Calculates the image embeddings for the provided image, allowing
+        masks to be predicted with the 'predict' method.
+
+        Arguments:
+          image (np.ndarray): The image for calculating masks. Expects an
+            image in HWC uint8 format, with pixel values in [0, 255].
+          image_format (str): The color format of the image, in ['RGB', 'BGR'].
+        """
+        assert image_format in [
+            "RGB",
+            "BGR",
+        ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+        # import pdb;pdb.set_trace()
+        if image_format != self.model.image_format:
+            image = image[..., ::-1]
+
+        # Transform the image to the form expected by the model
+        # import pdb;pdb.set_trace()
+        input_image = self.transform.apply_image(image)
+        input_image_torch = torch.as_tensor(input_image, device=self.device)
+        input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
+
+        self.set_torch_image(input_image_torch, image.shape[:2])
+
+    @torch.no_grad()
+    def set_torch_image(
+        self,
+        transformed_image: torch.Tensor,
+        original_image_size: Tuple[int, ...],
+    ) -> None:
+        """
+        Calculates the image embeddings for the provided image, allowing
+        masks to be predicted with the 'predict' method. Expects the input
+        image to be already transformed to the format expected by the model.
+
+        Arguments:
+          transformed_image (torch.Tensor): The input image, with shape
+            1x3xHxW, which has been transformed with ResizeLongestSide.
+          original_image_size (tuple(int, int)): The size of the image
+            before transformation, in (H, W) format.
+        """
+        assert (
+            len(transformed_image.shape) == 4
+            and transformed_image.shape[1] == 3
+            and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
+        ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
+        self.reset_image()
+
+        self.original_size = original_image_size
+        self.input_size = tuple(transformed_image.shape[-2:])
+        input_image = self.model.preprocess(transformed_image)
+        self.features, self.interm_features = self.model.image_encoder(input_image)
+        self.is_image_set = True
+
+    def predict(
+        self,
+        point_coords: Optional[np.ndarray] = None,
+        point_labels: Optional[np.ndarray] = None,
+        box: Optional[np.ndarray] = None,
+        mask_input: Optional[np.ndarray] = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+        hq_token_only: bool =False,
+    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+        """
+        Predict masks for the given input prompts, using the currently set image.
+
+        Arguments:
+          point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+            model. Each point is in (X,Y) in pixels.
+          point_labels (np.ndarray or None): A length N array of labels for the
+            point prompts. 1 indicates a foreground point and 0 indicates a
+            background point.
+          box (np.ndarray or None): A length 4 array given a box prompt to the
+            model, in XYXY format.
+          mask_input (np.ndarray): A low resolution mask input to the model, typically
+            coming from a previous prediction iteration. Has form 1xHxW, where
+            for SAM, H=W=256.
+          multimask_output (bool): If true, the model will return three masks.
+            For ambiguous input prompts (such as a single click), this will often
+            produce better masks than a single prediction. If only a single
+            mask is needed, the model's predicted quality score can be used
+            to select the best mask. For non-ambiguous prompts, such as multiple
+            input prompts, multimask_output=False can give better results.
+          return_logits (bool): If true, returns un-thresholded masks logits
+            instead of a binary mask.
+
+        Returns:
+          (np.ndarray): The output masks in CxHxW format, where C is the
+            number of masks, and (H, W) is the original image size.
+          (np.ndarray): An array of length C containing the model's
+            predictions for the quality of each mask.
+          (np.ndarray): An array of shape CxHxW, where C is the number
+            of masks and H=W=256. These low resolution logits can be passed to
+            a subsequent iteration as mask input.
+        """
+        if not self.is_image_set:
+            raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+        # Transform input prompts
+        coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
+        if point_coords is not None:
+            assert (
+                point_labels is not None
+            ), "point_labels must be supplied if point_coords is supplied."
+            point_coords = self.transform.apply_coords(point_coords, self.original_size)
+            coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
+            labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+            coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+        if box is not None:
+            box = self.transform.apply_boxes(box, self.original_size)
+            box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
+            box_torch = box_torch[None, :]
+        if mask_input is not None:
+            mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
+            mask_input_torch = mask_input_torch[None, :, :, :]
+
+        masks, iou_predictions, low_res_masks = self.predict_torch(
+            coords_torch,
+            labels_torch,
+            box_torch,
+            mask_input_torch,
+            multimask_output,
+            return_logits=return_logits,
+            hq_token_only=hq_token_only,
+        )
+
+        masks_np = masks[0].detach().cpu().numpy()
+        iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
+        low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
+        return masks_np, iou_predictions_np, low_res_masks_np
+
+    @torch.no_grad()
+    def predict_torch(
+        self,
+        point_coords: Optional[torch.Tensor],
+        point_labels: Optional[torch.Tensor],
+        boxes: Optional[torch.Tensor] = None,
+        mask_input: Optional[torch.Tensor] = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+        hq_token_only: bool =False,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Predict masks for the given input prompts, using the currently set image.
+        Input prompts are batched torch tensors and are expected to already be
+        transformed to the input frame using ResizeLongestSide.
+
+        Arguments:
+          point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+            model. Each point is in (X,Y) in pixels.
+          point_labels (torch.Tensor or None): A BxN array of labels for the
+            point prompts. 1 indicates a foreground point and 0 indicates a
+            background point.
+          boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+            model, in XYXY format.
+          mask_input (np.ndarray): A low resolution mask input to the model, typically
+            coming from a previous prediction iteration. Has form Bx1xHxW, where
+            for SAM, H=W=256. Masks returned by a previous iteration of the
+            predict method do not need further transformation.
+          multimask_output (bool): If true, the model will return three masks.
+            For ambiguous input prompts (such as a single click), this will often
+            produce better masks than a single prediction. If only a single
+            mask is needed, the model's predicted quality score can be used
+            to select the best mask. For non-ambiguous prompts, such as multiple
+            input prompts, multimask_output=False can give better results.
+          return_logits (bool): If true, returns un-thresholded masks logits
+            instead of a binary mask.
+
+        Returns:
+          (torch.Tensor): The output masks in BxCxHxW format, where C is the
+            number of masks, and (H, W) is the original image size.
+          (torch.Tensor): An array of shape BxC containing the model's
+            predictions for the quality of each mask.
+          (torch.Tensor): An array of shape BxCxHxW, where C is the number
+            of masks and H=W=256. These low res logits can be passed to
+            a subsequent iteration as mask input.
+        """
+        if not self.is_image_set:
+            raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+
+        if point_coords is not None:
+            points = (point_coords, point_labels)
+        else:
+            points = None
+
+        # Embed prompts
+        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+            points=points,
+            boxes=boxes,
+            masks=mask_input,
+        )
+
+        # Predict masks
+        low_res_masks, iou_predictions = self.model.mask_decoder(
+            image_embeddings=self.features,
+            image_pe=self.model.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+            hq_token_only=hq_token_only,
+            interm_embeddings=self.interm_features,
+        )
+
+        # Upscale the masks to the original image resolution
+        masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
+
+        if not return_logits:
+            masks = masks > self.model.mask_threshold
+
+        return masks, iou_predictions, low_res_masks
+
+    def get_image_embedding(self) -> torch.Tensor:
+        """
+        Returns the image embeddings for the currently set image, with
+        shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+        the embedding spatial dimension of SAM (typically C=256, H=W=64).
+        """
+        if not self.is_image_set:
+            raise RuntimeError(
+                "An image must be set with .set_image(...) to generate an embedding."
+            )
+        assert self.features is not None, "Features must exist if an image has been set."
+        return self.features
+
+    @property
+    def device(self) -> torch.device:
+        return self.model.device
+
+    def reset_image(self) -> None:
+        """Resets the currently set image."""
+        self.is_image_set = False
+        self.features = None
+        self.orig_h = None
+        self.orig_w = None
+        self.input_h = None
+        self.input_w = None
diff --git a/modeling/semantic_enhanced_matting/utils/__init__.py b/modeling/semantic_enhanced_matting/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/modeling/semantic_enhanced_matting/utils/__pycache__/__init__.cpython-38.pyc b/modeling/semantic_enhanced_matting/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f79164a21c51300b8d96a5a74c38600cb3fb4305
Binary files /dev/null and b/modeling/semantic_enhanced_matting/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/utils/__pycache__/amg.cpython-38.pyc b/modeling/semantic_enhanced_matting/utils/__pycache__/amg.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..583eb6bdabf07e41d55ce5df361ccaa702a47d70
Binary files /dev/null and b/modeling/semantic_enhanced_matting/utils/__pycache__/amg.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/utils/__pycache__/transforms.cpython-38.pyc b/modeling/semantic_enhanced_matting/utils/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..084df8b376efbe2fac8deca61d316bf8730cc6d5
Binary files /dev/null and b/modeling/semantic_enhanced_matting/utils/__pycache__/transforms.cpython-38.pyc differ
diff --git a/modeling/semantic_enhanced_matting/utils/amg.py b/modeling/semantic_enhanced_matting/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..be064071ef399fea96c673ad173689656c23534a
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/utils/amg.py
@@ -0,0 +1,346 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+
+class MaskData:
+    """
+    A structure for storing masks and their related data in batched format.
+    Implements basic filtering and concatenation.
+    """
+
+    def __init__(self, **kwargs) -> None:
+        for v in kwargs.values():
+            assert isinstance(
+                v, (list, np.ndarray, torch.Tensor)
+            ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats = dict(**kwargs)
+
+    def __setitem__(self, key: str, item: Any) -> None:
+        assert isinstance(
+            item, (list, np.ndarray, torch.Tensor)
+        ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats[key] = item
+
+    def __delitem__(self, key: str) -> None:
+        del self._stats[key]
+
+    def __getitem__(self, key: str) -> Any:
+        return self._stats[key]
+
+    def items(self) -> ItemsView[str, Any]:
+        return self._stats.items()
+
+    def filter(self, keep: torch.Tensor) -> None:
+        for k, v in self._stats.items():
+            if v is None:
+                self._stats[k] = None
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = v[keep.detach().cpu().numpy()]
+            elif isinstance(v, list) and keep.dtype == torch.bool:
+                self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+            elif isinstance(v, list):
+                self._stats[k] = [v[i] for i in keep]
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+    def cat(self, new_stats: "MaskData") -> None:
+        for k, v in new_stats.items():
+            if k not in self._stats or self._stats[k] is None:
+                self._stats[k] = deepcopy(v)
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+            elif isinstance(v, list):
+                self._stats[k] = self._stats[k] + deepcopy(v)
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+    def to_numpy(self) -> None:
+        for k, v in self._stats.items():
+            if isinstance(v, torch.Tensor):
+                self._stats[k] = v.detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+    boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+    """Filter masks at the edge of a crop, but not at the edge of the original image."""
+    crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+    orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+    boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+    near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+    near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+    near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+    return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+    box_xywh = deepcopy(box_xyxy)
+    box_xywh[2] = box_xywh[2] - box_xywh[0]
+    box_xywh[3] = box_xywh[3] - box_xywh[1]
+    return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+    assert len(args) > 0 and all(
+        len(a) == len(args[0]) for a in args
+    ), "Batched iteration must have inputs of all the same size."
+    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+    for b in range(n_batches):
+        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+    """
+    Encodes masks to an uncompressed RLE, in the format expected by
+    pycoco tools.
+    """
+    # Put in fortran order and flatten h,w
+    b, h, w = tensor.shape
+    tensor = tensor.permute(0, 2, 1).flatten(1)
+
+    # Compute change indices
+    diff = tensor[:, 1:] ^ tensor[:, :-1]
+    change_indices = diff.nonzero()
+
+    # Encode run length
+    out = []
+    for i in range(b):
+        cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+        cur_idxs = torch.cat(
+            [
+                torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+                cur_idxs + 1,
+                torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+            ]
+        )
+        btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+        counts = [] if tensor[i, 0] == 0 else [0]
+        counts.extend(btw_idxs.detach().cpu().tolist())
+        out.append({"size": [h, w], "counts": counts})
+    return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+    """Compute a binary mask from an uncompressed RLE."""
+    h, w = rle["size"]
+    mask = np.empty(h * w, dtype=bool)
+    idx = 0
+    parity = False
+    for count in rle["counts"]:
+        mask[idx : idx + count] = parity
+        idx += count
+        parity ^= True
+    mask = mask.reshape(w, h)
+    return mask.transpose()  # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+    return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+    masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+    """
+    Computes the stability score for a batch of masks. The stability
+    score is the IoU between the binary masks obtained by thresholding
+    the predicted mask logits at high and low values.
+    """
+    # One mask is always contained inside the other.
+    # Save memory by preventing unnecessary cast to torch.int64
+    intersections = (
+        (masks > (mask_threshold + threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    unions = (
+        (masks > (mask_threshold - threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+    """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+    offset = 1 / (2 * n_per_side)
+    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+    points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+    return points
+
+
+def build_all_layer_point_grids(
+    n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+    """Generates point grids for all crop layers."""
+    points_by_layer = []
+    for i in range(n_layers + 1):
+        n_points = int(n_per_side / (scale_per_layer**i))
+        points_by_layer.append(build_point_grid(n_points))
+    return points_by_layer
+
+
+def generate_crop_boxes(
+    im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+    """
+    Generates a list of crop boxes of different sizes. Each layer
+    has (2**i)**2 boxes for the ith layer.
+    """
+    crop_boxes, layer_idxs = [], []
+    im_h, im_w = im_size
+    short_side = min(im_h, im_w)
+
+    # Original image
+    crop_boxes.append([0, 0, im_w, im_h])
+    layer_idxs.append(0)
+
+    def crop_len(orig_len, n_crops, overlap):
+        return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+    for i_layer in range(n_layers):
+        n_crops_per_side = 2 ** (i_layer + 1)
+        overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+        crop_w = crop_len(im_w, n_crops_per_side, overlap)
+        crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+        crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+        crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+        # Crops in XYWH format
+        for x0, y0 in product(crop_box_x0, crop_box_y0):
+            box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+            crop_boxes.append(box)
+            layer_idxs.append(i_layer + 1)
+
+    return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+    # Check if boxes has a channel dimension
+    if len(boxes.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0]], device=points.device)
+    # Check if points has a channel dimension
+    if len(points.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return points + offset
+
+
+def uncrop_masks(
+    masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+    x0, y0, x1, y1 = crop_box
+    if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+        return masks
+    # Coordinate transform masks
+    pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+    pad = (x0, pad_x - x0, y0, pad_y - y0)
+    return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+    mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+    """
+    Removes small disconnected regions and holes in a mask. Returns the
+    mask and an indicator of if the mask has been modified.
+    """
+    import cv2  # type: ignore
+
+    assert mode in ["holes", "islands"]
+    correct_holes = mode == "holes"
+    working_mask = (correct_holes ^ mask).astype(np.uint8)
+    n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+    sizes = stats[:, -1][1:]  # Row 0 is background label
+    small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+    if len(small_regions) == 0:
+        return mask, False
+    fill_labels = [0] + small_regions
+    if not correct_holes:
+        fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+        # If every region is below threshold, keep largest
+        if len(fill_labels) == 0:
+            fill_labels = [int(np.argmax(sizes)) + 1]
+    mask = np.isin(regions, fill_labels)
+    return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+    from pycocotools import mask as mask_utils  # type: ignore
+
+    h, w = uncompressed_rle["size"]
+    rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+    rle["counts"] = rle["counts"].decode("utf-8")  # Necessary to serialize with json
+    return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+    """
+    Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+    an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+    """
+    # torch.max below raises an error on empty inputs, just skip in this case
+    if torch.numel(masks) == 0:
+        return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+    # Normalize shape to CxHxW
+    shape = masks.shape
+    h, w = shape[-2:]
+    if len(shape) > 2:
+        masks = masks.flatten(0, -3)
+    else:
+        masks = masks.unsqueeze(0)
+
+    # Get top and bottom edges
+    in_height, _ = torch.max(masks, dim=-1)
+    in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+    bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+    in_height_coords = in_height_coords + h * (~in_height)
+    top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+    # Get left and right edges
+    in_width, _ = torch.max(masks, dim=-2)
+    in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+    right_edges, _ = torch.max(in_width_coords, dim=-1)
+    in_width_coords = in_width_coords + w * (~in_width)
+    left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+    # If the mask is empty the right edge will be to the left of the left edge.
+    # Replace these boxes with [0, 0, 0, 0]
+    empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+    out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+    out = out * (~empty_filter).unsqueeze(-1)
+
+    # Return to original shape
+    if len(shape) > 2:
+        out = out.reshape(*shape[:-2], 4)
+    else:
+        out = out[0]
+
+    return out
diff --git a/modeling/semantic_enhanced_matting/utils/onnx.py b/modeling/semantic_enhanced_matting/utils/onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..8013dc43d0373f1d84cd7ff7950822ff12b82a82
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/utils/onnx.py
@@ -0,0 +1,155 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from typing import Tuple
+
+from ..modeling import Sam
+from .amg import calculate_stability_score
+
+
+class SamOnnxModel(nn.Module):
+    """
+    This model should not be called directly, but is used in ONNX export.
+    It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
+    with some functions modified to enable model tracing. Also supports extra
+    options controlling what information. See the ONNX export script for details.
+    """
+
+    def __init__(
+        self,
+        model: Sam,
+        hq_token_only: bool = False,
+        multimask_output: bool = False,
+        use_stability_score: bool = False,
+        return_extra_metrics: bool = False,
+    ) -> None:
+        super().__init__()
+        self.mask_decoder = model.mask_decoder
+        self.model = model
+        self.img_size = model.image_encoder.img_size
+        self.hq_token_only = hq_token_only
+        self.multimask_output = multimask_output
+        self.use_stability_score = use_stability_score
+        self.stability_score_offset = 1.0
+        self.return_extra_metrics = return_extra_metrics
+
+    @staticmethod
+    def resize_longest_image_size(
+        input_image_size: torch.Tensor, longest_side: int
+    ) -> torch.Tensor:
+        input_image_size = input_image_size.to(torch.float32)
+        scale = longest_side / torch.max(input_image_size)
+        transformed_size = scale * input_image_size
+        transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
+        return transformed_size
+
+    def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
+        point_coords = point_coords + 0.5
+        point_coords = point_coords / self.img_size
+        point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
+        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
+
+        point_embedding = point_embedding * (point_labels != -1)
+        point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
+            point_labels == -1
+        )
+
+        for i in range(self.model.prompt_encoder.num_point_embeddings):
+            point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
+                i
+            ].weight * (point_labels == i)
+
+        return point_embedding
+
+    def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
+        mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
+        mask_embedding = mask_embedding + (
+            1 - has_mask_input
+        ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
+        return mask_embedding
+
+    def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
+        masks = F.interpolate(
+            masks,
+            size=(self.img_size, self.img_size),
+            mode="bilinear",
+            align_corners=False,
+        )
+
+        prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
+        masks = masks[..., : prepadded_size[0], : prepadded_size[1]]  # type: ignore
+
+        orig_im_size = orig_im_size.to(torch.int64)
+        h, w = orig_im_size[0], orig_im_size[1]
+        masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
+        return masks
+
+
+    @torch.no_grad()
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        interm_embeddings: torch.Tensor,
+        point_coords: torch.Tensor,
+        point_labels: torch.Tensor,
+        mask_input: torch.Tensor,
+        has_mask_input: torch.Tensor,
+        orig_im_size: torch.Tensor,
+    ):
+        sparse_embedding = self._embed_points(point_coords, point_labels)
+        dense_embedding = self._embed_masks(mask_input, has_mask_input)
+
+        vit_features = interm_embeddings[0].permute(0, 3, 1, 2) # early-layer ViT feature, after 1st global attention block in ViT
+        hq_features = self.model.mask_decoder.embedding_encoder(image_embeddings) + self.model.mask_decoder.compress_vit_feat(vit_features)
+
+        masks, scores = self.model.mask_decoder.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=self.model.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embedding,
+            dense_prompt_embeddings=dense_embedding,
+            hq_features=hq_features,
+        )
+
+        if self.use_stability_score:
+            scores = calculate_stability_score(
+                masks, self.model.mask_threshold, self.stability_score_offset
+            )
+
+        if self.multimask_output:
+            # mask with highest score
+            mask_slice = slice(1,self.model.mask_decoder.num_mask_tokens-1)
+            scores = scores[:, mask_slice]
+            scores, max_iou_idx = torch.max(scores,dim=1)
+            scores = scores.unsqueeze(1)
+            masks_multi = masks[:, mask_slice, :, :]
+            masks_sam = masks_multi[torch.arange(masks_multi.size(0)),max_iou_idx].unsqueeze(1)
+        else:
+            # singale mask output, default
+            mask_slice = slice(0, 1)
+            scores = scores[:,mask_slice]
+            masks_sam = masks[:,mask_slice]
+
+        masks_hq = masks[:,slice(self.model.mask_decoder.num_mask_tokens-1, self.model.mask_decoder.num_mask_tokens)]
+
+        if self.hq_token_only:
+            masks = masks_hq
+        else:
+            masks = masks_sam + masks_hq
+
+        upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
+
+        if self.return_extra_metrics:
+            stability_scores = calculate_stability_score(
+                upscaled_masks, self.model.mask_threshold, self.stability_score_offset
+            )
+            areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
+            return upscaled_masks, scores, stability_scores, areas, masks
+
+        return upscaled_masks, scores, masks
diff --git a/modeling/semantic_enhanced_matting/utils/transforms.py b/modeling/semantic_enhanced_matting/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c08ba1e3db751f3a5483a003be38c69c2cf2df85
--- /dev/null
+++ b/modeling/semantic_enhanced_matting/utils/transforms.py
@@ -0,0 +1,102 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+from torchvision.transforms.functional import resize, to_pil_image  # type: ignore
+
+from copy import deepcopy
+from typing import Tuple
+
+
+class ResizeLongestSide:
+    """
+    Resizes images to the longest side 'target_length', as well as provides
+    methods for resizing coordinates and boxes. Provides methods for
+    transforming both numpy array and batched torch tensors.
+    """
+
+    def __init__(self, target_length: int) -> None:
+        self.target_length = target_length
+
+    def apply_image(self, image: np.ndarray) -> np.ndarray:
+        """
+        Expects a numpy array with shape HxWxC in uint8 format.
+        """
+        target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+        return np.array(resize(to_pil_image(image), target_size))
+
+    def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+        """
+        Expects a numpy array of length 2 in the final dimension. Requires the
+        original image size in (H, W) format.
+        """
+        old_h, old_w = original_size
+        new_h, new_w = self.get_preprocess_shape(
+            original_size[0], original_size[1], self.target_length
+        )
+        coords = deepcopy(coords).astype(float)
+        coords[..., 0] = coords[..., 0] * (new_w / old_w)
+        coords[..., 1] = coords[..., 1] * (new_h / old_h)
+        return coords
+
+    def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+        """
+        Expects a numpy array shape Bx4. Requires the original image size
+        in (H, W) format.
+        """
+        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
+        return boxes.reshape(-1, 4)
+
+    def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
+        """
+        Expects batched images with shape BxCxHxW and float format. This
+        transformation may not exactly match apply_image. apply_image is
+        the transformation expected by the model.
+        """
+        # Expects an image in BCHW format. May not exactly match apply_image.
+        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
+        return F.interpolate(
+            image, target_size, mode="bilinear", align_corners=False, antialias=True
+        )
+
+    def apply_coords_torch(
+        self, coords: torch.Tensor, original_size: Tuple[int, ...]
+    ) -> torch.Tensor:
+        """
+        Expects a torch tensor with length 2 in the last dimension. Requires the
+        original image size in (H, W) format.
+        """
+        old_h, old_w = original_size
+        new_h, new_w = self.get_preprocess_shape(
+            original_size[0], original_size[1], self.target_length
+        )
+        coords = deepcopy(coords).to(torch.float)
+        coords[..., 0] = coords[..., 0] * (new_w / old_w)
+        coords[..., 1] = coords[..., 1] * (new_h / old_h)
+        return coords
+
+    def apply_boxes_torch(
+        self, boxes: torch.Tensor, original_size: Tuple[int, ...]
+    ) -> torch.Tensor:
+        """
+        Expects a torch tensor with shape Bx4. Requires the original image
+        size in (H, W) format.
+        """
+        boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
+        return boxes.reshape(-1, 4)
+
+    @staticmethod
+    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+        """
+        Compute the output size given input size and target long side length.
+        """
+        scale = long_side_length * 1.0 / max(oldh, oldw)
+        newh, neww = oldh * scale, oldw * scale
+        neww = int(neww + 0.5)
+        newh = int(newh + 0.5)
+        return (newh, neww)
diff --git a/pretrained/preprocess.py b/pretrained/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b28412b4329757ed8acb0dc5978539a0e4fc4645
--- /dev/null
+++ b/pretrained/preprocess.py
@@ -0,0 +1,27 @@
+import torch
+import wget
+
+def preprocess(model, name='dino', embed_dim=384):
+    new_model = {}
+    for k in model.keys():
+        if 'patch_embed.proj.weight' in k:
+            x = torch.zeros(embed_dim, 4, 16, 16)
+            x[:, :3] = model[k]
+            new_model['backbone.'+k] = x
+        else:
+            new_model['backbone.'+k] = model[k]
+    if embed_dim==384:
+        size='s'
+    else:
+        size='b'
+    torch.save(new_model, name+'_vit_'+ size + '_fna.pth')
+
+if __name__ == "__main__":
+
+    wget.download('https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth')
+    wget.download('https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth')
+
+    dino_model = torch.load('dino_deitsmall16_pretrain.pth')
+    mae_model = torch.load('mae_pretrain_vit_base.pth')['model']
+    preprocess(dino_model, 'dino', 384)
+    preprocess(mae_model, 'mae', 768)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4847f250554f6dd66d546a7719b24d6705094a42
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,17 @@
+torch==2.0.0
+torchvision
+tensorboard
+timm==0.5.4
+opencv-python==4.5.3.56
+setuptools==58.2.0
+easydict
+wget
+scikit-image
+fairscale
+imgaug
+peft
+kornia
+gradio==4.44.1
+gradio_image_prompter
+huggingface_hub
+detectron2 @ git+https://github.com/facebookresearch/detectron2@v0.6
diff --git a/sam2/__init__.py b/sam2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff90d1042707d3190277036f3fd8e0ff177fd365
--- /dev/null
+++ b/sam2/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from hydra import initialize_config_module
+
+initialize_config_module("sam2_configs", version_base="1.2")
diff --git a/sam2/__pycache__/__init__.cpython-38.pyc b/sam2/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95dd7344f8c877202887c8f36677189d99419e57
Binary files /dev/null and b/sam2/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sam2/__pycache__/build_sam.cpython-38.pyc b/sam2/__pycache__/build_sam.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..de2c2814ddcc56f027d4e49dfba22793f8e64229
Binary files /dev/null and b/sam2/__pycache__/build_sam.cpython-38.pyc differ
diff --git a/sam2/__pycache__/sam2_image_predictor.cpython-38.pyc b/sam2/__pycache__/sam2_image_predictor.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d44530cce89b105284fcc04952805cc49123b39
Binary files /dev/null and b/sam2/__pycache__/sam2_image_predictor.cpython-38.pyc differ
diff --git a/sam2/automatic_mask_generator.py b/sam2/automatic_mask_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..67668b2fc937010d237f8195d5c059c7cc481a3e
--- /dev/null
+++ b/sam2/automatic_mask_generator.py
@@ -0,0 +1,434 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
+from typing import Any, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from torchvision.ops.boxes import batched_nms, box_area  # type: ignore
+
+from sam2.modeling.sam2_base import SAM2Base
+from sam2.sam2_image_predictor import SAM2ImagePredictor
+from sam2.utils.amg import (
+    area_from_rle,
+    batch_iterator,
+    batched_mask_to_box,
+    box_xyxy_to_xywh,
+    build_all_layer_point_grids,
+    calculate_stability_score,
+    coco_encode_rle,
+    generate_crop_boxes,
+    is_box_near_crop_edge,
+    mask_to_rle_pytorch,
+    MaskData,
+    remove_small_regions,
+    rle_to_mask,
+    uncrop_boxes_xyxy,
+    uncrop_masks,
+    uncrop_points,
+)
+
+
+class SAM2AutomaticMaskGenerator:
+    def __init__(
+        self,
+        model: SAM2Base,
+        points_per_side: Optional[int] = 32,
+        points_per_batch: int = 64,
+        pred_iou_thresh: float = 0.8,
+        stability_score_thresh: float = 0.95,
+        stability_score_offset: float = 1.0,
+        mask_threshold: float = 0.0,
+        box_nms_thresh: float = 0.7,
+        crop_n_layers: int = 0,
+        crop_nms_thresh: float = 0.7,
+        crop_overlap_ratio: float = 512 / 1500,
+        crop_n_points_downscale_factor: int = 1,
+        point_grids: Optional[List[np.ndarray]] = None,
+        min_mask_region_area: int = 0,
+        output_mode: str = "binary_mask",
+        use_m2m: bool = False,
+        multimask_output: bool = True,
+    ) -> None:
+        """
+        Using a SAM 2 model, generates masks for the entire image.
+        Generates a grid of point prompts over the image, then filters
+        low quality and duplicate masks. The default settings are chosen
+        for SAM 2 with a HieraL backbone.
+
+        Arguments:
+          model (Sam): The SAM 2 model to use for mask prediction.
+          points_per_side (int or None): The number of points to be sampled
+            along one side of the image. The total number of points is
+            points_per_side**2. If None, 'point_grids' must provide explicit
+            point sampling.
+          points_per_batch (int): Sets the number of points run simultaneously
+            by the model. Higher numbers may be faster but use more GPU memory.
+          pred_iou_thresh (float): A filtering threshold in [0,1], using the
+            model's predicted mask quality.
+          stability_score_thresh (float): A filtering threshold in [0,1], using
+            the stability of the mask under changes to the cutoff used to binarize
+            the model's mask predictions.
+          stability_score_offset (float): The amount to shift the cutoff when
+            calculated the stability score.
+          mask_threshold (float): Threshold for binarizing the mask logits
+          box_nms_thresh (float): The box IoU cutoff used by non-maximal
+            suppression to filter duplicate masks.
+          crop_n_layers (int): If >0, mask prediction will be run again on
+            crops of the image. Sets the number of layers to run, where each
+            layer has 2**i_layer number of image crops.
+          crop_nms_thresh (float): The box IoU cutoff used by non-maximal
+            suppression to filter duplicate masks between different crops.
+          crop_overlap_ratio (float): Sets the degree to which crops overlap.
+            In the first crop layer, crops will overlap by this fraction of
+            the image length. Later layers with more crops scale down this overlap.
+          crop_n_points_downscale_factor (int): The number of points-per-side
+            sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
+          point_grids (list(np.ndarray) or None): A list over explicit grids
+            of points used for sampling, normalized to [0,1]. The nth grid in the
+            list is used in the nth crop layer. Exclusive with points_per_side.
+          min_mask_region_area (int): If >0, postprocessing will be applied
+            to remove disconnected regions and holes in masks with area smaller
+            than min_mask_region_area. Requires opencv.
+          output_mode (str): The form masks are returned in. Can be 'binary_mask',
+            'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
+            For large resolutions, 'binary_mask' may consume large amounts of
+            memory.
+          use_m2m (bool): Whether to add a one step refinement using previous mask predictions.
+          multimask_output (bool): Whether to output multimask at each point of the grid.
+        """
+
+        assert (points_per_side is None) != (
+            point_grids is None
+        ), "Exactly one of points_per_side or point_grid must be provided."
+        if points_per_side is not None:
+            self.point_grids = build_all_layer_point_grids(
+                points_per_side,
+                crop_n_layers,
+                crop_n_points_downscale_factor,
+            )
+        elif point_grids is not None:
+            self.point_grids = point_grids
+        else:
+            raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+        assert output_mode in [
+            "binary_mask",
+            "uncompressed_rle",
+            "coco_rle",
+        ], f"Unknown output_mode {output_mode}."
+        if output_mode == "coco_rle":
+            try:
+                from pycocotools import mask as mask_utils  # type: ignore  # noqa: F401
+            except ImportError as e:
+                print("Please install pycocotools")
+                raise e
+
+        self.predictor = SAM2ImagePredictor(
+            model,
+            max_hole_area=min_mask_region_area,
+            max_sprinkle_area=min_mask_region_area,
+        )
+        self.points_per_batch = points_per_batch
+        self.pred_iou_thresh = pred_iou_thresh
+        self.stability_score_thresh = stability_score_thresh
+        self.stability_score_offset = stability_score_offset
+        self.mask_threshold = mask_threshold
+        self.box_nms_thresh = box_nms_thresh
+        self.crop_n_layers = crop_n_layers
+        self.crop_nms_thresh = crop_nms_thresh
+        self.crop_overlap_ratio = crop_overlap_ratio
+        self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+        self.min_mask_region_area = min_mask_region_area
+        self.output_mode = output_mode
+        self.use_m2m = use_m2m
+        self.multimask_output = multimask_output
+
+    @torch.no_grad()
+    def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
+        """
+        Generates masks for the given image.
+
+        Arguments:
+          image (np.ndarray): The image to generate masks for, in HWC uint8 format.
+
+        Returns:
+           list(dict(str, any)): A list over records for masks. Each record is
+             a dict containing the following keys:
+               segmentation (dict(str, any) or np.ndarray): The mask. If
+                 output_mode='binary_mask', is an array of shape HW. Otherwise,
+                 is a dictionary containing the RLE.
+               bbox (list(float)): The box around the mask, in XYWH format.
+               area (int): The area in pixels of the mask.
+               predicted_iou (float): The model's own prediction of the mask's
+                 quality. This is filtered by the pred_iou_thresh parameter.
+               point_coords (list(list(float))): The point coordinates input
+                 to the model to generate this mask.
+               stability_score (float): A measure of the mask's quality. This
+                 is filtered on using the stability_score_thresh parameter.
+               crop_box (list(float)): The crop of the image used to generate
+                 the mask, given in XYWH format.
+        """
+
+        # Generate masks
+        mask_data = self._generate_masks(image)
+
+        # Encode masks
+        if self.output_mode == "coco_rle":
+            mask_data["segmentations"] = [
+                coco_encode_rle(rle) for rle in mask_data["rles"]
+            ]
+        elif self.output_mode == "binary_mask":
+            mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
+        else:
+            mask_data["segmentations"] = mask_data["rles"]
+
+        # Write mask records
+        curr_anns = []
+        for idx in range(len(mask_data["segmentations"])):
+            ann = {
+                "segmentation": mask_data["segmentations"][idx],
+                "area": area_from_rle(mask_data["rles"][idx]),
+                "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
+                "predicted_iou": mask_data["iou_preds"][idx].item(),
+                "point_coords": [mask_data["points"][idx].tolist()],
+                "stability_score": mask_data["stability_score"][idx].item(),
+                "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
+            }
+            curr_anns.append(ann)
+
+        return curr_anns
+
+    def _generate_masks(self, image: np.ndarray) -> MaskData:
+        orig_size = image.shape[:2]
+        crop_boxes, layer_idxs = generate_crop_boxes(
+            orig_size, self.crop_n_layers, self.crop_overlap_ratio
+        )
+
+        # Iterate over image crops
+        data = MaskData()
+        for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
+            crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
+            data.cat(crop_data)
+
+        # Remove duplicate masks between crops
+        if len(crop_boxes) > 1:
+            # Prefer masks from smaller crops
+            scores = 1 / box_area(data["crop_boxes"])
+            scores = scores.to(data["boxes"].device)
+            keep_by_nms = batched_nms(
+                data["boxes"].float(),
+                scores,
+                torch.zeros_like(data["boxes"][:, 0]),  # categories
+                iou_threshold=self.crop_nms_thresh,
+            )
+            data.filter(keep_by_nms)
+        data.to_numpy()
+        return data
+
+    def _process_crop(
+        self,
+        image: np.ndarray,
+        crop_box: List[int],
+        crop_layer_idx: int,
+        orig_size: Tuple[int, ...],
+    ) -> MaskData:
+        # Crop the image and calculate embeddings
+        x0, y0, x1, y1 = crop_box
+        cropped_im = image[y0:y1, x0:x1, :]
+        cropped_im_size = cropped_im.shape[:2]
+        self.predictor.set_image(cropped_im)
+
+        # Get points for this crop
+        points_scale = np.array(cropped_im_size)[None, ::-1]
+        points_for_image = self.point_grids[crop_layer_idx] * points_scale
+
+        # Generate masks for this crop in batches
+        data = MaskData()
+        for (points,) in batch_iterator(self.points_per_batch, points_for_image):
+            batch_data = self._process_batch(
+                points, cropped_im_size, crop_box, orig_size, normalize=True
+            )
+            data.cat(batch_data)
+            del batch_data
+        self.predictor.reset_predictor()
+
+        # Remove duplicates within this crop.
+        keep_by_nms = batched_nms(
+            data["boxes"].float(),
+            data["iou_preds"],
+            torch.zeros_like(data["boxes"][:, 0]),  # categories
+            iou_threshold=self.box_nms_thresh,
+        )
+        data.filter(keep_by_nms)
+
+        # Return to the original image frame
+        data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
+        data["points"] = uncrop_points(data["points"], crop_box)
+        data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
+
+        return data
+
+    def _process_batch(
+        self,
+        points: np.ndarray,
+        im_size: Tuple[int, ...],
+        crop_box: List[int],
+        orig_size: Tuple[int, ...],
+        normalize=False,
+    ) -> MaskData:
+        orig_h, orig_w = orig_size
+
+        # Run model on this batch
+        points = torch.as_tensor(points, device=self.predictor.device)
+        in_points = self.predictor._transforms.transform_coords(
+            points, normalize=normalize, orig_hw=im_size
+        )
+        in_labels = torch.ones(
+            in_points.shape[0], dtype=torch.int, device=in_points.device
+        )
+        masks, iou_preds, low_res_masks = self.predictor._predict(
+            in_points[:, None, :],
+            in_labels[:, None],
+            multimask_output=self.multimask_output,
+            return_logits=True,
+        )
+
+        # Serialize predictions and store in MaskData
+        data = MaskData(
+            masks=masks.flatten(0, 1),
+            iou_preds=iou_preds.flatten(0, 1),
+            points=points.repeat_interleave(masks.shape[1], dim=0),
+            low_res_masks=low_res_masks.flatten(0, 1),
+        )
+        del masks
+
+        if not self.use_m2m:
+            # Filter by predicted IoU
+            if self.pred_iou_thresh > 0.0:
+                keep_mask = data["iou_preds"] > self.pred_iou_thresh
+                data.filter(keep_mask)
+
+            # Calculate and filter by stability score
+            data["stability_score"] = calculate_stability_score(
+                data["masks"], self.mask_threshold, self.stability_score_offset
+            )
+            if self.stability_score_thresh > 0.0:
+                keep_mask = data["stability_score"] >= self.stability_score_thresh
+                data.filter(keep_mask)
+        else:
+            # One step refinement using previous mask predictions
+            in_points = self.predictor._transforms.transform_coords(
+                data["points"], normalize=normalize, orig_hw=im_size
+            )
+            labels = torch.ones(
+                in_points.shape[0], dtype=torch.int, device=in_points.device
+            )
+            masks, ious = self.refine_with_m2m(
+                in_points, labels, data["low_res_masks"], self.points_per_batch
+            )
+            data["masks"] = masks.squeeze(1)
+            data["iou_preds"] = ious.squeeze(1)
+
+            if self.pred_iou_thresh > 0.0:
+                keep_mask = data["iou_preds"] > self.pred_iou_thresh
+                data.filter(keep_mask)
+
+            data["stability_score"] = calculate_stability_score(
+                data["masks"], self.mask_threshold, self.stability_score_offset
+            )
+            if self.stability_score_thresh > 0.0:
+                keep_mask = data["stability_score"] >= self.stability_score_thresh
+                data.filter(keep_mask)
+
+        # Threshold masks and calculate boxes
+        data["masks"] = data["masks"] > self.mask_threshold
+        data["boxes"] = batched_mask_to_box(data["masks"])
+
+        # Filter boxes that touch crop boundaries
+        keep_mask = ~is_box_near_crop_edge(
+            data["boxes"], crop_box, [0, 0, orig_w, orig_h]
+        )
+        if not torch.all(keep_mask):
+            data.filter(keep_mask)
+
+        # Compress to RLE
+        data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
+        data["rles"] = mask_to_rle_pytorch(data["masks"])
+        del data["masks"]
+
+        return data
+
+    @staticmethod
+    def postprocess_small_regions(
+        mask_data: MaskData, min_area: int, nms_thresh: float
+    ) -> MaskData:
+        """
+        Removes small disconnected regions and holes in masks, then reruns
+        box NMS to remove any new duplicates.
+
+        Edits mask_data in place.
+
+        Requires open-cv as a dependency.
+        """
+        if len(mask_data["rles"]) == 0:
+            return mask_data
+
+        # Filter small disconnected regions and holes
+        new_masks = []
+        scores = []
+        for rle in mask_data["rles"]:
+            mask = rle_to_mask(rle)
+
+            mask, changed = remove_small_regions(mask, min_area, mode="holes")
+            unchanged = not changed
+            mask, changed = remove_small_regions(mask, min_area, mode="islands")
+            unchanged = unchanged and not changed
+
+            new_masks.append(torch.as_tensor(mask).unsqueeze(0))
+            # Give score=0 to changed masks and score=1 to unchanged masks
+            # so NMS will prefer ones that didn't need postprocessing
+            scores.append(float(unchanged))
+
+        # Recalculate boxes and remove any new duplicates
+        masks = torch.cat(new_masks, dim=0)
+        boxes = batched_mask_to_box(masks)
+        keep_by_nms = batched_nms(
+            boxes.float(),
+            torch.as_tensor(scores),
+            torch.zeros_like(boxes[:, 0]),  # categories
+            iou_threshold=nms_thresh,
+        )
+
+        # Only recalculate RLEs for masks that have changed
+        for i_mask in keep_by_nms:
+            if scores[i_mask] == 0.0:
+                mask_torch = masks[i_mask].unsqueeze(0)
+                mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
+                mask_data["boxes"][i_mask] = boxes[i_mask]  # update res directly
+        mask_data.filter(keep_by_nms)
+
+        return mask_data
+
+    def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch):
+        new_masks = []
+        new_iou_preds = []
+
+        for cur_points, cur_point_labels, low_res_mask in batch_iterator(
+            points_per_batch, points, point_labels, low_res_masks
+        ):
+            best_masks, best_iou_preds, _ = self.predictor._predict(
+                cur_points[:, None, :],
+                cur_point_labels[:, None],
+                mask_input=low_res_mask[:, None, :],
+                multimask_output=False,
+                return_logits=True,
+            )
+            new_masks.append(best_masks)
+            new_iou_preds.append(best_iou_preds)
+        masks = torch.cat(new_masks, dim=0)
+        return masks, torch.cat(new_iou_preds, dim=0)
diff --git a/sam2/build_sam.py b/sam2/build_sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad12041ee9d689a05a7c43149c183ffe84993336
--- /dev/null
+++ b/sam2/build_sam.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+import torch
+from hydra import compose
+from hydra.utils import instantiate
+from omegaconf import OmegaConf
+
+
+def build_sam2(
+    config_file,
+    ckpt_path=None,
+    device="cuda",
+    mode="eval",
+    hydra_overrides_extra=[],
+    apply_postprocessing=True,
+    bbox_mask_matting_token = False,
+    matting_logits_res_add = False,
+    upscaled_embedding_res_add = True,
+):
+
+    if apply_postprocessing:
+        hydra_overrides_extra = hydra_overrides_extra.copy()
+        hydra_overrides_extra += [
+            # dynamically fall back to multi-mask if the single mask is not stable
+            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+        ]
+    
+    if bbox_mask_matting_token:
+        hydra_overrides_extra = hydra_overrides_extra.copy()
+        hydra_overrides_extra += [
+            "++model.mask_decoder_matting_token=true",
+            "++model.image_encoder.trunk._target_=sam2.modeling.backbones.hieradet.HieraBBoxMask",
+            "++model.matting_logits_res_add=true" if matting_logits_res_add else "++model.matting_logits_res_add=false",
+            "++model.upscaled_embedding_res_add=true" if upscaled_embedding_res_add else "++model.upscaled_embedding_res_add=false",
+        ]
+
+    # Read config and init model
+    cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
+    OmegaConf.resolve(cfg)
+    model = instantiate(cfg.model, _recursive_=True)
+    _load_checkpoint(model, ckpt_path, add_new_layer_weights=True)
+
+    model = model.to(device)
+    if mode == "eval":
+        model.eval()
+    
+    if bbox_mask_matting_token:
+        for n, p in model.named_parameters():
+            if 'matting' in n or 'bbox_mask' in n:
+                p.requires_grad = True
+            else:
+                p.requires_grad = False
+
+    return model
+
+
+def build_sam2_video_predictor(
+    config_file,
+    ckpt_path=None,
+    device="cuda",
+    mode="eval",
+    hydra_overrides_extra=[],
+    apply_postprocessing=True,
+):
+    hydra_overrides = [
+        "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
+    ]
+    if apply_postprocessing:
+        hydra_overrides_extra = hydra_overrides_extra.copy()
+        hydra_overrides_extra += [
+            # dynamically fall back to multi-mask if the single mask is not stable
+            "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
+            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
+            "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
+            # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
+            "++model.binarize_mask_from_pts_for_mem_enc=true",
+            # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
+            "++model.fill_hole_area=8",
+        ]
+    hydra_overrides.extend(hydra_overrides_extra)
+
+    # Read config and init model
+    cfg = compose(config_name=config_file, overrides=hydra_overrides)
+    OmegaConf.resolve(cfg)
+    model = instantiate(cfg.model, _recursive_=True)
+    _load_checkpoint(model, ckpt_path)
+    model = model.to(device)
+    if mode == "eval":
+        model.eval()
+    return model
+
+
+def build_sam2_hf(model_id, **kwargs):
+
+    from huggingface_hub import hf_hub_download
+
+    model_id_to_filenames = {
+        "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
+        "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
+        "facebook/sam2-hiera-base-plus": (
+            "sam2_hiera_b+.yaml",
+            "sam2_hiera_base_plus.pt",
+        ),
+        "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
+    }
+    config_name, checkpoint_name = model_id_to_filenames[model_id]
+    ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
+    return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
+
+
+def build_sam2_video_predictor_hf(model_id, **kwargs):
+
+    from huggingface_hub import hf_hub_download
+
+    model_id_to_filenames = {
+        "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"),
+        "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"),
+        "facebook/sam2-hiera-base-plus": (
+            "sam2_hiera_b+.yaml",
+            "sam2_hiera_base_plus.pt",
+        ),
+        "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"),
+    }
+    config_name, checkpoint_name = model_id_to_filenames[model_id]
+    ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
+    return build_sam2_video_predictor(
+        config_file=config_name, ckpt_path=ckpt_path, **kwargs
+    )
+
+
+def _load_checkpoint(model, ckpt_path, add_new_layer_weights=False):
+    # if add_new_layer_weights:
+    #     assert ckpt_path is not None
+    if ckpt_path is not None:
+        sd = torch.load(ckpt_path, map_location="cpu")["model"]
+        if add_new_layer_weights:
+
+            # bbox patch embed
+            sd['image_encoder.trunk.bbox_mask_patch_embed.proj.weight'] = torch.concat((
+                sd['image_encoder.trunk.patch_embed.proj.weight'], 
+                torch.mean(sd['image_encoder.trunk.patch_embed.proj.weight'], dim=1, keepdim=True)
+            ), dim=1)
+            sd['image_encoder.trunk.bbox_mask_patch_embed.proj.bias'] = sd['image_encoder.trunk.patch_embed.proj.bias']
+
+            # matting token
+            sd['sam_mask_decoder.matting_mask_tokens.weight'] = torch.mean(sd['sam_mask_decoder.mask_tokens.weight'], dim=0, keepdim=True).repeat(model.sam_mask_decoder.matting_token_num, 1)
+            
+            output_hypernetworks_mlps_0_keys = [key for key in sd.keys() if 'output_hypernetworks_mlps.0' in key]
+            for i in range(model.sam_mask_decoder.matting_token_num):
+                for key in output_hypernetworks_mlps_0_keys:
+                    target_key = key.replace('output_hypernetworks_mlps.0', 'matting_output_hypernetworks_mlps.{}'.format(i))
+                    sd[target_key] = sd[key]
+
+            output_upscaling_keys = [key for key in sd.keys() if 'output_upscaling' in key]
+            for key in output_upscaling_keys:
+                target_key = key.replace('output_upscaling', 'matting_output_upscaling')
+                sd[target_key] = sd[key]
+
+        missing_keys, unexpected_keys = model.load_state_dict(sd)
+        if missing_keys:
+            logging.error(missing_keys)
+            raise RuntimeError()
+        if unexpected_keys:
+            logging.error(unexpected_keys)
+            raise RuntimeError()
+        logging.info("Loaded checkpoint sucessfully")
diff --git a/sam2/csrc/connected_components.cu b/sam2/csrc/connected_components.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ced21eb32eaaadb818d441c1322b99d1bf068f45
--- /dev/null
+++ b/sam2/csrc/connected_components.cu
@@ -0,0 +1,289 @@
+// Copyright (c) Meta Platforms, Inc. and affiliates.
+// All rights reserved.
+
+// This source code is licensed under the license found in the
+// LICENSE file in the root directory of this source tree.
+
+// adapted from https://github.com/zsef123/Connected_components_PyTorch
+// with license found in the LICENSE_cctorch file in the root directory.
+#include <ATen/cuda/CUDAContext.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#include <torch/extension.h>
+#include <torch/script.h>
+#include <vector>
+
+// 2d
+#define BLOCK_ROWS 16
+#define BLOCK_COLS 16
+
+namespace cc2d {
+
+template <typename T>
+__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) {
+  return (bitmap >> pos) & 1;
+}
+
+__device__ int32_t find(const int32_t* s_buf, int32_t n) {
+  while (s_buf[n] != n)
+    n = s_buf[n];
+  return n;
+}
+
+__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) {
+  const int32_t id = n;
+  while (s_buf[n] != n) {
+    n = s_buf[n];
+    s_buf[id] = n;
+  }
+  return n;
+}
+
+__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) {
+  bool done;
+  do {
+    a = find(s_buf, a);
+    b = find(s_buf, b);
+
+    if (a < b) {
+      int32_t old = atomicMin(s_buf + b, a);
+      done = (old == b);
+      b = old;
+    } else if (b < a) {
+      int32_t old = atomicMin(s_buf + a, b);
+      done = (old == a);
+      a = old;
+    } else
+      done = true;
+
+  } while (!done);
+}
+
+__global__ void
+init_labeling(int32_t* label, const uint32_t W, const uint32_t H) {
+  const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+  const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+  const uint32_t idx = row * W + col;
+
+  if (row < H && col < W)
+    label[idx] = idx;
+}
+
+__global__ void
+merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) {
+  const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+  const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+  const uint32_t idx = row * W + col;
+
+  if (row >= H || col >= W)
+    return;
+
+  uint32_t P = 0;
+
+  if (img[idx])
+    P |= 0x777;
+  if (row + 1 < H && img[idx + W])
+    P |= 0x777 << 4;
+  if (col + 1 < W && img[idx + 1])
+    P |= 0x777 << 1;
+
+  if (col == 0)
+    P &= 0xEEEE;
+  if (col + 1 >= W)
+    P &= 0x3333;
+  else if (col + 2 >= W)
+    P &= 0x7777;
+
+  if (row == 0)
+    P &= 0xFFF0;
+  if (row + 1 >= H)
+    P &= 0xFF;
+
+  if (P > 0) {
+    // If need check about top-left pixel(if flag the first bit) and hit the
+    // top-left pixel
+    if (hasBit(P, 0) && img[idx - W - 1]) {
+      union_(label, idx, idx - 2 * W - 2); // top left block
+    }
+
+    if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1]))
+      union_(label, idx, idx - 2 * W); // top bottom block
+
+    if (hasBit(P, 3) && img[idx + 2 - W])
+      union_(label, idx, idx - 2 * W + 2); // top right block
+
+    if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1]))
+      union_(label, idx, idx - 2); // just left block
+  }
+}
+
+__global__ void compression(int32_t* label, const int32_t W, const int32_t H) {
+  const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+  const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+  const uint32_t idx = row * W + col;
+
+  if (row < H && col < W)
+    find_n_compress(label, idx);
+}
+
+__global__ void final_labeling(
+    const uint8_t* img,
+    int32_t* label,
+    const int32_t W,
+    const int32_t H) {
+  const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2;
+  const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2;
+  const uint32_t idx = row * W + col;
+
+  if (row >= H || col >= W)
+    return;
+
+  int32_t y = label[idx] + 1;
+
+  if (img[idx])
+    label[idx] = y;
+  else
+    label[idx] = 0;
+
+  if (col + 1 < W) {
+    if (img[idx + 1])
+      label[idx + 1] = y;
+    else
+      label[idx + 1] = 0;
+
+    if (row + 1 < H) {
+      if (img[idx + W + 1])
+        label[idx + W + 1] = y;
+      else
+        label[idx + W + 1] = 0;
+    }
+  }
+
+  if (row + 1 < H) {
+    if (img[idx + W])
+      label[idx + W] = y;
+    else
+      label[idx + W] = 0;
+  }
+}
+
+__global__ void init_counting(
+    const int32_t* label,
+    int32_t* count_init,
+    const int32_t W,
+    const int32_t H) {
+  const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
+  const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
+  const uint32_t idx = row * W + col;
+
+  if (row >= H || col >= W)
+    return;
+
+  int32_t y = label[idx];
+  if (y > 0) {
+    int32_t count_idx = y - 1;
+    atomicAdd(count_init + count_idx, 1);
+  }
+}
+
+__global__ void final_counting(
+    const int32_t* label,
+    const int32_t* count_init,
+    int32_t* count_final,
+    const int32_t W,
+    const int32_t H) {
+  const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y);
+  const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x);
+  const uint32_t idx = row * W + col;
+
+  if (row >= H || col >= W)
+    return;
+
+  int32_t y = label[idx];
+  if (y > 0) {
+    int32_t count_idx = y - 1;
+    count_final[idx] = count_init[count_idx];
+  } else {
+    count_final[idx] = 0;
+  }
+}
+
+} // namespace cc2d
+
+std::vector<torch::Tensor> get_connected_componnets(
+    const torch::Tensor& inputs) {
+  AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor");
+  AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape");
+  AT_ASSERTM(
+      inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type");
+
+  const uint32_t N = inputs.size(0);
+  const uint32_t C = inputs.size(1);
+  const uint32_t H = inputs.size(2);
+  const uint32_t W = inputs.size(3);
+
+  AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape");
+  AT_ASSERTM((H % 2) == 0, "height must be an even number");
+  AT_ASSERTM((W % 2) == 0, "width must be an even number");
+
+  // label must be uint32_t
+  auto label_options =
+      torch::TensorOptions().dtype(torch::kInt32).device(inputs.device());
+  torch::Tensor labels = torch::zeros({N, C, H, W}, label_options);
+  torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options);
+  torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options);
+
+  dim3 grid = dim3(
+      ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS,
+      ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS);
+  dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS);
+  dim3 grid_count =
+      dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS);
+  dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS);
+  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  for (int n = 0; n < N; n++) {
+    uint32_t offset = n * H * W;
+
+    cc2d::init_labeling<<<grid, block, 0, stream>>>(
+        labels.data_ptr<int32_t>() + offset, W, H);
+    cc2d::merge<<<grid, block, 0, stream>>>(
+        inputs.data_ptr<uint8_t>() + offset,
+        labels.data_ptr<int32_t>() + offset,
+        W,
+        H);
+    cc2d::compression<<<grid, block, 0, stream>>>(
+        labels.data_ptr<int32_t>() + offset, W, H);
+    cc2d::final_labeling<<<grid, block, 0, stream>>>(
+        inputs.data_ptr<uint8_t>() + offset,
+        labels.data_ptr<int32_t>() + offset,
+        W,
+        H);
+
+    // get the counting of each pixel
+    cc2d::init_counting<<<grid_count, block_count, 0, stream>>>(
+        labels.data_ptr<int32_t>() + offset,
+        counts_init.data_ptr<int32_t>() + offset,
+        W,
+        H);
+    cc2d::final_counting<<<grid_count, block_count, 0, stream>>>(
+        labels.data_ptr<int32_t>() + offset,
+        counts_init.data_ptr<int32_t>() + offset,
+        counts_final.data_ptr<int32_t>() + offset,
+        W,
+        H);
+  }
+
+  // returned values are [labels, counts]
+  std::vector<torch::Tensor> outputs;
+  outputs.push_back(labels);
+  outputs.push_back(counts_final);
+  return outputs;
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def(
+      "get_connected_componnets",
+      &get_connected_componnets,
+      "get_connected_componnets");
+}
diff --git a/sam2/modeling/__init__.py b/sam2/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/modeling/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/modeling/__pycache__/__init__.cpython-38.pyc b/sam2/modeling/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d5d8ad6db0de1c412842a6180a9102a6710f3c6
Binary files /dev/null and b/sam2/modeling/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sam2/modeling/__pycache__/memory_attention.cpython-38.pyc b/sam2/modeling/__pycache__/memory_attention.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b326cce83183de1df2f78510bd84c3b3e282bf97
Binary files /dev/null and b/sam2/modeling/__pycache__/memory_attention.cpython-38.pyc differ
diff --git a/sam2/modeling/__pycache__/memory_encoder.cpython-38.pyc b/sam2/modeling/__pycache__/memory_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e8dfda2877a8c75024f2df773d5e8edcc8ed844c
Binary files /dev/null and b/sam2/modeling/__pycache__/memory_encoder.cpython-38.pyc differ
diff --git a/sam2/modeling/__pycache__/position_encoding.cpython-38.pyc b/sam2/modeling/__pycache__/position_encoding.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..12347e45deec1f2d3157a9302d1126a482710598
Binary files /dev/null and b/sam2/modeling/__pycache__/position_encoding.cpython-38.pyc differ
diff --git a/sam2/modeling/__pycache__/sam2_base.cpython-38.pyc b/sam2/modeling/__pycache__/sam2_base.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f07fe8f41df4ca9f93b8b0bfb35f82103e71fbbe
Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_base.cpython-38.pyc differ
diff --git a/sam2/modeling/__pycache__/sam2_utils.cpython-38.pyc b/sam2/modeling/__pycache__/sam2_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ddc6ab7a35cf4b9e2ad30d5375c9cd926b8c91e
Binary files /dev/null and b/sam2/modeling/__pycache__/sam2_utils.cpython-38.pyc differ
diff --git a/sam2/modeling/backbones/__init__.py b/sam2/modeling/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/modeling/backbones/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/modeling/backbones/__pycache__/__init__.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..010a69e49cf524fdf0dedb354fadaa2807dddd04
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sam2/modeling/backbones/__pycache__/hieradet.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/hieradet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a7a555aeaf7364acd19e6fc95c81630d90f7251
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/hieradet.cpython-38.pyc differ
diff --git a/sam2/modeling/backbones/__pycache__/image_encoder.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c4933839c75088fd93b7083ad4f5b004546921b1
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/image_encoder.cpython-38.pyc differ
diff --git a/sam2/modeling/backbones/__pycache__/utils.cpython-38.pyc b/sam2/modeling/backbones/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..424099ce2487c1c3778225acda64f768bbaac87e
Binary files /dev/null and b/sam2/modeling/backbones/__pycache__/utils.cpython-38.pyc differ
diff --git a/sam2/modeling/backbones/hieradet.py b/sam2/modeling/backbones/hieradet.py
new file mode 100644
index 0000000000000000000000000000000000000000..690041c6f70d967d79b5c4e84b1fd97e54b2e242
--- /dev/null
+++ b/sam2/modeling/backbones/hieradet.py
@@ -0,0 +1,339 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from functools import partial
+from typing import List, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.modeling.backbones.utils import (
+    PatchEmbed,
+    window_partition,
+    window_unpartition,
+)
+
+from sam2.modeling.sam2_utils import DropPath, MLP
+
+
+def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
+    if pool is None:
+        return x
+    # (B, H, W, C) -> (B, C, H, W)
+    x = x.permute(0, 3, 1, 2)
+    x = pool(x)
+    # (B, C, H', W') -> (B, H', W', C)
+    x = x.permute(0, 2, 3, 1)
+    if norm:
+        x = norm(x)
+
+    return x
+
+
+class MultiScaleAttention(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        dim_out: int,
+        num_heads: int,
+        q_pool: nn.Module = None,
+    ):
+        super().__init__()
+
+        self.dim = dim
+        self.dim_out = dim_out
+
+        self.num_heads = num_heads
+        head_dim = dim_out // num_heads
+        self.scale = head_dim**-0.5
+
+        self.q_pool = q_pool
+        self.qkv = nn.Linear(dim, dim_out * 3)
+        self.proj = nn.Linear(dim_out, dim_out)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        B, H, W, _ = x.shape
+        # qkv with shape (B, H * W, 3, nHead, C)
+        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
+        # q, k, v with shape (B, H * W, nheads, C)
+        q, k, v = torch.unbind(qkv, 2)
+
+        # Q pooling (for downsample at stage changes)
+        if self.q_pool:
+            q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
+            H, W = q.shape[1:3]  # downsampled shape
+            q = q.reshape(B, H * W, self.num_heads, -1)
+
+        # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
+        x = F.scaled_dot_product_attention(
+            q.transpose(1, 2),
+            k.transpose(1, 2),
+            v.transpose(1, 2),
+        )
+        # Transpose back
+        x = x.transpose(1, 2)
+        x = x.reshape(B, H, W, -1)
+
+        x = self.proj(x)
+
+        return x
+
+
+class MultiScaleBlock(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        dim_out: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        drop_path: float = 0.0,
+        norm_layer: Union[nn.Module, str] = "LayerNorm",
+        q_stride: Tuple[int, int] = None,
+        act_layer: nn.Module = nn.GELU,
+        window_size: int = 0,
+    ):
+        super().__init__()
+
+        if isinstance(norm_layer, str):
+            norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
+
+        self.dim = dim
+        self.dim_out = dim_out
+        self.norm1 = norm_layer(dim)
+
+        self.window_size = window_size
+
+        self.pool, self.q_stride = None, q_stride
+        if self.q_stride:
+            self.pool = nn.MaxPool2d(
+                kernel_size=q_stride, stride=q_stride, ceil_mode=False
+            )
+
+        self.attn = MultiScaleAttention(
+            dim,
+            dim_out,
+            num_heads=num_heads,
+            q_pool=self.pool,
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+        self.norm2 = norm_layer(dim_out)
+        self.mlp = MLP(
+            dim_out,
+            int(dim_out * mlp_ratio),
+            dim_out,
+            num_layers=2,
+            activation=act_layer,
+        )
+
+        if dim != dim_out:
+            self.proj = nn.Linear(dim, dim_out)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        shortcut = x  # B, H, W, C
+        x = self.norm1(x)
+
+        # Skip connection
+        if self.dim != self.dim_out:
+            shortcut = do_pool(self.proj(x), self.pool)
+
+        # Window partition
+        window_size = self.window_size
+        if window_size > 0:
+            H, W = x.shape[1], x.shape[2]
+            x, pad_hw = window_partition(x, window_size)
+
+        # Window Attention + Q Pooling (if stage change)
+        x = self.attn(x)
+        if self.q_stride:
+            # Shapes have changed due to Q pooling
+            window_size = self.window_size // self.q_stride[0]
+            H, W = shortcut.shape[1:3]
+
+            pad_h = (window_size - H % window_size) % window_size
+            pad_w = (window_size - W % window_size) % window_size
+            pad_hw = (H + pad_h, W + pad_w)
+
+        # Reverse window partition
+        if self.window_size > 0:
+            x = window_unpartition(x, window_size, pad_hw, (H, W))
+
+        x = shortcut + self.drop_path(x)
+        # MLP
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+        return x
+
+
+class Hiera(nn.Module):
+    """
+    Reference: https://arxiv.org/abs/2306.00989
+    """
+
+    def __init__(
+        self,
+        embed_dim: int = 96,  # initial embed dim
+        num_heads: int = 1,  # initial number of heads
+        drop_path_rate: float = 0.0,  # stochastic depth
+        q_pool: int = 3,  # number of q_pool stages
+        q_stride: Tuple[int, int] = (2, 2),  # downsample stride bet. stages
+        stages: Tuple[int, ...] = (2, 3, 16, 3),  # blocks per stage
+        dim_mul: float = 2.0,  # dim_mul factor at stage shift
+        head_mul: float = 2.0,  # head_mul factor at stage shift
+        window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
+        # window size per stage, when not using global att.
+        window_spec: Tuple[int, ...] = (
+            8,
+            4,
+            14,
+            7,
+        ),
+        # global attn in these blocks
+        global_att_blocks: Tuple[int, ...] = (
+            12,
+            16,
+            20,
+        ),
+        return_interm_layers=True,  # return feats from every stage
+    ):
+        super().__init__()
+
+        assert len(stages) == len(window_spec)
+        self.window_spec = window_spec
+
+        depth = sum(stages)
+        self.q_stride = q_stride
+        self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
+        assert 0 <= q_pool <= len(self.stage_ends[:-1])
+        self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
+        self.return_interm_layers = return_interm_layers
+
+        self.patch_embed = PatchEmbed(
+            embed_dim=embed_dim,
+        )
+        # Which blocks have global att?
+        self.global_att_blocks = global_att_blocks
+
+        # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
+        self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
+        self.pos_embed = nn.Parameter(
+            torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
+        )
+        self.pos_embed_window = nn.Parameter(
+            torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
+        )
+
+        dpr = [
+            x.item() for x in torch.linspace(0, drop_path_rate, depth)
+        ]  # stochastic depth decay rule
+
+        cur_stage = 1
+        self.blocks = nn.ModuleList()
+
+        for i in range(depth):
+            dim_out = embed_dim
+            # lags by a block, so first block of
+            # next stage uses an initial window size
+            # of previous stage and final window size of current stage
+            window_size = self.window_spec[cur_stage - 1]
+
+            if self.global_att_blocks is not None:
+                window_size = 0 if i in self.global_att_blocks else window_size
+
+            if i - 1 in self.stage_ends:
+                dim_out = int(embed_dim * dim_mul)
+                num_heads = int(num_heads * head_mul)
+                cur_stage += 1
+
+            block = MultiScaleBlock(
+                dim=embed_dim,
+                dim_out=dim_out,
+                num_heads=num_heads,
+                drop_path=dpr[i],
+                q_stride=self.q_stride if i in self.q_pool_blocks else None,
+                window_size=window_size,
+            )
+
+            embed_dim = dim_out
+            self.blocks.append(block)
+
+        self.channel_list = (
+            [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
+            if return_interm_layers
+            else [self.blocks[-1].dim_out]
+        )
+
+    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
+        h, w = hw
+        window_embed = self.pos_embed_window
+        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
+        pos_embed = pos_embed + window_embed.tile(
+            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
+        )
+        pos_embed = pos_embed.permute(0, 2, 3, 1)
+        return pos_embed
+
+    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+        x = self.patch_embed(x)
+        # x: (B, H, W, C)
+
+        # Add pos embed
+        x = x + self._get_pos_embed(x.shape[1:3])
+
+        outputs = []
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if (i == self.stage_ends[-1]) or (
+                i in self.stage_ends and self.return_interm_layers
+            ):
+                feats = x.permute(0, 3, 1, 2)
+                outputs.append(feats)
+
+        return outputs
+
+
+class HieraBBoxMask(Hiera):
+    def __init__(
+        self,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        self.bbox_mask_patch_embed = PatchEmbed(
+            in_chans=4,
+            embed_dim=self.patch_embed.proj.out_channels,
+        )
+
+    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
+        
+        # x = self.patch_embed(x)
+
+        img, condition = x[0], x[1]
+        if condition is not None:
+            # concat mask and img as condition
+            bbox_mask = torch.zeros_like(img)[:, 0:1]
+            for i in range(condition.shape[0]):
+                l, u, r, d = condition[i, 0, :]
+                bbox_mask[i, :, int(u): int(d), int(l): int(r)] = 1.0
+            condition_input = torch.concat((img, bbox_mask), dim=1)
+            x = self.patch_embed(img) + self.bbox_mask_patch_embed(condition_input)
+        else:
+            x = self.patch_embed(img)
+        # x: (B, H, W, C)
+
+        # Add pos embed
+        x = x + self._get_pos_embed(x.shape[1:3])
+
+        outputs = []
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if (i == self.stage_ends[-1]) or (
+                i in self.stage_ends and self.return_interm_layers
+            ):
+                feats = x.permute(0, 3, 1, 2)
+                outputs.append(feats)
+
+        return outputs
\ No newline at end of file
diff --git a/sam2/modeling/backbones/image_encoder.py b/sam2/modeling/backbones/image_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f92baf47dcab96385ff99899fd3e3a642c1cf9c
--- /dev/null
+++ b/sam2/modeling/backbones/image_encoder.py
@@ -0,0 +1,133 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ImageEncoder(nn.Module):
+    def __init__(
+        self,
+        trunk: nn.Module,
+        neck: nn.Module,
+        scalp: int = 0,
+    ):
+        super().__init__()
+        self.trunk = trunk
+        self.neck = neck
+        self.scalp = scalp
+        assert (
+            self.trunk.channel_list == self.neck.backbone_channel_list
+        ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
+
+    def forward(self, sample: torch.Tensor):
+        # Forward through backbone
+        features, pos = self.neck(self.trunk(sample))
+        if self.scalp > 0:
+            # Discard the lowest resolution features
+            features, pos = features[: -self.scalp], pos[: -self.scalp]
+
+        src = features[-1]
+        output = {
+            "vision_features": src,
+            "vision_pos_enc": pos,
+            "backbone_fpn": features,
+        }
+        return output
+
+
+class FpnNeck(nn.Module):
+    """
+    A modified variant of Feature Pyramid Network (FPN) neck
+    (we remove output conv and also do bicubic interpolation similar to ViT
+    pos embed interpolation)
+    """
+
+    def __init__(
+        self,
+        position_encoding: nn.Module,
+        d_model: int,
+        backbone_channel_list: List[int],
+        kernel_size: int = 1,
+        stride: int = 1,
+        padding: int = 0,
+        fpn_interp_model: str = "bilinear",
+        fuse_type: str = "sum",
+        fpn_top_down_levels: Optional[List[int]] = None,
+    ):
+        """Initialize the neck
+        :param trunk: the backbone
+        :param position_encoding: the positional encoding to use
+        :param d_model: the dimension of the model
+        :param neck_norm: the normalization to use
+        """
+        super().__init__()
+        self.position_encoding = position_encoding
+        self.convs = nn.ModuleList()
+        self.backbone_channel_list = backbone_channel_list
+        for dim in backbone_channel_list:
+            current = nn.Sequential()
+            current.add_module(
+                "conv",
+                nn.Conv2d(
+                    in_channels=dim,
+                    out_channels=d_model,
+                    kernel_size=kernel_size,
+                    stride=stride,
+                    padding=padding,
+                ),
+            )
+
+            self.convs.append(current)
+        self.fpn_interp_model = fpn_interp_model
+        assert fuse_type in ["sum", "avg"]
+        self.fuse_type = fuse_type
+
+        # levels to have top-down features in its outputs
+        # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
+        # have top-down propagation, while outputs of level 0 and level 1 have only
+        # lateral features from the same backbone level.
+        if fpn_top_down_levels is None:
+            # default is to have top-down features on all levels
+            fpn_top_down_levels = range(len(self.convs))
+        self.fpn_top_down_levels = list(fpn_top_down_levels)
+
+    def forward(self, xs: List[torch.Tensor]):
+
+        out = [None] * len(self.convs)
+        pos = [None] * len(self.convs)
+        assert len(xs) == len(self.convs)
+        # fpn forward pass
+        # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
+        prev_features = None
+        # forward in top-down order (from low to high resolution)
+        n = len(self.convs) - 1
+        for i in range(n, -1, -1):
+            x = xs[i]
+            lateral_features = self.convs[n - i](x)
+            if i in self.fpn_top_down_levels and prev_features is not None:
+                top_down_features = F.interpolate(
+                    prev_features.to(dtype=torch.float32),
+                    scale_factor=2.0,
+                    mode=self.fpn_interp_model,
+                    align_corners=(
+                        None if self.fpn_interp_model == "nearest" else False
+                    ),
+                    antialias=False,
+                )
+                prev_features = lateral_features + top_down_features
+                if self.fuse_type == "avg":
+                    prev_features /= 2
+            else:
+                prev_features = lateral_features
+            x_out = prev_features
+            out[i] = x_out
+            pos[i] = self.position_encoding(x_out).to(x_out.dtype)
+
+        return out, pos
diff --git a/sam2/modeling/backbones/utils.py b/sam2/modeling/backbones/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7
--- /dev/null
+++ b/sam2/modeling/backbones/utils.py
@@ -0,0 +1,95 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Some utilities for backbones, in particular for windowing"""
+
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def window_partition(x, window_size):
+    """
+    Partition into non-overlapping windows with padding if needed.
+    Args:
+        x (tensor): input tokens with [B, H, W, C].
+        window_size (int): window size.
+    Returns:
+        windows: windows after partition with [B * num_windows, window_size, window_size, C].
+        (Hp, Wp): padded height and width before partition
+    """
+    B, H, W, C = x.shape
+
+    pad_h = (window_size - H % window_size) % window_size
+    pad_w = (window_size - W % window_size) % window_size
+    if pad_h > 0 or pad_w > 0:
+        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
+    Hp, Wp = H + pad_h, W + pad_w
+
+    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
+    windows = (
+        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    )
+    return windows, (Hp, Wp)
+
+
+def window_unpartition(windows, window_size, pad_hw, hw):
+    """
+    Window unpartition into original sequences and removing padding.
+    Args:
+        x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
+        window_size (int): window size.
+        pad_hw (Tuple): padded height and width (Hp, Wp).
+        hw (Tuple): original height and width (H, W) before padding.
+    Returns:
+        x: unpartitioned sequences with [B, H, W, C].
+    """
+    Hp, Wp = pad_hw
+    H, W = hw
+    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
+    x = windows.view(
+        B, Hp // window_size, Wp // window_size, window_size, window_size, -1
+    )
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
+
+    if Hp > H or Wp > W:
+        x = x[:, :H, :W, :].contiguous()
+    return x
+
+
+class PatchEmbed(nn.Module):
+    """
+    Image to Patch Embedding.
+    """
+
+    def __init__(
+        self,
+        kernel_size: Tuple[int, ...] = (7, 7),
+        stride: Tuple[int, ...] = (4, 4),
+        padding: Tuple[int, ...] = (3, 3),
+        in_chans: int = 3,
+        embed_dim: int = 768,
+    ):
+        """
+        Args:
+            kernel_size (Tuple): kernel size of the projection layer.
+            stride (Tuple): stride of the projection layer.
+            padding (Tuple): padding size of the projection layer.
+            in_chans (int): Number of input image channels.
+            embed_dim (int):  embed_dim (int): Patch embedding dimension.
+        """
+        super().__init__()
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.proj(x)
+        # B C H W -> B H W C
+        x = x.permute(0, 2, 3, 1)
+        return x
diff --git a/sam2/modeling/memory_attention.py b/sam2/modeling/memory_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b07f9d87e3d8194ca5e11fc20f01604d591a59d
--- /dev/null
+++ b/sam2/modeling/memory_attention.py
@@ -0,0 +1,169 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional
+
+import torch
+from torch import nn, Tensor
+
+from sam2.modeling.sam.transformer import RoPEAttention
+
+from sam2.modeling.sam2_utils import get_activation_fn, get_clones
+
+
+class MemoryAttentionLayer(nn.Module):
+
+    def __init__(
+        self,
+        activation: str,
+        cross_attention: nn.Module,
+        d_model: int,
+        dim_feedforward: int,
+        dropout: float,
+        pos_enc_at_attn: bool,
+        pos_enc_at_cross_attn_keys: bool,
+        pos_enc_at_cross_attn_queries: bool,
+        self_attention: nn.Module,
+    ):
+        super().__init__()
+        self.d_model = d_model
+        self.dim_feedforward = dim_feedforward
+        self.dropout_value = dropout
+        self.self_attn = self_attention
+        self.cross_attn_image = cross_attention
+
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation_str = activation
+        self.activation = get_activation_fn(activation)
+
+        # Where to add pos enc
+        self.pos_enc_at_attn = pos_enc_at_attn
+        self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
+        self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
+
+    def _forward_sa(self, tgt, query_pos):
+        # Self-Attention
+        tgt2 = self.norm1(tgt)
+        q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
+        tgt2 = self.self_attn(q, k, v=tgt2)
+        tgt = tgt + self.dropout1(tgt2)
+        return tgt
+
+    def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
+        kwds = {}
+        if num_k_exclude_rope > 0:
+            assert isinstance(self.cross_attn_image, RoPEAttention)
+            kwds = {"num_k_exclude_rope": num_k_exclude_rope}
+
+        # Cross-Attention
+        tgt2 = self.norm2(tgt)
+        tgt2 = self.cross_attn_image(
+            q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
+            k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
+            v=memory,
+            **kwds,
+        )
+        tgt = tgt + self.dropout2(tgt2)
+        return tgt
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+        num_k_exclude_rope: int = 0,
+    ) -> torch.Tensor:
+
+        # Self-Attn, Cross-Attn
+        tgt = self._forward_sa(tgt, query_pos)
+        tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
+        # MLP
+        tgt2 = self.norm3(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout3(tgt2)
+        return tgt
+
+
+class MemoryAttention(nn.Module):
+    def __init__(
+        self,
+        d_model: int,
+        pos_enc_at_input: bool,
+        layer: nn.Module,
+        num_layers: int,
+        batch_first: bool = True,  # Do layers expect batch first input?
+    ):
+        super().__init__()
+        self.d_model = d_model
+        self.layers = get_clones(layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = nn.LayerNorm(d_model)
+        self.pos_enc_at_input = pos_enc_at_input
+        self.batch_first = batch_first
+
+    def forward(
+        self,
+        curr: torch.Tensor,  # self-attention inputs
+        memory: torch.Tensor,  # cross-attention inputs
+        curr_pos: Optional[Tensor] = None,  # pos_enc for self-attention inputs
+        memory_pos: Optional[Tensor] = None,  # pos_enc for cross-attention inputs
+        num_obj_ptr_tokens: int = 0,  # number of object pointer *tokens*
+    ):
+        if isinstance(curr, list):
+            assert isinstance(curr_pos, list)
+            assert len(curr) == len(curr_pos) == 1
+            curr, curr_pos = (
+                curr[0],
+                curr_pos[0],
+            )
+
+        assert (
+            curr.shape[1] == memory.shape[1]
+        ), "Batch size must be the same for curr and memory"
+
+        output = curr
+        if self.pos_enc_at_input and curr_pos is not None:
+            output = output + 0.1 * curr_pos
+
+        if self.batch_first:
+            # Convert to batch first
+            output = output.transpose(0, 1)
+            curr_pos = curr_pos.transpose(0, 1)
+            memory = memory.transpose(0, 1)
+            memory_pos = memory_pos.transpose(0, 1)
+
+        for layer in self.layers:
+            kwds = {}
+            if isinstance(layer.cross_attn_image, RoPEAttention):
+                kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
+
+            output = layer(
+                tgt=output,
+                memory=memory,
+                pos=memory_pos,
+                query_pos=curr_pos,
+                **kwds,
+            )
+        normed_output = self.norm(output)
+
+        if self.batch_first:
+            # Convert back to seq first
+            normed_output = normed_output.transpose(0, 1)
+            curr_pos = curr_pos.transpose(0, 1)
+
+        return normed_output
diff --git a/sam2/modeling/memory_encoder.py b/sam2/modeling/memory_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60202dfaba87232c3870fb2101b5322a119d985
--- /dev/null
+++ b/sam2/modeling/memory_encoder.py
@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
+
+
+class MaskDownSampler(nn.Module):
+    """
+    Progressively downsample a mask by total_stride, each time by stride.
+    Note that LayerNorm is applied per *token*, like in ViT.
+
+    With each downsample (by a factor stride**2), channel capacity increases by the same factor.
+    In the end, we linearly project to embed_dim channels.
+    """
+
+    def __init__(
+        self,
+        embed_dim=256,
+        kernel_size=4,
+        stride=4,
+        padding=0,
+        total_stride=16,
+        activation=nn.GELU,
+    ):
+        super().__init__()
+        num_layers = int(math.log2(total_stride) // math.log2(stride))
+        assert stride**num_layers == total_stride
+        self.encoder = nn.Sequential()
+        mask_in_chans, mask_out_chans = 1, 1
+        for _ in range(num_layers):
+            mask_out_chans = mask_in_chans * (stride**2)
+            self.encoder.append(
+                nn.Conv2d(
+                    mask_in_chans,
+                    mask_out_chans,
+                    kernel_size=kernel_size,
+                    stride=stride,
+                    padding=padding,
+                )
+            )
+            self.encoder.append(LayerNorm2d(mask_out_chans))
+            self.encoder.append(activation())
+            mask_in_chans = mask_out_chans
+
+        self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
+
+    def forward(self, x):
+        return self.encoder(x)
+
+
+# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
+class CXBlock(nn.Module):
+    r"""ConvNeXt Block. There are two equivalent implementations:
+    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+    We use (2) as we find it slightly faster in PyTorch
+
+    Args:
+        dim (int): Number of input channels.
+        drop_path (float): Stochastic depth rate. Default: 0.0
+        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+    """
+
+    def __init__(
+        self,
+        dim,
+        kernel_size=7,
+        padding=3,
+        drop_path=0.0,
+        layer_scale_init_value=1e-6,
+        use_dwconv=True,
+    ):
+        super().__init__()
+        self.dwconv = nn.Conv2d(
+            dim,
+            dim,
+            kernel_size=kernel_size,
+            padding=padding,
+            groups=dim if use_dwconv else 1,
+        )  # depthwise conv
+        self.norm = LayerNorm2d(dim, eps=1e-6)
+        self.pwconv1 = nn.Linear(
+            dim, 4 * dim
+        )  # pointwise/1x1 convs, implemented with linear layers
+        self.act = nn.GELU()
+        self.pwconv2 = nn.Linear(4 * dim, dim)
+        self.gamma = (
+            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
+            if layer_scale_init_value > 0
+            else None
+        )
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+    def forward(self, x):
+        input = x
+        x = self.dwconv(x)
+        x = self.norm(x)
+        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
+        x = self.pwconv1(x)
+        x = self.act(x)
+        x = self.pwconv2(x)
+        if self.gamma is not None:
+            x = self.gamma * x
+        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
+
+        x = input + self.drop_path(x)
+        return x
+
+
+class Fuser(nn.Module):
+    def __init__(self, layer, num_layers, dim=None, input_projection=False):
+        super().__init__()
+        self.proj = nn.Identity()
+        self.layers = get_clones(layer, num_layers)
+
+        if input_projection:
+            assert dim is not None
+            self.proj = nn.Conv2d(dim, dim, kernel_size=1)
+
+    def forward(self, x):
+        # normally x: (N, C, H, W)
+        x = self.proj(x)
+        for layer in self.layers:
+            x = layer(x)
+        return x
+
+
+class MemoryEncoder(nn.Module):
+    def __init__(
+        self,
+        out_dim,
+        mask_downsampler,
+        fuser,
+        position_encoding,
+        in_dim=256,  # in_dim of pix_feats
+    ):
+        super().__init__()
+
+        self.mask_downsampler = mask_downsampler
+
+        self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
+        self.fuser = fuser
+        self.position_encoding = position_encoding
+        self.out_proj = nn.Identity()
+        if out_dim != in_dim:
+            self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
+
+    def forward(
+        self,
+        pix_feat: torch.Tensor,
+        masks: torch.Tensor,
+        skip_mask_sigmoid: bool = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        ## Process masks
+        # sigmoid, so that less domain shift from gt masks which are bool
+        if not skip_mask_sigmoid:
+            masks = F.sigmoid(masks)
+        masks = self.mask_downsampler(masks)
+
+        ## Fuse pix_feats and downsampled masks
+        # in case the visual features are on CPU, cast them to CUDA
+        pix_feat = pix_feat.to(masks.device)
+
+        x = self.pix_feat_proj(pix_feat)
+        x = x + masks
+        x = self.fuser(x)
+        x = self.out_proj(x)
+
+        pos = self.position_encoding(x).to(x.dtype)
+
+        return {"vision_features": x, "vision_pos_enc": [pos]}
diff --git a/sam2/modeling/position_encoding.py b/sam2/modeling/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4b57ae7b431f3859af1368acdf4597d671cda32
--- /dev/null
+++ b/sam2/modeling/position_encoding.py
@@ -0,0 +1,216 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Any, Optional, Tuple
+
+import numpy as np
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+    """
+    This is a more standard version of the position embedding, very similar to the one
+    used by the Attention is all you need paper, generalized to work on images.
+    """
+
+    def __init__(
+        self,
+        num_pos_feats,
+        temperature: int = 10000,
+        normalize: bool = True,
+        scale: Optional[float] = None,
+    ):
+        super().__init__()
+        assert num_pos_feats % 2 == 0, "Expecting even model width"
+        self.num_pos_feats = num_pos_feats // 2
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+        self.cache = {}
+
+    def _encode_xy(self, x, y):
+        # The positions are expected to be normalized
+        assert len(x) == len(y) and x.ndim == y.ndim == 1
+        x_embed = x * self.scale
+        y_embed = y * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, None] / dim_t
+        pos_y = y_embed[:, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
+        ).flatten(1)
+        pos_y = torch.stack(
+            (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
+        ).flatten(1)
+        return pos_x, pos_y
+
+    @torch.no_grad()
+    def encode_boxes(self, x, y, w, h):
+        pos_x, pos_y = self._encode_xy(x, y)
+        pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
+        return pos
+
+    encode = encode_boxes  # Backwards compatibility
+
+    @torch.no_grad()
+    def encode_points(self, x, y, labels):
+        (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
+        assert bx == by and nx == ny and bx == bl and nx == nl
+        pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
+        pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
+        pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
+        return pos
+
+    @torch.no_grad()
+    def forward(self, x: torch.Tensor):
+        cache_key = (x.shape[-2], x.shape[-1])
+        if cache_key in self.cache:
+            return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
+        y_embed = (
+            torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
+            .view(1, -1, 1)
+            .repeat(x.shape[0], 1, x.shape[-1])
+        )
+        x_embed = (
+            torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
+            .view(1, 1, -1)
+            .repeat(x.shape[0], x.shape[-2], 1)
+        )
+
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        self.cache[cache_key] = pos[0]
+        return pos
+
+
+class PositionEmbeddingRandom(nn.Module):
+    """
+    Positional encoding using random spatial frequencies.
+    """
+
+    def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+        super().__init__()
+        if scale is None or scale <= 0.0:
+            scale = 1.0
+        self.register_buffer(
+            "positional_encoding_gaussian_matrix",
+            scale * torch.randn((2, num_pos_feats)),
+        )
+
+    def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
+        """Positionally encode points that are normalized to [0,1]."""
+        # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+        coords = 2 * coords - 1
+        coords = coords @ self.positional_encoding_gaussian_matrix
+        coords = 2 * np.pi * coords
+        # outputs d_1 x ... x d_n x C shape
+        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+    def forward(self, size: Tuple[int, int]) -> torch.Tensor:
+        """Generate positional encoding for a grid of the specified size."""
+        h, w = size
+        device: Any = self.positional_encoding_gaussian_matrix.device
+        grid = torch.ones((h, w), device=device, dtype=torch.float32)
+        y_embed = grid.cumsum(dim=0) - 0.5
+        x_embed = grid.cumsum(dim=1) - 0.5
+        y_embed = y_embed / h
+        x_embed = x_embed / w
+
+        pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
+        return pe.permute(2, 0, 1)  # C x H x W
+
+    def forward_with_coords(
+        self, coords_input: torch.Tensor, image_size: Tuple[int, int]
+    ) -> torch.Tensor:
+        """Positionally encode points that are not normalized to [0,1]."""
+        coords = coords_input.clone()
+        coords[:, :, 0] = coords[:, :, 0] / image_size[1]
+        coords[:, :, 1] = coords[:, :, 1] / image_size[0]
+        return self._pe_encoding(coords.to(torch.float))  # B x N x C
+
+
+# Rotary Positional Encoding, adapted from:
+# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
+# 2. https://github.com/naver-ai/rope-vit
+# 3. https://github.com/lucidrains/rotary-embedding-torch
+
+
+def init_t_xy(end_x: int, end_y: int):
+    t = torch.arange(end_x * end_y, dtype=torch.float32)
+    t_x = (t % end_x).float()
+    t_y = torch.div(t, end_x, rounding_mode="floor").float()
+    return t_x, t_y
+
+
+def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
+    freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+    freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
+
+    t_x, t_y = init_t_xy(end_x, end_y)
+    freqs_x = torch.outer(t_x, freqs_x)
+    freqs_y = torch.outer(t_y, freqs_y)
+    freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
+    freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
+    return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
+    shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
+
+
+def apply_rotary_enc(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    freqs_cis: torch.Tensor,
+    repeat_freqs_k: bool = False,
+):
+    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+    xk_ = (
+        torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+        if xk.shape[-2] != 0
+        else None
+    )
+    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+    if xk_ is None:
+        # no keys to rotate, due to dropout
+        return xq_out.type_as(xq).to(xq.device), xk
+    # repeat freqs along seq_len dim to match k seq_len
+    if repeat_freqs_k:
+        r = xk_.shape[-2] // xq_.shape[-2]
+        freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
+    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
diff --git a/sam2/modeling/sam/__init__.py b/sam2/modeling/sam/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/modeling/sam/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/modeling/sam/__pycache__/__init__.cpython-38.pyc b/sam2/modeling/sam/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..04d07d5dfdb24c5405223f23e93e4fb4498d37e6
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sam2/modeling/sam/__pycache__/mask_decoder.cpython-38.pyc b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f2af91dbdace3f54d6ffd5e14eccd7a265f0abb4
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/mask_decoder.cpython-38.pyc differ
diff --git a/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-38.pyc b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6418f0562c84ba4a72a8841b6e9d010c06496a63
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/prompt_encoder.cpython-38.pyc differ
diff --git a/sam2/modeling/sam/__pycache__/transformer.cpython-38.pyc b/sam2/modeling/sam/__pycache__/transformer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ebc5538fc8f07a878d4fac3f698729506518b2f6
Binary files /dev/null and b/sam2/modeling/sam/__pycache__/transformer.cpython-38.pyc differ
diff --git a/sam2/modeling/sam/mask_decoder.py b/sam2/modeling/sam/mask_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..775e05895572e884043f4e1c2e72cd6205302661
--- /dev/null
+++ b/sam2/modeling/sam/mask_decoder.py
@@ -0,0 +1,458 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.sam2_utils import LayerNorm2d, MLP
+
+
+class MaskDecoder(nn.Module):
+    def __init__(
+        self,
+        *,
+        transformer_dim: int,
+        transformer: nn.Module,
+        num_multimask_outputs: int = 3,
+        activation: Type[nn.Module] = nn.GELU,
+        iou_head_depth: int = 3,
+        iou_head_hidden_dim: int = 256,
+        use_high_res_features: bool = False,
+        iou_prediction_use_sigmoid=False,
+        dynamic_multimask_via_stability=False,
+        dynamic_multimask_stability_delta=0.05,
+        dynamic_multimask_stability_thresh=0.98,
+        pred_obj_scores: bool = False,
+        pred_obj_scores_mlp: bool = False,
+        use_multimask_token_for_obj_ptr: bool = False,
+    ) -> None:
+        """
+        Predicts masks given an image and prompt embeddings, using a
+        transformer architecture.
+
+        Arguments:
+          transformer_dim (int): the channel dimension of the transformer
+          transformer (nn.Module): the transformer used to predict masks
+          num_multimask_outputs (int): the number of masks to predict
+            when disambiguating masks
+          activation (nn.Module): the type of activation to use when
+            upscaling masks
+          iou_head_depth (int): the depth of the MLP used to predict
+            mask quality
+          iou_head_hidden_dim (int): the hidden dimension of the MLP
+            used to predict mask quality
+        """
+        super().__init__()
+        self.transformer_dim = transformer_dim
+        self.transformer = transformer
+
+        self.num_multimask_outputs = num_multimask_outputs
+
+        self.iou_token = nn.Embedding(1, transformer_dim)
+        self.num_mask_tokens = num_multimask_outputs + 1
+        self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
+
+        self.pred_obj_scores = pred_obj_scores
+        if self.pred_obj_scores:
+            self.obj_score_token = nn.Embedding(1, transformer_dim)
+        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+
+        self.output_upscaling = nn.Sequential(
+            nn.ConvTranspose2d(
+                transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
+            ),
+            LayerNorm2d(transformer_dim // 4),
+            activation(),
+            nn.ConvTranspose2d(
+                transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
+            ),
+            activation(),
+        )
+        self.use_high_res_features = use_high_res_features
+        if use_high_res_features:
+            self.conv_s0 = nn.Conv2d(
+                transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
+            )
+            self.conv_s1 = nn.Conv2d(
+                transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
+            )
+
+        self.output_hypernetworks_mlps = nn.ModuleList(
+            [
+                MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
+                for i in range(self.num_mask_tokens)
+            ]
+        )
+
+        self.iou_prediction_head = MLP(
+            transformer_dim,
+            iou_head_hidden_dim,
+            self.num_mask_tokens,
+            iou_head_depth,
+            sigmoid_output=iou_prediction_use_sigmoid,
+        )
+        if self.pred_obj_scores:
+            self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
+            if pred_obj_scores_mlp:
+                self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
+
+        # When outputting a single mask, optionally we can dynamically fall back to the best
+        # multimask output token if the single mask output token gives low stability scores.
+        self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
+        self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
+        self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        repeat_image: bool,
+        high_res_features: Optional[List[torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Predict masks given image and prompt embeddings.
+
+        Arguments:
+          image_embeddings (torch.Tensor): the embeddings from the image encoder
+          image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
+          sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
+          dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
+          multimask_output (bool): Whether to return multiple masks or a single
+            mask.
+
+        Returns:
+          torch.Tensor: batched predicted masks
+          torch.Tensor: batched predictions of mask quality
+          torch.Tensor: batched SAM token for mask output
+        """
+        masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+            repeat_image=repeat_image,
+            high_res_features=high_res_features,
+        )
+
+        # Select the correct mask or masks for output
+        if multimask_output:
+            masks = masks[:, 1:, :, :]
+            iou_pred = iou_pred[:, 1:]
+        elif self.dynamic_multimask_via_stability and not self.training:
+            masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
+        else:
+            masks = masks[:, 0:1, :, :]
+            iou_pred = iou_pred[:, 0:1]
+
+        if multimask_output and self.use_multimask_token_for_obj_ptr:
+            sam_tokens_out = mask_tokens_out[:, 1:]  # [b, 3, c] shape
+        else:
+            # Take the mask output token. Here we *always* use the token for single mask output.
+            # At test time, even if we track after 1-click (and using multimask_output=True),
+            # we still take the single mask token here. The rationale is that we always track
+            # after multiple clicks during training, so the past tokens seen during training
+            # are always the single mask token (and we'll let it be the object-memory token).
+            sam_tokens_out = mask_tokens_out[:, 0:1]  # [b, 1, c] shape
+
+        # Prepare output
+        return masks, iou_pred, sam_tokens_out, object_score_logits
+
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        repeat_image: bool,
+        high_res_features: Optional[List[torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+        # Concatenate output tokens
+        s = 0
+        if self.pred_obj_scores:
+            output_tokens = torch.cat(
+                [
+                    self.obj_score_token.weight,
+                    self.iou_token.weight,
+                    self.mask_tokens.weight,
+                ],
+                dim=0,
+            )
+            s = 1
+        else:
+            output_tokens = torch.cat(
+                [self.iou_token.weight, self.mask_tokens.weight], dim=0
+            )
+        output_tokens = output_tokens.unsqueeze(0).expand(
+            sparse_prompt_embeddings.size(0), -1, -1
+        )
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        if repeat_image:
+            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        else:
+            assert image_embeddings.shape[0] == tokens.shape[0]
+            src = image_embeddings
+        src = src + dense_prompt_embeddings
+        assert (
+            image_pe.size(0) == 1
+        ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, s, :]
+        mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)
+        if not self.use_high_res_features:
+            upscaled_embedding = self.output_upscaling(src)
+        else:
+            dc1, ln1, act1, dc2, act2 = self.output_upscaling
+            feat_s0, feat_s1 = high_res_features
+            upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+            upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+        hyper_in_list: List[torch.Tensor] = []
+        for i in range(self.num_mask_tokens):
+            hyper_in_list.append(
+                self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
+            )
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding.shape
+        masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+
+        # Generate mask quality predictions
+        iou_pred = self.iou_prediction_head(iou_token_out)
+        if self.pred_obj_scores:
+            assert s == 1
+            object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+        else:
+            # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+            object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+        return masks, iou_pred, mask_tokens_out, object_score_logits
+
+    def _get_stability_scores(self, mask_logits):
+        """
+        Compute stability scores of the mask logits based on the IoU between upper and
+        lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
+        """
+        mask_logits = mask_logits.flatten(-2)
+        stability_delta = self.dynamic_multimask_stability_delta
+        area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
+        area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
+        stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
+        return stability_scores
+
+    def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
+        """
+        When outputting a single mask, if the stability score from the current single-mask
+        output (based on output token 0) falls below a threshold, we instead select from
+        multi-mask outputs (based on output token 1~3) the mask with the highest predicted
+        IoU score. This is intended to ensure a valid mask for both clicking and tracking.
+        """
+        # The best mask from multimask output tokens (1~3)
+        multimask_logits = all_mask_logits[:, 1:, :, :]
+        multimask_iou_scores = all_iou_scores[:, 1:]
+        best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
+        batch_inds = torch.arange(
+            multimask_iou_scores.size(0), device=all_iou_scores.device
+        )
+        best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
+        best_multimask_logits = best_multimask_logits.unsqueeze(1)
+        best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
+        best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
+
+        # The mask from singlemask output token 0 and its stability score
+        singlemask_logits = all_mask_logits[:, 0:1, :, :]
+        singlemask_iou_scores = all_iou_scores[:, 0:1]
+        stability_scores = self._get_stability_scores(singlemask_logits)
+        is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
+
+        # Dynamically fall back to best multimask output upon low stability scores.
+        mask_logits_out = torch.where(
+            is_stable[..., None, None].expand_as(singlemask_logits),
+            singlemask_logits,
+            best_multimask_logits,
+        )
+        iou_scores_out = torch.where(
+            is_stable.expand_as(singlemask_iou_scores),
+            singlemask_iou_scores,
+            best_multimask_iou_scores,
+        )
+        return mask_logits_out, iou_scores_out
+
+
+from copy import deepcopy
+
+class MaskDecoderMattingToken(MaskDecoder):
+    def __init__(
+        self,
+        matting_token_num = 3,
+        upscaled_embedding_res_add = True,
+        matting_logits_res_add = False,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+
+        self.matting_token_num = matting_token_num
+        self.upscaled_embedding_res_add = upscaled_embedding_res_add
+        self.matting_logits_res_add = matting_logits_res_add
+
+        self.num_mask_tokens = self.num_mask_tokens + self.matting_token_num
+        self.matting_mask_tokens = nn.Embedding(self.matting_token_num, self.transformer_dim)
+        self.matting_output_hypernetworks_mlps = nn.ModuleList(
+            [
+                MLP(self.transformer_dim, self.transformer_dim, self.transformer_dim // 8, 3)
+                for i in range(self.matting_token_num)
+            ]
+        )
+        self.matting_output_upscaling = deepcopy(self.output_upscaling)
+
+    def forward(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        multimask_output: bool,
+        repeat_image: bool,
+        high_res_features: Optional[List[torch.Tensor]] = None,
+        wo_matting_token: bool = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+        sam2_logits, matting_logits = self.predict_masks(
+            image_embeddings=image_embeddings,
+            image_pe=image_pe,
+            sparse_prompt_embeddings=sparse_prompt_embeddings,
+            dense_prompt_embeddings=dense_prompt_embeddings,
+            repeat_image=repeat_image,
+            high_res_features=high_res_features,
+            wo_matting_token=wo_matting_token,
+        )
+        sam2_logits = sam2_logits[:, 0:1, :, :]
+        if self.matting_logits_res_add and matting_logits is not None:
+            matting_logits = matting_logits + sam2_logits
+
+        return sam2_logits, matting_logits
+    
+    def predict_masks(
+        self,
+        image_embeddings: torch.Tensor,
+        image_pe: torch.Tensor,
+        sparse_prompt_embeddings: torch.Tensor,
+        dense_prompt_embeddings: torch.Tensor,
+        repeat_image: bool,
+        high_res_features: Optional[List[torch.Tensor]] = None,
+        wo_matting_token: bool = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Predicts masks. See 'forward' for more details."""
+        # Concatenate output tokens
+        s = 0
+        if self.pred_obj_scores:
+            output_tokens = torch.cat(
+                [
+                    self.obj_score_token.weight,
+                    self.iou_token.weight,
+                    self.mask_tokens.weight,
+                    self.matting_mask_tokens.weight
+                ],
+                dim=0,
+            )
+            s = 1
+        else:
+            output_tokens = torch.cat(
+                [self.iou_token.weight, self.mask_tokens.weight, self.matting_mask_tokens.weight], dim=0
+            )
+
+        if wo_matting_token:
+            output_tokens = output_tokens[:-self.matting_token_num]
+
+        output_tokens = output_tokens.unsqueeze(0).expand(
+            sparse_prompt_embeddings.size(0), -1, -1
+        )
+        tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
+
+        # Expand per-image data in batch direction to be per-mask
+        if repeat_image:
+            src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
+        else:
+            assert image_embeddings.shape[0] == tokens.shape[0]
+            src = image_embeddings
+        src = src + dense_prompt_embeddings
+        assert (
+            image_pe.size(0) == 1
+        ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
+        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
+        b, c, h, w = src.shape
+
+        # Run the transformer
+        hs, src = self.transformer(src, pos_src, tokens)
+        iou_token_out = hs[:, s, :]
+        mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
+
+        # Upscale mask embeddings and predict masks using the mask tokens
+        src = src.transpose(1, 2).view(b, c, h, w)
+            # if not self.use_high_res_features:comparison
+            #     upscaled_embedding = self.output_upscaling(src)
+            # else:
+        assert self.use_high_res_features
+            # ori process
+        dc1, ln1, act1, dc2, act2 = self.output_upscaling
+        feat_s0, feat_s1 = high_res_features
+        upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
+        upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
+
+        if not wo_matting_token:
+                # matting process
+            matting_dc1, matting_ln1, matting_act1, matting_dc2, matting_act2 = self.matting_output_upscaling
+            matting_upscaled_embedding = matting_act1(matting_ln1(matting_dc1(src) + feat_s1))
+            matting_upscaled_embedding = matting_act2(matting_dc2(matting_upscaled_embedding) + feat_s0)
+            if self.upscaled_embedding_res_add:
+                matting_upscaled_embedding = upscaled_embedding + matting_upscaled_embedding  # use res form
+
+        hyper_in_list: List[torch.Tensor] = []
+        for i in range(self.num_mask_tokens - self.matting_token_num if wo_matting_token else self.num_mask_tokens):
+            if i < self.num_mask_tokens - self.matting_token_num:
+                hyper_in_list.append(
+                    self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
+                )
+            # matting token
+            else:
+                hyper_in_list.append(
+                    self.matting_output_hypernetworks_mlps[i - (self.num_mask_tokens - self.matting_token_num)](mask_tokens_out[:, i, :])
+                )
+
+        hyper_in = torch.stack(hyper_in_list, dim=1)
+        b, c, h, w = upscaled_embedding.shape
+
+            # ori token
+        sam2_logits = (hyper_in[:, :self.num_mask_tokens - self.matting_token_num] @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+        if not wo_matting_token:
+                # matting token
+            matting_logits = (hyper_in[:, -self.matting_token_num: ] @ matting_upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
+        else:
+            matting_logits = None
+        # Generate mask quality predictions
+        # iou_pred = self.iou_prediction_head(iou_token_out)
+        # if self.pred_obj_scores:
+        #     assert s == 1
+        #     object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
+        # else:
+        #     # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
+        #     object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
+
+        return sam2_logits, matting_logits
\ No newline at end of file
diff --git a/sam2/modeling/sam/prompt_encoder.py b/sam2/modeling/sam/prompt_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b3bbb95be0aea9c88f49f586ac959a9fda1b18b
--- /dev/null
+++ b/sam2/modeling/sam/prompt_encoder.py
@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Tuple, Type
+
+import torch
+from torch import nn
+
+from sam2.modeling.position_encoding import PositionEmbeddingRandom
+
+from sam2.modeling.sam2_utils import LayerNorm2d
+
+
+class PromptEncoder(nn.Module):
+    def __init__(
+        self,
+        embed_dim: int,
+        image_embedding_size: Tuple[int, int],
+        input_image_size: Tuple[int, int],
+        mask_in_chans: int,
+        activation: Type[nn.Module] = nn.GELU,
+    ) -> None:
+        """
+        Encodes prompts for input to SAM's mask decoder.
+
+        Arguments:
+          embed_dim (int): The prompts' embedding dimension
+          image_embedding_size (tuple(int, int)): The spatial size of the
+            image embedding, as (H, W).
+          input_image_size (int): The padded size of the image as input
+            to the image encoder, as (H, W).
+          mask_in_chans (int): The number of hidden channels used for
+            encoding input masks.
+          activation (nn.Module): The activation to use when encoding
+            input masks.
+        """
+        super().__init__()
+        self.embed_dim = embed_dim
+        self.input_image_size = input_image_size
+        self.image_embedding_size = image_embedding_size
+        self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
+
+        self.num_point_embeddings: int = 4  # pos/neg point + 2 box corners
+        point_embeddings = [
+            nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
+        ]
+        self.point_embeddings = nn.ModuleList(point_embeddings)
+        self.not_a_point_embed = nn.Embedding(1, embed_dim)
+
+        self.mask_input_size = (
+            4 * image_embedding_size[0],
+            4 * image_embedding_size[1],
+        )
+        self.mask_downscaling = nn.Sequential(
+            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans // 4),
+            activation(),
+            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
+            LayerNorm2d(mask_in_chans),
+            activation(),
+            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
+        )
+        self.no_mask_embed = nn.Embedding(1, embed_dim)
+
+    def get_dense_pe(self) -> torch.Tensor:
+        """
+        Returns the positional encoding used to encode point prompts,
+        applied to a dense set of points the shape of the image encoding.
+
+        Returns:
+          torch.Tensor: Positional encoding with shape
+            1x(embed_dim)x(embedding_h)x(embedding_w)
+        """
+        return self.pe_layer(self.image_embedding_size).unsqueeze(0)
+
+    def _embed_points(
+        self,
+        points: torch.Tensor,
+        labels: torch.Tensor,
+        pad: bool,
+    ) -> torch.Tensor:
+        """Embeds point prompts."""
+        points = points + 0.5  # Shift to center of pixel
+        if pad:
+            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
+            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
+            points = torch.cat([points, padding_point], dim=1)
+            labels = torch.cat([labels, padding_label], dim=1)
+        point_embedding = self.pe_layer.forward_with_coords(
+            points, self.input_image_size
+        )
+        point_embedding[labels == -1] = 0.0
+        point_embedding[labels == -1] += self.not_a_point_embed.weight
+        point_embedding[labels == 0] += self.point_embeddings[0].weight
+        point_embedding[labels == 1] += self.point_embeddings[1].weight
+        point_embedding[labels == 2] += self.point_embeddings[2].weight
+        point_embedding[labels == 3] += self.point_embeddings[3].weight
+        return point_embedding
+
+    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
+        """Embeds box prompts."""
+        boxes = boxes + 0.5  # Shift to center of pixel
+        coords = boxes.reshape(-1, 2, 2)
+        corner_embedding = self.pe_layer.forward_with_coords(
+            coords, self.input_image_size
+        )
+        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
+        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
+        return corner_embedding
+
+    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
+        """Embeds mask inputs."""
+        mask_embedding = self.mask_downscaling(masks)
+        return mask_embedding
+
+    def _get_batch_size(
+        self,
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> int:
+        """
+        Gets the batch size of the output given the batch size of the input prompts.
+        """
+        if points is not None:
+            return points[0].shape[0]
+        elif boxes is not None:
+            return boxes.shape[0]
+        elif masks is not None:
+            return masks.shape[0]
+        else:
+            return 1
+
+    def _get_device(self) -> torch.device:
+        return self.point_embeddings[0].weight.device
+
+    def forward(
+        self,
+        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
+        boxes: Optional[torch.Tensor],
+        masks: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Embeds different types of prompts, returning both sparse and dense
+        embeddings.
+
+        Arguments:
+          points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
+            and labels to embed.
+          boxes (torch.Tensor or none): boxes to embed
+          masks (torch.Tensor or none): masks to embed
+
+        Returns:
+          torch.Tensor: sparse embeddings for the points and boxes, with shape
+            BxNx(embed_dim), where N is determined by the number of input points
+            and boxes.
+          torch.Tensor: dense embeddings for the masks, in the shape
+            Bx(embed_dim)x(embed_H)x(embed_W)
+        """
+        bs = self._get_batch_size(points, boxes, masks)
+        sparse_embeddings = torch.empty(
+            (bs, 0, self.embed_dim), device=self._get_device()
+        )
+        if points is not None:
+            coords, labels = points
+            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
+            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
+        if boxes is not None:
+            box_embeddings = self._embed_boxes(boxes)
+            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
+
+        if masks is not None:
+            dense_embeddings = self._embed_masks(masks)
+        else:
+            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
+                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
+            )
+
+        return sparse_embeddings, dense_embeddings
diff --git a/sam2/modeling/sam/transformer.py b/sam2/modeling/sam/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5b6fa2f87e85a7f222fb2ba0b661734dc57a08a
--- /dev/null
+++ b/sam2/modeling/sam/transformer.py
@@ -0,0 +1,360 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import contextlib
+import math
+import warnings
+from functools import partial
+from typing import Tuple, Type
+
+import torch
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
+from sam2.modeling.sam2_utils import MLP
+from sam2.utils.misc import get_sdpa_settings
+
+warnings.simplefilter(action="ignore", category=FutureWarning)
+# Check whether Flash Attention is available (and use it by default)
+OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
+# A fallback setting to allow all available kernels if Flash Attention fails
+ALLOW_ALL_KERNELS = False
+
+
+def sdp_kernel_context(dropout_p):
+    """
+    Get the context for the attention scaled dot-product kernel. We use Flash Attention
+    by default, but fall back to all available kernels if Flash Attention fails.
+    """
+    if ALLOW_ALL_KERNELS:
+        return contextlib.nullcontext()
+
+    return torch.backends.cuda.sdp_kernel(
+        enable_flash=USE_FLASH_ATTN,
+        # if Flash attention kernel is off, then math kernel needs to be enabled
+        enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
+        enable_mem_efficient=OLD_GPU,
+    )
+
+
+class TwoWayTransformer(nn.Module):
+    def __init__(
+        self,
+        depth: int,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+    ) -> None:
+        """
+        A transformer decoder that attends to an input image using
+        queries whose positional embedding is supplied.
+
+        Args:
+          depth (int): number of layers in the transformer
+          embedding_dim (int): the channel dimension for the input embeddings
+          num_heads (int): the number of heads for multihead attention. Must
+            divide embedding_dim
+          mlp_dim (int): the channel dimension internal to the MLP block
+          activation (nn.Module): the activation to use in the MLP block
+        """
+        super().__init__()
+        self.depth = depth
+        self.embedding_dim = embedding_dim
+        self.num_heads = num_heads
+        self.mlp_dim = mlp_dim
+        self.layers = nn.ModuleList()
+
+        for i in range(depth):
+            self.layers.append(
+                TwoWayAttentionBlock(
+                    embedding_dim=embedding_dim,
+                    num_heads=num_heads,
+                    mlp_dim=mlp_dim,
+                    activation=activation,
+                    attention_downsample_rate=attention_downsample_rate,
+                    skip_first_layer_pe=(i == 0),
+                )
+            )
+
+        self.final_attn_token_to_image = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+        self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+    def forward(
+        self,
+        image_embedding: Tensor,
+        image_pe: Tensor,
+        point_embedding: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Args:
+          image_embedding (torch.Tensor): image to attend to. Should be shape
+            B x embedding_dim x h x w for any h and w.
+          image_pe (torch.Tensor): the positional encoding to add to the image. Must
+            have the same shape as image_embedding.
+          point_embedding (torch.Tensor): the embedding to add to the query points.
+            Must have shape B x N_points x embedding_dim for any N_points.
+
+        Returns:
+          torch.Tensor: the processed point_embedding
+          torch.Tensor: the processed image_embedding
+        """
+        # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+        bs, c, h, w = image_embedding.shape
+        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+        image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+        # Prepare queries
+        queries = point_embedding
+        keys = image_embedding
+
+        # Apply transformer blocks and final layernorm
+        for layer in self.layers:
+            queries, keys = layer(
+                queries=queries,
+                keys=keys,
+                query_pe=point_embedding,
+                key_pe=image_pe,
+            )
+
+        # Apply the final attention layer from the points to the image
+        q = queries + point_embedding
+        k = keys + image_pe
+        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm_final_attn(queries)
+
+        return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        mlp_dim: int = 2048,
+        activation: Type[nn.Module] = nn.ReLU,
+        attention_downsample_rate: int = 2,
+        skip_first_layer_pe: bool = False,
+    ) -> None:
+        """
+        A transformer block with four layers: (1) self-attention of sparse
+        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+        block on sparse inputs, and (4) cross attention of dense inputs to sparse
+        inputs.
+
+        Arguments:
+          embedding_dim (int): the channel dimension of the embeddings
+          num_heads (int): the number of heads in the attention layers
+          mlp_dim (int): the hidden dimension of the mlp block
+          activation (nn.Module): the activation of the mlp block
+          skip_first_layer_pe (bool): skip the PE on the first layer
+        """
+        super().__init__()
+        self.self_attn = Attention(embedding_dim, num_heads)
+        self.norm1 = nn.LayerNorm(embedding_dim)
+
+        self.cross_attn_token_to_image = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+        self.norm2 = nn.LayerNorm(embedding_dim)
+
+        self.mlp = MLP(
+            embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
+        )
+        self.norm3 = nn.LayerNorm(embedding_dim)
+
+        self.norm4 = nn.LayerNorm(embedding_dim)
+        self.cross_attn_image_to_token = Attention(
+            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
+        )
+
+        self.skip_first_layer_pe = skip_first_layer_pe
+
+    def forward(
+        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        # Self attention block
+        if self.skip_first_layer_pe:
+            queries = self.self_attn(q=queries, k=queries, v=queries)
+        else:
+            q = queries + query_pe
+            attn_out = self.self_attn(q=q, k=q, v=queries)
+            queries = queries + attn_out
+        queries = self.norm1(queries)
+
+        # Cross attention block, tokens attending to image embedding
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+        queries = queries + attn_out
+        queries = self.norm2(queries)
+
+        # MLP block
+        mlp_out = self.mlp(queries)
+        queries = queries + mlp_out
+        queries = self.norm3(queries)
+
+        # Cross attention block, image embedding attending to tokens
+        q = queries + query_pe
+        k = keys + key_pe
+        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+        keys = keys + attn_out
+        keys = self.norm4(keys)
+
+        return queries, keys
+
+
+class Attention(nn.Module):
+    """
+    An attention layer that allows for downscaling the size of the embedding
+    after projection to queries, keys, and values.
+    """
+
+    def __init__(
+        self,
+        embedding_dim: int,
+        num_heads: int,
+        downsample_rate: int = 1,
+        dropout: float = 0.0,
+        kv_in_dim: int = None,
+    ) -> None:
+        super().__init__()
+        self.embedding_dim = embedding_dim
+        self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
+        self.internal_dim = embedding_dim // downsample_rate
+        self.num_heads = num_heads
+        assert (
+            self.internal_dim % num_heads == 0
+        ), "num_heads must divide embedding_dim."
+
+        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+        self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+        self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
+        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+        self.dropout_p = dropout
+
+    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
+        b, n, c = x.shape
+        x = x.reshape(b, n, num_heads, c // num_heads)
+        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head
+
+    def _recombine_heads(self, x: Tensor) -> Tensor:
+        b, n_heads, n_tokens, c_per_head = x.shape
+        x = x.transpose(1, 2)
+        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C
+
+    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
+        # Input projections
+        q = self.q_proj(q)
+        k = self.k_proj(k)
+        v = self.v_proj(v)
+
+        # Separate into heads
+        q = self._separate_heads(q, self.num_heads)
+        k = self._separate_heads(k, self.num_heads)
+        v = self._separate_heads(v, self.num_heads)
+
+        dropout_p = self.dropout_p if self.training else 0.0
+        # Attention
+        try:
+            with sdp_kernel_context(dropout_p):
+                out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+        except Exception as e:
+            # Fall back to all kernels if the Flash attention kernel fails
+            warnings.warn(
+                f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
+                f"kernels for scaled_dot_product_attention (which may have a slower speed).",
+                category=UserWarning,
+                stacklevel=2,
+            )
+            global ALLOW_ALL_KERNELS
+            ALLOW_ALL_KERNELS = True
+            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+        out = self._recombine_heads(out)
+        out = self.out_proj(out)
+
+        return out
+
+
+class RoPEAttention(Attention):
+    """Attention with rotary position encoding."""
+
+    def __init__(
+        self,
+        *args,
+        rope_theta=10000.0,
+        # whether to repeat q rope to match k length
+        # this is needed for cross-attention to memories
+        rope_k_repeat=False,
+        feat_sizes=(32, 32),  # [w, h] for stride 16 feats at 512 resolution
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+
+        self.compute_cis = partial(
+            compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
+        )
+        freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
+        self.freqs_cis = freqs_cis
+        self.rope_k_repeat = rope_k_repeat
+
+    def forward(
+        self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
+    ) -> Tensor:
+        # Input projections
+        q = self.q_proj(q)
+        k = self.k_proj(k)
+        v = self.v_proj(v)
+
+        # Separate into heads
+        q = self._separate_heads(q, self.num_heads)
+        k = self._separate_heads(k, self.num_heads)
+        v = self._separate_heads(v, self.num_heads)
+
+        # Apply rotary position encoding
+        w = h = math.sqrt(q.shape[-2])
+        self.freqs_cis = self.freqs_cis.to(q.device)
+        if self.freqs_cis.shape[0] != q.shape[-2]:
+            self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
+        if q.shape[-2] != k.shape[-2]:
+            assert self.rope_k_repeat
+
+        num_k_rope = k.size(-2) - num_k_exclude_rope
+        q, k[:, :, :num_k_rope] = apply_rotary_enc(
+            q,
+            k[:, :, :num_k_rope],
+            freqs_cis=self.freqs_cis,
+            repeat_freqs_k=self.rope_k_repeat,
+        )
+
+        dropout_p = self.dropout_p if self.training else 0.0
+        # Attention
+        try:
+            with sdp_kernel_context(dropout_p):
+                out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+        except Exception as e:
+            # Fall back to all kernels if the Flash attention kernel fails
+            warnings.warn(
+                f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
+                f"kernels for scaled_dot_product_attention (which may have a slower speed).",
+                category=UserWarning,
+                stacklevel=2,
+            )
+            global ALLOW_ALL_KERNELS
+            ALLOW_ALL_KERNELS = True
+            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
+
+        out = self._recombine_heads(out)
+        out = self.out_proj(out)
+
+        return out
diff --git a/sam2/modeling/sam2_base.py b/sam2/modeling/sam2_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b79524c80f1453b96361c9ffeede6ad08fe87bf0
--- /dev/null
+++ b/sam2/modeling/sam2_base.py
@@ -0,0 +1,864 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.distributed
+import torch.nn.functional as F
+
+from torch.nn.init import trunc_normal_
+
+from sam2.modeling.sam.mask_decoder import MaskDecoder, MaskDecoderMattingToken
+from sam2.modeling.sam.prompt_encoder import PromptEncoder
+from sam2.modeling.sam.transformer import TwoWayTransformer
+from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames
+
+# a large negative value as a placeholder score for missing objects
+NO_OBJ_SCORE = -1024.0
+
+
+class SAM2Base(torch.nn.Module):
+    def __init__(
+        self,
+        image_encoder,
+        memory_attention,
+        memory_encoder,
+        num_maskmem=7,  # default 1 input frame + 6 previous frames
+        image_size=512,
+        backbone_stride=16,  # stride of the image backbone output
+        sigmoid_scale_for_mem_enc=1.0,  # scale factor for mask sigmoid prob
+        sigmoid_bias_for_mem_enc=0.0,  # bias factor for mask sigmoid prob
+        # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks
+        binarize_mask_from_pts_for_mem_enc=False,
+        use_mask_input_as_output_without_sam=False,  # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder
+        # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit,
+        # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model
+        # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM.
+        max_cond_frames_in_attn=-1,
+        # on the first frame, whether to directly add the no-memory embedding to the image feature
+        # (instead of using the transformer encoder)
+        directly_add_no_mem_embed=False,
+        # whether to use high-resolution feature maps in the SAM mask decoder
+        use_high_res_features_in_sam=False,
+        # whether to output multiple (3) masks for the first click on initial conditioning frames
+        multimask_output_in_sam=False,
+        # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`;
+        # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points)
+        multimask_min_pt_num=1,
+        multimask_max_pt_num=1,
+        # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`)
+        multimask_output_for_tracking=False,
+        # Whether to use multimask tokens for obj ptr; Only relevant when both
+        # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True
+        use_multimask_token_for_obj_ptr: bool = False,
+        # whether to use sigmoid to restrict ious prediction to [0-1]
+        iou_prediction_use_sigmoid=False,
+        # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5).
+        # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of
+        # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame.
+        memory_temporal_stride_for_eval=1,
+        # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click
+        # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames
+        add_all_frames_to_correct_as_cond=False,
+        # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks)
+        non_overlap_masks_for_mem_enc=False,
+        # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+        use_obj_ptrs_in_encoder=False,
+        # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`)
+        max_obj_ptrs_in_encoder=16,
+        # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`)
+        add_tpos_enc_to_obj_ptrs=True,
+        # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference
+        # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`)
+        proj_tpos_enc_in_obj_ptrs=False,
+        # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation
+        # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking)
+        only_obj_ptrs_in_the_past_for_eval=False,
+        # Whether to predict if there is an object in the frame
+        pred_obj_scores: bool = False,
+        # Whether to use an MLP to predict object scores
+        pred_obj_scores_mlp: bool = False,
+        # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True;
+        # Whether to have a fixed no obj pointer when there is no object present
+        # or to use it as an additive embedding with obj_ptr produced by decoder
+        fixed_no_obj_ptr: bool = False,
+        # Soft no object, i.e. mix in no_obj_ptr softly,
+        # hope to make recovery easier if there is a mistake and mitigate accumulation of errors
+        soft_no_obj_ptr: bool = False,
+        use_mlp_for_obj_ptr_proj: bool = False,
+        # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class.
+        sam_mask_decoder_extra_args=None,
+        compile_image_encoder: bool = False,
+        # xrh defined
+        mask_decoder_matting_token: bool = False,
+        matting_logits_res_add: bool = False,
+        upscaled_embedding_res_add: bool = True,
+    ):
+        super().__init__()
+        
+        # xrh
+        self.mask_decoder_matting_token = mask_decoder_matting_token
+        self.matting_logits_res_add = matting_logits_res_add
+        self.upscaled_embedding_res_add = upscaled_embedding_res_add
+
+        # Part 1: the image backbone
+        self.image_encoder = image_encoder
+        # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting
+        self.use_high_res_features_in_sam = use_high_res_features_in_sam
+        self.num_feature_levels = 3 if use_high_res_features_in_sam else 1
+        self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder
+        self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder
+        if use_obj_ptrs_in_encoder:
+            # A conv layer to downsample the mask prompt to stride 4 (the same stride as
+            # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale,
+            # so that it can be fed into the SAM mask decoder to generate a pointer.
+            self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4)
+        self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs
+        if proj_tpos_enc_in_obj_ptrs:
+            assert add_tpos_enc_to_obj_ptrs  # these options need to be used together
+        self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs
+        self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval
+
+        # Part 2: memory attention to condition current frame's visual features
+        # with memories (and obj ptrs) from past frames
+        self.memory_attention = memory_attention
+        self.hidden_dim = memory_attention.d_model
+
+        # Part 3: memory encoder for the previous frame's outputs
+        self.memory_encoder = memory_encoder
+        self.mem_dim = self.hidden_dim
+        if hasattr(self.memory_encoder, "out_proj") and hasattr(
+            self.memory_encoder.out_proj, "weight"
+        ):
+            # if there is compression of memories along channel dim
+            self.mem_dim = self.memory_encoder.out_proj.weight.shape[0]
+        self.num_maskmem = num_maskmem  # Number of memories accessible
+        # Temporal encoding of the memories
+        self.maskmem_tpos_enc = torch.nn.Parameter(
+            torch.zeros(num_maskmem, 1, 1, self.mem_dim)
+        )
+        trunc_normal_(self.maskmem_tpos_enc, std=0.02)
+        # a single token to indicate no memory embedding from previous frames
+        self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+        self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim))
+        trunc_normal_(self.no_mem_embed, std=0.02)
+        trunc_normal_(self.no_mem_pos_enc, std=0.02)
+        self.directly_add_no_mem_embed = directly_add_no_mem_embed
+        # Apply sigmoid to the output raw mask logits (to turn them from
+        # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder
+        self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc
+        self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc
+        self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc
+        self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc
+        self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval
+        # On frames with mask input, whether to directly output the input mask without
+        # using a SAM prompt encoder + mask decoder
+        self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam
+        self.multimask_output_in_sam = multimask_output_in_sam
+        self.multimask_min_pt_num = multimask_min_pt_num
+        self.multimask_max_pt_num = multimask_max_pt_num
+        self.multimask_output_for_tracking = multimask_output_for_tracking
+        self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
+        self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid
+
+        # Part 4: SAM-style prompt encoder (for both mask and point inputs)
+        # and SAM-style mask decoder for the final mask output
+        self.image_size = image_size
+        self.backbone_stride = backbone_stride
+        self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args
+        self.pred_obj_scores = pred_obj_scores
+        self.pred_obj_scores_mlp = pred_obj_scores_mlp
+        self.fixed_no_obj_ptr = fixed_no_obj_ptr
+        self.soft_no_obj_ptr = soft_no_obj_ptr
+        if self.fixed_no_obj_ptr:
+            assert self.pred_obj_scores
+            assert self.use_obj_ptrs_in_encoder
+        if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
+            self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
+            trunc_normal_(self.no_obj_ptr, std=0.02)
+        self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj
+
+        self._build_sam_heads()
+        self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond
+        self.max_cond_frames_in_attn = max_cond_frames_in_attn
+
+        # Model compilation
+        if compile_image_encoder:
+            # Compile the forward function (not the full module) to allow loading checkpoints.
+            print(
+                "Image encoder compilation is enabled. First forward pass will be slow."
+            )
+            self.image_encoder.forward = torch.compile(
+                self.image_encoder.forward,
+                mode="max-autotune",
+                fullgraph=True,
+                dynamic=False,
+            )
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
+
+    def forward(self, *args, **kwargs):
+        raise NotImplementedError(
+            "Please use the corresponding methods in SAM2VideoPredictor for inference."
+            "See notebooks/video_predictor_example.ipynb for an example."
+        )
+
+    def _build_sam_heads(self):
+        """Build SAM-style prompt encoder and mask decoder."""
+        self.sam_prompt_embed_dim = self.hidden_dim
+        self.sam_image_embedding_size = self.image_size // self.backbone_stride
+
+        # build PromptEncoder and MaskDecoder from SAM
+        # (their hyperparameters like `mask_in_chans=16` are from SAM code)
+        self.sam_prompt_encoder = PromptEncoder(
+            embed_dim=self.sam_prompt_embed_dim,
+            image_embedding_size=(
+                self.sam_image_embedding_size,
+                self.sam_image_embedding_size,
+            ),
+            input_image_size=(self.image_size, self.image_size),
+            mask_in_chans=16,
+        )
+
+        # xrh
+        if self.mask_decoder_matting_token:
+            self.sam_mask_decoder = MaskDecoderMattingToken(
+                # xrh
+                upscaled_embedding_res_add=self.upscaled_embedding_res_add,
+                matting_logits_res_add=self.matting_logits_res_add,
+
+                num_multimask_outputs=3,
+                transformer=TwoWayTransformer(
+                    depth=2,
+                    embedding_dim=self.sam_prompt_embed_dim,
+                    mlp_dim=2048,
+                    num_heads=8,
+                ),
+                transformer_dim=self.sam_prompt_embed_dim,
+                iou_head_depth=3,
+                iou_head_hidden_dim=256,
+                use_high_res_features=self.use_high_res_features_in_sam,
+                iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+                pred_obj_scores=self.pred_obj_scores,
+                pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+                use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+                **(self.sam_mask_decoder_extra_args or {}),
+            )
+        else:
+            self.sam_mask_decoder = MaskDecoder(
+                num_multimask_outputs=3,
+                transformer=TwoWayTransformer(
+                    depth=2,
+                    embedding_dim=self.sam_prompt_embed_dim,
+                    mlp_dim=2048,
+                    num_heads=8,
+                ),
+                transformer_dim=self.sam_prompt_embed_dim,
+                iou_head_depth=3,
+                iou_head_hidden_dim=256,
+                use_high_res_features=self.use_high_res_features_in_sam,
+                iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid,
+                pred_obj_scores=self.pred_obj_scores,
+                pred_obj_scores_mlp=self.pred_obj_scores_mlp,
+                use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr,
+                **(self.sam_mask_decoder_extra_args or {}),
+            )
+        if self.use_obj_ptrs_in_encoder:
+            # a linear projection on SAM output tokens to turn them into object pointers
+            self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
+            if self.use_mlp_for_obj_ptr_proj:
+                self.obj_ptr_proj = MLP(
+                    self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
+                )
+        else:
+            self.obj_ptr_proj = torch.nn.Identity()
+        if self.proj_tpos_enc_in_obj_ptrs:
+            # a linear projection on temporal positional encoding in object pointers to
+            # avoid potential interference with spatial positional encoding
+            self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim)
+        else:
+            self.obj_ptr_tpos_proj = torch.nn.Identity()
+
+    def _forward_sam_heads(
+        self,
+        backbone_features,
+        point_inputs=None,
+        mask_inputs=None,
+        high_res_features=None,
+        multimask_output=False,
+    ):
+        """
+        Forward SAM prompt encoders and mask heads.
+
+        Inputs:
+        - backbone_features: image features of [B, C, H, W] shape
+        - point_inputs: a dictionary with "point_coords" and "point_labels", where
+          1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
+             absolute pixel-unit coordinate in (x, y) format of the P input points
+          2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
+             positive clicks, 0 means negative clicks, and -1 means padding
+        - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
+          same spatial size as the image.
+        - high_res_features: either 1) None or 2) or a list of length 2 containing
+          two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
+          which will be used as high-resolution feature maps for SAM decoder.
+        - multimask_output: if it's True, we output 3 candidate masks and their 3
+          corresponding IoU estimates, and if it's False, we output only 1 mask and
+          its corresponding IoU estimate.
+
+        Outputs:
+        - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
+          `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
+          output mask logits (before sigmoid) for the low-resolution masks, with 4x
+          the resolution (1/4 stride) of the input backbone_features.
+        - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
+          if `multimask_output=True` and M = 1 if `multimask_output=False`),
+          upsampled from the low-resolution masks, with shape size as the image
+          (stride is 1 pixel).
+        - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
+          if `multimask_output=False`), the estimated IoU of each output mask.
+        - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
+          If `multimask_output=True`, it's the mask with the highest IoU estimate.
+          If `multimask_output=False`, it's the same as `low_res_multimasks`.
+        - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
+          If `multimask_output=True`, it's the mask with the highest IoU estimate.
+          If `multimask_output=False`, it's the same as `high_res_multimasks`.
+        - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
+          based on the output token from the SAM mask decoder.
+        """
+        B = backbone_features.size(0)
+        device = backbone_features.device
+        assert backbone_features.size(1) == self.sam_prompt_embed_dim
+        assert backbone_features.size(2) == self.sam_image_embedding_size
+        assert backbone_features.size(3) == self.sam_image_embedding_size
+
+        # a) Handle point prompts
+        if point_inputs is not None:
+            sam_point_coords = point_inputs["point_coords"]
+            sam_point_labels = point_inputs["point_labels"]
+            assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
+        else:
+            # If no points are provide, pad with an empty point (with label -1)
+            sam_point_coords = torch.zeros(B, 1, 2, device=device)
+            sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
+
+        # b) Handle mask prompts
+        if mask_inputs is not None:
+            # If mask_inputs is provided, downsize it into low-res mask input if needed
+            # and feed it as a dense mask prompt into the SAM mask encoder
+            assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
+            if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
+                sam_mask_prompt = F.interpolate(
+                    mask_inputs.float(),
+                    size=self.sam_prompt_encoder.mask_input_size,
+                    align_corners=False,
+                    mode="bilinear",
+                    antialias=True,  # use antialias for downsampling
+                )
+            else:
+                sam_mask_prompt = mask_inputs
+        else:
+            # Otherwise, simply feed None (and SAM's prompt encoder will add
+            # a learned `no_mask_embed` to indicate no mask input in this case).
+            sam_mask_prompt = None
+
+        sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
+            points=(sam_point_coords, sam_point_labels),
+            boxes=None,
+            masks=sam_mask_prompt,
+        )
+        (
+            low_res_multimasks,
+            ious,
+            sam_output_tokens,
+            object_score_logits,
+        ) = self.sam_mask_decoder(
+            image_embeddings=backbone_features,
+            image_pe=self.sam_prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+            repeat_image=False,  # the image is already batched
+            high_res_features=high_res_features,
+        )
+        if self.pred_obj_scores:
+            is_obj_appearing = object_score_logits > 0
+
+            # Mask used for spatial memories is always a *hard* choice between obj and no obj,
+            # consistent with the actual mask prediction
+            low_res_multimasks = torch.where(
+                is_obj_appearing[:, None, None],
+                low_res_multimasks,
+                NO_OBJ_SCORE,
+            )
+
+        # convert masks from possibly bfloat16 (or float16) to float32
+        # (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
+        low_res_multimasks = low_res_multimasks.float()
+        high_res_multimasks = F.interpolate(
+            low_res_multimasks,
+            size=(self.image_size, self.image_size),
+            mode="bilinear",
+            align_corners=False,
+        )
+
+        sam_output_token = sam_output_tokens[:, 0]
+        if multimask_output:
+            # take the best mask prediction (with the highest IoU estimation)
+            best_iou_inds = torch.argmax(ious, dim=-1)
+            batch_inds = torch.arange(B, device=device)
+            low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+            high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
+            if sam_output_tokens.size(1) > 1:
+                sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
+        else:
+            low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
+
+        # Extract object pointer from the SAM output token (with occlusion handling)
+        obj_ptr = self.obj_ptr_proj(sam_output_token)
+        if self.pred_obj_scores:
+            # Allow *soft* no obj ptr, unlike for masks
+            if self.soft_no_obj_ptr:
+                # Only hard possible with gt
+                assert not self.teacher_force_obj_scores_for_mem
+                lambda_is_obj_appearing = object_score_logits.sigmoid()
+            else:
+                lambda_is_obj_appearing = is_obj_appearing.float()
+
+            if self.fixed_no_obj_ptr:
+                obj_ptr = lambda_is_obj_appearing * obj_ptr
+            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+        return (
+            low_res_multimasks,
+            high_res_multimasks,
+            ious,
+            low_res_masks,
+            high_res_masks,
+            obj_ptr,
+            object_score_logits,
+        )
+
+    def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs):
+        """
+        Directly turn binary `mask_inputs` into a output mask logits without using SAM.
+        (same input and output shapes as in _forward_sam_heads above).
+        """
+        # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid).
+        out_scale, out_bias = 20.0, -10.0  # sigmoid(-10.0)=4.5398e-05
+        mask_inputs_float = mask_inputs.float()
+        high_res_masks = mask_inputs_float * out_scale + out_bias
+        low_res_masks = F.interpolate(
+            high_res_masks,
+            size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4),
+            align_corners=False,
+            mode="bilinear",
+            antialias=True,  # use antialias for downsampling
+        )
+        # a dummy IoU prediction of all 1's under mask input
+        ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float()
+        if not self.use_obj_ptrs_in_encoder:
+            # all zeros as a dummy object pointer (of shape [B, C])
+            obj_ptr = torch.zeros(
+                mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device
+            )
+        else:
+            # produce an object pointer using the SAM decoder from the mask input
+            _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads(
+                backbone_features=backbone_features,
+                mask_inputs=self.mask_downsample(mask_inputs_float),
+                high_res_features=high_res_features,
+            )
+        # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem;
+        # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying
+        # on the object_scores from the SAM decoder.
+        is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1)
+        is_obj_appearing = is_obj_appearing[..., None]
+        lambda_is_obj_appearing = is_obj_appearing.float()
+        object_score_logits = out_scale * lambda_is_obj_appearing + out_bias
+        if self.pred_obj_scores:
+            if self.fixed_no_obj_ptr:
+                obj_ptr = lambda_is_obj_appearing * obj_ptr
+            obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
+
+        return (
+            low_res_masks,
+            high_res_masks,
+            ious,
+            low_res_masks,
+            high_res_masks,
+            obj_ptr,
+            object_score_logits,
+        )
+
+    def forward_image(self, img_batch: torch.Tensor):
+        """Get the image feature on the input batch."""
+        backbone_out = self.image_encoder(img_batch)
+        if self.use_high_res_features_in_sam:
+            # precompute projected level 0 and level 1 features in SAM decoder
+            # to avoid running it again on every SAM click
+            backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(
+                backbone_out["backbone_fpn"][0]
+            )
+            backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(
+                backbone_out["backbone_fpn"][1]
+            )
+        return backbone_out
+
+    def _prepare_backbone_features(self, backbone_out):
+        """Prepare and flatten visual features."""
+        backbone_out = backbone_out.copy()
+        assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
+        assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels
+
+        feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :]
+        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :]
+
+        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
+        # flatten NxCxHxW to HWxNxC
+        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
+        vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
+
+        return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
+
+    def _prepare_memory_conditioned_features(
+        self,
+        frame_idx,
+        is_init_cond_frame,
+        current_vision_feats,
+        current_vision_pos_embeds,
+        feat_sizes,
+        output_dict,
+        num_frames,
+        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
+    ):
+        """Fuse the current frame's visual feature map with previous memory."""
+        B = current_vision_feats[-1].size(1)  # batch size on this frame
+        C = self.hidden_dim
+        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
+        device = current_vision_feats[-1].device
+        # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images.
+        # In this case, we skip the fusion with any memory.
+        if self.num_maskmem == 0:  # Disable memory and skip fusion
+            pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+            return pix_feat
+
+        num_obj_ptr_tokens = 0
+        # Step 1: condition the visual features of the current frame on previous memories
+        if not is_init_cond_frame:
+            # Retrieve the memories encoded with the maskmem backbone
+            to_cat_memory, to_cat_memory_pos_embed = [], []
+            # Add conditioning frames's output first (all cond frames have t_pos=0 for
+            # when getting temporal positional embedding below)
+            assert len(output_dict["cond_frame_outputs"]) > 0
+            # Select a maximum number of temporally closest cond frames for cross attention
+            cond_outputs = output_dict["cond_frame_outputs"]
+            selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames(
+                frame_idx, cond_outputs, self.max_cond_frames_in_attn
+            )
+            t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()]
+            # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory
+            # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1
+            # We also allow taking the memory frame non-consecutively (with r>1), in which case
+            # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame.
+            r = self.memory_temporal_stride_for_eval
+            for t_pos in range(1, self.num_maskmem):
+                t_rel = self.num_maskmem - t_pos  # how many frames before current frame
+                if t_rel == 1:
+                    # for t_rel == 1, we take the last frame (regardless of r)
+                    if not track_in_reverse:
+                        # the frame immediately before this frame (i.e. frame_idx - 1)
+                        prev_frame_idx = frame_idx - t_rel
+                    else:
+                        # the frame immediately after this frame (i.e. frame_idx + 1)
+                        prev_frame_idx = frame_idx + t_rel
+                else:
+                    # for t_rel >= 2, we take the memory frame from every r-th frames
+                    if not track_in_reverse:
+                        # first find the nearest frame among every r-th frames before this frame
+                        # for r=1, this would be (frame_idx - 2)
+                        prev_frame_idx = ((frame_idx - 2) // r) * r
+                        # then seek further among every r-th frames
+                        prev_frame_idx = prev_frame_idx - (t_rel - 2) * r
+                    else:
+                        # first find the nearest frame among every r-th frames after this frame
+                        # for r=1, this would be (frame_idx + 2)
+                        prev_frame_idx = -(-(frame_idx + 2) // r) * r
+                        # then seek further among every r-th frames
+                        prev_frame_idx = prev_frame_idx + (t_rel - 2) * r
+                out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None)
+                if out is None:
+                    # If an unselected conditioning frame is among the last (self.num_maskmem - 1)
+                    # frames, we still attend to it as if it's a non-conditioning frame.
+                    out = unselected_cond_outputs.get(prev_frame_idx, None)
+                t_pos_and_prevs.append((t_pos, out))
+
+            for t_pos, prev in t_pos_and_prevs:
+                if prev is None:
+                    continue  # skip padding frames
+                # "maskmem_features" might have been offloaded to CPU in demo use cases,
+                # so we load it back to GPU (it's a no-op if it's already on GPU).
+                feats = prev["maskmem_features"].cuda(non_blocking=True)
+                to_cat_memory.append(feats.flatten(2).permute(2, 0, 1))
+                # Spatial positional encoding (it might have been offloaded to CPU in eval)
+                maskmem_enc = prev["maskmem_pos_enc"][-1].cuda()
+                maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1)
+                # Temporal positional encoding
+                maskmem_enc = (
+                    maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1]
+                )
+                to_cat_memory_pos_embed.append(maskmem_enc)
+
+            # Construct the list of past object pointers
+            if self.use_obj_ptrs_in_encoder:
+                max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder)
+                # First add those object pointers from selected conditioning frames
+                # (optionally, only include object pointers in the past during evaluation)
+                if not self.training and self.only_obj_ptrs_in_the_past_for_eval:
+                    ptr_cond_outputs = {
+                        t: out
+                        for t, out in selected_cond_outputs.items()
+                        if (t >= frame_idx if track_in_reverse else t <= frame_idx)
+                    }
+                else:
+                    ptr_cond_outputs = selected_cond_outputs
+                pos_and_ptrs = [
+                    # Temporal pos encoding contains how far away each pointer is from current frame
+                    (abs(frame_idx - t), out["obj_ptr"])
+                    for t, out in ptr_cond_outputs.items()
+                ]
+                # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame
+                for t_diff in range(1, max_obj_ptrs_in_encoder):
+                    t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff
+                    if t < 0 or (num_frames is not None and t >= num_frames):
+                        break
+                    out = output_dict["non_cond_frame_outputs"].get(
+                        t, unselected_cond_outputs.get(t, None)
+                    )
+                    if out is not None:
+                        pos_and_ptrs.append((t_diff, out["obj_ptr"]))
+                # If we have at least one object pointer, add them to the across attention
+                if len(pos_and_ptrs) > 0:
+                    pos_list, ptrs_list = zip(*pos_and_ptrs)
+                    # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape
+                    obj_ptrs = torch.stack(ptrs_list, dim=0)
+                    # a temporal positional embedding based on how far each object pointer is from
+                    # the current frame (sine embedding normalized by the max pointer num).
+                    if self.add_tpos_enc_to_obj_ptrs:
+                        t_diff_max = max_obj_ptrs_in_encoder - 1
+                        tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim
+                        obj_pos = torch.tensor(pos_list, device=device)
+                        obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim)
+                        obj_pos = self.obj_ptr_tpos_proj(obj_pos)
+                        obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim)
+                    else:
+                        obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim)
+                    if self.mem_dim < C:
+                        # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C
+                        obj_ptrs = obj_ptrs.reshape(
+                            -1, B, C // self.mem_dim, self.mem_dim
+                        )
+                        obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1)
+                        obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0)
+                    to_cat_memory.append(obj_ptrs)
+                    to_cat_memory_pos_embed.append(obj_pos)
+                    num_obj_ptr_tokens = obj_ptrs.shape[0]
+                else:
+                    num_obj_ptr_tokens = 0
+        else:
+            # for initial conditioning frames, encode them without using any previous memory
+            if self.directly_add_no_mem_embed:
+                # directly add no-mem embedding (instead of using the transformer encoder)
+                pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed
+                pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+                return pix_feat_with_mem
+
+            # Use a dummy token on the first frame (to avoid emtpy memory input to tranformer encoder)
+            to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)]
+            to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)]
+
+        # Step 2: Concatenate the memories and forward through the transformer encoder
+        memory = torch.cat(to_cat_memory, dim=0)
+        memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0)
+
+        pix_feat_with_mem = self.memory_attention(
+            curr=current_vision_feats,
+            curr_pos=current_vision_pos_embeds,
+            memory=memory,
+            memory_pos=memory_pos_embed,
+            num_obj_ptr_tokens=num_obj_ptr_tokens,
+        )
+        # reshape the output (HW)BC => BCHW
+        pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W)
+        return pix_feat_with_mem
+
+    def _encode_new_memory(
+        self,
+        current_vision_feats,
+        feat_sizes,
+        pred_masks_high_res,
+        is_mask_from_pts,
+    ):
+        """Encode the current image and its prediction into a memory feature."""
+        B = current_vision_feats[-1].size(1)  # batch size on this frame
+        C = self.hidden_dim
+        H, W = feat_sizes[-1]  # top-level (lowest-resolution) feature size
+        # top-level feature, (HW)BC => BCHW
+        pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W)
+        if self.non_overlap_masks_for_mem_enc and not self.training:
+            # optionally, apply non-overlapping constraints to the masks (it's applied
+            # in the batch dimension and should only be used during eval, where all
+            # the objects come from the same video under batch size 1).
+            pred_masks_high_res = self._apply_non_overlapping_constraints(
+                pred_masks_high_res
+            )
+        # scale the raw mask logits with a temperature before applying sigmoid
+        binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts
+        if binarize and not self.training:
+            mask_for_mem = (pred_masks_high_res > 0).float()
+        else:
+            # apply sigmoid on the raw mask logits to turn them into range (0, 1)
+            mask_for_mem = torch.sigmoid(pred_masks_high_res)
+        # apply scale and bias terms to the sigmoid probabilities
+        if self.sigmoid_scale_for_mem_enc != 1.0:
+            mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc
+        if self.sigmoid_bias_for_mem_enc != 0.0:
+            mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc
+        maskmem_out = self.memory_encoder(
+            pix_feat, mask_for_mem, skip_mask_sigmoid=True  # sigmoid already applied
+        )
+        maskmem_features = maskmem_out["vision_features"]
+        maskmem_pos_enc = maskmem_out["vision_pos_enc"]
+
+        return maskmem_features, maskmem_pos_enc
+
+    def track_step(
+        self,
+        frame_idx,
+        is_init_cond_frame,
+        current_vision_feats,
+        current_vision_pos_embeds,
+        feat_sizes,
+        point_inputs,
+        mask_inputs,
+        output_dict,
+        num_frames,
+        track_in_reverse=False,  # tracking in reverse time order (for demo usage)
+        # Whether to run the memory encoder on the predicted masks. Sometimes we might want
+        # to skip the memory encoder with `run_mem_encoder=False`. For example,
+        # in demo we might call `track_step` multiple times for each user click,
+        # and only encode the memory when the user finalizes their clicks. And in ablation
+        # settings like SAM training on static images, we don't need the memory encoder.
+        run_mem_encoder=True,
+        # The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
+        prev_sam_mask_logits=None,
+    ):
+        current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
+        # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
+        if len(current_vision_feats) > 1:
+            high_res_features = [
+                x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
+                for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
+            ]
+        else:
+            high_res_features = None
+        if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
+            # When use_mask_input_as_output_without_sam=True, we directly output the mask input
+            # (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
+            pix_feat = current_vision_feats[-1].permute(1, 2, 0)
+            pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
+            sam_outputs = self._use_mask_as_output(
+                pix_feat, high_res_features, mask_inputs
+            )
+        else:
+            # fused the visual feature with previous memory features in the memory bank
+            pix_feat_with_mem = self._prepare_memory_conditioned_features(
+                frame_idx=frame_idx,
+                is_init_cond_frame=is_init_cond_frame,
+                current_vision_feats=current_vision_feats[-1:],
+                current_vision_pos_embeds=current_vision_pos_embeds[-1:],
+                feat_sizes=feat_sizes[-1:],
+                output_dict=output_dict,
+                num_frames=num_frames,
+                track_in_reverse=track_in_reverse,
+            )
+            # apply SAM-style segmentation head
+            # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
+            # e.g. in demo where such logits come from earlier interaction instead of correction sampling
+            # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
+            if prev_sam_mask_logits is not None:
+                assert point_inputs is not None and mask_inputs is None
+                mask_inputs = prev_sam_mask_logits
+            multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
+            sam_outputs = self._forward_sam_heads(
+                backbone_features=pix_feat_with_mem,
+                point_inputs=point_inputs,
+                mask_inputs=mask_inputs,
+                high_res_features=high_res_features,
+                multimask_output=multimask_output,
+            )
+        (
+            _,
+            _,
+            _,
+            low_res_masks,
+            high_res_masks,
+            obj_ptr,
+            _,
+        ) = sam_outputs
+
+        current_out["pred_masks"] = low_res_masks
+        current_out["pred_masks_high_res"] = high_res_masks
+        current_out["obj_ptr"] = obj_ptr
+
+        # Finally run the memory encoder on the predicted mask to encode
+        # it into a new memory feature (that can be used in future frames)
+        if run_mem_encoder and self.num_maskmem > 0:
+            high_res_masks_for_mem_enc = high_res_masks
+            maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+                current_vision_feats=current_vision_feats,
+                feat_sizes=feat_sizes,
+                pred_masks_high_res=high_res_masks_for_mem_enc,
+                is_mask_from_pts=(point_inputs is not None),
+            )
+            current_out["maskmem_features"] = maskmem_features
+            current_out["maskmem_pos_enc"] = maskmem_pos_enc
+        else:
+            current_out["maskmem_features"] = None
+            current_out["maskmem_pos_enc"] = None
+
+        return current_out
+
+    def _use_multimask(self, is_init_cond_frame, point_inputs):
+        """Whether to use multimask output in the SAM head."""
+        num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
+        multimask_output = (
+            self.multimask_output_in_sam
+            and (is_init_cond_frame or self.multimask_output_for_tracking)
+            and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num)
+        )
+        return multimask_output
+
+    def _apply_non_overlapping_constraints(self, pred_masks):
+        """
+        Apply non-overlapping constraints to the object scores in pred_masks. Here we
+        keep only the highest scoring object at each spatial location in pred_masks.
+        """
+        batch_size = pred_masks.size(0)
+        if batch_size == 1:
+            return pred_masks
+
+        device = pred_masks.device
+        # "max_obj_inds": object index of the object with the highest score at each location
+        max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True)
+        # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks`
+        batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None]
+        keep = max_obj_inds == batch_obj_inds
+        # suppress overlapping regions' scores below -10.0 so that the foreground regions
+        # don't overlap (here sigmoid(-10.0)=4.5398e-05)
+        pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
+        return pred_masks
diff --git a/sam2/modeling/sam2_utils.py b/sam2/modeling/sam2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d9705963efc57d74b7d1bff31692d7d293a46ad
--- /dev/null
+++ b/sam2/modeling/sam2_utils.py
@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
+    """
+    Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
+    that are temporally closest to the current frame at `frame_idx`. Here, we take
+    - a) the closest conditioning frame before `frame_idx` (if any);
+    - b) the closest conditioning frame after `frame_idx` (if any);
+    - c) any other temporally closest conditioning frames until reaching a total
+         of `max_cond_frame_num` conditioning frames.
+
+    Outputs:
+    - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
+    - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
+    """
+    if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
+        selected_outputs = cond_frame_outputs
+        unselected_outputs = {}
+    else:
+        assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
+        selected_outputs = {}
+
+        # the closest conditioning frame before `frame_idx` (if any)
+        idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
+        if idx_before is not None:
+            selected_outputs[idx_before] = cond_frame_outputs[idx_before]
+
+        # the closest conditioning frame after `frame_idx` (if any)
+        idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
+        if idx_after is not None:
+            selected_outputs[idx_after] = cond_frame_outputs[idx_after]
+
+        # add other temporally closest conditioning frames until reaching a total
+        # of `max_cond_frame_num` conditioning frames.
+        num_remain = max_cond_frame_num - len(selected_outputs)
+        inds_remain = sorted(
+            (t for t in cond_frame_outputs if t not in selected_outputs),
+            key=lambda x: abs(x - frame_idx),
+        )[:num_remain]
+        selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
+        unselected_outputs = {
+            t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
+        }
+
+    return selected_outputs, unselected_outputs
+
+
+def get_1d_sine_pe(pos_inds, dim, temperature=10000):
+    """
+    Get 1D sine positional embedding as in the original Transformer paper.
+    """
+    pe_dim = dim // 2
+    dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
+    dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
+
+    pos_embed = pos_inds.unsqueeze(-1) / dim_t
+    pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
+    return pos_embed
+
+
+def get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
+
+
+def get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+class DropPath(nn.Module):
+    # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
+    def __init__(self, drop_prob=0.0, scale_by_keep=True):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+        self.scale_by_keep = scale_by_keep
+
+    def forward(self, x):
+        if self.drop_prob == 0.0 or not self.training:
+            return x
+        keep_prob = 1 - self.drop_prob
+        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+        random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+        if keep_prob > 0.0 and self.scale_by_keep:
+            random_tensor.div_(keep_prob)
+        return x * random_tensor
+
+
+# Lightly adapted from
+# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
+class MLP(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        hidden_dim: int,
+        output_dim: int,
+        num_layers: int,
+        activation: nn.Module = nn.ReLU,
+        sigmoid_output: bool = False,
+    ) -> None:
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+        )
+        self.sigmoid_output = sigmoid_output
+        self.act = activation()
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
+        if self.sigmoid_output:
+            x = F.sigmoid(x)
+        return x
+
+
+# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
+# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119  # noqa
+class LayerNorm2d(nn.Module):
+    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(num_channels))
+        self.bias = nn.Parameter(torch.zeros(num_channels))
+        self.eps = eps
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        u = x.mean(1, keepdim=True)
+        s = (x - u).pow(2).mean(1, keepdim=True)
+        x = (x - u) / torch.sqrt(s + self.eps)
+        x = self.weight[:, None, None] * x + self.bias[:, None, None]
+        return x
diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..14843fec0f791be2572d69dd8020d10e0b0669f2
--- /dev/null
+++ b/sam2/sam2_image_predictor.py
@@ -0,0 +1,517 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import logging
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from PIL.Image import Image
+
+from sam2.modeling.sam2_base import SAM2Base
+
+from sam2.utils.transforms import SAM2Transforms
+
+
+class SAM2ImagePredictor:
+    def __init__(
+        self,
+        sam_model: SAM2Base,
+        mask_threshold=0.0,
+        max_hole_area=0.0,
+        max_sprinkle_area=0.0,
+    ) -> None:
+        """
+        Uses SAM-2 to calculate the image embedding for an image, and then
+        allow repeated, efficient mask prediction given prompts.
+
+        Arguments:
+          sam_model (Sam-2): The model to use for mask prediction.
+          mask_threshold (float): The threshold to use when converting mask logits
+            to binary masks. Masks are thresholded at 0 by default.
+          fill_hole_area (int): If fill_hole_area > 0, we fill small holes in up to
+            the maximum area of fill_hole_area in low_res_masks.
+        """
+        super().__init__()
+        self.model = sam_model
+        self._transforms = SAM2Transforms(
+            resolution=self.model.image_size,
+            mask_threshold=mask_threshold,
+            max_hole_area=max_hole_area,
+            max_sprinkle_area=max_sprinkle_area,
+        )
+
+        # Predictor state
+        self._is_image_set = False
+        self._features = None
+        self._orig_hw = None
+        # Whether the predictor is set for single image or a batch of images
+        self._is_batch = False
+
+        # Predictor config
+        self.mask_threshold = mask_threshold
+
+        # Spatial dim for backbone feature maps
+        self._bb_feat_sizes = [
+            (256, 256),
+            (128, 128),
+            (64, 64),
+        ]
+
+    @classmethod
+    def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor":
+        """
+        Load a pretrained model from the Hugging Face hub.
+
+        Arguments:
+          model_id (str): The Hugging Face repository ID.
+          **kwargs: Additional arguments to pass to the model constructor.
+
+        Returns:
+          (SAM2ImagePredictor): The loaded model.
+        """
+        from sam2.build_sam import build_sam2_hf
+
+        sam_model = build_sam2_hf(model_id, **kwargs)
+        return cls(sam_model)
+
+    @torch.no_grad()
+    def set_image(
+        self,
+        image: Union[np.ndarray, Image],
+    ) -> None:
+        """
+        Calculates the image embeddings for the provided image, allowing
+        masks to be predicted with the 'predict' method.
+
+        Arguments:
+          image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image
+          with pixel values in [0, 255].
+          image_format (str): The color format of the image, in ['RGB', 'BGR'].
+        """
+        self.reset_predictor()
+        # Transform the image to the form expected by the model
+        if isinstance(image, np.ndarray):
+            logging.info("For numpy array image, we assume (HxWxC) format")
+            self._orig_hw = [image.shape[:2]]
+        elif isinstance(image, Image):
+            w, h = image.size
+            self._orig_hw = [(h, w)]
+        else:
+            raise NotImplementedError("Image format not supported")
+
+        input_image = self._transforms(image)
+        input_image = input_image[None, ...].to(self.device)
+
+        assert (
+            len(input_image.shape) == 4 and input_image.shape[1] == 3
+        ), f"input_image must be of size 1x3xHxW, got {input_image.shape}"
+        logging.info("Computing image embeddings for the provided image...")
+        backbone_out = self.model.forward_image(input_image)
+        _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+        # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+        if self.model.directly_add_no_mem_embed:
+            vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+        feats = [
+            feat.permute(1, 2, 0).view(1, -1, *feat_size)
+            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+        ][::-1]
+        self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+        self._is_image_set = True
+        logging.info("Image embeddings computed.")
+
+    @torch.no_grad()
+    def set_image_batch(
+        self,
+        image_list: List[Union[np.ndarray]],
+    ) -> None:
+        """
+        Calculates the image embeddings for the provided image batch, allowing
+        masks to be predicted with the 'predict_batch' method.
+
+        Arguments:
+          image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray
+          with pixel values in [0, 255].
+        """
+        self.reset_predictor()
+        assert isinstance(image_list, list)
+        self._orig_hw = []
+        for image in image_list:
+            assert isinstance(
+                image, np.ndarray
+            ), "Images are expected to be an np.ndarray in RGB format, and of shape  HWC"
+            self._orig_hw.append(image.shape[:2])
+        # Transform the image to the form expected by the model
+        img_batch = self._transforms.forward_batch(image_list)
+        img_batch = img_batch.to(self.device)
+        batch_size = img_batch.shape[0]
+        assert (
+            len(img_batch.shape) == 4 and img_batch.shape[1] == 3
+        ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}"
+        logging.info("Computing image embeddings for the provided images...")
+        backbone_out = self.model.forward_image(img_batch)
+        _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out)
+        # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos
+        if self.model.directly_add_no_mem_embed:
+            vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed
+
+        feats = [
+            feat.permute(1, 2, 0).view(batch_size, -1, *feat_size)
+            for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
+        ][::-1]
+        self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
+        self._is_image_set = True
+        self._is_batch = True
+        logging.info("Image embeddings computed.")
+
+    def predict_batch(
+        self,
+        point_coords_batch: List[np.ndarray] = None,
+        point_labels_batch: List[np.ndarray] = None,
+        box_batch: List[np.ndarray] = None,
+        mask_input_batch: List[np.ndarray] = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+        normalize_coords=True,
+    ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
+        """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images.
+        It returns a tupele of lists of masks, ious, and low_res_masks_logits.
+        """
+        assert self._is_batch, "This function should only be used when in batched mode"
+        if not self._is_image_set:
+            raise RuntimeError(
+                "An image must be set with .set_image_batch(...) before mask prediction."
+            )
+        num_images = len(self._features["image_embed"])
+        all_masks = []
+        all_ious = []
+        all_low_res_masks = []
+        for img_idx in range(num_images):
+            # Transform input prompts
+            point_coords = (
+                point_coords_batch[img_idx] if point_coords_batch is not None else None
+            )
+            point_labels = (
+                point_labels_batch[img_idx] if point_labels_batch is not None else None
+            )
+            box = box_batch[img_idx] if box_batch is not None else None
+            mask_input = (
+                mask_input_batch[img_idx] if mask_input_batch is not None else None
+            )
+            mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+                point_coords,
+                point_labels,
+                box,
+                mask_input,
+                normalize_coords,
+                img_idx=img_idx,
+            )
+            masks, iou_predictions, low_res_masks = self._predict(
+                unnorm_coords,
+                labels,
+                unnorm_box,
+                mask_input,
+                multimask_output,
+                return_logits=return_logits,
+                img_idx=img_idx,
+            )
+            masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+            iou_predictions_np = (
+                iou_predictions.squeeze(0).float().detach().cpu().numpy()
+            )
+            low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+            all_masks.append(masks_np)
+            all_ious.append(iou_predictions_np)
+            all_low_res_masks.append(low_res_masks_np)
+
+        return all_masks, all_ious, all_low_res_masks
+
+    def predict(
+        self,
+        point_coords: Optional[np.ndarray] = None,
+        point_labels: Optional[np.ndarray] = None,
+        box: Optional[np.ndarray] = None,
+        mask_input: Optional[np.ndarray] = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+        normalize_coords=True,
+    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+        """
+        Predict masks for the given input prompts, using the currently set image.
+
+        Arguments:
+          point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+            model. Each point is in (X,Y) in pixels.
+          point_labels (np.ndarray or None): A length N array of labels for the
+            point prompts. 1 indicates a foreground point and 0 indicates a
+            background point.
+          box (np.ndarray or None): A length 4 array given a box prompt to the
+            model, in XYXY format.
+          mask_input (np.ndarray): A low resolution mask input to the model, typically
+            coming from a previous prediction iteration. Has form 1xHxW, where
+            for SAM, H=W=256.
+          multimask_output (bool): If true, the model will return three masks.
+            For ambiguous input prompts (such as a single click), this will often
+            produce better masks than a single prediction. If only a single
+            mask is needed, the model's predicted quality score can be used
+            to select the best mask. For non-ambiguous prompts, such as multiple
+            input prompts, multimask_output=False can give better results.
+          return_logits (bool): If true, returns un-thresholded masks logits
+            instead of a binary mask.
+          normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions.
+
+        Returns:
+          (np.ndarray): The output masks in CxHxW format, where C is the
+            number of masks, and (H, W) is the original image size.
+          (np.ndarray): An array of length C containing the model's
+            predictions for the quality of each mask.
+          (np.ndarray): An array of shape CxHxW, where C is the number
+            of masks and H=W=256. These low resolution logits can be passed to
+            a subsequent iteration as mask input.
+        """
+        if not self._is_image_set:
+            raise RuntimeError(
+                "An image must be set with .set_image(...) before mask prediction."
+            )
+
+        # Transform input prompts
+
+        mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts(
+            point_coords, point_labels, box, mask_input, normalize_coords
+        )
+
+        masks, iou_predictions, low_res_masks = self._predict(
+            unnorm_coords,
+            labels,
+            unnorm_box,
+            mask_input,
+            multimask_output,
+            return_logits=return_logits,
+        )
+
+        masks_np = masks.squeeze(0).float().detach().cpu().numpy()
+        iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy()
+        low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy()
+        return masks_np, iou_predictions_np, low_res_masks_np
+
+    def _prep_prompts(
+        self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1
+    ):
+
+        unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None
+        if point_coords is not None:
+            assert (
+                point_labels is not None
+            ), "point_labels must be supplied if point_coords is supplied."
+            point_coords = torch.as_tensor(
+                point_coords, dtype=torch.float, device=self.device
+            )
+            unnorm_coords = self._transforms.transform_coords(
+                point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+            )
+            labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+            if len(unnorm_coords.shape) == 2:
+                unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
+        if box is not None:
+            box = torch.as_tensor(box, dtype=torch.float, device=self.device)
+            unnorm_box = self._transforms.transform_boxes(
+                box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx]
+            )  # Bx2x2
+        if mask_logits is not None:
+            mask_input = torch.as_tensor(
+                mask_logits, dtype=torch.float, device=self.device
+            )
+            if len(mask_input.shape) == 3:
+                mask_input = mask_input[None, :, :, :]
+        return mask_input, unnorm_coords, labels, unnorm_box
+
+    @torch.no_grad()
+    def _predict(
+        self,
+        point_coords: Optional[torch.Tensor],
+        point_labels: Optional[torch.Tensor],
+        boxes: Optional[torch.Tensor] = None,
+        mask_input: Optional[torch.Tensor] = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+        img_idx: int = -1,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Predict masks for the given input prompts, using the currently set image.
+        Input prompts are batched torch tensors and are expected to already be
+        transformed to the input frame using SAM2Transforms.
+
+        Arguments:
+          point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+            model. Each point is in (X,Y) in pixels.
+          point_labels (torch.Tensor or None): A BxN array of labels for the
+            point prompts. 1 indicates a foreground point and 0 indicates a
+            background point.
+          boxes (np.ndarray or None): A Bx4 array given a box prompt to the
+            model, in XYXY format.
+          mask_input (np.ndarray): A low resolution mask input to the model, typically
+            coming from a previous prediction iteration. Has form Bx1xHxW, where
+            for SAM, H=W=256. Masks returned by a previous iteration of the
+            predict method do not need further transformation.
+          multimask_output (bool): If true, the model will return three masks.
+            For ambiguous input prompts (such as a single click), this will often
+            produce better masks than a single prediction. If only a single
+            mask is needed, the model's predicted quality score can be used
+            to select the best mask. For non-ambiguous prompts, such as multiple
+            input prompts, multimask_output=False can give better results.
+          return_logits (bool): If true, returns un-thresholded masks logits
+            instead of a binary mask.
+
+        Returns:
+          (torch.Tensor): The output masks in BxCxHxW format, where C is the
+            number of masks, and (H, W) is the original image size.
+          (torch.Tensor): An array of shape BxC containing the model's
+            predictions for the quality of each mask.
+          (torch.Tensor): An array of shape BxCxHxW, where C is the number
+            of masks and H=W=256. These low res logits can be passed to
+            a subsequent iteration as mask input.
+        """
+        if not self._is_image_set:
+            raise RuntimeError(
+                "An image must be set with .set_image(...) before mask prediction."
+            )
+
+        if point_coords is not None:
+            concat_points = (point_coords, point_labels)
+        else:
+            concat_points = None
+
+        # Embed prompts
+        if boxes is not None:
+            box_coords = boxes.reshape(-1, 2, 2)
+            box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
+            box_labels = box_labels.repeat(boxes.size(0), 1)
+            # we merge "boxes" and "points" into a single "concat_points" input (where
+            # boxes are added at the beginning) to sam_prompt_encoder
+            if concat_points is not None:
+                concat_coords = torch.cat([box_coords, concat_points[0]], dim=1)
+                concat_labels = torch.cat([box_labels, concat_points[1]], dim=1)
+                concat_points = (concat_coords, concat_labels)
+            else:
+                concat_points = (box_coords, box_labels)
+
+        sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+            points=concat_points,
+            boxes=None,
+            masks=mask_input,
+        )
+
+        # Predict masks
+        batched_mode = (
+            concat_points is not None and concat_points[0].shape[0] > 1
+        )  # multi object prediction
+        high_res_features = [
+            feat_level[img_idx].unsqueeze(0)
+            for feat_level in self._features["high_res_feats"]
+        ]
+        low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder(
+            image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0),
+            image_pe=self.model.sam_prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+            repeat_image=batched_mode,
+            high_res_features=high_res_features,
+        )
+
+        # Upscale the masks to the original image resolution
+        masks = self._transforms.postprocess_masks(
+            low_res_masks, self._orig_hw[img_idx]
+        )
+        low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
+        if not return_logits:
+            masks = masks > self.mask_threshold
+
+        return masks, iou_predictions, low_res_masks
+
+    def get_image_embedding(self) -> torch.Tensor:
+        """
+        Returns the image embeddings for the currently set image, with
+        shape 1xCxHxW, where C is the embedding dimension and (H,W) are
+        the embedding spatial dimension of SAM (typically C=256, H=W=64).
+        """
+        if not self._is_image_set:
+            raise RuntimeError(
+                "An image must be set with .set_image(...) to generate an embedding."
+            )
+        assert (
+            self._features is not None
+        ), "Features must exist if an image has been set."
+        return self._features["image_embed"]
+
+    @property
+    def device(self) -> torch.device:
+        return self.model.device
+
+    def reset_predictor(self) -> None:
+        """
+        Resets the image embeddings and other state variables.
+        """
+        self._is_image_set = False
+        self._features = None
+        self._orig_hw = None
+        self._is_batch = False
+
+    def predict_batch_boxes_and_features(
+        self,
+        boxes,
+        features,
+        multimask_output: bool = False,
+        return_logits: bool = True,
+        wo_matting_token: bool = False,
+    ):
+
+        box_coords = boxes.reshape(-1, 2, 2)
+        box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device)
+        box_labels = box_labels.repeat(boxes.size(0), 1)
+        # we merge "boxes" and "points" into a single "concat_points" input (where
+        # boxes are added at the beginning) to sam_prompt_encoder
+        concat_points = (box_coords, box_labels)
+
+        image_pe=self.model.sam_prompt_encoder.get_dense_pe()
+        sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder(
+            points=concat_points,
+            boxes=None,
+            masks=None,
+        )
+
+        batch_size = boxes.shape[0]
+        cat_sam2_logits = []
+        cat_matting_logits = []
+        for bs_idx in range(batch_size):
+            
+            high_res_features = [
+                feat_level[bs_idx].unsqueeze(0)
+                for feat_level in features["high_res_feats"]
+            ]
+
+            sam2_logits, matting_logits = self.model.sam_mask_decoder(
+                image_embeddings=features["image_embed"][bs_idx].unsqueeze(0),
+                image_pe=image_pe,
+                sparse_prompt_embeddings=sparse_embeddings[bs_idx].unsqueeze(0),
+                dense_prompt_embeddings=dense_embeddings[bs_idx].unsqueeze(0),
+                multimask_output=multimask_output,
+                repeat_image=False,
+                high_res_features=high_res_features,
+                wo_matting_token=wo_matting_token,
+            )
+            cat_sam2_logits.append(sam2_logits)
+            if not wo_matting_token:
+                cat_matting_logits.append(matting_logits)
+
+        sam2_logits = torch.concatenate(cat_sam2_logits, dim=0)
+        if not wo_matting_token:
+            matting_logits = torch.concatenate(cat_matting_logits, dim=0)
+            return sam2_logits, matting_logits
+        else:
+            return sam2_logits
\ No newline at end of file
diff --git a/sam2/sam2_video_predictor.py b/sam2/sam2_video_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5a6bdf4b5742c3214e655657abd4b6bb9b7ede2
--- /dev/null
+++ b/sam2/sam2_video_predictor.py
@@ -0,0 +1,957 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+from collections import OrderedDict
+
+import torch
+
+from tqdm import tqdm
+
+from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base
+from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames
+
+
+class SAM2VideoPredictor(SAM2Base):
+    """The predictor class to handle user interactions and manage inference states."""
+
+    def __init__(
+        self,
+        fill_hole_area=0,
+        # whether to apply non-overlapping constraints on the output object masks
+        non_overlap_masks=False,
+        # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks;
+        # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True)
+        clear_non_cond_mem_around_input=False,
+        # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True).
+        clear_non_cond_mem_for_multi_obj=False,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+        self.fill_hole_area = fill_hole_area
+        self.non_overlap_masks = non_overlap_masks
+        self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input
+        self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj
+
+    @torch.inference_mode()
+    def init_state(
+        self,
+        video_path,
+        offload_video_to_cpu=False,
+        offload_state_to_cpu=False,
+        async_loading_frames=False,
+    ):
+        """Initialize a inference state."""
+        images, video_height, video_width = load_video_frames(
+            video_path=video_path,
+            image_size=self.image_size,
+            offload_video_to_cpu=offload_video_to_cpu,
+            async_loading_frames=async_loading_frames,
+        )
+        inference_state = {}
+        inference_state["images"] = images
+        inference_state["num_frames"] = len(images)
+        # whether to offload the video frames to CPU memory
+        # turning on this option saves the GPU memory with only a very small overhead
+        inference_state["offload_video_to_cpu"] = offload_video_to_cpu
+        # whether to offload the inference state to CPU memory
+        # turning on this option saves the GPU memory at the cost of a lower tracking fps
+        # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object
+        # and from 24 to 21 when tracking two objects)
+        inference_state["offload_state_to_cpu"] = offload_state_to_cpu
+        # the original video height and width, used for resizing final output scores
+        inference_state["video_height"] = video_height
+        inference_state["video_width"] = video_width
+        inference_state["device"] = torch.device("cuda")
+        if offload_state_to_cpu:
+            inference_state["storage_device"] = torch.device("cpu")
+        else:
+            inference_state["storage_device"] = torch.device("cuda")
+        # inputs on each frame
+        inference_state["point_inputs_per_obj"] = {}
+        inference_state["mask_inputs_per_obj"] = {}
+        # visual features on a small number of recently visited frames for quick interactions
+        inference_state["cached_features"] = {}
+        # values that don't change across frames (so we only need to hold one copy of them)
+        inference_state["constants"] = {}
+        # mapping between client-side object id and model-side object index
+        inference_state["obj_id_to_idx"] = OrderedDict()
+        inference_state["obj_idx_to_id"] = OrderedDict()
+        inference_state["obj_ids"] = []
+        # A storage to hold the model's tracking results and states on each frame
+        inference_state["output_dict"] = {
+            "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+            "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+        }
+        # Slice (view) of each object tracking results, sharing the same memory with "output_dict"
+        inference_state["output_dict_per_obj"] = {}
+        # A temporary storage to hold new outputs when user interact with a frame
+        # to add clicks or mask (it's merged into "output_dict" before propagation starts)
+        inference_state["temp_output_dict_per_obj"] = {}
+        # Frames that already holds consolidated outputs from click or mask inputs
+        # (we directly use their consolidated outputs during tracking)
+        inference_state["consolidated_frame_inds"] = {
+            "cond_frame_outputs": set(),  # set containing frame indices
+            "non_cond_frame_outputs": set(),  # set containing frame indices
+        }
+        # metadata for each tracking frame (e.g. which direction it's tracked)
+        inference_state["tracking_has_started"] = False
+        inference_state["frames_already_tracked"] = {}
+        # Warm up the visual backbone and cache the image feature on frame 0
+        self._get_image_feature(inference_state, frame_idx=0, batch_size=1)
+        return inference_state
+
+    @classmethod
+    def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor":
+        """
+        Load a pretrained model from the Hugging Face hub.
+
+        Arguments:
+          model_id (str): The Hugging Face repository ID.
+          **kwargs: Additional arguments to pass to the model constructor.
+
+        Returns:
+          (SAM2VideoPredictor): The loaded model.
+        """
+        from sam2.build_sam import build_sam2_video_predictor_hf
+
+        sam_model = build_sam2_video_predictor_hf(model_id, **kwargs)
+        return cls(sam_model)
+
+    def _obj_id_to_idx(self, inference_state, obj_id):
+        """Map client-side object id to model-side object index."""
+        obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None)
+        if obj_idx is not None:
+            return obj_idx
+
+        # This is a new object id not sent to the server before. We only allow adding
+        # new objects *before* the tracking starts.
+        allow_new_object = not inference_state["tracking_has_started"]
+        if allow_new_object:
+            # get the next object slot
+            obj_idx = len(inference_state["obj_id_to_idx"])
+            inference_state["obj_id_to_idx"][obj_id] = obj_idx
+            inference_state["obj_idx_to_id"][obj_idx] = obj_id
+            inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"])
+            # set up input and output structures for this object
+            inference_state["point_inputs_per_obj"][obj_idx] = {}
+            inference_state["mask_inputs_per_obj"][obj_idx] = {}
+            inference_state["output_dict_per_obj"][obj_idx] = {
+                "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+                "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+            }
+            inference_state["temp_output_dict_per_obj"][obj_idx] = {
+                "cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+                "non_cond_frame_outputs": {},  # dict containing {frame_idx: <out>}
+            }
+            return obj_idx
+        else:
+            raise RuntimeError(
+                f"Cannot add new object id {obj_id} after tracking starts. "
+                f"All existing object ids: {inference_state['obj_ids']}. "
+                f"Please call 'reset_state' to restart from scratch."
+            )
+
+    def _obj_idx_to_id(self, inference_state, obj_idx):
+        """Map model-side object index to client-side object id."""
+        return inference_state["obj_idx_to_id"][obj_idx]
+
+    def _get_obj_num(self, inference_state):
+        """Get the total number of unique object ids received so far in this session."""
+        return len(inference_state["obj_idx_to_id"])
+
+    @torch.inference_mode()
+    def add_new_points_or_box(
+        self,
+        inference_state,
+        frame_idx,
+        obj_id,
+        points=None,
+        labels=None,
+        clear_old_points=True,
+        normalize_coords=True,
+        box=None,
+    ):
+        """Add new points to a frame."""
+        obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+        point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+        mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+        if (points is not None) != (labels is not None):
+            raise ValueError("points and labels must be provided together")
+        if points is None and box is None:
+            raise ValueError("at least one of points or box must be provided as input")
+
+        if points is None:
+            points = torch.zeros(0, 2, dtype=torch.float32)
+        elif not isinstance(points, torch.Tensor):
+            points = torch.tensor(points, dtype=torch.float32)
+        if labels is None:
+            labels = torch.zeros(0, dtype=torch.int32)
+        elif not isinstance(labels, torch.Tensor):
+            labels = torch.tensor(labels, dtype=torch.int32)
+        if points.dim() == 2:
+            points = points.unsqueeze(0)  # add batch dimension
+        if labels.dim() == 1:
+            labels = labels.unsqueeze(0)  # add batch dimension
+
+        # If `box` is provided, we add it as the first two points with labels 2 and 3
+        # along with the user-provided points (consistent with how SAM 2 is trained).
+        if box is not None:
+            if not clear_old_points:
+                raise ValueError(
+                    "cannot add box without clearing old points, since "
+                    "box prompt must be provided before any point prompt "
+                    "(please use clear_old_points=True instead)"
+                )
+            if inference_state["tracking_has_started"]:
+                warnings.warn(
+                    "You are adding a box after tracking starts. SAM 2 may not always be "
+                    "able to incorporate a box prompt for *refinement*. If you intend to "
+                    "use box prompt as an *initial* input before tracking, please call "
+                    "'reset_state' on the inference state to restart from scratch.",
+                    category=UserWarning,
+                    stacklevel=2,
+                )
+            if not isinstance(box, torch.Tensor):
+                box = torch.tensor(box, dtype=torch.float32, device=points.device)
+            box_coords = box.reshape(1, 2, 2)
+            box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device)
+            box_labels = box_labels.reshape(1, 2)
+            points = torch.cat([box_coords, points], dim=1)
+            labels = torch.cat([box_labels, labels], dim=1)
+
+        if normalize_coords:
+            video_H = inference_state["video_height"]
+            video_W = inference_state["video_width"]
+            points = points / torch.tensor([video_W, video_H]).to(points.device)
+        # scale the (normalized) coordinates by the model's internal image size
+        points = points * self.image_size
+        points = points.to(inference_state["device"])
+        labels = labels.to(inference_state["device"])
+
+        if not clear_old_points:
+            point_inputs = point_inputs_per_frame.get(frame_idx, None)
+        else:
+            point_inputs = None
+        point_inputs = concat_points(point_inputs, points, labels)
+
+        point_inputs_per_frame[frame_idx] = point_inputs
+        mask_inputs_per_frame.pop(frame_idx, None)
+        # If this frame hasn't been tracked before, we treat it as an initial conditioning
+        # frame, meaning that the inputs points are to generate segments on this frame without
+        # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+        # the input points will be used to correct the already tracked masks.
+        is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+        # whether to track in reverse time order
+        if is_init_cond_frame:
+            reverse = False
+        else:
+            reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+        obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+        obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+        # Add a frame to conditioning output if it's an initial conditioning frame or
+        # if the model sees all frames receiving clicks/mask as conditioning frames.
+        is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+        storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+        # Get any previously predicted mask logits on this object and feed it along with
+        # the new clicks into the SAM mask decoder.
+        prev_sam_mask_logits = None
+        # lookup temporary output dict first, which contains the most recent output
+        # (if not found, then lookup conditioning and non-conditioning frame output)
+        prev_out = obj_temp_output_dict[storage_key].get(frame_idx)
+        if prev_out is None:
+            prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx)
+            if prev_out is None:
+                prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx)
+
+        if prev_out is not None and prev_out["pred_masks"] is not None:
+            prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True)
+            # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues.
+            prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0)
+        current_out, _ = self._run_single_frame_inference(
+            inference_state=inference_state,
+            output_dict=obj_output_dict,  # run on the slice of a single object
+            frame_idx=frame_idx,
+            batch_size=1,  # run on the slice of a single object
+            is_init_cond_frame=is_init_cond_frame,
+            point_inputs=point_inputs,
+            mask_inputs=None,
+            reverse=reverse,
+            # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+            # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+            # allows us to enforce non-overlapping constraints on all objects before encoding
+            # them into memory.
+            run_mem_encoder=False,
+            prev_sam_mask_logits=prev_sam_mask_logits,
+        )
+        # Add the output to the output dict (to be used as future memory)
+        obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+        # Resize the output mask to the original video resolution
+        obj_ids = inference_state["obj_ids"]
+        consolidated_out = self._consolidate_temp_output_across_obj(
+            inference_state,
+            frame_idx,
+            is_cond=is_cond,
+            run_mem_encoder=False,
+            consolidate_at_video_res=True,
+        )
+        _, video_res_masks = self._get_orig_video_res_output(
+            inference_state, consolidated_out["pred_masks_video_res"]
+        )
+        return frame_idx, obj_ids, video_res_masks
+
+    def add_new_points(self, *args, **kwargs):
+        """Deprecated method. Please use `add_new_points_or_box` instead."""
+        return self.add_new_points_or_box(*args, **kwargs)
+
+    @torch.inference_mode()
+    def add_new_mask(
+        self,
+        inference_state,
+        frame_idx,
+        obj_id,
+        mask,
+    ):
+        """Add new mask to a frame."""
+        obj_idx = self._obj_id_to_idx(inference_state, obj_id)
+        point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx]
+        mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx]
+
+        if not isinstance(mask, torch.Tensor):
+            mask = torch.tensor(mask, dtype=torch.bool)
+        assert mask.dim() == 2
+        mask_H, mask_W = mask.shape
+        mask_inputs_orig = mask[None, None]  # add batch and channel dimension
+        mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"])
+
+        # resize the mask if it doesn't match the model's image size
+        if mask_H != self.image_size or mask_W != self.image_size:
+            mask_inputs = torch.nn.functional.interpolate(
+                mask_inputs_orig,
+                size=(self.image_size, self.image_size),
+                align_corners=False,
+                mode="bilinear",
+                antialias=True,  # use antialias for downsampling
+            )
+            mask_inputs = (mask_inputs >= 0.5).float()
+        else:
+            mask_inputs = mask_inputs_orig
+
+        mask_inputs_per_frame[frame_idx] = mask_inputs
+        point_inputs_per_frame.pop(frame_idx, None)
+        # If this frame hasn't been tracked before, we treat it as an initial conditioning
+        # frame, meaning that the inputs points are to generate segments on this frame without
+        # using any memory from other frames, like in SAM. Otherwise (if it has been tracked),
+        # the input points will be used to correct the already tracked masks.
+        is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"]
+        # whether to track in reverse time order
+        if is_init_cond_frame:
+            reverse = False
+        else:
+            reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"]
+        obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+        obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+        # Add a frame to conditioning output if it's an initial conditioning frame or
+        # if the model sees all frames receiving clicks/mask as conditioning frames.
+        is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond
+        storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+
+        current_out, _ = self._run_single_frame_inference(
+            inference_state=inference_state,
+            output_dict=obj_output_dict,  # run on the slice of a single object
+            frame_idx=frame_idx,
+            batch_size=1,  # run on the slice of a single object
+            is_init_cond_frame=is_init_cond_frame,
+            point_inputs=None,
+            mask_inputs=mask_inputs,
+            reverse=reverse,
+            # Skip the memory encoder when adding clicks or mask. We execute the memory encoder
+            # at the beginning of `propagate_in_video` (after user finalize their clicks). This
+            # allows us to enforce non-overlapping constraints on all objects before encoding
+            # them into memory.
+            run_mem_encoder=False,
+        )
+        # Add the output to the output dict (to be used as future memory)
+        obj_temp_output_dict[storage_key][frame_idx] = current_out
+
+        # Resize the output mask to the original video resolution
+        obj_ids = inference_state["obj_ids"]
+        consolidated_out = self._consolidate_temp_output_across_obj(
+            inference_state,
+            frame_idx,
+            is_cond=is_cond,
+            run_mem_encoder=False,
+            consolidate_at_video_res=True,
+        )
+        _, video_res_masks = self._get_orig_video_res_output(
+            inference_state, consolidated_out["pred_masks_video_res"]
+        )
+        return frame_idx, obj_ids, video_res_masks
+
+    def _get_orig_video_res_output(self, inference_state, any_res_masks):
+        """
+        Resize the object scores to the original video resolution (video_res_masks)
+        and apply non-overlapping constraints for final output.
+        """
+        device = inference_state["device"]
+        video_H = inference_state["video_height"]
+        video_W = inference_state["video_width"]
+        any_res_masks = any_res_masks.to(device, non_blocking=True)
+        if any_res_masks.shape[-2:] == (video_H, video_W):
+            video_res_masks = any_res_masks
+        else:
+            video_res_masks = torch.nn.functional.interpolate(
+                any_res_masks,
+                size=(video_H, video_W),
+                mode="bilinear",
+                align_corners=False,
+            )
+        if self.non_overlap_masks:
+            video_res_masks = self._apply_non_overlapping_constraints(video_res_masks)
+        return any_res_masks, video_res_masks
+
+    def _consolidate_temp_output_across_obj(
+        self,
+        inference_state,
+        frame_idx,
+        is_cond,
+        run_mem_encoder,
+        consolidate_at_video_res=False,
+    ):
+        """
+        Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on
+        a frame into a single output for all objects, including
+        1) fill any missing objects either from `output_dict_per_obj` (if they exist in
+           `output_dict_per_obj` for this frame) or leave them as placeholder values
+           (if they don't exist in `output_dict_per_obj` for this frame);
+        2) if specified, rerun memory encoder after apply non-overlapping constraints
+           on the object scores.
+        """
+        batch_size = self._get_obj_num(inference_state)
+        storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+        # Optionally, we allow consolidating the temporary outputs at the original
+        # video resolution (to provide a better editing experience for mask prompts).
+        if consolidate_at_video_res:
+            assert not run_mem_encoder, "memory encoder cannot run at video resolution"
+            consolidated_H = inference_state["video_height"]
+            consolidated_W = inference_state["video_width"]
+            consolidated_mask_key = "pred_masks_video_res"
+        else:
+            consolidated_H = consolidated_W = self.image_size // 4
+            consolidated_mask_key = "pred_masks"
+
+        # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc"
+        # will be added when rerunning the memory encoder after applying non-overlapping
+        # constraints to object scores. Its "pred_masks" are prefilled with a large
+        # negative value (NO_OBJ_SCORE) to represent missing objects.
+        consolidated_out = {
+            "maskmem_features": None,
+            "maskmem_pos_enc": None,
+            consolidated_mask_key: torch.full(
+                size=(batch_size, 1, consolidated_H, consolidated_W),
+                fill_value=NO_OBJ_SCORE,
+                dtype=torch.float32,
+                device=inference_state["storage_device"],
+            ),
+            "obj_ptr": torch.full(
+                size=(batch_size, self.hidden_dim),
+                fill_value=NO_OBJ_SCORE,
+                dtype=torch.float32,
+                device=inference_state["device"],
+            ),
+        }
+        empty_mask_ptr = None
+        for obj_idx in range(batch_size):
+            obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx]
+            obj_output_dict = inference_state["output_dict_per_obj"][obj_idx]
+            out = obj_temp_output_dict[storage_key].get(frame_idx, None)
+            # If the object doesn't appear in "temp_output_dict_per_obj" on this frame,
+            # we fall back and look up its previous output in "output_dict_per_obj".
+            # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in
+            # "output_dict_per_obj" to find a previous output for this object.
+            if out is None:
+                out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None)
+            if out is None:
+                out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None)
+            # If the object doesn't appear in "output_dict_per_obj" either, we skip it
+            # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE
+            # placeholder above) and set its object pointer to be a dummy pointer.
+            if out is None:
+                # Fill in dummy object pointers for those objects without any inputs or
+                # tracking outcomes on this frame (only do it under `run_mem_encoder=True`,
+                # i.e. when we need to build the memory for tracking).
+                if run_mem_encoder:
+                    if empty_mask_ptr is None:
+                        empty_mask_ptr = self._get_empty_mask_ptr(
+                            inference_state, frame_idx
+                        )
+                    # fill object pointer with a dummy pointer (based on an empty mask)
+                    consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr
+                continue
+            # Add the temporary object output mask to consolidated output mask
+            obj_mask = out["pred_masks"]
+            consolidated_pred_masks = consolidated_out[consolidated_mask_key]
+            if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]:
+                consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask
+            else:
+                # Resize first if temporary object mask has a different resolution
+                resized_obj_mask = torch.nn.functional.interpolate(
+                    obj_mask,
+                    size=consolidated_pred_masks.shape[-2:],
+                    mode="bilinear",
+                    align_corners=False,
+                )
+                consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask
+            consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"]
+
+        # Optionally, apply non-overlapping constraints on the consolidated scores
+        # and rerun the memory encoder
+        if run_mem_encoder:
+            device = inference_state["device"]
+            high_res_masks = torch.nn.functional.interpolate(
+                consolidated_out["pred_masks"].to(device, non_blocking=True),
+                size=(self.image_size, self.image_size),
+                mode="bilinear",
+                align_corners=False,
+            )
+            if self.non_overlap_masks_for_mem_enc:
+                high_res_masks = self._apply_non_overlapping_constraints(high_res_masks)
+            maskmem_features, maskmem_pos_enc = self._run_memory_encoder(
+                inference_state=inference_state,
+                frame_idx=frame_idx,
+                batch_size=batch_size,
+                high_res_masks=high_res_masks,
+                is_mask_from_pts=True,  # these frames are what the user interacted with
+            )
+            consolidated_out["maskmem_features"] = maskmem_features
+            consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc
+
+        return consolidated_out
+
+    def _get_empty_mask_ptr(self, inference_state, frame_idx):
+        """Get a dummy object pointer based on an empty mask on the current frame."""
+        # A dummy (empty) mask with a single object
+        batch_size = 1
+        mask_inputs = torch.zeros(
+            (batch_size, 1, self.image_size, self.image_size),
+            dtype=torch.float32,
+            device=inference_state["device"],
+        )
+
+        # Retrieve correct image features
+        (
+            _,
+            _,
+            current_vision_feats,
+            current_vision_pos_embeds,
+            feat_sizes,
+        ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+        # Feed the empty mask and image feature above to get a dummy object pointer
+        current_out = self.track_step(
+            frame_idx=frame_idx,
+            is_init_cond_frame=True,
+            current_vision_feats=current_vision_feats,
+            current_vision_pos_embeds=current_vision_pos_embeds,
+            feat_sizes=feat_sizes,
+            point_inputs=None,
+            mask_inputs=mask_inputs,
+            output_dict={},
+            num_frames=inference_state["num_frames"],
+            track_in_reverse=False,
+            run_mem_encoder=False,
+            prev_sam_mask_logits=None,
+        )
+        return current_out["obj_ptr"]
+
+    @torch.inference_mode()
+    def propagate_in_video_preflight(self, inference_state):
+        """Prepare inference_state and consolidate temporary outputs before tracking."""
+        # Tracking has started and we don't allow adding new objects until session is reset.
+        inference_state["tracking_has_started"] = True
+        batch_size = self._get_obj_num(inference_state)
+
+        # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and
+        # add them into "output_dict".
+        temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"]
+        output_dict = inference_state["output_dict"]
+        # "consolidated_frame_inds" contains indices of those frames where consolidated
+        # temporary outputs have been added (either in this call or any previous calls
+        # to `propagate_in_video_preflight`).
+        consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+        for is_cond in [False, True]:
+            # Separately consolidate conditioning and non-conditioning temp outptus
+            storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs"
+            # Find all the frames that contain temporary outputs for any objects
+            # (these should be the frames that have just received clicks for mask inputs
+            # via `add_new_points_or_box` or `add_new_mask`)
+            temp_frame_inds = set()
+            for obj_temp_output_dict in temp_output_dict_per_obj.values():
+                temp_frame_inds.update(obj_temp_output_dict[storage_key].keys())
+            consolidated_frame_inds[storage_key].update(temp_frame_inds)
+            # consolidate the temprary output across all objects on this frame
+            for frame_idx in temp_frame_inds:
+                consolidated_out = self._consolidate_temp_output_across_obj(
+                    inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True
+                )
+                # merge them into "output_dict" and also create per-object slices
+                output_dict[storage_key][frame_idx] = consolidated_out
+                self._add_output_per_object(
+                    inference_state, frame_idx, consolidated_out, storage_key
+                )
+                clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+                    self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+                )
+                if clear_non_cond_mem:
+                    # clear non-conditioning memory of the surrounding frames
+                    self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+
+            # clear temporary outputs in `temp_output_dict_per_obj`
+            for obj_temp_output_dict in temp_output_dict_per_obj.values():
+                obj_temp_output_dict[storage_key].clear()
+
+        # edge case: if an output is added to "cond_frame_outputs", we remove any prior
+        # output on the same frame in "non_cond_frame_outputs"
+        for frame_idx in output_dict["cond_frame_outputs"]:
+            output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+        for obj_output_dict in inference_state["output_dict_per_obj"].values():
+            for frame_idx in obj_output_dict["cond_frame_outputs"]:
+                obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None)
+        for frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+            assert frame_idx in output_dict["cond_frame_outputs"]
+            consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx)
+
+        # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames
+        # with either points or mask inputs (which should be true under a correct workflow).
+        all_consolidated_frame_inds = (
+            consolidated_frame_inds["cond_frame_outputs"]
+            | consolidated_frame_inds["non_cond_frame_outputs"]
+        )
+        input_frames_inds = set()
+        for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values():
+            input_frames_inds.update(point_inputs_per_frame.keys())
+        for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values():
+            input_frames_inds.update(mask_inputs_per_frame.keys())
+        assert all_consolidated_frame_inds == input_frames_inds
+
+    @torch.inference_mode()
+    def propagate_in_video(
+        self,
+        inference_state,
+        start_frame_idx=None,
+        max_frame_num_to_track=None,
+        reverse=False,
+    ):
+        """Propagate the input points across frames to track in the entire video."""
+        self.propagate_in_video_preflight(inference_state)
+
+        output_dict = inference_state["output_dict"]
+        consolidated_frame_inds = inference_state["consolidated_frame_inds"]
+        obj_ids = inference_state["obj_ids"]
+        num_frames = inference_state["num_frames"]
+        batch_size = self._get_obj_num(inference_state)
+        if len(output_dict["cond_frame_outputs"]) == 0:
+            raise RuntimeError("No points are provided; please add points first")
+        clear_non_cond_mem = self.clear_non_cond_mem_around_input and (
+            self.clear_non_cond_mem_for_multi_obj or batch_size <= 1
+        )
+
+        # set start index, end index, and processing order
+        if start_frame_idx is None:
+            # default: start from the earliest frame with input points
+            start_frame_idx = min(output_dict["cond_frame_outputs"])
+        if max_frame_num_to_track is None:
+            # default: track all the frames in the video
+            max_frame_num_to_track = num_frames
+        if reverse:
+            end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0)
+            if start_frame_idx > 0:
+                processing_order = range(start_frame_idx, end_frame_idx - 1, -1)
+            else:
+                processing_order = []  # skip reverse tracking if starting from frame 0
+        else:
+            end_frame_idx = min(
+                start_frame_idx + max_frame_num_to_track, num_frames - 1
+            )
+            processing_order = range(start_frame_idx, end_frame_idx + 1)
+
+        for frame_idx in tqdm(processing_order, desc="propagate in video"):
+            # We skip those frames already in consolidated outputs (these are frames
+            # that received input clicks or mask). Note that we cannot directly run
+            # batched forward on them via `_run_single_frame_inference` because the
+            # number of clicks on each object might be different.
+            if frame_idx in consolidated_frame_inds["cond_frame_outputs"]:
+                storage_key = "cond_frame_outputs"
+                current_out = output_dict[storage_key][frame_idx]
+                pred_masks = current_out["pred_masks"]
+                if clear_non_cond_mem:
+                    # clear non-conditioning memory of the surrounding frames
+                    self._clear_non_cond_mem_around_input(inference_state, frame_idx)
+            elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]:
+                storage_key = "non_cond_frame_outputs"
+                current_out = output_dict[storage_key][frame_idx]
+                pred_masks = current_out["pred_masks"]
+            else:
+                storage_key = "non_cond_frame_outputs"
+                current_out, pred_masks = self._run_single_frame_inference(
+                    inference_state=inference_state,
+                    output_dict=output_dict,
+                    frame_idx=frame_idx,
+                    batch_size=batch_size,
+                    is_init_cond_frame=False,
+                    point_inputs=None,
+                    mask_inputs=None,
+                    reverse=reverse,
+                    run_mem_encoder=True,
+                )
+                output_dict[storage_key][frame_idx] = current_out
+            # Create slices of per-object outputs for subsequent interaction with each
+            # individual object after tracking.
+            self._add_output_per_object(
+                inference_state, frame_idx, current_out, storage_key
+            )
+            inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse}
+
+            # Resize the output mask to the original video resolution (we directly use
+            # the mask scores on GPU for output to avoid any CPU conversion in between)
+            _, video_res_masks = self._get_orig_video_res_output(
+                inference_state, pred_masks
+            )
+            yield frame_idx, obj_ids, video_res_masks
+
+    def _add_output_per_object(
+        self, inference_state, frame_idx, current_out, storage_key
+    ):
+        """
+        Split a multi-object output into per-object output slices and add them into
+        `output_dict_per_obj`. The resulting slices share the same tensor storage.
+        """
+        maskmem_features = current_out["maskmem_features"]
+        assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor)
+
+        maskmem_pos_enc = current_out["maskmem_pos_enc"]
+        assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list)
+
+        output_dict_per_obj = inference_state["output_dict_per_obj"]
+        for obj_idx, obj_output_dict in output_dict_per_obj.items():
+            obj_slice = slice(obj_idx, obj_idx + 1)
+            obj_out = {
+                "maskmem_features": None,
+                "maskmem_pos_enc": None,
+                "pred_masks": current_out["pred_masks"][obj_slice],
+                "obj_ptr": current_out["obj_ptr"][obj_slice],
+            }
+            if maskmem_features is not None:
+                obj_out["maskmem_features"] = maskmem_features[obj_slice]
+            if maskmem_pos_enc is not None:
+                obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc]
+            obj_output_dict[storage_key][frame_idx] = obj_out
+
+    @torch.inference_mode()
+    def reset_state(self, inference_state):
+        """Remove all input points or mask in all frames throughout the video."""
+        self._reset_tracking_results(inference_state)
+        # Remove all object ids
+        inference_state["obj_id_to_idx"].clear()
+        inference_state["obj_idx_to_id"].clear()
+        inference_state["obj_ids"].clear()
+        inference_state["point_inputs_per_obj"].clear()
+        inference_state["mask_inputs_per_obj"].clear()
+        inference_state["output_dict_per_obj"].clear()
+        inference_state["temp_output_dict_per_obj"].clear()
+
+    def _reset_tracking_results(self, inference_state):
+        """Reset all tracking inputs and results across the videos."""
+        for v in inference_state["point_inputs_per_obj"].values():
+            v.clear()
+        for v in inference_state["mask_inputs_per_obj"].values():
+            v.clear()
+        for v in inference_state["output_dict_per_obj"].values():
+            v["cond_frame_outputs"].clear()
+            v["non_cond_frame_outputs"].clear()
+        for v in inference_state["temp_output_dict_per_obj"].values():
+            v["cond_frame_outputs"].clear()
+            v["non_cond_frame_outputs"].clear()
+        inference_state["output_dict"]["cond_frame_outputs"].clear()
+        inference_state["output_dict"]["non_cond_frame_outputs"].clear()
+        inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear()
+        inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear()
+        inference_state["tracking_has_started"] = False
+        inference_state["frames_already_tracked"].clear()
+
+    def _get_image_feature(self, inference_state, frame_idx, batch_size):
+        """Compute the image features on a given frame."""
+        # Look up in the cache first
+        image, backbone_out = inference_state["cached_features"].get(
+            frame_idx, (None, None)
+        )
+        if backbone_out is None:
+            # Cache miss -- we will run inference on a single image
+            image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0)
+            backbone_out = self.forward_image(image)
+            # Cache the most recent frame's feature (for repeated interactions with
+            # a frame; we can use an LRU cache for more frames in the future).
+            inference_state["cached_features"] = {frame_idx: (image, backbone_out)}
+
+        # expand the features to have the same dimension as the number of objects
+        expanded_image = image.expand(batch_size, -1, -1, -1)
+        expanded_backbone_out = {
+            "backbone_fpn": backbone_out["backbone_fpn"].copy(),
+            "vision_pos_enc": backbone_out["vision_pos_enc"].copy(),
+        }
+        for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]):
+            expanded_backbone_out["backbone_fpn"][i] = feat.expand(
+                batch_size, -1, -1, -1
+            )
+        for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]):
+            pos = pos.expand(batch_size, -1, -1, -1)
+            expanded_backbone_out["vision_pos_enc"][i] = pos
+
+        features = self._prepare_backbone_features(expanded_backbone_out)
+        features = (expanded_image,) + features
+        return features
+
+    def _run_single_frame_inference(
+        self,
+        inference_state,
+        output_dict,
+        frame_idx,
+        batch_size,
+        is_init_cond_frame,
+        point_inputs,
+        mask_inputs,
+        reverse,
+        run_mem_encoder,
+        prev_sam_mask_logits=None,
+    ):
+        """Run tracking on a single frame based on current inputs and previous memory."""
+        # Retrieve correct image features
+        (
+            _,
+            _,
+            current_vision_feats,
+            current_vision_pos_embeds,
+            feat_sizes,
+        ) = self._get_image_feature(inference_state, frame_idx, batch_size)
+
+        # point and mask should not appear as input simultaneously on the same frame
+        assert point_inputs is None or mask_inputs is None
+        current_out = self.track_step(
+            frame_idx=frame_idx,
+            is_init_cond_frame=is_init_cond_frame,
+            current_vision_feats=current_vision_feats,
+            current_vision_pos_embeds=current_vision_pos_embeds,
+            feat_sizes=feat_sizes,
+            point_inputs=point_inputs,
+            mask_inputs=mask_inputs,
+            output_dict=output_dict,
+            num_frames=inference_state["num_frames"],
+            track_in_reverse=reverse,
+            run_mem_encoder=run_mem_encoder,
+            prev_sam_mask_logits=prev_sam_mask_logits,
+        )
+
+        # optionally offload the output to CPU memory to save GPU space
+        storage_device = inference_state["storage_device"]
+        maskmem_features = current_out["maskmem_features"]
+        if maskmem_features is not None:
+            maskmem_features = maskmem_features.to(torch.bfloat16)
+            maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+        pred_masks_gpu = current_out["pred_masks"]
+        # potentially fill holes in the predicted masks
+        if self.fill_hole_area > 0:
+            pred_masks_gpu = fill_holes_in_mask_scores(
+                pred_masks_gpu, self.fill_hole_area
+            )
+        pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True)
+        # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+        maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out)
+        # object pointer is a small tensor, so we always keep it on GPU memory for fast access
+        obj_ptr = current_out["obj_ptr"]
+        # make a compact version of this frame's output to reduce the state size
+        compact_current_out = {
+            "maskmem_features": maskmem_features,
+            "maskmem_pos_enc": maskmem_pos_enc,
+            "pred_masks": pred_masks,
+            "obj_ptr": obj_ptr,
+        }
+        return compact_current_out, pred_masks_gpu
+
+    def _run_memory_encoder(
+        self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts
+    ):
+        """
+        Run the memory encoder on `high_res_masks`. This is usually after applying
+        non-overlapping constraints to object scores. Since their scores changed, their
+        memory also need to be computed again with the memory encoder.
+        """
+        # Retrieve correct image features
+        _, _, current_vision_feats, _, feat_sizes = self._get_image_feature(
+            inference_state, frame_idx, batch_size
+        )
+        maskmem_features, maskmem_pos_enc = self._encode_new_memory(
+            current_vision_feats=current_vision_feats,
+            feat_sizes=feat_sizes,
+            pred_masks_high_res=high_res_masks,
+            is_mask_from_pts=is_mask_from_pts,
+        )
+
+        # optionally offload the output to CPU memory to save GPU space
+        storage_device = inference_state["storage_device"]
+        maskmem_features = maskmem_features.to(torch.bfloat16)
+        maskmem_features = maskmem_features.to(storage_device, non_blocking=True)
+        # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it
+        maskmem_pos_enc = self._get_maskmem_pos_enc(
+            inference_state, {"maskmem_pos_enc": maskmem_pos_enc}
+        )
+        return maskmem_features, maskmem_pos_enc
+
+    def _get_maskmem_pos_enc(self, inference_state, current_out):
+        """
+        `maskmem_pos_enc` is the same across frames and objects, so we cache it as
+        a constant in the inference session to reduce session storage size.
+        """
+        model_constants = inference_state["constants"]
+        # "out_maskmem_pos_enc" should be either a list of tensors or None
+        out_maskmem_pos_enc = current_out["maskmem_pos_enc"]
+        if out_maskmem_pos_enc is not None:
+            if "maskmem_pos_enc" not in model_constants:
+                assert isinstance(out_maskmem_pos_enc, list)
+                # only take the slice for one object, since it's same across objects
+                maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc]
+                model_constants["maskmem_pos_enc"] = maskmem_pos_enc
+            else:
+                maskmem_pos_enc = model_constants["maskmem_pos_enc"]
+            # expand the cached maskmem_pos_enc to the actual batch size
+            batch_size = out_maskmem_pos_enc[0].size(0)
+            expanded_maskmem_pos_enc = [
+                x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc
+            ]
+        else:
+            expanded_maskmem_pos_enc = None
+        return expanded_maskmem_pos_enc
+
+    def _clear_non_cond_mem_around_input(self, inference_state, frame_idx):
+        """
+        Remove the non-conditioning memory around the input frame. When users provide
+        correction clicks, the surrounding frames' non-conditioning memories can still
+        contain outdated object appearance information and could confuse the model.
+
+        This method clears those non-conditioning memories surrounding the interacted
+        frame to avoid giving the model both old and new information about the object.
+        """
+        r = self.memory_temporal_stride_for_eval
+        frame_idx_begin = frame_idx - r * self.num_maskmem
+        frame_idx_end = frame_idx + r * self.num_maskmem
+        output_dict = inference_state["output_dict"]
+        non_cond_frame_outputs = output_dict["non_cond_frame_outputs"]
+        for t in range(frame_idx_begin, frame_idx_end + 1):
+            non_cond_frame_outputs.pop(t, None)
+            for obj_output_dict in inference_state["output_dict_per_obj"].values():
+                obj_output_dict["non_cond_frame_outputs"].pop(t, None)
diff --git a/sam2/utils/__init__.py b/sam2/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2/utils/__pycache__/__init__.cpython-38.pyc b/sam2/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b21e2ff1656de771fe65e29da0d2acf066754878
Binary files /dev/null and b/sam2/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sam2/utils/__pycache__/misc.cpython-38.pyc b/sam2/utils/__pycache__/misc.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9ce20911a73e11523e49706ef94ea22a8b3767ba
Binary files /dev/null and b/sam2/utils/__pycache__/misc.cpython-38.pyc differ
diff --git a/sam2/utils/__pycache__/transforms.cpython-38.pyc b/sam2/utils/__pycache__/transforms.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3eec9b72dae7c7cb5d30d8d3bf329af70df64c79
Binary files /dev/null and b/sam2/utils/__pycache__/transforms.cpython-38.pyc differ
diff --git a/sam2/utils/amg.py b/sam2/utils/amg.py
new file mode 100644
index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e
--- /dev/null
+++ b/sam2/utils/amg.py
@@ -0,0 +1,348 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from copy import deepcopy
+from itertools import product
+from typing import Any, Dict, Generator, ItemsView, List, Tuple
+
+import numpy as np
+import torch
+
+# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
+
+
+class MaskData:
+    """
+    A structure for storing masks and their related data in batched format.
+    Implements basic filtering and concatenation.
+    """
+
+    def __init__(self, **kwargs) -> None:
+        for v in kwargs.values():
+            assert isinstance(
+                v, (list, np.ndarray, torch.Tensor)
+            ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats = dict(**kwargs)
+
+    def __setitem__(self, key: str, item: Any) -> None:
+        assert isinstance(
+            item, (list, np.ndarray, torch.Tensor)
+        ), "MaskData only supports list, numpy arrays, and torch tensors."
+        self._stats[key] = item
+
+    def __delitem__(self, key: str) -> None:
+        del self._stats[key]
+
+    def __getitem__(self, key: str) -> Any:
+        return self._stats[key]
+
+    def items(self) -> ItemsView[str, Any]:
+        return self._stats.items()
+
+    def filter(self, keep: torch.Tensor) -> None:
+        for k, v in self._stats.items():
+            if v is None:
+                self._stats[k] = None
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = v[keep.detach().cpu().numpy()]
+            elif isinstance(v, list) and keep.dtype == torch.bool:
+                self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
+            elif isinstance(v, list):
+                self._stats[k] = [v[i] for i in keep]
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+    def cat(self, new_stats: "MaskData") -> None:
+        for k, v in new_stats.items():
+            if k not in self._stats or self._stats[k] is None:
+                self._stats[k] = deepcopy(v)
+            elif isinstance(v, torch.Tensor):
+                self._stats[k] = torch.cat([self._stats[k], v], dim=0)
+            elif isinstance(v, np.ndarray):
+                self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
+            elif isinstance(v, list):
+                self._stats[k] = self._stats[k] + deepcopy(v)
+            else:
+                raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
+
+    def to_numpy(self) -> None:
+        for k, v in self._stats.items():
+            if isinstance(v, torch.Tensor):
+                self._stats[k] = v.float().detach().cpu().numpy()
+
+
+def is_box_near_crop_edge(
+    boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
+) -> torch.Tensor:
+    """Filter masks at the edge of a crop, but not at the edge of the original image."""
+    crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
+    orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
+    boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
+    near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
+    near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
+    near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
+    return torch.any(near_crop_edge, dim=1)
+
+
+def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
+    box_xywh = deepcopy(box_xyxy)
+    box_xywh[2] = box_xywh[2] - box_xywh[0]
+    box_xywh[3] = box_xywh[3] - box_xywh[1]
+    return box_xywh
+
+
+def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
+    assert len(args) > 0 and all(
+        len(a) == len(args[0]) for a in args
+    ), "Batched iteration must have inputs of all the same size."
+    n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
+    for b in range(n_batches):
+        yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
+
+
+def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
+    """
+    Encodes masks to an uncompressed RLE, in the format expected by
+    pycoco tools.
+    """
+    # Put in fortran order and flatten h,w
+    b, h, w = tensor.shape
+    tensor = tensor.permute(0, 2, 1).flatten(1)
+
+    # Compute change indices
+    diff = tensor[:, 1:] ^ tensor[:, :-1]
+    change_indices = diff.nonzero()
+
+    # Encode run length
+    out = []
+    for i in range(b):
+        cur_idxs = change_indices[change_indices[:, 0] == i, 1]
+        cur_idxs = torch.cat(
+            [
+                torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
+                cur_idxs + 1,
+                torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
+            ]
+        )
+        btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
+        counts = [] if tensor[i, 0] == 0 else [0]
+        counts.extend(btw_idxs.detach().cpu().tolist())
+        out.append({"size": [h, w], "counts": counts})
+    return out
+
+
+def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
+    """Compute a binary mask from an uncompressed RLE."""
+    h, w = rle["size"]
+    mask = np.empty(h * w, dtype=bool)
+    idx = 0
+    parity = False
+    for count in rle["counts"]:
+        mask[idx : idx + count] = parity
+        idx += count
+        parity ^= True
+    mask = mask.reshape(w, h)
+    return mask.transpose()  # Put in C order
+
+
+def area_from_rle(rle: Dict[str, Any]) -> int:
+    return sum(rle["counts"][1::2])
+
+
+def calculate_stability_score(
+    masks: torch.Tensor, mask_threshold: float, threshold_offset: float
+) -> torch.Tensor:
+    """
+    Computes the stability score for a batch of masks. The stability
+    score is the IoU between the binary masks obtained by thresholding
+    the predicted mask logits at high and low values.
+    """
+    # One mask is always contained inside the other.
+    # Save memory by preventing unnecessary cast to torch.int64
+    intersections = (
+        (masks > (mask_threshold + threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    unions = (
+        (masks > (mask_threshold - threshold_offset))
+        .sum(-1, dtype=torch.int16)
+        .sum(-1, dtype=torch.int32)
+    )
+    return intersections / unions
+
+
+def build_point_grid(n_per_side: int) -> np.ndarray:
+    """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
+    offset = 1 / (2 * n_per_side)
+    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
+    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
+    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
+    points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
+    return points
+
+
+def build_all_layer_point_grids(
+    n_per_side: int, n_layers: int, scale_per_layer: int
+) -> List[np.ndarray]:
+    """Generates point grids for all crop layers."""
+    points_by_layer = []
+    for i in range(n_layers + 1):
+        n_points = int(n_per_side / (scale_per_layer**i))
+        points_by_layer.append(build_point_grid(n_points))
+    return points_by_layer
+
+
+def generate_crop_boxes(
+    im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
+) -> Tuple[List[List[int]], List[int]]:
+    """
+    Generates a list of crop boxes of different sizes. Each layer
+    has (2**i)**2 boxes for the ith layer.
+    """
+    crop_boxes, layer_idxs = [], []
+    im_h, im_w = im_size
+    short_side = min(im_h, im_w)
+
+    # Original image
+    crop_boxes.append([0, 0, im_w, im_h])
+    layer_idxs.append(0)
+
+    def crop_len(orig_len, n_crops, overlap):
+        return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
+
+    for i_layer in range(n_layers):
+        n_crops_per_side = 2 ** (i_layer + 1)
+        overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
+
+        crop_w = crop_len(im_w, n_crops_per_side, overlap)
+        crop_h = crop_len(im_h, n_crops_per_side, overlap)
+
+        crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
+        crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
+
+        # Crops in XYWH format
+        for x0, y0 in product(crop_box_x0, crop_box_y0):
+            box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
+            crop_boxes.append(box)
+            layer_idxs.append(i_layer + 1)
+
+    return crop_boxes, layer_idxs
+
+
+def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
+    # Check if boxes has a channel dimension
+    if len(boxes.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return boxes + offset
+
+
+def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
+    x0, y0, _, _ = crop_box
+    offset = torch.tensor([[x0, y0]], device=points.device)
+    # Check if points has a channel dimension
+    if len(points.shape) == 3:
+        offset = offset.unsqueeze(1)
+    return points + offset
+
+
+def uncrop_masks(
+    masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
+) -> torch.Tensor:
+    x0, y0, x1, y1 = crop_box
+    if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
+        return masks
+    # Coordinate transform masks
+    pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
+    pad = (x0, pad_x - x0, y0, pad_y - y0)
+    return torch.nn.functional.pad(masks, pad, value=0)
+
+
+def remove_small_regions(
+    mask: np.ndarray, area_thresh: float, mode: str
+) -> Tuple[np.ndarray, bool]:
+    """
+    Removes small disconnected regions and holes in a mask. Returns the
+    mask and an indicator of if the mask has been modified.
+    """
+    import cv2  # type: ignore
+
+    assert mode in ["holes", "islands"]
+    correct_holes = mode == "holes"
+    working_mask = (correct_holes ^ mask).astype(np.uint8)
+    n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
+    sizes = stats[:, -1][1:]  # Row 0 is background label
+    small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
+    if len(small_regions) == 0:
+        return mask, False
+    fill_labels = [0] + small_regions
+    if not correct_holes:
+        fill_labels = [i for i in range(n_labels) if i not in fill_labels]
+        # If every region is below threshold, keep largest
+        if len(fill_labels) == 0:
+            fill_labels = [int(np.argmax(sizes)) + 1]
+    mask = np.isin(regions, fill_labels)
+    return mask, True
+
+
+def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
+    from pycocotools import mask as mask_utils  # type: ignore
+
+    h, w = uncompressed_rle["size"]
+    rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
+    rle["counts"] = rle["counts"].decode("utf-8")  # Necessary to serialize with json
+    return rle
+
+
+def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
+    """
+    Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
+    an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
+    """
+    # torch.max below raises an error on empty inputs, just skip in this case
+    if torch.numel(masks) == 0:
+        return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
+
+    # Normalize shape to CxHxW
+    shape = masks.shape
+    h, w = shape[-2:]
+    if len(shape) > 2:
+        masks = masks.flatten(0, -3)
+    else:
+        masks = masks.unsqueeze(0)
+
+    # Get top and bottom edges
+    in_height, _ = torch.max(masks, dim=-1)
+    in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
+    bottom_edges, _ = torch.max(in_height_coords, dim=-1)
+    in_height_coords = in_height_coords + h * (~in_height)
+    top_edges, _ = torch.min(in_height_coords, dim=-1)
+
+    # Get left and right edges
+    in_width, _ = torch.max(masks, dim=-2)
+    in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
+    right_edges, _ = torch.max(in_width_coords, dim=-1)
+    in_width_coords = in_width_coords + w * (~in_width)
+    left_edges, _ = torch.min(in_width_coords, dim=-1)
+
+    # If the mask is empty the right edge will be to the left of the left edge.
+    # Replace these boxes with [0, 0, 0, 0]
+    empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
+    out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
+    out = out * (~empty_filter).unsqueeze(-1)
+
+    # Return to original shape
+    if len(shape) > 2:
+        out = out.reshape(*shape[:-2], 4)
+    else:
+        out = out[0]
+
+    return out
diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..df97b4a8e96f968d8993473344bf60eb8fadfd65
--- /dev/null
+++ b/sam2/utils/misc.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import warnings
+from threading import Thread
+
+import numpy as np
+import torch
+from PIL import Image
+from tqdm import tqdm
+
+
+def get_sdpa_settings():
+    if torch.cuda.is_available():
+        old_gpu = torch.cuda.get_device_properties(0).major < 7
+        # only use Flash Attention on Ampere (8.0) or newer GPUs
+        use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
+        if not use_flash_attn:
+            warnings.warn(
+                "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.",
+                category=UserWarning,
+                stacklevel=2,
+            )
+        # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only
+        # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases)
+        pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
+        if pytorch_version < (2, 2):
+            warnings.warn(
+                f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. "
+                "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).",
+                category=UserWarning,
+                stacklevel=2,
+            )
+        math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn
+    else:
+        old_gpu = True
+        use_flash_attn = False
+        math_kernel_on = True
+
+    return old_gpu, use_flash_attn, math_kernel_on
+
+
+def get_connected_components(mask):
+    """
+    Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
+
+    Inputs:
+    - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
+            background.
+
+    Outputs:
+    - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
+              for foreground pixels and 0 for background pixels.
+    - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
+              components for foreground pixels and 0 for background pixels.
+    """
+    from sam2 import _C
+
+    return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
+
+
+def mask_to_box(masks: torch.Tensor):
+    """
+    compute bounding box given an input mask
+
+    Inputs:
+    - masks: [B, 1, H, W] boxes, dtype=torch.Tensor
+
+    Returns:
+    - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
+    """
+    B, _, h, w = masks.shape
+    device = masks.device
+    xs = torch.arange(w, device=device, dtype=torch.int32)
+    ys = torch.arange(h, device=device, dtype=torch.int32)
+    grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
+    grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
+    grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
+    min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
+    max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
+    min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
+    max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
+    bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
+
+    return bbox_coords
+
+
+def _load_img_as_tensor(img_path, image_size):
+    img_pil = Image.open(img_path)
+    img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
+    if img_np.dtype == np.uint8:  # np.uint8 is expected for JPEG images
+        img_np = img_np / 255.0
+    else:
+        raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
+    img = torch.from_numpy(img_np).permute(2, 0, 1)
+    video_width, video_height = img_pil.size  # the original video size
+    return img, video_height, video_width
+
+
+class AsyncVideoFrameLoader:
+    """
+    A list of video frames to be load asynchronously without blocking session start.
+    """
+
+    def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std):
+        self.img_paths = img_paths
+        self.image_size = image_size
+        self.offload_video_to_cpu = offload_video_to_cpu
+        self.img_mean = img_mean
+        self.img_std = img_std
+        # items in `self._images` will be loaded asynchronously
+        self.images = [None] * len(img_paths)
+        # catch and raise any exceptions in the async loading thread
+        self.exception = None
+        # video_height and video_width be filled when loading the first image
+        self.video_height = None
+        self.video_width = None
+
+        # load the first frame to fill video_height and video_width and also
+        # to cache it (since it's most likely where the user will click)
+        self.__getitem__(0)
+
+        # load the rest of frames asynchronously without blocking the session start
+        def _load_frames():
+            try:
+                for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
+                    self.__getitem__(n)
+            except Exception as e:
+                self.exception = e
+
+        self.thread = Thread(target=_load_frames, daemon=True)
+        self.thread.start()
+
+    def __getitem__(self, index):
+        if self.exception is not None:
+            raise RuntimeError("Failure in frame loading thread") from self.exception
+
+        img = self.images[index]
+        if img is not None:
+            return img
+
+        img, video_height, video_width = _load_img_as_tensor(
+            self.img_paths[index], self.image_size
+        )
+        self.video_height = video_height
+        self.video_width = video_width
+        # normalize by mean and std
+        img -= self.img_mean
+        img /= self.img_std
+        if not self.offload_video_to_cpu:
+            img = img.cuda(non_blocking=True)
+        self.images[index] = img
+        return img
+
+    def __len__(self):
+        return len(self.images)
+
+
+def load_video_frames(
+    video_path,
+    image_size,
+    offload_video_to_cpu,
+    img_mean=(0.485, 0.456, 0.406),
+    img_std=(0.229, 0.224, 0.225),
+    async_loading_frames=False,
+):
+    """
+    Load the video frames from a directory of JPEG files ("<frame_index>.jpg" format).
+
+    The frames are resized to image_size x image_size and are loaded to GPU if
+    `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`.
+
+    You can load a frame asynchronously by setting `async_loading_frames` to `True`.
+    """
+    if isinstance(video_path, str) and os.path.isdir(video_path):
+        jpg_folder = video_path
+    else:
+        raise NotImplementedError("Only JPEG frames are supported at this moment")
+
+    frame_names = [
+        p
+        for p in os.listdir(jpg_folder)
+        if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
+    ]
+    frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
+    num_frames = len(frame_names)
+    if num_frames == 0:
+        raise RuntimeError(f"no images found in {jpg_folder}")
+    img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
+    img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
+    img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
+
+    if async_loading_frames:
+        lazy_images = AsyncVideoFrameLoader(
+            img_paths, image_size, offload_video_to_cpu, img_mean, img_std
+        )
+        return lazy_images, lazy_images.video_height, lazy_images.video_width
+
+    images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
+    for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
+        images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
+    if not offload_video_to_cpu:
+        images = images.cuda()
+        img_mean = img_mean.cuda()
+        img_std = img_std.cuda()
+    # normalize by mean and std
+    images -= img_mean
+    images /= img_std
+    return images, video_height, video_width
+
+
+def fill_holes_in_mask_scores(mask, max_area):
+    """
+    A post processor to fill small holes in mask scores with area under `max_area`.
+    """
+    # Holes are those connected components in background with area <= self.max_area
+    # (background regions are those with mask scores <= 0)
+    assert max_area > 0, "max_area must be positive"
+
+    input_mask = mask
+    try:
+        labels, areas = get_connected_components(mask <= 0)
+        is_hole = (labels > 0) & (areas <= max_area)
+        # We fill holes with a small positive mask score (0.1) to change them to foreground.
+        mask = torch.where(is_hole, 0.1, mask)
+    except Exception as e:
+        # Skip the post-processing step on removing small holes if the CUDA kernel fails
+        warnings.warn(
+            f"{e}\n\nSkipping the post-processing step due to the error above. "
+            "Consider building SAM 2 with CUDA extension to enable post-processing (see "
+            "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
+            category=UserWarning,
+            stacklevel=2,
+        )
+        mask = input_mask
+
+    return mask
+
+
+def concat_points(old_point_inputs, new_points, new_labels):
+    """Add new points and labels to previous point inputs (add at the end)."""
+    if old_point_inputs is None:
+        points, labels = new_points, new_labels
+    else:
+        points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
+        labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
+
+    return {"point_coords": points, "point_labels": labels}
diff --git a/sam2/utils/transforms.py b/sam2/utils/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..995baf989da5a8e4927c87b1bbb0777067b673cc
--- /dev/null
+++ b/sam2/utils/transforms.py
@@ -0,0 +1,117 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision.transforms import Normalize, Resize, ToTensor
+
+
+class SAM2Transforms(nn.Module):
+    def __init__(
+        self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
+    ):
+        """
+        Transforms for SAM2.
+        """
+        super().__init__()
+        self.resolution = resolution
+        self.mask_threshold = mask_threshold
+        self.max_hole_area = max_hole_area
+        self.max_sprinkle_area = max_sprinkle_area
+        self.mean = [0.485, 0.456, 0.406]
+        self.std = [0.229, 0.224, 0.225]
+        self.to_tensor = ToTensor()
+        self.transforms = torch.jit.script(
+            nn.Sequential(
+                Resize((self.resolution, self.resolution)),
+                Normalize(self.mean, self.std),
+            )
+        )
+
+    def __call__(self, x):
+        x = self.to_tensor(x)
+        return self.transforms(x)
+
+    def forward_batch(self, img_list):
+        img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
+        img_batch = torch.stack(img_batch, dim=0)
+        return img_batch
+
+    def transform_coords(
+        self, coords: torch.Tensor, normalize=False, orig_hw=None
+    ) -> torch.Tensor:
+        """
+        Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
+        If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+
+        Returns
+            Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
+        """
+        if normalize:
+            assert orig_hw is not None
+            h, w = orig_hw
+            coords = coords.clone()
+            coords[..., 0] = coords[..., 0] / w
+            coords[..., 1] = coords[..., 1] / h
+
+        coords = coords * self.resolution  # unnormalize coords
+        return coords
+
+    def transform_boxes(
+        self, boxes: torch.Tensor, normalize=False, orig_hw=None
+    ) -> torch.Tensor:
+        """
+        Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
+        if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
+        """
+        boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
+        return boxes
+
+    def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
+        """
+        Perform PostProcessing on output masks.
+        """
+        from sam2.utils.misc import get_connected_components
+
+        masks = masks.float()
+        input_masks = masks
+        mask_flat = masks.flatten(0, 1).unsqueeze(1)  # flatten as 1-channel image
+        try:
+            if self.max_hole_area > 0:
+                # Holes are those connected components in background with area <= self.fill_hole_area
+                # (background regions are those with mask scores <= self.mask_threshold)
+                labels, areas = get_connected_components(
+                    mask_flat <= self.mask_threshold
+                )
+                is_hole = (labels > 0) & (areas <= self.max_hole_area)
+                is_hole = is_hole.reshape_as(masks)
+                # We fill holes with a small positive mask score (10.0) to change them to foreground.
+                masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
+
+            if self.max_sprinkle_area > 0:
+                labels, areas = get_connected_components(
+                    mask_flat > self.mask_threshold
+                )
+                is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
+                is_hole = is_hole.reshape_as(masks)
+                # We fill holes with negative mask score (-10.0) to change them to background.
+                masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
+        except Exception as e:
+            # Skip the post-processing step if the CUDA kernel fails
+            warnings.warn(
+                f"{e}\n\nSkipping the post-processing step due to the error above. "
+                "Consider building SAM 2 with CUDA extension to enable post-processing (see "
+                "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).",
+                category=UserWarning,
+                stacklevel=2,
+            )
+            masks = input_masks
+
+        masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
+        return masks
diff --git a/sam2_configs/__init__.py b/sam2_configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae
--- /dev/null
+++ b/sam2_configs/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/sam2_configs/__pycache__/__init__.cpython-38.pyc b/sam2_configs/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e29f3cd0b6e2646b4f089ce4b7273f0a93ebd3e0
Binary files /dev/null and b/sam2_configs/__pycache__/__init__.cpython-38.pyc differ
diff --git a/sam2_configs/sam2_hiera_b+.yaml b/sam2_configs/sam2_hiera_b+.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..58f3eb81554018e873f8515ecb98e36d16ac29e4
--- /dev/null
+++ b/sam2_configs/sam2_hiera_b+.yaml
@@ -0,0 +1,113 @@
+# @package _global_
+
+# Model
+model:
+  _target_: sam2.modeling.sam2_base.SAM2Base
+  image_encoder:
+    _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+    scalp: 1
+    trunk:
+      _target_: sam2.modeling.backbones.hieradet.Hiera
+      embed_dim: 112
+      num_heads: 2
+    neck:
+      _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 256
+        normalize: true
+        scale: null
+        temperature: 10000
+      d_model: 256
+      backbone_channel_list: [896, 448, 224, 112]
+      fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
+      fpn_interp_model: nearest
+
+  memory_attention:
+    _target_: sam2.modeling.memory_attention.MemoryAttention
+    d_model: 256
+    pos_enc_at_input: true
+    layer:
+      _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+      activation: relu
+      dim_feedforward: 2048
+      dropout: 0.1
+      pos_enc_at_attn: false
+      self_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+      d_model: 256
+      pos_enc_at_cross_attn_keys: true
+      pos_enc_at_cross_attn_queries: false
+      cross_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        rope_k_repeat: True
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+        kv_in_dim: 64
+    num_layers: 4
+
+  memory_encoder:
+      _target_: sam2.modeling.memory_encoder.MemoryEncoder
+      out_dim: 64
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 64
+        normalize: true
+        scale: null
+        temperature: 10000
+      mask_downsampler:
+        _target_: sam2.modeling.memory_encoder.MaskDownSampler
+        kernel_size: 3
+        stride: 2
+        padding: 1
+      fuser:
+        _target_: sam2.modeling.memory_encoder.Fuser
+        layer:
+          _target_: sam2.modeling.memory_encoder.CXBlock
+          dim: 256
+          kernel_size: 7
+          padding: 3
+          layer_scale_init_value: 1e-6
+          use_dwconv: True  # depth-wise convs
+        num_layers: 2
+
+  num_maskmem: 7
+  image_size: 1024
+  # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+  sigmoid_scale_for_mem_enc: 20.0
+  sigmoid_bias_for_mem_enc: -10.0
+  use_mask_input_as_output_without_sam: true
+  # Memory
+  directly_add_no_mem_embed: true
+  # use high-resolution feature map in the SAM mask decoder
+  use_high_res_features_in_sam: true
+  # output 3 masks on the first click on initial conditioning frames
+  multimask_output_in_sam: true
+  # SAM heads
+  iou_prediction_use_sigmoid: True
+  # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+  use_obj_ptrs_in_encoder: true
+  add_tpos_enc_to_obj_ptrs: false
+  only_obj_ptrs_in_the_past_for_eval: true
+  # object occlusion prediction
+  pred_obj_scores: true
+  pred_obj_scores_mlp: true
+  fixed_no_obj_ptr: true
+  # multimask tracking settings
+  multimask_output_for_tracking: true
+  use_multimask_token_for_obj_ptr: true
+  multimask_min_pt_num: 0
+  multimask_max_pt_num: 1
+  use_mlp_for_obj_ptr_proj: true
+  # Compilation flag
+  compile_image_encoder: False
diff --git a/sam2_configs/sam2_hiera_l.yaml b/sam2_configs/sam2_hiera_l.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..918667f50c3e1ad2dcf77c0c14cb4dd114cfd080
--- /dev/null
+++ b/sam2_configs/sam2_hiera_l.yaml
@@ -0,0 +1,117 @@
+# @package _global_
+
+# Model
+model:
+  _target_: sam2.modeling.sam2_base.SAM2Base
+  image_encoder:
+    _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+    scalp: 1
+    trunk:
+      _target_: sam2.modeling.backbones.hieradet.Hiera
+      embed_dim: 144
+      num_heads: 2
+      stages: [2, 6, 36, 4]
+      global_att_blocks: [23, 33, 43]
+      window_pos_embed_bkg_spatial_size: [7, 7]
+      window_spec: [8, 4, 16, 8]
+    neck:
+      _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 256
+        normalize: true
+        scale: null
+        temperature: 10000
+      d_model: 256
+      backbone_channel_list: [1152, 576, 288, 144]
+      fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
+      fpn_interp_model: nearest
+
+  memory_attention:
+    _target_: sam2.modeling.memory_attention.MemoryAttention
+    d_model: 256
+    pos_enc_at_input: true
+    layer:
+      _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+      activation: relu
+      dim_feedforward: 2048
+      dropout: 0.1
+      pos_enc_at_attn: false
+      self_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+      d_model: 256
+      pos_enc_at_cross_attn_keys: true
+      pos_enc_at_cross_attn_queries: false
+      cross_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        rope_k_repeat: True
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+        kv_in_dim: 64
+    num_layers: 4
+
+  memory_encoder:
+      _target_: sam2.modeling.memory_encoder.MemoryEncoder
+      out_dim: 64
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 64
+        normalize: true
+        scale: null
+        temperature: 10000
+      mask_downsampler:
+        _target_: sam2.modeling.memory_encoder.MaskDownSampler
+        kernel_size: 3
+        stride: 2
+        padding: 1
+      fuser:
+        _target_: sam2.modeling.memory_encoder.Fuser
+        layer:
+          _target_: sam2.modeling.memory_encoder.CXBlock
+          dim: 256
+          kernel_size: 7
+          padding: 3
+          layer_scale_init_value: 1e-6
+          use_dwconv: True  # depth-wise convs
+        num_layers: 2
+
+  num_maskmem: 7
+  image_size: 1024
+  # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+  sigmoid_scale_for_mem_enc: 20.0
+  sigmoid_bias_for_mem_enc: -10.0
+  use_mask_input_as_output_without_sam: true
+  # Memory
+  directly_add_no_mem_embed: true
+  # use high-resolution feature map in the SAM mask decoder
+  use_high_res_features_in_sam: true
+  # output 3 masks on the first click on initial conditioning frames
+  multimask_output_in_sam: true
+  # SAM heads
+  iou_prediction_use_sigmoid: True
+  # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+  use_obj_ptrs_in_encoder: true
+  add_tpos_enc_to_obj_ptrs: false
+  only_obj_ptrs_in_the_past_for_eval: true
+  # object occlusion prediction
+  pred_obj_scores: true
+  pred_obj_scores_mlp: true
+  fixed_no_obj_ptr: true
+  # multimask tracking settings
+  multimask_output_for_tracking: true
+  use_multimask_token_for_obj_ptr: true
+  multimask_min_pt_num: 0
+  multimask_max_pt_num: 1
+  use_mlp_for_obj_ptr_proj: true
+  # Compilation flag
+  compile_image_encoder: False
diff --git a/sam2_configs/sam2_hiera_s.yaml b/sam2_configs/sam2_hiera_s.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..26e5d4d39f7b2892396106005c37c7ffe6c83bc2
--- /dev/null
+++ b/sam2_configs/sam2_hiera_s.yaml
@@ -0,0 +1,116 @@
+# @package _global_
+
+# Model
+model:
+  _target_: sam2.modeling.sam2_base.SAM2Base
+  image_encoder:
+    _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+    scalp: 1
+    trunk:
+      _target_: sam2.modeling.backbones.hieradet.Hiera
+      embed_dim: 96
+      num_heads: 1
+      stages: [1, 2, 11, 2]
+      global_att_blocks: [7, 10, 13]
+      window_pos_embed_bkg_spatial_size: [7, 7]
+    neck:
+      _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 256
+        normalize: true
+        scale: null
+        temperature: 10000
+      d_model: 256
+      backbone_channel_list: [768, 384, 192, 96]
+      fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
+      fpn_interp_model: nearest
+
+  memory_attention:
+    _target_: sam2.modeling.memory_attention.MemoryAttention
+    d_model: 256
+    pos_enc_at_input: true
+    layer:
+      _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+      activation: relu
+      dim_feedforward: 2048
+      dropout: 0.1
+      pos_enc_at_attn: false
+      self_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+      d_model: 256
+      pos_enc_at_cross_attn_keys: true
+      pos_enc_at_cross_attn_queries: false
+      cross_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        rope_k_repeat: True
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+        kv_in_dim: 64
+    num_layers: 4
+
+  memory_encoder:
+      _target_: sam2.modeling.memory_encoder.MemoryEncoder
+      out_dim: 64
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 64
+        normalize: true
+        scale: null
+        temperature: 10000
+      mask_downsampler:
+        _target_: sam2.modeling.memory_encoder.MaskDownSampler
+        kernel_size: 3
+        stride: 2
+        padding: 1
+      fuser:
+        _target_: sam2.modeling.memory_encoder.Fuser
+        layer:
+          _target_: sam2.modeling.memory_encoder.CXBlock
+          dim: 256
+          kernel_size: 7
+          padding: 3
+          layer_scale_init_value: 1e-6
+          use_dwconv: True  # depth-wise convs
+        num_layers: 2
+
+  num_maskmem: 7
+  image_size: 1024
+  # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+  sigmoid_scale_for_mem_enc: 20.0
+  sigmoid_bias_for_mem_enc: -10.0
+  use_mask_input_as_output_without_sam: true
+  # Memory
+  directly_add_no_mem_embed: true
+  # use high-resolution feature map in the SAM mask decoder
+  use_high_res_features_in_sam: true
+  # output 3 masks on the first click on initial conditioning frames
+  multimask_output_in_sam: true
+  # SAM heads
+  iou_prediction_use_sigmoid: True
+  # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+  use_obj_ptrs_in_encoder: true
+  add_tpos_enc_to_obj_ptrs: false
+  only_obj_ptrs_in_the_past_for_eval: true
+  # object occlusion prediction
+  pred_obj_scores: true
+  pred_obj_scores_mlp: true
+  fixed_no_obj_ptr: true
+  # multimask tracking settings
+  multimask_output_for_tracking: true
+  use_multimask_token_for_obj_ptr: true
+  multimask_min_pt_num: 0
+  multimask_max_pt_num: 1
+  use_mlp_for_obj_ptr_proj: true
+  # Compilation flag
+  compile_image_encoder: False
diff --git a/sam2_configs/sam2_hiera_t.yaml b/sam2_configs/sam2_hiera_t.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a62c903aaa5f80828077c6e06a59626926570ed6
--- /dev/null
+++ b/sam2_configs/sam2_hiera_t.yaml
@@ -0,0 +1,118 @@
+# @package _global_
+
+# Model
+model:
+  _target_: sam2.modeling.sam2_base.SAM2Base
+  image_encoder:
+    _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
+    scalp: 1
+    trunk:
+      _target_: sam2.modeling.backbones.hieradet.Hiera
+      embed_dim: 96
+      num_heads: 1
+      stages: [1, 2, 7, 2]
+      global_att_blocks: [5, 7, 9]
+      window_pos_embed_bkg_spatial_size: [7, 7]
+    neck:
+      _target_: sam2.modeling.backbones.image_encoder.FpnNeck
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 256
+        normalize: true
+        scale: null
+        temperature: 10000
+      d_model: 256
+      backbone_channel_list: [768, 384, 192, 96]
+      fpn_top_down_levels: [2, 3]  # output level 0 and 1 directly use the backbone features
+      fpn_interp_model: nearest
+
+  memory_attention:
+    _target_: sam2.modeling.memory_attention.MemoryAttention
+    d_model: 256
+    pos_enc_at_input: true
+    layer:
+      _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
+      activation: relu
+      dim_feedforward: 2048
+      dropout: 0.1
+      pos_enc_at_attn: false
+      self_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+      d_model: 256
+      pos_enc_at_cross_attn_keys: true
+      pos_enc_at_cross_attn_queries: false
+      cross_attention:
+        _target_: sam2.modeling.sam.transformer.RoPEAttention
+        rope_theta: 10000.0
+        feat_sizes: [32, 32]
+        rope_k_repeat: True
+        embedding_dim: 256
+        num_heads: 1
+        downsample_rate: 1
+        dropout: 0.1
+        kv_in_dim: 64
+    num_layers: 4
+
+  memory_encoder:
+      _target_: sam2.modeling.memory_encoder.MemoryEncoder
+      out_dim: 64
+      position_encoding:
+        _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
+        num_pos_feats: 64
+        normalize: true
+        scale: null
+        temperature: 10000
+      mask_downsampler:
+        _target_: sam2.modeling.memory_encoder.MaskDownSampler
+        kernel_size: 3
+        stride: 2
+        padding: 1
+      fuser:
+        _target_: sam2.modeling.memory_encoder.Fuser
+        layer:
+          _target_: sam2.modeling.memory_encoder.CXBlock
+          dim: 256
+          kernel_size: 7
+          padding: 3
+          layer_scale_init_value: 1e-6
+          use_dwconv: True  # depth-wise convs
+        num_layers: 2
+
+  num_maskmem: 7
+  image_size: 1024
+  # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
+  # SAM decoder
+  sigmoid_scale_for_mem_enc: 20.0
+  sigmoid_bias_for_mem_enc: -10.0
+  use_mask_input_as_output_without_sam: true
+  # Memory
+  directly_add_no_mem_embed: true
+  # use high-resolution feature map in the SAM mask decoder
+  use_high_res_features_in_sam: true
+  # output 3 masks on the first click on initial conditioning frames
+  multimask_output_in_sam: true
+  # SAM heads
+  iou_prediction_use_sigmoid: True
+  # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
+  use_obj_ptrs_in_encoder: true
+  add_tpos_enc_to_obj_ptrs: false
+  only_obj_ptrs_in_the_past_for_eval: true
+  # object occlusion prediction
+  pred_obj_scores: true
+  pred_obj_scores_mlp: true
+  fixed_no_obj_ptr: true
+  # multimask tracking settings
+  multimask_output_for_tracking: true
+  use_multimask_token_for_obj_ptr: true
+  multimask_min_pt_num: 0
+  multimask_max_pt_num: 1
+  use_mlp_for_obj_ptr_proj: true
+  # Compilation flag
+  # HieraT does not currently support compilation, should always be set to False
+  compile_image_encoder: False