diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..1a66a3a9a1405f07aea398f222adbf6f449d3146
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,36 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.model filter=lfs diff=lfs merge=lfs -text
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.pt filter=lfs diff=lfs merge=lfs -text
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+*.jpeg filter=lfs diff=lfs merge=lfs -text
+*.gif filter=lfs diff=lfs merge=lfs -text
+*.png filter=lfs diff=lfs merge=lfs -text
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..43184f07307c95e7d3fd796e1372c4882026899c
--- /dev/null
+++ b/README.md
@@ -0,0 +1,14 @@
+---
+title: Ov Seg
+emoji: 📊
+colorFrom: red
+colorTo: pink
+sdk: gradio
+sdk_version: 3.8.2
+app_file: app.py
+pinned: false
+license: cc-by-nc-4.0
+duplicated_from: facebook/ov-seg
+---
+
+Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd9b5260e98a6bbdce7c0dcdb36bd3780587b4d2
--- /dev/null
+++ b/app.py
@@ -0,0 +1,96 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import multiprocessing as mp
+
+import numpy as np
+from PIL import Image
+
+
+try:
+    import detectron2
+except:
+    import os
+    os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
+
+from detectron2.config import get_cfg
+
+from detectron2.projects.deeplab import add_deeplab_config
+from detectron2.data.detection_utils import read_image
+from open_vocab_seg import add_ovseg_config
+from open_vocab_seg.utils import VisualizationDemo, SAMVisualizationDemo
+
+import gradio as gr
+
+import gdown
+
+# ckpt_url = 'https://drive.google.com/uc?id=1cn-ohxgXDrDfkzC1QdO-fi8IjbjXmgKy'
+# output = './ovseg_swinbase_vitL14_ft_mpt.pth'
+# gdown.download(ckpt_url, output, quiet=False)
+
+def setup_cfg(config_file):
+    # load config from file and command-line arguments
+    cfg = get_cfg()
+    add_deeplab_config(cfg)
+    add_ovseg_config(cfg)
+    cfg.merge_from_file(config_file)
+    cfg.freeze()
+    return cfg
+
+
+def inference(class_names, proposal_gen, granularity, input_img):
+    mp.set_start_method("spawn", force=True)
+    config_file = './ovseg_swinB_vitL_demo.yaml'
+    cfg = setup_cfg(config_file)
+    if proposal_gen == 'MaskFormer':
+        demo = VisualizationDemo(cfg)
+    elif proposal_gen == 'Segment_Anything':
+        demo = SAMVisualizationDemo(cfg, granularity, './sam_vit_l_0b3195.pth', './ovseg_clip_l_9a1909.pth')
+    class_names = class_names.split(',')
+    img = read_image(input_img, format="BGR")
+    _, visualized_output = demo.run_on_image(img, class_names)
+
+    return Image.fromarray(np.uint8(visualized_output.get_image())).convert('RGB')
+
+
+examples = [['Saturn V, toys, desk, wall, sunflowers, white roses, chrysanthemums, carnations, green dianthus', 'Segment_Anything', 0.8, './resources/demo_samples/sample_01.jpeg'],
+            ['red bench, yellow bench, blue bench, brown bench, green bench, blue chair, yellow chair, green chair, brown chair, yellow square painting, barrel, buddha statue', 'Segment_Anything', 0.8, './resources/demo_samples/sample_04.png'],
+            ['pillow, pipe, sweater, shirt, jeans jacket, shoes, cabinet, handbag, photo frame', 'Segment_Anything', 0.8, './resources/demo_samples/sample_05.png'],
+            ['Saturn V, toys, blossom', 'MaskFormer', 1.0, './resources/demo_samples/sample_01.jpeg'],
+            ['Oculus, Ukulele', 'MaskFormer', 1.0, './resources/demo_samples/sample_03.jpeg'],
+            ['Golden gate, yacht', 'MaskFormer', 1.0, './resources/demo_samples/sample_02.jpeg'],]
+output_labels = ['segmentation map']
+
+title = 'OVSeg (+ Segment_Anything)'
+
+description = """
+[NEW!] We incorperate OVSeg CLIP w/ Segment_Anything, enabling SAM's text prompts.
+Gradio Demo for Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP. \n
+OVSeg could perform open vocabulary segmentation, you may input more classes (seperate by comma). You may click on of the examples or upload your own image. \n
+It might take some time to process. Cheers!
+<p>(Colab only supports MaskFormer proposal generator) Don't want to wait in queue? <a href="https://colab.research.google.com/drive/1O4Ain5uFZNcQYUmDTG92DpEGCatga8K5?usp=sharing"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://camo.githubusercontent.com/84f0493939e0c4de4e6dbe113251b4bfb5353e57134ffd9fcab6b8714514d4d1/68747470733a2f2f636f6c61622e72657365617263682e676f6f676c652e636f6d2f6173736574732f636f6c61622d62616467652e737667"></a></p>
+"""
+
+article = """
+<p style='text-align: center'>
+<a href='https://arxiv.org/abs/2210.04150' target='_blank'>
+Open-Vocabulary Semantic Segmentation with Mask-adapted CLIP
+</a>
+|
+<a href='https://github.com/facebookresearch/ov-seg' target='_blank'>Github Repo</a></p>
+"""
+
+gr.Interface(
+    inference,
+    inputs=[
+        gr.Textbox(
+            lines=1, placeholder=None, default='', label='class names'),
+        gr.Radio(["Segment_Anything", "MaskFormer"], label="Proposal generator", default="Segment_Anything"),
+        gr.Slider(0, 1.0, 0.8, label="For Segment_Anything only, granularity of masks from 0 (most coarse) to 1 (most precise)"),
+        gr.Image(type='filepath'),
+    ],
+    outputs=gr.outputs.Image(label='segmentation map'),
+    title=title,
+    description=description,
+    article=article,
+    examples=examples).launch(enable_queue=True)
diff --git a/open_vocab_seg/.DS_Store b/open_vocab_seg/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..571ecdfd99ba3951eec8fd0206519a409e86c2b1
Binary files /dev/null and b/open_vocab_seg/.DS_Store differ
diff --git a/open_vocab_seg/__init__.py b/open_vocab_seg/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b54fce14b8a029f1355bc8b74c20884e880ee9c4
--- /dev/null
+++ b/open_vocab_seg/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from . import data
+from . import modeling
+from .config import add_ovseg_config
+
+from .test_time_augmentation import SemanticSegmentorWithTTA
+from .ovseg_model import OVSeg, OVSegDEMO
diff --git a/open_vocab_seg/config.py b/open_vocab_seg/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..400e9a05d4995e3f3401b34a22ae687b2c9c90e0
--- /dev/null
+++ b/open_vocab_seg/config.py
@@ -0,0 +1,133 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from detectron2.config import CfgNode as CN
+
+
+def add_mask_former_default_config(cfg):
+    # data config
+    # select the dataset mapper
+    cfg.INPUT.DATASET_MAPPER_NAME = "mask_former_semantic"
+    # Color augmentation
+    cfg.INPUT.COLOR_AUG_SSD = False
+    # We retry random cropping until no single category in semantic segmentation GT occupies more
+    # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
+    cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
+    # Pad image and segmentation GT in dataset mapper.
+    cfg.INPUT.SIZE_DIVISIBILITY = -1
+
+    # solver config
+    # test batch size
+    cfg.SOLVER.TEST_IMS_PER_BATCH = 1
+    # weight decay on embedding
+    cfg.SOLVER.WEIGHT_DECAY_EMBED = 0.0
+    # optimizer
+    cfg.SOLVER.OPTIMIZER = "ADAMW"
+    cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1
+
+    # mask_former model config
+    cfg.MODEL.MASK_FORMER = CN()
+
+    # loss
+    cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION = True
+    cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT = 0.1
+    cfg.MODEL.MASK_FORMER.DICE_WEIGHT = 1.0
+    cfg.MODEL.MASK_FORMER.MASK_WEIGHT = 20.0
+
+    # transformer config
+    cfg.MODEL.MASK_FORMER.NHEADS = 8
+    cfg.MODEL.MASK_FORMER.DROPOUT = 0.1
+    cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD = 2048
+    cfg.MODEL.MASK_FORMER.ENC_LAYERS = 0
+    cfg.MODEL.MASK_FORMER.DEC_LAYERS = 6
+    cfg.MODEL.MASK_FORMER.PRE_NORM = False
+
+    cfg.MODEL.MASK_FORMER.HIDDEN_DIM = 256
+    cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES = 100
+
+    cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE = "res5"
+    cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ = False
+
+    # mask_former inference config
+    cfg.MODEL.MASK_FORMER.TEST = CN()
+    cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON = False
+    cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD = 0.0
+    cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD = 0.0
+    cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE = False
+
+    # Sometimes `backbone.size_divisibility` is set to 0 for some backbone (e.g. ResNet)
+    # you can use this config to override
+    cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY = 32
+
+    # pixel decoder config
+    cfg.MODEL.SEM_SEG_HEAD.MASK_DIM = 256
+    # adding transformer in pixel decoder
+    cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS = 0
+    # pixel decoder
+    cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME = "BasePixelDecoder"
+
+    # swin transformer backbone
+    cfg.MODEL.SWIN = CN()
+    cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE = 224
+    cfg.MODEL.SWIN.PATCH_SIZE = 4
+    cfg.MODEL.SWIN.EMBED_DIM = 96
+    cfg.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
+    cfg.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
+    cfg.MODEL.SWIN.WINDOW_SIZE = 7
+    cfg.MODEL.SWIN.MLP_RATIO = 4.0
+    cfg.MODEL.SWIN.QKV_BIAS = True
+    cfg.MODEL.SWIN.QK_SCALE = None
+    cfg.MODEL.SWIN.NORM_INDICES = None
+    cfg.MODEL.SWIN.PROJECTION = False
+    cfg.MODEL.SWIN.PROJECT_DIM = 256
+    cfg.MODEL.SWIN.DROP_RATE = 0.0
+    cfg.MODEL.SWIN.ATTN_DROP_RATE = 0.0
+    cfg.MODEL.SWIN.DROP_PATH_RATE = 0.3
+    cfg.MODEL.SWIN.APE = False
+    cfg.MODEL.SWIN.PATCH_NORM = True
+    cfg.MODEL.SWIN.OUT_FEATURES = ["res2", "res3", "res4", "res5"]
+
+
+def add_our_config(cfg):
+    cfg.TEST.SLIDING_WINDOW = False
+    cfg.TEST.SLIDING_TILE_SIZE = 224
+    cfg.TEST.SLIDING_OVERLAP = 2 / 3.0
+    # whether to use dense crf
+    cfg.TEST.DENSE_CRF = False
+    cfg.DATASETS.SAMPLE_PER_CLASS = -1
+    cfg.DATASETS.SAMPLE_SEED = 0
+    # embedding head
+    cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM = 512
+    cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM = 1024
+    cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS = 2
+    # clip_adapter
+    cfg.MODEL.CLIP_ADAPTER = CN()
+    cfg.MODEL.CLIP_ADAPTER.TEXT_TEMPLATES = "vild"
+    # for predefined
+    cfg.MODEL.CLIP_ADAPTER.PREDEFINED_PROMPT_TEMPLATES = ["a photo of a {}."]
+    # for learnable prompt
+    cfg.MODEL.CLIP_ADAPTER.PROMPT_CHECKPOINT = ""
+    cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME = "ViT-B/16"
+    cfg.MODEL.CLIP_ADAPTER.MASK_FILL = "mean"
+    cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO = 1.0
+    cfg.MODEL.CLIP_ADAPTER.MASK_THR = 0.4
+    cfg.MODEL.CLIP_ADAPTER.MASK_MATTING = False
+    cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED = True
+    cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE = True
+    cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT = 0.7
+    # for mask prompt
+    cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH = 3
+    cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD = False
+
+    # wandb
+    cfg.WANDB = CN()
+    cfg.WANDB.PROJECT = "open_vocab_seg"
+    cfg.WANDB.NAME = None
+
+
+def add_ovseg_config(cfg):
+    """
+    Add config for open_vocab_seg.
+    """
+    add_mask_former_default_config(cfg)
+    add_our_config(cfg)
diff --git a/open_vocab_seg/data/.DS_Store b/open_vocab_seg/data/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..e3548df3784b4020dc5b9d6383241cbe099cb0df
Binary files /dev/null and b/open_vocab_seg/data/.DS_Store differ
diff --git a/open_vocab_seg/data/__init__.py b/open_vocab_seg/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..970e2c8ce7f90afab089bf84e249af5ee7124951
--- /dev/null
+++ b/open_vocab_seg/data/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from .dataset_mappers import *
+from . import datasets
+from .build import (
+    build_detection_train_loader,
+    build_detection_test_loader,
+)
diff --git a/open_vocab_seg/data/augmentations.py b/open_vocab_seg/data/augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..44e4906d4827812fa707f50e703f253a64ab6e43
--- /dev/null
+++ b/open_vocab_seg/data/augmentations.py
@@ -0,0 +1,202 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import math
+import numbers
+import numpy as np
+from detectron2.data.transforms.augmentation import Augmentation
+from detectron2.data.transforms.transform import (
+    CropTransform,
+    ResizeTransform,
+    TransformList,
+)
+from PIL import Image
+from fvcore.transforms.transform import PadTransform
+
+
+def mask2box(mask: np.ndarray):
+    # use naive way
+    row = np.nonzero(mask.sum(axis=0))[0]
+    if len(row) == 0:
+        return None
+    x1 = row.min()
+    x2 = row.max()
+    col = np.nonzero(mask.sum(axis=1))[0]
+    y1 = col.min()
+    y2 = col.max()
+    return x1, y1, x2 + 1 - x1, y2 + 1 - y1
+
+
+def expand_box(x, y, w, h, expand_ratio=1.0, max_h=None, max_w=None):
+    cx = x + 0.5 * w
+    cy = y + 0.5 * h
+    w = w * expand_ratio
+    h = h * expand_ratio
+    box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
+    if max_h is not None:
+        box[1] = max(0, box[1])
+        box[3] = min(max_h - 1, box[3])
+    if max_w is not None:
+        box[0] = max(0, box[0])
+        box[2] = min(max_w - 1, box[2])
+    box[2] = box[2] - box[0]
+    box[3] = box[3] - box[1]
+
+    return [int(b) for b in box]
+
+
+class CropImageWithMask(Augmentation):
+    def __init__(self, expand_ratio=1.0, mode="choice"):
+        if isinstance(expand_ratio, numbers.Number):
+            expand_ratio = (expand_ratio, expand_ratio)
+        self.mode = mode
+        self.expand_ratio = expand_ratio
+        if self.mode == "range":
+            assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1]
+
+    def get_transform(self, image, sem_seg, category_id):
+        input_size = image.shape[:2]
+        bin_mask = sem_seg == category_id
+        x, y, w, h = mask2box(bin_mask)
+        if self.mode == "choice":
+            expand_ratio = np.random.choice(self.expand_ratio)
+        else:
+            expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1])
+        x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size)
+        w = max(w, 1)
+        h = max(h, 1)
+        return CropTransform(x, y, w, h, input_size[1], input_size[0])
+
+
+class CropImageWithBox(Augmentation):
+    def __init__(self, expand_ratio=1.0, mode="choice"):
+        if isinstance(expand_ratio, numbers.Number):
+            expand_ratio = (expand_ratio, expand_ratio)
+        self.mode = mode
+        self.expand_ratio = expand_ratio
+        if self.mode == "range":
+            assert len(expand_ratio) == 2 and expand_ratio[0] < expand_ratio[1]
+
+    def get_transform(self, image, boxes):
+        input_size = image.shape[:2]
+        x, y, x2, y2 = boxes[0]
+        w = x2 - x + 1
+        h = y2 - y + 1
+        if self.mode == "choice":
+            expand_ratio = np.random.choice(self.expand_ratio)
+        else:
+            expand_ratio = np.random.uniform(self.expand_ratio[0], self.expand_ratio[1])
+        x, y, w, h = expand_box(x, y, w, h, expand_ratio, *input_size)
+        w = max(w, 1)
+        h = max(h, 1)
+        return CropTransform(x, y, w, h, input_size[1], input_size[0])
+
+
+class RandomResizedCrop(Augmentation):
+    def __init__(
+        self,
+        size,
+        scale=(0.08, 1.0),
+        ratio=(3.0 / 4.0, 4.0 / 3.0),
+        interpolation=Image.BILINEAR,
+    ):
+        if isinstance(size, int):
+            size = (size, size)
+        else:
+            assert isinstance(size, (tuple, list)) and len(size) == 2
+
+        self.size = size
+
+        self.scale = scale
+        self.ratio = ratio
+        self.interpolation = interpolation
+
+    def get_transform(self, image):
+        height, width = image.shape[:2]
+        area = height * width
+
+        log_ratio = np.log(np.array(self.ratio))
+        is_success = False
+        for _ in range(10):
+            target_area = area * np.random.uniform(self.scale[0], self.scale[1])
+            aspect_ratio = np.exp(np.random.uniform(log_ratio[0], log_ratio[1]))
+
+            w = int(round(math.sqrt(target_area * aspect_ratio)))
+            h = int(round(math.sqrt(target_area / aspect_ratio)))
+
+            if 0 < w <= width and 0 < h <= height:
+                i = np.random.randint(0, width - w + 1)
+                j = np.random.randint(0, height - h + 1)
+
+                is_success = True
+                break
+
+        if not is_success:
+            # Fallback to central crop
+            in_ratio = float(width) / float(height)
+            if in_ratio < min(self.ratio):
+                w = width
+                h = int(round(w / min(self.ratio)))
+            elif in_ratio > max(self.ratio):
+                h = height
+                w = int(round(h * max(self.ratio)))
+            else:  # whole image
+                w = width
+                h = height
+            i = (width - w) // 2
+            j = (height - h) // 2
+        return TransformList(
+            [
+                CropTransform(i, j, w, h, width, height),
+                ResizeTransform(
+                    h, w, self.size[1], self.size[0], interp=self.interpolation
+                ),
+            ]
+        )
+
+
+class CenterCrop(Augmentation):
+    def __init__(self, size, seg_ignore_label):
+        if isinstance(size, numbers.Number):
+            size = (int(size), int(size))
+        elif isinstance(size, (tuple, list)) and len(size) == 1:
+            size = (size[0], size[0])
+        self.size = size
+        self.seg_ignore_label = seg_ignore_label
+
+    def get_transform(self, image):
+
+        image_height, image_width = image.shape[:2]
+        crop_height, crop_width = self.size
+
+        transforms = []
+        if crop_width > image_width or crop_height > image_height:
+            padding_ltrb = [
+                (crop_width - image_width) // 2 if crop_width > image_width else 0,
+                (crop_height - image_height) // 2 if crop_height > image_height else 0,
+                (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
+                (crop_height - image_height + 1) // 2
+                if crop_height > image_height
+                else 0,
+            ]
+            transforms.append(
+                PadTransform(
+                    *padding_ltrb,
+                    orig_w=image_width,
+                    orig_h=image_height,
+                    seg_pad_value=self.seg_ignore_label
+                )
+            )
+            image_width, image_height = (
+                image_width + padding_ltrb[0] + padding_ltrb[2],
+                image_height + padding_ltrb[1] + padding_ltrb[3],
+            )
+
+        crop_top = int(round((image_height - crop_height) / 2.0))
+        crop_left = int(round((image_width - crop_width) / 2.0))
+        transforms.append(
+            CropTransform(
+                crop_left, crop_top, crop_width, crop_height, image_width, image_height
+            )
+        )
+        return TransformList(transforms)
diff --git a/open_vocab_seg/data/build.py b/open_vocab_seg/data/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcd3b9dcebb86c319b91a632c25bcf7827292c3f
--- /dev/null
+++ b/open_vocab_seg/data/build.py
@@ -0,0 +1,344 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import itertools
+import logging
+import numpy as np
+from collections import Counter
+import torch.utils.data
+from tabulate import tabulate
+from termcolor import colored
+
+from detectron2.utils.logger import _log_api_usage, log_first_n
+from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
+import torch.utils.data
+from detectron2.config import configurable
+from detectron2.data.build import (
+    build_batch_data_loader,
+    trivial_batch_collator,
+    load_proposals_into_dataset,
+    filter_images_with_only_crowd_annotations,
+    filter_images_with_few_keypoints,
+    print_instances_class_histogram,
+)
+
+from detectron2.data.common import DatasetFromList, MapDataset
+from detectron2.data.dataset_mapper import DatasetMapper
+from detectron2.data.detection_utils import check_metadata_consistency
+from detectron2.data.samplers import (
+    InferenceSampler,
+    RandomSubsetTrainingSampler,
+    RepeatFactorTrainingSampler,
+    TrainingSampler,
+)
+
+"""
+This file contains the default logic to build a dataloader for training or testing.
+"""
+
+__all__ = [
+    "build_detection_train_loader",
+    "build_detection_test_loader",
+]
+
+
+def print_classification_instances_class_histogram(dataset_dicts, class_names):
+    """
+    Args:
+        dataset_dicts (list[dict]): list of dataset dicts.
+        class_names (list[str]): list of class names (zero-indexed).
+    """
+    num_classes = len(class_names)
+    hist_bins = np.arange(num_classes + 1)
+    histogram = np.zeros((num_classes,), dtype=np.int)
+    for entry in dataset_dicts:
+        classes = np.asarray([entry["category_id"]], dtype=np.int)
+        if len(classes):
+            assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}"
+            assert (
+                classes.max() < num_classes
+            ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes"
+        histogram += np.histogram(classes, bins=hist_bins)[0]
+
+    N_COLS = min(6, len(class_names) * 2)
+
+    def short_name(x):
+        # make long class names shorter. useful for lvis
+        if len(x) > 13:
+            return x[:11] + ".."
+        return x
+
+    data = list(
+        itertools.chain(
+            *[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]
+        )
+    )
+    total_num_instances = sum(data[1::2])
+    data.extend([None] * (N_COLS - (len(data) % N_COLS)))
+    if num_classes > 1:
+        data.extend(["total", total_num_instances])
+    data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)])
+    table = tabulate(
+        data,
+        headers=["category", "#instances"] * (N_COLS // 2),
+        tablefmt="pipe",
+        numalign="left",
+        stralign="center",
+    )
+    log_first_n(
+        logging.INFO,
+        "Distribution of instances among all {} categories:\n".format(num_classes)
+        + colored(table, "cyan"),
+        key="message",
+    )
+
+
+def wrap_metas(dataset_dict, **kwargs):
+    def _assign_attr(data_dict: dict, **kwargs):
+        assert not any(
+            [key in data_dict for key in kwargs]
+        ), "Assigned attributes should not exist in the original sample."
+        data_dict.update(kwargs)
+        return data_dict
+
+    return [_assign_attr(sample, meta=kwargs) for sample in dataset_dict]
+
+
+def get_detection_dataset_dicts(
+    names, filter_empty=True, min_keypoints=0, proposal_files=None
+):
+    """
+    Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
+
+    Args:
+        names (str or list[str]): a dataset name or a list of dataset names
+        filter_empty (bool): whether to filter out images without instance annotations
+        min_keypoints (int): filter out images with fewer keypoints than
+            `min_keypoints`. Set to 0 to do nothing.
+        proposal_files (list[str]): if given, a list of object proposal files
+            that match each dataset in `names`.
+
+    Returns:
+        list[dict]: a list of dicts following the standard dataset dict format.
+    """
+    if isinstance(names, str):
+        names = [names]
+    assert len(names), names
+    dataset_dicts = [
+        wrap_metas(DatasetCatalog.get(dataset_name), dataset_name=dataset_name)
+        for dataset_name in names
+    ]
+    for dataset_name, dicts in zip(names, dataset_dicts):
+        assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
+
+    if proposal_files is not None:
+        assert len(names) == len(proposal_files)
+        # load precomputed proposals from proposal files
+        dataset_dicts = [
+            load_proposals_into_dataset(dataset_i_dicts, proposal_file)
+            for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files)
+        ]
+
+    dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
+
+    has_instances = "annotations" in dataset_dicts[0]
+    if filter_empty and has_instances:
+        dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
+    if min_keypoints > 0 and has_instances:
+        dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
+
+    if has_instances:
+        try:
+            class_names = MetadataCatalog.get(names[0]).thing_classes
+            check_metadata_consistency("thing_classes", names)
+            print_instances_class_histogram(dataset_dicts, class_names)
+        except AttributeError:  # class names are not available for this dataset
+            pass
+
+    assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
+    return dataset_dicts
+
+
+def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
+    if dataset is None:
+        dataset = get_detection_dataset_dicts(
+            cfg.DATASETS.TRAIN,
+            filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
+            min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
+            if cfg.MODEL.KEYPOINT_ON
+            else 0,
+            proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN
+            if cfg.MODEL.LOAD_PROPOSALS
+            else None,
+        )
+        _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0])
+
+    if mapper is None:
+        mapper = DatasetMapper(cfg, True)
+
+    if sampler is None:
+        sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
+        logger = logging.getLogger(__name__)
+        logger.info("Using training sampler {}".format(sampler_name))
+        if sampler_name == "TrainingSampler":
+            sampler = TrainingSampler(len(dataset))
+        elif sampler_name == "RepeatFactorTrainingSampler":
+            repeat_factors = (
+                RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
+                    dataset, cfg.DATALOADER.REPEAT_THRESHOLD
+                )
+            )
+            sampler = RepeatFactorTrainingSampler(repeat_factors)
+        elif sampler_name == "RandomSubsetTrainingSampler":
+            sampler = RandomSubsetTrainingSampler(
+                len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO
+            )
+        else:
+            raise ValueError("Unknown training sampler: {}".format(sampler_name))
+
+    return {
+        "dataset": dataset,
+        "sampler": sampler,
+        "mapper": mapper,
+        "total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
+        "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
+        "num_workers": cfg.DATALOADER.NUM_WORKERS,
+    }
+
+
+# TODO can allow dataset as an iterable or IterableDataset to make this function more general
+@configurable(from_config=_train_loader_from_config)
+def build_detection_train_loader(
+    dataset,
+    *,
+    mapper,
+    sampler=None,
+    total_batch_size,
+    aspect_ratio_grouping=True,
+    num_workers=0,
+):
+    """
+    Build a dataloader for object detection with some default features.
+    This interface is experimental.
+
+    Args:
+        dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
+            or a map-style pytorch dataset. They can be obtained by using
+            :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
+        mapper (callable): a callable which takes a sample (dict) from dataset and
+            returns the format to be consumed by the model.
+            When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``.
+        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
+            indices to be applied on ``dataset``. Default to :class:`TrainingSampler`,
+            which coordinates an infinite random shuffle sequence across all workers.
+        total_batch_size (int): total batch size across all workers. Batching
+            simply puts data into a list.
+        aspect_ratio_grouping (bool): whether to group images with similar
+            aspect ratio for efficiency. When enabled, it requires each
+            element in dataset be a dict with keys "width" and "height".
+        num_workers (int): number of parallel data loading workers
+
+    Returns:
+        torch.utils.data.DataLoader:
+            a dataloader. Each output from it is a ``list[mapped_element]`` of length
+            ``total_batch_size / num_workers``, where ``mapped_element`` is produced
+            by the ``mapper``.
+    """
+    if isinstance(dataset, list):
+        dataset = DatasetFromList(dataset, copy=False)
+    if mapper is not None:
+        dataset = MapDataset(dataset, mapper)
+    if sampler is None:
+        sampler = TrainingSampler(len(dataset))
+    assert isinstance(sampler, torch.utils.data.sampler.Sampler)
+    return build_batch_data_loader(
+        dataset,
+        sampler,
+        total_batch_size,
+        aspect_ratio_grouping=aspect_ratio_grouping,
+        num_workers=num_workers,
+    )
+
+
+def _test_loader_from_config(cfg, dataset_name, mapper=None):
+    """
+    Uses the given `dataset_name` argument (instead of the names in cfg), because the
+    standard practice is to evaluate each test set individually (not combining them).
+    """
+    if isinstance(dataset_name, str):
+        dataset_name = [dataset_name]
+
+    dataset = get_detection_dataset_dicts(
+        dataset_name,
+        filter_empty=False,
+        proposal_files=[
+            cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)]
+            for x in dataset_name
+        ]
+        if cfg.MODEL.LOAD_PROPOSALS
+        else None,
+    )
+    if mapper is None:
+        mapper = DatasetMapper(cfg, False)
+    return {
+        "dataset": dataset,
+        "mapper": mapper,
+        "num_workers": 0,
+        "samples_per_gpu": cfg.SOLVER.TEST_IMS_PER_BATCH,
+    }
+
+
+@configurable(from_config=_test_loader_from_config)
+def build_detection_test_loader(
+    dataset, *, mapper, sampler=None, num_workers=0, samples_per_gpu=1
+):
+    """
+    Similar to `build_detection_train_loader`, but uses a batch size of 1,
+    and :class:`InferenceSampler`. This sampler coordinates all workers to
+    produce the exact set of all samples.
+    This interface is experimental.
+
+    Args:
+        dataset (list or torch.utils.data.Dataset): a list of dataset dicts,
+            or a map-style pytorch dataset. They can be obtained by using
+            :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`.
+        mapper (callable): a callable which takes a sample (dict) from dataset
+           and returns the format to be consumed by the model.
+           When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``.
+        sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces
+            indices to be applied on ``dataset``. Default to :class:`InferenceSampler`,
+            which splits the dataset across all workers.
+        num_workers (int): number of parallel data loading workers
+
+    Returns:
+        DataLoader: a torch DataLoader, that loads the given detection
+        dataset, with test-time transformation and batching.
+
+    Examples:
+    ::
+        data_loader = build_detection_test_loader(
+            DatasetRegistry.get("my_test"),
+            mapper=DatasetMapper(...))
+
+        # or, instantiate with a CfgNode:
+        data_loader = build_detection_test_loader(cfg, "my_test")
+    """
+    if isinstance(dataset, list):
+        dataset = DatasetFromList(dataset, copy=False)
+    if mapper is not None:
+        dataset = MapDataset(dataset, mapper)
+    if sampler is None:
+        sampler = InferenceSampler(len(dataset))
+    # Always use 1 image per worker during inference since this is the
+    # standard when reporting inference time in papers.
+    batch_sampler = torch.utils.data.sampler.BatchSampler(
+        sampler, samples_per_gpu, drop_last=False
+    )
+    data_loader = torch.utils.data.DataLoader(
+        dataset,
+        num_workers=num_workers,
+        batch_sampler=batch_sampler,
+        collate_fn=trivial_batch_collator,
+    )
+    return data_loader
+
diff --git a/open_vocab_seg/data/dataset_mappers/__init__.py b/open_vocab_seg/data/dataset_mappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f63cd5c034fcb60af8c78431205ae9b410f33250
--- /dev/null
+++ b/open_vocab_seg/data/dataset_mappers/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper
diff --git a/open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py b/open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..2836579942cf91c726cb34cbbd2d137c975bee37
--- /dev/null
+++ b/open_vocab_seg/data/dataset_mappers/mask_former_semantic_dataset_mapper.py
@@ -0,0 +1,208 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import copy
+import logging
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.data import detection_utils as utils
+from detectron2.data import transforms as T
+from detectron2.projects.point_rend import ColorAugSSDTransform
+from detectron2.structures import BitMasks, Instances
+
+__all__ = ["MaskFormerSemanticDatasetMapper"]
+
+
+class MaskFormerSemanticDatasetMapper:
+    """
+    A callable which takes a dataset dict in Detectron2 Dataset format,
+    and map it into a format used by MaskFormer for semantic segmentation.
+
+    The callable currently does the following:
+
+    1. Read the image from "file_name"
+    2. Applies geometric transforms to the image and annotation
+    3. Find and applies suitable cropping to the image and annotation
+    4. Prepare image and annotation to Tensors
+    """
+
+    @configurable
+    def __init__(
+        self,
+        is_train=True,
+        *,
+        augmentations,
+        image_format,
+        ignore_label,
+        size_divisibility,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            is_train: for training or inference
+            augmentations: a list of augmentations or deterministic transforms to apply
+            image_format: an image format supported by :func:`detection_utils.read_image`.
+            ignore_label: the label that is ignored to evaluation
+            size_divisibility: pad image size to be divisible by this value
+        """
+        self.is_train = is_train
+        self.tfm_gens = augmentations
+        self.img_format = image_format
+        self.ignore_label = ignore_label
+        self.size_divisibility = size_divisibility
+
+        logger = logging.getLogger(__name__)
+        mode = "training" if is_train else "inference"
+        logger.info(
+            f"[{self.__class__.__name__}] Augmentations used in {mode}: {augmentations}"
+        )
+
+    @classmethod
+    def from_config(cls, cfg, is_train=True):
+        # Build augmentation
+        if is_train:
+            augs = [
+                T.ResizeShortestEdge(
+                    cfg.INPUT.MIN_SIZE_TRAIN,
+                    cfg.INPUT.MAX_SIZE_TRAIN,
+                    cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
+                )
+            ]
+            if cfg.INPUT.CROP.ENABLED:
+                augs.append(
+                    T.RandomCrop_CategoryAreaConstraint(
+                        cfg.INPUT.CROP.TYPE,
+                        cfg.INPUT.CROP.SIZE,
+                        cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA,
+                        cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+                    )
+                )
+            if cfg.INPUT.COLOR_AUG_SSD:
+                augs.append(ColorAugSSDTransform(img_format=cfg.INPUT.FORMAT))
+            augs.append(T.RandomFlip())
+
+            # Assume always applies to the training set.
+            dataset_names = cfg.DATASETS.TRAIN
+        else:
+            min_size = cfg.INPUT.MIN_SIZE_TEST
+            max_size = cfg.INPUT.MAX_SIZE_TEST
+            sample_style = "choice"
+            augs = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
+            dataset_names = cfg.DATASETS.TEST
+        meta = MetadataCatalog.get(dataset_names[0])
+        ignore_label = meta.ignore_label
+
+        ret = {
+            "is_train": is_train,
+            "augmentations": augs,
+            "image_format": cfg.INPUT.FORMAT,
+            "ignore_label": ignore_label,
+            "size_divisibility": cfg.INPUT.SIZE_DIVISIBILITY if is_train else -1,
+        }
+        return ret
+
+    def __call__(self, dataset_dict):
+        """
+        Args:
+            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
+
+        Returns:
+            dict: a format that builtin models in detectron2 accept
+        """
+        # assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"
+
+        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
+        image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
+        utils.check_image_size(dataset_dict, image)
+
+        if "sem_seg_file_name" in dataset_dict:
+            # PyTorch transformation not implemented for uint16, so converting it to double first
+            sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype(
+                "double"
+            )
+        else:
+            sem_seg_gt = None
+
+        if sem_seg_gt is None:
+            raise ValueError(
+                "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
+                    dataset_dict["file_name"]
+                )
+            )
+
+        aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
+        aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
+        image = aug_input.image
+        sem_seg_gt = aug_input.sem_seg
+
+        # Pad image and segmentation label here!
+        image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
+        if sem_seg_gt is not None:
+            sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
+
+        if self.size_divisibility > 0:
+            image_size = (image.shape[-2], image.shape[-1])
+            padding_size = [
+                0,
+                self.size_divisibility - image_size[1],
+                0,
+                self.size_divisibility - image_size[0],
+            ]
+            image = F.pad(image, padding_size, value=128).contiguous()
+            if sem_seg_gt is not None:
+                sem_seg_gt = F.pad(
+                    sem_seg_gt, padding_size, value=self.ignore_label
+                ).contiguous()
+
+        image_shape = (image.shape[-2], image.shape[-1])  # h, w
+
+        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
+        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
+        # Therefore it's important to use torch.Tensor.
+        dataset_dict["image"] = image
+
+        if sem_seg_gt is not None:
+            dataset_dict["sem_seg"] = sem_seg_gt.long()
+
+        if "annotations" in dataset_dict:
+            raise ValueError(
+                "Semantic segmentation dataset should not have 'annotations'."
+            )
+
+        # Prepare per-category binary masks
+        if sem_seg_gt is not None:
+            sem_seg_gt = sem_seg_gt.numpy()
+            instances = Instances(image_shape)
+            classes = np.unique(sem_seg_gt)
+            # remove ignored region
+            classes = classes[classes != self.ignore_label]
+            instances.gt_classes = torch.tensor(classes, dtype=torch.int64)
+
+            masks = []
+            for class_id in classes:
+                masks.append(sem_seg_gt == class_id)
+
+            if len(masks) == 0:
+                # Some image does not have annotation (all ignored)
+                instances.gt_masks = torch.zeros(
+                    (0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1])
+                )
+            else:
+                masks = BitMasks(
+                    torch.stack(
+                        [
+                            torch.from_numpy(np.ascontiguousarray(x.copy()))
+                            for x in masks
+                        ]
+                    )
+                )
+                instances.gt_masks = masks.tensor
+
+            dataset_dict["instances"] = instances
+
+        return dataset_dict
diff --git a/open_vocab_seg/data/datasets/__init__.py b/open_vocab_seg/data/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..352792b6fcdbffefa229d5d67a5c7375769fa345
--- /dev/null
+++ b/open_vocab_seg/data/datasets/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from . import register_coco_stuff, register_voc_seg
+from . import register_cc3m
+from . import register_ade20k_full
+from . import register_pascal_context
\ No newline at end of file
diff --git a/open_vocab_seg/data/datasets/csv_data.py b/open_vocab_seg/data/datasets/csv_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a4c9e52b0b792d49c48fe8bc2693be5ea879581
--- /dev/null
+++ b/open_vocab_seg/data/datasets/csv_data.py
@@ -0,0 +1,459 @@
+#   Copyright (c) Meta Platforms, Inc. and affiliates.
+import ast
+import json
+import logging
+import math
+import os
+import random
+import sys
+import time
+from dataclasses import dataclass
+from multiprocessing import Value
+
+import braceexpand
+import numpy as np
+import pandas as pd
+import torch
+import torchvision.datasets as datasets
+import webdataset as wds
+from PIL import Image
+from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, IterableDataset, get_worker_info
+from torch.utils.data.distributed import DistributedSampler
+from webdataset.filters import _shuffle
+from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample
+
+try:
+    import horovod.torch as hvd
+except ImportError:
+    hvd = None
+
+from clip import tokenize
+
+
+class CsvDataset(Dataset):
+    def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
+        logging.debug(f'Loading csv data from {input_filename}.')
+        df = pd.read_csv(input_filename, sep=sep)
+
+        self.images = df[img_key].tolist()
+        self.captions = df[caption_key].tolist()
+        self.transforms = transforms
+        logging.debug('Done loading data.')
+
+    def __len__(self):
+        return len(self.captions)
+
+    def __getitem__(self, idx):
+        images = self.transforms(Image.open(str(self.images[idx])))
+        texts = tokenize([str(self.captions[idx])])[0]
+        return images, texts
+
+
+class SharedEpoch:
+    def __init__(self, epoch: int = 0):
+        self.shared_epoch = Value('i', epoch)
+
+    def set_value(self, epoch):
+        self.shared_epoch.value = epoch
+
+    def get_value(self):
+        return self.shared_epoch.value
+
+
+@dataclass
+class DataInfo:
+    dataloader: DataLoader
+    sampler: DistributedSampler = None
+    shared_epoch: SharedEpoch = None
+
+    def set_epoch(self, epoch):
+        if self.shared_epoch is not None:
+            self.shared_epoch.set_value(epoch)
+        if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
+            self.sampler.set_epoch(epoch)
+
+
+def preprocess_txt(text):
+    return tokenize([str(text)])[0]
+
+
+def get_dataset_size(shards):
+    shards_list = list(braceexpand.braceexpand(shards))
+    dir_path = os.path.dirname(shards)
+    sizes_filename = os.path.join(dir_path, 'sizes.json')
+    len_filename = os.path.join(dir_path, '__len__')
+    if os.path.exists(sizes_filename):
+        sizes = json.load(open(sizes_filename, 'r'))
+        total_size = sum([int(sizes[os.path.basename(shard)]) for shard in shards_list])
+    elif os.path.exists(len_filename):
+        # FIXME this used to be eval(open(...)) but that seemed rather unsafe
+        total_size = ast.literal_eval(open(len_filename, 'r').read())
+    else:
+        total_size = None  # num samples undefined
+        # some common dataset sizes (at time of authors last download)
+        # CC3M (train): 2905954
+        # CC12M: 10968539
+        # LAION-400M: 407332084
+        # LAION-2B (english): 2170337258
+    num_shards = len(shards_list)
+    return total_size, num_shards
+
+
+def get_imagenet(args, preprocess_fns, split):
+    assert split in ["train", "val", "v2"]
+    is_train = split == "train"
+    preprocess_train, preprocess_val = preprocess_fns
+
+    if split == "v2":
+        from imagenetv2_pytorch import ImageNetV2Dataset
+        dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
+    else:
+        if is_train:
+            data_path = args.imagenet_train
+            preprocess_fn = preprocess_train
+        else:
+            data_path = args.imagenet_val
+            preprocess_fn = preprocess_val
+        assert data_path
+
+        dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
+
+    if is_train:
+        idxs = np.zeros(len(dataset.targets))
+        target_array = np.array(dataset.targets)
+        k = 50
+        for c in range(1000):
+            m = target_array == c
+            n = len(idxs[m])
+            arr = np.zeros(n)
+            arr[:k] = 1
+            np.random.shuffle(arr)
+            idxs[m] = arr
+
+        idxs = idxs.astype('int')
+        sampler = SubsetRandomSampler(np.where(idxs)[0])
+    else:
+        sampler = None
+
+    dataloader = torch.utils.data.DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        num_workers=args.workers,
+        sampler=sampler,
+    )
+
+    return DataInfo(dataloader=dataloader, sampler=sampler)
+
+
+def count_samples(dataloader):
+    os.environ["WDS_EPOCH"] = "0"
+    n_elements, n_batches = 0, 0
+    for images, texts in dataloader:
+        n_batches += 1
+        n_elements += len(images)
+        assert len(images) == len(texts)
+    return n_elements, n_batches
+
+
+def filter_no_caption(sample):
+    return 'txt' in sample
+
+
+def log_and_continue(exn):
+    """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
+    logging.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
+    return True
+
+
+def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None):
+    """Return function over iterator that groups key, value pairs into samples.
+
+    :param keys: function that splits the key into key and extension (base_plus_ext)
+    :param lcase: convert suffixes to lower case (Default value = True)
+    """
+    current_sample = None
+    for filesample in data:
+        assert isinstance(filesample, dict)
+        fname, value = filesample["fname"], filesample["data"]
+        prefix, suffix = keys(fname)
+        if prefix is None:
+            continue
+        if lcase:
+            suffix = suffix.lower()
+        # FIXME webdataset version throws if suffix in current_sample, but we have a potential for
+        #  this happening in the current LAION400m dataset if a tar ends with same prefix as the next
+        #  begins, rare, but can happen since prefix aren't unique across tar files in that dataset
+        if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample:
+            if valid_sample(current_sample):
+                yield current_sample
+            current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
+        if suffixes is None or suffix in suffixes:
+            current_sample[suffix] = value
+    if valid_sample(current_sample):
+        yield current_sample
+
+
+def tarfile_to_samples_nothrow(src, handler=log_and_continue):
+    # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
+    streams = url_opener(src, handler=handler)
+    files = tar_file_expander(streams, handler=handler)
+    samples = group_by_keys_nothrow(files, handler=handler)
+    return samples
+
+
+def pytorch_worker_seed():
+    """get dataloader worker seed from pytorch"""
+    worker_info = get_worker_info()
+    if worker_info is not None:
+        # favour the seed already created for pytorch dataloader workers if it exists
+        return worker_info.seed
+    # fallback to wds rank based seed
+    return wds.utils.pytorch_worker_seed()
+
+
+_SHARD_SHUFFLE_SIZE = 2000
+_SHARD_SHUFFLE_INITIAL = 500
+_SAMPLE_SHUFFLE_SIZE = 5000
+_SAMPLE_SHUFFLE_INITIAL = 1000
+
+
+class detshuffle2(wds.PipelineStage):
+    def __init__(
+            self,
+            bufsize=1000,
+            initial=100,
+            seed=0,
+            epoch=-1,
+    ):
+        self.bufsize = bufsize
+        self.initial = initial
+        self.seed = seed
+        self.epoch = epoch
+
+    def run(self, src):
+        if isinstance(self.epoch, SharedEpoch):
+            epoch = self.epoch.get_value()
+        else:
+            # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
+            # situation as different workers may wrap at different times (or not at all).
+            self.epoch += 1
+            epoch = self.epoch
+        rng = random.Random()
+        if self.seed < 0:
+            seed = pytorch_worker_seed() + epoch
+        else:
+            seed = self.seed + epoch
+        rng.seed(seed)
+        return _shuffle(src, self.bufsize, self.initial, rng)
+
+
+class ResampledShards2(IterableDataset):
+    """An iterable dataset yielding a list of urls."""
+
+    def __init__(
+        self,
+        urls,
+        nshards=sys.maxsize,
+        worker_seed=None,
+        deterministic=False,
+        epoch=-1,
+    ):
+        """Sample shards from the shard list with replacement.
+
+        :param urls: a list of URLs as a Python list or brace notation string
+        """
+        super().__init__()
+        urls = wds.shardlists.expand_urls(urls)
+        self.urls = urls
+        assert isinstance(self.urls[0], str)
+        self.nshards = nshards
+        self.rng = random.Random()
+        self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed
+        self.deterministic = deterministic
+        self.epoch = epoch
+
+    def __iter__(self):
+        """Return an iterator over the shards."""
+        if isinstance(self.epoch, SharedEpoch):
+            epoch = self.epoch.get_value()
+        else:
+            # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
+            # situation as different workers may wrap at different times (or not at all).
+            self.epoch += 1
+            epoch = self.epoch
+        if self.deterministic:
+            # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed
+            self.rng.seed(self.worker_seed() + epoch)
+        for _ in range(self.nshards):
+            yield dict(url=self.rng.choice(self.urls))
+
+
+def get_wds_dataset(args, preprocess_img, is_train, epoch=0, floor=False):
+    input_shards = args.train_data if is_train else args.val_data
+    assert input_shards is not None
+    resampled = getattr(args, 'dataset_resampled', False) and is_train
+
+    num_samples, num_shards = get_dataset_size(input_shards)
+    if not num_samples:
+        if is_train:
+            num_samples = args.train_num_samples
+            if not num_samples:
+                raise RuntimeError(
+                    'Currently, number of dataset samples must be specified for training dataset. '
+                    'Please specify via `--train-num-samples` if no dataset length info present.')
+        else:
+            num_samples = args.val_num_samples or 0  # eval will just exhaust the iterator if not specified
+
+    shared_epoch = SharedEpoch(epoch=epoch)  # create a shared epoch store to sync epoch to dataloader worker proc
+    if resampled:
+        pipeline = [ResampledShards2(input_shards, deterministic=True, epoch=shared_epoch)]
+    else:
+        pipeline = [wds.SimpleShardList(input_shards)]
+
+    # at this point we have an iterator over all the shards
+    if is_train:
+        if not resampled:
+            pipeline.extend([
+                detshuffle2(
+                    bufsize=_SHARD_SHUFFLE_SIZE,
+                    initial=_SHARD_SHUFFLE_INITIAL,
+                    seed=args.seed,
+                    epoch=shared_epoch,
+                ),
+                wds.split_by_node,
+                wds.split_by_worker,
+            ])
+        pipeline.extend([
+            # at this point, we have an iterator over the shards assigned to each worker at each node
+            tarfile_to_samples_nothrow,  # wds.tarfile_to_samples(handler=log_and_continue),
+            wds.shuffle(
+                bufsize=_SAMPLE_SHUFFLE_SIZE,
+                initial=_SAMPLE_SHUFFLE_INITIAL,
+            ),
+        ])
+    else:
+        pipeline.extend([
+            wds.split_by_worker,
+            # at this point, we have an iterator over the shards assigned to each worker
+            wds.tarfile_to_samples(handler=log_and_continue),
+        ])
+    pipeline.extend([
+        wds.select(filter_no_caption),
+        wds.decode("pilrgb", handler=log_and_continue),
+        wds.rename(image="jpg;png", text="txt"),
+        wds.map_dict(image=preprocess_img, text=preprocess_txt),
+        wds.to_tuple("image", "text"),
+        wds.batched(args.batch_size, partial=not is_train),
+    ])
+
+    dataset = wds.DataPipeline(*pipeline)
+    if is_train:
+        if not resampled:
+            assert num_shards >= args.workers * args.world_size, 'number of shards must be >= total workers'
+        # roll over and repeat a few samples to get same number of full batches on each node
+        round_fn = math.floor if floor else math.ceil
+        global_batch_size = args.batch_size * args.world_size
+        num_batches = round_fn(num_samples / global_batch_size)
+        num_workers = max(1, args.workers)
+        num_worker_batches = round_fn(num_batches / num_workers)  # per dataloader worker
+        num_batches = num_worker_batches * num_workers
+        num_samples = num_batches * global_batch_size
+        dataset = dataset.with_epoch(num_worker_batches)  # each worker is iterating over this
+    else:
+        # last batches are partial, eval is done on single (master) node
+        num_batches = math.ceil(num_samples / args.batch_size)
+
+    dataloader = wds.WebLoader(
+        dataset,
+        batch_size=None,
+        shuffle=False,
+        num_workers=args.workers,
+        persistent_workers=True,
+    )
+
+    # FIXME not clear which approach is better, with_epoch before vs after dataloader?
+    # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
+    # if is_train:
+    #     # roll over and repeat a few samples to get same number of full batches on each node
+    #     global_batch_size = args.batch_size * args.world_size
+    #     num_batches = math.ceil(num_samples / global_batch_size)
+    #     num_workers = max(1, args.workers)
+    #     num_batches = math.ceil(num_batches / num_workers) * num_workers
+    #     num_samples = num_batches * global_batch_size
+    #     dataloader = dataloader.with_epoch(num_batches)
+    # else:
+    #     # last batches are partial, eval is done on single (master) node
+    #     num_batches = math.ceil(num_samples / args.batch_size)
+
+    # add meta-data to dataloader instance for convenience
+    dataloader.num_batches = num_batches
+    dataloader.num_samples = num_samples
+
+    return DataInfo(dataloader=dataloader, shared_epoch=shared_epoch)
+
+
+def get_csv_dataset(args, preprocess_fn, is_train, epoch=0):
+    input_filename = args.train_data if is_train else args.val_data
+    assert input_filename
+    dataset = CsvDataset(
+        input_filename,
+        preprocess_fn,
+        img_key=args.csv_img_key,
+        caption_key=args.csv_caption_key,
+        sep=args.csv_separator)
+    num_samples = len(dataset)
+    sampler = DistributedSampler(dataset) if args.distributed and is_train else None
+    shuffle = is_train and sampler is None
+
+    dataloader = DataLoader(
+        dataset,
+        batch_size=args.batch_size,
+        shuffle=shuffle,
+        num_workers=args.workers,
+        pin_memory=True,
+        sampler=sampler,
+        drop_last=is_train,
+    )
+    dataloader.num_samples = num_samples
+    dataloader.num_batches = len(dataloader)
+
+    return DataInfo(dataloader, sampler)
+
+
+def get_dataset_fn(data_path, dataset_type):
+    if dataset_type == "webdataset":
+        return get_wds_dataset
+    elif dataset_type == "csv":
+        return get_csv_dataset
+    elif dataset_type == "auto":
+        ext = data_path.split('.')[-1]
+        if ext in ['csv', 'tsv']:
+            return get_csv_dataset
+        elif ext in ['tar']:
+            return get_wds_dataset
+        else:
+            raise ValueError(
+                f"Tried to figure out dataset type, but failed for extention {ext}.")
+    else:
+        raise ValueError(f"Unsupported dataset type: {dataset_type}")
+
+
+def get_data(args, preprocess_fns, epoch=0):
+    preprocess_train, preprocess_val = preprocess_fns
+    data = {}
+
+    if args.train_data:
+        data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
+            args, preprocess_train, is_train=True, epoch=epoch)
+
+    if args.val_data:
+        data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
+            args, preprocess_val, is_train=False)
+
+    if args.imagenet_val is not None:
+        data["imagenet-val"] = get_imagenet(args, preprocess_fns, "val")
+
+    if args.imagenet_v2 is not None:
+        data["imagenet-v2"] = get_imagenet(args, preprocess_fns, "v2")
+
+    return data
diff --git a/open_vocab_seg/data/datasets/register_ade20k_full.py b/open_vocab_seg/data/datasets/register_ade20k_full.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ba35274c8ba7f03cbe92621f944c8368794497f
--- /dev/null
+++ b/open_vocab_seg/data/datasets/register_ade20k_full.py
@@ -0,0 +1,995 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+
+ADE20K_SEM_SEG_FULL_CATEGORIES = [
+    {"name": "wall", "id": 2978, "trainId": 0},
+    {"name": "building, edifice", "id": 312, "trainId": 1},
+    {"name": "sky", "id": 2420, "trainId": 2},
+    {"name": "tree", "id": 2855, "trainId": 3},
+    {"name": "road, route", "id": 2131, "trainId": 4},
+    {"name": "floor, flooring", "id": 976, "trainId": 5},
+    {"name": "ceiling", "id": 447, "trainId": 6},
+    {"name": "bed", "id": 165, "trainId": 7},
+    {"name": "sidewalk, pavement", "id": 2377, "trainId": 8},
+    {"name": "earth, ground", "id": 838, "trainId": 9},
+    {"name": "cabinet", "id": 350, "trainId": 10},
+    {
+        "name": "person, individual, someone, somebody, mortal, soul",
+        "id": 1831,
+        "trainId": 11,
+    },
+    {"name": "grass", "id": 1125, "trainId": 12},
+    {"name": "windowpane, window", "id": 3055, "trainId": 13},
+    {"name": "car, auto, automobile, machine, motorcar", "id": 401, "trainId": 14},
+    {"name": "mountain, mount", "id": 1610, "trainId": 15},
+    {"name": "plant, flora, plant life", "id": 1910, "trainId": 16},
+    {"name": "table", "id": 2684, "trainId": 17},
+    {"name": "chair", "id": 471, "trainId": 18},
+    {"name": "curtain, drape, drapery, mantle, pall", "id": 687, "trainId": 19},
+    {"name": "door", "id": 774, "trainId": 20},
+    {"name": "sofa, couch, lounge", "id": 2473, "trainId": 21},
+    {"name": "sea", "id": 2264, "trainId": 22},
+    {"name": "painting, picture", "id": 1735, "trainId": 23},
+    {"name": "water", "id": 2994, "trainId": 24},
+    {"name": "mirror", "id": 1564, "trainId": 25},
+    {"name": "house", "id": 1276, "trainId": 26},
+    {"name": "rug, carpet, carpeting", "id": 2178, "trainId": 27},
+    {"name": "shelf", "id": 2329, "trainId": 28},
+    {"name": "armchair", "id": 57, "trainId": 29},
+    {"name": "fence, fencing", "id": 907, "trainId": 30},
+    {"name": "field", "id": 913, "trainId": 31},
+    {"name": "lamp", "id": 1395, "trainId": 32},
+    {"name": "rock, stone", "id": 2138, "trainId": 33},
+    {"name": "seat", "id": 2272, "trainId": 34},
+    {"name": "river", "id": 2128, "trainId": 35},
+    {"name": "desk", "id": 724, "trainId": 36},
+    {"name": "bathtub, bathing tub, bath, tub", "id": 155, "trainId": 37},
+    {"name": "railing, rail", "id": 2053, "trainId": 38},
+    {"name": "signboard, sign", "id": 2380, "trainId": 39},
+    {"name": "cushion", "id": 689, "trainId": 40},
+    {"name": "path", "id": 1788, "trainId": 41},
+    {"name": "work surface", "id": 3087, "trainId": 42},
+    {"name": "stairs, steps", "id": 2530, "trainId": 43},
+    {"name": "column, pillar", "id": 581, "trainId": 44},
+    {"name": "sink", "id": 2388, "trainId": 45},
+    {"name": "wardrobe, closet, press", "id": 2985, "trainId": 46},
+    {"name": "snow", "id": 2454, "trainId": 47},
+    {"name": "refrigerator, icebox", "id": 2096, "trainId": 48},
+    {"name": "base, pedestal, stand", "id": 137, "trainId": 49},
+    {"name": "bridge, span", "id": 294, "trainId": 50},
+    {"name": "blind, screen", "id": 212, "trainId": 51},
+    {"name": "runway", "id": 2185, "trainId": 52},
+    {"name": "cliff, drop, drop-off", "id": 524, "trainId": 53},
+    {"name": "sand", "id": 2212, "trainId": 54},
+    {"name": "fireplace, hearth, open fireplace", "id": 943, "trainId": 55},
+    {"name": "pillow", "id": 1869, "trainId": 56},
+    {"name": "screen door, screen", "id": 2251, "trainId": 57},
+    {
+        "name": "toilet, can, commode, crapper, pot, potty, stool, throne",
+        "id": 2793,
+        "trainId": 58,
+    },
+    {"name": "skyscraper", "id": 2423, "trainId": 59},
+    {"name": "grandstand, covered stand", "id": 1121, "trainId": 60},
+    {"name": "box", "id": 266, "trainId": 61},
+    {"name": "pool table, billiard table, snooker table", "id": 1948, "trainId": 62},
+    {"name": "palm, palm tree", "id": 1744, "trainId": 63},
+    {"name": "double door", "id": 783, "trainId": 64},
+    {"name": "coffee table, cocktail table", "id": 571, "trainId": 65},
+    {"name": "counter", "id": 627, "trainId": 66},
+    {"name": "countertop", "id": 629, "trainId": 67},
+    {"name": "chest of drawers, chest, bureau, dresser", "id": 491, "trainId": 68},
+    {"name": "kitchen island", "id": 1374, "trainId": 69},
+    {"name": "boat", "id": 223, "trainId": 70},
+    {"name": "waterfall, falls", "id": 3016, "trainId": 71},
+    {
+        "name": "stove, kitchen stove, range, kitchen range, cooking stove",
+        "id": 2598,
+        "trainId": 72,
+    },
+    {"name": "flower", "id": 978, "trainId": 73},
+    {"name": "bookcase", "id": 239, "trainId": 74},
+    {"name": "controls", "id": 608, "trainId": 75},
+    {"name": "book", "id": 236, "trainId": 76},
+    {"name": "stairway, staircase", "id": 2531, "trainId": 77},
+    {"name": "streetlight, street lamp", "id": 2616, "trainId": 78},
+    {
+        "name": "computer, computing machine, computing device, data processor, electronic computer, information processing system",
+        "id": 591,
+        "trainId": 79,
+    },
+    {
+        "name": "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, motorcoach, omnibus, passenger vehicle",
+        "id": 327,
+        "trainId": 80,
+    },
+    {"name": "swivel chair", "id": 2679, "trainId": 81},
+    {"name": "light, light source", "id": 1451, "trainId": 82},
+    {"name": "bench", "id": 181, "trainId": 83},
+    {"name": "case, display case, showcase, vitrine", "id": 420, "trainId": 84},
+    {"name": "towel", "id": 2821, "trainId": 85},
+    {"name": "fountain", "id": 1023, "trainId": 86},
+    {"name": "embankment", "id": 855, "trainId": 87},
+    {
+        "name": "television receiver, television, television set, tv, tv set, idiot box, boob tube, telly, goggle box",
+        "id": 2733,
+        "trainId": 88,
+    },
+    {"name": "van", "id": 2928, "trainId": 89},
+    {"name": "hill", "id": 1240, "trainId": 90},
+    {"name": "awning, sunshade, sunblind", "id": 77, "trainId": 91},
+    {"name": "poster, posting, placard, notice, bill, card", "id": 1969, "trainId": 92},
+    {"name": "truck, motortruck", "id": 2880, "trainId": 93},
+    {"name": "airplane, aeroplane, plane", "id": 14, "trainId": 94},
+    {"name": "pole", "id": 1936, "trainId": 95},
+    {"name": "tower", "id": 2828, "trainId": 96},
+    {"name": "court", "id": 631, "trainId": 97},
+    {"name": "ball", "id": 103, "trainId": 98},
+    {
+        "name": "aircraft carrier, carrier, flattop, attack aircraft carrier",
+        "id": 3144,
+        "trainId": 99,
+    },
+    {"name": "buffet, counter, sideboard", "id": 308, "trainId": 100},
+    {"name": "hovel, hut, hutch, shack, shanty", "id": 1282, "trainId": 101},
+    {"name": "apparel, wearing apparel, dress, clothes", "id": 38, "trainId": 102},
+    {"name": "minibike, motorbike", "id": 1563, "trainId": 103},
+    {
+        "name": "animal, animate being, beast, brute, creature, fauna",
+        "id": 29,
+        "trainId": 104,
+    },
+    {"name": "chandelier, pendant, pendent", "id": 480, "trainId": 105},
+    {"name": "step, stair", "id": 2569, "trainId": 106},
+    {"name": "booth, cubicle, stall, kiosk", "id": 247, "trainId": 107},
+    {"name": "bicycle, bike, wheel, cycle", "id": 187, "trainId": 108},
+    {"name": "doorframe, doorcase", "id": 778, "trainId": 109},
+    {"name": "sconce", "id": 2243, "trainId": 110},
+    {"name": "pond", "id": 1941, "trainId": 111},
+    {"name": "trade name, brand name, brand, marque", "id": 2833, "trainId": 112},
+    {
+        "name": "bannister, banister, balustrade, balusters, handrail",
+        "id": 120,
+        "trainId": 113,
+    },
+    {"name": "bag", "id": 95, "trainId": 114},
+    {"name": "traffic light, traffic signal, stoplight", "id": 2836, "trainId": 115},
+    {"name": "gazebo", "id": 1087, "trainId": 116},
+    {"name": "escalator, moving staircase, moving stairway", "id": 868, "trainId": 117},
+    {"name": "land, ground, soil", "id": 1401, "trainId": 118},
+    {"name": "board, plank", "id": 220, "trainId": 119},
+    {"name": "arcade machine", "id": 47, "trainId": 120},
+    {"name": "eiderdown, duvet, continental quilt", "id": 843, "trainId": 121},
+    {"name": "bar", "id": 123, "trainId": 122},
+    {"name": "stall, stand, sales booth", "id": 2537, "trainId": 123},
+    {"name": "playground", "id": 1927, "trainId": 124},
+    {"name": "ship", "id": 2337, "trainId": 125},
+    {"name": "ottoman, pouf, pouffe, puff, hassock", "id": 1702, "trainId": 126},
+    {
+        "name": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
+        "id": 64,
+        "trainId": 127,
+    },
+    {"name": "bottle", "id": 249, "trainId": 128},
+    {"name": "cradle", "id": 642, "trainId": 129},
+    {"name": "pot, flowerpot", "id": 1981, "trainId": 130},
+    {
+        "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
+        "id": 609,
+        "trainId": 131,
+    },
+    {"name": "train, railroad train", "id": 2840, "trainId": 132},
+    {"name": "stool", "id": 2586, "trainId": 133},
+    {"name": "lake", "id": 1393, "trainId": 134},
+    {"name": "tank, storage tank", "id": 2704, "trainId": 135},
+    {"name": "ice, water ice", "id": 1304, "trainId": 136},
+    {"name": "basket, handbasket", "id": 146, "trainId": 137},
+    {"name": "manhole", "id": 1494, "trainId": 138},
+    {"name": "tent, collapsible shelter", "id": 2739, "trainId": 139},
+    {"name": "canopy", "id": 389, "trainId": 140},
+    {"name": "microwave, microwave oven", "id": 1551, "trainId": 141},
+    {"name": "barrel, cask", "id": 131, "trainId": 142},
+    {"name": "dirt track", "id": 738, "trainId": 143},
+    {"name": "beam", "id": 161, "trainId": 144},
+    {"name": "dishwasher, dish washer, dishwashing machine", "id": 747, "trainId": 145},
+    {"name": "plate", "id": 1919, "trainId": 146},
+    {"name": "screen, crt screen", "id": 3109, "trainId": 147},
+    {"name": "ruins", "id": 2179, "trainId": 148},
+    {"name": "washer, automatic washer, washing machine", "id": 2989, "trainId": 149},
+    {"name": "blanket, cover", "id": 206, "trainId": 150},
+    {"name": "plaything, toy", "id": 1930, "trainId": 151},
+    {"name": "food, solid food", "id": 1002, "trainId": 152},
+    {"name": "screen, silver screen, projection screen", "id": 2254, "trainId": 153},
+    {"name": "oven", "id": 1708, "trainId": 154},
+    {"name": "stage", "id": 2526, "trainId": 155},
+    {"name": "beacon, lighthouse, beacon light, pharos", "id": 160, "trainId": 156},
+    {"name": "umbrella", "id": 2901, "trainId": 157},
+    {"name": "sculpture", "id": 2262, "trainId": 158},
+    {"name": "aqueduct", "id": 44, "trainId": 159},
+    {"name": "container", "id": 597, "trainId": 160},
+    {"name": "scaffolding, staging", "id": 2235, "trainId": 161},
+    {"name": "hood, exhaust hood", "id": 1260, "trainId": 162},
+    {"name": "curb, curbing, kerb", "id": 682, "trainId": 163},
+    {"name": "roller coaster", "id": 2151, "trainId": 164},
+    {"name": "horse, equus caballus", "id": 3107, "trainId": 165},
+    {"name": "catwalk", "id": 432, "trainId": 166},
+    {"name": "glass, drinking glass", "id": 1098, "trainId": 167},
+    {"name": "vase", "id": 2932, "trainId": 168},
+    {"name": "central reservation", "id": 461, "trainId": 169},
+    {"name": "carousel", "id": 410, "trainId": 170},
+    {"name": "radiator", "id": 2046, "trainId": 171},
+    {"name": "closet", "id": 533, "trainId": 172},
+    {"name": "machine", "id": 1481, "trainId": 173},
+    {"name": "pier, wharf, wharfage, dock", "id": 1858, "trainId": 174},
+    {"name": "fan", "id": 894, "trainId": 175},
+    {"name": "inflatable bounce game", "id": 1322, "trainId": 176},
+    {"name": "pitch", "id": 1891, "trainId": 177},
+    {"name": "paper", "id": 1756, "trainId": 178},
+    {"name": "arcade, colonnade", "id": 49, "trainId": 179},
+    {"name": "hot tub", "id": 1272, "trainId": 180},
+    {"name": "helicopter", "id": 1229, "trainId": 181},
+    {"name": "tray", "id": 2850, "trainId": 182},
+    {"name": "partition, divider", "id": 1784, "trainId": 183},
+    {"name": "vineyard", "id": 2962, "trainId": 184},
+    {"name": "bowl", "id": 259, "trainId": 185},
+    {"name": "bullring", "id": 319, "trainId": 186},
+    {"name": "flag", "id": 954, "trainId": 187},
+    {"name": "pot", "id": 1974, "trainId": 188},
+    {"name": "footbridge, overcrossing, pedestrian bridge", "id": 1013, "trainId": 189},
+    {"name": "shower", "id": 2356, "trainId": 190},
+    {
+        "name": "bag, traveling bag, travelling bag, grip, suitcase",
+        "id": 97,
+        "trainId": 191,
+    },
+    {"name": "bulletin board, notice board", "id": 318, "trainId": 192},
+    {"name": "confessional booth", "id": 592, "trainId": 193},
+    {"name": "trunk, tree trunk, bole", "id": 2885, "trainId": 194},
+    {"name": "forest", "id": 1017, "trainId": 195},
+    {"name": "elevator door", "id": 851, "trainId": 196},
+    {"name": "laptop, laptop computer", "id": 1407, "trainId": 197},
+    {"name": "instrument panel", "id": 1332, "trainId": 198},
+    {"name": "bucket, pail", "id": 303, "trainId": 199},
+    {"name": "tapestry, tapis", "id": 2714, "trainId": 200},
+    {"name": "platform", "id": 1924, "trainId": 201},
+    {"name": "jacket", "id": 1346, "trainId": 202},
+    {"name": "gate", "id": 1081, "trainId": 203},
+    {"name": "monitor, monitoring device", "id": 1583, "trainId": 204},
+    {
+        "name": "telephone booth, phone booth, call box, telephone box, telephone kiosk",
+        "id": 2727,
+        "trainId": 205,
+    },
+    {"name": "spotlight, spot", "id": 2509, "trainId": 206},
+    {"name": "ring", "id": 2123, "trainId": 207},
+    {"name": "control panel", "id": 602, "trainId": 208},
+    {"name": "blackboard, chalkboard", "id": 202, "trainId": 209},
+    {"name": "air conditioner, air conditioning", "id": 10, "trainId": 210},
+    {"name": "chest", "id": 490, "trainId": 211},
+    {"name": "clock", "id": 530, "trainId": 212},
+    {"name": "sand dune", "id": 2213, "trainId": 213},
+    {"name": "pipe, pipage, piping", "id": 1884, "trainId": 214},
+    {"name": "vault", "id": 2934, "trainId": 215},
+    {"name": "table football", "id": 2687, "trainId": 216},
+    {"name": "cannon", "id": 387, "trainId": 217},
+    {"name": "swimming pool, swimming bath, natatorium", "id": 2668, "trainId": 218},
+    {"name": "fluorescent, fluorescent fixture", "id": 982, "trainId": 219},
+    {"name": "statue", "id": 2547, "trainId": 220},
+    {
+        "name": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
+        "id": 1474,
+        "trainId": 221,
+    },
+    {"name": "exhibitor", "id": 877, "trainId": 222},
+    {"name": "ladder", "id": 1391, "trainId": 223},
+    {"name": "carport", "id": 414, "trainId": 224},
+    {"name": "dam", "id": 698, "trainId": 225},
+    {"name": "pulpit", "id": 2019, "trainId": 226},
+    {"name": "skylight, fanlight", "id": 2422, "trainId": 227},
+    {"name": "water tower", "id": 3010, "trainId": 228},
+    {"name": "grill, grille, grillwork", "id": 1139, "trainId": 229},
+    {"name": "display board", "id": 753, "trainId": 230},
+    {"name": "pane, pane of glass, window glass", "id": 1747, "trainId": 231},
+    {"name": "rubbish, trash, scrap", "id": 2175, "trainId": 232},
+    {"name": "ice rink", "id": 1301, "trainId": 233},
+    {"name": "fruit", "id": 1033, "trainId": 234},
+    {"name": "patio", "id": 1789, "trainId": 235},
+    {"name": "vending machine", "id": 2939, "trainId": 236},
+    {"name": "telephone, phone, telephone set", "id": 2730, "trainId": 237},
+    {"name": "net", "id": 1652, "trainId": 238},
+    {
+        "name": "backpack, back pack, knapsack, packsack, rucksack, haversack",
+        "id": 90,
+        "trainId": 239,
+    },
+    {"name": "jar", "id": 1349, "trainId": 240},
+    {"name": "track", "id": 2830, "trainId": 241},
+    {"name": "magazine", "id": 1485, "trainId": 242},
+    {"name": "shutter", "id": 2370, "trainId": 243},
+    {"name": "roof", "id": 2155, "trainId": 244},
+    {"name": "banner, streamer", "id": 118, "trainId": 245},
+    {"name": "landfill", "id": 1402, "trainId": 246},
+    {"name": "post", "id": 1957, "trainId": 247},
+    {"name": "altarpiece, reredos", "id": 3130, "trainId": 248},
+    {"name": "hat, chapeau, lid", "id": 1197, "trainId": 249},
+    {"name": "arch, archway", "id": 52, "trainId": 250},
+    {"name": "table game", "id": 2688, "trainId": 251},
+    {"name": "bag, handbag, pocketbook, purse", "id": 96, "trainId": 252},
+    {"name": "document, written document, papers", "id": 762, "trainId": 253},
+    {"name": "dome", "id": 772, "trainId": 254},
+    {"name": "pier", "id": 1857, "trainId": 255},
+    {"name": "shanties", "id": 2315, "trainId": 256},
+    {"name": "forecourt", "id": 1016, "trainId": 257},
+    {"name": "crane", "id": 643, "trainId": 258},
+    {"name": "dog, domestic dog, canis familiaris", "id": 3105, "trainId": 259},
+    {"name": "piano, pianoforte, forte-piano", "id": 1849, "trainId": 260},
+    {"name": "drawing", "id": 791, "trainId": 261},
+    {"name": "cabin", "id": 349, "trainId": 262},
+    {
+        "name": "ad, advertisement, advertizement, advertising, advertizing, advert",
+        "id": 6,
+        "trainId": 263,
+    },
+    {"name": "amphitheater, amphitheatre, coliseum", "id": 3114, "trainId": 264},
+    {"name": "monument", "id": 1587, "trainId": 265},
+    {"name": "henhouse", "id": 1233, "trainId": 266},
+    {"name": "cockpit", "id": 559, "trainId": 267},
+    {"name": "heater, warmer", "id": 1223, "trainId": 268},
+    {"name": "windmill, aerogenerator, wind generator", "id": 3049, "trainId": 269},
+    {"name": "pool", "id": 1943, "trainId": 270},
+    {"name": "elevator, lift", "id": 853, "trainId": 271},
+    {"name": "decoration, ornament, ornamentation", "id": 709, "trainId": 272},
+    {"name": "labyrinth", "id": 1390, "trainId": 273},
+    {"name": "text, textual matter", "id": 2748, "trainId": 274},
+    {"name": "printer", "id": 2007, "trainId": 275},
+    {"name": "mezzanine, first balcony", "id": 1546, "trainId": 276},
+    {"name": "mattress", "id": 1513, "trainId": 277},
+    {"name": "straw", "id": 2600, "trainId": 278},
+    {"name": "stalls", "id": 2538, "trainId": 279},
+    {"name": "patio, terrace", "id": 1790, "trainId": 280},
+    {"name": "billboard, hoarding", "id": 194, "trainId": 281},
+    {"name": "bus stop", "id": 326, "trainId": 282},
+    {"name": "trouser, pant", "id": 2877, "trainId": 283},
+    {"name": "console table, console", "id": 594, "trainId": 284},
+    {"name": "rack", "id": 2036, "trainId": 285},
+    {"name": "notebook", "id": 1662, "trainId": 286},
+    {"name": "shrine", "id": 2366, "trainId": 287},
+    {"name": "pantry", "id": 1754, "trainId": 288},
+    {"name": "cart", "id": 418, "trainId": 289},
+    {"name": "steam shovel", "id": 2553, "trainId": 290},
+    {"name": "porch", "id": 1951, "trainId": 291},
+    {"name": "postbox, mailbox, letter box", "id": 1963, "trainId": 292},
+    {"name": "figurine, statuette", "id": 918, "trainId": 293},
+    {"name": "recycling bin", "id": 2086, "trainId": 294},
+    {"name": "folding screen", "id": 997, "trainId": 295},
+    {"name": "telescope", "id": 2731, "trainId": 296},
+    {"name": "deck chair, beach chair", "id": 704, "trainId": 297},
+    {"name": "kennel", "id": 1365, "trainId": 298},
+    {"name": "coffee maker", "id": 569, "trainId": 299},
+    {"name": "altar, communion table, lord's table", "id": 3108, "trainId": 300},
+    {"name": "fish", "id": 948, "trainId": 301},
+    {"name": "easel", "id": 839, "trainId": 302},
+    {"name": "artificial golf green", "id": 63, "trainId": 303},
+    {"name": "iceberg", "id": 1305, "trainId": 304},
+    {"name": "candlestick, candle holder", "id": 378, "trainId": 305},
+    {"name": "shower stall, shower bath", "id": 2362, "trainId": 306},
+    {"name": "television stand", "id": 2734, "trainId": 307},
+    {
+        "name": "wall socket, wall plug, electric outlet, electrical outlet, outlet, electric receptacle",
+        "id": 2982,
+        "trainId": 308,
+    },
+    {"name": "skeleton", "id": 2398, "trainId": 309},
+    {"name": "grand piano, grand", "id": 1119, "trainId": 310},
+    {"name": "candy, confect", "id": 382, "trainId": 311},
+    {"name": "grille door", "id": 1141, "trainId": 312},
+    {"name": "pedestal, plinth, footstall", "id": 1805, "trainId": 313},
+    {"name": "jersey, t-shirt, tee shirt", "id": 3102, "trainId": 314},
+    {"name": "shoe", "id": 2341, "trainId": 315},
+    {"name": "gravestone, headstone, tombstone", "id": 1131, "trainId": 316},
+    {"name": "shanty", "id": 2316, "trainId": 317},
+    {"name": "structure", "id": 2626, "trainId": 318},
+    {"name": "rocking chair, rocker", "id": 3104, "trainId": 319},
+    {"name": "bird", "id": 198, "trainId": 320},
+    {"name": "place mat", "id": 1896, "trainId": 321},
+    {"name": "tomb", "id": 2800, "trainId": 322},
+    {"name": "big top", "id": 190, "trainId": 323},
+    {
+        "name": "gas pump, gasoline pump, petrol pump, island dispenser",
+        "id": 3131,
+        "trainId": 324,
+    },
+    {"name": "lockers", "id": 1463, "trainId": 325},
+    {"name": "cage", "id": 357, "trainId": 326},
+    {"name": "finger", "id": 929, "trainId": 327},
+    {"name": "bleachers", "id": 209, "trainId": 328},
+    {"name": "ferris wheel", "id": 912, "trainId": 329},
+    {"name": "hairdresser chair", "id": 1164, "trainId": 330},
+    {"name": "mat", "id": 1509, "trainId": 331},
+    {"name": "stands", "id": 2539, "trainId": 332},
+    {"name": "aquarium, fish tank, marine museum", "id": 3116, "trainId": 333},
+    {
+        "name": "streetcar, tram, tramcar, trolley, trolley car",
+        "id": 2615,
+        "trainId": 334,
+    },
+    {"name": "napkin, table napkin, serviette", "id": 1644, "trainId": 335},
+    {"name": "dummy", "id": 818, "trainId": 336},
+    {"name": "booklet, brochure, folder, leaflet, pamphlet", "id": 242, "trainId": 337},
+    {"name": "sand trap", "id": 2217, "trainId": 338},
+    {"name": "shop, store", "id": 2347, "trainId": 339},
+    {"name": "table cloth", "id": 2686, "trainId": 340},
+    {"name": "service station", "id": 2300, "trainId": 341},
+    {"name": "coffin", "id": 572, "trainId": 342},
+    {"name": "drawer", "id": 789, "trainId": 343},
+    {"name": "cages", "id": 358, "trainId": 344},
+    {"name": "slot machine, coin machine", "id": 2443, "trainId": 345},
+    {"name": "balcony", "id": 101, "trainId": 346},
+    {"name": "volleyball court", "id": 2969, "trainId": 347},
+    {"name": "table tennis", "id": 2692, "trainId": 348},
+    {"name": "control table", "id": 606, "trainId": 349},
+    {"name": "shirt", "id": 2339, "trainId": 350},
+    {"name": "merchandise, ware, product", "id": 1533, "trainId": 351},
+    {"name": "railway", "id": 2060, "trainId": 352},
+    {"name": "parterre", "id": 1782, "trainId": 353},
+    {"name": "chimney", "id": 495, "trainId": 354},
+    {"name": "can, tin, tin can", "id": 371, "trainId": 355},
+    {"name": "tanks", "id": 2707, "trainId": 356},
+    {"name": "fabric, cloth, material, textile", "id": 889, "trainId": 357},
+    {"name": "alga, algae", "id": 3156, "trainId": 358},
+    {"name": "system", "id": 2683, "trainId": 359},
+    {"name": "map", "id": 1499, "trainId": 360},
+    {"name": "greenhouse", "id": 1135, "trainId": 361},
+    {"name": "mug", "id": 1619, "trainId": 362},
+    {"name": "barbecue", "id": 125, "trainId": 363},
+    {"name": "trailer", "id": 2838, "trainId": 364},
+    {
+        "name": "toilet tissue, toilet paper, bathroom tissue",
+        "id": 2792,
+        "trainId": 365,
+    },
+    {"name": "organ", "id": 1695, "trainId": 366},
+    {"name": "dishrag, dishcloth", "id": 746, "trainId": 367},
+    {"name": "island", "id": 1343, "trainId": 368},
+    {"name": "keyboard", "id": 1370, "trainId": 369},
+    {"name": "trench", "id": 2858, "trainId": 370},
+    {"name": "basket, basketball hoop, hoop", "id": 145, "trainId": 371},
+    {"name": "steering wheel, wheel", "id": 2565, "trainId": 372},
+    {"name": "pitcher, ewer", "id": 1892, "trainId": 373},
+    {"name": "goal", "id": 1103, "trainId": 374},
+    {"name": "bread, breadstuff, staff of life", "id": 286, "trainId": 375},
+    {"name": "beds", "id": 170, "trainId": 376},
+    {"name": "wood", "id": 3073, "trainId": 377},
+    {"name": "file cabinet", "id": 922, "trainId": 378},
+    {"name": "newspaper, paper", "id": 1655, "trainId": 379},
+    {"name": "motorboat", "id": 1602, "trainId": 380},
+    {"name": "rope", "id": 2160, "trainId": 381},
+    {"name": "guitar", "id": 1151, "trainId": 382},
+    {"name": "rubble", "id": 2176, "trainId": 383},
+    {"name": "scarf", "id": 2239, "trainId": 384},
+    {"name": "barrels", "id": 132, "trainId": 385},
+    {"name": "cap", "id": 394, "trainId": 386},
+    {"name": "leaves", "id": 1424, "trainId": 387},
+    {"name": "control tower", "id": 607, "trainId": 388},
+    {"name": "dashboard", "id": 700, "trainId": 389},
+    {"name": "bandstand", "id": 116, "trainId": 390},
+    {"name": "lectern", "id": 1425, "trainId": 391},
+    {"name": "switch, electric switch, electrical switch", "id": 2676, "trainId": 392},
+    {"name": "baseboard, mopboard, skirting board", "id": 141, "trainId": 393},
+    {"name": "shower room", "id": 2360, "trainId": 394},
+    {"name": "smoke", "id": 2449, "trainId": 395},
+    {"name": "faucet, spigot", "id": 897, "trainId": 396},
+    {"name": "bulldozer", "id": 317, "trainId": 397},
+    {"name": "saucepan", "id": 2228, "trainId": 398},
+    {"name": "shops", "id": 2351, "trainId": 399},
+    {"name": "meter", "id": 1543, "trainId": 400},
+    {"name": "crevasse", "id": 656, "trainId": 401},
+    {"name": "gear", "id": 1088, "trainId": 402},
+    {"name": "candelabrum, candelabra", "id": 373, "trainId": 403},
+    {"name": "sofa bed", "id": 2472, "trainId": 404},
+    {"name": "tunnel", "id": 2892, "trainId": 405},
+    {"name": "pallet", "id": 1740, "trainId": 406},
+    {"name": "wire, conducting wire", "id": 3067, "trainId": 407},
+    {"name": "kettle, boiler", "id": 1367, "trainId": 408},
+    {"name": "bidet", "id": 188, "trainId": 409},
+    {
+        "name": "baby buggy, baby carriage, carriage, perambulator, pram, stroller, go-cart, pushchair, pusher",
+        "id": 79,
+        "trainId": 410,
+    },
+    {"name": "music stand", "id": 1633, "trainId": 411},
+    {"name": "pipe, tube", "id": 1885, "trainId": 412},
+    {"name": "cup", "id": 677, "trainId": 413},
+    {"name": "parking meter", "id": 1779, "trainId": 414},
+    {"name": "ice hockey rink", "id": 1297, "trainId": 415},
+    {"name": "shelter", "id": 2334, "trainId": 416},
+    {"name": "weeds", "id": 3027, "trainId": 417},
+    {"name": "temple", "id": 2735, "trainId": 418},
+    {"name": "patty, cake", "id": 1791, "trainId": 419},
+    {"name": "ski slope", "id": 2405, "trainId": 420},
+    {"name": "panel", "id": 1748, "trainId": 421},
+    {"name": "wallet", "id": 2983, "trainId": 422},
+    {"name": "wheel", "id": 3035, "trainId": 423},
+    {"name": "towel rack, towel horse", "id": 2824, "trainId": 424},
+    {"name": "roundabout", "id": 2168, "trainId": 425},
+    {"name": "canister, cannister, tin", "id": 385, "trainId": 426},
+    {"name": "rod", "id": 2148, "trainId": 427},
+    {"name": "soap dispenser", "id": 2465, "trainId": 428},
+    {"name": "bell", "id": 175, "trainId": 429},
+    {"name": "canvas", "id": 390, "trainId": 430},
+    {"name": "box office, ticket office, ticket booth", "id": 268, "trainId": 431},
+    {"name": "teacup", "id": 2722, "trainId": 432},
+    {"name": "trellis", "id": 2857, "trainId": 433},
+    {"name": "workbench", "id": 3088, "trainId": 434},
+    {"name": "valley, vale", "id": 2926, "trainId": 435},
+    {"name": "toaster", "id": 2782, "trainId": 436},
+    {"name": "knife", "id": 1378, "trainId": 437},
+    {"name": "podium", "id": 1934, "trainId": 438},
+    {"name": "ramp", "id": 2072, "trainId": 439},
+    {"name": "tumble dryer", "id": 2889, "trainId": 440},
+    {"name": "fireplug, fire hydrant, plug", "id": 944, "trainId": 441},
+    {"name": "gym shoe, sneaker, tennis shoe", "id": 1158, "trainId": 442},
+    {"name": "lab bench", "id": 1383, "trainId": 443},
+    {"name": "equipment", "id": 867, "trainId": 444},
+    {"name": "rocky formation", "id": 2145, "trainId": 445},
+    {"name": "plastic", "id": 1915, "trainId": 446},
+    {"name": "calendar", "id": 361, "trainId": 447},
+    {"name": "caravan", "id": 402, "trainId": 448},
+    {"name": "check-in-desk", "id": 482, "trainId": 449},
+    {"name": "ticket counter", "id": 2761, "trainId": 450},
+    {"name": "brush", "id": 300, "trainId": 451},
+    {"name": "mill", "id": 1554, "trainId": 452},
+    {"name": "covered bridge", "id": 636, "trainId": 453},
+    {"name": "bowling alley", "id": 260, "trainId": 454},
+    {"name": "hanger", "id": 1186, "trainId": 455},
+    {"name": "excavator", "id": 871, "trainId": 456},
+    {"name": "trestle", "id": 2859, "trainId": 457},
+    {"name": "revolving door", "id": 2103, "trainId": 458},
+    {"name": "blast furnace", "id": 208, "trainId": 459},
+    {"name": "scale, weighing machine", "id": 2236, "trainId": 460},
+    {"name": "projector", "id": 2012, "trainId": 461},
+    {"name": "soap", "id": 2462, "trainId": 462},
+    {"name": "locker", "id": 1462, "trainId": 463},
+    {"name": "tractor", "id": 2832, "trainId": 464},
+    {"name": "stretcher", "id": 2617, "trainId": 465},
+    {"name": "frame", "id": 1024, "trainId": 466},
+    {"name": "grating", "id": 1129, "trainId": 467},
+    {"name": "alembic", "id": 18, "trainId": 468},
+    {"name": "candle, taper, wax light", "id": 376, "trainId": 469},
+    {"name": "barrier", "id": 134, "trainId": 470},
+    {"name": "cardboard", "id": 407, "trainId": 471},
+    {"name": "cave", "id": 434, "trainId": 472},
+    {"name": "puddle", "id": 2017, "trainId": 473},
+    {"name": "tarp", "id": 2717, "trainId": 474},
+    {"name": "price tag", "id": 2005, "trainId": 475},
+    {"name": "watchtower", "id": 2993, "trainId": 476},
+    {"name": "meters", "id": 1545, "trainId": 477},
+    {
+        "name": "light bulb, lightbulb, bulb, incandescent lamp, electric light, electric-light bulb",
+        "id": 1445,
+        "trainId": 478,
+    },
+    {"name": "tracks", "id": 2831, "trainId": 479},
+    {"name": "hair dryer", "id": 1161, "trainId": 480},
+    {"name": "skirt", "id": 2411, "trainId": 481},
+    {"name": "viaduct", "id": 2949, "trainId": 482},
+    {"name": "paper towel", "id": 1769, "trainId": 483},
+    {"name": "coat", "id": 552, "trainId": 484},
+    {"name": "sheet", "id": 2327, "trainId": 485},
+    {"name": "fire extinguisher, extinguisher, asphyxiator", "id": 939, "trainId": 486},
+    {"name": "water wheel", "id": 3013, "trainId": 487},
+    {"name": "pottery, clayware", "id": 1986, "trainId": 488},
+    {"name": "magazine rack", "id": 1486, "trainId": 489},
+    {"name": "teapot", "id": 2723, "trainId": 490},
+    {"name": "microphone, mike", "id": 1549, "trainId": 491},
+    {"name": "support", "id": 2649, "trainId": 492},
+    {"name": "forklift", "id": 1020, "trainId": 493},
+    {"name": "canyon", "id": 392, "trainId": 494},
+    {"name": "cash register, register", "id": 422, "trainId": 495},
+    {"name": "leaf, leafage, foliage", "id": 1419, "trainId": 496},
+    {"name": "remote control, remote", "id": 2099, "trainId": 497},
+    {"name": "soap dish", "id": 2464, "trainId": 498},
+    {"name": "windshield, windscreen", "id": 3058, "trainId": 499},
+    {"name": "cat", "id": 430, "trainId": 500},
+    {"name": "cue, cue stick, pool cue, pool stick", "id": 675, "trainId": 501},
+    {"name": "vent, venthole, vent-hole, blowhole", "id": 2941, "trainId": 502},
+    {"name": "videos", "id": 2955, "trainId": 503},
+    {"name": "shovel", "id": 2355, "trainId": 504},
+    {"name": "eaves", "id": 840, "trainId": 505},
+    {"name": "antenna, aerial, transmitting aerial", "id": 32, "trainId": 506},
+    {"name": "shipyard", "id": 2338, "trainId": 507},
+    {"name": "hen, biddy", "id": 1232, "trainId": 508},
+    {"name": "traffic cone", "id": 2834, "trainId": 509},
+    {"name": "washing machines", "id": 2991, "trainId": 510},
+    {"name": "truck crane", "id": 2879, "trainId": 511},
+    {"name": "cds", "id": 444, "trainId": 512},
+    {"name": "niche", "id": 1657, "trainId": 513},
+    {"name": "scoreboard", "id": 2246, "trainId": 514},
+    {"name": "briefcase", "id": 296, "trainId": 515},
+    {"name": "boot", "id": 245, "trainId": 516},
+    {"name": "sweater, jumper", "id": 2661, "trainId": 517},
+    {"name": "hay", "id": 1202, "trainId": 518},
+    {"name": "pack", "id": 1714, "trainId": 519},
+    {"name": "bottle rack", "id": 251, "trainId": 520},
+    {"name": "glacier", "id": 1095, "trainId": 521},
+    {"name": "pergola", "id": 1828, "trainId": 522},
+    {"name": "building materials", "id": 311, "trainId": 523},
+    {"name": "television camera", "id": 2732, "trainId": 524},
+    {"name": "first floor", "id": 947, "trainId": 525},
+    {"name": "rifle", "id": 2115, "trainId": 526},
+    {"name": "tennis table", "id": 2738, "trainId": 527},
+    {"name": "stadium", "id": 2525, "trainId": 528},
+    {"name": "safety belt", "id": 2194, "trainId": 529},
+    {"name": "cover", "id": 634, "trainId": 530},
+    {"name": "dish rack", "id": 740, "trainId": 531},
+    {"name": "synthesizer", "id": 2682, "trainId": 532},
+    {"name": "pumpkin", "id": 2020, "trainId": 533},
+    {"name": "gutter", "id": 1156, "trainId": 534},
+    {"name": "fruit stand", "id": 1036, "trainId": 535},
+    {"name": "ice floe, floe", "id": 1295, "trainId": 536},
+    {"name": "handle, grip, handgrip, hold", "id": 1181, "trainId": 537},
+    {"name": "wheelchair", "id": 3037, "trainId": 538},
+    {"name": "mousepad, mouse mat", "id": 1614, "trainId": 539},
+    {"name": "diploma", "id": 736, "trainId": 540},
+    {"name": "fairground ride", "id": 893, "trainId": 541},
+    {"name": "radio", "id": 2047, "trainId": 542},
+    {"name": "hotplate", "id": 1274, "trainId": 543},
+    {"name": "junk", "id": 1361, "trainId": 544},
+    {"name": "wheelbarrow", "id": 3036, "trainId": 545},
+    {"name": "stream", "id": 2606, "trainId": 546},
+    {"name": "toll plaza", "id": 2797, "trainId": 547},
+    {"name": "punching bag", "id": 2022, "trainId": 548},
+    {"name": "trough", "id": 2876, "trainId": 549},
+    {"name": "throne", "id": 2758, "trainId": 550},
+    {"name": "chair desk", "id": 472, "trainId": 551},
+    {"name": "weighbridge", "id": 3028, "trainId": 552},
+    {"name": "extractor fan", "id": 882, "trainId": 553},
+    {"name": "hanging clothes", "id": 1189, "trainId": 554},
+    {"name": "dish, dish aerial, dish antenna, saucer", "id": 743, "trainId": 555},
+    {"name": "alarm clock, alarm", "id": 3122, "trainId": 556},
+    {"name": "ski lift", "id": 2401, "trainId": 557},
+    {"name": "chain", "id": 468, "trainId": 558},
+    {"name": "garage", "id": 1061, "trainId": 559},
+    {"name": "mechanical shovel", "id": 1523, "trainId": 560},
+    {"name": "wine rack", "id": 3059, "trainId": 561},
+    {"name": "tramway", "id": 2843, "trainId": 562},
+    {"name": "treadmill", "id": 2853, "trainId": 563},
+    {"name": "menu", "id": 1529, "trainId": 564},
+    {"name": "block", "id": 214, "trainId": 565},
+    {"name": "well", "id": 3032, "trainId": 566},
+    {"name": "witness stand", "id": 3071, "trainId": 567},
+    {"name": "branch", "id": 277, "trainId": 568},
+    {"name": "duck", "id": 813, "trainId": 569},
+    {"name": "casserole", "id": 426, "trainId": 570},
+    {"name": "frying pan", "id": 1039, "trainId": 571},
+    {"name": "desk organizer", "id": 727, "trainId": 572},
+    {"name": "mast", "id": 1508, "trainId": 573},
+    {"name": "spectacles, specs, eyeglasses, glasses", "id": 2490, "trainId": 574},
+    {"name": "service elevator", "id": 2299, "trainId": 575},
+    {"name": "dollhouse", "id": 768, "trainId": 576},
+    {"name": "hammock", "id": 1172, "trainId": 577},
+    {"name": "clothes hanging", "id": 537, "trainId": 578},
+    {"name": "photocopier", "id": 1847, "trainId": 579},
+    {"name": "notepad", "id": 1664, "trainId": 580},
+    {"name": "golf cart", "id": 1110, "trainId": 581},
+    {"name": "footpath", "id": 1014, "trainId": 582},
+    {"name": "cross", "id": 662, "trainId": 583},
+    {"name": "baptismal font", "id": 121, "trainId": 584},
+    {"name": "boiler", "id": 227, "trainId": 585},
+    {"name": "skip", "id": 2410, "trainId": 586},
+    {"name": "rotisserie", "id": 2165, "trainId": 587},
+    {"name": "tables", "id": 2696, "trainId": 588},
+    {"name": "water mill", "id": 3005, "trainId": 589},
+    {"name": "helmet", "id": 1231, "trainId": 590},
+    {"name": "cover curtain", "id": 635, "trainId": 591},
+    {"name": "brick", "id": 292, "trainId": 592},
+    {"name": "table runner", "id": 2690, "trainId": 593},
+    {"name": "ashtray", "id": 65, "trainId": 594},
+    {"name": "street box", "id": 2607, "trainId": 595},
+    {"name": "stick", "id": 2574, "trainId": 596},
+    {"name": "hangers", "id": 1188, "trainId": 597},
+    {"name": "cells", "id": 456, "trainId": 598},
+    {"name": "urinal", "id": 2913, "trainId": 599},
+    {"name": "centerpiece", "id": 459, "trainId": 600},
+    {"name": "portable fridge", "id": 1955, "trainId": 601},
+    {"name": "dvds", "id": 827, "trainId": 602},
+    {"name": "golf club", "id": 1111, "trainId": 603},
+    {"name": "skirting board", "id": 2412, "trainId": 604},
+    {"name": "water cooler", "id": 2997, "trainId": 605},
+    {"name": "clipboard", "id": 528, "trainId": 606},
+    {"name": "camera, photographic camera", "id": 366, "trainId": 607},
+    {"name": "pigeonhole", "id": 1863, "trainId": 608},
+    {"name": "chips", "id": 500, "trainId": 609},
+    {"name": "food processor", "id": 1001, "trainId": 610},
+    {"name": "post box", "id": 1958, "trainId": 611},
+    {"name": "lid", "id": 1441, "trainId": 612},
+    {"name": "drum", "id": 809, "trainId": 613},
+    {"name": "blender", "id": 210, "trainId": 614},
+    {"name": "cave entrance", "id": 435, "trainId": 615},
+    {"name": "dental chair", "id": 718, "trainId": 616},
+    {"name": "obelisk", "id": 1674, "trainId": 617},
+    {"name": "canoe", "id": 388, "trainId": 618},
+    {"name": "mobile", "id": 1572, "trainId": 619},
+    {"name": "monitors", "id": 1584, "trainId": 620},
+    {"name": "pool ball", "id": 1944, "trainId": 621},
+    {"name": "cue rack", "id": 674, "trainId": 622},
+    {"name": "baggage carts", "id": 99, "trainId": 623},
+    {"name": "shore", "id": 2352, "trainId": 624},
+    {"name": "fork", "id": 1019, "trainId": 625},
+    {"name": "paper filer", "id": 1763, "trainId": 626},
+    {"name": "bicycle rack", "id": 185, "trainId": 627},
+    {"name": "coat rack", "id": 554, "trainId": 628},
+    {"name": "garland", "id": 1066, "trainId": 629},
+    {"name": "sports bag", "id": 2508, "trainId": 630},
+    {"name": "fish tank", "id": 951, "trainId": 631},
+    {"name": "towel dispenser", "id": 2822, "trainId": 632},
+    {"name": "carriage", "id": 415, "trainId": 633},
+    {"name": "brochure", "id": 297, "trainId": 634},
+    {"name": "plaque", "id": 1914, "trainId": 635},
+    {"name": "stringer", "id": 2619, "trainId": 636},
+    {"name": "iron", "id": 1338, "trainId": 637},
+    {"name": "spoon", "id": 2505, "trainId": 638},
+    {"name": "flag pole", "id": 955, "trainId": 639},
+    {"name": "toilet brush", "id": 2786, "trainId": 640},
+    {"name": "book stand", "id": 238, "trainId": 641},
+    {"name": "water faucet, water tap, tap, hydrant", "id": 3000, "trainId": 642},
+    {"name": "ticket office", "id": 2763, "trainId": 643},
+    {"name": "broom", "id": 299, "trainId": 644},
+    {"name": "dvd", "id": 822, "trainId": 645},
+    {"name": "ice bucket", "id": 1288, "trainId": 646},
+    {"name": "carapace, shell, cuticle, shield", "id": 3101, "trainId": 647},
+    {"name": "tureen", "id": 2894, "trainId": 648},
+    {"name": "folders", "id": 992, "trainId": 649},
+    {"name": "chess", "id": 489, "trainId": 650},
+    {"name": "root", "id": 2157, "trainId": 651},
+    {"name": "sewing machine", "id": 2309, "trainId": 652},
+    {"name": "model", "id": 1576, "trainId": 653},
+    {"name": "pen", "id": 1810, "trainId": 654},
+    {"name": "violin", "id": 2964, "trainId": 655},
+    {"name": "sweatshirt", "id": 2662, "trainId": 656},
+    {"name": "recycling materials", "id": 2087, "trainId": 657},
+    {"name": "mitten", "id": 1569, "trainId": 658},
+    {"name": "chopping board, cutting board", "id": 503, "trainId": 659},
+    {"name": "mask", "id": 1505, "trainId": 660},
+    {"name": "log", "id": 1468, "trainId": 661},
+    {"name": "mouse, computer mouse", "id": 1613, "trainId": 662},
+    {"name": "grill", "id": 1138, "trainId": 663},
+    {"name": "hole", "id": 1256, "trainId": 664},
+    {"name": "target", "id": 2715, "trainId": 665},
+    {"name": "trash bag", "id": 2846, "trainId": 666},
+    {"name": "chalk", "id": 477, "trainId": 667},
+    {"name": "sticks", "id": 2576, "trainId": 668},
+    {"name": "balloon", "id": 108, "trainId": 669},
+    {"name": "score", "id": 2245, "trainId": 670},
+    {"name": "hair spray", "id": 1162, "trainId": 671},
+    {"name": "roll", "id": 2149, "trainId": 672},
+    {"name": "runner", "id": 2183, "trainId": 673},
+    {"name": "engine", "id": 858, "trainId": 674},
+    {"name": "inflatable glove", "id": 1324, "trainId": 675},
+    {"name": "games", "id": 1055, "trainId": 676},
+    {"name": "pallets", "id": 1741, "trainId": 677},
+    {"name": "baskets", "id": 149, "trainId": 678},
+    {"name": "coop", "id": 615, "trainId": 679},
+    {"name": "dvd player", "id": 825, "trainId": 680},
+    {"name": "rocking horse", "id": 2143, "trainId": 681},
+    {"name": "buckets", "id": 304, "trainId": 682},
+    {"name": "bread rolls", "id": 283, "trainId": 683},
+    {"name": "shawl", "id": 2322, "trainId": 684},
+    {"name": "watering can", "id": 3017, "trainId": 685},
+    {"name": "spotlights", "id": 2510, "trainId": 686},
+    {"name": "post-it", "id": 1960, "trainId": 687},
+    {"name": "bowls", "id": 265, "trainId": 688},
+    {"name": "security camera", "id": 2282, "trainId": 689},
+    {"name": "runner cloth", "id": 2184, "trainId": 690},
+    {"name": "lock", "id": 1461, "trainId": 691},
+    {"name": "alarm, warning device, alarm system", "id": 3113, "trainId": 692},
+    {"name": "side", "id": 2372, "trainId": 693},
+    {"name": "roulette", "id": 2166, "trainId": 694},
+    {"name": "bone", "id": 232, "trainId": 695},
+    {"name": "cutlery", "id": 693, "trainId": 696},
+    {"name": "pool balls", "id": 1945, "trainId": 697},
+    {"name": "wheels", "id": 3039, "trainId": 698},
+    {"name": "spice rack", "id": 2494, "trainId": 699},
+    {"name": "plant pots", "id": 1908, "trainId": 700},
+    {"name": "towel ring", "id": 2827, "trainId": 701},
+    {"name": "bread box", "id": 280, "trainId": 702},
+    {"name": "video", "id": 2950, "trainId": 703},
+    {"name": "funfair", "id": 1044, "trainId": 704},
+    {"name": "breads", "id": 288, "trainId": 705},
+    {"name": "tripod", "id": 2863, "trainId": 706},
+    {"name": "ironing board", "id": 1342, "trainId": 707},
+    {"name": "skimmer", "id": 2409, "trainId": 708},
+    {"name": "hollow", "id": 1258, "trainId": 709},
+    {"name": "scratching post", "id": 2249, "trainId": 710},
+    {"name": "tricycle", "id": 2862, "trainId": 711},
+    {"name": "file box", "id": 920, "trainId": 712},
+    {"name": "mountain pass", "id": 1607, "trainId": 713},
+    {"name": "tombstones", "id": 2802, "trainId": 714},
+    {"name": "cooker", "id": 610, "trainId": 715},
+    {"name": "card game, cards", "id": 3129, "trainId": 716},
+    {"name": "golf bag", "id": 1108, "trainId": 717},
+    {"name": "towel paper", "id": 2823, "trainId": 718},
+    {"name": "chaise lounge", "id": 476, "trainId": 719},
+    {"name": "sun", "id": 2641, "trainId": 720},
+    {"name": "toilet paper holder", "id": 2788, "trainId": 721},
+    {"name": "rake", "id": 2070, "trainId": 722},
+    {"name": "key", "id": 1368, "trainId": 723},
+    {"name": "umbrella stand", "id": 2903, "trainId": 724},
+    {"name": "dartboard", "id": 699, "trainId": 725},
+    {"name": "transformer", "id": 2844, "trainId": 726},
+    {"name": "fireplace utensils", "id": 942, "trainId": 727},
+    {"name": "sweatshirts", "id": 2663, "trainId": 728},
+    {
+        "name": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
+        "id": 457,
+        "trainId": 729,
+    },
+    {"name": "tallboy", "id": 2701, "trainId": 730},
+    {"name": "stapler", "id": 2540, "trainId": 731},
+    {"name": "sauna", "id": 2231, "trainId": 732},
+    {"name": "test tube", "id": 2746, "trainId": 733},
+    {"name": "palette", "id": 1738, "trainId": 734},
+    {"name": "shopping carts", "id": 2350, "trainId": 735},
+    {"name": "tools", "id": 2808, "trainId": 736},
+    {"name": "push button, push, button", "id": 2025, "trainId": 737},
+    {"name": "star", "id": 2541, "trainId": 738},
+    {"name": "roof rack", "id": 2156, "trainId": 739},
+    {"name": "barbed wire", "id": 126, "trainId": 740},
+    {"name": "spray", "id": 2512, "trainId": 741},
+    {"name": "ear", "id": 831, "trainId": 742},
+    {"name": "sponge", "id": 2503, "trainId": 743},
+    {"name": "racket", "id": 2039, "trainId": 744},
+    {"name": "tins", "id": 2774, "trainId": 745},
+    {"name": "eyeglasses", "id": 886, "trainId": 746},
+    {"name": "file", "id": 919, "trainId": 747},
+    {"name": "scarfs", "id": 2240, "trainId": 748},
+    {"name": "sugar bowl", "id": 2636, "trainId": 749},
+    {"name": "flip flop", "id": 963, "trainId": 750},
+    {"name": "headstones", "id": 1218, "trainId": 751},
+    {"name": "laptop bag", "id": 1406, "trainId": 752},
+    {"name": "leash", "id": 1420, "trainId": 753},
+    {"name": "climbing frame", "id": 526, "trainId": 754},
+    {"name": "suit hanger", "id": 2639, "trainId": 755},
+    {"name": "floor spotlight", "id": 975, "trainId": 756},
+    {"name": "plate rack", "id": 1921, "trainId": 757},
+    {"name": "sewer", "id": 2305, "trainId": 758},
+    {"name": "hard drive", "id": 1193, "trainId": 759},
+    {"name": "sprinkler", "id": 2517, "trainId": 760},
+    {"name": "tools box", "id": 2809, "trainId": 761},
+    {"name": "necklace", "id": 1647, "trainId": 762},
+    {"name": "bulbs", "id": 314, "trainId": 763},
+    {"name": "steel industry", "id": 2560, "trainId": 764},
+    {"name": "club", "id": 545, "trainId": 765},
+    {"name": "jack", "id": 1345, "trainId": 766},
+    {"name": "door bars", "id": 775, "trainId": 767},
+    {
+        "name": "control panel, instrument panel, control board, board, panel",
+        "id": 603,
+        "trainId": 768,
+    },
+    {"name": "hairbrush", "id": 1163, "trainId": 769},
+    {"name": "napkin holder", "id": 1641, "trainId": 770},
+    {"name": "office", "id": 1678, "trainId": 771},
+    {"name": "smoke detector", "id": 2450, "trainId": 772},
+    {"name": "utensils", "id": 2915, "trainId": 773},
+    {"name": "apron", "id": 42, "trainId": 774},
+    {"name": "scissors", "id": 2242, "trainId": 775},
+    {"name": "terminal", "id": 2741, "trainId": 776},
+    {"name": "grinder", "id": 1143, "trainId": 777},
+    {"name": "entry phone", "id": 862, "trainId": 778},
+    {"name": "newspaper stand", "id": 1654, "trainId": 779},
+    {"name": "pepper shaker", "id": 1826, "trainId": 780},
+    {"name": "onions", "id": 1689, "trainId": 781},
+    {
+        "name": "central processing unit, cpu, c p u , central processor, processor, mainframe",
+        "id": 3124,
+        "trainId": 782,
+    },
+    {"name": "tape", "id": 2710, "trainId": 783},
+    {"name": "bat", "id": 152, "trainId": 784},
+    {"name": "coaster", "id": 549, "trainId": 785},
+    {"name": "calculator", "id": 360, "trainId": 786},
+    {"name": "potatoes", "id": 1982, "trainId": 787},
+    {"name": "luggage rack", "id": 1478, "trainId": 788},
+    {"name": "salt", "id": 2203, "trainId": 789},
+    {"name": "street number", "id": 2612, "trainId": 790},
+    {"name": "viewpoint", "id": 2956, "trainId": 791},
+    {"name": "sword", "id": 2681, "trainId": 792},
+    {"name": "cd", "id": 437, "trainId": 793},
+    {"name": "rowing machine", "id": 2171, "trainId": 794},
+    {"name": "plug", "id": 1933, "trainId": 795},
+    {"name": "andiron, firedog, dog, dog-iron", "id": 3110, "trainId": 796},
+    {"name": "pepper", "id": 1824, "trainId": 797},
+    {"name": "tongs", "id": 2803, "trainId": 798},
+    {"name": "bonfire", "id": 234, "trainId": 799},
+    {"name": "dog dish", "id": 764, "trainId": 800},
+    {"name": "belt", "id": 177, "trainId": 801},
+    {"name": "dumbbells", "id": 817, "trainId": 802},
+    {"name": "videocassette recorder, vcr", "id": 3145, "trainId": 803},
+    {"name": "hook", "id": 1262, "trainId": 804},
+    {"name": "envelopes", "id": 864, "trainId": 805},
+    {"name": "shower faucet", "id": 2359, "trainId": 806},
+    {"name": "watch", "id": 2992, "trainId": 807},
+    {"name": "padlock", "id": 1725, "trainId": 808},
+    {"name": "swimming pool ladder", "id": 2667, "trainId": 809},
+    {"name": "spanners", "id": 2484, "trainId": 810},
+    {"name": "gravy boat", "id": 1133, "trainId": 811},
+    {"name": "notice board", "id": 1667, "trainId": 812},
+    {"name": "trash bags", "id": 2847, "trainId": 813},
+    {"name": "fire alarm", "id": 932, "trainId": 814},
+    {"name": "ladle", "id": 1392, "trainId": 815},
+    {"name": "stethoscope", "id": 2573, "trainId": 816},
+    {"name": "rocket", "id": 2140, "trainId": 817},
+    {"name": "funnel", "id": 1046, "trainId": 818},
+    {"name": "bowling pins", "id": 264, "trainId": 819},
+    {"name": "valve", "id": 2927, "trainId": 820},
+    {"name": "thermometer", "id": 2752, "trainId": 821},
+    {"name": "cups", "id": 679, "trainId": 822},
+    {"name": "spice jar", "id": 2493, "trainId": 823},
+    {"name": "night light", "id": 1658, "trainId": 824},
+    {"name": "soaps", "id": 2466, "trainId": 825},
+    {"name": "games table", "id": 1057, "trainId": 826},
+    {"name": "slotted spoon", "id": 2444, "trainId": 827},
+    {"name": "reel", "id": 2093, "trainId": 828},
+    {"name": "scourer", "id": 2248, "trainId": 829},
+    {"name": "sleeping robe", "id": 2432, "trainId": 830},
+    {"name": "desk mat", "id": 726, "trainId": 831},
+    {"name": "dumbbell", "id": 816, "trainId": 832},
+    {"name": "hammer", "id": 1171, "trainId": 833},
+    {"name": "tie", "id": 2766, "trainId": 834},
+    {"name": "typewriter", "id": 2900, "trainId": 835},
+    {"name": "shaker", "id": 2313, "trainId": 836},
+    {"name": "cheese dish", "id": 488, "trainId": 837},
+    {"name": "sea star", "id": 2265, "trainId": 838},
+    {"name": "racquet", "id": 2043, "trainId": 839},
+    {"name": "butane gas cylinder", "id": 332, "trainId": 840},
+    {"name": "paper weight", "id": 1771, "trainId": 841},
+    {"name": "shaving brush", "id": 2320, "trainId": 842},
+    {"name": "sunglasses", "id": 2646, "trainId": 843},
+    {"name": "gear shift", "id": 1089, "trainId": 844},
+    {"name": "towel rail", "id": 2826, "trainId": 845},
+    {"name": "adding machine, totalizer, totaliser", "id": 3148, "trainId": 846},
+]
+
+
+def _get_ade20k_full_meta():
+    stuff_ids = [k["id"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
+    assert len(stuff_ids) == 847, len(stuff_ids)
+
+    stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
+    stuff_classes = [k["name"] for k in ADE20K_SEM_SEG_FULL_CATEGORIES]
+
+    ret = {
+        "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
+        "stuff_classes": stuff_classes,
+    }
+    return ret
+
+
+def register_all_ade20k_full(root):
+    meta = _get_ade20k_full_meta()
+    for name, dirname in [("val", "validation")]:
+        image_dir = os.path.join(root, "ADE20K_2021_17_01/images_detectron2", dirname)
+        gt_dir = os.path.join(root, "ADE20K_2021_17_01/annotations_detectron2", dirname)
+        name = f"ade20k_full_sem_seg_{name}"
+        DatasetCatalog.register(
+            name,
+            lambda x=image_dir, y=gt_dir: load_sem_seg(
+                y, x, gt_ext="tif", image_ext="jpg"
+            ),
+        )
+        MetadataCatalog.get(name).set(
+            stuff_classes=meta["stuff_classes"][:],
+            thing_classes=meta["stuff_classes"][:],  # the same as stuff_classes
+            image_root=image_dir,
+            sem_seg_root=gt_dir,
+            evaluator_type="sem_seg",
+            ignore_label=65535,  # NOTE: gt is saved in 16-bit TIFF images
+        )
+
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_all_ade20k_full(_root)
diff --git a/open_vocab_seg/data/datasets/register_cc3m.py b/open_vocab_seg/data/datasets/register_cc3m.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aa5cb07bc99b574505b6319835750789bb3ee26
--- /dev/null
+++ b/open_vocab_seg/data/datasets/register_cc3m.py
@@ -0,0 +1,457 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import os
+
+import pandas as pd
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+from detectron2.utils.file_io import PathManager
+
+
+COCO_CATEGORIES = [
+    {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
+    {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
+    {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
+    {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
+    {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
+    {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
+    {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
+    {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
+    {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
+    {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
+    {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
+    {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
+    {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
+    {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
+    {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
+    {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
+    {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
+    {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
+    {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
+    {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
+    {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
+    {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
+    {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
+    {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
+    {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
+    {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
+    {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
+    {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
+    {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
+    {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
+    {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
+    {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
+    {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
+    {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
+    {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
+    {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
+    {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
+    {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
+    {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
+    {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
+    {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
+    {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
+    {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
+    {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
+    {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
+    {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
+    {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
+    {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
+    {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
+    {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
+    {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
+    {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
+    {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
+    {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
+    {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
+    {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
+    {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
+    {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
+    {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
+    {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
+    {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
+    {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
+    {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
+    {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
+    {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
+    {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
+    {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
+    {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
+    {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
+    {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
+    {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
+    {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
+    {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
+    {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
+    {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
+    {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
+    {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
+    {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
+    {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
+    {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
+    {"id": 92, "name": "banner", "supercategory": "textile"},
+    {"id": 93, "name": "blanket", "supercategory": "textile"},
+    {"id": 94, "name": "branch", "supercategory": "plant"},
+    {"id": 95, "name": "bridge", "supercategory": "building"},
+    {"id": 96, "name": "building-other", "supercategory": "building"},
+    {"id": 97, "name": "bush", "supercategory": "plant"},
+    {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
+    {"id": 99, "name": "cage", "supercategory": "structural"},
+    {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
+    {"id": 101, "name": "carpet", "supercategory": "floor"},
+    {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
+    {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
+    {"id": 104, "name": "cloth", "supercategory": "textile"},
+    {"id": 105, "name": "clothes", "supercategory": "textile"},
+    {"id": 106, "name": "clouds", "supercategory": "sky"},
+    {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
+    {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
+    {"id": 109, "name": "curtain", "supercategory": "textile"},
+    {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
+    {"id": 111, "name": "dirt", "supercategory": "ground"},
+    {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
+    {"id": 113, "name": "fence", "supercategory": "structural"},
+    {"id": 114, "name": "floor-marble", "supercategory": "floor"},
+    {"id": 115, "name": "floor-other", "supercategory": "floor"},
+    {"id": 116, "name": "floor-stone", "supercategory": "floor"},
+    {"id": 117, "name": "floor-tile", "supercategory": "floor"},
+    {"id": 118, "name": "floor-wood", "supercategory": "floor"},
+    {"id": 119, "name": "flower", "supercategory": "plant"},
+    {"id": 120, "name": "fog", "supercategory": "water"},
+    {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
+    {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
+    {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
+    {"id": 124, "name": "grass", "supercategory": "plant"},
+    {"id": 125, "name": "gravel", "supercategory": "ground"},
+    {"id": 126, "name": "ground-other", "supercategory": "ground"},
+    {"id": 127, "name": "hill", "supercategory": "solid"},
+    {"id": 128, "name": "house", "supercategory": "building"},
+    {"id": 129, "name": "leaves", "supercategory": "plant"},
+    {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
+    {"id": 131, "name": "mat", "supercategory": "textile"},
+    {"id": 132, "name": "metal", "supercategory": "raw-material"},
+    {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
+    {"id": 134, "name": "moss", "supercategory": "plant"},
+    {"id": 135, "name": "mountain", "supercategory": "solid"},
+    {"id": 136, "name": "mud", "supercategory": "ground"},
+    {"id": 137, "name": "napkin", "supercategory": "textile"},
+    {"id": 138, "name": "net", "supercategory": "structural"},
+    {"id": 139, "name": "paper", "supercategory": "raw-material"},
+    {"id": 140, "name": "pavement", "supercategory": "ground"},
+    {"id": 141, "name": "pillow", "supercategory": "textile"},
+    {"id": 142, "name": "plant-other", "supercategory": "plant"},
+    {"id": 143, "name": "plastic", "supercategory": "raw-material"},
+    {"id": 144, "name": "platform", "supercategory": "ground"},
+    {"id": 145, "name": "playingfield", "supercategory": "ground"},
+    {"id": 146, "name": "railing", "supercategory": "structural"},
+    {"id": 147, "name": "railroad", "supercategory": "ground"},
+    {"id": 148, "name": "river", "supercategory": "water"},
+    {"id": 149, "name": "road", "supercategory": "ground"},
+    {"id": 150, "name": "rock", "supercategory": "solid"},
+    {"id": 151, "name": "roof", "supercategory": "building"},
+    {"id": 152, "name": "rug", "supercategory": "textile"},
+    {"id": 153, "name": "salad", "supercategory": "food-stuff"},
+    {"id": 154, "name": "sand", "supercategory": "ground"},
+    {"id": 155, "name": "sea", "supercategory": "water"},
+    {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
+    {"id": 157, "name": "sky-other", "supercategory": "sky"},
+    {"id": 158, "name": "skyscraper", "supercategory": "building"},
+    {"id": 159, "name": "snow", "supercategory": "ground"},
+    {"id": 160, "name": "solid-other", "supercategory": "solid"},
+    {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
+    {"id": 162, "name": "stone", "supercategory": "solid"},
+    {"id": 163, "name": "straw", "supercategory": "plant"},
+    {"id": 164, "name": "structural-other", "supercategory": "structural"},
+    {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
+    {"id": 166, "name": "tent", "supercategory": "building"},
+    {"id": 167, "name": "textile-other", "supercategory": "textile"},
+    {"id": 168, "name": "towel", "supercategory": "textile"},
+    {"id": 169, "name": "tree", "supercategory": "plant"},
+    {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
+    {"id": 171, "name": "wall-brick", "supercategory": "wall"},
+    {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
+    {"id": 173, "name": "wall-other", "supercategory": "wall"},
+    {"id": 174, "name": "wall-panel", "supercategory": "wall"},
+    {"id": 175, "name": "wall-stone", "supercategory": "wall"},
+    {"id": 176, "name": "wall-tile", "supercategory": "wall"},
+    {"id": 177, "name": "wall-wood", "supercategory": "wall"},
+    {"id": 178, "name": "water-other", "supercategory": "water"},
+    {"id": 179, "name": "waterdrops", "supercategory": "water"},
+    {"id": 180, "name": "window-blind", "supercategory": "window"},
+    {"id": 181, "name": "window-other", "supercategory": "window"},
+    {"id": 182, "name": "wood", "supercategory": "solid"},
+]
+
+
+ADE20K_150_CATEGORIES = [
+    {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
+    {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
+    {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
+    {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
+    {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
+    {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
+    {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
+    {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
+    {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
+    {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
+    {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
+    {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
+    {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
+    {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
+    {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
+    {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
+    {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
+    {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
+    {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
+    {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
+    {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
+    {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
+    {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
+    {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
+    {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
+    {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
+    {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
+    {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
+    {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
+    {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
+    {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
+    {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
+    {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
+    {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
+    {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
+    {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
+    {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
+    {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
+    {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
+    {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
+    {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
+    {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
+    {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
+    {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
+    {
+        "color": [6, 51, 255],
+        "id": 44,
+        "isthing": 1,
+        "name": "chest of drawers, chest, bureau, dresser",
+    },
+    {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
+    {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
+    {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
+    {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
+    {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
+    {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
+    {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
+    {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
+    {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
+    {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
+    {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
+    {
+        "color": [255, 71, 0],
+        "id": 56,
+        "isthing": 1,
+        "name": "pool table, billiard table, snooker table",
+    },
+    {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
+    {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
+    {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
+    {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
+    {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
+    {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
+    {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
+    {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
+    {
+        "color": [0, 255, 133],
+        "id": 65,
+        "isthing": 1,
+        "name": "toilet, can, commode, crapper, pot, potty, stool, throne",
+    },
+    {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
+    {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
+    {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
+    {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
+    {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
+    {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
+    {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
+    {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
+    {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
+    {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
+    {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
+    {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
+    {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
+    {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
+    {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
+    {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
+    {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
+    {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
+    {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
+    {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
+    {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
+    {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
+    {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
+    {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
+    {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
+    {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
+    {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
+    {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
+    {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
+    {
+        "color": [0, 122, 255],
+        "id": 95,
+        "isthing": 1,
+        "name": "bannister, banister, balustrade, balusters, handrail",
+    },
+    {
+        "color": [0, 255, 163],
+        "id": 96,
+        "isthing": 0,
+        "name": "escalator, moving staircase, moving stairway",
+    },
+    {
+        "color": [255, 153, 0],
+        "id": 97,
+        "isthing": 1,
+        "name": "ottoman, pouf, pouffe, puff, hassock",
+    },
+    {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
+    {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
+    {
+        "color": [143, 255, 0],
+        "id": 100,
+        "isthing": 0,
+        "name": "poster, posting, placard, notice, bill, card",
+    },
+    {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
+    {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
+    {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
+    {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
+    {
+        "color": [133, 0, 255],
+        "id": 105,
+        "isthing": 0,
+        "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
+    },
+    {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
+    {
+        "color": [184, 0, 255],
+        "id": 107,
+        "isthing": 1,
+        "name": "washer, automatic washer, washing machine",
+    },
+    {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
+    {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
+    {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
+    {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
+    {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
+    {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
+    {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
+    {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
+    {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
+    {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
+    {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
+    {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
+    {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
+    {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
+    {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
+    {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
+    {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
+    {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
+    {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
+    {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
+    {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
+    {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
+    {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
+    {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
+    {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
+    {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
+    {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
+    {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
+    {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
+    {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
+    {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
+    {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
+    {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
+    {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
+    {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
+    {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
+    {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
+    {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
+    {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
+    {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
+    {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
+    {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
+]
+
+TEST_CATEGORIES = [
+    {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "Oculus"},
+    {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "Ukulele"},
+]
+
+COCO_BASE_CATEGORIES = [
+    c
+    for i, c in enumerate(COCO_CATEGORIES)
+    if c["id"] - 1
+    not in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171]
+]
+COCO_NOVEL_CATEGORIES = [
+    c
+    for i, c in enumerate(COCO_CATEGORIES)
+    if c["id"] - 1
+    in [20, 24, 32, 33, 40, 56, 86, 99, 105, 123, 144, 147, 148, 168, 171]
+]
+
+
+def load_cc_image(csv_file, img_key='filepath', caption_key='title', sep="\t"):
+    print(f'Loading csv data from {csv_file}.')
+    df = pd.read_csv(csv_file, sep=sep)
+
+    input_files = df[img_key].tolist()
+    captions = df[caption_key].tolist()
+
+    print("Loaded {} images".format(len(input_files)))
+
+    dataset_dicts = []
+    for (img_path, text) in zip(input_files, captions):
+        record = {}
+        record["file_name"] = img_path
+        record["caption"] = text
+        dataset_dicts.append(record)
+
+    return dataset_dicts
+
+
+def _get_coco_stuff_meta(cat_list):
+    # Id 0 is reserved for ignore_label, we change ignore_label for 0
+    # to 255 in our pre-processing.
+    stuff_ids = [k["id"] for k in cat_list]
+
+    # For semantic segmentation, this mapping maps from contiguous stuff id
+    # (in [0, 91], used in models) to ids in the dataset (used for processing results)
+    stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
+    stuff_classes = [k["name"] for k in cat_list]
+
+    ret = {
+        "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
+        "stuff_classes": stuff_classes,
+    }
+    return ret
+
+
+def register_cc_3m(csv_file):
+
+    meta = _get_coco_stuff_meta(TEST_CATEGORIES)
+    name = "cc_3m_train"
+
+    DatasetCatalog.register(
+        name,
+        lambda x=csv_file: load_cc_image(x),
+    )
+    MetadataCatalog.get(name).set(
+        csv_file=csv_file,
+        evaluator_type="dummy",
+        ignore_label=255,
+        **meta,
+    )
+
+
+# _csv_file = "/home/jeffliang/zsseg/datasets/coco/coco_train_merge_captions.csv"
+_csv_file = "/home/jeffliang/zsseg/configs/masked_images/pred/samples.csv"
+register_cc_3m(_csv_file)
diff --git a/open_vocab_seg/data/datasets/register_coco_stuff.py b/open_vocab_seg/data/datasets/register_coco_stuff.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1a0f5b571a971fe20ebc8932d27499de856a565
--- /dev/null
+++ b/open_vocab_seg/data/datasets/register_coco_stuff.py
@@ -0,0 +1,250 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+
+
+COCO_CATEGORIES = [
+    {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
+    {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
+    {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
+    {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
+    {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
+    {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
+    {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
+    {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
+    {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
+    {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
+    {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
+    {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
+    {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
+    {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
+    {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
+    {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
+    {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
+    {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
+    {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
+    {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
+    {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
+    {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
+    {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
+    {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
+    {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
+    {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
+    {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
+    {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
+    {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
+    {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
+    {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
+    {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
+    {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
+    {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
+    {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
+    {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
+    {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
+    {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
+    {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
+    {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
+    {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
+    {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
+    {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
+    {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
+    {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
+    {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
+    {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
+    {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
+    {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
+    {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
+    {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
+    {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
+    {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
+    {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
+    {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
+    {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
+    {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
+    {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
+    {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
+    {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
+    {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
+    {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
+    {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
+    {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
+    {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
+    {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
+    {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
+    {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
+    {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
+    {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
+    {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
+    {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
+    {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
+    {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
+    {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
+    {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
+    {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
+    {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
+    {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
+    {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
+    {"id": 92, "name": "banner", "supercategory": "textile"},
+    {"id": 93, "name": "blanket", "supercategory": "textile"},
+    {"id": 94, "name": "branch", "supercategory": "plant"},
+    {"id": 95, "name": "bridge", "supercategory": "building"},
+    {"id": 96, "name": "building-other", "supercategory": "building"},
+    {"id": 97, "name": "bush", "supercategory": "plant"},
+    {"id": 98, "name": "cabinet", "supercategory": "furniture-stuff"},
+    {"id": 99, "name": "cage", "supercategory": "structural"},
+    {"id": 100, "name": "cardboard", "supercategory": "raw-material"},
+    {"id": 101, "name": "carpet", "supercategory": "floor"},
+    {"id": 102, "name": "ceiling-other", "supercategory": "ceiling"},
+    {"id": 103, "name": "ceiling-tile", "supercategory": "ceiling"},
+    {"id": 104, "name": "cloth", "supercategory": "textile"},
+    {"id": 105, "name": "clothes", "supercategory": "textile"},
+    {"id": 106, "name": "clouds", "supercategory": "sky"},
+    {"id": 107, "name": "counter", "supercategory": "furniture-stuff"},
+    {"id": 108, "name": "cupboard", "supercategory": "furniture-stuff"},
+    {"id": 109, "name": "curtain", "supercategory": "textile"},
+    {"id": 110, "name": "desk-stuff", "supercategory": "furniture-stuff"},
+    {"id": 111, "name": "dirt", "supercategory": "ground"},
+    {"id": 112, "name": "door-stuff", "supercategory": "furniture-stuff"},
+    {"id": 113, "name": "fence", "supercategory": "structural"},
+    {"id": 114, "name": "floor-marble", "supercategory": "floor"},
+    {"id": 115, "name": "floor-other", "supercategory": "floor"},
+    {"id": 116, "name": "floor-stone", "supercategory": "floor"},
+    {"id": 117, "name": "floor-tile", "supercategory": "floor"},
+    {"id": 118, "name": "floor-wood", "supercategory": "floor"},
+    {"id": 119, "name": "flower", "supercategory": "plant"},
+    {"id": 120, "name": "fog", "supercategory": "water"},
+    {"id": 121, "name": "food-other", "supercategory": "food-stuff"},
+    {"id": 122, "name": "fruit", "supercategory": "food-stuff"},
+    {"id": 123, "name": "furniture-other", "supercategory": "furniture-stuff"},
+    {"id": 124, "name": "grass", "supercategory": "plant"},
+    {"id": 125, "name": "gravel", "supercategory": "ground"},
+    {"id": 126, "name": "ground-other", "supercategory": "ground"},
+    {"id": 127, "name": "hill", "supercategory": "solid"},
+    {"id": 128, "name": "house", "supercategory": "building"},
+    {"id": 129, "name": "leaves", "supercategory": "plant"},
+    {"id": 130, "name": "light", "supercategory": "furniture-stuff"},
+    {"id": 131, "name": "mat", "supercategory": "textile"},
+    {"id": 132, "name": "metal", "supercategory": "raw-material"},
+    {"id": 133, "name": "mirror-stuff", "supercategory": "furniture-stuff"},
+    {"id": 134, "name": "moss", "supercategory": "plant"},
+    {"id": 135, "name": "mountain", "supercategory": "solid"},
+    {"id": 136, "name": "mud", "supercategory": "ground"},
+    {"id": 137, "name": "napkin", "supercategory": "textile"},
+    {"id": 138, "name": "net", "supercategory": "structural"},
+    {"id": 139, "name": "paper", "supercategory": "raw-material"},
+    {"id": 140, "name": "pavement", "supercategory": "ground"},
+    {"id": 141, "name": "pillow", "supercategory": "textile"},
+    {"id": 142, "name": "plant-other", "supercategory": "plant"},
+    {"id": 143, "name": "plastic", "supercategory": "raw-material"},
+    {"id": 144, "name": "platform", "supercategory": "ground"},
+    {"id": 145, "name": "playingfield", "supercategory": "ground"},
+    {"id": 146, "name": "railing", "supercategory": "structural"},
+    {"id": 147, "name": "railroad", "supercategory": "ground"},
+    {"id": 148, "name": "river", "supercategory": "water"},
+    {"id": 149, "name": "road", "supercategory": "ground"},
+    {"id": 150, "name": "rock", "supercategory": "solid"},
+    {"id": 151, "name": "roof", "supercategory": "building"},
+    {"id": 152, "name": "rug", "supercategory": "textile"},
+    {"id": 153, "name": "salad", "supercategory": "food-stuff"},
+    {"id": 154, "name": "sand", "supercategory": "ground"},
+    {"id": 155, "name": "sea", "supercategory": "water"},
+    {"id": 156, "name": "shelf", "supercategory": "furniture-stuff"},
+    {"id": 157, "name": "sky-other", "supercategory": "sky"},
+    {"id": 158, "name": "skyscraper", "supercategory": "building"},
+    {"id": 159, "name": "snow", "supercategory": "ground"},
+    {"id": 160, "name": "solid-other", "supercategory": "solid"},
+    {"id": 161, "name": "stairs", "supercategory": "furniture-stuff"},
+    {"id": 162, "name": "stone", "supercategory": "solid"},
+    {"id": 163, "name": "straw", "supercategory": "plant"},
+    {"id": 164, "name": "structural-other", "supercategory": "structural"},
+    {"id": 165, "name": "table", "supercategory": "furniture-stuff"},
+    {"id": 166, "name": "tent", "supercategory": "building"},
+    {"id": 167, "name": "textile-other", "supercategory": "textile"},
+    {"id": 168, "name": "towel", "supercategory": "textile"},
+    {"id": 169, "name": "tree", "supercategory": "plant"},
+    {"id": 170, "name": "vegetable", "supercategory": "food-stuff"},
+    {"id": 171, "name": "wall-brick", "supercategory": "wall"},
+    {"id": 172, "name": "wall-concrete", "supercategory": "wall"},
+    {"id": 173, "name": "wall-other", "supercategory": "wall"},
+    {"id": 174, "name": "wall-panel", "supercategory": "wall"},
+    {"id": 175, "name": "wall-stone", "supercategory": "wall"},
+    {"id": 176, "name": "wall-tile", "supercategory": "wall"},
+    {"id": 177, "name": "wall-wood", "supercategory": "wall"},
+    {"id": 178, "name": "water-other", "supercategory": "water"},
+    {"id": 179, "name": "waterdrops", "supercategory": "water"},
+    {"id": 180, "name": "window-blind", "supercategory": "window"},
+    {"id": 181, "name": "window-other", "supercategory": "window"},
+    {"id": 182, "name": "wood", "supercategory": "solid"},
+]
+
+def _get_coco_stuff_meta(cat_list):
+    # Id 0 is reserved for ignore_label, we change ignore_label for 0
+    # to 255 in our pre-processing.
+    stuff_ids = [k["id"] for k in cat_list]
+
+    # For semantic segmentation, this mapping maps from contiguous stuff id
+    # (in [0, 91], used in models) to ids in the dataset (used for processing results)
+    stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}
+    stuff_classes = [k["name"] for k in cat_list]
+
+    ret = {
+        "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
+        "stuff_classes": stuff_classes,
+    }
+    return ret
+
+
+def register_all_coco_stuff_10k(root):
+    root = os.path.join(root, "coco", "coco_stuff_10k")
+    meta = _get_coco_stuff_meta(COCO_CATEGORIES)
+    for name, image_dirname, sem_seg_dirname in [
+        ("train", "images_detectron2/train", "annotations_detectron2/train"),
+    ]:
+        image_dir = os.path.join(root, image_dirname)
+        gt_dir = os.path.join(root, sem_seg_dirname)
+        name = f"coco_2017_{name}_stuff_10k_sem_seg"
+        DatasetCatalog.register(
+            name,
+            lambda x=image_dir, y=gt_dir: load_sem_seg(
+                y, x, gt_ext="png", image_ext="jpg"
+            ),
+        )
+        MetadataCatalog.get(name).set(
+            image_root=image_dir,
+            sem_seg_root=gt_dir,
+            evaluator_type="sem_seg",
+            ignore_label=255,
+            **meta,
+        )
+
+
+def register_all_coco_stuff(root):
+    root = os.path.join(root, "coco")
+    meta = _get_coco_stuff_meta(COCO_CATEGORIES)
+
+    for name, image_dirname, sem_seg_dirname in [
+        ("train", "train2017", "stuffthingmaps_detectron2/train2017"),
+    ]:
+        image_dir = os.path.join(root, image_dirname)
+        gt_dir = os.path.join(root, sem_seg_dirname)
+        all_name = f"coco_2017_{name}_stuff_sem_seg"
+        DatasetCatalog.register(
+            all_name,
+            lambda x=image_dir, y=gt_dir: load_sem_seg(
+                y, x, gt_ext="png", image_ext="jpg"
+            ),
+        )
+        MetadataCatalog.get(all_name).set(
+            image_root=image_dir,
+            sem_seg_root=gt_dir,
+            evaluator_type="sem_seg",
+            ignore_label=255,
+            **meta,
+        )
+
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_all_coco_stuff_10k(_root)
+register_all_coco_stuff(_root)
diff --git a/open_vocab_seg/data/datasets/register_pascal_context.py b/open_vocab_seg/data/datasets/register_pascal_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..e40f87c945da20e78c0a3ea230bc9f36d1800071
--- /dev/null
+++ b/open_vocab_seg/data/datasets/register_pascal_context.py
@@ -0,0 +1,588 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+
+PASCALCONTEX59_NAMES = (
+    "aeroplane",
+    "bicycle",
+    "bird",
+    "boat",
+    "bottle",
+    "bus",
+    "car",
+    "cat",
+    "chair",
+    "cow",
+    "table",
+    "dog",
+    "horse",
+    "motorbike",
+    "person",
+    "pottedplant",
+    "sheep",
+    "sofa",
+    "train",
+    "tvmonitor",
+    "bag",
+    "bed",
+    "bench",
+    "book",
+    "building",
+    "cabinet",
+    "ceiling",
+    "cloth",
+    "computer",
+    "cup",
+    "door",
+    "fence",
+    "floor",
+    "flower",
+    "food",
+    "grass",
+    "ground",
+    "keyboard",
+    "light",
+    "mountain",
+    "mouse",
+    "curtain",
+    "platform",
+    "sign",
+    "plate",
+    "road",
+    "rock",
+    "shelves",
+    "sidewalk",
+    "sky",
+    "snow",
+    "bedclothes",
+    "track",
+    "tree",
+    "truck",
+    "wall",
+    "water",
+    "window",
+    "wood",
+)
+
+PASCALCONTEX459_NAMES = (
+    "accordion",
+    "aeroplane",
+    "air conditioner",
+    "antenna",
+    "artillery",
+    "ashtray",
+    "atrium",
+    "baby carriage",
+    "bag",
+    "ball",
+    "balloon",
+    "bamboo weaving",
+    "barrel",
+    "baseball bat",
+    "basket",
+    "basketball backboard",
+    "bathtub",
+    "bed",
+    "bedclothes",
+    "beer",
+    "bell",
+    "bench",
+    "bicycle",
+    "binoculars",
+    "bird",
+    "bird cage",
+    "bird feeder",
+    "bird nest",
+    "blackboard",
+    "board",
+    "boat",
+    "bone",
+    "book",
+    "bottle",
+    "bottle opener",
+    "bowl",
+    "box",
+    "bracelet",
+    "brick",
+    "bridge",
+    "broom",
+    "brush",
+    "bucket",
+    "building",
+    "bus",
+    "cabinet",
+    "cabinet door",
+    "cage",
+    "cake",
+    "calculator",
+    "calendar",
+    "camel",
+    "camera",
+    "camera lens",
+    "can",
+    "candle",
+    "candle holder",
+    "cap",
+    "car",
+    "card",
+    "cart",
+    "case",
+    "casette recorder",
+    "cash register",
+    "cat",
+    "cd",
+    "cd player",
+    "ceiling",
+    "cell phone",
+    "cello",
+    "chain",
+    "chair",
+    "chessboard",
+    "chicken",
+    "chopstick",
+    "clip",
+    "clippers",
+    "clock",
+    "closet",
+    "cloth",
+    "clothes tree",
+    "coffee",
+    "coffee machine",
+    "comb",
+    "computer",
+    "concrete",
+    "cone",
+    "container",
+    "control booth",
+    "controller",
+    "cooker",
+    "copying machine",
+    "coral",
+    "cork",
+    "corkscrew",
+    "counter",
+    "court",
+    "cow",
+    "crabstick",
+    "crane",
+    "crate",
+    "cross",
+    "crutch",
+    "cup",
+    "curtain",
+    "cushion",
+    "cutting board",
+    "dais",
+    "disc",
+    "disc case",
+    "dishwasher",
+    "dock",
+    "dog",
+    "dolphin",
+    "door",
+    "drainer",
+    "dray",
+    "drink dispenser",
+    "drinking machine",
+    "drop",
+    "drug",
+    "drum",
+    "drum kit",
+    "duck",
+    "dumbbell",
+    "earphone",
+    "earrings",
+    "egg",
+    "electric fan",
+    "electric iron",
+    "electric pot",
+    "electric saw",
+    "electronic keyboard",
+    "engine",
+    "envelope",
+    "equipment",
+    "escalator",
+    "exhibition booth",
+    "extinguisher",
+    "eyeglass",
+    "fan",
+    "faucet",
+    "fax machine",
+    "fence",
+    "ferris wheel",
+    "fire extinguisher",
+    "fire hydrant",
+    "fire place",
+    "fish",
+    "fish tank",
+    "fishbowl",
+    "fishing net",
+    "fishing pole",
+    "flag",
+    "flagstaff",
+    "flame",
+    "flashlight",
+    "floor",
+    "flower",
+    "fly",
+    "foam",
+    "food",
+    "footbridge",
+    "forceps",
+    "fork",
+    "forklift",
+    "fountain",
+    "fox",
+    "frame",
+    "fridge",
+    "frog",
+    "fruit",
+    "funnel",
+    "furnace",
+    "game controller",
+    "game machine",
+    "gas cylinder",
+    "gas hood",
+    "gas stove",
+    "gift box",
+    "glass",
+    "glass marble",
+    "globe",
+    "glove",
+    "goal",
+    "grandstand",
+    "grass",
+    "gravestone",
+    "ground",
+    "guardrail",
+    "guitar",
+    "gun",
+    "hammer",
+    "hand cart",
+    "handle",
+    "handrail",
+    "hanger",
+    "hard disk drive",
+    "hat",
+    "hay",
+    "headphone",
+    "heater",
+    "helicopter",
+    "helmet",
+    "holder",
+    "hook",
+    "horse",
+    "horse-drawn carriage",
+    "hot-air balloon",
+    "hydrovalve",
+    "ice",
+    "inflator pump",
+    "ipod",
+    "iron",
+    "ironing board",
+    "jar",
+    "kart",
+    "kettle",
+    "key",
+    "keyboard",
+    "kitchen range",
+    "kite",
+    "knife",
+    "knife block",
+    "ladder",
+    "ladder truck",
+    "ladle",
+    "laptop",
+    "leaves",
+    "lid",
+    "life buoy",
+    "light",
+    "light bulb",
+    "lighter",
+    "line",
+    "lion",
+    "lobster",
+    "lock",
+    "machine",
+    "mailbox",
+    "mannequin",
+    "map",
+    "mask",
+    "mat",
+    "match book",
+    "mattress",
+    "menu",
+    "metal",
+    "meter box",
+    "microphone",
+    "microwave",
+    "mirror",
+    "missile",
+    "model",
+    "money",
+    "monkey",
+    "mop",
+    "motorbike",
+    "mountain",
+    "mouse",
+    "mouse pad",
+    "musical instrument",
+    "napkin",
+    "net",
+    "newspaper",
+    "oar",
+    "ornament",
+    "outlet",
+    "oven",
+    "oxygen bottle",
+    "pack",
+    "pan",
+    "paper",
+    "paper box",
+    "paper cutter",
+    "parachute",
+    "parasol",
+    "parterre",
+    "patio",
+    "pelage",
+    "pen",
+    "pen container",
+    "pencil",
+    "person",
+    "photo",
+    "piano",
+    "picture",
+    "pig",
+    "pillar",
+    "pillow",
+    "pipe",
+    "pitcher",
+    "plant",
+    "plastic",
+    "plate",
+    "platform",
+    "player",
+    "playground",
+    "pliers",
+    "plume",
+    "poker",
+    "poker chip",
+    "pole",
+    "pool table",
+    "postcard",
+    "poster",
+    "pot",
+    "pottedplant",
+    "printer",
+    "projector",
+    "pumpkin",
+    "rabbit",
+    "racket",
+    "radiator",
+    "radio",
+    "rail",
+    "rake",
+    "ramp",
+    "range hood",
+    "receiver",
+    "recorder",
+    "recreational machines",
+    "remote control",
+    "road",
+    "robot",
+    "rock",
+    "rocket",
+    "rocking horse",
+    "rope",
+    "rug",
+    "ruler",
+    "runway",
+    "saddle",
+    "sand",
+    "saw",
+    "scale",
+    "scanner",
+    "scissors",
+    "scoop",
+    "screen",
+    "screwdriver",
+    "sculpture",
+    "scythe",
+    "sewer",
+    "sewing machine",
+    "shed",
+    "sheep",
+    "shell",
+    "shelves",
+    "shoe",
+    "shopping cart",
+    "shovel",
+    "sidecar",
+    "sidewalk",
+    "sign",
+    "signal light",
+    "sink",
+    "skateboard",
+    "ski",
+    "sky",
+    "sled",
+    "slippers",
+    "smoke",
+    "snail",
+    "snake",
+    "snow",
+    "snowmobiles",
+    "sofa",
+    "spanner",
+    "spatula",
+    "speaker",
+    "speed bump",
+    "spice container",
+    "spoon",
+    "sprayer",
+    "squirrel",
+    "stage",
+    "stair",
+    "stapler",
+    "stick",
+    "sticky note",
+    "stone",
+    "stool",
+    "stove",
+    "straw",
+    "stretcher",
+    "sun",
+    "sunglass",
+    "sunshade",
+    "surveillance camera",
+    "swan",
+    "sweeper",
+    "swim ring",
+    "swimming pool",
+    "swing",
+    "switch",
+    "table",
+    "tableware",
+    "tank",
+    "tap",
+    "tape",
+    "tarp",
+    "telephone",
+    "telephone booth",
+    "tent",
+    "tire",
+    "toaster",
+    "toilet",
+    "tong",
+    "tool",
+    "toothbrush",
+    "towel",
+    "toy",
+    "toy car",
+    "track",
+    "train",
+    "trampoline",
+    "trash bin",
+    "tray",
+    "tree",
+    "tricycle",
+    "tripod",
+    "trophy",
+    "truck",
+    "tube",
+    "turtle",
+    "tvmonitor",
+    "tweezers",
+    "typewriter",
+    "umbrella",
+    "unknown",
+    "vacuum cleaner",
+    "vending machine",
+    "video camera",
+    "video game console",
+    "video player",
+    "video tape",
+    "violin",
+    "wakeboard",
+    "wall",
+    "wallet",
+    "wardrobe",
+    "washing machine",
+    "watch",
+    "water",
+    "water dispenser",
+    "water pipe",
+    "water skate board",
+    "watermelon",
+    "whale",
+    "wharf",
+    "wheel",
+    "wheelchair",
+    "window",
+    "window blinds",
+    "wineglass",
+    "wire",
+    "wood",
+    "wool",
+
+)
+
+
+def _get_voc_meta(cat_list):
+    ret = {
+        "stuff_classes": cat_list,
+    }
+    return ret
+
+
+def register_pascal_context_59(root):
+    root = os.path.join(root, "VOCdevkit/VOC2010")
+    meta = _get_voc_meta(PASCALCONTEX59_NAMES)
+    for name, image_dirname, sem_seg_dirname in [
+        ("val", "JPEGImages", "annotations_detectron2/pc59_val"),
+    ]:
+        image_dir = os.path.join(root, image_dirname)
+        gt_dir = os.path.join(root, sem_seg_dirname)
+        all_name = f"pascal_context_59_sem_seg_{name}"
+        DatasetCatalog.register(
+            all_name,
+            lambda x=image_dir, y=gt_dir: load_sem_seg(
+                y, x, gt_ext="png", image_ext="jpg"
+            ),
+        )
+        MetadataCatalog.get(all_name).set(
+            image_root=image_dir,
+            sem_seg_root=gt_dir,
+            evaluator_type="sem_seg",
+            ignore_label=255,
+            **meta,
+        )
+
+def register_pascal_context_459(root):
+    root = os.path.join(root, "VOCdevkit/VOC2010")
+    meta = _get_voc_meta(PASCALCONTEX459_NAMES)
+    for name, image_dirname, sem_seg_dirname in [
+        ("val", "JPEGImages", "annotations_detectron2/pc459_val"),
+    ]:
+        image_dir = os.path.join(root, image_dirname)
+        gt_dir = os.path.join(root, sem_seg_dirname)
+        all_name = f"pascal_context_459_sem_seg_{name}"
+        DatasetCatalog.register(
+            all_name,
+            lambda x=image_dir, y=gt_dir: load_sem_seg(
+                y, x, gt_ext="tif", image_ext="jpg"
+            ),
+        )
+        MetadataCatalog.get(all_name).set(
+            image_root=image_dir,
+            sem_seg_root=gt_dir,
+            evaluator_type="sem_seg",
+            ignore_label=65535,  # NOTE: gt is saved in 16-bit TIFF images
+            **meta,
+        )
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_pascal_context_59(_root)
+register_pascal_context_459(_root)
diff --git a/open_vocab_seg/data/datasets/register_voc_seg.py b/open_vocab_seg/data/datasets/register_voc_seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8c2be16f4bb5348de8f1051f3579e02e362488f
--- /dev/null
+++ b/open_vocab_seg/data/datasets/register_voc_seg.py
@@ -0,0 +1,62 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import os
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.data.datasets import load_sem_seg
+
+PASCALVOC20_NAMES = (
+    "aeroplane",
+    "bicycle",
+    "bird",
+    "boat",
+    "bottle",
+    "bus",
+    "car",
+    "cat",
+    "chair",
+    "cow",
+    "diningtable",
+    "dog",
+    "horse",
+    "motorbike",
+    "person",
+    "pottedplant",
+    "sheep",
+    "sofa",
+    "train",
+    "tvmonitor",
+)
+
+def _get_voc_meta(cat_list):
+    ret = {
+        "stuff_classes": cat_list,
+    }
+    return ret
+
+
+def register_pascalvoc(root):
+    root = os.path.join(root, "VOCdevkit/VOC2012")
+    meta = _get_voc_meta(PASCALVOC20_NAMES)
+
+    for name, image_dirname, sem_seg_dirname in [
+        ("val", "JPEGImages", "annotations_detectron2/val"),
+    ]:
+        image_dir = os.path.join(root, image_dirname)
+        gt_dir = os.path.join(root, sem_seg_dirname)
+        all_name = f"pascalvoc20_sem_seg_{name}"
+        DatasetCatalog.register(
+            all_name,
+            lambda x=image_dir, y=gt_dir: load_sem_seg(
+                y, x, gt_ext="png", image_ext="jpg"
+            ),
+        )
+        MetadataCatalog.get(all_name).set(
+            image_root=image_dir,
+            sem_seg_root=gt_dir,
+            evaluator_type="sem_seg",
+            ignore_label=255,
+            **meta,
+        )
+
+_root = os.getenv("DETECTRON2_DATASETS", "datasets")
+register_pascalvoc(_root)
diff --git a/open_vocab_seg/evaluation/__init__.py b/open_vocab_seg/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9d36d8e9659a1d31471273a6a0f82c2642ea982
--- /dev/null
+++ b/open_vocab_seg/evaluation/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from .generalized_sem_seg_evaluation import GeneralizedSemSegEvaluator
diff --git a/open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py b/open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce960ae7cbffde4a981be941ed03a8fc7025ed80
--- /dev/null
+++ b/open_vocab_seg/evaluation/generalized_sem_seg_evaluation.py
@@ -0,0 +1,159 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import itertools
+import json
+import numpy as np
+import os
+from collections import OrderedDict
+import PIL.Image as Image
+import torch
+
+from detectron2.data import DatasetCatalog, MetadataCatalog
+from detectron2.utils.comm import all_gather, is_main_process, synchronize
+from detectron2.utils.file_io import PathManager
+
+from detectron2.evaluation import SemSegEvaluator
+
+
+class GeneralizedSemSegEvaluator(SemSegEvaluator):
+    """
+    Evaluate semantic segmentation metrics.
+    """
+
+    def __init__(
+        self,
+        dataset_name,
+        distributed=True,
+        output_dir=None,
+        *,
+        num_classes=None,
+        ignore_label=None,
+        post_process_func=None,
+    ):
+        super().__init__(
+            dataset_name,
+            distributed=distributed,
+            output_dir=output_dir,
+            num_classes=num_classes,
+            ignore_label=ignore_label,
+        )
+        meta = MetadataCatalog.get(dataset_name)
+        try:
+            self._evaluation_set = meta.evaluation_set
+        except AttributeError:
+            self._evaluation_set = None
+        self.post_process_func = (
+            post_process_func
+            if post_process_func is not None
+            else lambda x, **kwargs: x
+        )
+
+    def process(self, inputs, outputs):
+        """
+        Args:
+            inputs: the inputs to a model.
+                It is a list of dicts. Each dict corresponds to an image and
+                contains keys like "height", "width", "file_name".
+            outputs: the outputs of a model. It is either list of semantic segmentation predictions
+                (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic
+                segmentation prediction in the same format.
+        """
+        for input, output in zip(inputs, outputs):
+            output = self.post_process_func(
+                output["sem_seg"], image=np.array(Image.open(input["file_name"]))
+            )
+            output = output.argmax(dim=0).to(self._cpu_device)
+            pred = np.array(output, dtype=np.int)
+            with PathManager.open(
+                self.input_file_to_gt_file[input["file_name"]], "rb"
+            ) as f:
+                gt = np.array(Image.open(f), dtype=np.int)
+
+            gt[gt == self._ignore_label] = self._num_classes
+
+            self._conf_matrix += np.bincount(
+                (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1),
+                minlength=self._conf_matrix.size,
+            ).reshape(self._conf_matrix.shape)
+
+            self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"]))
+
+    def evaluate(self):
+        """
+        Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval):
+
+        * Mean intersection-over-union averaged across classes (mIoU)
+        * Frequency Weighted IoU (fwIoU)
+        * Mean pixel accuracy averaged across classes (mACC)
+        * Pixel Accuracy (pACC)
+        """
+        if self._distributed:
+            synchronize()
+            conf_matrix_list = all_gather(self._conf_matrix)
+            self._predictions = all_gather(self._predictions)
+            self._predictions = list(itertools.chain(*self._predictions))
+            if not is_main_process():
+                return
+
+            self._conf_matrix = np.zeros_like(self._conf_matrix)
+            for conf_matrix in conf_matrix_list:
+                self._conf_matrix += conf_matrix
+
+        if self._output_dir:
+            PathManager.mkdirs(self._output_dir)
+            file_path = os.path.join(self._output_dir, "sem_seg_predictions.json")
+            with PathManager.open(file_path, "w") as f:
+                f.write(json.dumps(self._predictions))
+
+        acc = np.full(self._num_classes, np.nan, dtype=np.float)
+        iou = np.full(self._num_classes, np.nan, dtype=np.float)
+        tp = self._conf_matrix.diagonal()[:-1].astype(np.float)
+        pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float)
+        class_weights = pos_gt / np.sum(pos_gt)
+        pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float)
+        acc_valid = pos_gt > 0
+        acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid]
+        iou_valid = (pos_gt + pos_pred) > 0
+        union = pos_gt + pos_pred - tp
+        iou[acc_valid] = tp[acc_valid] / union[acc_valid]
+        macc = np.sum(acc[acc_valid]) / np.sum(acc_valid)
+        miou = np.sum(iou[acc_valid]) / np.sum(iou_valid)
+        fiou = np.sum(iou[acc_valid] * class_weights[acc_valid])
+        pacc = np.sum(tp) / np.sum(pos_gt)
+
+        res = {}
+        res["mIoU"] = 100 * miou
+        res["fwIoU"] = 100 * fiou
+        for i, name in enumerate(self._class_names):
+            res["IoU-{}".format(name)] = 100 * iou[i]
+        res["mACC"] = 100 * macc
+        res["pACC"] = 100 * pacc
+        for i, name in enumerate(self._class_names):
+            res["ACC-{}".format(name)] = 100 * acc[i]
+        if self._evaluation_set is not None:
+            for set_name, set_inds in self._evaluation_set.items():
+                iou_list = []
+                set_inds = np.array(set_inds, np.int)
+                mask = np.zeros((len(iou),)).astype(np.bool)
+                mask[set_inds] = 1
+                miou = np.sum(iou[mask][acc_valid[mask]]) / np.sum(iou_valid[mask])
+                pacc = np.sum(tp[mask]) / np.sum(pos_gt[mask])
+                res["mIoU-{}".format(set_name)] = 100 * miou
+                res["pAcc-{}".format(set_name)] = 100 * pacc
+                iou_list.append(miou)
+                miou = np.sum(iou[~mask][acc_valid[~mask]]) / np.sum(iou_valid[~mask])
+                pacc = np.sum(tp[~mask]) / np.sum(pos_gt[~mask])
+                res["mIoU-un{}".format(set_name)] = 100 * miou
+                res["pAcc-un{}".format(set_name)] = 100 * pacc
+                iou_list.append(miou)
+                res["hIoU-{}".format(set_name)] = (
+                    100 * len(iou_list) / sum([1 / iou for iou in iou_list])
+                )
+        if self._output_dir:
+            file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth")
+            with PathManager.open(file_path, "wb") as f:
+                torch.save(res, f)
+        results = OrderedDict({"sem_seg": res})
+        self._logger.info(results)
+        return results
diff --git a/open_vocab_seg/mask_former_model.py b/open_vocab_seg/mask_former_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3708d65de4695368b1d088abde4bdf4a9fa39b2b
--- /dev/null
+++ b/open_vocab_seg/mask_former_model.py
@@ -0,0 +1,254 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from typing import Tuple
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, build_sem_seg_head
+from detectron2.modeling.backbone import Backbone
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+from detectron2.structures import ImageList
+
+from .modeling.criterion import SetCriterion
+from .modeling.matcher import HungarianMatcher
+
+
+@META_ARCH_REGISTRY.register()
+class MaskFormer(nn.Module):
+    """
+    Main class for mask classification semantic segmentation architectures.
+    """
+
+    @configurable
+    def __init__(
+        self,
+        *,
+        backbone: Backbone,
+        sem_seg_head: nn.Module,
+        criterion: nn.Module,
+        num_queries: int,
+        panoptic_on: bool,
+        object_mask_threshold: float,
+        overlap_threshold: float,
+        metadata,
+        size_divisibility: int,
+        sem_seg_postprocess_before_inference: bool,
+        pixel_mean: Tuple[float],
+        pixel_std: Tuple[float],
+    ):
+        """
+        Args:
+            backbone: a backbone module, must follow detectron2's backbone interface
+            sem_seg_head: a module that predicts semantic segmentation from backbone features
+            criterion: a module that defines the loss
+            num_queries: int, number of queries
+            panoptic_on: bool, whether to output panoptic segmentation prediction
+            object_mask_threshold: float, threshold to filter query based on classification score
+                for panoptic segmentation inference
+            overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+            metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+                segmentation inference
+            size_divisibility: Some backbones require the input height and width to be divisible by a
+                specific integer. We can use this to override such requirement.
+            sem_seg_postprocess_before_inference: whether to resize the prediction back
+                to original input size before semantic segmentation inference or after.
+                For high-resolution dataset like Mapillary, resizing predictions before
+                inference will cause OOM error.
+            pixel_mean, pixel_std: list or tuple with #channels element, representing
+                the per-channel mean and std to be used to normalize the input image
+        """
+        super().__init__()
+        self.backbone = backbone
+        self.sem_seg_head = sem_seg_head
+        self.criterion = criterion
+        self.num_queries = num_queries
+        self.overlap_threshold = overlap_threshold
+        self.panoptic_on = panoptic_on
+        self.object_mask_threshold = object_mask_threshold
+        self.metadata = metadata
+        if size_divisibility < 0:
+            # use backbone size_divisibility if not set
+            size_divisibility = self.backbone.size_divisibility
+        self.size_divisibility = size_divisibility
+        self.sem_seg_postprocess_before_inference = sem_seg_postprocess_before_inference
+        self.register_buffer(
+            "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
+        )
+        self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
+
+    @classmethod
+    def from_config(cls, cfg):
+        backbone = build_backbone(cfg)
+        sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape())
+
+        # Loss parameters:
+        deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+        no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
+        dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
+        mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
+
+        # building criterion
+        matcher = HungarianMatcher(
+            cost_class=1,
+            cost_mask=mask_weight,
+            cost_dice=dice_weight,
+        )
+
+        weight_dict = {"loss_ce": 1, "loss_mask": mask_weight, "loss_dice": dice_weight}
+        if deep_supervision:
+            dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+            aux_weight_dict = {}
+            for i in range(dec_layers - 1):
+                aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
+            weight_dict.update(aux_weight_dict)
+
+        losses = ["labels", "masks"]
+
+        criterion = SetCriterion(
+            sem_seg_head.num_classes,
+            matcher=matcher,
+            weight_dict=weight_dict,
+            eos_coef=no_object_weight,
+            losses=losses,
+        )
+
+        return {
+            "backbone": backbone,
+            "sem_seg_head": sem_seg_head,
+            "criterion": criterion,
+            "num_queries": cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES,
+            "panoptic_on": cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON,
+            "object_mask_threshold": cfg.MODEL.MASK_FORMER.TEST.OBJECT_MASK_THRESHOLD,
+            "overlap_threshold": cfg.MODEL.MASK_FORMER.TEST.OVERLAP_THRESHOLD,
+            "metadata": MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
+            "size_divisibility": cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY,
+            "sem_seg_postprocess_before_inference": (
+                cfg.MODEL.MASK_FORMER.TEST.SEM_SEG_POSTPROCESSING_BEFORE_INFERENCE
+                or cfg.MODEL.MASK_FORMER.TEST.PANOPTIC_ON
+            ),
+            "pixel_mean": cfg.MODEL.PIXEL_MEAN,
+            "pixel_std": cfg.MODEL.PIXEL_STD,
+        }
+
+    @property
+    def device(self):
+        return self.pixel_mean.device
+
+    def forward(self, batched_inputs):
+        """
+        Args:
+            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+                Each item in the list contains the inputs for one image.
+                For now, each item in the list is a dict that contains:
+                   * "image": Tensor, image in (C, H, W) format.
+                   * "instances": per-region ground truth
+                   * Other information that's included in the original dicts, such as:
+                     "height", "width" (int): the output resolution of the model (may be different
+                     from input resolution), used in inference.
+        Returns:
+            list[dict]:
+                each dict has the results for one image. The dict contains the following keys:
+
+                * "sem_seg":
+                    A Tensor that represents the
+                    per-pixel segmentation prediced by the head.
+                    The prediction has shape KxHxW that represents the logits of
+                    each class for each pixel.
+                * "panoptic_seg":
+                    A tuple that represent panoptic output
+                    panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+                    segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+                        Each dict contains keys "id", "category_id", "isthing".
+        """
+        images = [x["image"].to(self.device) for x in batched_inputs]
+        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+        images = ImageList.from_tensors(images, self.size_divisibility)
+
+        features = self.backbone(images.tensor)
+        outputs = self.sem_seg_head(features)
+
+        if self.training:
+            # mask classification target
+            if "instances" in batched_inputs[0]:
+                gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
+                targets = self.prepare_targets(gt_instances, images)
+            else:
+                targets = None
+
+            # bipartite matching-based loss
+            losses = self.criterion(outputs, targets)
+
+            for k in list(losses.keys()):
+                if k in self.criterion.weight_dict:
+                    losses[k] *= self.criterion.weight_dict[k]
+                else:
+                    # remove this loss if not specified in `weight_dict`
+                    losses.pop(k)
+
+            return losses
+        else:
+            mask_cls_results = outputs["pred_logits"]
+            mask_pred_results = outputs["pred_masks"]
+            # upsample masks
+            mask_pred_results = F.interpolate(
+                mask_pred_results,
+                size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+                mode="bilinear",
+                align_corners=False,
+            )
+
+            processed_results = []
+            for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
+                mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
+            ):
+                height = input_per_image.get("height", image_size[0])
+                width = input_per_image.get("width", image_size[1])
+
+                if self.sem_seg_postprocess_before_inference:
+                    mask_pred_result = sem_seg_postprocess(
+                        mask_pred_result, image_size, height, width
+                    )
+
+                # semantic segmentation inference
+                r = self.semantic_inference(mask_cls_result, mask_pred_result)
+                if not self.sem_seg_postprocess_before_inference:
+                    r = sem_seg_postprocess(r, image_size, height, width)
+                processed_results.append({"sem_seg": r})
+
+                # panoptic segmentation inference
+                if self.panoptic_on:
+                    panoptic_r = self.panoptic_inference(
+                        mask_cls_result, mask_pred_result
+                    )
+                    processed_results[-1]["panoptic_seg"] = panoptic_r
+
+            return processed_results
+
+    def prepare_targets(self, targets, images):
+        h, w = images.tensor.shape[-2:]
+        new_targets = []
+        for targets_per_image in targets:
+            # pad gt
+            gt_masks = targets_per_image.gt_masks
+            padded_masks = torch.zeros(
+                (gt_masks.shape[0], h, w), dtype=gt_masks.dtype, device=gt_masks.device
+            )
+            padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
+            new_targets.append(
+                {
+                    "labels": targets_per_image.gt_classes,
+                    "masks": padded_masks,
+                }
+            )
+        return new_targets
+
+    def semantic_inference(self, mask_cls, mask_pred):
+        mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+        mask_pred = mask_pred.sigmoid()
+        semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+        return semseg
diff --git a/open_vocab_seg/modeling/.DS_Store b/open_vocab_seg/modeling/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..22e04d81d5a0756554382506a89d270c4397faa6
Binary files /dev/null and b/open_vocab_seg/modeling/.DS_Store differ
diff --git a/open_vocab_seg/modeling/__init__.py b/open_vocab_seg/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4dd2628880e93338b39b0b6562b2a5838692b5
--- /dev/null
+++ b/open_vocab_seg/modeling/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from .backbone.swin import D2SwinTransformer
+from .backbone.clip_resnet import D2ModifiedResNet
+from .heads.mask_former_head import MaskFormerHead
+from .heads.open_vocab_mask_former_head import OpenVocabMaskFormerHead
+from .heads.pixel_decoder import BasePixelDecoder
diff --git a/open_vocab_seg/modeling/backbone/__init__.py b/open_vocab_seg/modeling/backbone/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..49f9003b7a688f5396170dd89c26ef335a2c201f
--- /dev/null
+++ b/open_vocab_seg/modeling/backbone/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
diff --git a/open_vocab_seg/modeling/backbone/clip_resnet.py b/open_vocab_seg/modeling/backbone/clip_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d40d88c1eac79a873a1396f7203b3555c68a364
--- /dev/null
+++ b/open_vocab_seg/modeling/backbone/clip_resnet.py
@@ -0,0 +1,206 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, dilation=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+
+        self.conv2 = nn.Conv2d(
+            planes, planes, 3, padding=1 * dilation, bias=False, dilation=dilation
+        )
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(
+                OrderedDict(
+                    [
+                        ("-1", nn.AvgPool2d(stride)),
+                        (
+                            "0",
+                            nn.Conv2d(
+                                inplanes,
+                                planes * self.expansion,
+                                1,
+                                stride=1,
+                                bias=False,
+                            ),
+                        ),
+                        ("1", nn.BatchNorm2d(planes * self.expansion)),
+                    ]
+                )
+            )
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu(self.bn1(self.conv1(x)))
+        out = self.relu(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+        return out
+
+
+class ModifiedResNet(nn.Module):
+    """
+    A ResNet class that is similar to torchvision's but contains the following changes:
+    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+    - The final pooling layer is a QKV attention instead of an average pool
+    """
+
+    def __init__(self, layers, width=64, strides=[2, 1, 2, 2, 2], multi_grid=[1, 1, 1]):
+        super().__init__()
+
+        # the 3-layer stem
+        self.conv1 = nn.Conv2d(
+            3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
+        )
+        self.bn1 = nn.BatchNorm2d(width // 2)
+        self.conv2 = nn.Conv2d(
+            width // 2, width // 2, kernel_size=3, padding=1, bias=False
+        )
+        self.bn2 = nn.BatchNorm2d(width // 2)
+        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(width)
+        self.avgpool = nn.AvgPool2d(strides[0]) if strides[0] > 1 else nn.Identity()
+        self.relu = nn.ReLU(inplace=True)
+
+        # residual layers
+        self._inplanes = width  # this is a *mutable* variable used during construction
+        self.layer1 = self._make_layer(width, layers[0], stride=strides[1])
+        self.layer2 = self._make_layer(width * 2, layers[1], stride=strides[2])
+        self.layer3 = self._make_layer(width * 4, layers[2], stride=strides[3])
+        self.layer4 = self._make_layer(
+            width * 8, layers[3], stride=strides[4], dilations=multi_grid
+        )
+        self.num_features = [width * 4, width * 8, width * 16, width * 32]
+
+    def _make_layer(self, planes, blocks, stride=1, dilations=None):
+        if dilations is None:
+            dilations = [1] * blocks
+        layers = [Bottleneck(self._inplanes, planes, stride, dilation=dilations[0])]
+        self._inplanes = planes * Bottleneck.expansion
+
+        for i in range(1, blocks):
+            layers.append(Bottleneck(self._inplanes, planes, dilation=dilations[i]))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        def stem(x):
+            for conv, bn in [
+                (self.conv1, self.bn1),
+                (self.conv2, self.bn2),
+                (self.conv3, self.bn3),
+            ]:
+                x = self.relu(bn(conv(x)))
+            x = self.avgpool(x)
+            return x
+
+        output = {}
+        x = x.type(self.conv1.weight.dtype)
+        x = stem(x)  # 1/4,1/4
+        x = self.layer1(x)
+        output["res2"] = x
+        x = self.layer2(x)  # 1/8,1/8
+        output["res3"] = x
+        x = self.layer3(x)  # 1/16,1/16
+        output["res4"] = x
+        x = self.layer4(x)  # 1/32,1/32
+        output["res5"] = x
+        return output
+
+
+@BACKBONE_REGISTRY.register()
+class D2ModifiedResNet(ModifiedResNet, Backbone):
+    def __init__(self, cfg, input_shape):
+        depth = cfg.MODEL.RESNETS.DEPTH
+        num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
+        width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
+        bottleneck_channels = num_groups * width_per_group
+        num_blocks_per_stage = {
+            18: [2, 2, 2, 2],
+            34: [3, 4, 6, 3],
+            50: [3, 4, 6, 3],
+            101: [3, 4, 23, 3],
+            152: [3, 8, 36, 3],
+        }[depth]
+        strides = [2, 1, 2, 2, 2]
+        multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID
+        if cfg.MODEL.RESNETS.STEM_TYPE == "deeplab":
+            strides = [1, 1, 2, 2, 2]
+        super().__init__(
+            num_blocks_per_stage,
+            bottleneck_channels,
+            strides=strides,
+            multi_grid=multi_grid,
+        )
+        self._out_features = cfg.MODEL.RESNETS.OUT_FEATURES
+
+        self._out_feature_strides = {
+            "res2": 4,
+            "res3": 8,
+            "res4": 16,
+            "res5": 32,
+        }
+        self._out_feature_channels = {
+            "res2": self.num_features[0],
+            "res3": self.num_features[1],
+            "res4": self.num_features[2],
+            "res5": self.num_features[3],
+        }
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+        Returns:
+            dict[str->Tensor]: names and the corresponding features
+        """
+        outputs = {}
+        y = super().forward(x)
+        for k in y.keys():
+            if k in self._out_features:
+                outputs[k] = y[k]
+        return outputs
+
+    def output_shape(self):
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name],
+                stride=self._out_feature_strides[name],
+            )
+            for name in self._out_features
+        }
+
+    @property
+    def size_divisibility(self):
+        return 32
diff --git a/open_vocab_seg/modeling/backbone/swin.py b/open_vocab_seg/modeling/backbone/swin.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa651bdab51bb353e3be4b5554f41e251803d5cb
--- /dev/null
+++ b/open_vocab_seg/modeling/backbone/swin.py
@@ -0,0 +1,832 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation/blob/main/mmseg/models/backbones/swin_transformer.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
+
+
+class Mlp(nn.Module):
+    """Multilayer perceptron."""
+
+    def __init__(
+        self,
+        in_features,
+        hidden_features=None,
+        out_features=None,
+        act_layer=nn.GELU,
+        drop=0.0,
+    ):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x
+
+
+def window_partition(x, window_size):
+    """
+    Args:
+        x: (B, H, W, C)
+        window_size (int): window size
+    Returns:
+        windows: (num_windows*B, window_size, window_size, C)
+    """
+    B, H, W, C = x.shape
+    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+    windows = (
+        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+    )
+    return windows
+
+
+def window_reverse(windows, window_size, H, W):
+    """
+    Args:
+        windows: (num_windows*B, window_size, window_size, C)
+        window_size (int): Window size
+        H (int): Height of image
+        W (int): Width of image
+    Returns:
+        x: (B, H, W, C)
+    """
+    B = int(windows.shape[0] / (H * W / window_size / window_size))
+    x = windows.view(
+        B, H // window_size, W // window_size, window_size, window_size, -1
+    )
+    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+    return x
+
+
+class WindowAttention(nn.Module):
+    """Window based multi-head self attention (W-MSA) module with relative position bias.
+    It supports both of shifted and non-shifted window.
+    Args:
+        dim (int): Number of input channels.
+        window_size (tuple[int]): The height and width of the window.
+        num_heads (int): Number of attention heads.
+        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+    """
+
+    def __init__(
+        self,
+        dim,
+        window_size,
+        num_heads,
+        qkv_bias=True,
+        qk_scale=None,
+        attn_drop=0.0,
+        proj_drop=0.0,
+    ):
+
+        super().__init__()
+        self.dim = dim
+        self.window_size = window_size  # Wh, Ww
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = qk_scale or head_dim ** -0.5
+
+        # define a parameter table of relative position bias
+        self.relative_position_bias_table = nn.Parameter(
+            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+        )  # 2*Wh-1 * 2*Ww-1, nH
+
+        # get pair-wise relative position index for each token inside the window
+        coords_h = torch.arange(self.window_size[0])
+        coords_w = torch.arange(self.window_size[1])
+        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
+        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
+        relative_coords = (
+            coords_flatten[:, :, None] - coords_flatten[:, None, :]
+        )  # 2, Wh*Ww, Wh*Ww
+        relative_coords = relative_coords.permute(
+            1, 2, 0
+        ).contiguous()  # Wh*Ww, Wh*Ww, 2
+        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
+        relative_coords[:, :, 1] += self.window_size[1] - 1
+        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
+        self.register_buffer("relative_position_index", relative_position_index)
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+        trunc_normal_(self.relative_position_bias_table, std=0.02)
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, x, mask=None):
+        """Forward function.
+        Args:
+            x: input features with shape of (num_windows*B, N, C)
+            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+        """
+        B_, N, C = x.shape
+        qkv = (
+            self.qkv(x)
+            .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
+            .permute(2, 0, 3, 1, 4)
+        )
+        q, k, v = (
+            qkv[0],
+            qkv[1],
+            qkv[2],
+        )  # make torchscript happy (cannot use tensor as tuple)
+
+        q = q * self.scale
+        attn = q @ k.transpose(-2, -1)
+
+        relative_position_bias = self.relative_position_bias_table[
+            self.relative_position_index.view(-1)
+        ].view(
+            self.window_size[0] * self.window_size[1],
+            self.window_size[0] * self.window_size[1],
+            -1,
+        )  # Wh*Ww,Wh*Ww,nH
+        relative_position_bias = relative_position_bias.permute(
+            2, 0, 1
+        ).contiguous()  # nH, Wh*Ww, Wh*Ww
+        attn = attn + relative_position_bias.unsqueeze(0)
+
+        if mask is not None:
+            nW = mask.shape[0]
+            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
+                1
+            ).unsqueeze(0)
+            attn = attn.view(-1, self.num_heads, N, N)
+            attn = self.softmax(attn)
+        else:
+            attn = self.softmax(attn)
+
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class SwinTransformerBlock(nn.Module):
+    """Swin Transformer Block.
+    Args:
+        dim (int): Number of input channels.
+        num_heads (int): Number of attention heads.
+        window_size (int): Window size.
+        shift_size (int): Shift size for SW-MSA.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float, optional): Stochastic depth rate. Default: 0.0
+        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(
+        self,
+        dim,
+        num_heads,
+        window_size=7,
+        shift_size=0,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        act_layer=nn.GELU,
+        norm_layer=nn.LayerNorm,
+    ):
+        super().__init__()
+        self.dim = dim
+        self.num_heads = num_heads
+        self.window_size = window_size
+        self.shift_size = shift_size
+        self.mlp_ratio = mlp_ratio
+        assert (
+            0 <= self.shift_size < self.window_size
+        ), "shift_size must in 0-window_size"
+
+        self.norm1 = norm_layer(dim)
+        self.attn = WindowAttention(
+            dim,
+            window_size=to_2tuple(self.window_size),
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            qk_scale=qk_scale,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+        )
+
+        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = Mlp(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+            drop=drop,
+        )
+
+        self.H = None
+        self.W = None
+
+    def forward(self, x, mask_matrix):
+        """Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+            mask_matrix: Attention mask for cyclic shift.
+        """
+        B, L, C = x.shape
+        H, W = self.H, self.W
+        assert L == H * W, "input feature has wrong size"
+
+        shortcut = x
+        x = self.norm1(x)
+        x = x.view(B, H, W, C)
+
+        # pad feature maps to multiples of window size
+        pad_l = pad_t = 0
+        pad_r = (self.window_size - W % self.window_size) % self.window_size
+        pad_b = (self.window_size - H % self.window_size) % self.window_size
+        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+        _, Hp, Wp, _ = x.shape
+
+        # cyclic shift
+        if self.shift_size > 0:
+            shifted_x = torch.roll(
+                x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
+            )
+            attn_mask = mask_matrix
+        else:
+            shifted_x = x
+            attn_mask = None
+
+        # partition windows
+        x_windows = window_partition(
+            shifted_x, self.window_size
+        )  # nW*B, window_size, window_size, C
+        x_windows = x_windows.view(
+            -1, self.window_size * self.window_size, C
+        )  # nW*B, window_size*window_size, C
+
+        # W-MSA/SW-MSA
+        attn_windows = self.attn(
+            x_windows, mask=attn_mask
+        )  # nW*B, window_size*window_size, C
+
+        # merge windows
+        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+        shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp)  # B H' W' C
+
+        # reverse cyclic shift
+        if self.shift_size > 0:
+            x = torch.roll(
+                shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
+            )
+        else:
+            x = shifted_x
+
+        if pad_r > 0 or pad_b > 0:
+            x = x[:, :H, :W, :].contiguous()
+
+        x = x.view(B, H * W, C)
+
+        # FFN
+        x = shortcut + self.drop_path(x)
+        x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+        return x
+
+
+class PatchMerging(nn.Module):
+    """Patch Merging Layer
+    Args:
+        dim (int): Number of input channels.
+        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
+    """
+
+    def __init__(self, dim, norm_layer=nn.LayerNorm):
+        super().__init__()
+        self.dim = dim
+        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+        self.norm = norm_layer(4 * dim)
+
+    def forward(self, x, H, W):
+        """Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+        B, L, C = x.shape
+        assert L == H * W, "input feature has wrong size"
+
+        x = x.view(B, H, W, C)
+
+        # padding
+        pad_input = (H % 2 == 1) or (W % 2 == 1)
+        if pad_input:
+            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
+        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
+        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
+        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
+        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
+        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C
+
+        x = self.norm(x)
+        x = self.reduction(x)
+
+        return x
+
+
+class BasicLayer(nn.Module):
+    """A basic Swin Transformer layer for one stage.
+    Args:
+        dim (int): Number of feature channels
+        depth (int): Depths of this stage.
+        num_heads (int): Number of attention head.
+        window_size (int): Local window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+        drop (float, optional): Dropout rate. Default: 0.0
+        attn_drop (float, optional): Attention dropout rate. Default: 0.0
+        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(
+        self,
+        dim,
+        depth,
+        num_heads,
+        window_size=7,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop=0.0,
+        attn_drop=0.0,
+        drop_path=0.0,
+        norm_layer=nn.LayerNorm,
+        downsample=None,
+        use_checkpoint=False,
+    ):
+        super().__init__()
+        self.window_size = window_size
+        self.shift_size = window_size // 2
+        self.depth = depth
+        self.use_checkpoint = use_checkpoint
+
+        # build blocks
+        self.blocks = nn.ModuleList(
+            [
+                SwinTransformerBlock(
+                    dim=dim,
+                    num_heads=num_heads,
+                    window_size=window_size,
+                    shift_size=0 if (i % 2 == 0) else window_size // 2,
+                    mlp_ratio=mlp_ratio,
+                    qkv_bias=qkv_bias,
+                    qk_scale=qk_scale,
+                    drop=drop,
+                    attn_drop=attn_drop,
+                    drop_path=drop_path[i]
+                    if isinstance(drop_path, list)
+                    else drop_path,
+                    norm_layer=norm_layer,
+                )
+                for i in range(depth)
+            ]
+        )
+
+        # patch merging layer
+        if downsample is not None:
+            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+        else:
+            self.downsample = None
+
+    def forward(self, x, H, W):
+        """Forward function.
+        Args:
+            x: Input feature, tensor size (B, H*W, C).
+            H, W: Spatial resolution of the input feature.
+        """
+
+        # calculate attention mask for SW-MSA
+        Hp = int(np.ceil(H / self.window_size)) * self.window_size
+        Wp = int(np.ceil(W / self.window_size)) * self.window_size
+        img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)  # 1 Hp Wp 1
+        h_slices = (
+            slice(0, -self.window_size),
+            slice(-self.window_size, -self.shift_size),
+            slice(-self.shift_size, None),
+        )
+        w_slices = (
+            slice(0, -self.window_size),
+            slice(-self.window_size, -self.shift_size),
+            slice(-self.shift_size, None),
+        )
+        cnt = 0
+        for h in h_slices:
+            for w in w_slices:
+                img_mask[:, h, w, :] = cnt
+                cnt += 1
+
+        mask_windows = window_partition(
+            img_mask, self.window_size
+        )  # nW, window_size, window_size, 1
+        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
+            attn_mask == 0, float(0.0)
+        )
+
+        for blk in self.blocks:
+            blk.H, blk.W = H, W
+            if self.use_checkpoint:
+                x = checkpoint.checkpoint(blk, x, attn_mask)
+            else:
+                x = blk(x, attn_mask)
+        if self.downsample is not None:
+            x_down = self.downsample(x, H, W)
+            Wh, Ww = (H + 1) // 2, (W + 1) // 2
+            return x, H, W, x_down, Wh, Ww
+        else:
+            return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+    """Image to Patch Embedding
+    Args:
+        patch_size (int): Patch token size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        norm_layer (nn.Module, optional): Normalization layer. Default: None
+    """
+
+    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+        super().__init__()
+        patch_size = to_2tuple(patch_size)
+        self.patch_size = patch_size
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.proj = nn.Conv2d(
+            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+        )
+        if norm_layer is not None:
+            self.norm = norm_layer(embed_dim)
+        else:
+            self.norm = None
+
+    def forward(self, x):
+        """Forward function."""
+        # padding
+        _, _, H, W = x.size()
+        if W % self.patch_size[1] != 0:
+            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+        if H % self.patch_size[0] != 0:
+            x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+        x = self.proj(x)  # B C Wh Ww
+        if self.norm is not None:
+            Wh, Ww = x.size(2), x.size(3)
+            x = x.flatten(2).transpose(1, 2)
+            x = self.norm(x)
+            x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+        return x
+
+
+class SwinTransformer(nn.Module):
+    """Swin Transformer backbone.
+        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
+          https://arxiv.org/pdf/2103.14030
+    Args:
+        pretrain_img_size (int): Input image size for training the pretrained model,
+            used in absolute postion embedding. Default 224.
+        patch_size (int | tuple(int)): Patch size. Default: 4.
+        in_chans (int): Number of input image channels. Default: 3.
+        embed_dim (int): Number of linear projection output channels. Default: 96.
+        depths (tuple[int]): Depths of each Swin Transformer stage.
+        num_heads (tuple[int]): Number of attention head of each stage.
+        window_size (int): Window size. Default: 7.
+        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+        drop_rate (float): Dropout rate.
+        attn_drop_rate (float): Attention dropout rate. Default: 0.
+        drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+        patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+        out_indices (Sequence[int]): Output from which stages.
+        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+            -1 means not freezing any parameters.
+        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+    """
+
+    def __init__(
+        self,
+        pretrain_img_size=224,
+        patch_size=4,
+        in_chans=3,
+        embed_dim=96,
+        depths=[2, 2, 6, 2],
+        num_heads=[3, 6, 12, 24],
+        window_size=7,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        qk_scale=None,
+        drop_rate=0.0,
+        attn_drop_rate=0.0,
+        drop_path_rate=0.2,
+        norm_layer=nn.LayerNorm,
+        ape=False,
+        patch_norm=True,
+        out_indices=(0, 1, 2, 3),
+        norm_indices=None,
+        frozen_stages=-1,
+        use_checkpoint=False,
+        projection=False,
+        project_dim=256,
+    ):
+        super().__init__()
+
+        self.pretrain_img_size = pretrain_img_size
+        self.num_layers = len(depths)
+        self.embed_dim = embed_dim
+        self.ape = ape
+        self.patch_norm = patch_norm
+        self.out_indices = out_indices
+        self.norm_indices = norm_indices if norm_indices is not None else out_indices
+        self.frozen_stages = frozen_stages
+
+        # split image into non-overlapping patches
+        self.patch_embed = PatchEmbed(
+            patch_size=patch_size,
+            in_chans=in_chans,
+            embed_dim=embed_dim,
+            norm_layer=norm_layer if self.patch_norm else None,
+        )
+
+        # absolute position embedding
+        if self.ape:
+            pretrain_img_size = to_2tuple(pretrain_img_size)
+            patch_size = to_2tuple(patch_size)
+            patches_resolution = [
+                pretrain_img_size[0] // patch_size[0],
+                pretrain_img_size[1] // patch_size[1],
+            ]
+
+            self.absolute_pos_embed = nn.Parameter(
+                torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])
+            )
+            trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+        self.pos_drop = nn.Dropout(p=drop_rate)
+
+        # stochastic depth
+        dpr = [
+            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+        ]  # stochastic depth decay rule
+
+        # build layers
+        self.layers = nn.ModuleList()
+        for i_layer in range(self.num_layers):
+            layer = BasicLayer(
+                dim=int(embed_dim * 2 ** i_layer),
+                depth=depths[i_layer],
+                num_heads=num_heads[i_layer],
+                window_size=window_size,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                qk_scale=qk_scale,
+                drop=drop_rate,
+                attn_drop=attn_drop_rate,
+                drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+                norm_layer=norm_layer,
+                downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+                use_checkpoint=use_checkpoint,
+            )
+            self.layers.append(layer)
+
+        num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+        self.num_features = num_features
+
+        # add a norm layer for each output
+        for i_layer in self.norm_indices:
+            if i_layer >= len(self.num_features):
+                continue
+            layer = norm_layer(num_features[i_layer])
+            layer_name = f"norm{i_layer}"
+            self.add_module(layer_name, layer)
+        # add projector head
+        self.projection = projection
+        if projection:
+            self.project_dim = project_dim
+            self.norm = norm_layer(self.num_features[-1])
+            self.projector = nn.Linear(self.num_features[-1], project_dim, bias=False)
+        self._freeze_stages()
+
+    def _freeze_stages(self):
+        if self.frozen_stages >= 0:
+            self.patch_embed.eval()
+            for param in self.patch_embed.parameters():
+                param.requires_grad = False
+
+        if self.frozen_stages >= 1 and self.ape:
+            self.absolute_pos_embed.requires_grad = False
+
+        if self.frozen_stages >= 2:
+            self.pos_drop.eval()
+            for i in range(0, self.frozen_stages - 1):
+                m = self.layers[i]
+                m.eval()
+                for param in m.parameters():
+                    param.requires_grad = False
+
+    def init_weights(self, pretrained=None):
+        """Initialize the weights in backbone.
+        Args:
+            pretrained (str, optional): Path to pre-trained weights.
+                Defaults to None.
+        """
+
+        def _init_weights(m):
+            if isinstance(m, nn.Linear):
+                trunc_normal_(m.weight, std=0.02)
+                if isinstance(m, nn.Linear) and m.bias is not None:
+                    nn.init.constant_(m.bias, 0)
+            elif isinstance(m, nn.LayerNorm):
+                nn.init.constant_(m.bias, 0)
+                nn.init.constant_(m.weight, 1.0)
+
+    def forward(self, x):
+        """Forward function."""
+        x = self.patch_embed(x)
+
+        Wh, Ww = x.size(2), x.size(3)
+        if self.ape:
+            # interpolate the position embedding to the corresponding size
+            absolute_pos_embed = F.interpolate(
+                self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic"
+            )
+            x = (x + absolute_pos_embed).flatten(2).transpose(1, 2)  # B Wh*Ww C
+        else:
+            x = x.flatten(2).transpose(1, 2)
+        x = self.pos_drop(x)
+
+        outs = {}
+        for i in range(self.num_layers):
+            layer = self.layers[i]
+            x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+            if i in self.out_indices:
+                if i in self.norm_indices:
+                    norm_layer = getattr(self, f"norm{i}")
+                    x_out = norm_layer(x_out)
+                out = (
+                    x_out.view(-1, H, W, self.num_features[i])
+                    .permute(0, 3, 1, 2)
+                    .contiguous()
+                )
+                outs["res{}".format(i + 2)] = out
+        if self.projection:
+            x_out = self.norm(x_out)
+            x_out = x_out.view(-1, H, W, self.num_features[-1]).contiguous()
+            outs["fc"] = self.projector(x_out).permute(0, 3, 1, 2)
+
+        return outs
+
+    def train(self, mode=True):
+        """Convert the model into training mode while keep layers freezed."""
+        super(SwinTransformer, self).train(mode)
+        self._freeze_stages()
+
+
+@BACKBONE_REGISTRY.register()
+class D2SwinTransformer(SwinTransformer, Backbone):
+    def __init__(self, cfg, input_shape):
+
+        pretrain_img_size = cfg.MODEL.SWIN.PRETRAIN_IMG_SIZE
+        patch_size = cfg.MODEL.SWIN.PATCH_SIZE
+        in_chans = 3
+        embed_dim = cfg.MODEL.SWIN.EMBED_DIM
+        depths = cfg.MODEL.SWIN.DEPTHS
+        num_heads = cfg.MODEL.SWIN.NUM_HEADS
+        window_size = cfg.MODEL.SWIN.WINDOW_SIZE
+        mlp_ratio = cfg.MODEL.SWIN.MLP_RATIO
+        qkv_bias = cfg.MODEL.SWIN.QKV_BIAS
+        qk_scale = cfg.MODEL.SWIN.QK_SCALE
+        drop_rate = cfg.MODEL.SWIN.DROP_RATE
+        attn_drop_rate = cfg.MODEL.SWIN.ATTN_DROP_RATE
+        drop_path_rate = cfg.MODEL.SWIN.DROP_PATH_RATE
+        norm_layer = nn.LayerNorm
+        ape = cfg.MODEL.SWIN.APE
+        patch_norm = cfg.MODEL.SWIN.PATCH_NORM
+        norm_indices = cfg.MODEL.SWIN.NORM_INDICES
+        projection = cfg.MODEL.SWIN.PROJECTION
+        project_dim = cfg.MODEL.SWIN.PROJECT_DIM
+        super().__init__(
+            pretrain_img_size,
+            patch_size,
+            in_chans,
+            embed_dim,
+            depths,
+            num_heads,
+            window_size,
+            mlp_ratio,
+            qkv_bias,
+            qk_scale,
+            drop_rate,
+            attn_drop_rate,
+            drop_path_rate,
+            norm_layer,
+            ape,
+            patch_norm,
+            norm_indices=norm_indices,
+            projection=projection,
+            project_dim=project_dim,
+        )
+
+        self._out_features = cfg.MODEL.SWIN.OUT_FEATURES
+
+        self._out_feature_strides = {
+            "res2": 4,
+            "res3": 8,
+            "res4": 16,
+            "res5": 32,
+            "fc": 32,
+        }
+        self._out_feature_channels = {
+            "res2": self.num_features[0],
+            "res3": self.num_features[1],
+            "res4": self.num_features[2],
+            "res5": self.num_features[3],
+            "fc": self.num_features[3],
+        }
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
+        Returns:
+            dict[str->Tensor]: names and the corresponding features
+        """
+        assert (
+            x.dim() == 4
+        ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
+        outputs = {}
+        y = super().forward(x)
+        for k in y.keys():
+            if k in self._out_features:
+                outputs[k] = y[k]
+        return outputs
+
+    def output_shape(self):
+        return {
+            name: ShapeSpec(
+                channels=self._out_feature_channels[name],
+                stride=self._out_feature_strides[name],
+            )
+            for name in self._out_features
+        }
+
+    @property
+    def size_divisibility(self):
+        return 32
diff --git a/open_vocab_seg/modeling/clip_adapter/__init__.py b/open_vocab_seg/modeling/clip_adapter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d925b068328373b26352f9a82895d197b47455c
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/__init__.py
@@ -0,0 +1,25 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from .text_template import (
+    PredefinedPromptExtractor,
+    ImageNetPromptExtractor,
+    VILDPromptExtractor,
+)
+from .adapter import ClipAdapter, MaskFormerClipAdapter
+
+
+def build_text_prompt(cfg):
+    if cfg.TEXT_TEMPLATES == "predefined":
+        text_templates = PredefinedPromptExtractor(cfg.PREDEFINED_PROMPT_TEMPLATES)
+    elif cfg.TEXT_TEMPLATES == "imagenet":
+        text_templates = ImageNetPromptExtractor()
+    elif cfg.TEXT_TEMPLATES == "vild":
+        text_templates = VILDPromptExtractor()
+    else:
+        raise NotImplementedError(
+            "Prompt learner {} is not supported".format(cfg.TEXT_TEMPLATES)
+        )
+    return text_templates
+
+from .clip import tokenize
\ No newline at end of file
diff --git a/open_vocab_seg/modeling/clip_adapter/adapter.py b/open_vocab_seg/modeling/clip_adapter/adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..864d20b160714865b4130fab8714f323aaae2572
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/adapter.py
@@ -0,0 +1,206 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+# Modified by Feng Liang from
+# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/adapter.py
+
+from typing import List
+import torch
+from torch import nn
+from torch.nn import functional as F
+from detectron2.structures import BitMasks
+from .utils import build_clip_model, crop_with_mask
+from .text_template import PromptExtractor
+
+
+PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
+PIXEL_STD = (0.26862954, 0.26130258, 0.27577711)
+
+
+class ClipAdapter(nn.Module):
+    def __init__(self, clip_model_name: str, mask_prompt_depth: int, text_templates: PromptExtractor):
+        super().__init__()
+        self.clip_model = build_clip_model(clip_model_name, mask_prompt_depth)
+        self.text_templates = text_templates
+        self.text_templates.init_buffer(self.clip_model)
+        self.text_feature_buffer = {}
+
+    def forward(self, image: torch.Tensor, text: List[str], **kwargs):
+        image = self._preprocess_image(image, **kwargs)
+        text_feature = self.get_text_features(text)  # k,feat_dim
+        image_features = self.get_image_features(image)
+        return self.get_sim_logits(text_feature, image_features)
+
+    def _preprocess_image(self, image: torch.Tensor):
+        return image
+
+    def _get_text_features(self, noun_list: List[str]):
+        left_noun_list = [
+            noun for noun in noun_list if noun not in self.text_feature_buffer
+        ]
+        if len(left_noun_list) > 0:
+            left_text_features = self.text_templates(
+                left_noun_list, self.clip_model
+            )
+            self.text_feature_buffer.update(
+                {
+                    noun: text_feature
+                    for noun, text_feature in zip(
+                        left_noun_list, left_text_features
+                    )
+                }
+            )
+        return torch.stack([self.text_feature_buffer[noun] for noun in noun_list])
+
+
+    def get_text_features(self, noun_list: List[str]):
+        return self._get_text_features(noun_list)
+
+    def get_image_features(self, image: torch.Tensor):
+        image_features = self.clip_model.visual(image)
+        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+        return image_features
+
+    def get_sim_logits(
+        self,
+        text_features: torch.Tensor,
+        image_features: torch.Tensor,
+        temperature: float = 100,
+    ):
+        return temperature * image_features @ text_features.T
+
+    def normalize_feature(self, feat: torch.Tensor):
+        return feat / feat.norm(dim=-1, keepdim=True)
+
+
+class MaskFormerClipAdapter(ClipAdapter):
+    def __init__(
+        self,
+        clip_model_name: str,
+        text_templates: PromptExtractor,
+        mask_fill: str = "mean",
+        mask_expand_ratio: float = 1.0,
+        mask_thr: float = 0.5,
+        mask_matting: bool = False,
+        region_resized: bool = True,
+        mask_prompt_depth: int = 0,
+        mask_prompt_fwd: bool = False,
+    ):
+        super().__init__(clip_model_name, mask_prompt_depth, text_templates)
+        self.non_object_embedding = nn.Parameter(
+            torch.empty(1, self.clip_model.text_projection.shape[-1])
+        )
+        nn.init.normal_(
+            self.non_object_embedding.data,
+            std=self.clip_model.transformer.width ** -0.5,
+        )
+        # for test
+        self.mask_fill = mask_fill
+        if self.mask_fill == "zero":
+            self.mask_fill = (0.0, 0.0, 0.0)
+        elif self.mask_fill == "mean":
+            self.mask_fill = [255.0 * c for c in PIXEL_MEAN]
+        else:
+            raise NotImplementedError(
+                "Unknown mask_fill method: {}".format(self.mask_fill)
+            )
+        self.mask_expand_ratio = mask_expand_ratio
+        self.mask_thr = mask_thr
+        self.mask_matting = mask_matting
+        self.region_resized = region_resized
+        self.mask_prompt_fwd = mask_prompt_fwd
+        self.register_buffer(
+            "pixel_mean", torch.Tensor(PIXEL_MEAN).reshape(1, 3, 1, 1) * 255.0
+        )
+        self.register_buffer(
+            "pixel_std", torch.Tensor(PIXEL_STD).reshape(1, 3, 1, 1) * 255.0
+        )
+
+    def forward(
+        self,
+        image: torch.Tensor,
+        text: List[str],
+        mask: torch.Tensor,
+        normalize: bool = True,
+        fwd_w_region_mask: bool = False,
+    ):
+        (regions, unnorm_regions), region_masks, valid_flag = self._preprocess_image(image, mask, normalize=normalize)
+        if regions is None:
+            return None, valid_flag
+        if isinstance(regions, list):
+            assert NotImplementedError
+            image_features = torch.cat(
+                [self.get_image_features(image_i) for image_i in regions], dim=0
+            )
+        else:
+            if self.mask_prompt_fwd:
+                image_features = self.get_image_features(regions, region_masks)
+            else:
+                image_features = self.get_image_features(regions)
+        text_feature = self.get_text_features(text)  # k,feat_dim
+        return self.get_sim_logits(text_feature, image_features), unnorm_regions, valid_flag
+
+    def get_image_features(self, image, region_masks=None):
+        image_features = self.clip_model.visual(image, region_masks)
+        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+        return image_features
+
+    def _preprocess_image(
+        self, image: torch.Tensor, mask: torch.Tensor, normalize: bool = True
+    ):
+        """crop, mask and normalize the image
+
+        Args:
+            image ([type]): [C,H,W]
+            mask ([type]): [K,H,W
+            normalize (bool, optional): [description]. Defaults to True.
+        """
+        dtype = mask.dtype
+        bin_mask = mask > self.mask_thr
+        valid = bin_mask.sum(dim=(-1, -2)) > 0
+        bin_mask = bin_mask[valid]
+        mask = mask[valid]
+        if not self.mask_matting:
+            mask = bin_mask
+        bin_mask = BitMasks(bin_mask)
+        bboxes = bin_mask.get_bounding_boxes()
+        # crop,mask
+        regions = []
+        region_masks = []
+        for bbox, single_mask in zip(bboxes, mask):
+            region, region_mask = crop_with_mask(
+                image.type(dtype),
+                single_mask.type(dtype),
+                bbox,
+                fill=self.mask_fill,
+                expand_ratio=self.mask_expand_ratio,
+            )
+            regions.append(region.unsqueeze(0))
+            region_masks.append(region_mask.unsqueeze(0))
+        if len(regions) == 0:
+            return None, valid
+        unnorm_regions = regions
+        if normalize:
+            regions = [(r - self.pixel_mean) / self.pixel_std for r in regions]
+        # resize
+        if self.region_resized:
+            regions = [
+                F.interpolate(r, size=(224, 224), mode="bicubic") for r in regions
+            ]
+            regions = torch.cat(regions)
+            region_masks = [
+                F.interpolate(r, size=(224, 224), mode="nearest") for r in region_masks
+            ]
+            region_masks = torch.cat(region_masks)
+            unnorm_regions = [
+                F.interpolate(r, size=(224, 224), mode="bicubic") for r in unnorm_regions
+            ]
+            unnorm_regions = torch.cat(unnorm_regions)
+        return (regions, unnorm_regions), region_masks, valid
+
+    def get_text_features(self, noun_list: List[str]):
+        object_text_features = self._get_text_features(noun_list)
+        non_object_text_features = (
+            self.non_object_embedding
+            / self.non_object_embedding.norm(dim=-1, keepdim=True)
+        )
+        return torch.cat([object_text_features, non_object_text_features], dim=0)
diff --git a/open_vocab_seg/modeling/clip_adapter/clip/__init__.py b/open_vocab_seg/modeling/clip_adapter/clip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcc5619538c0f7c782508bdbd9587259d805e0d9
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/clip/__init__.py
@@ -0,0 +1 @@
+from .clip import *
diff --git a/open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz b/open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/clip/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/open_vocab_seg/modeling/clip_adapter/clip/clip.py b/open_vocab_seg/modeling/clip_adapter/clip/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d733edfac02d81ba3e402eb7e702764728bdaa2
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/clip/clip.py
@@ -0,0 +1,285 @@
+import hashlib
+import os
+import urllib
+import warnings
+from collections import OrderedDict
+from typing import Union, List
+
+import torch
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+from tqdm import tqdm
+
+from .model import build_model
+from .simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+try:
+    from torchvision.transforms import InterpolationMode
+
+    BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+    BICUBIC = Image.BICUBIC
+
+
+if torch.__version__.split(".") < ["1", "7", "1"]:
+    warnings.warn("PyTorch version 1.7.1 or higher is recommended")
+
+
+__all__ = ["available_models", "load", "tokenize"]
+_tokenizer = _Tokenizer()
+
+_MODELS = {
+    "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+    "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+    "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+    "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+    "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+    "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+    "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
+    "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
+}
+
+
+def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+    os.makedirs(root, exist_ok=True)
+    filename = os.path.basename(url)
+
+    expected_sha256 = url.split("/")[-2]
+    download_target = os.path.join(root, filename)
+
+    if os.path.exists(download_target) and not os.path.isfile(download_target):
+        raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+    if os.path.isfile(download_target):
+        if (
+            hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+            == expected_sha256
+        ):
+            return download_target
+        else:
+            warnings.warn(
+                f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
+            )
+
+    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+        with tqdm(
+            total=int(source.info().get("Content-Length")),
+            ncols=80,
+            unit="iB",
+            unit_scale=True,
+        ) as loop:
+            while True:
+                buffer = source.read(8192)
+                if not buffer:
+                    break
+
+                output.write(buffer)
+                loop.update(len(buffer))
+
+    if (
+        hashlib.sha256(open(download_target, "rb").read()).hexdigest()
+        != expected_sha256
+    ):
+        raise RuntimeError(
+            f"Model has been downloaded but the SHA256 checksum does not not match"
+        )
+
+    return download_target
+
+
+def _transform(n_px):
+    return Compose(
+        [
+            Resize(n_px, interpolation=BICUBIC),
+            CenterCrop(n_px),
+            lambda image: image.convert("RGB"),
+            ToTensor(),
+            Normalize(
+                (0.48145466, 0.4578275, 0.40821073),
+                (0.26862954, 0.26130258, 0.27577711),
+            ),
+        ]
+    )
+
+
+def available_models() -> List[str]:
+    """Returns the names of available CLIP models"""
+    return list(_MODELS.keys())
+
+
+def load(
+    name: str,
+    mask_prompt_depth: int = 0,
+    device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
+    jit=False,
+):
+    """Load a CLIP model
+
+    Parameters
+    ----------
+    name : str
+        A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+
+    device : Union[str, torch.device]
+        The device to put the loaded model
+
+    jit : bool
+        Whether to load the optimized JIT model or more hackable non-JIT model (default).
+
+    Returns
+    -------
+    model : torch.nn.Module
+        The CLIP model
+
+    preprocess : Callable[[PIL.Image], torch.Tensor]
+        A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+    """
+    if name in _MODELS:
+        model_path = _download(_MODELS[name])
+    elif os.path.isfile(name):
+        model_path = name
+    else:
+        raise RuntimeError(
+            f"Model {name} not found; available models = {available_models()}"
+        )
+
+    try:
+        # loading JIT archive
+        model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+        state_dict = None
+    except RuntimeError:
+        # loading saved state dict
+        if jit:
+            warnings.warn(
+                f"File {model_path} is not a JIT archive. Loading as a state dict instead"
+            )
+            jit = False
+        state_dict = torch.load(model_path, map_location="cpu")
+        if 'state_dict' in state_dict:
+            new_state_dict = OrderedDict()
+            for k, v in state_dict['state_dict'].items():
+                if k.startswith('module.'):
+                    name = k[7:]  # remove `module.`
+                    new_state_dict[name] = v
+            state_dict = new_state_dict
+
+    if not jit:
+        model = build_model(state_dict or model.state_dict(), mask_prompt_depth).to(device)
+        if str(device) == "cpu":
+            model.float()
+        return model, _transform(model.visual.input_resolution)
+
+    # patch the device names
+    device_holder = torch.jit.trace(
+        lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
+    )
+    device_node = [
+        n
+        for n in device_holder.graph.findAllNodes("prim::Constant")
+        if "Device" in repr(n)
+    ][-1]
+
+    def patch_device(module):
+        try:
+            graphs = [module.graph] if hasattr(module, "graph") else []
+        except RuntimeError:
+            graphs = []
+
+        if hasattr(module, "forward1"):
+            graphs.append(module.forward1.graph)
+
+        for graph in graphs:
+            for node in graph.findAllNodes("prim::Constant"):
+                if "value" in node.attributeNames() and str(node["value"]).startswith(
+                    "cuda"
+                ):
+                    node.copyAttributes(device_node)
+
+    model.apply(patch_device)
+    patch_device(model.encode_image)
+    patch_device(model.encode_text)
+
+    # patch dtype to float32 on CPU
+    if str(device) == "cpu":
+        float_holder = torch.jit.trace(
+            lambda: torch.ones([]).float(), example_inputs=[]
+        )
+        float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+        float_node = float_input.node()
+
+        def patch_float(module):
+            try:
+                graphs = [module.graph] if hasattr(module, "graph") else []
+            except RuntimeError:
+                graphs = []
+
+            if hasattr(module, "forward1"):
+                graphs.append(module.forward1.graph)
+
+            for graph in graphs:
+                for node in graph.findAllNodes("aten::to"):
+                    inputs = list(node.inputs())
+                    for i in [
+                        1,
+                        2,
+                    ]:  # dtype can be the second or third argument to aten::to()
+                        if inputs[i].node()["value"] == 5:
+                            inputs[i].node().copyAttributes(float_node)
+
+        model.apply(patch_float)
+        patch_float(model.encode_image)
+        patch_float(model.encode_text)
+
+        model.float()
+
+    return model, _transform(model.input_resolution.item())
+
+
+def tokenize(
+    texts: Union[str, List[str]],
+    context_length: int = 77,
+    truncate: bool = False,
+    return_length: bool = False,
+) -> torch.LongTensor:
+    """
+    Returns the tokenized representation of given input string(s)
+
+    Parameters
+    ----------
+    texts : Union[str, List[str]]
+        An input string or a list of input strings to tokenize
+
+    context_length : int
+        The context length to use; all CLIP models use 77 as the context length
+
+    truncate: bool
+        Whether to truncate the text in case its encoding is longer than the context length
+
+    Returns
+    -------
+    A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
+    """
+    if isinstance(texts, str):
+        texts = [texts]
+
+    sot_token = _tokenizer.encoder["<|startoftext|>"]
+    eot_token = _tokenizer.encoder["<|endoftext|>"]
+    all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
+    result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
+    length = []
+    for i, tokens in enumerate(all_tokens):
+        if len(tokens) > context_length:
+            if truncate:
+                tokens = tokens[:context_length]
+                tokens[-1] = eot_token
+                length.append(context_length)
+            else:
+                raise RuntimeError(
+                    f"Input {texts[i]} is too long for context length {context_length}"
+                )
+        else:
+            length.append(len(tokens))
+        result[i, : len(tokens)] = torch.tensor(tokens)
+    if return_length:
+        return result, length
+    return result
diff --git a/open_vocab_seg/modeling/clip_adapter/clip/model.py b/open_vocab_seg/modeling/clip_adapter/clip/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ea730a2cc8a992f9180428bd1fec7fc96aa89dd
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/clip/model.py
@@ -0,0 +1,613 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+# Modified by Feng Liang from https://github.com/openai/CLIP/blob/main/clip/model.py
+
+from collections import OrderedDict
+from typing import Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1):
+        super().__init__()
+
+        # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+        self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+
+        self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+        self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = None
+        self.stride = stride
+
+        if stride > 1 or inplanes != planes * Bottleneck.expansion:
+            # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+            self.downsample = nn.Sequential(
+                OrderedDict(
+                    [
+                        ("-1", nn.AvgPool2d(stride)),
+                        (
+                            "0",
+                            nn.Conv2d(
+                                inplanes,
+                                planes * self.expansion,
+                                1,
+                                stride=1,
+                                bias=False,
+                            ),
+                        ),
+                        ("1", nn.BatchNorm2d(planes * self.expansion)),
+                    ]
+                )
+            )
+
+    def forward(self, x: torch.Tensor):
+        identity = x
+
+        out = self.relu(self.bn1(self.conv1(x)))
+        out = self.relu(self.bn2(self.conv2(out)))
+        out = self.avgpool(out)
+        out = self.bn3(self.conv3(out))
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+        return out
+
+
+class AttentionPool2d(nn.Module):
+    def __init__(
+        self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
+    ):
+        super().__init__()
+        self.positional_embedding = nn.Parameter(
+            torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5
+        )
+        self.k_proj = nn.Linear(embed_dim, embed_dim)
+        self.q_proj = nn.Linear(embed_dim, embed_dim)
+        self.v_proj = nn.Linear(embed_dim, embed_dim)
+        self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+        self.num_heads = num_heads
+        self.grid_size = spacial_dim
+
+    def forward(self, x, mask=None, return_cls=True):
+        b, c, gh, gw = x.shape
+        # remove irrelated feature
+        if mask is not None:
+            mask = F.interpolate(mask[:, None, ...], size=(gh, gw)).squeeze(
+                1
+            )  # [N,H,W] -> [N,grid,grid]
+            mask = (mask > 0.5).reshape(mask.shape[0], -1)
+            mask = torch.cat([mask, mask.new_ones(mask.shape[0], 1)], dim=1)
+            if x.size()[0] == 1:
+                x = x.expand(mask.shape[0], c, gh, gw)
+
+        x = x.reshape(x.shape[0], c, gh * gw).permute(2, 0, 1)  # NCHW -> (HW)NC
+
+        x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)  # (HW+1)NC
+        positional_embedding = self.positional_embedding
+        if not (self.positional_embedding.shape[0] == x.shape[0]):
+            cls_pos = positional_embedding[0:1, :]
+            per_pos_embedding = (
+                F.interpolate(
+                    positional_embedding[1:, :]
+                    .permute(1, 0)
+                    .view(1, -1, self.grid_size, self.grid_size),
+                    size=(gh, gw),
+                    mode="bicubic",
+                )
+                .reshape(-1, gh * gw)
+                .permute(1, 0)
+            )
+            positional_embedding = torch.cat([cls_pos, per_pos_embedding])
+
+        x = x + positional_embedding[:, None, :].to(x.dtype)  # (HW+1)NC
+        x, _ = F.multi_head_attention_forward(
+            query=x,
+            key=x,
+            value=x,
+            embed_dim_to_check=x.shape[-1],
+            num_heads=self.num_heads,
+            q_proj_weight=self.q_proj.weight,
+            k_proj_weight=self.k_proj.weight,
+            v_proj_weight=self.v_proj.weight,
+            in_proj_weight=None,
+            in_proj_bias=torch.cat(
+                [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
+            ),
+            bias_k=None,
+            bias_v=None,
+            add_zero_attn=False,
+            dropout_p=0,
+            out_proj_weight=self.c_proj.weight,
+            out_proj_bias=self.c_proj.bias,
+            use_separate_proj_weight=True,
+            training=self.training,
+            need_weights=False,
+            key_padding_mask=mask,
+        )
+
+        if return_cls:
+            return x[0]
+        else:
+            return x
+
+
+class ModifiedResNet(nn.Module):
+    """
+    A ResNet class that is similar to torchvision's but contains the following changes:
+    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+    - The final pooling layer is a QKV attention instead of an average pool
+    """
+
+    def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+        super().__init__()
+        self.output_dim = output_dim
+        self.input_resolution = input_resolution
+
+        # the 3-layer stem
+        self.conv1 = nn.Conv2d(
+            3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
+        )
+        self.bn1 = nn.BatchNorm2d(width // 2)
+        self.conv2 = nn.Conv2d(
+            width // 2, width // 2, kernel_size=3, padding=1, bias=False
+        )
+        self.bn2 = nn.BatchNorm2d(width // 2)
+        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+        self.bn3 = nn.BatchNorm2d(width)
+        self.avgpool = nn.AvgPool2d(2)
+        self.relu = nn.ReLU(inplace=True)
+
+        # residual layers
+        self._inplanes = width  # this is a *mutable* variable used during construction
+        self.layer1 = self._make_layer(width, layers[0])
+        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+        self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+        embed_dim = width * 32  # the ResNet feature dimension
+        self.attnpool = AttentionPool2d(
+            input_resolution // 32, embed_dim, heads, output_dim
+        )
+
+    def _make_layer(self, planes, blocks, stride=1):
+        layers = [Bottleneck(self._inplanes, planes, stride)]
+
+        self._inplanes = planes * Bottleneck.expansion
+        for _ in range(1, blocks):
+            layers.append(Bottleneck(self._inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x, mask: torch.Tensor = None, return_cls=True):
+        def stem(x):
+            for conv, bn in [
+                (self.conv1, self.bn1),
+                (self.conv2, self.bn2),
+                (self.conv3, self.bn3),
+            ]:
+                x = self.relu(bn(conv(x)))
+            x = self.avgpool(x)
+            return x
+
+        x = x.type(self.conv1.weight.dtype)
+        x = stem(x)  # 1/4,1/4
+        x = self.layer1(x)
+        x = self.layer2(x)  # 1/8,1/8
+        x = self.layer3(x)  # 1/16,1/16
+        x = self.layer4(x)  # 1/32,1/32
+        b, c, gh, gw = x.shape
+        x = self.attnpool(x, mask, return_cls)
+        if not return_cls:
+            return x[1:].permute(1, 0, 2).reshape(b, gh, gw, x.shape[-1])  # N,L,C
+        return x
+
+
+class LayerNorm(nn.LayerNorm):
+    """Subclass torch's LayerNorm to handle fp16."""
+
+    def forward(self, x: torch.Tensor):
+        orig_type = x.dtype
+        ret = super().forward(x.type(torch.float32))
+        return ret.type(orig_type)
+
+
+class QuickGELU(nn.Module):
+    def forward(self, x: torch.Tensor):
+        return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+        super().__init__()
+
+        self.attn = nn.MultiheadAttention(d_model, n_head)
+        self.ln_1 = LayerNorm(d_model)
+        self.mlp = nn.Sequential(
+            OrderedDict(
+                [
+                    ("c_fc", nn.Linear(d_model, d_model * 4)),
+                    ("gelu", QuickGELU()),
+                    ("c_proj", nn.Linear(d_model * 4, d_model)),
+                ]
+            )
+        )
+        self.ln_2 = LayerNorm(d_model)
+        self.attn_mask = attn_mask
+
+    def attention(self, x: torch.Tensor, **kwargs):
+        self.attn_mask = (
+            self.attn_mask.to(dtype=x.dtype, device=x.device)
+            if self.attn_mask is not None
+            else None
+        )
+        return self.attn(
+            x, x, x, need_weights=False, attn_mask=self.attn_mask, **kwargs
+        )[0]
+
+    def forward(self, x: torch.Tensor, **kwargs):
+        x = x + self.attention(self.ln_1(x), **kwargs)
+        x = x + self.mlp(self.ln_2(x))
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
+    ):
+        super().__init__()
+        self.width = width
+        self.layers = layers
+        self.resblocks = nn.Sequential(
+            *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
+        )
+
+    def forward(self, x: torch.Tensor, **kwargs):
+        for block in self.resblocks:
+            x = block(x, **kwargs)
+        return x
+
+
+class VisionTransformer(nn.Module):
+    def __init__(
+        self,
+        input_resolution: int,
+        patch_size: int,
+        mask_prompt_depth: int,
+        width: int,
+        layers: int,
+        heads: int,
+        output_dim: int,
+    ):
+        super().__init__()
+        self.input_resolution = input_resolution
+        self.output_dim = output_dim
+        self.conv1 = nn.Conv2d(
+            in_channels=3,
+            out_channels=width,
+            kernel_size=patch_size,
+            stride=patch_size,
+            bias=False,
+        )
+
+        scale = width ** -0.5
+        self.class_embedding = nn.Parameter(scale * torch.randn(width))
+        self.positional_embedding = nn.Parameter(
+            scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
+        )
+        self.grid_size = input_resolution // patch_size
+        self.ln_pre = LayerNorm(width)
+
+        self.transformer = Transformer(width, layers, heads)
+
+        self.ln_post = LayerNorm(width)
+        self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+        self.mask_pool = nn.AvgPool2d(patch_size, stride=patch_size)
+        self.mask_prompt_depth = mask_prompt_depth
+        self.mask_embedding = nn.Parameter(torch.zeros(self.mask_prompt_depth, self.grid_size * self.grid_size, width))
+
+    def forward(self, x: torch.Tensor, m: torch.Tensor = None):
+        x = self.conv1(x)  # shape = [*, width, grid, grid]
+        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+        if m is not None:
+            m = self.mask_pool(m.to(torch.float).squeeze()).reshape(m.shape[0], -1).unsqueeze(-1)
+            m = torch.ceil(m)
+            if self.mask_embedding.shape[1] == 1:
+                mask_embedding = self.mask_embedding.to(x.dtype).repeat(1, x.shape[1], 1)
+            else:
+                mask_embedding = self.mask_embedding.to(x.dtype)
+            x = x * m + mask_embedding[0].unsqueeze(0) * (1 - m)
+
+        x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+        x = x + self.positional_embedding.to(x.dtype)
+        x = self.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        if m is not None:
+            for i, blk in enumerate(self.transformer.resblocks):
+                d = i + 1
+                x = blk(x)
+                if d < self.mask_prompt_depth:
+                    masked_x = x[1:, :, :] * m.permute(1, 0, 2) + \
+                               mask_embedding[d].unsqueeze(0).permute(1, 0, 2) * (1 - m.permute(1, 0, 2))
+                    x = torch.cat([x[:1, :, :], masked_x], dim=0)
+        else:
+            x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        x = self.ln_post(x[:, 0, :])
+
+        if self.proj is not None:
+            x = x @ self.proj
+
+        return x
+
+
+
+class CLIP(nn.Module):
+    def __init__(
+        self,
+        embed_dim: int,
+        # vision
+        image_resolution: int,
+        vision_layers: Union[Tuple[int, int, int, int], int],
+        vision_width: int,
+        vision_patch_size: int,
+        mask_prompt_depth: int,
+        # text
+        context_length: int,
+        vocab_size: int,
+        transformer_width: int,
+        transformer_heads: int,
+        transformer_layers: int,
+    ):
+        super().__init__()
+
+        self.context_length = context_length
+
+        if isinstance(vision_layers, (tuple, list)):
+            vision_heads = vision_width * 32 // 64
+            self.visual = ModifiedResNet(
+                layers=vision_layers,
+                output_dim=embed_dim,
+                heads=vision_heads,
+                input_resolution=image_resolution,
+                width=vision_width,
+            )
+        else:
+            vision_heads = vision_width // 64
+            self.visual = VisionTransformer(
+                input_resolution=image_resolution,
+                patch_size=vision_patch_size,
+                mask_prompt_depth=mask_prompt_depth,
+                width=vision_width,
+                layers=vision_layers,
+                heads=vision_heads,
+                output_dim=embed_dim,
+            )
+
+        self.transformer = Transformer(
+            width=transformer_width,
+            layers=transformer_layers,
+            heads=transformer_heads,
+            attn_mask=self.build_attention_mask(),
+        )
+
+        self.vocab_size = vocab_size
+        self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+        self.positional_embedding = nn.Parameter(
+            torch.empty(self.context_length, transformer_width)
+        )
+        self.ln_final = LayerNorm(transformer_width)
+
+        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
+
+        self.initialize_parameters()
+
+    def initialize_parameters(self):
+        nn.init.normal_(self.token_embedding.weight, std=0.02)
+        nn.init.normal_(self.positional_embedding, std=0.01)
+
+        if isinstance(self.visual, ModifiedResNet):
+            if self.visual.attnpool is not None:
+                std = self.visual.attnpool.c_proj.in_features ** -0.5
+                nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+                nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+            for resnet_block in [
+                self.visual.layer1,
+                self.visual.layer2,
+                self.visual.layer3,
+                self.visual.layer4,
+            ]:
+                for name, param in resnet_block.named_parameters():
+                    if name.endswith("bn3.weight"):
+                        nn.init.zeros_(param)
+
+        proj_std = (self.transformer.width ** -0.5) * (
+            (2 * self.transformer.layers) ** -0.5
+        )
+        attn_std = self.transformer.width ** -0.5
+        fc_std = (2 * self.transformer.width) ** -0.5
+        for block in self.transformer.resblocks:
+            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+        if self.text_projection is not None:
+            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+    def build_attention_mask(self):
+        # lazily create causal attention mask, with full attention between the vision tokens
+        # pytorch uses additive attention mask; fill with -inf
+        mask = torch.empty(self.context_length, self.context_length)
+        mask.fill_(float("-inf"))
+        mask.triu_(1)  # zero out the lower diagonal
+        return mask
+
+    @property
+    def dtype(self):
+        return self.visual.conv1.weight.dtype
+
+    def encode_image(self, image, **kwargs):
+        return self.visual(image.type(self.dtype), **kwargs)
+
+    def encode_text(self, text):
+        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
+
+        x = x + self.positional_embedding.type(self.dtype)
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+        x = self.ln_final(x).type(self.dtype)
+
+        # x.shape = [batch_size, n_ctx, transformer.width]
+        # take features from the eot embedding (eot_token is the highest number in each sequence)
+        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+        return x
+
+    def forward(self, image, text):
+        image_features = self.encode_image(image)
+        text_features = self.encode_text(text)
+
+        # normalized features
+        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+        # cosine similarity as logits
+        logit_scale = self.logit_scale.exp()
+        logits_per_image = logit_scale * image_features @ text_features.t()
+        logits_per_text = logit_scale * text_features @ image_features.t()
+
+        # shape = [global_batch_size, global_batch_size]
+        return logits_per_image, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+    """Convert applicable model parameters to fp16"""
+
+    def _convert_weights_to_fp16(l):
+        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+            l.weight.data = l.weight.data.half()
+            if l.bias is not None:
+                l.bias.data = l.bias.data.half()
+
+        if isinstance(l, nn.MultiheadAttention):
+            for attr in [
+                *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
+                "in_proj_bias",
+                "bias_k",
+                "bias_v",
+            ]:
+                tensor = getattr(l, attr)
+                if tensor is not None:
+                    tensor.data = tensor.data.half()
+
+        for name in ["text_projection", "proj"]:
+            if hasattr(l, name):
+                attr = getattr(l, name)
+                if attr is not None:
+                    attr.data = attr.data.half()
+
+    model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict, mask_prompt_depth: int = 0):
+    vit = "visual.proj" in state_dict
+
+    if vit:
+        vision_width = state_dict["visual.conv1.weight"].shape[0]
+        vision_layers = len(
+            [
+                k
+                for k in state_dict.keys()
+                if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
+            ]
+        )
+        vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+        grid_size = round(
+            (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
+        )
+        image_resolution = vision_patch_size * grid_size
+    else:
+        assert mask_prompt_depth == 0, 'ResNets do not support mask prompt tuning'
+        counts: list = [
+            len(
+                set(
+                    k.split(".")[2]
+                    for k in state_dict
+                    if k.startswith(f"visual.layer{b}")
+                )
+            )
+            for b in [1, 2, 3, 4]
+        ]
+        vision_layers = tuple(counts)
+        vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+        output_width = round(
+            (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
+        )
+        vision_patch_size = None
+        assert (
+            output_width ** 2 + 1
+            == state_dict["visual.attnpool.positional_embedding"].shape[0]
+        )
+        image_resolution = output_width * 32
+
+    embed_dim = state_dict["text_projection"].shape[1]
+    context_length = state_dict["positional_embedding"].shape[0]
+    vocab_size = state_dict["token_embedding.weight"].shape[0]
+    transformer_width = state_dict["ln_final.weight"].shape[0]
+    transformer_heads = transformer_width // 64
+    transformer_layers = len(
+        set(
+            k.split(".")[2]
+            for k in state_dict
+            if k.startswith(f"transformer.resblocks")
+        )
+    )
+
+    model = CLIP(
+        embed_dim,
+        image_resolution,
+        vision_layers,
+        vision_width,
+        vision_patch_size,
+        mask_prompt_depth,
+        context_length,
+        vocab_size,
+        transformer_width,
+        transformer_heads,
+        transformer_layers,
+    )
+
+    for key in ["input_resolution", "context_length", "vocab_size"]:
+        if key in state_dict:
+            del state_dict[key]
+
+    convert_weights(model)
+    model.load_state_dict(state_dict, strict=False)
+    return model.eval()
diff --git a/open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py b/open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..56d17512b06afb700e7834e4f3f6515c315ebb74
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/clip/simple_tokenizer.py
@@ -0,0 +1,150 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+    return os.path.join(
+        os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
+    )
+
+
+@lru_cache()
+def bytes_to_unicode():
+    """
+    Returns list of utf-8 byte and a corresponding list of unicode strings.
+    The reversible bpe codes work on unicode strings.
+    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+    This is a signficant percentage of your normal, say, 32K bpe vocab.
+    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+    And avoids mapping to whitespace/control characters the bpe code barfs on.
+    """
+    bs = (
+        list(range(ord("!"), ord("~") + 1))
+        + list(range(ord("¡"), ord("¬") + 1))
+        + list(range(ord("®"), ord("ÿ") + 1))
+    )
+    cs = bs[:]
+    n = 0
+    for b in range(2 ** 8):
+        if b not in bs:
+            bs.append(b)
+            cs.append(2 ** 8 + n)
+            n += 1
+    cs = [chr(n) for n in cs]
+    return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+    """Return set of symbol pairs in a word.
+    Word is represented as tuple of symbols (symbols being variable-length strings).
+    """
+    pairs = set()
+    prev_char = word[0]
+    for char in word[1:]:
+        pairs.add((prev_char, char))
+        prev_char = char
+    return pairs
+
+
+def basic_clean(text):
+    text = ftfy.fix_text(text)
+    text = html.unescape(html.unescape(text))
+    return text.strip()
+
+
+def whitespace_clean(text):
+    text = re.sub(r"\s+", " ", text)
+    text = text.strip()
+    return text
+
+
+class SimpleTokenizer(object):
+    def __init__(self, bpe_path: str = default_bpe()):
+        self.byte_encoder = bytes_to_unicode()
+        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+        merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
+        merges = merges[1 : 49152 - 256 - 2 + 1]
+        merges = [tuple(merge.split()) for merge in merges]
+        vocab = list(bytes_to_unicode().values())
+        vocab = vocab + [v + "</w>" for v in vocab]
+        for merge in merges:
+            vocab.append("".join(merge))
+        vocab.extend(["<|startoftext|>", "<|endoftext|>"])
+        self.encoder = dict(zip(vocab, range(len(vocab))))
+        self.decoder = {v: k for k, v in self.encoder.items()}
+        self.bpe_ranks = dict(zip(merges, range(len(merges))))
+        self.cache = {
+            "<|startoftext|>": "<|startoftext|>",
+            "<|endoftext|>": "<|endoftext|>",
+        }
+        self.pat = re.compile(
+            r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
+            re.IGNORECASE,
+        )
+
+    def bpe(self, token):
+        if token in self.cache:
+            return self.cache[token]
+        word = tuple(token[:-1]) + (token[-1] + "</w>",)
+        pairs = get_pairs(word)
+
+        if not pairs:
+            return token + "</w>"
+
+        while True:
+            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
+            if bigram not in self.bpe_ranks:
+                break
+            first, second = bigram
+            new_word = []
+            i = 0
+            while i < len(word):
+                try:
+                    j = word.index(first, i)
+                    new_word.extend(word[i:j])
+                    i = j
+                except:
+                    new_word.extend(word[i:])
+                    break
+
+                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
+                    new_word.append(first + second)
+                    i += 2
+                else:
+                    new_word.append(word[i])
+                    i += 1
+            new_word = tuple(new_word)
+            word = new_word
+            if len(word) == 1:
+                break
+            else:
+                pairs = get_pairs(word)
+        word = " ".join(word)
+        self.cache[token] = word
+        return word
+
+    def encode(self, text):
+        bpe_tokens = []
+        text = whitespace_clean(basic_clean(text)).lower()
+        for token in re.findall(self.pat, text):
+            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
+            bpe_tokens.extend(
+                self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
+            )
+        return bpe_tokens
+
+    def decode(self, tokens):
+        text = "".join([self.decoder[token] for token in tokens])
+        text = (
+            bytearray([self.byte_decoder[c] for c in text])
+            .decode("utf-8", errors="replace")
+            .replace("</w>", " ")
+        )
+        return text
diff --git a/open_vocab_seg/modeling/clip_adapter/text_template.py b/open_vocab_seg/modeling/clip_adapter/text_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..724bbef34c6bd74b0d7ead336d6b06d145bbee2d
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/text_template.py
@@ -0,0 +1,156 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+# Modified by Feng Liang from
+# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/text_prompt.py
+# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/clip_adapter/utils.py
+
+from typing import List
+
+# import clip
+from .clip import tokenize
+import torch
+from torch import nn
+
+IMAGENET_PROMPT = [
+    "a bad photo of a {}.",
+    "a photo of many {}.",
+    "a sculpture of a {}.",
+    "a photo of the hard to see {}.",
+    "a low resolution photo of the {}.",
+    "a rendering of a {}.",
+    "graffiti of a {}.",
+    "a bad photo of the {}.",
+    "a cropped photo of the {}.",
+    "a tattoo of a {}.",
+    "the embroidered {}.",
+    "a photo of a hard to see {}.",
+    "a bright photo of a {}.",
+    "a photo of a clean {}.",
+    "a photo of a dirty {}.",
+    "a dark photo of the {}.",
+    "a drawing of a {}.",
+    "a photo of my {}.",
+    "the plastic {}.",
+    "a photo of the cool {}.",
+    "a close-up photo of a {}.",
+    "a black and white photo of the {}.",
+    "a painting of the {}.",
+    "a painting of a {}.",
+    "a pixelated photo of the {}.",
+    "a sculpture of the {}.",
+    "a bright photo of the {}.",
+    "a cropped photo of a {}.",
+    "a plastic {}.",
+    "a photo of the dirty {}.",
+    "a jpeg corrupted photo of a {}.",
+    "a blurry photo of the {}.",
+    "a photo of the {}.",
+    "a good photo of the {}.",
+    "a rendering of the {}.",
+    "a {} in a video game.",
+    "a photo of one {}.",
+    "a doodle of a {}.",
+    "a close-up photo of the {}.",
+    "a photo of a {}.",
+    "the origami {}.",
+    "the {} in a video game.",
+    "a sketch of a {}.",
+    "a doodle of the {}.",
+    "a origami {}.",
+    "a low resolution photo of a {}.",
+    "the toy {}.",
+    "a rendition of the {}.",
+    "a photo of the clean {}.",
+    "a photo of a large {}.",
+    "a rendition of a {}.",
+    "a photo of a nice {}.",
+    "a photo of a weird {}.",
+    "a blurry photo of a {}.",
+    "a cartoon {}.",
+    "art of a {}.",
+    "a sketch of the {}.",
+    "a embroidered {}.",
+    "a pixelated photo of a {}.",
+    "itap of the {}.",
+    "a jpeg corrupted photo of the {}.",
+    "a good photo of a {}.",
+    "a plushie {}.",
+    "a photo of the nice {}.",
+    "a photo of the small {}.",
+    "a photo of the weird {}.",
+    "the cartoon {}.",
+    "art of the {}.",
+    "a drawing of the {}.",
+    "a photo of the large {}.",
+    "a black and white photo of a {}.",
+    "the plushie {}.",
+    "a dark photo of a {}.",
+    "itap of a {}.",
+    "graffiti of the {}.",
+    "a toy {}.",
+    "itap of my {}.",
+    "a photo of a cool {}.",
+    "a photo of a small {}.",
+    "a tattoo of the {}.",
+]
+
+VILD_PROMPT = [
+    "a photo of a {}.",
+    "This is a photo of a {}",
+    "There is a {} in the scene",
+    "There is the {} in the scene",
+    "a photo of a {} in the scene",
+    "a photo of a small {}.",
+    "a photo of a medium {}.",
+    "a photo of a large {}.",
+    "This is a photo of a small {}.",
+    "This is a photo of a medium {}.",
+    "This is a photo of a large {}.",
+    "There is a small {} in the scene.",
+    "There is a medium {} in the scene.",
+    "There is a large {} in the scene.",
+]
+
+class PromptExtractor(nn.Module):
+    def __init__(self):
+        super().__init__()
+        self._buffer_init = False
+
+    def init_buffer(self, clip_model):
+        self._buffer_init = True
+
+    def forward(self, noun_list: List[str], clip_model: nn.Module):
+        raise NotImplementedError()
+
+
+class PredefinedPromptExtractor(PromptExtractor):
+    def __init__(self, templates: List[str]):
+        super().__init__()
+        self.templates = templates
+
+    def forward(self, noun_list: List[str], clip_model: nn.Module):
+        text_features_bucket = []
+        for template in self.templates:
+            noun_tokens = [tokenize(template.format(noun)) for noun in noun_list]
+            text_inputs = torch.cat(noun_tokens).to(
+                clip_model.text_projection.data.device
+            )
+            text_features = clip_model.encode_text(text_inputs)
+            text_features /= text_features.norm(dim=-1, keepdim=True)
+            text_features_bucket.append(text_features)
+        del text_inputs
+        # ensemble by averaging
+        text_features = torch.stack(text_features_bucket).mean(dim=0)
+        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+        return text_features
+
+
+class ImageNetPromptExtractor(PredefinedPromptExtractor):
+    def __init__(self):
+        super().__init__(IMAGENET_PROMPT)
+
+
+class VILDPromptExtractor(PredefinedPromptExtractor):
+    def __init__(self):
+        super().__init__(VILD_PROMPT)
diff --git a/open_vocab_seg/modeling/clip_adapter/utils.py b/open_vocab_seg/modeling/clip_adapter/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbe5d9d5284597cca444287f6bae38e37549bde0
--- /dev/null
+++ b/open_vocab_seg/modeling/clip_adapter/utils.py
@@ -0,0 +1,81 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from typing import Tuple
+import numpy as np
+import torch
+from .clip import load as clip_load
+from detectron2.utils.comm import get_local_rank, synchronize
+
+
+def expand_box(
+    x1: float,
+    y1: float,
+    x2: float,
+    y2: float,
+    expand_ratio: float = 1.0,
+    max_h: int = None,
+    max_w: int = None,
+):
+    cx = 0.5 * (x1 + x2)
+    cy = 0.5 * (y1 + y2)
+    w = x2 - x1
+    h = y2 - y1
+    w = w * expand_ratio
+    h = h * expand_ratio
+    box = [cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h]
+    if max_h is not None:
+        box[1] = max(0, box[1])
+        box[3] = min(max_h - 1, box[3])
+    if max_w is not None:
+        box[0] = max(0, box[0])
+        box[2] = min(max_w - 1, box[2])
+    return [int(b) for b in box]
+
+
+def mask2box(mask: torch.Tensor):
+    # use naive way
+    row = torch.nonzero(mask.sum(dim=0))[:, 0]
+    if len(row) == 0:
+        return None
+    x1 = row.min()
+    x2 = row.max()
+    col = np.nonzero(mask.sum(dim=1))[:, 0]
+    y1 = col.min()
+    y2 = col.max()
+    return x1, y1, x2 + 1, y2 + 1
+
+
+def crop_with_mask(
+    image: torch.Tensor,
+    mask: torch.Tensor,
+    bbox: torch.Tensor,
+    fill: Tuple[float, float, float] = (0, 0, 0),
+    expand_ratio: float = 1.0,
+):
+    l, t, r, b = expand_box(*bbox, expand_ratio)
+    _, h, w = image.shape
+    l = max(l, 0)
+    t = max(t, 0)
+    r = min(r, w)
+    b = min(b, h)
+    new_image = torch.cat(
+        [image.new_full((1, b - t, r - l), fill_value=val) for val in fill]
+    )
+    mask_bool = mask.bool()
+    return image[:, t:b, l:r] * mask[None, t:b, l:r] + (~ mask_bool[None, t:b, l:r]) * new_image, mask[None, t:b, l:r]
+
+
+def build_clip_model(model: str, mask_prompt_depth: int = 0, frozen: bool = True):
+    rank = get_local_rank()
+    if rank == 0:
+        # download on rank 0 only
+        model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
+    synchronize()
+    if rank != 0:
+        model, _ = clip_load(model, mask_prompt_depth=mask_prompt_depth, device="cpu")
+    synchronize()
+    if frozen:
+        for param in model.parameters():
+            param.requires_grad = False
+    return model
diff --git a/open_vocab_seg/modeling/criterion.py b/open_vocab_seg/modeling/criterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4d5b71242f87c6f67463f9c31f873a742f3e5c7
--- /dev/null
+++ b/open_vocab_seg/modeling/criterion.py
@@ -0,0 +1,229 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/detr.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+"""
+MaskFormer criterion.
+"""
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from detectron2.utils.comm import get_world_size
+
+from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list
+
+
+def dice_loss(inputs, targets, num_masks):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    """
+    inputs = inputs.sigmoid()
+    inputs = inputs.flatten(1)
+    numerator = 2 * (inputs * targets).sum(-1)
+    denominator = inputs.sum(-1) + targets.sum(-1)
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss.sum() / num_masks
+
+
+def sigmoid_focal_loss(
+    inputs, targets, num_masks, alpha: float = 0.25, gamma: float = 2
+):
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+    Returns:
+        Loss tensor
+    """
+    prob = inputs.sigmoid()
+    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+    p_t = prob * targets + (1 - prob) * (1 - targets)
+    loss = ce_loss * ((1 - p_t) ** gamma)
+
+    if alpha >= 0:
+        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+        loss = alpha_t * loss
+
+    return loss.mean(1).sum() / num_masks
+
+
+class SetCriterion(nn.Module):
+    """This class computes the loss for DETR.
+    The process happens in two steps:
+        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+    """
+
+    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
+        """Create the criterion.
+        Parameters:
+            num_classes: number of object categories, omitting the special no-object category
+            matcher: module able to compute a matching between targets and proposals
+            weight_dict: dict containing as key the names of the losses and as values their relative weight.
+            eos_coef: relative classification weight applied to the no-object category
+            losses: list of all the losses to be applied. See get_loss for list of available losses.
+        """
+        super().__init__()
+        self.num_classes = num_classes
+        self.matcher = matcher
+        self.weight_dict = weight_dict
+        self.eos_coef = eos_coef
+        self.losses = losses
+        if eos_coef > 0:
+
+            empty_weight = torch.ones(self.num_classes + 1)
+
+            empty_weight[-1] = self.eos_coef
+            self.register_buffer("empty_weight", empty_weight)
+            self.use_ignore_idx = False
+        else:
+            self.use_ignore_idx = True
+        self.cur_target = []
+
+    def loss_labels(self, outputs, targets, indices, num_masks):
+        """Classification loss (NLL)
+        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+        """
+        assert "pred_logits" in outputs
+        src_logits = outputs["pred_logits"]
+
+        idx = self._get_src_permutation_idx(indices)
+        target_classes_o = torch.cat(
+            [t["labels"][J] for t, (_, J) in zip(targets, indices)]
+        )
+        target_classes = torch.full(
+            src_logits.shape[:2],
+            self.num_classes,
+            dtype=torch.int64,
+            device=src_logits.device,
+        )
+        target_classes[idx] = target_classes_o
+        if self.use_ignore_idx:
+            loss_ce = F.cross_entropy(
+                src_logits.transpose(1, 2),
+                target_classes,
+                ignore_index=self.num_classes,
+            )
+        else:
+            if "empty_weight" in outputs:
+                empty_weight = torch.cat(
+                    [outputs["empty_weight"], self.empty_weight[-1:]]
+                ).detach()
+            else:
+                empty_weight = self.empty_weight
+            loss_ce = F.cross_entropy(
+                src_logits.transpose(1, 2), target_classes, empty_weight
+            )
+        losses = {"loss_ce": loss_ce}
+        return losses
+
+    def loss_masks(self, outputs, targets, indices, num_masks):
+        """Compute the losses related to the masks: the focal loss and the dice loss.
+        targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
+        """
+        assert "pred_masks" in outputs
+
+        src_idx = self._get_src_permutation_idx(indices)
+        tgt_idx = self._get_tgt_permutation_idx(indices)
+        src_masks = outputs["pred_masks"]
+        src_masks = src_masks[src_idx]
+        masks = [t["masks"] for t in targets]
+        # TODO use valid to mask invalid areas due to padding in loss
+        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
+        target_masks = target_masks.to(src_masks)
+        target_masks = target_masks[tgt_idx]
+
+        # upsample predictions to the target size
+        src_masks = F.interpolate(
+            src_masks[:, None],
+            size=target_masks.shape[-2:],
+            mode="bilinear",
+            align_corners=False,
+        )
+        src_masks = src_masks[:, 0].flatten(1)
+
+        target_masks = target_masks.flatten(1)
+        target_masks = target_masks.view(src_masks.shape)
+        losses = {
+            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
+            "loss_dice": dice_loss(src_masks, target_masks, num_masks),
+        }
+        return losses
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat(
+            [torch.full_like(src, i) for i, (src, _) in enumerate(indices)]
+        )
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat(
+            [torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]
+        )
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+    def get_loss(self, loss, outputs, targets, indices, num_masks):
+        loss_map = {"labels": self.loss_labels, "masks": self.loss_masks}
+        assert loss in loss_map, f"do you really want to compute {loss} loss?"
+        return loss_map[loss](outputs, targets, indices, num_masks)
+
+    def forward(self, outputs, targets):
+        """This performs the loss computation.
+        Parameters:
+             outputs: dict of tensors, see the output specification of the model for the format
+             targets: list of dicts, such that len(targets) == batch_size.
+                      The expected keys in each dict depends on the losses applied, see each loss' doc
+        """
+        outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_masks = sum(len(t["labels"]) for t in targets)
+        num_masks = torch.as_tensor(
+            [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device
+        )
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_masks)
+        num_masks = torch.clamp(num_masks / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            losses.update(self.get_loss(loss, outputs, targets, indices, num_masks))
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if "aux_outputs" in outputs:
+            for i, aux_outputs in enumerate(outputs["aux_outputs"]):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    l_dict = self.get_loss(
+                        loss, aux_outputs, targets, indices, num_masks
+                    )
+                    l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        return losses
+
+    def clean_buffer(self):
+        self.cur_target = []
diff --git a/open_vocab_seg/modeling/heads/__init__.py b/open_vocab_seg/modeling/heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52db7cce67b1686f7cab3698f15b8f309c897918
--- /dev/null
+++ b/open_vocab_seg/modeling/heads/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
\ No newline at end of file
diff --git a/open_vocab_seg/modeling/heads/mask_former_head.py b/open_vocab_seg/modeling/heads/mask_former_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f592662f92d1b0862a3ef76304e7b28b46ecf80
--- /dev/null
+++ b/open_vocab_seg/modeling/heads/mask_former_head.py
@@ -0,0 +1,135 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import logging
+from copy import deepcopy
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer.transformer_predictor import TransformerPredictor
+from .pixel_decoder import build_pixel_decoder
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class MaskFormerHead(nn.Module):
+
+    _version = 2
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        version = local_metadata.get("version", None)
+        if version is None or version < 2:
+            # Do not warn if train from scratch
+            scratch = True
+            logger = logging.getLogger(__name__)
+            for k in list(state_dict.keys()):
+                newk = k
+                if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+                    newk = k.replace(prefix, prefix + "pixel_decoder.")
+                    # logger.debug(f"{k} ==> {newk}")
+                if newk != k:
+                    state_dict[newk] = state_dict[k]
+                    del state_dict[k]
+                    scratch = False
+
+            if not scratch:
+                logger.warning(
+                    f"Weight format of {self.__class__.__name__} have changed! "
+                    "Please upgrade your models. Applying automatic conversion now ..."
+                )
+
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        num_classes: int,
+        pixel_decoder: nn.Module,
+        loss_weight: float = 1.0,
+        ignore_value: int = -1,
+        # extra parameters
+        transformer_predictor: nn.Module,
+        transformer_in_feature: str,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            num_classes: number of classes to predict
+            pixel_decoder: the pixel decoder module
+            loss_weight: loss weight
+            ignore_value: category id to be ignored during training.
+            transformer_predictor: the transformer decoder that makes prediction
+            transformer_in_feature: input feature name to the transformer_predictor
+        """
+        super().__init__()
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]
+        feature_strides = [v.stride for k, v in input_shape]
+        feature_channels = [v.channels for k, v in input_shape]
+
+        self.ignore_value = ignore_value
+        self.common_stride = 4
+        self.loss_weight = loss_weight
+
+        self.pixel_decoder = pixel_decoder
+        self.predictor = transformer_predictor
+        self.transformer_in_feature = transformer_in_feature
+
+        self.num_classes = num_classes
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        return {
+            "input_shape": {
+                k: v
+                for k, v in input_shape.items()
+                if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+            },
+            "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+            "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+            "pixel_decoder": build_pixel_decoder(cfg, input_shape),
+            "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
+            "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
+            "transformer_predictor": TransformerPredictor(
+                cfg,
+                cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+                if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
+                else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
+                mask_classification=True,
+            ),
+        }
+
+    def forward(self, features):
+        return self.layers(features)
+
+    def layers(self, features):
+        (
+            mask_features,
+            transformer_encoder_features,
+        ) = self.pixel_decoder.forward_features(features)
+        if self.transformer_in_feature == "transformer_encoder":
+            assert (
+                transformer_encoder_features is not None
+            ), "Please use the TransformerEncoderPixelDecoder."
+            predictions = self.predictor(transformer_encoder_features, mask_features)
+        else:
+            predictions = self.predictor(
+                features[self.transformer_in_feature], mask_features
+            )
+        return predictions
diff --git a/open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py b/open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ed84f9a44d24415b3334fdf2ea8e1188de32de6
--- /dev/null
+++ b/open_vocab_seg/modeling/heads/open_vocab_mask_former_head.py
@@ -0,0 +1,145 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+# Modified by Feng Liang from
+# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/modeling/heads/zero_shot_mask_former_head.py
+
+import logging
+from copy import deepcopy
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer.open_vocab_transformer_predictor import OpenVocabTransformerPredictor
+from .pixel_decoder import build_pixel_decoder
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class OpenVocabMaskFormerHead(nn.Module):
+
+    _version = 2
+
+    def _load_from_state_dict(
+        self,
+        state_dict,
+        prefix,
+        local_metadata,
+        strict,
+        missing_keys,
+        unexpected_keys,
+        error_msgs,
+    ):
+        version = local_metadata.get("version", None)
+        if version is None or version < 2:
+            # Do not warn if train from scratch
+            scratch = True
+            logger = logging.getLogger(__name__)
+            for k in list(state_dict.keys()):
+                newk = k
+                if "sem_seg_head" in k and not k.startswith(prefix + "predictor"):
+                    newk = k.replace(prefix, prefix + "pixel_decoder.")
+                    # logger.debug(f"{k} ==> {newk}")
+                if newk != k:
+                    state_dict[newk] = state_dict[k]
+                    del state_dict[k]
+                    scratch = False
+
+            if not scratch:
+                logger.warning(
+                    f"Weight format of {self.__class__.__name__} have changed! "
+                    "Please upgrade your models. Applying automatic conversion now ..."
+                )
+
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        num_classes: int,
+        pixel_decoder: nn.Module,
+        loss_weight: float = 1.0,
+        ignore_value: int = -1,
+        # extra parameters
+        transformer_predictor: nn.Module,
+        transformer_in_feature: str,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            num_classes: number of classes to predict
+            pixel_decoder: the pixel decoder module
+            loss_weight: loss weight
+            ignore_value: category id to be ignored during training.
+            transformer_predictor: the transformer decoder that makes prediction
+            transformer_in_feature: input feature name to the transformer_predictor
+        """
+        super().__init__()
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]
+        feature_strides = [v.stride for k, v in input_shape]
+        feature_channels = [v.channels for k, v in input_shape]
+
+        self.ignore_value = ignore_value
+        self.common_stride = 4
+        self.loss_weight = loss_weight
+
+        self.pixel_decoder = pixel_decoder
+        self.predictor = transformer_predictor
+        self.transformer_in_feature = transformer_in_feature
+
+        self.num_classes = num_classes
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        return {
+            "input_shape": {
+                k: v
+                for k, v in input_shape.items()
+                if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+            },
+            "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
+            "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES,
+            "pixel_decoder": build_pixel_decoder(cfg, input_shape),
+            "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT,
+            "transformer_in_feature": cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE,
+            "transformer_predictor": OpenVocabTransformerPredictor(
+                cfg,
+                cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+                if cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE == "transformer_encoder"
+                else input_shape[cfg.MODEL.MASK_FORMER.TRANSFORMER_IN_FEATURE].channels,
+                mask_classification=True,
+            ),
+        }
+
+    def forward(self, features):
+        return self.layers(features)
+
+    def layers(self, features):
+        (
+            mask_features,
+            transformer_encoder_features,
+        ) = self.pixel_decoder.forward_features(features)
+        if self.transformer_in_feature == "transformer_encoder":
+            assert (
+                transformer_encoder_features is not None
+            ), "Please use the TransformerEncoderPixelDecoder."
+            predictions = self.predictor(transformer_encoder_features, mask_features)
+        else:
+            predictions = self.predictor(
+                features[self.transformer_in_feature], mask_features
+            )
+        return predictions
+
+    def freeze_pretrained(self):
+        for name, module in self.named_children():
+            if name not in ["predictor"]:
+                for param in module.parameters():
+                    param.requires_grad = False
+            else:
+                module.freeze_pretrained()
diff --git a/open_vocab_seg/modeling/heads/pixel_decoder.py b/open_vocab_seg/modeling/heads/pixel_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b10089331785e937b79cf82af6d8fba55519082
--- /dev/null
+++ b/open_vocab_seg/modeling/heads/pixel_decoder.py
@@ -0,0 +1,308 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import logging
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import fvcore.nn.weight_init as weight_init
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d, ShapeSpec, get_norm
+from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
+
+from ..transformer.position_encoding import PositionEmbeddingSine
+from ..transformer.transformer import TransformerEncoder, TransformerEncoderLayer
+
+
+def build_pixel_decoder(cfg, input_shape):
+    """
+    Build a pixel decoder from `cfg.MODEL.MASK_FORMER.PIXEL_DECODER_NAME`.
+    """
+    name = cfg.MODEL.SEM_SEG_HEAD.PIXEL_DECODER_NAME
+    model = SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape)
+    forward_features = getattr(model, "forward_features", None)
+    if not callable(forward_features):
+        raise ValueError(
+            "Only SEM_SEG_HEADS with forward_features method can be used as pixel decoder. "
+            f"Please implement forward_features for {name} to only return mask features."
+        )
+    return model
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class BasePixelDecoder(nn.Module):
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        conv_dim: int,
+        mask_dim: int,
+        norm: Optional[Union[str, Callable]] = None,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            conv_dims: number of output channels for the intermediate conv layers.
+            mask_dim: number of output channels for the final conv layer.
+            norm (str or callable): normalization for all conv layers
+        """
+        super().__init__()
+
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]  # starting from "res2" to "res5"
+        feature_channels = [v.channels for k, v in input_shape]
+
+        lateral_convs = []
+        output_convs = []
+
+        use_bias = norm == ""
+        for idx, in_channels in enumerate(feature_channels):
+            if idx == len(self.in_features) - 1:
+                output_norm = get_norm(norm, conv_dim)
+                output_conv = Conv2d(
+                    in_channels,
+                    conv_dim,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=use_bias,
+                    norm=output_norm,
+                    activation=F.relu,
+                )
+                weight_init.c2_xavier_fill(output_conv)
+                self.add_module("layer_{}".format(idx + 1), output_conv)
+
+                lateral_convs.append(None)
+                output_convs.append(output_conv)
+            else:
+                lateral_norm = get_norm(norm, conv_dim)
+                output_norm = get_norm(norm, conv_dim)
+
+                lateral_conv = Conv2d(
+                    in_channels,
+                    conv_dim,
+                    kernel_size=1,
+                    bias=use_bias,
+                    norm=lateral_norm,
+                )
+                output_conv = Conv2d(
+                    conv_dim,
+                    conv_dim,
+                    kernel_size=3,
+                    stride=1,
+                    padding=1,
+                    bias=use_bias,
+                    norm=output_norm,
+                    activation=F.relu,
+                )
+                weight_init.c2_xavier_fill(lateral_conv)
+                weight_init.c2_xavier_fill(output_conv)
+                self.add_module("adapter_{}".format(idx + 1), lateral_conv)
+                self.add_module("layer_{}".format(idx + 1), output_conv)
+
+                lateral_convs.append(lateral_conv)
+                output_convs.append(output_conv)
+        # Place convs into top-down order (from low to high resolution)
+        # to make the top-down computation in forward clearer.
+        self.lateral_convs = lateral_convs[::-1]
+        self.output_convs = output_convs[::-1]
+
+        self.mask_dim = mask_dim
+        self.mask_features = Conv2d(
+            conv_dim,
+            mask_dim,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+        )
+        weight_init.c2_xavier_fill(self.mask_features)
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        ret = {}
+        ret["input_shape"] = {
+            k: v
+            for k, v in input_shape.items()
+            if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES
+        }
+        ret["conv_dim"] = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+        ret["norm"] = cfg.MODEL.SEM_SEG_HEAD.NORM
+        return ret
+
+    def forward_features(self, features):
+        # Reverse feature maps into top-down order (from low to high resolution)
+        for idx, f in enumerate(self.in_features[::-1]):
+            x = features[f]
+            lateral_conv = self.lateral_convs[idx]
+            output_conv = self.output_convs[idx]
+            if lateral_conv is None:
+                y = output_conv(x)
+            else:
+                cur_fpn = lateral_conv(x)
+                # Following FPN implementation, we use nearest upsampling here
+                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+                y = output_conv(y)
+        return self.mask_features(y), None
+
+    def forward(self, features, targets=None):
+        logger = logging.getLogger(__name__)
+        logger.warning(
+            "Calling forward() may cause unpredicted behavior of PixelDecoder module."
+        )
+        return self.forward_features(features)
+
+
+class TransformerEncoderOnly(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_encoder_layers=6,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+    ):
+        super().__init__()
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+        self.encoder = TransformerEncoder(
+            encoder_layer, num_encoder_layers, encoder_norm
+        )
+
+        self._reset_parameters()
+
+        self.d_model = d_model
+        self.nhead = nhead
+
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def forward(self, src, mask, pos_embed):
+        # flatten NxCxHxW to HWxNxC
+        bs, c, h, w = src.shape
+        src = src.flatten(2).permute(2, 0, 1)
+        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+        if mask is not None:
+            mask = mask.flatten(1)
+
+        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+        return memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+@SEM_SEG_HEADS_REGISTRY.register()
+class TransformerEncoderPixelDecoder(BasePixelDecoder):
+    @configurable
+    def __init__(
+        self,
+        input_shape: Dict[str, ShapeSpec],
+        *,
+        transformer_dropout: float,
+        transformer_nheads: int,
+        transformer_dim_feedforward: int,
+        transformer_enc_layers: int,
+        transformer_pre_norm: bool,
+        conv_dim: int,
+        mask_dim: int,
+        norm: Optional[Union[str, Callable]] = None,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            input_shape: shapes (channels and stride) of the input features
+            transformer_dropout: dropout probability in transformer
+            transformer_nheads: number of heads in transformer
+            transformer_dim_feedforward: dimension of feedforward network
+            transformer_enc_layers: number of transformer encoder layers
+            transformer_pre_norm: whether to use pre-layernorm or not
+            conv_dims: number of output channels for the intermediate conv layers.
+            mask_dim: number of output channels for the final conv layer.
+            norm (str or callable): normalization for all conv layers
+        """
+        super().__init__(input_shape, conv_dim=conv_dim, mask_dim=mask_dim, norm=norm)
+
+        input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride)
+        self.in_features = [k for k, v in input_shape]  # starting from "res2" to "res5"
+        feature_strides = [v.stride for k, v in input_shape]
+        feature_channels = [v.channels for k, v in input_shape]
+
+        in_channels = feature_channels[len(self.in_features) - 1]
+        self.input_proj = Conv2d(in_channels, conv_dim, kernel_size=1)
+        weight_init.c2_xavier_fill(self.input_proj)
+        self.transformer = TransformerEncoderOnly(
+            d_model=conv_dim,
+            dropout=transformer_dropout,
+            nhead=transformer_nheads,
+            dim_feedforward=transformer_dim_feedforward,
+            num_encoder_layers=transformer_enc_layers,
+            normalize_before=transformer_pre_norm,
+        )
+        N_steps = conv_dim // 2
+        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+        # update layer
+        use_bias = norm == ""
+        output_norm = get_norm(norm, conv_dim)
+        output_conv = Conv2d(
+            conv_dim,
+            conv_dim,
+            kernel_size=3,
+            stride=1,
+            padding=1,
+            bias=use_bias,
+            norm=output_norm,
+            activation=F.relu,
+        )
+        weight_init.c2_xavier_fill(output_conv)
+        delattr(self, "layer_{}".format(len(self.in_features)))
+        self.add_module("layer_{}".format(len(self.in_features)), output_conv)
+        self.output_convs[0] = output_conv
+
+    @classmethod
+    def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]):
+        ret = super().from_config(cfg, input_shape)
+        ret["transformer_dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+        ret["transformer_nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        ret["transformer_dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+        ret[
+            "transformer_enc_layers"
+        ] = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS  # a separate config
+        ret["transformer_pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+        return ret
+
+    def forward_features(self, features):
+        # Reverse feature maps into top-down order (from low to high resolution)
+        for idx, f in enumerate(self.in_features[::-1]):
+            x = features[f]
+            lateral_conv = self.lateral_convs[idx]
+            output_conv = self.output_convs[idx]
+            if lateral_conv is None:
+                transformer = self.input_proj(x)
+                pos = self.pe_layer(x)
+                transformer = self.transformer(transformer, None, pos)
+                y = output_conv(transformer)
+                # save intermediate feature as input to Transformer decoder
+                transformer_encoder_features = transformer
+            else:
+                cur_fpn = lateral_conv(x)
+                # Following FPN implementation, we use nearest upsampling here
+                y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
+                y = output_conv(y)
+        return self.mask_features(y), transformer_encoder_features
+
+    def forward(self, features, targets=None):
+        logger = logging.getLogger(__name__)
+        logger.warning(
+            "Calling forward() may cause unpredicted behavior of PixelDecoder module."
+        )
+        return self.forward_features(features)
diff --git a/open_vocab_seg/modeling/matcher.py b/open_vocab_seg/modeling/matcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..a72ba671ad60db078e08046357a6aa0e5e9bd5dc
--- /dev/null
+++ b/open_vocab_seg/modeling/matcher.py
@@ -0,0 +1,187 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/models/matcher.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+"""
+Modules to compute the matching cost and solve the corresponding LSAP.
+"""
+import torch
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+from torch import nn
+
+
+def batch_dice_loss(inputs, targets):
+    """
+    Compute the DICE loss, similar to generalized IOU for masks
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+    """
+    inputs = inputs.sigmoid()
+    inputs = inputs.flatten(1)
+    numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets)
+    denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :]
+    loss = 1 - (numerator + 1) / (denominator + 1)
+    return loss
+
+
+def batch_sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+    Returns:
+        Loss tensor
+    """
+    hw = inputs.shape[1]
+
+    prob = inputs.sigmoid()
+    focal_pos = ((1 - prob) ** gamma) * F.binary_cross_entropy_with_logits(
+        inputs, torch.ones_like(inputs), reduction="none"
+    )
+    focal_neg = (prob ** gamma) * F.binary_cross_entropy_with_logits(
+        inputs, torch.zeros_like(inputs), reduction="none"
+    )
+    if alpha >= 0:
+        focal_pos = focal_pos * alpha
+        focal_neg = focal_neg * (1 - alpha)
+
+    loss = torch.einsum("nc,mc->nm", focal_pos, targets) + torch.einsum(
+        "nc,mc->nm", focal_neg, (1 - targets)
+    )
+
+    return loss / hw
+
+
+class HungarianMatcher(nn.Module):
+    """This class computes an assignment between the targets and the predictions of the network
+
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    def __init__(
+        self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1
+    ):
+        """Creates the matcher
+
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost
+            cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost
+        """
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_mask = cost_mask
+        self.cost_dice = cost_dice
+        assert (
+            cost_class != 0 or cost_mask != 0 or cost_dice != 0
+        ), "all costs cant be 0"
+
+    @torch.no_grad()
+    def memory_efficient_forward(self, outputs, targets):
+        """More memory-friendly matching"""
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        # Work out the mask padding size
+        masks = [v["masks"] for v in targets]
+        h_max = max([m.shape[1] for m in masks])
+        w_max = max([m.shape[2] for m in masks])
+
+        indices = []
+
+        # Iterate through batch size
+        for b in range(bs):
+
+            out_prob = outputs["pred_logits"][b].softmax(
+                -1
+            )  # [num_queries, num_classes]
+            out_mask = outputs["pred_masks"][b]  # [num_queries, H_pred, W_pred]
+
+            tgt_ids = targets[b]["labels"]
+            # gt masks are already padded when preparing target
+            tgt_mask = targets[b]["masks"].to(out_mask)
+
+            # Compute the classification cost. Contrary to the loss, we don't use the NLL,
+            # but approximate it in 1 - proba[target class].
+            # The 1 is a constant that doesn't change the matching, it can be ommitted.
+            cost_class = -out_prob[:, tgt_ids]
+
+            # Downsample gt masks to save memory
+            tgt_mask = F.interpolate(
+                tgt_mask[:, None], size=out_mask.shape[-2:], mode="nearest"
+            )
+
+            # Flatten spatial dimension
+            out_mask = out_mask.flatten(1)  # [batch_size * num_queries, H*W]
+            tgt_mask = tgt_mask[:, 0].flatten(1)  # [num_total_targets, H*W]
+
+            # Compute the focal loss between masks
+            cost_mask = batch_sigmoid_focal_loss(out_mask, tgt_mask)
+
+            # Compute the dice loss betwen masks
+            cost_dice = batch_dice_loss(out_mask, tgt_mask)
+
+            # Final cost matrix
+            C = (
+                self.cost_mask * cost_mask
+                + self.cost_class * cost_class
+                + self.cost_dice * cost_dice
+            )
+            C = C.reshape(num_queries, -1).cpu()
+
+            indices.append(linear_sum_assignment(C))
+        return [
+            (
+                torch.as_tensor(i, dtype=torch.int64),
+                torch.as_tensor(j, dtype=torch.int64),
+            )
+            for i, j in indices
+        ]
+
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        """Performs the matching
+
+        Params:
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks
+
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks
+
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        return self.memory_efficient_forward(outputs, targets)
+
+    def __repr__(self):
+        head = "Matcher " + self.__class__.__name__
+        body = [
+            "cost_class: {}".format(self.cost_class),
+            "cost_mask: {}".format(self.cost_mask),
+            "cost_dice: {}".format(self.cost_dice),
+        ]
+        _repr_indent = 4
+        lines = [head] + [" " * _repr_indent + line for line in body]
+        return "\n".join(lines)
diff --git a/open_vocab_seg/modeling/transformer/__init__.py b/open_vocab_seg/modeling/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..49f9003b7a688f5396170dd89c26ef335a2c201f
--- /dev/null
+++ b/open_vocab_seg/modeling/transformer/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
diff --git a/open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py b/open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..0efee3e14c71400a1cc5a55ea6c21b6876189aaa
--- /dev/null
+++ b/open_vocab_seg/modeling/transformer/open_vocab_transformer_predictor.py
@@ -0,0 +1,84 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from torch import nn
+from detectron2.config import configurable
+from .transformer_predictor import TransformerPredictor, MLP
+
+
+class OpenVocabTransformerPredictor(TransformerPredictor):
+    @configurable
+    def __init__(
+        self,
+        in_channels,
+        mask_classification=True,
+        *,
+        embedding_dim: int,
+        embed_hidden_dim: int,
+        embed_layers: int,
+        hidden_dim: int,
+        num_queries: int,
+        nheads: int,
+        dropout: float,
+        dim_feedforward: int,
+        enc_layers: int,
+        dec_layers: int,
+        pre_norm: bool,
+        deep_supervision: bool,
+        mask_dim: int,
+        enforce_input_project: bool,
+    ):
+        super().__init__(
+            in_channels,
+            False,
+            num_classes=embedding_dim,
+            hidden_dim=hidden_dim,
+            num_queries=num_queries,
+            nheads=nheads,
+            dropout=dropout,
+            dim_feedforward=dim_feedforward,
+            enc_layers=enc_layers,
+            dec_layers=dec_layers,
+            pre_norm=pre_norm,
+            deep_supervision=deep_supervision,
+            mask_dim=mask_dim,
+            enforce_input_project=enforce_input_project,
+        )
+        self.mask_classification = mask_classification
+        # output FFNs
+        if self.mask_classification:
+            self.class_embed = MLP(
+                hidden_dim, embed_hidden_dim, embedding_dim, embed_layers
+            )
+
+    def freeze_pretrained(self):
+        for name, module in self.named_children():
+            if name not in ["class_embed"]:
+                for param in module.parameters():
+                    param.requires_grad = False
+
+    @classmethod
+    def from_config(cls, cfg, in_channels, mask_classification):
+        ret = {}
+        ret["in_channels"] = in_channels
+        ret["mask_classification"] = mask_classification
+
+        ret["embedding_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBEDDING_DIM
+        ret["embed_hidden_dim"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_HIDDEN_DIM
+        ret["embed_layers"] = cfg.MODEL.SEM_SEG_HEAD.EMBED_LAYERS
+        ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
+        ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
+        # Transformer parameters:
+        ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+        ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+        ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
+        ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+        ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+        ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+        ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
+
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+
+        return ret
diff --git a/open_vocab_seg/modeling/transformer/position_encoding.py b/open_vocab_seg/modeling/transformer/position_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..db236c5b36cbc4f4435a83b542bdc242cbb441c3
--- /dev/null
+++ b/open_vocab_seg/modeling/transformer/position_encoding.py
@@ -0,0 +1,58 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# # Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+"""
+Various positional encodings for the transformer.
+"""
+import math
+
+import torch
+from torch import nn
+
+
+class PositionEmbeddingSine(nn.Module):
+    """
+    This is a more standard version of the position embedding, very similar to the one
+    used by the Attention is all you need paper, generalized to work on images.
+    """
+
+    def __init__(
+        self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+    ):
+        super().__init__()
+        self.num_pos_feats = num_pos_feats
+        self.temperature = temperature
+        self.normalize = normalize
+        if scale is not None and normalize is False:
+            raise ValueError("normalize should be True if scale is passed")
+        if scale is None:
+            scale = 2 * math.pi
+        self.scale = scale
+
+    def forward(self, x, mask=None):
+        if mask is None:
+            mask = torch.zeros(
+                (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
+            )
+        not_mask = ~mask
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+        if self.normalize:
+            eps = 1e-6
+            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+        pos_x = x_embed[:, :, :, None] / dim_t
+        pos_y = y_embed[:, :, :, None] / dim_t
+        pos_x = torch.stack(
+            (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos_y = torch.stack(
+            (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+        ).flatten(3)
+        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        return pos
diff --git a/open_vocab_seg/modeling/transformer/transformer.py b/open_vocab_seg/modeling/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..76d1003b3852ce72c6ad5c3c23705f380197362f
--- /dev/null
+++ b/open_vocab_seg/modeling/transformer/transformer.py
@@ -0,0 +1,380 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+"""
+Transformer class.
+
+Copy-paste from torch.nn.Transformer with modifications:
+    * positional encodings are passed in MHattention
+    * extra LN at the end of encoder is removed
+    * decoder returns a stack of activations from all decoding layers
+"""
+import copy
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        d_model=512,
+        nhead=8,
+        num_encoder_layers=6,
+        num_decoder_layers=6,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+        return_intermediate_dec=False,
+    ):
+        super().__init__()
+
+        encoder_layer = TransformerEncoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
+        self.encoder = TransformerEncoder(
+            encoder_layer, num_encoder_layers, encoder_norm
+        )
+
+        decoder_layer = TransformerDecoderLayer(
+            d_model, nhead, dim_feedforward, dropout, activation, normalize_before
+        )
+        decoder_norm = nn.LayerNorm(d_model)
+        self.decoder = TransformerDecoder(
+            decoder_layer,
+            num_decoder_layers,
+            decoder_norm,
+            return_intermediate=return_intermediate_dec,
+        )
+
+        self._reset_parameters()
+
+        self.d_model = d_model
+        self.nhead = nhead
+
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def forward(self, src, mask, query_embed, pos_embed):
+        # flatten NxCxHxW to HWxNxC
+        bs, c, h, w = src.shape
+        src = src.flatten(2).permute(2, 0, 1)
+        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+        if mask is not None:
+            mask = mask.flatten(1)
+
+        tgt = torch.zeros_like(query_embed)
+        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
+        hs = self.decoder(
+            tgt,
+            memory,
+            memory_key_padding_mask=mask,
+            pos=pos_embed,
+            query_pos=query_embed,
+        )
+        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
+
+
+class TransformerEncoder(nn.Module):
+    def __init__(self, encoder_layer, num_layers, norm=None):
+        super().__init__()
+        self.layers = _get_clones(encoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+
+    def forward(
+        self,
+        src,
+        mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        output = src
+
+        for layer in self.layers:
+            output = layer(
+                output,
+                src_mask=mask,
+                src_key_padding_mask=src_key_padding_mask,
+                pos=pos,
+            )
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output
+
+
+class TransformerDecoder(nn.Module):
+    def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
+        super().__init__()
+        self.layers = _get_clones(decoder_layer, num_layers)
+        self.num_layers = num_layers
+        self.norm = norm
+        self.return_intermediate = return_intermediate
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        output = tgt
+
+        intermediate = []
+
+        for layer in self.layers:
+            output = layer(
+                output,
+                memory,
+                tgt_mask=tgt_mask,
+                memory_mask=memory_mask,
+                tgt_key_padding_mask=tgt_key_padding_mask,
+                memory_key_padding_mask=memory_key_padding_mask,
+                pos=pos,
+                query_pos=query_pos,
+            )
+            if self.return_intermediate:
+                intermediate.append(self.norm(output))
+
+        if self.norm is not None:
+            output = self.norm(output)
+            if self.return_intermediate:
+                intermediate.pop()
+                intermediate.append(output)
+
+        if self.return_intermediate:
+            return torch.stack(intermediate)
+
+        return output.unsqueeze(0)
+
+
+class TransformerEncoderLayer(nn.Module):
+    def __init__(
+        self,
+        d_model,
+        nhead,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+    ):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(
+        self,
+        src,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        q = k = self.with_pos_embed(src, pos)
+        src2 = self.self_attn(
+            q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+        )[0]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        return src
+
+    def forward_pre(
+        self,
+        src,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        src2 = self.norm1(src)
+        q = k = self.with_pos_embed(src2, pos)
+        src2 = self.self_attn(
+            q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
+        )[0]
+        src = src + self.dropout1(src2)
+        src2 = self.norm2(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+        src = src + self.dropout2(src2)
+        return src
+
+    def forward(
+        self,
+        src,
+        src_mask: Optional[Tensor] = None,
+        src_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+    ):
+        if self.normalize_before:
+            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
+        return self.forward_post(src, src_mask, src_key_padding_mask, pos)
+
+
+class TransformerDecoderLayer(nn.Module):
+    def __init__(
+        self,
+        d_model,
+        nhead,
+        dim_feedforward=2048,
+        dropout=0.1,
+        activation="relu",
+        normalize_before=False,
+    ):
+        super().__init__()
+        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
+        # Implementation of Feedforward model
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation = _get_activation_fn(activation)
+        self.normalize_before = normalize_before
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def forward_post(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        q = k = self.with_pos_embed(tgt, query_pos)
+        tgt2 = self.self_attn(
+            q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+        )[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+        tgt2 = self.multihead_attn(
+            query=self.with_pos_embed(tgt, query_pos),
+            key=self.with_pos_embed(memory, pos),
+            value=memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout3(tgt2)
+        tgt = self.norm3(tgt)
+        return tgt
+
+    def forward_pre(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        tgt2 = self.norm1(tgt)
+        q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(
+            q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
+        )[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt2 = self.norm2(tgt)
+        tgt2 = self.multihead_attn(
+            query=self.with_pos_embed(tgt2, query_pos),
+            key=self.with_pos_embed(memory, pos),
+            value=memory,
+            attn_mask=memory_mask,
+            key_padding_mask=memory_key_padding_mask,
+        )[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt2 = self.norm3(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout3(tgt2)
+        return tgt
+
+    def forward(
+        self,
+        tgt,
+        memory,
+        tgt_mask: Optional[Tensor] = None,
+        memory_mask: Optional[Tensor] = None,
+        tgt_key_padding_mask: Optional[Tensor] = None,
+        memory_key_padding_mask: Optional[Tensor] = None,
+        pos: Optional[Tensor] = None,
+        query_pos: Optional[Tensor] = None,
+    ):
+        if self.normalize_before:
+            return self.forward_pre(
+                tgt,
+                memory,
+                tgt_mask,
+                memory_mask,
+                tgt_key_padding_mask,
+                memory_key_padding_mask,
+                pos,
+                query_pos,
+            )
+        return self.forward_post(
+            tgt,
+            memory,
+            tgt_mask,
+            memory_mask,
+            tgt_key_padding_mask,
+            memory_key_padding_mask,
+            pos,
+            query_pos,
+        )
+
+
+def _get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+
+def _get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
diff --git a/open_vocab_seg/modeling/transformer/transformer_predictor.py b/open_vocab_seg/modeling/transformer/transformer_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..72378abe29c01809a00fa1b87d275258ee9c91fa
--- /dev/null
+++ b/open_vocab_seg/modeling/transformer/transformer_predictor.py
@@ -0,0 +1,179 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from: https://github.com/facebookresearch/detr/blob/master/models/detr.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import fvcore.nn.weight_init as weight_init
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.layers import Conv2d
+
+from .position_encoding import PositionEmbeddingSine
+from .transformer import Transformer
+
+
+class TransformerPredictor(nn.Module):
+    @configurable
+    def __init__(
+        self,
+        in_channels,
+        mask_classification=True,
+        *,
+        num_classes: int,
+        hidden_dim: int,
+        num_queries: int,
+        nheads: int,
+        dropout: float,
+        dim_feedforward: int,
+        enc_layers: int,
+        dec_layers: int,
+        pre_norm: bool,
+        deep_supervision: bool,
+        mask_dim: int,
+        enforce_input_project: bool,
+    ):
+        """
+        NOTE: this interface is experimental.
+        Args:
+            in_channels: channels of the input features
+            mask_classification: whether to add mask classifier or not
+            num_classes: number of classes
+            hidden_dim: Transformer feature dimension
+            num_queries: number of queries
+            nheads: number of heads
+            dropout: dropout in Transformer
+            dim_feedforward: feature dimension in feedforward network
+            enc_layers: number of Transformer encoder layers
+            dec_layers: number of Transformer decoder layers
+            pre_norm: whether to use pre-LayerNorm or not
+            deep_supervision: whether to add supervision to every decoder layers
+            mask_dim: mask feature dimension
+            enforce_input_project: add input project 1x1 conv even if input
+                channels and hidden dim is identical
+        """
+        super().__init__()
+
+        self.mask_classification = mask_classification
+
+        # positional encoding
+        N_steps = hidden_dim // 2
+        self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
+
+        transformer = Transformer(
+            d_model=hidden_dim,
+            dropout=dropout,
+            nhead=nheads,
+            dim_feedforward=dim_feedforward,
+            num_encoder_layers=enc_layers,
+            num_decoder_layers=dec_layers,
+            normalize_before=pre_norm,
+            return_intermediate_dec=deep_supervision,
+        )
+
+        self.num_queries = num_queries
+        self.transformer = transformer
+        hidden_dim = transformer.d_model
+
+        self.query_embed = nn.Embedding(num_queries, hidden_dim)
+
+        if in_channels != hidden_dim or enforce_input_project:
+            self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1)
+            weight_init.c2_xavier_fill(self.input_proj)
+        else:
+            self.input_proj = nn.Sequential()
+        self.aux_loss = deep_supervision
+
+        # output FFNs
+        if self.mask_classification:
+            self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
+        self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)
+
+    @classmethod
+    def from_config(cls, cfg, in_channels, mask_classification):
+        ret = {}
+        ret["in_channels"] = in_channels
+        ret["mask_classification"] = mask_classification
+
+        ret["num_classes"] = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
+        ret["hidden_dim"] = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
+        ret["num_queries"] = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
+        # Transformer parameters:
+        ret["nheads"] = cfg.MODEL.MASK_FORMER.NHEADS
+        ret["dropout"] = cfg.MODEL.MASK_FORMER.DROPOUT
+        ret["dim_feedforward"] = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
+        ret["enc_layers"] = cfg.MODEL.MASK_FORMER.ENC_LAYERS
+        ret["dec_layers"] = cfg.MODEL.MASK_FORMER.DEC_LAYERS
+        ret["pre_norm"] = cfg.MODEL.MASK_FORMER.PRE_NORM
+        ret["deep_supervision"] = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
+        ret["enforce_input_project"] = cfg.MODEL.MASK_FORMER.ENFORCE_INPUT_PROJ
+
+        ret["mask_dim"] = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
+
+        return ret
+
+    def forward(self, x, mask_features):
+        pos = self.pe_layer(x)
+
+        src = x
+        mask = None
+        hs, memory = self.transformer(
+            self.input_proj(src), mask, self.query_embed.weight, pos
+        )
+
+        if self.mask_classification:
+            outputs_class = self.class_embed(hs)
+            out = {"pred_logits": outputs_class[-1]}
+        else:
+            out = {}
+
+        if self.aux_loss:
+            # [l, bs, queries, embed]
+            mask_embed = self.mask_embed(hs)
+            outputs_seg_masks = torch.einsum(
+                "lbqc,bchw->lbqhw", mask_embed, mask_features
+            )
+            out["pred_masks"] = outputs_seg_masks[-1]
+            out["aux_outputs"] = self._set_aux_loss(
+                outputs_class if self.mask_classification else None, outputs_seg_masks
+            )
+        else:
+            # FIXME h_boxes takes the last one computed, keep this in mind
+            # [bs, queries, embed]
+            mask_embed = self.mask_embed(hs[-1])
+            outputs_seg_masks = torch.einsum(
+                "bqc,bchw->bqhw", mask_embed, mask_features
+            )
+            out["pred_masks"] = outputs_seg_masks
+        return out
+
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_seg_masks):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        if self.mask_classification:
+            return [
+                {"pred_logits": a, "pred_masks": b}
+                for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1])
+            ]
+        else:
+            return [{"pred_masks": b} for b in outputs_seg_masks[:-1]]
+
+
+class MLP(nn.Module):
+    """Very simple multi-layer perceptron (also called FFN)"""
+
+    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(
+            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+        )
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
diff --git a/open_vocab_seg/ovseg_model.py b/open_vocab_seg/ovseg_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..48df93168a5bdfb831715f7b8c008d7b7a5d3814
--- /dev/null
+++ b/open_vocab_seg/ovseg_model.py
@@ -0,0 +1,460 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+# Modified by Feng Liang from
+# https://github.com/MendelXu/zsseg.baseline/blob/master/mask_former/zero_shot_mask_former_model.py
+
+import logging
+from typing import Tuple
+
+import numpy as np
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from detectron2.config import configurable
+from detectron2.data import MetadataCatalog
+from detectron2.modeling import META_ARCH_REGISTRY
+from detectron2.modeling.backbone import Backbone
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+from detectron2.structures import ImageList
+from detectron2.utils.logger import log_first_n
+from .modeling.clip_adapter import (
+    ClipAdapter,
+    MaskFormerClipAdapter,
+    build_text_prompt,
+)
+from .mask_former_model import MaskFormer
+from .utils.misc import get_gt_binary_masks
+
+@META_ARCH_REGISTRY.register()
+class OVSeg(MaskFormer):
+    """
+    Main class for zero shot mask classification semantic segmentation architectures.
+    """
+
+    @configurable
+    def __init__(
+        self,
+        *,
+        backbone: Backbone,
+        sem_seg_head: nn.Module,
+        clip_adapter: nn.Module,
+        criterion: nn.Module,
+        num_queries: int,
+        panoptic_on: bool,
+        object_mask_threshold: float,
+        overlap_threshold: float,
+        metadata,
+        size_divisibility: int,
+        sem_seg_postprocess_before_inference: bool,
+        clip_ensemble: bool,
+        clip_ensemble_weight: float,
+        pixel_mean: Tuple[float],
+        pixel_std: Tuple[float],
+    ):
+        """
+        Args:
+            backbone: a backbone module, must follow detectron2's backbone interface
+            sem_seg_head: a module that predicts semantic segmentation from backbone features
+            criterion: a module that defines the loss
+            clip_adapter: adapter for clip-based mask classification
+            num_queries: int, number of queries
+            panoptic_on: bool, whether to output panoptic segmentation prediction
+            object_mask_threshold: float, threshold to filter query based on classification score
+                for panoptic segmentation inference
+            overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+            metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+                segmentation inference
+            size_divisibility: Some backbones require the input height and width to be divisible by a
+                specific integer. We can use this to override such requirement.
+            sem_seg_postprocess_before_inference: whether to resize the prediction back
+                to original input size before semantic segmentation inference or after.
+                For high-resolution dataset like Mapillary, resizing predictions before
+                inference will cause OOM error.
+            pixel_mean, pixel_std: list or tuple with #channels element, representing
+                the per-channel mean and std to be used to normalize the input image
+        """
+        super().__init__(
+            backbone=backbone,
+            sem_seg_head=sem_seg_head,
+            criterion=criterion,
+            num_queries=num_queries,
+            panoptic_on=panoptic_on,
+            object_mask_threshold=object_mask_threshold,
+            overlap_threshold=overlap_threshold,
+            metadata=metadata,
+            size_divisibility=size_divisibility,
+            sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference,
+            pixel_mean=pixel_mean,
+            pixel_std=pixel_std,
+        )
+        self.clip_adapter: ClipAdapter = clip_adapter
+
+        self.clip_ensemble: bool = clip_ensemble
+        self.clip_ensemble_weight: float = clip_ensemble_weight
+
+    @classmethod
+    def from_config(cls, cfg):
+        init_kwargs = MaskFormer.from_config(cfg)
+        text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER)
+
+        clip_adapter = MaskFormerClipAdapter(
+            cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME,
+            text_templates,
+            mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL,
+            mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO,
+            mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR,
+            mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING,
+            region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED,
+            mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH,
+            mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD,
+        )
+        init_kwargs["clip_adapter"] = clip_adapter
+        init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE
+        init_kwargs[
+            "clip_ensemble_weight"
+        ] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT
+
+        return init_kwargs
+
+    def forward(self, batched_inputs):
+        """
+        Args:
+            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+                Each item in the list contains the inputs for one image.
+                For now, each item in the list is a dict that contains:
+                   * "image": Tensor, image in (C, H, W) format.
+                   * "instances": per-region ground truth
+                   * Other information that's included in the original dicts, such as:
+                     "height", "width" (int): the output resolution of the model (may be different
+                     from input resolution), used in inference.
+        Returns:
+            list[dict]:
+                each dict has the results for one image. The dict contains the following keys:
+
+                * "sem_seg":
+                    A Tensor that represents the
+                    per-pixel segmentation prediced by the head.
+                    The prediction has shape KxHxW that represents the logits of
+                    each class for each pixel.
+                * "panoptic_seg":
+                    A tuple that represent panoptic output
+                    panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+                    segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+                        Each dict contains keys "id", "category_id", "isthing".
+        """
+        dataset_name = [x["meta"]["dataset_name"] for x in batched_inputs]
+        assert len(set(dataset_name)) == 1
+        dataset_name = dataset_name[0]
+
+        images = [x["image"].to(self.device) for x in batched_inputs]
+        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+        images = ImageList.from_tensors(images, self.size_divisibility)
+
+        features = self.backbone(images.tensor)
+        outputs = self.sem_seg_head(features)
+        class_names = self.get_class_name_list(dataset_name)
+        text_features = self.clip_adapter.get_text_features(class_names)
+        outputs["pred_logits"] = self.clip_adapter.get_sim_logits(
+            text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"])
+        )
+        if self.training:
+            if "aux_outputs" in outputs.keys():
+                for i in range(len(outputs["aux_outputs"])):
+                    outputs["aux_outputs"][i][
+                        "pred_logits"
+                    ] = self.clip_adapter.get_sim_logits(
+                        text_features,
+                        self.clip_adapter.normalize_feature(
+                            outputs["aux_outputs"][i]["pred_logits"]
+                        ),
+                    )
+            # mask classification target
+            if "instances" in batched_inputs[0]:
+                gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
+                targets = self.prepare_targets(gt_instances, images)
+            else:
+                targets = None
+
+            # bipartite matching-based loss
+            losses = self.criterion(outputs, targets)
+
+            for k in list(losses.keys()):
+                if k in self.criterion.weight_dict:
+                    losses[k] *= self.criterion.weight_dict[k]
+                else:
+                    # remove this loss if not specified in `weight_dict`
+                    losses.pop(k)
+
+            return losses
+        else:
+            mask_cls_results = outputs["pred_logits"]
+            mask_pred_results = outputs["pred_masks"]
+            # upsample masks
+            mask_pred_results = F.interpolate(
+                mask_pred_results,
+                size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+                mode="bilinear",
+                align_corners=False,
+            )
+
+            processed_results = []
+            for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
+                mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
+            ):
+                height = image_size[0]
+                width = image_size[1]
+                mask_pred_result = sem_seg_postprocess(
+                    mask_pred_result, image_size, height, width
+                )
+                image = input_per_image["image"].to(self.device)
+
+                r, regions = self.semantic_inference(
+                    mask_cls_result, mask_pred_result, image, class_names
+                )
+
+                height = input_per_image.get("height", image_size[0])
+                width = input_per_image.get("width", image_size[1])
+                r = sem_seg_postprocess(r, image_size, height, width)
+                processed_results.append({"sem_seg": r})
+
+                # panoptic segmentation inference
+                if self.panoptic_on:
+                    panoptic_r = self.panoptic_inference(
+                        mask_cls_result, mask_pred_result
+                    )
+                    processed_results[-1]["panoptic_seg"] = panoptic_r
+
+            return processed_results
+
+
+    def semantic_inference(self, mask_cls, mask_pred, image, class_names):
+        mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+        mask_pred = mask_pred.sigmoid()
+
+        regions = None
+        if self.clip_ensemble:
+            clip_cls, regions, valid_flag = self.clip_adapter(
+                image, class_names, mask_pred, normalize=True
+            )
+            if clip_cls is None:
+                clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device)
+            # softmax before index or after?
+            clip_cls = F.softmax(clip_cls[:, :-1], dim=-1)
+            if self.clip_ensemble_weight > 0:
+                map_back_clip_cls = mask_cls.new_ones(mask_cls.shape)
+                map_back_clip_cls[valid_flag] = clip_cls
+                mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \
+                           torch.pow(map_back_clip_cls, self.clip_ensemble_weight)
+
+
+            else:
+                # only clip model predictions are used
+                mask_cls = clip_cls
+                mask_pred = mask_pred[valid_flag]
+        semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
+        return semseg, regions
+
+    def get_class_name_list(self, dataset_name):
+        class_names = [
+            c.strip() for c in MetadataCatalog.get(dataset_name).stuff_classes
+        ]
+        return class_names
+
+
+@META_ARCH_REGISTRY.register()
+class OVSegDEMO(MaskFormer):
+    """
+    Main class for zero shot mask classification semantic segmentation architectures.
+    """
+
+    @configurable
+    def __init__(
+        self,
+        *,
+        backbone: Backbone,
+        sem_seg_head: nn.Module,
+        clip_adapter: nn.Module,
+        criterion: nn.Module,
+        num_queries: int,
+        panoptic_on: bool,
+        object_mask_threshold: float,
+        overlap_threshold: float,
+        metadata,
+        size_divisibility: int,
+        sem_seg_postprocess_before_inference: bool,
+        clip_ensemble: bool,
+        clip_ensemble_weight: float,
+        pixel_mean: Tuple[float],
+        pixel_std: Tuple[float],
+    ):
+        """
+        Args:
+            backbone: a backbone module, must follow detectron2's backbone interface
+            sem_seg_head: a module that predicts semantic segmentation from backbone features
+            criterion: a module that defines the loss
+            clip_adapter: adapter for clip-based mask classification
+            num_queries: int, number of queries
+            panoptic_on: bool, whether to output panoptic segmentation prediction
+            object_mask_threshold: float, threshold to filter query based on classification score
+                for panoptic segmentation inference
+            overlap_threshold: overlap threshold used in general inference for panoptic segmentation
+            metadata: dataset meta, get `thing` and `stuff` category names for panoptic
+                segmentation inference
+            size_divisibility: Some backbones require the input height and width to be divisible by a
+                specific integer. We can use this to override such requirement.
+            sem_seg_postprocess_before_inference: whether to resize the prediction back
+                to original input size before semantic segmentation inference or after.
+                For high-resolution dataset like Mapillary, resizing predictions before
+                inference will cause OOM error.
+            pixel_mean, pixel_std: list or tuple with #channels element, representing
+                the per-channel mean and std to be used to normalize the input image
+        """
+        super().__init__(
+            backbone=backbone,
+            sem_seg_head=sem_seg_head,
+            criterion=criterion,
+            num_queries=num_queries,
+            panoptic_on=panoptic_on,
+            object_mask_threshold=object_mask_threshold,
+            overlap_threshold=overlap_threshold,
+            metadata=metadata,
+            size_divisibility=size_divisibility,
+            sem_seg_postprocess_before_inference=sem_seg_postprocess_before_inference,
+            pixel_mean=pixel_mean,
+            pixel_std=pixel_std,
+        )
+        self.clip_adapter: ClipAdapter = clip_adapter
+
+        self.clip_ensemble: bool = clip_ensemble
+        self.clip_ensemble_weight: float = clip_ensemble_weight
+
+    @classmethod
+    def from_config(cls, cfg):
+        init_kwargs = MaskFormer.from_config(cfg)
+        text_templates = build_text_prompt(cfg.MODEL.CLIP_ADAPTER)
+
+        clip_adapter = MaskFormerClipAdapter(
+            cfg.MODEL.CLIP_ADAPTER.CLIP_MODEL_NAME,
+            text_templates,
+            mask_fill=cfg.MODEL.CLIP_ADAPTER.MASK_FILL,
+            mask_expand_ratio=cfg.MODEL.CLIP_ADAPTER.MASK_EXPAND_RATIO,
+            mask_thr=cfg.MODEL.CLIP_ADAPTER.MASK_THR,
+            mask_matting=cfg.MODEL.CLIP_ADAPTER.MASK_MATTING,
+            region_resized=cfg.MODEL.CLIP_ADAPTER.REGION_RESIZED,
+            mask_prompt_depth=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_DEPTH,
+            mask_prompt_fwd=cfg.MODEL.CLIP_ADAPTER.MASK_PROMPT_FWD,
+        )
+        init_kwargs["clip_adapter"] = clip_adapter
+        init_kwargs["clip_ensemble"] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE
+        init_kwargs[
+            "clip_ensemble_weight"
+        ] = cfg.MODEL.CLIP_ADAPTER.CLIP_ENSEMBLE_WEIGHT
+
+        return init_kwargs
+
+    def forward(self, batched_inputs):
+        """
+        Args:
+            batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
+                Each item in the list contains the inputs for one image.
+                For now, each item in the list is a dict that contains:
+                   * "image": Tensor, image in (C, H, W) format.
+                   * "instances": per-region ground truth
+                   * Other information that's included in the original dicts, such as:
+                     "height", "width" (int): the output resolution of the model (may be different
+                     from input resolution), used in inference.
+        Returns:
+            list[dict]:
+                each dict has the results for one image. The dict contains the following keys:
+
+                * "sem_seg":
+                    A Tensor that represents the
+                    per-pixel segmentation prediced by the head.
+                    The prediction has shape KxHxW that represents the logits of
+                    each class for each pixel.
+                * "panoptic_seg":
+                    A tuple that represent panoptic output
+                    panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment.
+                    segments_info (list[dict]): Describe each segment in `panoptic_seg`.
+                        Each dict contains keys "id", "category_id", "isthing".
+        """
+        images = [x["image"].to(self.device) for x in batched_inputs]
+        images = [(x - self.pixel_mean) / self.pixel_std for x in images]
+        images = ImageList.from_tensors(images, self.size_divisibility)
+
+        features = self.backbone(images.tensor)
+        outputs = self.sem_seg_head(features)
+        class_names = batched_inputs[0]["class_names"]
+        if len(class_names) == 1:
+            # Because classification is performed in a 'contrastive' manner, adding others to represent other concepts
+            class_names.append('others')
+        text_features = self.clip_adapter.get_text_features(class_names)
+        outputs["pred_logits"] = self.clip_adapter.get_sim_logits(
+            text_features, self.clip_adapter.normalize_feature(outputs["pred_logits"])
+        )
+        mask_cls_results = outputs["pred_logits"]
+        mask_pred_results = outputs["pred_masks"]
+        # upsample masks
+        mask_pred_results = F.interpolate(
+            mask_pred_results,
+            size=(images.tensor.shape[-2], images.tensor.shape[-1]),
+            mode="bilinear",
+            align_corners=False,
+        )
+
+        processed_results = []
+        for mask_cls_result, mask_pred_result, input_per_image, image_size in zip(
+            mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes
+        ):
+            height = image_size[0]
+            width = image_size[1]
+            mask_pred_result = sem_seg_postprocess(
+                mask_pred_result, image_size, height, width
+            )
+            image = input_per_image["image"].to(self.device)
+
+            r, regions = self.demo_inference(mask_cls_result, mask_pred_result, image, class_names)
+
+            height = input_per_image.get("height", image_size[0])
+            width = input_per_image.get("width", image_size[1])
+            r = sem_seg_postprocess(r, image_size, height, width)
+            processed_results.append({"sem_seg": r})
+
+        return processed_results
+
+
+
+
+    def demo_inference(self, mask_cls, mask_pred, image, class_names):
+        mask_cls = F.softmax(mask_cls, dim=-1)[..., :-1]
+        mask_pred = mask_pred.sigmoid()
+
+        regions = None
+        if self.clip_ensemble:
+            clip_cls, regions, valid_flag = self.clip_adapter(
+                image, class_names, mask_pred, normalize=True
+            )
+            if clip_cls is None:
+                clip_cls = torch.empty(0, mask_cls.shape[-1] + 1, device=self.device)
+            # softmax before index or after?
+            clip_cls = F.softmax(clip_cls[:, :-1], dim=-1)
+            if self.clip_ensemble_weight > 0:
+                map_back_clip_cls = mask_cls.new_ones(mask_cls.shape)
+                map_back_clip_cls[valid_flag] = clip_cls
+                mask_cls = torch.pow(mask_cls, 1 - self.clip_ensemble_weight) * \
+                           torch.pow(map_back_clip_cls, self.clip_ensemble_weight)
+
+            else:
+                # only clip model predictions are used
+                mask_cls = clip_cls
+                mask_pred = mask_pred[valid_flag]
+        bin_mask = mask_pred > self.clip_adapter.mask_thr
+        select_cls = torch.zeros(sum(valid_flag), mask_cls.shape[-1], device=self.device)
+        select_mask = torch.argmax(mask_cls, dim=0)
+        if len(class_names) == 2 and class_names[-1] == 'others':
+            select_mask = select_mask[:-1]
+        for idx in select_mask:
+            select_cls[idx] = mask_cls[idx]
+        semseg = torch.einsum("qc,qhw->chw", select_cls, bin_mask.float())
+        return semseg, regions
diff --git a/open_vocab_seg/test_time_augmentation.py b/open_vocab_seg/test_time_augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb7a51f28419c59775013c74fdee49e5166bde51
--- /dev/null
+++ b/open_vocab_seg/test_time_augmentation.py
@@ -0,0 +1,217 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import copy
+from itertools import count
+import math
+import numpy as np
+import torch
+from fvcore.transforms import HFlipTransform
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel
+
+from detectron2.data.detection_utils import read_image
+from detectron2.modeling import DatasetMapperTTA
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+import logging
+from detectron2.utils.logger import log_every_n, log_first_n
+
+__all__ = [
+    "SemanticSegmentorWithTTA",
+]
+
+
+class SemanticSegmentorWithTTA(nn.Module):
+    """
+    A SemanticSegmentor with test-time augmentation enabled.
+    Its :meth:`__call__` method has the same interface as :meth:`SemanticSegmentor.forward`.
+    """
+
+    def __init__(self, cfg, model, tta_mapper=None, batch_size=1):
+        """
+        Args:
+            cfg (CfgNode):
+            model (SemanticSegmentor): a SemanticSegmentor to apply TTA on.
+            tta_mapper (callable): takes a dataset dict and returns a list of
+                augmented versions of the dataset dict. Defaults to
+                `DatasetMapperTTA(cfg)`.
+            batch_size (int): batch the augmented images into this batch size for inference.
+        """
+        super().__init__()
+        if isinstance(model, DistributedDataParallel):
+            model = model.module
+        self.cfg = cfg.clone()
+
+        self.model = model
+
+        if tta_mapper is None:
+            tta_mapper = DatasetMapperTTA(cfg)
+        self.tta_mapper = tta_mapper
+        self.batch_size = batch_size
+
+    def _inference_with_model(self, inputs):
+        if self.cfg.TEST.SLIDING_WINDOW:
+            log_first_n(logging.INFO, "Using sliding window to test")
+
+            outputs = []
+
+            for input in inputs:
+                image_size = input["image"].shape[1:]  # h,w
+                if self.cfg.TEST.SLIDING_TILE_SIZE > 0:
+                    tile_size = (
+                        self.cfg.TEST.SLIDING_TILE_SIZE,
+                        self.cfg.TEST.SLIDING_TILE_SIZE,
+                    )
+                else:
+                    selected_mapping = {256: 224, 512: 256, 768: 512, 896: 512}
+                    tile_size = min(image_size)
+                    tile_size = selected_mapping[tile_size]
+                    tile_size = (tile_size, tile_size)
+                extra_info = {
+                    k: v
+                    for k, v in input.items()
+                    if k not in ["image", "height", "width"]
+                }
+                log_every_n(
+                    logging.INFO, "split {} to {}".format(image_size, tile_size)
+                )
+                overlap = self.cfg.TEST.SLIDING_OVERLAP
+                stride = math.ceil(tile_size[0] * (1 - overlap))
+                tile_rows = int(
+                    math.ceil((image_size[0] - tile_size[0]) / stride) + 1
+                )  # strided convolution formula
+                tile_cols = int(math.ceil((image_size[1] - tile_size[1]) / stride) + 1)
+                full_probs = None
+                count_predictions = None
+                tile_counter = 0
+
+                for row in range(tile_rows):
+                    for col in range(tile_cols):
+                        x1 = int(col * stride)
+                        y1 = int(row * stride)
+                        x2 = min(x1 + tile_size[1], image_size[1])
+                        y2 = min(y1 + tile_size[0], image_size[0])
+                        x1 = max(
+                            int(x2 - tile_size[1]), 0
+                        )  # for portrait images the x1 underflows sometimes
+                        y1 = max(
+                            int(y2 - tile_size[0]), 0
+                        )  # for very few rows y1 underflows
+
+                        img = input["image"][:, y1:y2, x1:x2]
+                        padded_img = nn.functional.pad(
+                            img,
+                            (
+                                0,
+                                tile_size[1] - img.shape[-1],
+                                0,
+                                tile_size[0] - img.shape[-2],
+                            ),
+                        )
+                        tile_counter += 1
+                        padded_input = {"image": padded_img}
+                        padded_input.update(extra_info)
+                        padded_prediction = self.model([padded_input])[0]["sem_seg"]
+                        prediction = padded_prediction[
+                            :, 0 : img.shape[1], 0 : img.shape[2]
+                        ]
+                        if full_probs is None:
+                            full_probs = prediction.new_zeros(
+                                prediction.shape[0], image_size[0], image_size[1]
+                            )
+                        if count_predictions is None:
+                            count_predictions = prediction.new_zeros(
+                                prediction.shape[0], image_size[0], image_size[1]
+                            )
+                        count_predictions[:, y1:y2, x1:x2] += 1
+                        full_probs[
+                            :, y1:y2, x1:x2
+                        ] += prediction  # accumulate the predictions also in the overlapping regions
+
+                full_probs /= count_predictions
+                full_probs = sem_seg_postprocess(
+                    full_probs,
+                    image_size,
+                    input.get("height", image_size[0]),
+                    input.get("width", image_size[1]),
+                )
+                outputs.append({"sem_seg": full_probs})
+
+            return outputs
+        else:
+            log_first_n(logging.INFO, "Using whole image to test")
+            return self.model(inputs)
+
+    def _batch_inference(self, batched_inputs):
+        """
+        Execute inference on a list of inputs,
+        using batch size = self.batch_size, instead of the length of the list.
+        Inputs & outputs have the same format as :meth:`SemanticSegmentor.forward`
+        """
+        outputs = []
+        inputs = []
+        for idx, input in zip(count(), batched_inputs):
+            inputs.append(input)
+            if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1:
+                with torch.no_grad():
+                    outputs.extend(self._inference_with_model(inputs))
+                inputs = []
+        return outputs
+
+    def __call__(self, batched_inputs):
+        """
+        Same input/output format as :meth:`SemanticSegmentor.forward`
+        """
+
+        def _maybe_read_image(dataset_dict):
+            ret = copy.copy(dataset_dict)
+            if "image" not in ret:
+                image = read_image(ret.pop("file_name"), self.model.input_format)
+                image = torch.from_numpy(
+                    np.ascontiguousarray(image.transpose(2, 0, 1))
+                )  # CHW
+                ret["image"] = image
+            if "height" not in ret and "width" not in ret:
+                ret["height"] = image.shape[1]
+                ret["width"] = image.shape[2]
+            return ret
+
+        return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs]
+
+    def _inference_one_image(self, input):
+        """
+        Args:
+            input (dict): one dataset dict with "image" field being a CHW tensor
+        Returns:
+            dict: one output dict
+        """
+        augmented_inputs, tfms = self._get_augmented_inputs(input)
+        # 1: forward with all augmented images
+        outputs = self._batch_inference(augmented_inputs)
+        # Delete now useless variables to avoid being out of memory
+        del augmented_inputs
+        # 2: merge the results
+        # handle flip specially
+        # outputs = [output.detach() for output in outputs]
+        return self._merge_auged_output(outputs, tfms)
+
+    def _merge_auged_output(self, outputs, tfms):
+        new_outputs = []
+        for output, tfm in zip(outputs, tfms):
+            if any(isinstance(t, HFlipTransform) for t in tfm.transforms):
+                new_outputs.append(output["sem_seg"].flip(dims=[2]))
+            else:
+                new_outputs.append(output["sem_seg"])
+        del outputs
+        # to avoid OOM with torch.stack
+        final_predictions = new_outputs[0]
+        for i in range(1, len(new_outputs)):
+            final_predictions += new_outputs[i]
+        final_predictions = final_predictions / len(new_outputs)
+        del new_outputs
+        return {"sem_seg": final_predictions}
+
+    def _get_augmented_inputs(self, input):
+        augmented_inputs = self.tta_mapper(input)
+        tfms = [x.pop("transforms") for x in augmented_inputs]
+        return augmented_inputs, tfms
diff --git a/open_vocab_seg/utils/__init__.py b/open_vocab_seg/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcf832dce405bbdcf45f2534a782494b37760cd9
--- /dev/null
+++ b/open_vocab_seg/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+from .events import setup_wandb, WandbWriter
+from .predictor import VisualizationDemo, SAMVisualizationDemo
\ No newline at end of file
diff --git a/open_vocab_seg/utils/events.py b/open_vocab_seg/utils/events.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbe82ce80a7110a1018167763ba3adc90f58faa0
--- /dev/null
+++ b/open_vocab_seg/utils/events.py
@@ -0,0 +1,121 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import os
+import wandb
+from detectron2.utils import comm
+from detectron2.utils.events import EventWriter, get_event_storage
+
+
+def setup_wandb(cfg, args):
+    if comm.is_main_process():
+        init_args = {
+            k.lower(): v
+            for k, v in cfg.WANDB.items()
+            if isinstance(k, str) and k not in ["config", "name"]
+        }
+        # only include most related part to avoid too big table
+        # TODO: add configurable params to select which part of `cfg` should be saved in config
+        if "config_exclude_keys" in init_args:
+            init_args["config"] = cfg
+            init_args["config"]["cfg_file"] = args.config_file
+        else:
+            init_args["config"] = {
+                "model": cfg.MODEL,
+                "solver": cfg.SOLVER,
+                "cfg_file": args.config_file,
+            }
+        if ("name" not in init_args) or (init_args["name"] is None):
+            init_args["name"] = os.path.basename(args.config_file)
+        wandb.init(**init_args)
+
+
+class BaseRule(object):
+    def __call__(self, target):
+        return target
+
+
+class IsIn(BaseRule):
+    def __init__(self, keyword: str):
+        self.keyword = keyword
+
+    def __call__(self, target):
+        return self.keyword in target
+
+
+class Prefix(BaseRule):
+    def __init__(self, keyword: str):
+        self.keyword = keyword
+
+    def __call__(self, target):
+        return "/".join([self.keyword, target])
+
+
+class WandbWriter(EventWriter):
+    """
+    Write all scalars to a tensorboard file.
+    """
+
+    def __init__(self):
+        """
+        Args:
+            log_dir (str): the directory to save the output events
+            kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
+        """
+        self._last_write = -1
+        self._group_rules = [
+            (IsIn("/"), BaseRule()),
+            (IsIn("loss"), Prefix("train")),
+        ]
+
+    def write(self):
+
+        storage = get_event_storage()
+
+        def _group_name(scalar_name):
+            for (rule, op) in self._group_rules:
+                if rule(scalar_name):
+                    return op(scalar_name)
+            return scalar_name
+
+        stats = {
+            _group_name(name): scalars[0]
+            for name, scalars in storage.latest().items()
+            if scalars[1] > self._last_write
+        }
+        if len(stats) > 0:
+            self._last_write = max([v[1] for k, v in storage.latest().items()])
+
+        # storage.put_{image,histogram} is only meant to be used by
+        # tensorboard writer. So we access its internal fields directly from here.
+        if len(storage._vis_data) >= 1:
+            stats["image"] = [
+                wandb.Image(img, caption=img_name)
+                for img_name, img, step_num in storage._vis_data
+            ]
+            # Storage stores all image data and rely on this writer to clear them.
+            # As a result it assumes only one writer will use its image data.
+            # An alternative design is to let storage store limited recent
+            # data (e.g. only the most recent image) that all writers can access.
+            # In that case a writer may not see all image data if its period is long.
+            storage.clear_images()
+
+        if len(storage._histograms) >= 1:
+
+            def create_bar(tag, bucket_limits, bucket_counts, **kwargs):
+                data = [
+                    [label, val] for (label, val) in zip(bucket_limits, bucket_counts)
+                ]
+                table = wandb.Table(data=data, columns=["label", "value"])
+                return wandb.plot.bar(table, "label", "value", title=tag)
+
+            stats["hist"] = [create_bar(**params) for params in storage._histograms]
+
+            storage.clear_histograms()
+
+        if len(stats) == 0:
+            return
+        wandb.log(stats, step=storage.iter)
+
+    def close(self):
+        wandb.finish()
diff --git a/open_vocab_seg/utils/misc.py b/open_vocab_seg/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a22d0a978c9cd89595c6e7c900885e1c148844b1
--- /dev/null
+++ b/open_vocab_seg/utils/misc.py
@@ -0,0 +1,126 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Modified by Bowen Cheng from https://github.com/facebookresearch/detr/blob/master/util/misc.py
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+import torchvision
+from torch import Tensor
+
+
+def _max_by_axis(the_list):
+    # type: (List[List[int]]) -> List[int]
+    maxes = the_list[0]
+    for sublist in the_list[1:]:
+        for index, item in enumerate(sublist):
+            maxes[index] = max(maxes[index], item)
+    return maxes
+
+
+class NestedTensor(object):
+    def __init__(self, tensors, mask: Optional[Tensor]):
+        self.tensors = tensors
+        self.mask = mask
+
+    def to(self, device):
+        # type: (Device) -> NestedTensor # noqa
+        cast_tensor = self.tensors.to(device)
+        mask = self.mask
+        if mask is not None:
+            assert mask is not None
+            cast_mask = mask.to(device)
+        else:
+            cast_mask = None
+        return NestedTensor(cast_tensor, cast_mask)
+
+    def decompose(self):
+        return self.tensors, self.mask
+
+    def __repr__(self):
+        return str(self.tensors)
+
+
+def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
+    # TODO make this more general
+    if tensor_list[0].ndim == 3:
+        if torchvision._is_tracing():
+            # nested_tensor_from_tensor_list() does not export well to ONNX
+            # call _onnx_nested_tensor_from_tensor_list() instead
+            return _onnx_nested_tensor_from_tensor_list(tensor_list)
+
+        # TODO make it support different-sized images
+        max_size = _max_by_axis([list(img.shape) for img in tensor_list])
+        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
+        batch_shape = [len(tensor_list)] + max_size
+        b, c, h, w = batch_shape
+        dtype = tensor_list[0].dtype
+        device = tensor_list[0].device
+        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
+        mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
+        for img, pad_img, m in zip(tensor_list, tensor, mask):
+            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+            m[: img.shape[1], : img.shape[2]] = False
+    else:
+        raise ValueError("not supported")
+    return NestedTensor(tensor, mask)
+
+
+# _onnx_nested_tensor_from_tensor_list() is an implementation of
+# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
+@torch.jit.unused
+def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
+    max_size = []
+    for i in range(tensor_list[0].dim()):
+        max_size_i = torch.max(
+            torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
+        ).to(torch.int64)
+        max_size.append(max_size_i)
+    max_size = tuple(max_size)
+
+    # work around for
+    # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
+    # m[: img.shape[1], :img.shape[2]] = False
+    # which is not yet supported in onnx
+    padded_imgs = []
+    padded_masks = []
+    for img in tensor_list:
+        padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
+        padded_img = torch.nn.functional.pad(
+            img, (0, padding[2], 0, padding[1], 0, padding[0])
+        )
+        padded_imgs.append(padded_img)
+
+        m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
+        padded_mask = torch.nn.functional.pad(
+            m, (0, padding[2], 0, padding[1]), "constant", 1
+        )
+        padded_masks.append(padded_mask.to(torch.bool))
+
+    tensor = torch.stack(padded_imgs)
+    mask = torch.stack(padded_masks)
+
+    return NestedTensor(tensor, mask=mask)
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+def get_gt_binary_masks(gt_semseg):
+    mask_ids = torch.unique(gt_semseg)
+    gt_masks = []
+    for id in mask_ids:
+        if id != 255:
+            gt_masks.append(gt_semseg == id)
+    gt_masks = torch.stack(gt_masks).float()
+    return gt_masks
diff --git a/open_vocab_seg/utils/post_process_utils.py b/open_vocab_seg/utils/post_process_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed214319d90ceba0b47ef835072102b9ffec5179
--- /dev/null
+++ b/open_vocab_seg/utils/post_process_utils.py
@@ -0,0 +1,74 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import torch
+from torch.nn import functional as F
+import numpy as np
+
+try:
+    import pydensecrf.densecrf as dcrf
+    from pydensecrf.utils import (
+        unary_from_softmax,
+        unary_from_labels,
+        create_pairwise_bilateral,
+        create_pairwise_gaussian,
+    )
+except:
+    dcrf = None
+
+
+def dense_crf_post_process(
+    logits,
+    image,
+    n_labels=None,
+    max_iters=5,
+    pos_xy_std=(3, 3),
+    pos_w=3,
+    bi_xy_std=(80, 80),
+    bi_rgb_std=(13, 13, 13),
+    bi_w=10,
+):
+    """
+    logits : [C,H,W]
+    image : [3,H,W]
+    """
+    if dcrf is None:
+        raise FileNotFoundError(
+            "pydensecrf is required to perform dense crf inference."
+        )
+    if isinstance(logits, torch.Tensor):
+        logits = F.softmax(logits, dim=0).detach().cpu().numpy()
+        U = unary_from_softmax(logits)
+        n_labels = logits.shape[0]
+    elif logits.ndim == 3:
+        U = unary_from_softmax(logits)
+        n_labels = logits.shape[0]
+    else:
+        assert n_labels is not None
+        U = unary_from_labels(logits, n_labels, zero_unsure=False)
+
+    d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], n_labels)
+
+    d.setUnaryEnergy(U)
+
+    # This adds the color-independent term, features are the locations only.
+    d.addPairwiseGaussian(
+        sxy=pos_xy_std,
+        compat=pos_w,
+        kernel=dcrf.DIAG_KERNEL,
+        normalization=dcrf.NORMALIZE_SYMMETRIC,
+    )
+
+    # This adds the color-dependent term, i.e. features are (x,y,r,g,b).
+    d.addPairwiseBilateral(
+        sxy=bi_xy_std,
+        srgb=bi_rgb_std,
+        rgbim=image,
+        compat=bi_w,
+        kernel=dcrf.DIAG_KERNEL,
+        normalization=dcrf.NORMALIZE_SYMMETRIC,
+    )
+    # Run five inference steps.
+    logits = d.inference(max_iters)
+    logits = np.asarray(logits).reshape((n_labels, image.shape[0], image.shape[1]))
+    return torch.from_numpy(logits)
diff --git a/open_vocab_seg/utils/predictor.py b/open_vocab_seg/utils/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f6cd8f86af6995768a27849f37baab82d6486cf
--- /dev/null
+++ b/open_vocab_seg/utils/predictor.py
@@ -0,0 +1,232 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+# Copyright (c) Meta Platforms, Inc. All Rights Reserved
+
+import numpy as np
+import torch
+from torch.nn import functional as F
+import cv2
+
+from detectron2.data import MetadataCatalog
+from detectron2.structures import BitMasks
+from detectron2.engine.defaults import DefaultPredictor
+from detectron2.utils.visualizer import ColorMode, Visualizer
+from detectron2.modeling.postprocessing import sem_seg_postprocess
+
+import open_clip
+from segment_anything import SamAutomaticMaskGenerator, sam_model_registry 
+from open_vocab_seg.modeling.clip_adapter.adapter import PIXEL_MEAN, PIXEL_STD
+from open_vocab_seg.modeling.clip_adapter.utils import crop_with_mask
+
+class OVSegPredictor(DefaultPredictor):
+    def __init__(self, cfg):
+        super().__init__(cfg)
+
+    def __call__(self, original_image, class_names):
+        """
+        Args:
+            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+
+        Returns:
+            predictions (dict):
+                the output of the model for one image only.
+                See :doc:`/tutorials/models` for details about the format.
+        """
+        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
+            # Apply pre-processing to image.
+            if self.input_format == "RGB":
+                # whether the model expects BGR inputs or RGB
+                original_image = original_image[:, :, ::-1]
+            height, width = original_image.shape[:2]
+            image = self.aug.get_transform(original_image).apply_image(original_image)
+            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
+
+            inputs = {"image": image, "height": height, "width": width, "class_names": class_names}
+            predictions = self.model([inputs])[0]
+            return predictions
+
+class OVSegVisualizer(Visualizer):
+    def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE, class_names=None):
+        super().__init__(img_rgb, metadata, scale, instance_mode)
+        self.class_names = class_names
+
+    def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8):
+        """
+        Draw semantic segmentation predictions/labels.
+
+        Args:
+            sem_seg (Tensor or ndarray): the segmentation of shape (H, W).
+                Each value is the integer label of the pixel.
+            area_threshold (int): segments with less than `area_threshold` are not drawn.
+            alpha (float): the larger it is, the more opaque the segmentations are.
+
+        Returns:
+            output (VisImage): image object with visualizations.
+        """
+        if isinstance(sem_seg, torch.Tensor):
+            sem_seg = sem_seg.numpy()
+        labels, areas = np.unique(sem_seg, return_counts=True)
+        sorted_idxs = np.argsort(-areas).tolist()
+        labels = labels[sorted_idxs]
+        class_names = self.class_names if self.class_names is not None else self.metadata.stuff_classes
+
+        for label in filter(lambda l: l < len(class_names), labels):
+            try:
+                mask_color = [x / 255 for x in self.metadata.stuff_colors[label]]
+            except (AttributeError, IndexError):
+                mask_color = None
+
+            binary_mask = (sem_seg == label).astype(np.uint8)
+            text = class_names[label]
+            self.draw_binary_mask(
+                binary_mask,
+                color=mask_color,
+                edge_color=(1.0, 1.0, 240.0 / 255),
+                text=text,
+                alpha=alpha,
+                area_threshold=area_threshold,
+            )
+        return self.output
+
+
+
+class VisualizationDemo(object):
+    def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False):
+        """
+        Args:
+            cfg (CfgNode):
+            instance_mode (ColorMode):
+            parallel (bool): whether to run the model in different processes from visualization.
+                Useful since the visualization logic can be slow.
+        """
+        self.metadata = MetadataCatalog.get(
+            cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
+        )
+
+        self.cpu_device = torch.device("cpu")
+        self.instance_mode = instance_mode
+
+        self.parallel = parallel
+        if parallel:
+            raise NotImplementedError
+        else:
+            self.predictor = OVSegPredictor(cfg)
+
+    def run_on_image(self, image, class_names):
+        """
+        Args:
+            image (np.ndarray): an image of shape (H, W, C) (in BGR order).
+                This is the format used by OpenCV.
+        Returns:
+            predictions (dict): the output of the model.
+            vis_output (VisImage): the visualized image output.
+        """
+        predictions = self.predictor(image, class_names)
+        # Convert image from OpenCV BGR format to Matplotlib RGB format.
+        image = image[:, :, ::-1]
+        visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
+        if "sem_seg" in predictions:
+            r = predictions["sem_seg"]
+            blank_area = (r[0] == 0)
+            pred_mask = r.argmax(dim=0).to('cpu')
+            pred_mask[blank_area] = 255
+            pred_mask = np.array(pred_mask, dtype=np.int)
+
+            vis_output = visualizer.draw_sem_seg(
+                pred_mask
+            )
+        else:
+            raise NotImplementedError
+
+        return predictions, vis_output
+    
+class SAMVisualizationDemo(object):
+    def __init__(self, cfg, granularity, sam_path, ovsegclip_path, instance_mode=ColorMode.IMAGE, parallel=False):
+        self.metadata = MetadataCatalog.get(
+            cfg.DATASETS.TEST[0] if len(cfg.DATASETS.TEST) else "__unused"
+        )
+
+        self.cpu_device = torch.device("cpu")
+        self.instance_mode = instance_mode
+
+        self.parallel = parallel
+        self.granularity = granularity
+        sam = sam_model_registry["vit_l"](checkpoint=sam_path).cuda()
+        self.predictor = SamAutomaticMaskGenerator(sam, points_per_batch=16)
+        self.clip_model, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained=ovsegclip_path)
+        self.clip_model.cuda()
+
+    def run_on_image(self, image, class_names):
+        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+        visualizer = OVSegVisualizer(image, self.metadata, instance_mode=self.instance_mode, class_names=class_names)
+        with torch.no_grad(), torch.cuda.amp.autocast():
+            masks = self.predictor.generate(image)
+        pred_masks = [masks[i]['segmentation'][None,:,:] for i in range(len(masks))]
+        pred_masks = np.row_stack(pred_masks)
+        pred_masks = BitMasks(pred_masks)
+        bboxes = pred_masks.get_bounding_boxes()
+
+        mask_fill = [255.0 * c for c in PIXEL_MEAN]
+
+        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
+
+        regions = []
+        for bbox, mask in zip(bboxes, pred_masks):
+            region, _ = crop_with_mask(
+                image,
+                mask,
+                bbox,
+                fill=mask_fill,
+            )
+            regions.append(region.unsqueeze(0))
+        regions = [F.interpolate(r.to(torch.float), size=(224, 224), mode="bicubic") for r in regions]
+
+        pixel_mean = torch.tensor(PIXEL_MEAN).reshape(1, -1, 1, 1)
+        pixel_std = torch.tensor(PIXEL_STD).reshape(1, -1, 1, 1)
+        imgs = [(r/255.0 - pixel_mean) / pixel_std for r in regions]
+        imgs = torch.cat(imgs)
+        if len(class_names) == 1:
+            class_names.append('others')
+        txts = [f'a photo of {cls_name}' for cls_name in class_names]
+        text = open_clip.tokenize(txts)
+
+        img_batches = torch.split(imgs, 32, dim=0)
+
+        with torch.no_grad(), torch.cuda.amp.autocast():
+            text_features = self.clip_model.encode_text(text.cuda())
+            text_features /= text_features.norm(dim=-1, keepdim=True)
+            image_features = []
+            for img_batch in img_batches:
+                image_feat = self.clip_model.encode_image(img_batch.cuda().half())
+                image_feat /= image_feat.norm(dim=-1, keepdim=True)
+                image_features.append(image_feat.detach())
+            image_features = torch.cat(image_features, dim=0)
+            class_preds = (100.0 * image_features @ text_features.T).softmax(dim=-1)
+        select_cls = torch.zeros_like(class_preds)
+
+        max_scores, select_mask = torch.max(class_preds, dim=0)
+        if len(class_names) == 2 and class_names[-1] == 'others':
+            select_mask = select_mask[:-1]
+        if self.granularity < 1:
+            thr_scores = max_scores * self.granularity
+            select_mask = []
+            if len(class_names) == 2 and class_names[-1] == 'others':
+                thr_scores = thr_scores[:-1]
+            for i, thr in enumerate(thr_scores):
+                cls_pred = class_preds[:,i]
+                locs = torch.where(cls_pred > thr)
+                select_mask.extend(locs[0].tolist())
+        for idx in select_mask:
+            select_cls[idx] = class_preds[idx]
+        semseg = torch.einsum("qc,qhw->chw", select_cls.float(), pred_masks.tensor.float().cuda())
+
+        r = semseg
+        blank_area = (r[0] == 0)
+        pred_mask = r.argmax(dim=0).to('cpu')
+        pred_mask[blank_area] = 255
+        pred_mask = np.array(pred_mask, dtype=np.int)
+
+        vis_output = visualizer.draw_sem_seg(
+            pred_mask
+        )
+
+        return None, vis_output
\ No newline at end of file
diff --git a/ovseg_clip_l_9a1909.pth b/ovseg_clip_l_9a1909.pth
new file mode 100644
index 0000000000000000000000000000000000000000..f88dad1269412ee0c449597330d31dbb6d3e1042
--- /dev/null
+++ b/ovseg_clip_l_9a1909.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:eb5cbc83b922e18241654a19ad4cb836cf4f00169cd5684a4932d8a6f825dd36
+size 1710616901
diff --git a/ovseg_swinB_vitL_demo.yaml b/ovseg_swinB_vitL_demo.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aaeb0e2cc02a9a7bee7c02f6cfb313add8da9794
--- /dev/null
+++ b/ovseg_swinB_vitL_demo.yaml
@@ -0,0 +1,99 @@
+MODEL:
+  META_ARCHITECTURE: "OVSegDEMO"
+  BACKBONE:
+    FREEZE_AT: 0
+    NAME: "D2SwinTransformer"
+  SWIN:
+    EMBED_DIM: 128
+    DEPTHS: [2, 2, 18, 2]
+    NUM_HEADS: [4, 8, 16, 32]
+    WINDOW_SIZE: 12
+    APE: False
+    DROP_PATH_RATE: 0.3
+    PATCH_NORM: True
+    PRETRAIN_IMG_SIZE: 384
+  WEIGHTS: "./ovseg_swinbase_vitL14_ft_mpt.pth"
+  PIXEL_MEAN: [123.675, 116.280, 103.530]
+  PIXEL_STD: [58.395, 57.120, 57.375]
+  SEM_SEG_HEAD:
+    NAME: "OpenVocabMaskFormerHead"
+    IN_FEATURES: ["res2", "res3", "res4", "res5"]
+    IGNORE_VALUE: 255
+    NUM_CLASSES: 171 # number of categories in training set
+    EMBEDDING_DIM: 768
+    EMBED_LAYERS: 2
+    COMMON_STRIDE: 4 # not used, hard-coded
+    LOSS_WEIGHT: 1.0
+    CONVS_DIM: 256
+    MASK_DIM: 256
+    NORM: "GN"
+  MASK_FORMER:
+    TRANSFORMER_IN_FEATURE: "res5"
+    DEEP_SUPERVISION: True
+    NO_OBJECT_WEIGHT: 0.1
+    DICE_WEIGHT: 1.0
+    MASK_WEIGHT: 20.0
+    HIDDEN_DIM: 256
+    NUM_OBJECT_QUERIES: 100
+    NHEADS: 8
+    DROPOUT: 0.1
+    DIM_FEEDFORWARD: 2048
+    ENC_LAYERS: 0
+    DEC_LAYERS: 6
+    PRE_NORM: False
+  CLIP_ADAPTER:
+    TEXT_TEMPLATES: "vild"
+    CLIP_MODEL_NAME: "ViT-L/14"
+    MASK_FILL: "mean"
+    MASK_EXPAND_RATIO: 1.0
+    MASK_THR: 0.35 # choose the foreground objects
+    MASK_MATTING: False # use soft background, default not used
+    MASK_PROMPT_DEPTH: 3
+    MASK_PROMPT_FWD: True # use mask prompt during forward
+    REGION_RESIZED: True # resize to the input of clip, e.g., 224
+    CLIP_ENSEMBLE: True # use ensemble of two classification branches
+    CLIP_ENSEMBLE_WEIGHT: 0.0
+DATASETS:
+  TRAIN: ("coco_2017_train_stuff_sem_seg",)
+  TEST: ("ade20k_sem_seg_val",)
+SOLVER:
+  IMS_PER_BATCH: 32
+  BASE_LR: 0.00006
+  MAX_ITER: 120000
+  WARMUP_FACTOR: 1e-6
+  WARMUP_ITERS: 1500
+  WEIGHT_DECAY: 0.01
+  WEIGHT_DECAY_NORM: 0.0
+  WEIGHT_DECAY_EMBED: 0.0
+  BACKBONE_MULTIPLIER: 1.0
+  TEST_IMS_PER_BATCH: 1
+  CLIP_GRADIENTS:
+    ENABLED: True
+    CLIP_TYPE: "full_model"
+    CLIP_VALUE: 0.01
+    NORM_TYPE: 2.0
+INPUT:
+  MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"]
+  MIN_SIZE_TRAIN_SAMPLING: "choice"
+  MIN_SIZE_TEST: 640
+  MAX_SIZE_TRAIN: 2560
+  MAX_SIZE_TEST: 2560
+  CROP:
+    ENABLED: True
+    TYPE: "absolute"
+    SIZE: (640, 640)
+    SINGLE_CATEGORY_MAX_AREA: 1.0
+  COLOR_AUG_SSD: True
+  SIZE_DIVISIBILITY: 640  # used in dataset mapper
+  FORMAT: "RGB"
+TEST:
+  EVAL_PERIOD: 5000
+  AUG:
+    ENABLED: False
+    MIN_SIZES: [256, 384, 512, 640, 768, 896]
+    MAX_SIZE: 3584
+    FLIP: True
+DATALOADER:
+  FILTER_EMPTY_ANNOTATIONS: True
+  NUM_WORKERS: 4
+VERSION: 2
\ No newline at end of file
diff --git a/ovseg_swinbase_vitL14_ft_mpt.pth b/ovseg_swinbase_vitL14_ft_mpt.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0d2dcc4c4e721b187574f4c3829c58236713037a
--- /dev/null
+++ b/ovseg_swinbase_vitL14_ft_mpt.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:dd3731dde48d96654aba63e5a93753dc837d6889162a18ddf0877f5463d94c90
+size 2129343629
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e60263206e63eced27619991b26cd842f6bc5649
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,28 @@
+numpy>=1.18.5,<1.24.0
+cython
+scipy
+shapely
+timm
+h5py
+wandb
+fire
+opencv-python
+pandas
+ftfy
+regex
+tqdm
+gdown
+# Torch
+--find-links https://download.pytorch.org/whl/cu113/torch_stable.html
+torch==1.10.1+cu113
+torchvision==0.11.2+cu113
+
+# Detectron
+--find-links https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html
+detectron2
+
+# Segment-anything
+git+https://github.com/facebookresearch/segment-anything.git
+
+# open_clip
+open_clip_torch==1.3.0
diff --git a/resources/demo_samples/sample_01.jpeg b/resources/demo_samples/sample_01.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..b3e0245ab25117d340bf23059af6b5dcbfc8a811
--- /dev/null
+++ b/resources/demo_samples/sample_01.jpeg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:154943906b5ed394b620da62124c4421dfa96f858f014839eb346678aaa71fc3
+size 4323630
diff --git a/resources/demo_samples/sample_02.jpeg b/resources/demo_samples/sample_02.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..e5489289e1128c33060f1d30a353452907f0a1d8
--- /dev/null
+++ b/resources/demo_samples/sample_02.jpeg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:591c2bf26a843a62881d89dbd7f4e9a6f90dda9fb8786c9b6e5172a28623d1b0
+size 1840881
diff --git a/resources/demo_samples/sample_03.jpeg b/resources/demo_samples/sample_03.jpeg
new file mode 100644
index 0000000000000000000000000000000000000000..d34147db0c2888b6c77e92202efd8acb1e8d0f36
--- /dev/null
+++ b/resources/demo_samples/sample_03.jpeg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:33e5c7054300d5cf1871a33972416504c05dbddff238b32ff884525bcbfca695
+size 7324740
diff --git a/resources/demo_samples/sample_04.png b/resources/demo_samples/sample_04.png
new file mode 100644
index 0000000000000000000000000000000000000000..fda5012edd98754db1e8a3cfbfbea287a41a9299
--- /dev/null
+++ b/resources/demo_samples/sample_04.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:072e7e119d437c0c1fa731c40c737780b7725bd49041ef7466153e1ee7045920
+size 7903759
diff --git a/resources/demo_samples/sample_05.png b/resources/demo_samples/sample_05.png
new file mode 100644
index 0000000000000000000000000000000000000000..18bd12d1862f474cf0c9fc26edd11f55e83f34d8
--- /dev/null
+++ b/resources/demo_samples/sample_05.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1331dfcef69066c225d34c659f756a92ce3dc71965978db67814eda36b1cdc5f
+size 2645089
diff --git a/resources/ovseg.gif b/resources/ovseg.gif
new file mode 100644
index 0000000000000000000000000000000000000000..9d77dbd81f06dc65346cfd2d7a1b4742ff0597f8
--- /dev/null
+++ b/resources/ovseg.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:828e23424f7b494c6dad079d52551a4bb9a4cfb292dcec0acd376b89c5944128
+size 3789733
diff --git a/resources/proposal.png b/resources/proposal.png
new file mode 100644
index 0000000000000000000000000000000000000000..4ebf9a6ae0163ad1b733fe9cd15537ea7a016c72
--- /dev/null
+++ b/resources/proposal.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b8cacaabbc9dd7b7b5a7f974975128a2ea604759606dacc45b90a3d67b18d8e8
+size 194338
diff --git a/resources/pytorch-logo-dark.png b/resources/pytorch-logo-dark.png
new file mode 100644
index 0000000000000000000000000000000000000000..8cef4518ca3ea7cda3b046082a6035541e0f07fd
--- /dev/null
+++ b/resources/pytorch-logo-dark.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8211f1b771de99ae379db83350327139a597c0f99d6b6312e81e977d4d413c44
+size 15625
diff --git a/sam_vit_h_4b8939.pth b/sam_vit_h_4b8939.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72
--- /dev/null
+++ b/sam_vit_h_4b8939.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
+size 2564550879
diff --git a/sam_vit_l_0b3195.pth b/sam_vit_l_0b3195.pth
new file mode 100644
index 0000000000000000000000000000000000000000..87a638d6b789dd2b10fc7414a88dacc34a50769a
--- /dev/null
+++ b/sam_vit_l_0b3195.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622
+size 1249524607