diff --git a/gradio_demo/app.py b/gradio_demo/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..271b4b8989a7ea7cb509b73c57befbb776dcc038
--- /dev/null
+++ b/gradio_demo/app.py
@@ -0,0 +1,545 @@
+import sys
+sys.path.append('./')
+import gradio as gr
+import random
+import numpy as np
+from gradio_demo.character_template import character_man, lorapath_man
+from gradio_demo.character_template import character_woman, lorapath_woman
+from gradio_demo.character_template import styles, lorapath_styles
+import torch
+import os
+from typing import Tuple, List
+import copy
+import argparse
+from diffusers.utils import load_image
+import cv2
+from PIL import Image
+from transformers import DPTFeatureExtractor, DPTForDepthEstimation
+from controlnet_aux import OpenposeDetector
+from controlnet_aux.open_pose.body import Body
+
+try:
+    from inference.models import YOLOWorld
+    from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
+    from src.efficientvit.sam_model_zoo import create_sam_model
+    import supervision as sv
+except:
+    print("YoloWorld can not be load")
+
+try:
+    from groundingdino.models import build_model
+    from groundingdino.util import box_ops
+    from groundingdino.util.slconfig import SLConfig
+    from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
+    from groundingdino.util.inference import annotate, predict
+    from segment_anything import build_sam, SamPredictor
+    import groundingdino.datasets.transforms as T
+except:
+    print("groundingdino can not be load")
+
+from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
+from src.prompt_attention.p2p_attention import AttentionReplace
+from diffusers import ControlNetModel, StableDiffusionXLPipeline
+from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
+
+CHARACTER_MAN_NAMES = list(character_man.keys())
+CHARACTER_WOMAN_NAMES = list(character_woman.keys())
+STYLE_NAMES = list(styles.keys())
+MAX_SEED = np.iinfo(np.int32).max
+
+### Description
+title = r"""
+<h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
+"""
+
+description = r"""
+<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
+
+How to use:<br>
+1. Select two characters.
+2. Enter a text prompt as done in normal text-to-image models.
+3. Click the <b>Submit</b> button to start customizing.
+4. Enjoy the generated image😊!
+"""
+
+article = r"""
+---
+📝 **Citation**
+<br>
+If our work is helpful for your research or applications, please cite us via:
+```bibtex
+@article{,
+title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
+author={},
+journal={},
+year={}
+}
+```
+"""
+
+tips = r"""
+### Usage tips of OMG
+1. Input text prompts to describe a man and a woman
+"""
+
+css = '''
+.gradio-container {width: 85% !important}
+'''
+
+def sample_image(pipe,
+    input_prompt,
+    input_neg_prompt=None,
+    generator=None,
+    concept_models=None,
+    num_inference_steps=50,
+    guidance_scale=7.5,
+    controller=None,
+    stage=None,
+    region_masks=None,
+    lora_list = None,
+    styleL=None,
+    **extra_kargs
+):
+
+    spatial_condition = extra_kargs.pop('spatial_condition')
+    if spatial_condition is not None:
+        spatial_condition_input = [spatial_condition] * len(input_prompt)
+    else:
+        spatial_condition_input = None
+
+    images = pipe(
+        prompt=input_prompt,
+        concept_models=concept_models,
+        negative_prompt=input_neg_prompt,
+        generator=generator,
+        guidance_scale=guidance_scale,
+        num_inference_steps=num_inference_steps,
+        cross_attention_kwargs={"scale": 0.8},
+        controller=controller,
+        stage=stage,
+        region_masks=region_masks,
+        lora_list=lora_list,
+        styleL=styleL,
+        image=spatial_condition_input,
+        **extra_kargs).images
+
+    return images
+
+def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
+    image = np.asarray(image_source)
+    return image
+
+def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
+    transform = T.Compose(
+        [
+            T.RandomResize([800], max_size=1333),
+            T.ToTensor(),
+            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+        ]
+    )
+    image = np.asarray(image_source)
+    image_transformed, _ = transform(image_source, None)
+    return image, image_transformed
+
+def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
+    if segmentType=='GroundingDINO':
+        image_source, image = load_image_dino(image)
+        boxes, logits, phrases = predict(
+            model=segmentmodel,
+            image=image,
+            caption=TEXT_PROMPT,
+            box_threshold=0.3,
+            text_threshold=0.25
+        )
+        sam.set_image(image_source)
+        H, W, _ = image_source.shape
+        boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
+
+        transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
+        masks, _, _ = sam.predict_torch(
+            point_coords=None,
+            point_labels=None,
+            boxes=transformed_boxes,
+            multimask_output=False,
+        )
+        masks=masks[0].squeeze(0)
+    else:
+        image_source = load_image_yoloworld(image)
+        segmentmodel.set_classes([TEXT_PROMPT])
+        results = segmentmodel.infer(image_source, confidence=confidence)
+        detections = sv.Detections.from_inference(results).with_nms(
+            class_agnostic=True, threshold=threshold
+        )
+        masks = None
+        if len(detections) != 0:
+            print(TEXT_PROMPT + " detected!")
+            sam.set_image(image_source, image_format="RGB")
+            masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
+            masks = torch.from_numpy(masks.squeeze())
+
+    return masks
+
+def prepare_text(prompt, region_prompts):
+    '''
+    Args:
+        prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
+    Returns:
+        full_prompt: subject1, attribute1 and subject2, attribute2, global text
+        context_prompt: subject1 and subject2, global text
+        entity_collection: [(subject1, attribute1), Location1]
+    '''
+    region_collection = []
+
+    regions = region_prompts.split('|')
+
+    for region in regions:
+        if region == '':
+            break
+        prompt_region, neg_prompt_region = region.split('-*-')
+        prompt_region = prompt_region.replace('[', '').replace(']', '')
+        neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
+
+        region_collection.append((prompt_region, neg_prompt_region))
+    return (prompt, region_collection)
+
+
+def build_model_sd(pretrained_model, controlnet_path, device, prompts):
+    controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
+    pipe = LoraMultiConceptPipeline.from_pretrained(
+        pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
+    controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
+    revise_regionally_controlnet_forward(pipe.unet, controller)
+    pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
+                                                             variant="fp16").to(device)
+    return pipe, controller, pipe_concept
+
+def build_model_lora(pipe_concept, lora_paths, style_path, condition, args):
+    pipe_list = []
+    if condition == "Human pose":
+        controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
+        pipe_concept.controlnet = controlnet
+    elif condition == "Canny Edge":
+        controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16).to(device)
+        pipe_concept.controlnet = controlnet
+    elif condition == "Depth":
+        controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
+        pipe_concept.controlnet = controlnet
+
+    if style_path is not None and os.path.exists(style_path):
+        pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
+
+    for lora_path in lora_paths.split('|'):
+        adapter_name = lora_path.split('/')[-1].split('.')[0]
+        pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
+        pipe_concept.enable_xformers_memory_efficient_attention()
+        pipe_list.append(adapter_name)
+    return pipe_list
+
+def build_yolo_segment_model(sam_path, device):
+    yolo_world = YOLOWorld(model_id="yolo_world/l")
+    sam = EfficientViTSamPredictor(
+        create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
+    )
+    return yolo_world, sam
+
+def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
+    args = SLConfig.fromfile(ckpt_config_filename)
+    model = build_model(args)
+    args.device = device
+
+    checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
+    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
+    print("Model loaded from {} \n => {}".format(filename, log))
+    _ = model.eval()
+    return model
+
+def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
+    ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
+    ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
+    groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
+    sam = build_sam(checkpoint=sam_checkpoint)
+    sam.cuda()
+    sam_predictor = SamPredictor(sam)
+    return groundingdino_model, sam_predictor
+
+
+
+def main(device, segment_type):
+    pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
+
+    if segment_type == 'GroundingDINO':
+        detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
+    else:
+        detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
+
+    resolution_list = ["1440*728",
+                       "1344*768",
+                       "1216*832",
+                       "1152*896",
+                       "1024*1024",
+                       "896*1152",
+                       "832*1216",
+                       "768*1344",
+                       "728*1440"]
+
+    condition_list = ["None",
+                      "Human pose",
+                      "Canny Edge",
+                      "Depth"]
+
+    depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
+    feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
+    body_model = Body(args.pose_detector_checkpoint)
+    openpose = OpenposeDetector(body_model)
+
+    def remove_tips():
+        return gr.update(visible=False)
+
+    def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
+        if randomize_seed:
+            seed = random.randint(0, MAX_SEED)
+        return seed
+
+    def get_humanpose(img):
+        openpose_image = openpose(img)
+        return openpose_image
+
+    def get_cannyedge(image):
+        image = np.array(image)
+        image = cv2.Canny(image, 100, 200)
+        image = image[:, :, None]
+        image = np.concatenate([image, image, image], axis=2)
+        canny_image = Image.fromarray(image)
+        return canny_image
+
+    def get_depth(image):
+        image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
+        with torch.no_grad(), torch.autocast("cuda"):
+            depth_map = depth_estimator(image).predicted_depth
+
+        depth_map = torch.nn.functional.interpolate(
+            depth_map.unsqueeze(1),
+            size=(1024, 1024),
+            mode="bicubic",
+            align_corners=False,
+        )
+        depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
+        depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
+        depth_map = (depth_map - depth_min) / (depth_max - depth_min)
+        image = torch.cat([depth_map] * 3, dim=1)
+        image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
+        image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
+        return image
+
+    def generate_image(prompt1, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style):
+        try:
+            path1 = lorapath_man[man]
+            path2 = lorapath_woman[woman]
+            pipe_concept.unload_lora_weights()
+            pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args)
+
+            if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
+                styleL = True
+            else:
+                styleL = False
+
+            input_list = [prompt1, prompt2, prompt3, prompt4]
+            condition_list = [condition_img1, condition_img2, condition_img3, condition_img4]
+            output_list = []
+
+            width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
+
+            kwargs = {
+                'height': height,
+                'width': width,
+            }
+
+            for prompt, condition_img in zip(input_list, condition_list):
+                if prompt!='':
+                    input_prompt = []
+                    p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
+                    if styleL:
+                        p = styles[style] + p
+                    input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
+                    input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]), (styles[style] + local_prompt2, character_woman.get(woman)[1])])
+
+                    if condition == 'Human pose' and condition_img is not None:
+                        spatial_condition = get_humanpose(condition_img).resize((width, height))
+                    elif condition == 'Canny Edge' and condition_img is not None:
+                        spatial_condition = get_cannyedge(condition_img).resize((width, height))
+                    elif condition == 'Depth' and condition_img is not None:
+                        spatial_condition = get_depth(condition_img).resize((width, height))
+                    else:
+                        spatial_condition = None
+
+                    kwargs['spatial_condition'] = spatial_condition
+
+                    controller.reset()
+                    image = sample_image(
+                        pipe,
+                        input_prompt=input_prompt,
+                        concept_models=pipe_concept,
+                        input_neg_prompt=[negative_prompt] * len(input_prompt),
+                        generator=torch.Generator(device).manual_seed(seed),
+                        controller=controller,
+                        stage=1,
+                        lora_list=pipe_list,
+                        styleL=styleL,
+                        **kwargs)
+
+                    controller.reset()
+                    if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
+                        mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
+                                             threshold=0.5)
+                    else:
+                        mask1 = None
+
+                    if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
+                        mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
+                                             threshold=0.5)
+                    else:
+                        mask2 = None
+
+                    if mask1 is None and mask2 is None:
+                        output_list.append(image[1])
+                    else:
+                        image = sample_image(
+                            pipe,
+                            input_prompt=input_prompt,
+                            concept_models=pipe_concept,
+                            input_neg_prompt=[negative_prompt] * len(input_prompt),
+                            generator=torch.Generator(device).manual_seed(seed),
+                            controller=controller,
+                            stage=2,
+                            region_masks=[mask1, mask2],
+                            lora_list=pipe_list,
+                            styleL=styleL,
+                            **kwargs)
+                        output_list.append(image[1])
+                else:
+                    output_list.append(None)
+            return output_list
+        except:
+            print("error")
+            return None, None, None, None
+
+    def get_local_value_man(input):
+        return character_man[input][0]
+
+    def get_local_value_woman(input):
+        return character_woman[input][0]
+
+
+    with gr.Blocks(css=css) as demo:
+        # description
+        gr.Markdown(title)
+        gr.Markdown(description)
+
+        with gr.Row():
+            gallery = gr.Image(label="Generated Images", height=512, width=512)
+            gallery2 = gr.Image(label="Generated Images", height=512, width=512)
+            gallery3 = gr.Image(label="Generated Images", height=512, width=512)
+            gallery4 = gr.Image(label="Generated Images", height=512, width=512)
+            usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
+
+        with gr.Row():
+            condition_img1 = gr.Image(label="Input condition", height=128, width=128)
+            condition_img2 = gr.Image(label="Input condition", height=128, width=128)
+            condition_img3 = gr.Image(label="Input condition", height=128, width=128)
+            condition_img4 = gr.Image(label="Input condition", height=128, width=128)
+
+        # character choose
+        with gr.Row():
+            man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Harry Potter (identifier: Harry Potter)")
+            woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Hermione Granger (identifier: Hermione Granger)")
+            resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024")
+            condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
+            style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
+
+        with gr.Row():
+            local_prompt1 = gr.Textbox(label="Character1_prompt",
+                                info="Describe the Character 1, this prompt should include the identifier of character 1",
+                                value="Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.")
+            local_prompt2 = gr.Textbox(label="Character2_prompt",
+                                       info="Describe the Character 2, this prompt should include the identifier of character2",
+                                       value="Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.")
+
+        man.change(get_local_value_man, man, local_prompt1)
+        woman.change(get_local_value_woman, woman, local_prompt2)
+
+        # prompt
+        with gr.Column():
+            prompt = gr.Textbox(label="Prompt 1",
+                                info="Give a simple prompt to describe the first image content",
+                                placeholder="Required",
+                                value="close-up shot, photography, the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling")
+            prompt2 = gr.Textbox(label="Prompt 2",
+                                 info="Give a simple prompt to describe the second image content",
+                                 placeholder="optional",
+                                 value="")
+            prompt3 = gr.Textbox(label="Prompt 3",
+                                 info="Give a simple prompt to describe the third image content",
+                                 placeholder="optional",
+                                 value="")
+            prompt4 = gr.Textbox(label="Prompt 4",
+                                 info="Give a simple prompt to describe the fourth image content",
+                                 placeholder="optional",
+                                 value="")
+
+        with gr.Accordion(open=False, label="Advanced Options"):
+            seed = gr.Slider(
+                label="Seed",
+                minimum=0,
+                maximum=MAX_SEED,
+                step=1,
+                value=42,
+            )
+            negative_prompt = gr.Textbox(label="Negative Prompt",
+                                placeholder="noisy, blurry, soft, deformed, ugly",
+                                value="noisy, blurry, soft, deformed, ugly")
+            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
+
+        submit = gr.Button("Submit", variant="primary")
+
+        submit.click(
+            fn=remove_tips,
+            outputs=usage_tips,
+        ).then(
+            fn=randomize_seed_fn,
+            inputs=[seed, randomize_seed],
+            outputs=seed,
+            queue=False,
+            api_name=False,
+        ).then(
+            fn=generate_image,
+            inputs=[prompt, prompt2, prompt3, prompt4, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, condition_img2, condition_img3, condition_img4, style],
+            outputs=[gallery, gallery2, gallery3, gallery4]
+        )
+    demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
+
+def parse_args():
+    parser = argparse.ArgumentParser('', add_help=False)
+    parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str)
+    parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str)
+    parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str)
+    parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str)
+    parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
+    parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
+    parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
+    parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str)
+    parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
+    parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
+    parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
+    parser.add_argument('--seed', default=22, type=int)
+    parser.add_argument('--suffix', default='', type=str)
+    parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
+    return parser.parse_args()
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    prompts = [args.prompt]*2
+    prompts_tmp = copy.deepcopy(prompts)
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    main(device, args.segment_type)
\ No newline at end of file
diff --git a/gradio_demo/app_generateOne.py b/gradio_demo/app_generateOne.py
new file mode 100644
index 0000000000000000000000000000000000000000..f716e314740ea0356d85021a261d5a1ee697a218
--- /dev/null
+++ b/gradio_demo/app_generateOne.py
@@ -0,0 +1,529 @@
+import sys
+sys.path.append('./')
+import gradio as gr
+import random
+import numpy as np
+from gradio_demo.character_template import character_man, lorapath_man
+from gradio_demo.character_template import character_woman, lorapath_woman
+from gradio_demo.character_template import styles, lorapath_styles
+import torch
+import os
+from typing import Tuple, List
+import copy
+import argparse
+from diffusers.utils import load_image
+import cv2
+from PIL import Image
+from transformers import DPTFeatureExtractor, DPTForDepthEstimation
+from controlnet_aux import OpenposeDetector
+from controlnet_aux.open_pose.body import Body
+
+try:
+    from inference.models import YOLOWorld
+    from src.efficientvit.models.efficientvit.sam import EfficientViTSamPredictor
+    from src.efficientvit.sam_model_zoo import create_sam_model
+    import supervision as sv
+except:
+    print("YoloWorld can not be load")
+
+try:
+    from groundingdino.models import build_model
+    from groundingdino.util import box_ops
+    from groundingdino.util.slconfig import SLConfig
+    from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
+    from groundingdino.util.inference import annotate, predict
+    from segment_anything import build_sam, SamPredictor
+    import groundingdino.datasets.transforms as T
+except:
+    print("groundingdino can not be load")
+
+from src.pipelines.lora_pipeline import LoraMultiConceptPipeline
+from src.prompt_attention.p2p_attention import AttentionReplace
+from diffusers import ControlNetModel, StableDiffusionXLPipeline
+from src.pipelines.lora_pipeline import revise_regionally_controlnet_forward
+
+CHARACTER_MAN_NAMES = list(character_man.keys())
+CHARACTER_WOMAN_NAMES = list(character_woman.keys())
+STYLE_NAMES = list(styles.keys())
+MAX_SEED = np.iinfo(np.int32).max
+
+### Description
+title = r"""
+<h1 align="center">OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</h1>
+"""
+
+description = r"""
+<b>Official 🤗 Gradio demo</b> for <a href='https://github.com/' target='_blank'><b>OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models</b></a>.<br>
+
+How to use:<br>
+1. Select two characters.
+2. Enter a text prompt as done in normal text-to-image models.
+3. Click the <b>Submit</b> button to start customizing.
+4. Enjoy the generated image😊!
+"""
+
+article = r"""
+---
+📝 **Citation**
+<br>
+If our work is helpful for your research or applications, please cite us via:
+```bibtex
+@article{,
+title={OMG: Occlusion-friendly Personalized Multi-concept Generation In Diffusion Models},
+author={},
+journal={},
+year={}
+}
+```
+"""
+
+tips = r"""
+### Usage tips of OMG
+1. Input text prompts to describe a man and a woman
+"""
+
+css = '''
+.gradio-container {width: 85% !important}
+'''
+
+def sample_image(pipe,
+    input_prompt,
+    input_neg_prompt=None,
+    generator=None,
+    concept_models=None,
+    num_inference_steps=50,
+    guidance_scale=7.5,
+    controller=None,
+    stage=None,
+    region_masks=None,
+    lora_list = None,
+    styleL=None,
+    **extra_kargs
+):
+
+    spatial_condition = extra_kargs.pop('spatial_condition')
+    if spatial_condition is not None:
+        spatial_condition_input = [spatial_condition] * len(input_prompt)
+    else:
+        spatial_condition_input = None
+
+    images = pipe(
+        prompt=input_prompt,
+        concept_models=concept_models,
+        negative_prompt=input_neg_prompt,
+        generator=generator,
+        guidance_scale=guidance_scale,
+        num_inference_steps=num_inference_steps,
+        cross_attention_kwargs={"scale": 0.8},
+        controller=controller,
+        stage=stage,
+        region_masks=region_masks,
+        lora_list=lora_list,
+        styleL=styleL,
+        image=spatial_condition_input,
+        **extra_kargs).images
+
+    return images
+
+def load_image_yoloworld(image_source) -> Tuple[np.array, torch.Tensor]:
+    image = np.asarray(image_source)
+    return image
+
+def load_image_dino(image_source) -> Tuple[np.array, torch.Tensor]:
+    transform = T.Compose(
+        [
+            T.RandomResize([800], max_size=1333),
+            T.ToTensor(),
+            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
+        ]
+    )
+    image = np.asarray(image_source)
+    image_transformed, _ = transform(image_source, None)
+    return image, image_transformed
+
+def predict_mask(segmentmodel, sam, image, TEXT_PROMPT, segmentType, confidence = 0.2, threshold = 0.5):
+    if segmentType=='GroundingDINO':
+        image_source, image = load_image_dino(image)
+        boxes, logits, phrases = predict(
+            model=segmentmodel,
+            image=image,
+            caption=TEXT_PROMPT,
+            box_threshold=0.3,
+            text_threshold=0.25
+        )
+        sam.set_image(image_source)
+        H, W, _ = image_source.shape
+        boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
+
+        transformed_boxes = sam.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).cuda()
+        masks, _, _ = sam.predict_torch(
+            point_coords=None,
+            point_labels=None,
+            boxes=transformed_boxes,
+            multimask_output=False,
+        )
+        masks=masks[0].squeeze(0)
+    else:
+        image_source = load_image_yoloworld(image)
+        segmentmodel.set_classes([TEXT_PROMPT])
+        results = segmentmodel.infer(image_source, confidence=confidence)
+        detections = sv.Detections.from_inference(results).with_nms(
+            class_agnostic=True, threshold=threshold
+        )
+        masks = None
+        if len(detections) != 0:
+            print(TEXT_PROMPT + " detected!")
+            sam.set_image(image_source, image_format="RGB")
+            masks, _, _ = sam.predict(box=detections.xyxy[0], multimask_output=False)
+            masks = torch.from_numpy(masks.squeeze())
+
+    return masks
+
+def prepare_text(prompt, region_prompts):
+    '''
+    Args:
+        prompt_entity: [subject1]-*-[attribute1]-*-[Location1]|[subject2]-*-[attribute2]-*-[Location2]|[global text]
+    Returns:
+        full_prompt: subject1, attribute1 and subject2, attribute2, global text
+        context_prompt: subject1 and subject2, global text
+        entity_collection: [(subject1, attribute1), Location1]
+    '''
+    region_collection = []
+
+    regions = region_prompts.split('|')
+
+    for region in regions:
+        if region == '':
+            break
+        prompt_region, neg_prompt_region = region.split('-*-')
+        prompt_region = prompt_region.replace('[', '').replace(']', '')
+        neg_prompt_region = neg_prompt_region.replace('[', '').replace(']', '')
+
+        region_collection.append((prompt_region, neg_prompt_region))
+    return (prompt, region_collection)
+
+
+def build_model_sd(pretrained_model, controlnet_path, device, prompts):
+    controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16).to(device)
+    pipe = LoraMultiConceptPipeline.from_pretrained(
+        pretrained_model, controlnet=controlnet, torch_dtype=torch.float16, variant="fp16").to(device)
+    controller = AttentionReplace(prompts, 50, cross_replace_steps={"default_": 1.}, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=device, dtype=torch.float16, width=1024//32, height=1024//32)
+    revise_regionally_controlnet_forward(pipe.unet, controller)
+    pipe_concept = StableDiffusionXLPipeline.from_pretrained(pretrained_model, torch_dtype=torch.float16,
+                                                             variant="fp16").to(device)
+    return pipe, controller, pipe_concept
+
+def build_model_lora(pipe_concept, lora_paths, style_path, condition, args):
+    pipe_list = []
+    if condition == "Human pose":
+        controlnet = ControlNetModel.from_pretrained(args.openpose_checkpoint, torch_dtype=torch.float16).to(device)
+        pipe_concept.controlnet = controlnet
+    elif condition == "Canny Edge":
+        controlnet = ControlNetModel.from_pretrained(args.canny_checkpoint, torch_dtype=torch.float16).to(device)
+        pipe_concept.controlnet = controlnet
+    elif condition == "Depth":
+        controlnet = ControlNetModel.from_pretrained(args.depth_checkpoint, torch_dtype=torch.float16).to(device)
+        pipe_concept.controlnet = controlnet
+
+    if style_path is not None and os.path.exists(style_path):
+        pipe_concept.load_lora_weights(style_path, weight_name="pytorch_lora_weights.safetensors", adapter_name='style')
+
+    for lora_path in lora_paths.split('|'):
+        adapter_name = lora_path.split('/')[-1].split('.')[0]
+        pipe_concept.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=adapter_name)
+        pipe_concept.enable_xformers_memory_efficient_attention()
+        pipe_list.append(adapter_name)
+    return pipe_list
+
+def build_yolo_segment_model(sam_path, device):
+    yolo_world = YOLOWorld(model_id="yolo_world/l")
+    sam = EfficientViTSamPredictor(
+        create_sam_model(name="xl1", weight_url=sam_path).to(device).eval()
+    )
+    return yolo_world, sam
+
+def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
+    args = SLConfig.fromfile(ckpt_config_filename)
+    model = build_model(args)
+    args.device = device
+
+    checkpoint = torch.load(os.path.join(repo_id, filename), map_location='cpu')
+    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
+    print("Model loaded from {} \n => {}".format(filename, log))
+    _ = model.eval()
+    return model
+
+def build_dino_segment_model(ckpt_repo_id, sam_checkpoint):
+    ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
+    ckpt_config_filename = os.path.join(ckpt_repo_id, "GroundingDINO_SwinB.cfg.py")
+    groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)
+    sam = build_sam(checkpoint=sam_checkpoint)
+    sam.cuda()
+    sam_predictor = SamPredictor(sam)
+    return groundingdino_model, sam_predictor
+
+
+
+def main(device, segment_type):
+    pipe, controller, pipe_concept = build_model_sd(args.pretrained_sdxl_model, args.openpose_checkpoint, device, prompts_tmp)
+
+    if segment_type == 'GroundingDINO':
+        detect_model, sam = build_dino_segment_model(args.dino_checkpoint, args.sam_checkpoint)
+    else:
+        detect_model, sam = build_yolo_segment_model(args.efficientViT_checkpoint, device)
+
+    resolution_list = ["1440*728",
+                       "1344*768",
+                       "1216*832",
+                       "1152*896",
+                       "1024*1024",
+                       "896*1152",
+                       "832*1216",
+                       "768*1344",
+                       "728*1440"]
+
+    condition_list = ["None",
+                      "Human pose",
+                      "Canny Edge",
+                      "Depth"]
+
+    depth_estimator = DPTForDepthEstimation.from_pretrained(args.dpt_checkpoint).to("cuda")
+    feature_extractor = DPTFeatureExtractor.from_pretrained(args.dpt_checkpoint)
+    body_model = Body(args.pose_detector_checkpoint)
+    openpose = OpenposeDetector(body_model)
+
+    def remove_tips():
+        return gr.update(visible=False)
+
+    def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
+        if randomize_seed:
+            seed = random.randint(0, MAX_SEED)
+        return seed
+
+    def get_humanpose(img):
+        openpose_image = openpose(img)
+        return openpose_image
+
+    def get_cannyedge(image):
+        image = np.array(image)
+        image = cv2.Canny(image, 100, 200)
+        image = image[:, :, None]
+        image = np.concatenate([image, image, image], axis=2)
+        canny_image = Image.fromarray(image)
+        return canny_image
+
+    def get_depth(image):
+        image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
+        with torch.no_grad(), torch.autocast("cuda"):
+            depth_map = depth_estimator(image).predicted_depth
+
+        depth_map = torch.nn.functional.interpolate(
+            depth_map.unsqueeze(1),
+            size=(1024, 1024),
+            mode="bicubic",
+            align_corners=False,
+        )
+        depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
+        depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
+        depth_map = (depth_map - depth_min) / (depth_max - depth_min)
+        image = torch.cat([depth_map] * 3, dim=1)
+        image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
+        image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
+        return image
+
+    def generate_image(prompt1, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style):
+        try:
+            path1 = lorapath_man[man]
+            path2 = lorapath_woman[woman]
+            pipe_concept.unload_lora_weights()
+            pipe_list = build_model_lora(pipe_concept, path1 + "|" + path2, lorapath_styles[style], condition, args)
+
+            if lorapath_styles[style] is not None and os.path.exists(lorapath_styles[style]):
+                styleL = True
+            else:
+                styleL = False
+
+            input_list = [prompt1]
+            condition_list = [condition_img1]
+            output_list = []
+
+            width, height = int(resolution.split("*")[0]), int(resolution.split("*")[1])
+
+            kwargs = {
+                'height': height,
+                'width': width,
+            }
+
+            for prompt, condition_img in zip(input_list, condition_list):
+                if prompt!='':
+                    input_prompt = []
+                    p = '{prompt}, 35mm photograph, film, professional, 4k, highly detailed.'
+                    if styleL:
+                        p = styles[style] + p
+                    input_prompt.append([p.replace("{prompt}", prompt), p.replace("{prompt}", prompt)])
+                    input_prompt.append([(styles[style] + local_prompt1, character_man.get(man)[1]), (styles[style] + local_prompt2, character_woman.get(woman)[1])])
+
+                    if condition == 'Human pose' and condition_img is not None:
+                        spatial_condition = get_humanpose(condition_img).resize((width, height))
+                    elif condition == 'Canny Edge' and condition_img is not None:
+                        spatial_condition = get_cannyedge(condition_img).resize((width, height))
+                    elif condition == 'Depth' and condition_img is not None:
+                        spatial_condition = get_depth(condition_img).resize((width, height))
+                    else:
+                        spatial_condition = None
+
+                    kwargs['spatial_condition'] = spatial_condition
+                    controller.reset()
+                    image = sample_image(
+                        pipe,
+                        input_prompt=input_prompt,
+                        concept_models=pipe_concept,
+                        input_neg_prompt=[negative_prompt] * len(input_prompt),
+                        generator=torch.Generator(device).manual_seed(seed),
+                        controller=controller,
+                        stage=1,
+                        lora_list=pipe_list,
+                        styleL=styleL,
+                        **kwargs)
+
+                    controller.reset()
+                    if pipe.tokenizer("man")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
+                        mask1 = predict_mask(detect_model, sam, image[0], 'man', args.segment_type, confidence=0.15,
+                                             threshold=0.5)
+                    else:
+                        mask1 = None
+
+                    if pipe.tokenizer("woman")["input_ids"][1] in pipe.tokenizer(args.prompt)["input_ids"][1:-1]:
+                        mask2 = predict_mask(detect_model, sam, image[0], 'woman', args.segment_type, confidence=0.15,
+                                             threshold=0.5)
+                    else:
+                        mask2 = None
+
+                    if mask1 is None and mask2 is None:
+                        output_list.append(image[1])
+                    else:
+                        image = sample_image(
+                            pipe,
+                            input_prompt=input_prompt,
+                            concept_models=pipe_concept,
+                            input_neg_prompt=[negative_prompt] * len(input_prompt),
+                            generator=torch.Generator(device).manual_seed(seed),
+                            controller=controller,
+                            stage=2,
+                            region_masks=[mask1, mask2],
+                            lora_list=pipe_list,
+                            styleL=styleL,
+                            **kwargs)
+                        output_list.append(image[1])
+                else:
+                    output_list.append(None)
+            output_list.append(spatial_condition)
+            return output_list
+        except:
+            print("error")
+            return
+
+    def get_local_value_man(input):
+        return character_man[input][0]
+
+    def get_local_value_woman(input):
+        return character_woman[input][0]
+
+
+    with gr.Blocks(css=css) as demo:
+        # description
+        gr.Markdown(title)
+        gr.Markdown(description)
+
+        with gr.Row():
+            gallery = gr.Image(label="Generated Images", height=512, width=512)
+            gen_condition = gr.Image(label="Spatial Condition", height=512, width=512)
+            usage_tips = gr.Markdown(label="Usage tips of OMG", value=tips, visible=False)
+
+        with gr.Row():
+            condition_img1 = gr.Image(label="Input an RGB image for condition", height=128, width=128)
+
+        # character choose
+        with gr.Row():
+            man = gr.Dropdown(label="Character 1 selection", choices=CHARACTER_MAN_NAMES, value="Harry Potter (identifier: Harry Potter)")
+            woman = gr.Dropdown(label="Character 2 selection", choices=CHARACTER_WOMAN_NAMES, value="Hermione Granger (identifier: Hermione Granger)")
+            resolution = gr.Dropdown(label="Image Resolution (width*height)", choices=resolution_list, value="1024*1024")
+            condition = gr.Dropdown(label="Input condition type", choices=condition_list, value="None")
+            style = gr.Dropdown(label="style", choices=STYLE_NAMES, value="None")
+
+        with gr.Row():
+            local_prompt1 = gr.Textbox(label="Character1_prompt",
+                                info="Describe the Character 1, this prompt should include the identifier of character 1",
+                                value="Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.")
+            local_prompt2 = gr.Textbox(label="Character2_prompt",
+                                       info="Describe the Character 2, this prompt should include the identifier of character2",
+                                       value="Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.")
+
+        man.change(get_local_value_man, man, local_prompt1)
+        woman.change(get_local_value_woman, woman, local_prompt2)
+
+        # prompt
+        with gr.Column():
+            prompt = gr.Textbox(label="Prompt 1",
+                                info="Give a simple prompt to describe the first image content",
+                                placeholder="Required",
+                                value="close-up shot, photography, the cool man and beautiful woman as they accidentally discover a mysterious island while on vacation by the sea, facing the camera smiling")
+
+
+        with gr.Accordion(open=False, label="Advanced Options"):
+            seed = gr.Slider(
+                label="Seed",
+                minimum=0,
+                maximum=MAX_SEED,
+                step=1,
+                value=42,
+            )
+            negative_prompt = gr.Textbox(label="Negative Prompt",
+                                placeholder="noisy, blurry, soft, deformed, ugly",
+                                value="noisy, blurry, soft, deformed, ugly")
+            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
+
+        submit = gr.Button("Submit", variant="primary")
+
+        submit.click(
+            fn=remove_tips,
+            outputs=usage_tips,
+        ).then(
+            fn=randomize_seed_fn,
+            inputs=[seed, randomize_seed],
+            outputs=seed,
+            queue=False,
+            api_name=False,
+        ).then(
+            fn=generate_image,
+            inputs=[prompt, negative_prompt, man, woman, resolution, local_prompt1, local_prompt2, seed, condition, condition_img1, style],
+            outputs=[gallery, gen_condition]
+        )
+    demo.launch(server_name='0.0.0.0',server_port=7861, debug=True)
+
+def parse_args():
+    parser = argparse.ArgumentParser('', add_help=False)
+    parser.add_argument('--pretrained_sdxl_model', default='./checkpoint/stable-diffusion-xl-base-1.0', type=str)
+    parser.add_argument('--openpose_checkpoint', default='./checkpoint/controlnet-openpose-sdxl-1.0', type=str)
+    parser.add_argument('--canny_checkpoint', default='./checkpoint/controlnet-canny-sdxl-1.0', type=str)
+    parser.add_argument('--depth_checkpoint', default='./checkpoint/controlnet-depth-sdxl-1.0', type=str)
+    parser.add_argument('--efficientViT_checkpoint', default='./checkpoint/sam/xl1.pt', type=str)
+    parser.add_argument('--dino_checkpoint', default='./checkpoint/GroundingDINO', type=str)
+    parser.add_argument('--sam_checkpoint', default='./checkpoint/sam/sam_vit_h_4b8939.pth', type=str)
+    parser.add_argument('--dpt_checkpoint', default='./checkpoint/dpt-hybrid-midas', type=str)
+    parser.add_argument('--pose_detector_checkpoint', default='./checkpoint/ControlNet/annotator/ckpts/body_pose_model.pth', type=str)
+    parser.add_argument('--prompt', default='Close-up photo of the cool man and beautiful woman in surprised expressions as they accidentally discover a mysterious island while on vacation by the sea, 35mm photograph, film, professional, 4k, highly detailed.', type=str)
+    parser.add_argument('--negative_prompt', default='noisy, blurry, soft, deformed, ugly', type=str)
+    parser.add_argument('--seed', default=22, type=int)
+    parser.add_argument('--suffix', default='', type=str)
+    parser.add_argument('--segment_type', default='yoloworld', help='GroundingDINO or yoloworld', type=str)
+    return parser.parse_args()
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    prompts = [args.prompt]*2
+    prompts_tmp = copy.deepcopy(prompts)
+    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+    main(device, args.segment_type)
\ No newline at end of file
diff --git a/gradio_demo/character_template.py b/gradio_demo/character_template.py
new file mode 100644
index 0000000000000000000000000000000000000000..07bb4873a325f1fe31eee8fa383298dbb3585623
--- /dev/null
+++ b/gradio_demo/character_template.py
@@ -0,0 +1,62 @@
+character_list_man = [
+    {
+        "name": "Harry Potter (identifier: Harry Potter)",
+        "prompt": "Close-up photo of the Harry Potter, 35mm photograph, film, professional, 4k, highly detailed.",
+        "negative_prompt": "noisy, blurry, soft, deformed, ugly",
+        "path": "./checkpoint/lora/Harry_Potter.safetensors",
+    },
+    {
+        "name": "Chris Evans (identifier: Chris Evans)",
+        "prompt": "Close-up photo of the Chris Evans, 35mm photograph, film, professional, 4k, highly detailed.",
+        "negative_prompt": "noisy, blurry, soft, deformed, ugly",
+        "path": "./checkpoint/lora/chris-evans.safetensors",
+    },
+    {
+        "name": "Jordan Torres (identifier: jordan_torres)",
+        "prompt": "Close-up photo of the jordan_torres man, 35mm photograph, film, professional, 4k, highly detailed.",
+        "negative_prompt": "noisy, blurry, soft, deformed, ugly",
+        "path": "./checkpoint/lora/jordan_torres_v2_xl.safetensors",
+    },
+]
+
+character_list_woman = [
+    {
+        "name": "Hermione Granger (identifier: Hermione Granger)",
+        "prompt": "Close-up photo of the Hermione Granger, 35mm photograph, film, professional, 4k, highly detailed.",
+        "negative_prompt": "noisy, blurry, soft, deformed, ugly",
+        "path": "./checkpoint/lora/Hermione_Granger.safetensors",
+    },
+    {
+        "name": "Taylor Swift (identifier: TaylorSwift)",
+        "prompt": "Close-up photo of the TaylorSwift, 35mm photograph, film, professional, 4k, highly detailed.",
+        "negative_prompt": "noisy, blurry, soft, deformed, ugly",
+        "path": "./checkpoint/lora/TaylorSwiftSDXL.safetensors",
+    },
+    {
+        "name": "Keira Knightley (identifier: ohwx woman)",
+        "prompt": "Close-up photo of the ohwx woman, 35mm photograph, film, professional, 4k, highly detailed.",
+        "negative_prompt": "noisy, blurry, soft, deformed, ugly",
+        "path": "./checkpoint/lora/keira_lora_sdxl_v1-000008.safetensors",
+    },
+]
+
+style_list = [
+    {
+        "name": "None",
+        "prompt": "",
+        "path": "",
+    },
+    {
+        "name": "Anime sketch style",
+        "prompt": "Pencil_Sketch:1.2, messy lines, greyscale, traditional media, sketch, ",
+        "path": "./checkpoint/style/Anime_Sketch_SDXL.safetensors",
+    }
+]
+
+character_man = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_man}
+character_woman = {k["name"]: (k["prompt"], k["negative_prompt"]) for k in character_list_woman}
+styles = {k["name"]: (k["prompt"]) for k in style_list}
+
+lorapath_man = {k["name"]: (k["path"]) for k in character_list_man}
+lorapath_woman = {k["name"]: (k["path"]) for k in character_list_woman}
+lorapath_styles = {k["name"]: (k["path"]) for k in style_list}
\ No newline at end of file
diff --git a/src/efficientvit/__init__.py b/src/efficientvit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/efficientvit/apps/__init__.py b/src/efficientvit/apps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/efficientvit/apps/data_provider/__init__.py b/src/efficientvit/apps/data_provider/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c9a5dfa34097fdf24730a203a9f24c5c4ac0a74
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/__init__.py
@@ -0,0 +1,7 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .augment import *
+from .base import *
+from .random_resolution import *
diff --git a/src/efficientvit/apps/data_provider/augment/__init__.py b/src/efficientvit/apps/data_provider/augment/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9ea4d65f7f5a471cc433fbd68a58d4853b217d2
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/augment/__init__.py
@@ -0,0 +1,6 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .bbox import *
+from .color_aug import *
diff --git a/src/efficientvit/apps/data_provider/augment/bbox.py b/src/efficientvit/apps/data_provider/augment/bbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f089a3f70881313a5ce4308d1f74fbf1fa0c31
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/augment/bbox.py
@@ -0,0 +1,30 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+
+__all__ = ["rand_bbox"]
+
+
+def rand_bbox(
+    h: int,
+    w: int,
+    lam: float,
+    rand_func: callable = np.random.uniform,
+) -> tuple[int, int, int, int]:
+    """randomly sample bbox, used in cutmix"""
+    cut_rat = np.sqrt(1.0 - lam)
+    cut_w = w * cut_rat
+    cut_h = h * cut_rat
+
+    # uniform
+    cx = rand_func(0, w)
+    cy = rand_func(0, h)
+
+    bbx1 = int(np.clip(cx - cut_w / 2, 0, w))
+    bby1 = int(np.clip(cy - cut_h / 2, 0, h))
+    bbx2 = int(np.clip(cx + cut_w / 2, 0, w))
+    bby2 = int(np.clip(cy + cut_h / 2, 0, h))
+
+    return bbx1, bby1, bbx2, bby2
diff --git a/src/efficientvit/apps/data_provider/augment/color_aug.py b/src/efficientvit/apps/data_provider/augment/color_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5e1dcc6998374738c300414b06e4fdb2ed8af95
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/augment/color_aug.py
@@ -0,0 +1,84 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+import torchvision.transforms as transforms
+from PIL import Image
+from timm.data.auto_augment import rand_augment_transform
+
+__all__ = ["ColorAug", "RandAug"]
+
+
+class ImageAug:
+    def aug_image(self, image: Image.Image) -> Image.Image:
+        raise NotImplementedError
+
+    def __call__(
+        self, feed_dict: dict or np.ndarray or Image.Image
+    ) -> dict or np.ndarray or Image.Image:
+        if isinstance(feed_dict, dict):
+            output_dict = feed_dict
+            image = feed_dict[self.key]
+        else:
+            output_dict = None
+            image = feed_dict
+        is_ndarray = isinstance(image, np.ndarray)
+        if is_ndarray:
+            image = Image.fromarray(image)
+
+        image = self.aug_image(image)
+
+        if is_ndarray:
+            image = np.array(image)
+
+        if output_dict is None:
+            return image
+        else:
+            output_dict[self.key] = image
+            return output_dict
+
+
+class ColorAug(transforms.ColorJitter, ImageAug):
+    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0, key="data"):
+        super().__init__(
+            brightness=brightness,
+            contrast=contrast,
+            saturation=saturation,
+            hue=hue,
+        )
+        self.key = key
+
+    def aug_image(self, image: Image.Image) -> Image.Image:
+        return transforms.ColorJitter.forward(self, image)
+
+    def forward(
+        self, feed_dict: dict or np.ndarray or Image.Image
+    ) -> dict or np.ndarray or Image.Image:
+        return ImageAug.__call__(self, feed_dict)
+
+
+class RandAug(ImageAug):
+    def __init__(
+        self, config: dict[str, any], mean: tuple[float, float, float], key="data"
+    ):
+        n = config.get("n", 2)
+        m = config.get("m", 9)
+        mstd = config.get("mstd", 1.0)
+        inc = config.get("inc", 1)
+        tpct = config.get("tpct", 0.45)
+        config_str = f"rand-n{n}-m{m}-mstd{mstd}-inc{inc}"
+
+        aa_params = dict(
+            translate_pct=tpct,
+            img_mean=tuple([min(255, round(255 * x)) for x in mean]),
+            interpolation=Image.BICUBIC,
+        )
+        self.aug_op = rand_augment_transform(config_str, aa_params)
+        self.key = key
+
+    def aug_image(self, image: Image.Image) -> Image.Image:
+        return self.aug_op(image)
+
+    def __repr__(self):
+        return self.aug_op.__repr__()
diff --git a/src/efficientvit/apps/data_provider/base.py b/src/efficientvit/apps/data_provider/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f57679f681a6c95c48bd66216e48241c391f209
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/base.py
@@ -0,0 +1,223 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+import warnings
+
+import torch.utils.data
+from torch.utils.data.distributed import DistributedSampler
+
+from src.efficientvit.apps.data_provider.random_resolution import RRSController
+from src.efficientvit.models.utils import val2tuple
+
+__all__ = ["parse_image_size", "random_drop_data", "DataProvider"]
+
+
+def parse_image_size(size: int or str) -> tuple[int, int]:
+    if isinstance(size, str):
+        size = [int(val) for val in size.split("-")]
+        return size[0], size[1]
+    else:
+        return val2tuple(size, 2)
+
+
+def random_drop_data(dataset, drop_size: int, seed: int, keys=("samples",)):
+    g = torch.Generator()
+    g.manual_seed(seed)  # set random seed before sampling validation set
+    rand_indexes = torch.randperm(len(dataset), generator=g).tolist()
+
+    dropped_indexes = rand_indexes[:drop_size]
+    remaining_indexes = rand_indexes[drop_size:]
+
+    dropped_dataset = copy.deepcopy(dataset)
+    for key in keys:
+        setattr(
+            dropped_dataset,
+            key,
+            [getattr(dropped_dataset, key)[idx] for idx in dropped_indexes],
+        )
+        setattr(dataset, key, [getattr(dataset, key)[idx] for idx in remaining_indexes])
+    return dataset, dropped_dataset
+
+
+class DataProvider:
+    data_keys = ("samples",)
+    mean_std = {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}
+    SUB_SEED = 937162211  # random seed for sampling subset
+    VALID_SEED = 2147483647  # random seed for the validation set
+
+    name: str
+
+    def __init__(
+        self,
+        train_batch_size: int,
+        test_batch_size: int or None,
+        valid_size: int or float or None,
+        n_worker: int,
+        image_size: int or list[int] or str or list[str],
+        num_replicas: int or None = None,
+        rank: int or None = None,
+        train_ratio: float or None = None,
+        drop_last: bool = False,
+    ):
+        warnings.filterwarnings("ignore")
+        super().__init__()
+
+        # batch_size & valid_size
+        self.train_batch_size = train_batch_size
+        self.test_batch_size = test_batch_size or self.train_batch_size
+        self.valid_size = valid_size
+
+        # image size
+        if isinstance(image_size, list):
+            self.image_size = [parse_image_size(size) for size in image_size]
+            self.image_size.sort()  # e.g., 160 -> 224
+            RRSController.IMAGE_SIZE_LIST = copy.deepcopy(self.image_size)
+            self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size[-1]
+        else:
+            self.image_size = parse_image_size(image_size)
+            RRSController.IMAGE_SIZE_LIST = [self.image_size]
+            self.active_image_size = RRSController.ACTIVE_SIZE = self.image_size
+
+        # distributed configs
+        self.num_replicas = num_replicas
+        self.rank = rank
+
+        # build datasets
+        train_dataset, val_dataset, test_dataset = self.build_datasets()
+
+        if train_ratio is not None and train_ratio < 1.0:
+            assert 0 < train_ratio < 1
+            _, train_dataset = random_drop_data(
+                train_dataset,
+                int(train_ratio * len(train_dataset)),
+                self.SUB_SEED,
+                self.data_keys,
+            )
+
+        # build data loader
+        self.train = self.build_dataloader(
+            train_dataset, train_batch_size, n_worker, drop_last=drop_last, train=True
+        )
+        self.valid = self.build_dataloader(
+            val_dataset, test_batch_size, n_worker, drop_last=False, train=False
+        )
+        self.test = self.build_dataloader(
+            test_dataset, test_batch_size, n_worker, drop_last=False, train=False
+        )
+        if self.valid is None:
+            self.valid = self.test
+        self.sub_train = None
+
+    @property
+    def data_shape(self) -> tuple[int, ...]:
+        return 3, self.active_image_size[0], self.active_image_size[1]
+
+    def build_valid_transform(self, image_size: tuple[int, int] or None = None) -> any:
+        raise NotImplementedError
+
+    def build_train_transform(self, image_size: tuple[int, int] or None = None) -> any:
+        raise NotImplementedError
+
+    def build_datasets(self) -> tuple[any, any, any]:
+        raise NotImplementedError
+
+    def build_dataloader(
+        self,
+        dataset: any or None,
+        batch_size: int,
+        n_worker: int,
+        drop_last: bool,
+        train: bool,
+    ):
+        if dataset is None:
+            return None
+        if isinstance(self.image_size, list) and train:
+            from efficientvit.apps.data_provider.random_resolution._data_loader import \
+                RRSDataLoader
+
+            dataloader_class = RRSDataLoader
+        else:
+            dataloader_class = torch.utils.data.DataLoader
+        if self.num_replicas is None:
+            return dataloader_class(
+                dataset=dataset,
+                batch_size=batch_size,
+                shuffle=True,
+                num_workers=n_worker,
+                pin_memory=True,
+                drop_last=drop_last,
+            )
+        else:
+            sampler = DistributedSampler(dataset, self.num_replicas, self.rank)
+            return dataloader_class(
+                dataset=dataset,
+                batch_size=batch_size,
+                sampler=sampler,
+                num_workers=n_worker,
+                pin_memory=True,
+                drop_last=drop_last,
+            )
+
+    def set_epoch(self, epoch: int) -> None:
+        RRSController.set_epoch(epoch, len(self.train))
+        if isinstance(self.train.sampler, DistributedSampler):
+            self.train.sampler.set_epoch(epoch)
+
+    def assign_active_image_size(self, new_size: int or tuple[int, int]) -> None:
+        self.active_image_size = val2tuple(new_size, 2)
+        new_transform = self.build_valid_transform(self.active_image_size)
+        # change the transform of the valid and test set
+        self.valid.dataset.transform = self.test.dataset.transform = new_transform
+
+    def sample_val_dataset(self, train_dataset, valid_transform) -> tuple[any, any]:
+        if self.valid_size is not None:
+            if 0 < self.valid_size < 1:
+                valid_size = int(self.valid_size * len(train_dataset))
+            else:
+                assert self.valid_size >= 1
+                valid_size = int(self.valid_size)
+            train_dataset, val_dataset = random_drop_data(
+                train_dataset,
+                valid_size,
+                self.VALID_SEED,
+                self.data_keys,
+            )
+            val_dataset.transform = valid_transform
+        else:
+            val_dataset = None
+        return train_dataset, val_dataset
+
+    def build_sub_train_loader(self, n_samples: int, batch_size: int) -> any:
+        # used for resetting BN running statistics
+        if self.sub_train is None:
+            self.sub_train = {}
+        if self.active_image_size in self.sub_train:
+            return self.sub_train[self.active_image_size]
+
+        # construct dataset and dataloader
+        train_dataset = copy.deepcopy(self.train.dataset)
+        if n_samples < len(train_dataset):
+            _, train_dataset = random_drop_data(
+                train_dataset,
+                n_samples,
+                self.SUB_SEED,
+                self.data_keys,
+            )
+        RRSController.ACTIVE_SIZE = self.active_image_size
+        train_dataset.transform = self.build_train_transform(
+            image_size=self.active_image_size
+        )
+        data_loader = self.build_dataloader(
+            train_dataset, batch_size, self.train.num_workers, True, False
+        )
+
+        # pre-fetch data
+        self.sub_train[self.active_image_size] = [
+            data
+            for data in data_loader
+            for _ in range(max(1, n_samples // len(train_dataset)))
+        ]
+
+        return self.sub_train[self.active_image_size]
diff --git a/src/efficientvit/apps/data_provider/random_resolution/__init__.py b/src/efficientvit/apps/data_provider/random_resolution/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b831fa9d3e933e76cf78120947143e8a19133ea2
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/random_resolution/__init__.py
@@ -0,0 +1,7 @@
+"""Random resolution data loader compatible with multi-processing and distributed training.
+
+Replace Pytorch's DataLoader with RRSDataLoader to support random resolution
+at the training time, resolution sampling is controlled by RRSController
+"""
+
+from .controller import *
diff --git a/src/efficientvit/apps/data_provider/random_resolution/_data_loader.py b/src/efficientvit/apps/data_provider/random_resolution/_data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..df06e2b95a4468cbeecc6b56c6f97deb8c35ff2c
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/random_resolution/_data_loader.py
@@ -0,0 +1,1598 @@
+r"""This file is based on torch/utils/data/data_loader.py
+
+Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
+
+To support these two classes, in `./_utils` we define many utility methods and
+functions to be run in multiprocessing. E.g., the data loading worker loop is
+in `./_utils/worker.py`.
+"""
+
+import functools
+import itertools
+import logging
+import multiprocessing as python_multiprocessing
+import os
+import queue
+import threading
+import warnings
+from typing import (Any, Callable, Generic, Iterable, List, Optional, Sequence,
+                    TypeVar, Union)
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as multiprocessing
+import torch.utils.data.graph_settings
+from torch._utils import ExceptionWrapper
+from torch.utils.data import (BatchSampler, Dataset, IterableDataset,
+                              IterDataPipe, MapDataPipe, RandomSampler,
+                              Sampler, SequentialSampler, _utils)
+from torch.utils.data.datapipes.datapipe import (
+    _IterDataPipeSerializationWrapper, _MapDataPipeSerializationWrapper)
+
+from ._data_worker import _worker_loop
+
+__all__ = ["RRSDataLoader"]
+
+T_co = TypeVar("T_co", covariant=True)
+T = TypeVar("T")
+_worker_init_fn_t = Callable[[int], None]
+
+# Ideally we would parameterize `DataLoader` by the return type of `collate_fn`, but there is currently no way to have that
+# type parameter set to a default value if the user doesn't pass in a custom 'collate_fn'.
+# See https://github.com/python/mypy/issues/3737.
+_collate_fn_t = Callable[[List[T]], Any]
+
+
+# These functions used to be defined in this file. However, it was moved to
+# _utils/collate.py. Although it is rather hard to access this from user land
+# (one has to explicitly directly `import torch.utils.data.dataloader`), there
+# probably is user code out there using it. This aliasing maintains BC in this
+# aspect.
+default_collate: _collate_fn_t = _utils.collate.default_collate
+default_convert = _utils.collate.default_convert
+
+get_worker_info = _utils.worker.get_worker_info
+
+logger = logging.getLogger(__name__)
+
+
+class _DatasetKind:
+    Map = 0
+    Iterable = 1
+
+    @staticmethod
+    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
+        if kind == _DatasetKind.Map:
+            return _utils.fetch._MapDatasetFetcher(
+                dataset, auto_collation, collate_fn, drop_last
+            )
+        else:
+            return _utils.fetch._IterableDatasetFetcher(
+                dataset, auto_collation, collate_fn, drop_last
+            )
+
+
+class _InfiniteConstantSampler(Sampler):
+    r"""Analogous to ``itertools.repeat(None, None)``.
+    Used as sampler for :class:`~torch.utils.data.IterableDataset`.
+
+    Args:
+        data_source (Dataset): dataset to sample from
+    """
+
+    def __init__(self):
+        super().__init__(None)
+
+    def __iter__(self):
+        while True:
+            yield None
+
+
+def _get_distributed_settings():
+    if dist.is_available() and dist.is_initialized():
+        return dist.get_world_size(), dist.get_rank()
+    else:
+        return 1, 0
+
+
+def _sharding_worker_init_fn(worker_init_fn, world_size, rank_id, worker_id):
+    global_worker_id = worker_id
+    info = torch.utils.data.get_worker_info()
+    assert info is not None
+    total_workers = info.num_workers
+    datapipe = info.dataset
+    assert isinstance(datapipe, (IterDataPipe, MapDataPipe))
+    # To distribute elements across distributed process evenly, we should shard data on distributed
+    # processes first then shard on worker processes
+    total_workers *= world_size
+    global_worker_id = global_worker_id * world_size + rank_id
+    # For BC, use default SHARDING_PRIORITIES
+    torch.utils.data.graph_settings.apply_sharding(
+        datapipe, total_workers, global_worker_id
+    )
+    if worker_init_fn is not None:
+        worker_init_fn(worker_id)
+
+
+def _share_dist_seed(generator, pg):
+    _shared_seed = torch.empty((), dtype=torch.int64).random_(generator=generator)
+    if isinstance(pg, dist.ProcessGroup):
+        dist.broadcast(_shared_seed, src=0, group=pg)
+    return _shared_seed.item()
+
+
+class RRSDataLoader(Generic[T_co]):
+    r"""
+    Data loader. Combines a dataset and a sampler, and provides an iterable over
+    the given dataset.
+
+    The :class:`~torch.utils.data.DataLoader` supports both map-style and
+    iterable-style datasets with single- or multi-process loading, customizing
+    loading order and optional automatic batching (collation) and memory pinning.
+
+    See :py:mod:`torch.utils.data` documentation page for more details.
+
+    Args:
+        dataset (Dataset): dataset from which to load the data.
+        batch_size (int, optional): how many samples per batch to load
+            (default: ``1``).
+        shuffle (bool, optional): set to ``True`` to have the data reshuffled
+            at every epoch (default: ``False``).
+        sampler (Sampler or Iterable, optional): defines the strategy to draw
+            samples from the dataset. Can be any ``Iterable`` with ``__len__``
+            implemented. If specified, :attr:`shuffle` must not be specified.
+        batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
+            returns a batch of indices at a time. Mutually exclusive with
+            :attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
+            and :attr:`drop_last`.
+        num_workers (int, optional): how many subprocesses to use for data
+            loading. ``0`` means that the data will be loaded in the main process.
+            (default: ``0``)
+        collate_fn (Callable, optional): merges a list of samples to form a
+            mini-batch of Tensor(s).  Used when using batched loading from a
+            map-style dataset.
+        pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
+            into device/CUDA pinned memory before returning them.  If your data elements
+            are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
+            see the example below.
+        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
+            if the dataset size is not divisible by the batch size. If ``False`` and
+            the size of dataset is not divisible by the batch size, then the last batch
+            will be smaller. (default: ``False``)
+        timeout (numeric, optional): if positive, the timeout value for collecting a batch
+            from workers. Should always be non-negative. (default: ``0``)
+        worker_init_fn (Callable, optional): If not ``None``, this will be called on each
+            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
+            input, after seeding and before data loading. (default: ``None``)
+        generator (torch.Generator, optional): If not ``None``, this RNG will be used
+            by RandomSampler to generate random indexes and multiprocessing to generate
+            `base_seed` for workers. (default: ``None``)
+        prefetch_factor (int, optional, keyword-only arg): Number of batches loaded
+            in advance by each worker. ``2`` means there will be a total of
+            2 * num_workers batches prefetched across all workers. (default value depends
+            on the set value for num_workers. If value of num_workers=0 default is ``None``.
+            Otherwise if value of num_workers>0 default is ``2``).
+        persistent_workers (bool, optional): If ``True``, the data loader will not shutdown
+            the worker processes after a dataset has been consumed once. This allows to
+            maintain the workers `Dataset` instances alive. (default: ``False``)
+        pin_memory_device (str, optional): the data loader will copy Tensors
+            into device pinned memory before returning them if pin_memory is set to true.
+
+
+    .. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
+                 cannot be an unpicklable object, e.g., a lambda function. See
+                 :ref:`multiprocessing-best-practices` on more details related
+                 to multiprocessing in PyTorch.
+
+    .. warning:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
+                 When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
+                 it instead returns an estimate based on ``len(dataset) / batch_size``, with proper
+                 rounding depending on :attr:`drop_last`, regardless of multi-process loading
+                 configurations. This represents the best guess PyTorch can make because PyTorch
+                 trusts user :attr:`dataset` code in correctly handling multi-process
+                 loading to avoid duplicate data.
+
+                 However, if sharding results in multiple workers having incomplete last batches,
+                 this estimate can still be inaccurate, because (1) an otherwise complete batch can
+                 be broken into multiple ones and (2) more than one batch worth of samples can be
+                 dropped when :attr:`drop_last` is set. Unfortunately, PyTorch can not detect such
+                 cases in general.
+
+                 See `Dataset Types`_ for more details on these two types of datasets and how
+                 :class:`~torch.utils.data.IterableDataset` interacts with
+                 `Multi-process data loading`_.
+
+    .. warning:: See :ref:`reproducibility`, and :ref:`dataloader-workers-random-seed`, and
+                 :ref:`data-loading-randomness` notes for random seed related questions.
+    """
+
+    dataset: Dataset[T_co]
+    batch_size: Optional[int]
+    num_workers: int
+    pin_memory: bool
+    drop_last: bool
+    timeout: float
+    sampler: Union[Sampler, Iterable]
+    pin_memory_device: str
+    prefetch_factor: Optional[int]
+    _iterator: Optional["_BaseDataLoaderIter"]
+    __initialized = False
+
+    def __init__(
+        self,
+        dataset: Dataset[T_co],
+        batch_size: Optional[int] = 1,
+        shuffle: Optional[bool] = None,
+        sampler: Union[Sampler, Iterable, None] = None,
+        batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
+        num_workers: int = 0,
+        collate_fn: Optional[_collate_fn_t] = None,
+        pin_memory: bool = False,
+        drop_last: bool = False,
+        timeout: float = 0,
+        worker_init_fn: Optional[_worker_init_fn_t] = None,
+        multiprocessing_context=None,
+        generator=None,
+        *,
+        prefetch_factor: Optional[int] = None,
+        persistent_workers: bool = False,
+        pin_memory_device: str = ""
+    ):
+        torch._C._log_api_usage_once("python.data_loader")
+
+        if num_workers < 0:
+            raise ValueError(
+                "num_workers option should be non-negative; "
+                "use num_workers=0 to disable multiprocessing."
+            )
+
+        if timeout < 0:
+            raise ValueError("timeout option should be non-negative")
+
+        if num_workers == 0 and prefetch_factor is not None:
+            raise ValueError(
+                "prefetch_factor option could only be specified in multiprocessing."
+                "let num_workers > 0 to enable multiprocessing, otherwise set prefetch_factor to None."
+            )
+        elif num_workers > 0 and prefetch_factor is None:
+            prefetch_factor = 2
+        elif prefetch_factor is not None and prefetch_factor < 0:
+            raise ValueError("prefetch_factor option should be non-negative")
+
+        if persistent_workers and num_workers == 0:
+            raise ValueError("persistent_workers option needs num_workers > 0")
+
+        self.dataset = dataset
+        self.num_workers = num_workers
+        self.prefetch_factor = prefetch_factor
+        self.pin_memory = pin_memory
+        self.pin_memory_device = pin_memory_device
+        self.timeout = timeout
+        self.worker_init_fn = worker_init_fn
+        self.multiprocessing_context = multiprocessing_context
+
+        # Adds forward compatibilities so classic DataLoader can work with DataPipes:
+        #   _DataPipeSerializationWrapper container makes it easier to serialize without redefining pickler
+        if isinstance(self.dataset, IterDataPipe):
+            self.dataset = _IterDataPipeSerializationWrapper(self.dataset)
+        elif isinstance(self.dataset, MapDataPipe):
+            self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
+
+        # Arg-check dataset related before checking samplers because we want to
+        # tell users that iterable-style datasets are incompatible with custom
+        # samplers first, so that they don't learn that this combo doesn't work
+        # after spending time fixing the custom sampler errors.
+        if isinstance(dataset, IterableDataset):
+            self._dataset_kind = _DatasetKind.Iterable
+            # NOTE [ Custom Samplers and IterableDataset ]
+            #
+            # `IterableDataset` does not support custom `batch_sampler` or
+            # `sampler` since the key is irrelevant (unless we support
+            # generator-style dataset one day...).
+            #
+            # For `sampler`, we always create a dummy sampler. This is an
+            # infinite sampler even when the dataset may have an implemented
+            # finite `__len__` because in multi-process data loading, naive
+            # settings will return duplicated data (which may be desired), and
+            # thus using a sampler with length matching that of dataset will
+            # cause data lost (you may have duplicates of the first couple
+            # batches, but never see anything afterwards). Therefore,
+            # `Iterabledataset` always uses an infinite sampler, an instance of
+            # `_InfiniteConstantSampler` defined above.
+            #
+            # A custom `batch_sampler` essentially only controls the batch size.
+            # However, it is unclear how useful it would be since an iterable-style
+            # dataset can handle that within itself. Moreover, it is pointless
+            # in multi-process data loading as the assignment order of batches
+            # to workers is an implementation detail so users can not control
+            # how to batchify each worker's iterable. Thus, we disable this
+            # option. If this turns out to be useful in future, we can re-enable
+            # this, and support custom samplers that specify the assignments to
+            # specific workers.
+            if isinstance(dataset, IterDataPipe):
+                if shuffle is not None:
+                    dataset = torch.utils.data.graph_settings.apply_shuffle_settings(
+                        dataset, shuffle=shuffle
+                    )
+            # We cannot check `shuffle is not None` here, since previously `shuffle=False` was the default.
+            elif shuffle not in {False, None}:
+                raise ValueError(
+                    "DataLoader with IterableDataset: expected unspecified "
+                    "shuffle option, but got shuffle={}".format(shuffle)
+                )
+
+            if sampler is not None:
+                # See NOTE [ Custom Samplers and IterableDataset ]
+                raise ValueError(
+                    "DataLoader with IterableDataset: expected unspecified "
+                    "sampler option, but got sampler={}".format(sampler)
+                )
+            elif batch_sampler is not None:
+                # See NOTE [ Custom Samplers and IterableDataset ]
+                raise ValueError(
+                    "DataLoader with IterableDataset: expected unspecified "
+                    "batch_sampler option, but got batch_sampler={}".format(
+                        batch_sampler
+                    )
+                )
+        else:
+            shuffle = bool(shuffle)
+            self._dataset_kind = _DatasetKind.Map
+
+        if sampler is not None and shuffle:
+            raise ValueError("sampler option is mutually exclusive with " "shuffle")
+
+        if batch_sampler is not None:
+            # auto_collation with custom batch_sampler
+            if batch_size != 1 or shuffle or sampler is not None or drop_last:
+                raise ValueError(
+                    "batch_sampler option is mutually exclusive "
+                    "with batch_size, shuffle, sampler, and "
+                    "drop_last"
+                )
+            batch_size = None
+            drop_last = False
+        elif batch_size is None:
+            # no auto_collation
+            if drop_last:
+                raise ValueError(
+                    "batch_size=None option disables auto-batching "
+                    "and is mutually exclusive with drop_last"
+                )
+
+        if sampler is None:  # give default samplers
+            if self._dataset_kind == _DatasetKind.Iterable:
+                # See NOTE [ Custom Samplers and IterableDataset ]
+                sampler = _InfiniteConstantSampler()
+            else:  # map-style
+                if shuffle:
+                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
+                else:
+                    sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
+
+        if batch_size is not None and batch_sampler is None:
+            # auto_collation without custom batch_sampler
+            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
+
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.sampler = sampler
+        self.batch_sampler = batch_sampler
+        self.generator = generator
+
+        if collate_fn is None:
+            if self._auto_collation:
+                collate_fn = _utils.collate.default_collate
+            else:
+                collate_fn = _utils.collate.default_convert
+
+        self.collate_fn = collate_fn
+        self.persistent_workers = persistent_workers
+
+        self.__initialized = True
+        self._IterableDataset_len_called = (
+            None  # See NOTE [ IterableDataset and __len__ ]
+        )
+
+        self._iterator = None
+
+        self.check_worker_number_rationality()
+
+        torch.set_vital("Dataloader", "enabled", "True")  # type: ignore[attr-defined]
+
+    def _get_iterator(self) -> "_BaseDataLoaderIter":
+        if self.num_workers == 0:
+            return _SingleProcessDataLoaderIter(self)
+        else:
+            self.check_worker_number_rationality()
+            return _MultiProcessingDataLoaderIter(self)
+
+    @property
+    def multiprocessing_context(self):
+        return self.__multiprocessing_context
+
+    @multiprocessing_context.setter
+    def multiprocessing_context(self, multiprocessing_context):
+        if multiprocessing_context is not None:
+            if self.num_workers > 0:
+                if isinstance(multiprocessing_context, str):
+                    valid_start_methods = multiprocessing.get_all_start_methods()
+                    if multiprocessing_context not in valid_start_methods:
+                        raise ValueError(
+                            (
+                                "multiprocessing_context option "
+                                "should specify a valid start method in {!r}, but got "
+                                "multiprocessing_context={!r}"
+                            ).format(valid_start_methods, multiprocessing_context)
+                        )
+                    multiprocessing_context = multiprocessing.get_context(
+                        multiprocessing_context
+                    )
+
+                if not isinstance(
+                    multiprocessing_context, python_multiprocessing.context.BaseContext
+                ):
+                    raise TypeError(
+                        (
+                            "multiprocessing_context option should be a valid context "
+                            "object or a string specifying the start method, but got "
+                            "multiprocessing_context={}"
+                        ).format(multiprocessing_context)
+                    )
+            else:
+                raise ValueError(
+                    (
+                        "multiprocessing_context can only be used with "
+                        "multi-process loading (num_workers > 0), but got "
+                        "num_workers={}"
+                    ).format(self.num_workers)
+                )
+
+        self.__multiprocessing_context = multiprocessing_context
+
+    def __setattr__(self, attr, val):
+        if self.__initialized and attr in (
+            "batch_size",
+            "batch_sampler",
+            "sampler",
+            "drop_last",
+            "dataset",
+            "persistent_workers",
+        ):
+            raise ValueError(
+                "{} attribute should not be set after {} is "
+                "initialized".format(attr, self.__class__.__name__)
+            )
+
+        super().__setattr__(attr, val)
+
+    # We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
+    # since '_BaseDataLoaderIter' references 'DataLoader'.
+    def __iter__(self) -> "_BaseDataLoaderIter":
+        # When using a single worker the returned iterator should be
+        # created everytime to avoid reseting its state
+        # However, in the case of a multiple workers iterator
+        # the iterator is only created once in the lifetime of the
+        # DataLoader object so that workers can be reused
+        if self.persistent_workers and self.num_workers > 0:
+            if self._iterator is None:
+                self._iterator = self._get_iterator()
+            else:
+                self._iterator._reset(self)
+            return self._iterator
+        else:
+            return self._get_iterator()
+
+    @property
+    def _auto_collation(self):
+        return self.batch_sampler is not None
+
+    @property
+    def _index_sampler(self):
+        # The actual sampler used for generating indices for `_DatasetFetcher`
+        # (see _utils/fetch.py) to read data at each time. This would be
+        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
+        # We can't change `.sampler` and `.batch_sampler` attributes for BC
+        # reasons.
+        if self._auto_collation:
+            return self.batch_sampler
+        else:
+            return self.sampler
+
+    def __len__(self) -> int:
+        if self._dataset_kind == _DatasetKind.Iterable:
+            # NOTE [ IterableDataset and __len__ ]
+            #
+            # For `IterableDataset`, `__len__` could be inaccurate when one naively
+            # does multi-processing data loading, since the samples will be duplicated.
+            # However, no real use case should be actually using that behavior, so
+            # it should count as a user error. We should generally trust user
+            # code to do the proper thing (e.g., configure each replica differently
+            # in `__iter__`), and give us the correct `__len__` if they choose to
+            # implement it (this will still throw if the dataset does not implement
+            # a `__len__`).
+            #
+            # To provide a further warning, we track if `__len__` was called on the
+            # `DataLoader`, save the returned value in `self._len_called`, and warn
+            # if the iterator ends up yielding more than this number of samples.
+
+            # Cannot statically verify that dataset is Sized
+            length = self._IterableDataset_len_called = len(self.dataset)  # type: ignore[assignment, arg-type]
+            if (
+                self.batch_size is not None
+            ):  # IterableDataset doesn't allow custom sampler or batch_sampler
+                from math import ceil
+
+                if self.drop_last:
+                    length = length // self.batch_size
+                else:
+                    length = ceil(length / self.batch_size)
+            return length
+        else:
+            return len(self._index_sampler)
+
+    def check_worker_number_rationality(self):
+        # This function check whether the dataloader's worker number is rational based on
+        # current system's resource. Current rule is that if the number of workers this
+        # Dataloader will create is bigger than the number of logical cpus that is allowed to
+        # use, than we will pop up a warning to let user pay attention.
+        #
+        # eg. If current system has 2 physical CPUs with 16 cores each. And each core support 2
+        #     threads, then the total logical cpus here is 2 * 16 * 2 = 64. Let's say current
+        #     DataLoader process can use half of them which is 32, then the rational max number of
+        #     worker that initiated from this process is 32.
+        #     Now, let's say the created DataLoader has num_works = 40, which is bigger than 32.
+        #     So the warning message is triggered to notify the user to lower the worker number if
+        #     necessary.
+        #
+        #
+        # [Note] Please note that this function repects `cpuset` only when os.sched_getaffinity is
+        #        available (available in most of Linux system, but not OSX and Windows).
+        #        When os.sched_getaffinity is not available, os.cpu_count() is called instead, but
+        #        it doesn't repect cpuset.
+        #        We don't take threading into account since each worker process is single threaded
+        #        at this time.
+        #
+        #        We don't set any threading flags (eg. OMP_NUM_THREADS, MKL_NUM_THREADS, etc)
+        #        other than `torch.set_num_threads` to 1 in the worker process, if the passing
+        #        in functions use 3rd party modules that rely on those threading flags to determine
+        #        how many thread to create (eg. numpy, etc), then it is caller's responsibility to
+        #        set those flags correctly.
+        def _create_warning_msg(num_worker_suggest, num_worker_created, cpuset_checked):
+
+            suggested_max_worker_msg = (
+                (
+                    (
+                        "Our suggested max number of worker in current system is {}{}, which is smaller "
+                        "than what this DataLoader is going to create."
+                    ).format(
+                        num_worker_suggest,
+                        (
+                            ""
+                            if cpuset_checked
+                            else " (`cpuset` is not taken into account)"
+                        ),
+                    )
+                )
+                if num_worker_suggest is not None
+                else (
+                    "DataLoader is not able to compute a suggested max number of worker in current system."
+                )
+            )
+
+            warn_msg = (
+                "This DataLoader will create {} worker processes in total. {} "
+                "Please be aware that excessive worker creation might get DataLoader running slow or even freeze, "
+                "lower the worker number to avoid potential slowness/freeze if necessary."
+            ).format(num_worker_created, suggested_max_worker_msg)
+            return warn_msg
+
+        if not self.num_workers or self.num_workers == 0:
+            return
+
+        # try to compute a suggested max number of worker based on system's resource
+        max_num_worker_suggest = None
+        cpuset_checked = False
+        if hasattr(os, "sched_getaffinity"):
+            try:
+                max_num_worker_suggest = len(os.sched_getaffinity(0))
+                cpuset_checked = True
+            except Exception:
+                pass
+        if max_num_worker_suggest is None:
+            # os.cpu_count() could return Optional[int]
+            # get cpu count first and check None in order to satify mypy check
+            cpu_count = os.cpu_count()
+            if cpu_count is not None:
+                max_num_worker_suggest = cpu_count
+
+        if max_num_worker_suggest is None:
+            warnings.warn(
+                _create_warning_msg(
+                    max_num_worker_suggest, self.num_workers, cpuset_checked
+                )
+            )
+            return
+
+        if self.num_workers > max_num_worker_suggest:
+            warnings.warn(
+                _create_warning_msg(
+                    max_num_worker_suggest, self.num_workers, cpuset_checked
+                )
+            )
+
+
+class _BaseDataLoaderIter:
+    def __init__(self, loader: RRSDataLoader) -> None:
+        self._dataset = loader.dataset
+        self._shared_seed = None
+        self._pg = None
+        if isinstance(self._dataset, IterDataPipe):
+            if dist.is_available() and dist.is_initialized():
+                self._pg = dist.new_group(backend="gloo")
+            self._shared_seed = _share_dist_seed(loader.generator, self._pg)
+            shared_rng = torch.Generator()
+            shared_rng.manual_seed(self._shared_seed)
+            self._dataset = torch.utils.data.graph_settings.apply_random_seed(
+                self._dataset, shared_rng
+            )
+        self._dataset_kind = loader._dataset_kind
+        self._IterableDataset_len_called = loader._IterableDataset_len_called
+        self._auto_collation = loader._auto_collation
+        self._drop_last = loader.drop_last
+        self._index_sampler = loader._index_sampler
+        self._num_workers = loader.num_workers
+        ws, rank = _get_distributed_settings()
+        self._world_size = ws
+        self._rank = rank
+        # for other backends, pin_memory_device need to set. if not set
+        # default behaviour is CUDA device. if pin_memory_device is selected
+        # and pin_memory is not set, the default behaviour false.
+        if len(loader.pin_memory_device) == 0:
+            self._pin_memory = loader.pin_memory and torch.cuda.is_available()
+            self._pin_memory_device = None
+        else:
+            if not loader.pin_memory:
+                warn_msg = (
+                    "pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
+                    "please set pin_memory to true, if you need to use the device pin memory"
+                )
+                warnings.warn(warn_msg)
+
+            self._pin_memory = loader.pin_memory
+            self._pin_memory_device = loader.pin_memory_device
+        self._timeout = loader.timeout
+        self._collate_fn = loader.collate_fn
+        self._sampler_iter = iter(self._index_sampler)
+        self._base_seed = (
+            torch.empty((), dtype=torch.int64)
+            .random_(generator=loader.generator)
+            .item()
+        )
+        self._persistent_workers = loader.persistent_workers
+        self._num_yielded = 0
+        self._profile_name = "enumerate(DataLoader)#{}.__next__".format(
+            self.__class__.__name__
+        )
+
+    def __iter__(self) -> "_BaseDataLoaderIter":
+        return self
+
+    def _reset(self, loader, first_iter=False):
+        self._sampler_iter = iter(self._index_sampler)
+        self._num_yielded = 0
+        self._IterableDataset_len_called = loader._IterableDataset_len_called
+        if isinstance(self._dataset, IterDataPipe):
+            self._shared_seed = _share_dist_seed(loader.generator, self._pg)
+            shared_rng = torch.Generator()
+            shared_rng.manual_seed(self._shared_seed)
+            self._dataset = torch.utils.data.graph_settings.apply_random_seed(
+                self._dataset, shared_rng
+            )
+
+    def _next_index(self):
+        return next(self._sampler_iter)  # may raise StopIteration
+
+    def _next_data(self):
+        raise NotImplementedError
+
+    def __next__(self) -> Any:
+        with torch.autograd.profiler.record_function(self._profile_name):
+            if self._sampler_iter is None:
+                self._reset()  # type: ignore[call-arg]
+            data = self._next_data()
+            self._num_yielded += 1
+            if (
+                self._dataset_kind == _DatasetKind.Iterable
+                and self._IterableDataset_len_called is not None
+                and self._num_yielded > self._IterableDataset_len_called
+            ):
+                warn_msg = (
+                    "Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
+                    "samples have been fetched. "
+                ).format(
+                    self._dataset, self._IterableDataset_len_called, self._num_yielded
+                )
+                if self._num_workers > 0:
+                    warn_msg += (
+                        "For multiprocessing data-loading, this could be caused by not properly configuring the "
+                        "IterableDataset replica at each worker. Please see "
+                        "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples."
+                    )
+                warnings.warn(warn_msg)
+            return data
+
+    def __len__(self) -> int:
+        return len(self._index_sampler)
+
+    def __getstate__(self):
+        # across multiple threads for HOGWILD.
+        # Probably the best way to do this is by moving the sample pushing
+        # to a separate thread and then just sharing the data queue
+        # but signalling the end is tricky without a non-blocking API
+        raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
+
+
+class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
+    def __init__(self, loader):
+        super().__init__(loader)
+        assert self._timeout == 0
+        assert self._num_workers == 0
+
+        # Adds forward compatibilities so classic DataLoader can work with DataPipes:
+        #   Taking care of distributed sharding
+        if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
+            # For BC, use default SHARDING_PRIORITIES
+            torch.utils.data.graph_settings.apply_sharding(
+                self._dataset, self._world_size, self._rank
+            )
+
+        self._dataset_fetcher = _DatasetKind.create_fetcher(
+            self._dataset_kind,
+            self._dataset,
+            self._auto_collation,
+            self._collate_fn,
+            self._drop_last,
+        )
+
+    def _next_data(self):
+        index = self._next_index()  # may raise StopIteration
+        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
+        if self._pin_memory:
+            data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
+        return data
+
+
+class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
+    r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
+
+    # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
+    #
+    # Preliminary:
+    #
+    # Our data model looks like this (queues are indicated with curly brackets):
+    #
+    #                main process                              ||
+    #                     |                                    ||
+    #               {index_queue}                              ||
+    #                     |                                    ||
+    #              worker processes                            ||     DATA
+    #                     |                                    ||
+    #            {worker_result_queue}                         ||     FLOW
+    #                     |                                    ||
+    #      pin_memory_thread of main process                   ||   DIRECTION
+    #                     |                                    ||
+    #               {data_queue}                               ||
+    #                     |                                    ||
+    #                data output                               \/
+    #
+    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
+    #      `pin_memory=False`.
+    #
+    #
+    # Terminating multiprocessing logic requires very careful design. In
+    # particular, we need to make sure that
+    #
+    #   1. The iterator gracefully exits the workers when its last reference is
+    #      gone or it is depleted.
+    #
+    #      In this case, the workers should be gracefully exited because the
+    #      main process may still need to continue to run, and we want cleaning
+    #      up code in the workers to be executed (e.g., releasing GPU memory).
+    #      Naturally, we implement the shutdown logic in `__del__` of
+    #      DataLoaderIterator.
+    #
+    #      We delay the discussion on the logic in this case until later.
+    #
+    #   2. The iterator exits the workers when the loader process and/or worker
+    #      processes exits normally or with error.
+    #
+    #      We set all workers and `pin_memory_thread` to have `daemon=True`.
+    #
+    #      You may ask, why can't we make the workers non-daemonic, and
+    #      gracefully exit using the same logic as we have in `__del__` when the
+    #      iterator gets deleted (see 1 above)?
+    #
+    #      First of all, `__del__` is **not** guaranteed to be called when
+    #      interpreter exits. Even if it is called, by the time it executes,
+    #      many Python core library resources may alreay be freed, and even
+    #      simple things like acquiring an internal lock of a queue may hang.
+    #      Therefore, in this case, we actually need to prevent `__del__` from
+    #      being executed, and rely on the automatic termination of daemonic
+    #      children.
+    #
+    #      Thus, we register an `atexit` hook that sets a global flag
+    #      `_utils.python_exit_status`. Since `atexit` hooks are executed in the
+    #      reverse order of registration, we are guaranteed that this flag is
+    #      set before library resources we use are freed (which, at least in
+    #      CPython, is done via an `atexit` handler defined in
+    #      `multiprocessing/util.py`
+    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/util.py#L320-L362
+    #      registered when an object requiring this mechanism is first
+    #      created, e.g., `mp.Queue`
+    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/context.py#L100-L103
+    #      https://github.com/python/cpython/blob/c606624af8d4cb3b4a052fb263bb983b3f87585b/Lib/multiprocessing/queues.py#L29
+    #      )
+    #
+    #      So in `__del__`, we check if `_utils.python_exit_status` is set or
+    #      `None` (freed), and perform no-op if so.
+    #
+    #      However, simply letting library clean-up codes run can also be bad,
+    #      because such codes (i.e., `multiprocessing.util._exit_function()`)
+    #      include join putting threads for `mp.Queue`, which can be blocking.
+    #      Hence, the main process putting threads are called with
+    #      `cancel_join_thread` at creation.  See later section
+    #      [ 3b. A process won't hang when putting into a queue; ]
+    #      for more details.
+    #
+    #      Here are two example cases where library clean-up codes can run
+    #      before `__del__` is called:
+    #
+    #        1. If we hold onto a reference to the iterator, it more often
+    #           than not tries to do `multiprocessing` library cleaning before
+    #           clearing the alive referenced objects (https://github.com/pytorch/pytorch/issues/48666)
+    #           and thus prevents our cleaning-up code to run first.
+    #
+    #        2. A similar issue araises when a `DataLoader` is used in a subprocess.
+    #           When a process ends, it shuts the all its daemonic children
+    #           down with a SIGTERM (instead of joining them without a timeout).
+    #           Simiarly for threads, but by a different mechanism. This fact,
+    #           together with a few implementation details of multiprocessing, forces
+    #           us to make workers daemonic. All of our problems arise when a
+    #           DataLoader is used in a subprocess, and are caused by multiprocessing
+    #           code which looks more or less like this:
+    #
+    #               try:
+    #                   your_function_using_a_dataloader()
+    #               finally:
+    #                   multiprocessing.util._exit_function()
+    #
+    #           The joining/termination mentioned above happens inside
+    #           `_exit_function()`. Now, if `your_function_using_a_dataloader()`
+    #           throws, the stack trace stored in the exception will prevent the
+    #           frame which uses `DataLoaderIter` to be freed. If the frame has any
+    #           reference to the `DataLoaderIter` (e.g., in a method of the iter),
+    #           its  `__del__`, which starts the shutdown procedure, will not be
+    #           called. That, in turn, means that workers aren't notified. Attempting
+    #           to join in `_exit_function` will then result in a hang.
+    #
+    #           For context, `_exit_function` is also registered as an `atexit` call.
+    #           So it is unclear to me (@ssnl) why this is needed in a finally block.
+    #           The code dates back to 2008 and there is no comment on the original
+    #           PEP 371 or patch https://bugs.python.org/issue3050 (containing both
+    #           the finally block and the `atexit` registration) that explains this.
+    #
+    #
+    #      Finally, another choice is to just shutdown workers with logic in 1
+    #      above whenever we see an error in `next`. This isn't ideal because
+    #        a. It prevents users from using try-catch to resume data loading.
+    #        b. It doesn't prevent hanging if users have references to the
+    #           iterator.
+    #
+    #   3. All processes exit if any of them die unexpectedly by fatal signals.
+    #
+    #      As shown above, the workers are set as daemonic children of the main
+    #      process. However, automatic cleaning-up of such child processes only
+    #      happens if the parent process exits gracefully (e.g., not via fatal
+    #      signals like SIGKILL). So we must ensure that each process will exit
+    #      even the process that should send/receive data to/from it were
+    #      killed, i.e.,
+    #
+    #        a. A process won't hang when getting from a queue.
+    #
+    #           Even with carefully designed data dependencies (i.e., a `put()`
+    #           always corresponding to a `get()`), hanging on `get()` can still
+    #           happen when data in queue is corrupted (e.g., due to
+    #           `cancel_join_thread` or unexpected exit).
+    #
+    #           For child exit, we set a timeout whenever we try to get data
+    #           from `data_queue`, and check the workers' status on each timeout
+    #           and error.
+    #           See `_DataLoaderiter._get_batch()` and
+    #           `_DataLoaderiter._try_get_data()` for details.
+    #
+    #           Additionally, for child exit on non-Windows platforms, we also
+    #           register a SIGCHLD handler (which is supported on Windows) on
+    #           the main process, which checks if any of the workers fail in the
+    #           (Python) handler. This is more efficient and faster in detecting
+    #           worker failures, compared to only using the above mechanism.
+    #           See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
+    #
+    #           For `.get()` calls where the sender(s) is not the workers, we
+    #           guard them with timeouts, and check the status of the sender
+    #           when timeout happens:
+    #             + in the workers, the `_utils.worker.ManagerWatchdog` class
+    #               checks the status of the main process.
+    #             + if `pin_memory=True`, when getting from `pin_memory_thread`,
+    #               check `pin_memory_thread` status periodically until `.get()`
+    #               returns or see that `pin_memory_thread` died.
+    #
+    #        b. A process won't hang when putting into a queue;
+    #
+    #           We use `mp.Queue` which has a separate background thread to put
+    #           objects from an unbounded buffer array. The background thread is
+    #           daemonic and usually automatically joined when the process
+    #           *exits*.
+    #
+    #           In case that the receiver has ended abruptly while
+    #           reading from the pipe, the join will hang forever.  The usual
+    #           solution for this in Python is calling  `q.cancel_join_thread`,
+    #           which prevents automatically joining it when finalizing
+    #           (exiting).
+    #
+    #           Nonetheless, `cancel_join_thread` must only be called when the
+    #           queue is **not** going to be read from or write into by another
+    #           process, because it may hold onto a lock or leave corrupted data
+    #           in the queue, leading other readers/writers to hang.
+    #
+    #           Hence,
+    #             + For worker processes, we only do so (for their output
+    #               queues, i.e., `worker_result_queue`) before exiting.
+    #             + For `pin_memory_thread`, its output queue `data_queue` is a
+    #               `queue.Queue` that does blocking `put` if the queue is full.
+    #               So there is no above problem, but as a result, in
+    #               `_pin_memory_loop`, we do need to  wrap the `put` in a loop
+    #               that breaks not only upon success, but also when the main
+    #               process stops reading, i.e., is shutting down.
+    #             + For loader process, we `cancel_join_thread()` for all
+    #               `_index_queues` because the whole purpose of workers and
+    #               `pin_memory_thread` is to serve the loader process.  If
+    #               loader process is already exiting, we don't really care if
+    #               the queues are corrupted.
+    #
+    #
+    # Now let's get back to 1:
+    #   how we gracefully exit the workers when the last reference to the
+    #   iterator is gone.
+    #
+    # To achieve this, we implement the following logic along with the design
+    # choices mentioned above:
+    #
+    # `workers_done_event`:
+    #   A `multiprocessing.Event` shared among the main process and all worker
+    #   processes. This is used to signal the workers that the iterator is
+    #   shutting down. After it is set, they will not send processed data to
+    #   queues anymore, and only wait for the final `None` before exiting.
+    #   `done_event` isn't strictly needed. I.e., we can just check for `None`
+    #   from the input queue, but it allows us to skip wasting resources
+    #   processing data if we are already shutting down.
+    #
+    # `pin_memory_thread_done_event`:
+    #   A `threading.Event` for a similar purpose to that of
+    #   `workers_done_event`, but is for the `pin_memory_thread`. The reason
+    #   that separate events are needed is that `pin_memory_thread` reads from
+    #   the output queue of the workers. But the workers, upon seeing that
+    #   `workers_done_event` is set, only wants to see the final `None`, and is
+    #   not required to flush all data in the output queue (e.g., it may call
+    #   `cancel_join_thread` on that queue if its `IterableDataset` iterator
+    #   happens to exhaust coincidentally, which is out of the control of the
+    #   main process). Thus, since we will exit `pin_memory_thread` before the
+    #   workers (see below), two separete events are used.
+    #
+    # NOTE: In short, the protocol is that the main process will set these
+    #       `done_event`s and then the corresponding processes/threads a `None`,
+    #       and that they may exit at any time after receiving the `None`.
+    #
+    # NOTE: Using `None` as the final signal is valid, since normal data will
+    #       always be a 2-tuple with the 1st element being the index of the data
+    #       transferred (different from dataset index/key), and the 2nd being
+    #       either the dataset key or the data sample (depending on which part
+    #       of the data model the queue is at).
+    #
+    # [ worker processes ]
+    #   While loader process is alive:
+    #     Get from `index_queue`.
+    #       If get anything else,
+    #          Check `workers_done_event`.
+    #            If set, continue to next iteration
+    #                    i.e., keep getting until see the `None`, then exit.
+    #            Otherwise, process data:
+    #                If is fetching from an `IterableDataset` and the iterator
+    #                    is exhausted, send an `_IterableDatasetStopIteration`
+    #                    object to signal iteration end. The main process, upon
+    #                    receiving such an object, will send `None` to this
+    #                    worker and not use the corresponding `index_queue`
+    #                    anymore.
+    #       If timed out,
+    #          No matter `workers_done_event` is set (still need to see `None`)
+    #          or not, must continue to next iteration.
+    #   (outside loop)
+    #   If `workers_done_event` is set,  (this can be False with `IterableDataset`)
+    #     `data_queue.cancel_join_thread()`.  (Everything is ending here:
+    #                                          main process won't read from it;
+    #                                          other workers will also call
+    #                                          `cancel_join_thread`.)
+    #
+    # [ pin_memory_thread ]
+    #   # No need to check main thread. If this thread is alive, the main loader
+    #   # thread must be alive, because this thread is set as daemonic.
+    #   While `pin_memory_thread_done_event` is not set:
+    #     Get from `index_queue`.
+    #       If timed out, continue to get in the next iteration.
+    #       Otherwise, process data.
+    #       While `pin_memory_thread_done_event` is not set:
+    #         Put processed data to `data_queue` (a `queue.Queue` with blocking put)
+    #         If timed out, continue to put in the next iteration.
+    #         Otherwise, break, i.e., continuing to the out loop.
+    #
+    #   NOTE: we don't check the status of the main thread because
+    #           1. if the process is killed by fatal signal, `pin_memory_thread`
+    #              ends.
+    #           2. in other cases, either the cleaning-up in __del__ or the
+    #              automatic exit of daemonic thread will take care of it.
+    #              This won't busy-wait either because `.get(timeout)` does not
+    #              busy-wait.
+    #
+    # [ main process ]
+    #   In the DataLoader Iter's `__del__`
+    #     b. Exit `pin_memory_thread`
+    #          i.   Set `pin_memory_thread_done_event`.
+    #          ii   Put `None` in `worker_result_queue`.
+    #          iii. Join the `pin_memory_thread`.
+    #          iv.  `worker_result_queue.cancel_join_thread()`.
+    #
+    #     c. Exit the workers.
+    #          i.   Set `workers_done_event`.
+    #          ii.  Put `None` in each worker's `index_queue`.
+    #          iii. Join the workers.
+    #          iv.  Call `.cancel_join_thread()` on each worker's `index_queue`.
+    #
+    #        NOTE: (c) is better placed after (b) because it may leave corrupted
+    #              data in `worker_result_queue`, which `pin_memory_thread`
+    #              reads from, in which case the `pin_memory_thread` can only
+    #              happen at timeing out, which is slow. Nonetheless, same thing
+    #              happens if a worker is killed by signal at unfortunate times,
+    #              but in other cases, we are better off having a non-corrupted
+    #              `worker_result_queue` for `pin_memory_thread`.
+    #
+    #   NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
+    #         can be omitted
+    #
+    # NB: `done_event`s isn't strictly needed. E.g., we can just check for
+    #     `None` from `index_queue`, but it allows us to skip wasting resources
+    #     processing indices already in `index_queue` if we are already shutting
+    #     down.
+
+    def __init__(self, loader):
+        super().__init__(loader)
+
+        self._prefetch_factor = loader.prefetch_factor
+
+        assert self._num_workers > 0
+        assert self._prefetch_factor > 0
+
+        if loader.multiprocessing_context is None:
+            multiprocessing_context = multiprocessing
+        else:
+            multiprocessing_context = loader.multiprocessing_context
+
+        self._worker_init_fn = loader.worker_init_fn
+
+        # Adds forward compatibilities so classic DataLoader can work with DataPipes:
+        #   Additional worker init function will take care of sharding in MP and Distributed
+        if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
+            self._worker_init_fn = functools.partial(
+                _sharding_worker_init_fn,
+                self._worker_init_fn,
+                self._world_size,
+                self._rank,
+            )
+
+        # No certainty which module multiprocessing_context is
+        self._worker_result_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
+        self._worker_pids_set = False
+        self._shutdown = False
+        self._workers_done_event = multiprocessing_context.Event()
+
+        self._index_queues = []
+        self._workers = []
+        for i in range(self._num_workers):
+            # No certainty which module multiprocessing_context is
+            index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
+            # Need to `cancel_join_thread` here!
+            # See sections (2) and (3b) above.
+            index_queue.cancel_join_thread()
+            w = multiprocessing_context.Process(
+                target=_worker_loop,
+                args=(
+                    self._dataset_kind,
+                    self._dataset,
+                    index_queue,
+                    self._worker_result_queue,
+                    self._workers_done_event,
+                    self._auto_collation,
+                    self._collate_fn,
+                    self._drop_last,
+                    self._base_seed,
+                    self._worker_init_fn,
+                    i,
+                    self._num_workers,
+                    self._persistent_workers,
+                    self._shared_seed,
+                ),
+            )
+            w.daemon = True
+            # NB: Process.start() actually take some time as it needs to
+            #     start a process and pass the arguments over via a pipe.
+            #     Therefore, we only add a worker to self._workers list after
+            #     it started, so that we do not call .join() if program dies
+            #     before it starts, and __del__ tries to join but will get:
+            #     AssertionError: can only join a started process.
+            w.start()
+            self._index_queues.append(index_queue)
+            self._workers.append(w)
+
+        if self._pin_memory:
+            self._pin_memory_thread_done_event = threading.Event()
+
+            # Queue is not type-annotated
+            self._data_queue = queue.Queue()  # type: ignore[var-annotated]
+            if self._pin_memory_device == "xpu":
+                current_device = torch.xpu.current_device()  # type: ignore[attr-defined]
+            else:
+                current_device = torch.cuda.current_device()  # choose cuda for default
+            pin_memory_thread = threading.Thread(
+                target=_utils.pin_memory._pin_memory_loop,
+                args=(
+                    self._worker_result_queue,
+                    self._data_queue,
+                    current_device,
+                    self._pin_memory_thread_done_event,
+                    self._pin_memory_device,
+                ),
+            )
+            pin_memory_thread.daemon = True
+            pin_memory_thread.start()
+            # Similar to workers (see comment above), we only register
+            # pin_memory_thread once it is started.
+            self._pin_memory_thread = pin_memory_thread
+        else:
+            self._data_queue = self._worker_result_queue
+
+        # In some rare cases, persistent workers (daemonic processes)
+        # would be terminated before `__del__` of iterator is invoked
+        # when main process exits
+        # It would cause failure when pin_memory_thread tries to read
+        # corrupted data from worker_result_queue
+        # atexit is used to shutdown thread and child processes in the
+        # right sequence before main process exits
+        if self._persistent_workers and self._pin_memory:
+            import atexit
+
+            for w in self._workers:
+                atexit.register(_MultiProcessingDataLoaderIter._clean_up_worker, w)
+
+        # .pid can be None only before process is spawned (not the case, so ignore)
+        _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore[misc]
+        _utils.signal_handling._set_SIGCHLD_handler()
+        self._worker_pids_set = True
+        self._reset(loader, first_iter=True)
+
+    def _reset(self, loader, first_iter=False):
+        super()._reset(loader, first_iter)
+        self._send_idx = 0  # idx of the next task to be sent to workers
+        self._rcvd_idx = 0  # idx of the next task to be returned in __next__
+        # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
+        # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
+        #                  \ (worker_id, data)   if data is already fetched (out-of-order)
+        self._task_info = {}
+        self._tasks_outstanding = (
+            0  # always equal to count(v for v in task_info.values() if len(v) == 1)
+        )
+        # A list of booleans representing whether each worker still has work to
+        # do, i.e., not having exhausted its iterable dataset object. It always
+        # contains all `True`s if not using an iterable-style dataset
+        # (i.e., if kind != Iterable).
+        # Not that this indicates that a worker still has work to do *for this epoch*.
+        # It does not mean that a worker is dead. In case of `_persistent_workers`,
+        # the worker will be reset to available in the next epoch.
+        self._workers_status = [True for i in range(self._num_workers)]
+        # Reset the worker queue cycle so it resumes next epoch at worker 0
+        self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
+        # We resume the prefetching in case it was enabled
+        if not first_iter:
+            for idx in range(self._num_workers):
+                self._index_queues[idx].put(
+                    _utils.worker._ResumeIteration(self._shared_seed)
+                )
+            resume_iteration_cnt = self._num_workers
+            while resume_iteration_cnt > 0:
+                return_idx, return_data = self._get_data()
+                if isinstance(return_idx, _utils.worker._ResumeIteration):
+                    assert return_data is None
+                    resume_iteration_cnt -= 1
+        # prime the prefetch loop
+        for _ in range(self._prefetch_factor * self._num_workers):
+            self._try_put_index()
+
+    def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
+        # Tries to fetch data from `self._data_queue` once for a given timeout.
+        # This can also be used as inner loop of fetching without timeout, with
+        # the sender status as the loop condition.
+        #
+        # This raises a `RuntimeError` if any worker died expectedly. This error
+        # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
+        # (only for non-Windows platforms), or the manual check below on errors
+        # and timeouts.
+        #
+        # Returns a 2-tuple:
+        #   (bool: whether successfully get data, any: data if successful else None)
+        try:
+            data = self._data_queue.get(timeout=timeout)
+            return (True, data)
+        except Exception as e:
+            # At timeout and error, we manually check whether any worker has
+            # failed. Note that this is the only mechanism for Windows to detect
+            # worker failures.
+            failed_workers = []
+            for worker_id, w in enumerate(self._workers):
+                if self._workers_status[worker_id] and not w.is_alive():
+                    failed_workers.append(w)
+                    self._mark_worker_as_unavailable(worker_id)
+            if len(failed_workers) > 0:
+                pids_str = ", ".join(str(w.pid) for w in failed_workers)
+                raise RuntimeError(
+                    "DataLoader worker (pid(s) {}) exited unexpectedly".format(pids_str)
+                ) from e
+            if isinstance(e, queue.Empty):
+                return (False, None)
+            import errno
+            import tempfile
+
+            try:
+                # Raise an exception if we are this close to the FDs limit.
+                # Apparently, trying to open only one file is not a sufficient
+                # test.
+                # See NOTE [ DataLoader on Linux and open files limit ]
+                fds_limit_margin = 10
+                fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
+            except OSError as e:
+                if e.errno == errno.EMFILE:
+                    raise RuntimeError(
+                        "Too many open files. Communication with the"
+                        " workers is no longer possible. Please increase the"
+                        " limit using `ulimit -n` in the shell or change the"
+                        " sharing strategy by calling"
+                        " `torch.multiprocessing.set_sharing_strategy('file_system')`"
+                        " at the beginning of your code"
+                    ) from None
+            raise
+
+    # NOTE [ DataLoader on Linux and open files limit ]
+    #
+    # On Linux when DataLoader is used with multiprocessing we pass the data between
+    # the root process and the workers through SHM files. We remove those files from
+    # the filesystem as soon as they are created and keep them alive by
+    # passing around their file descriptors through AF_UNIX sockets. (See
+    # docs/source/multiprocessing.rst and 'Multiprocessing Technical Notes` in
+    # the wiki (https://github.com/pytorch/pytorch/wiki).)
+    #
+    # This sometimes leads us to exceeding the open files limit. When that happens,
+    # and the offending file descriptor is coming over a socket, the `socket` Python
+    # package silently strips the file descriptor from the message, setting only the
+    # `MSG_CTRUNC` flag (which might be a bit misleading since the manpage says that
+    # it _indicates that some control data were discarded due to lack of space in
+    # the buffer for ancillary data_). This might reflect the C implementation of
+    # AF_UNIX sockets.
+    #
+    # This behaviour can be reproduced with the script and instructions at the
+    # bottom of this note.
+    #
+    # When that happens, the standard Python `multiprocessing` (and not
+    # `torch.multiprocessing`) raises a `RuntimeError: received 0 items of ancdata`
+    #
+    # Sometimes, instead of the FD being stripped, you may get an `OSError:
+    # Too many open files`, both in the script below and in DataLoader. However,
+    # this is rare and seems to be nondeterministic.
+    #
+    #
+    #   #!/usr/bin/env python3
+    #   import sys
+    #   import socket
+    #   import os
+    #   import array
+    #   import shutil
+    #   import socket
+    #
+    #
+    #   if len(sys.argv) != 4:
+    #       print("Usage: ", sys.argv[0], " tmp_dirname iteration (send|recv)")
+    #       sys.exit(1)
+    #
+    #   if __name__ == '__main__':
+    #       dirname = sys.argv[1]
+    #       sock_path = dirname + "/sock"
+    #       iterations = int(sys.argv[2])
+    #       def dummy_path(i):
+    #           return dirname + "/" + str(i) + ".dummy"
+    #
+    #
+    #       if sys.argv[3] == 'send':
+    #           while not os.path.exists(sock_path):
+    #               pass
+    #           client = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+    #           client.connect(sock_path)
+    #           for i in range(iterations):
+    #               fd = os.open(dummy_path(i), os.O_WRONLY | os.O_CREAT)
+    #               ancdata = array.array('i', [fd])
+    #               msg = bytes([i % 256])
+    #               print("Sending fd ", fd, " (iteration #", i, ")")
+    #               client.sendmsg([msg], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, ancdata)])
+    #
+    #
+    #       else:
+    #           assert sys.argv[3] == 'recv'
+    #
+    #           if os.path.exists(dirname):
+    #               raise Exception("Directory exists")
+    #
+    #           os.mkdir(dirname)
+    #
+    #           print("Opening socket...")
+    #           server = socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM)
+    #           server.bind(sock_path)
+    #
+    #           print("Listening...")
+    #           for i in range(iterations):
+    #               a = array.array('i')
+    #               msg, ancdata, flags, addr = server.recvmsg(1, socket.CMSG_SPACE(a.itemsize))
+    #               assert(len(ancdata) == 1)
+    #               cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+    #               a.frombytes(cmsg_data)
+    #               print("Received fd ", a[0], " (iteration #", i, ")")
+    #
+    #           shutil.rmtree(dirname)
+    #
+    # Steps to reproduce:
+    #
+    # 1. Run two shells and set lower file descriptor limit in the receiving one:
+    # (shell1) ulimit -n 1020
+    # (shell2) ulimit -n 1022
+    #
+    # 2. Run the script above with the `recv` option in the first shell
+    # (shell1) ./test_socket.py sock_tmp 1017 recv
+    #
+    # 3. Run the script with the `send` option in the second shell:
+    # (shell2) ./test_socket.py sock_tmp 1017 send
+
+    def _get_data(self):
+        # Fetches data from `self._data_queue`.
+        #
+        # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
+        # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
+        # in a loop. This is the only mechanism to detect worker failures for
+        # Windows. For other platforms, a SIGCHLD handler is also used for
+        # worker failure detection.
+        #
+        # If `pin_memory=True`, we also need check if `pin_memory_thread` had
+        # died at timeouts.
+        if self._timeout > 0:
+            success, data = self._try_get_data(self._timeout)
+            if success:
+                return data
+            else:
+                raise RuntimeError(
+                    "DataLoader timed out after {} seconds".format(self._timeout)
+                )
+        elif self._pin_memory:
+            while self._pin_memory_thread.is_alive():
+                success, data = self._try_get_data()
+                if success:
+                    return data
+            else:
+                # while condition is false, i.e., pin_memory_thread died.
+                raise RuntimeError("Pin memory thread exited unexpectedly")
+            # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
+            # need to call `.task_done()` because we don't use `.join()`.
+        else:
+            while True:
+                success, data = self._try_get_data()
+                if success:
+                    return data
+
+    def _next_data(self):
+        while True:
+            # If the worker responsible for `self._rcvd_idx` has already ended
+            # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
+            # we try to advance `self._rcvd_idx` to find the next valid index.
+            #
+            # This part needs to run in the loop because both the `self._get_data()`
+            # call and `_IterableDatasetStopIteration` check below can mark
+            # extra worker(s) as dead.
+            while self._rcvd_idx < self._send_idx:
+                info = self._task_info[self._rcvd_idx]
+                worker_id = info[0]
+                if (
+                    len(info) == 2 or self._workers_status[worker_id]
+                ):  # has data or is still active
+                    break
+                del self._task_info[self._rcvd_idx]
+                self._rcvd_idx += 1
+            else:
+                # no valid `self._rcvd_idx` is found (i.e., didn't break)
+                if not self._persistent_workers:
+                    self._shutdown_workers()
+                raise StopIteration
+
+            # Now `self._rcvd_idx` is the batch index we want to fetch
+
+            # Check if the next sample has already been generated
+            if len(self._task_info[self._rcvd_idx]) == 2:
+                data = self._task_info.pop(self._rcvd_idx)[1]
+                return self._process_data(data)
+
+            assert not self._shutdown and self._tasks_outstanding > 0
+            idx, data = self._get_data()
+            self._tasks_outstanding -= 1
+            if self._dataset_kind == _DatasetKind.Iterable:
+                # Check for _IterableDatasetStopIteration
+                if isinstance(data, _utils.worker._IterableDatasetStopIteration):
+                    if self._persistent_workers:
+                        self._workers_status[data.worker_id] = False
+                    else:
+                        self._mark_worker_as_unavailable(data.worker_id)
+                    self._try_put_index()
+                    continue
+
+            if idx != self._rcvd_idx:
+                # store out-of-order samples
+                self._task_info[idx] += (data,)
+            else:
+                del self._task_info[idx]
+                return self._process_data(data)
+
+    def _try_put_index(self):
+        assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
+
+        try:
+            index = self._next_index()
+        except StopIteration:
+            return
+        for _ in range(self._num_workers):  # find the next active worker, if any
+            worker_queue_idx = next(self._worker_queue_idx_cycle)
+            if self._workers_status[worker_queue_idx]:
+                break
+        else:
+            # not found (i.e., didn't break)
+            return
+
+        self._index_queues[worker_queue_idx].put((self._send_idx, index))
+        self._task_info[self._send_idx] = (worker_queue_idx,)
+        self._tasks_outstanding += 1
+        self._send_idx += 1
+
+    def _process_data(self, data):
+        self._rcvd_idx += 1
+        self._try_put_index()
+        if isinstance(data, ExceptionWrapper):
+            data.reraise()
+        return data
+
+    def _mark_worker_as_unavailable(self, worker_id, shutdown=False):
+        # Mark a worker as having finished its work e.g., due to
+        # exhausting an `IterableDataset`. This should be used only when this
+        # `_MultiProcessingDataLoaderIter` is going to continue running.
+
+        assert self._workers_status[worker_id] or (
+            self._persistent_workers and shutdown
+        )
+
+        # Signal termination to that specific worker.
+        q = self._index_queues[worker_id]
+        # Indicate that no more data will be put on this queue by the current
+        # process.
+        q.put(None)
+
+        # Note that we don't actually join the worker here, nor do we remove the
+        # worker's pid from C side struct because (1) joining may be slow, and
+        # (2) since we don't join, the worker may still raise error, and we
+        # prefer capturing those, rather than ignoring them, even though they
+        # are raised after the worker has finished its job.
+        # Joinning is deferred to `_shutdown_workers`, which it is called when
+        # all workers finish their jobs (e.g., `IterableDataset` replicas) or
+        # when this iterator is garbage collected.
+
+        self._workers_status[worker_id] = False
+
+        assert self._workers_done_event.is_set() == shutdown
+
+    def _shutdown_workers(self):
+        # Called when shutting down this `_MultiProcessingDataLoaderIter`.
+        # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
+        # the logic of this function.
+        if (
+            _utils is None
+            or _utils.python_exit_status is True
+            or _utils.python_exit_status is None
+        ):
+            # See (2) of the note. If Python is shutting down, do no-op.
+            return
+        # Normal exit when last reference is gone / iterator is depleted.
+        # See (1) and the second half of the note.
+        if not self._shutdown:
+            self._shutdown = True
+            try:
+                # Normal exit when last reference is gone / iterator is depleted.
+                # See (1) and the second half of the note.
+
+                # Exit `pin_memory_thread` first because exiting workers may leave
+                # corrupted data in `worker_result_queue` which `pin_memory_thread`
+                # reads from.
+                if hasattr(self, "_pin_memory_thread"):
+                    # Use hasattr in case error happens before we set the attribute.
+                    self._pin_memory_thread_done_event.set()
+                    # Send something to pin_memory_thread in case it is waiting
+                    # so that it can wake up and check `pin_memory_thread_done_event`
+                    self._worker_result_queue.put((None, None))
+                    self._pin_memory_thread.join()
+                    self._worker_result_queue.cancel_join_thread()
+                    self._worker_result_queue.close()
+
+                # Exit workers now.
+                self._workers_done_event.set()
+                for worker_id in range(len(self._workers)):
+                    # Get number of workers from `len(self._workers)` instead of
+                    # `self._num_workers` in case we error before starting all
+                    # workers.
+                    # If we are using workers_status with persistent_workers
+                    # we have to shut it down because the worker is paused
+                    if self._persistent_workers or self._workers_status[worker_id]:
+                        self._mark_worker_as_unavailable(worker_id, shutdown=True)
+                for w in self._workers:
+                    # We should be able to join here, but in case anything went
+                    # wrong, we set a timeout and if the workers fail to join,
+                    # they are killed in the `finally` block.
+                    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
+                for q in self._index_queues:
+                    q.cancel_join_thread()
+                    q.close()
+            finally:
+                # Even though all this function does is putting into queues that
+                # we have called `cancel_join_thread` on, weird things can
+                # happen when a worker is killed by a signal, e.g., hanging in
+                # `Event.set()`. So we need to guard this with SIGCHLD handler,
+                # and remove pids from the C side data structure only at the
+                # end.
+                #
+                if self._worker_pids_set:
+                    _utils.signal_handling._remove_worker_pids(id(self))
+                    self._worker_pids_set = False
+                for w in self._workers:
+                    if w.is_alive():
+                        # Existing mechanisms try to make the workers exit
+                        # peacefully, but in case that we unfortunately reach
+                        # here, which we shouldn't, (e.g., pytorch/pytorch#39570),
+                        # we kill the worker.
+                        w.terminate()
+
+    # staticmethod is used to remove reference to `_MultiProcessingDataLoaderIter`
+    @staticmethod
+    def _clean_up_worker(w):
+        try:
+            w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
+        finally:
+            if w.is_alive():
+                w.terminate()
+
+    def __del__(self):
+        self._shutdown_workers()
diff --git a/src/efficientvit/apps/data_provider/random_resolution/_data_worker.py b/src/efficientvit/apps/data_provider/random_resolution/_data_worker.py
new file mode 100644
index 0000000000000000000000000000000000000000..be41a3ec24d4d934a06ffa82eac24e432d27d619
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/random_resolution/_data_worker.py
@@ -0,0 +1,377 @@
+r""""This file is based on torch/utils/data/_utils/worker.py
+
+Contains definitions of the methods used by the _BaseDataLoaderIter workers.
+These **needs** to be in global scope since Py2 doesn't support serializing
+static methods.
+"""
+
+import os
+import queue
+import random
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from torch._utils import ExceptionWrapper
+from torch.utils.data._utils import (HAS_NUMPY, IS_WINDOWS,
+                                     MP_STATUS_CHECK_INTERVAL, signal_handling)
+
+if TYPE_CHECKING:
+    from torch.utils.data import Dataset
+
+from .controller import RRSController
+
+if IS_WINDOWS:
+    import ctypes
+    from ctypes.wintypes import BOOL, DWORD, HANDLE
+
+    # On Windows, the parent ID of the worker process remains unchanged when the manager process
+    # is gone, and the only way to check it through OS is to let the worker have a process handle
+    # of the manager and ask if the process status has changed.
+    class ManagerWatchdog:
+        def __init__(self):
+            self.manager_pid = os.getppid()
+
+            # mypy cannot detect this code is windows only
+            self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)  # type: ignore[attr-defined]
+            self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
+            self.kernel32.OpenProcess.restype = HANDLE
+            self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
+            self.kernel32.WaitForSingleObject.restype = DWORD
+
+            # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
+            SYNCHRONIZE = 0x00100000
+            self.manager_handle = self.kernel32.OpenProcess(
+                SYNCHRONIZE, 0, self.manager_pid
+            )
+
+            if not self.manager_handle:
+                raise ctypes.WinError(ctypes.get_last_error())  # type: ignore[attr-defined]
+
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
+                self.manager_dead = (
+                    self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
+                )
+            return not self.manager_dead
+
+else:
+
+    class ManagerWatchdog:  # type: ignore[no-redef]
+        def __init__(self):
+            self.manager_pid = os.getppid()
+            self.manager_dead = False
+
+        def is_alive(self):
+            if not self.manager_dead:
+                self.manager_dead = os.getppid() != self.manager_pid
+            return not self.manager_dead
+
+
+_worker_info = None
+
+
+class WorkerInfo:
+    id: int
+    num_workers: int
+    seed: int
+    dataset: "Dataset"
+    __initialized = False
+
+    def __init__(self, **kwargs):
+        for k, v in kwargs.items():
+            setattr(self, k, v)
+        self.__keys = tuple(kwargs.keys())
+        self.__initialized = True
+
+    def __setattr__(self, key, val):
+        if self.__initialized:
+            raise RuntimeError(
+                "Cannot assign attributes to {} objects".format(self.__class__.__name__)
+            )
+        return super().__setattr__(key, val)
+
+    def __repr__(self):
+        items = []
+        for k in self.__keys:
+            items.append("{}={}".format(k, getattr(self, k)))
+        return "{}({})".format(self.__class__.__name__, ", ".join(items))
+
+
+def get_worker_info() -> Optional[WorkerInfo]:
+    r"""Returns the information about the current
+    :class:`~torch.utils.data.DataLoader` iterator worker process.
+
+    When called in a worker, this returns an object guaranteed to have the
+    following attributes:
+
+    * :attr:`id`: the current worker id.
+    * :attr:`num_workers`: the total number of workers.
+    * :attr:`seed`: the random seed set for the current worker. This value is
+      determined by main process RNG and the worker id. See
+      :class:`~torch.utils.data.DataLoader`'s documentation for more details.
+    * :attr:`dataset`: the copy of the dataset object in **this** process. Note
+      that this will be a different object in a different process than the one
+      in the main process.
+
+    When called in the main process, this returns ``None``.
+
+    .. note::
+       When used in a :attr:`worker_init_fn` passed over to
+       :class:`~torch.utils.data.DataLoader`, this method can be useful to
+       set up each worker process differently, for instance, using ``worker_id``
+       to configure the ``dataset`` object to only read a specific fraction of a
+       sharded dataset, or use ``seed`` to seed other libraries used in dataset
+       code.
+    """
+    return _worker_info
+
+
+r"""Dummy class used to signal the end of an IterableDataset"""
+
+
+@dataclass(frozen=True)
+class _IterableDatasetStopIteration:
+    worker_id: int
+
+
+r"""Dummy class used to resume the fetching when worker reuse is enabled"""
+
+
+@dataclass(frozen=True)
+class _ResumeIteration:
+    seed: Optional[int] = None
+
+
+# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
+# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
+# It's MIT licensed, here is the copyright:
+
+# Copyright (c) 2015 Melissa E. O'Neill
+# Copyright (c) 2019 NumPy Developers
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+# This function generates an array of int32 as the seed for
+# `numpy.random`, in order to prevent state collision due to same
+# seed and algorithm for `numpy.random` and `random` modules.
+def _generate_state(base_seed, worker_id):
+    INIT_A = 0x43B0D7E5
+    MULT_A = 0x931E8875
+    INIT_B = 0x8B51F9DD
+    MULT_B = 0x58F38DED
+    MIX_MULT_L = 0xCA01F9DD
+    MIX_MULT_R = 0x4973F715
+    XSHIFT = 4 * 8 // 2
+    MASK32 = 0xFFFFFFFF
+
+    entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
+    pool = [0] * 4
+
+    hash_const_A = INIT_A
+
+    def hash(value):
+        nonlocal hash_const_A
+        value = (value ^ hash_const_A) & MASK32
+        hash_const_A = (hash_const_A * MULT_A) & MASK32
+        value = (value * hash_const_A) & MASK32
+        value = (value ^ (value >> XSHIFT)) & MASK32
+        return value
+
+    def mix(x, y):
+        result_x = (MIX_MULT_L * x) & MASK32
+        result_y = (MIX_MULT_R * y) & MASK32
+        result = (result_x - result_y) & MASK32
+        result = (result ^ (result >> XSHIFT)) & MASK32
+        return result
+
+    # Add in the entropy to the pool.
+    for i in range(len(pool)):
+        pool[i] = hash(entropy[i])
+
+    # Mix all bits together so late bits can affect earlier bits.
+    for i_src in range(len(pool)):
+        for i_dst in range(len(pool)):
+            if i_src != i_dst:
+                pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))
+
+    hash_const_B = INIT_B
+    state = []
+    for i_dst in range(4):
+        data_val = pool[i_dst]
+        data_val = (data_val ^ hash_const_B) & MASK32
+        hash_const_B = (hash_const_B * MULT_B) & MASK32
+        data_val = (data_val * hash_const_B) & MASK32
+        data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
+        state.append(data_val)
+    return state
+
+
+def _worker_loop(
+    dataset_kind,
+    dataset,
+    index_queue,
+    data_queue,
+    done_event,
+    auto_collation,
+    collate_fn,
+    drop_last,
+    base_seed,
+    init_fn,
+    worker_id,
+    num_workers,
+    persistent_workers,
+    shared_seed,
+):
+    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
+    # logic of this function.
+
+    try:
+        # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
+        # module's handlers are executed after Python returns from C low-level
+        # handlers, likely when the same fatal signal had already happened
+        # again.
+        # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
+        signal_handling._set_worker_signal_handlers()
+
+        torch.set_num_threads(1)
+        seed = base_seed + worker_id
+        random.seed(seed)
+        torch.manual_seed(seed)
+        if HAS_NUMPY:
+            np_seed = _generate_state(base_seed, worker_id)
+            import numpy as np
+
+            np.random.seed(np_seed)
+
+        from torch.utils.data import IterDataPipe
+        from torch.utils.data.graph_settings import apply_random_seed
+
+        shared_rng = torch.Generator()
+        if isinstance(dataset, IterDataPipe):
+            assert shared_seed is not None
+            shared_rng.manual_seed(shared_seed)
+            dataset = apply_random_seed(dataset, shared_rng)
+
+        global _worker_info
+        _worker_info = WorkerInfo(
+            id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset
+        )
+
+        from torch.utils.data import _DatasetKind
+
+        init_exception = None
+
+        try:
+            if init_fn is not None:
+                init_fn(worker_id)
+
+            fetcher = _DatasetKind.create_fetcher(
+                dataset_kind, dataset, auto_collation, collate_fn, drop_last
+            )
+        except Exception:
+            init_exception = ExceptionWrapper(
+                where="in DataLoader worker process {}".format(worker_id)
+            )
+
+        # When using Iterable mode, some worker can exit earlier than others due
+        # to the IterableDataset behaving differently for different workers.
+        # When such things happen, an `_IterableDatasetStopIteration` object is
+        # sent over to the main process with the ID of this worker, so that the
+        # main process won't send more tasks to this worker, and will send
+        # `None` to this worker to properly exit it.
+        #
+        # Note that we cannot set `done_event` from a worker as it is shared
+        # among all processes. Instead, we set the `iteration_end` flag to
+        # signify that the iterator is exhausted. When either `done_event` or
+        # `iteration_end` is set, we skip all processing step and just wait for
+        # `None`.
+        iteration_end = False
+
+        watchdog = ManagerWatchdog()
+
+        while watchdog.is_alive():
+            try:
+                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
+            except queue.Empty:
+                continue
+            if isinstance(r, _ResumeIteration):
+                # Acknowledge the main process
+                data_queue.put((r, None))
+                iteration_end = False
+
+                if isinstance(dataset, IterDataPipe):
+                    assert r.seed is not None
+                    shared_rng.manual_seed(r.seed)
+                    dataset = apply_random_seed(dataset, shared_rng)
+
+                # Recreate the fetcher for worker-reuse policy
+                fetcher = _DatasetKind.create_fetcher(
+                    dataset_kind, dataset, auto_collation, collate_fn, drop_last
+                )
+                continue
+            elif r is None:
+                # Received the final signal
+                assert done_event.is_set() or iteration_end
+                break
+            elif done_event.is_set() or iteration_end:
+                # `done_event` is set. But I haven't received the final signal
+                # (None) yet. I will keep continuing until get it, and skip the
+                # processing steps.
+                continue
+            idx, index = r
+            """ Added """
+            RRSController.sample_resolution(batch_id=idx)
+            """ Added """
+            data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
+            if init_exception is not None:
+                data = init_exception
+                init_exception = None
+            else:
+                try:
+                    data = fetcher.fetch(index)
+                except Exception as e:
+                    if (
+                        isinstance(e, StopIteration)
+                        and dataset_kind == _DatasetKind.Iterable
+                    ):
+                        data = _IterableDatasetStopIteration(worker_id)
+                        # Set `iteration_end`
+                        #   (1) to save future `next(...)` calls, and
+                        #   (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
+                        iteration_end = True
+                    else:
+                        # It is important that we don't store exc_info in a variable.
+                        # `ExceptionWrapper` does the correct thing.
+                        # See NOTE [ Python Traceback Reference Cycle Problem ]
+                        data = ExceptionWrapper(
+                            where="in DataLoader worker process {}".format(worker_id)
+                        )
+            data_queue.put((idx, data))
+            del data, idx, index, r  # save memory
+    except KeyboardInterrupt:
+        # Main process will raise KeyboardInterrupt anyways.
+        pass
+    if done_event.is_set():
+        data_queue.cancel_join_thread()
+        data_queue.close()
diff --git a/src/efficientvit/apps/data_provider/random_resolution/controller.py b/src/efficientvit/apps/data_provider/random_resolution/controller.py
new file mode 100644
index 0000000000000000000000000000000000000000..b62134a5e829c8322806366d98760f1d01c30678
--- /dev/null
+++ b/src/efficientvit/apps/data_provider/random_resolution/controller.py
@@ -0,0 +1,94 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+
+import torch
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as F
+
+from src.efficientvit.models.utils import torch_random_choices
+
+__all__ = [
+    "RRSController",
+    "get_interpolate",
+    "MyRandomResizedCrop",
+]
+
+
+class RRSController:
+    ACTIVE_SIZE = (224, 224)
+    IMAGE_SIZE_LIST = [(224, 224)]
+
+    CHOICE_LIST = None
+
+    @staticmethod
+    def get_candidates() -> list[tuple[int, int]]:
+        return copy.deepcopy(RRSController.IMAGE_SIZE_LIST)
+
+    @staticmethod
+    def sample_resolution(batch_id: int) -> None:
+        RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id]
+
+    @staticmethod
+    def set_epoch(epoch: int, batch_per_epoch: int) -> None:
+        g = torch.Generator()
+        g.manual_seed(epoch)
+        RRSController.CHOICE_LIST = torch_random_choices(
+            RRSController.get_candidates(),
+            g,
+            batch_per_epoch,
+        )
+
+
+def get_interpolate(name: str) -> F.InterpolationMode:
+    mapping = {
+        "nearest": F.InterpolationMode.NEAREST,
+        "bilinear": F.InterpolationMode.BILINEAR,
+        "bicubic": F.InterpolationMode.BICUBIC,
+        "box": F.InterpolationMode.BOX,
+        "hamming": F.InterpolationMode.HAMMING,
+        "lanczos": F.InterpolationMode.LANCZOS,
+    }
+    if name in mapping:
+        return mapping[name]
+    elif name == "random":
+        return torch_random_choices(
+            [
+                F.InterpolationMode.NEAREST,
+                F.InterpolationMode.BILINEAR,
+                F.InterpolationMode.BICUBIC,
+                F.InterpolationMode.BOX,
+                F.InterpolationMode.HAMMING,
+                F.InterpolationMode.LANCZOS,
+            ],
+        )
+    else:
+        raise NotImplementedError
+
+
+class MyRandomResizedCrop(transforms.RandomResizedCrop):
+    def __init__(
+        self,
+        scale=(0.08, 1.0),
+        ratio=(3.0 / 4.0, 4.0 / 3.0),
+        interpolation: str = "random",
+    ):
+        super(MyRandomResizedCrop, self).__init__(224, scale, ratio)
+        self.interpolation = interpolation
+
+    def forward(self, img: torch.Tensor) -> torch.Tensor:
+        i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio))
+        target_size = RRSController.ACTIVE_SIZE
+        return F.resized_crop(
+            img, i, j, h, w, list(target_size), get_interpolate(self.interpolation)
+        )
+
+    def __repr__(self) -> str:
+        format_string = self.__class__.__name__
+        format_string += f"(\n\tsize={RRSController.get_candidates()},\n"
+        format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n"
+        format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n"
+        format_string += f"\tinterpolation={self.interpolation})"
+        return format_string
diff --git a/src/efficientvit/apps/setup.py b/src/efficientvit/apps/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cd36c4205ad64f4ccb8dfc1efd9e5ebb0ce8e3d
--- /dev/null
+++ b/src/efficientvit/apps/setup.py
@@ -0,0 +1,141 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+import time
+from copy import deepcopy
+
+import torch.backends.cudnn
+import torch.distributed
+import torch.nn as nn
+
+from src.efficientvit.apps.data_provider import DataProvider
+from src.efficientvit.apps.trainer.run_config import RunConfig
+from src.efficientvit.apps.utils import (dist_init, dump_config,
+                                     get_dist_local_rank, get_dist_rank,
+                                     get_dist_size, init_modules, is_master,
+                                     load_config, partial_update_config,
+                                     zero_last_gamma)
+from src.efficientvit.models.utils import (build_kwargs_from_config,
+                                       load_state_dict_from_file)
+
+__all__ = [
+    "save_exp_config",
+    "setup_dist_env",
+    "setup_seed",
+    "setup_exp_config",
+    "setup_data_provider",
+    "setup_run_config",
+    "init_model",
+]
+
+
+def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
+    if not is_master():
+        return
+    dump_config(exp_config, os.path.join(path, name))
+
+
+def setup_dist_env(gpu: str or None = None) -> None:
+    if gpu is not None:
+        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
+    if not torch.distributed.is_initialized():
+        dist_init()
+    torch.backends.cudnn.benchmark = True
+    torch.cuda.set_device(get_dist_local_rank())
+
+
+def setup_seed(manual_seed: int, resume: bool) -> None:
+    if resume:
+        manual_seed = int(time.time())
+    manual_seed = get_dist_rank() + manual_seed
+    torch.manual_seed(manual_seed)
+    torch.cuda.manual_seed_all(manual_seed)
+
+
+def setup_exp_config(
+    config_path: str, recursive=True, opt_args: dict or None = None
+) -> dict:
+    # load config
+    if not os.path.isfile(config_path):
+        raise ValueError(config_path)
+
+    fpaths = [config_path]
+    if recursive:
+        extension = os.path.splitext(config_path)[1]
+        while os.path.dirname(config_path) != config_path:
+            config_path = os.path.dirname(config_path)
+            fpath = os.path.join(config_path, "default" + extension)
+            if os.path.isfile(fpath):
+                fpaths.append(fpath)
+        fpaths = fpaths[::-1]
+
+    default_config = load_config(fpaths[0])
+    exp_config = deepcopy(default_config)
+    for fpath in fpaths[1:]:
+        partial_update_config(exp_config, load_config(fpath))
+    # update config via args
+    if opt_args is not None:
+        partial_update_config(exp_config, opt_args)
+
+    return exp_config
+
+
+def setup_data_provider(
+    exp_config: dict,
+    data_provider_classes: list[type[DataProvider]],
+    is_distributed: bool = True,
+) -> DataProvider:
+    dp_config = exp_config["data_provider"]
+    dp_config["num_replicas"] = get_dist_size() if is_distributed else None
+    dp_config["rank"] = get_dist_rank() if is_distributed else None
+    dp_config["test_batch_size"] = (
+        dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
+    )
+    dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
+        "base_batch_size"
+    ]
+
+    data_provider_lookup = {
+        provider.name: provider for provider in data_provider_classes
+    }
+    data_provider_class = data_provider_lookup[dp_config["dataset"]]
+
+    data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
+    data_provider = data_provider_class(**data_provider_kwargs)
+    return data_provider
+
+
+def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
+    exp_config["run_config"]["init_lr"] = (
+        exp_config["run_config"]["base_lr"] * get_dist_size()
+    )
+
+    run_config = run_config_cls(**exp_config["run_config"])
+
+    return run_config
+
+
+def init_model(
+    network: nn.Module,
+    init_from: str or None = None,
+    backbone_init_from: str or None = None,
+    rand_init="trunc_normal",
+    last_gamma=None,
+) -> None:
+    # initialization
+    init_modules(network, init_type=rand_init)
+    # zero gamma of last bn in each block
+    if last_gamma is not None:
+        zero_last_gamma(network, last_gamma)
+
+    # load weight
+    if init_from is not None and os.path.isfile(init_from):
+        network.load_state_dict(load_state_dict_from_file(init_from))
+        print(f"Loaded init from {init_from}")
+    elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
+        network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
+        print(f"Loaded backbone init from {backbone_init_from}")
+    else:
+        print(f"Random init ({rand_init}) with last gamma {last_gamma}")
diff --git a/src/efficientvit/apps/trainer/__init__.py b/src/efficientvit/apps/trainer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b9219c0c05c23e46926de0988c658b79b72388b
--- /dev/null
+++ b/src/efficientvit/apps/trainer/__init__.py
@@ -0,0 +1,6 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .base import *
+from .run_config import *
diff --git a/src/efficientvit/apps/trainer/base.py b/src/efficientvit/apps/trainer/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f190e460bdb11f55357b5aef069486ced98939b
--- /dev/null
+++ b/src/efficientvit/apps/trainer/base.py
@@ -0,0 +1,297 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+
+import torch
+import torch.nn as nn
+
+from src.efficientvit.apps.data_provider import DataProvider, parse_image_size
+from src.efficientvit.apps.trainer.run_config import RunConfig
+from src.efficientvit.apps.utils import (EMA, dist_barrier, get_dist_local_rank,
+                                     is_master)
+from src.efficientvit.models.nn.norm import reset_bn
+from src.efficientvit.models.utils import is_parallel, load_state_dict_from_file
+
+__all__ = ["Trainer"]
+
+
+class Trainer:
+    def __init__(self, path: str, model: nn.Module, data_provider: DataProvider):
+        self.path = os.path.realpath(os.path.expanduser(path))
+        self.model = model.cuda()
+        self.data_provider = data_provider
+
+        self.ema = None
+
+        self.checkpoint_path = os.path.join(self.path, "checkpoint")
+        self.logs_path = os.path.join(self.path, "logs")
+        for path in [self.path, self.checkpoint_path, self.logs_path]:
+            os.makedirs(path, exist_ok=True)
+
+        self.best_val = 0.0
+        self.start_epoch = 0
+
+    @property
+    def network(self) -> nn.Module:
+        return self.model.module if is_parallel(self.model) else self.model
+
+    @property
+    def eval_network(self) -> nn.Module:
+        if self.ema is None:
+            model = self.model
+        else:
+            model = self.ema.shadows
+        model = model.module if is_parallel(model) else model
+        return model
+
+    def write_log(self, log_str, prefix="valid", print_log=True, mode="a") -> None:
+        if is_master():
+            fout = open(os.path.join(self.logs_path, f"{prefix}.log"), mode)
+            fout.write(log_str + "\n")
+            fout.flush()
+            fout.close()
+            if print_log:
+                print(log_str)
+
+    def save_model(
+        self,
+        checkpoint=None,
+        only_state_dict=True,
+        epoch=0,
+        model_name=None,
+    ) -> None:
+        if is_master():
+            if checkpoint is None:
+                if only_state_dict:
+                    checkpoint = {"state_dict": self.network.state_dict()}
+                else:
+                    checkpoint = {
+                        "state_dict": self.network.state_dict(),
+                        "epoch": epoch,
+                        "best_val": self.best_val,
+                        "optimizer": self.optimizer.state_dict(),
+                        "lr_scheduler": self.lr_scheduler.state_dict(),
+                        "ema": self.ema.state_dict() if self.ema is not None else None,
+                        "scaler": self.scaler.state_dict() if self.fp16 else None,
+                    }
+
+            model_name = model_name or "checkpoint.pt"
+
+            latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
+            model_path = os.path.join(self.checkpoint_path, model_name)
+            with open(latest_fname, "w") as _fout:
+                _fout.write(model_path + "\n")
+            torch.save(checkpoint, model_path)
+
+    def load_model(self, model_fname=None) -> None:
+        latest_fname = os.path.join(self.checkpoint_path, "latest.txt")
+        if model_fname is None and os.path.exists(latest_fname):
+            with open(latest_fname, "r") as fin:
+                model_fname = fin.readline()
+                if len(model_fname) > 0 and model_fname[-1] == "\n":
+                    model_fname = model_fname[:-1]
+        try:
+            if model_fname is None:
+                model_fname = f"{self.checkpoint_path}/checkpoint.pt"
+            elif not os.path.exists(model_fname):
+                model_fname = f"{self.checkpoint_path}/{os.path.basename(model_fname)}"
+                if not os.path.exists(model_fname):
+                    model_fname = f"{self.checkpoint_path}/checkpoint.pt"
+            print(f"=> loading checkpoint {model_fname}")
+            checkpoint = load_state_dict_from_file(model_fname, False)
+        except Exception:
+            self.write_log(f"fail to load checkpoint from {self.checkpoint_path}")
+            return
+
+        # load checkpoint
+        self.network.load_state_dict(checkpoint["state_dict"], strict=False)
+        log = []
+        if "epoch" in checkpoint:
+            self.start_epoch = checkpoint["epoch"] + 1
+            self.run_config.update_global_step(self.start_epoch)
+            log.append(f"epoch={self.start_epoch - 1}")
+        if "best_val" in checkpoint:
+            self.best_val = checkpoint["best_val"]
+            log.append(f"best_val={self.best_val:.2f}")
+        if "optimizer" in checkpoint:
+            self.optimizer.load_state_dict(checkpoint["optimizer"])
+            log.append("optimizer")
+        if "lr_scheduler" in checkpoint:
+            self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
+            log.append("lr_scheduler")
+        if "ema" in checkpoint and self.ema is not None:
+            self.ema.load_state_dict(checkpoint["ema"])
+            log.append("ema")
+        if "scaler" in checkpoint and self.fp16:
+            self.scaler.load_state_dict(checkpoint["scaler"])
+            log.append("scaler")
+        self.write_log("Loaded: " + ", ".join(log))
+
+    """ validate """
+
+    def reset_bn(
+        self,
+        network: nn.Module or None = None,
+        subset_size: int = 16000,
+        subset_batch_size: int = 100,
+        data_loader=None,
+        progress_bar=False,
+    ) -> None:
+        network = network or self.network
+        if data_loader is None:
+            data_loader = []
+            for data in self.data_provider.build_sub_train_loader(
+                subset_size, subset_batch_size
+            ):
+                if isinstance(data, list):
+                    data_loader.append(data[0])
+                elif isinstance(data, dict):
+                    data_loader.append(data["data"])
+                elif isinstance(data, torch.Tensor):
+                    data_loader.append(data)
+                else:
+                    raise NotImplementedError
+
+        network.eval()
+        reset_bn(
+            network,
+            data_loader,
+            sync=True,
+            progress_bar=progress_bar,
+        )
+
+    def _validate(self, model, data_loader, epoch) -> dict[str, any]:
+        raise NotImplementedError
+
+    def validate(
+        self, model=None, data_loader=None, is_test=True, epoch=0
+    ) -> dict[str, any]:
+        model = model or self.eval_network
+        if data_loader is None:
+            if is_test:
+                data_loader = self.data_provider.test
+            else:
+                data_loader = self.data_provider.valid
+
+        model.eval()
+        return self._validate(model, data_loader, epoch)
+
+    def multires_validate(
+        self,
+        model=None,
+        data_loader=None,
+        is_test=True,
+        epoch=0,
+        eval_image_size=None,
+    ) -> dict[str, dict[str, any]]:
+        eval_image_size = eval_image_size or self.run_config.eval_image_size
+        eval_image_size = eval_image_size or self.data_provider.image_size
+        model = model or self.eval_network
+
+        if not isinstance(eval_image_size, list):
+            eval_image_size = [eval_image_size]
+
+        output_dict = {}
+        for r in eval_image_size:
+            self.data_provider.assign_active_image_size(parse_image_size(r))
+            if self.run_config.reset_bn:
+                self.reset_bn(
+                    network=model,
+                    subset_size=self.run_config.reset_bn_size,
+                    subset_batch_size=self.run_config.reset_bn_batch_size,
+                    progress_bar=True,
+                )
+            output_dict[f"r{r}"] = self.validate(model, data_loader, is_test, epoch)
+        return output_dict
+
+    """ training """
+
+    def prep_for_training(
+        self, run_config: RunConfig, ema_decay: float or None = None, fp16=False
+    ) -> None:
+        self.run_config = run_config
+        self.model = nn.parallel.DistributedDataParallel(
+            self.model.cuda(),
+            device_ids=[get_dist_local_rank()],
+            static_graph=True,
+        )
+
+        self.run_config.global_step = 0
+        self.run_config.batch_per_epoch = len(self.data_provider.train)
+        assert self.run_config.batch_per_epoch > 0, "Training set is empty"
+
+        # build optimizer
+        self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)
+
+        if ema_decay is not None:
+            self.ema = EMA(self.network, ema_decay)
+
+        # fp16
+        self.fp16 = fp16
+        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)
+
+    def sync_model(self):
+        print("Sync model")
+        self.save_model(model_name="sync.pt")
+        dist_barrier()
+        checkpoint = torch.load(
+            os.path.join(self.checkpoint_path, "sync.pt"), map_location="cpu"
+        )
+        dist_barrier()
+        if is_master():
+            os.remove(os.path.join(self.checkpoint_path, "sync.pt"))
+        dist_barrier()
+
+        # load checkpoint
+        self.network.load_state_dict(checkpoint["state_dict"], strict=False)
+        if "optimizer" in checkpoint:
+            self.optimizer.load_state_dict(checkpoint["optimizer"])
+        if "lr_scheduler" in checkpoint:
+            self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
+        if "ema" in checkpoint and self.ema is not None:
+            self.ema.load_state_dict(checkpoint["ema"])
+        if "scaler" in checkpoint and self.fp16:
+            self.scaler.load_state_dict(checkpoint["scaler"])
+
+    def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
+        for key in feed_dict:
+            if isinstance(feed_dict[key], torch.Tensor):
+                feed_dict[key] = feed_dict[key].cuda()
+        return feed_dict
+
+    def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
+        raise NotImplementedError
+
+    def after_step(self) -> None:
+        self.scaler.unscale_(self.optimizer)
+        # gradient clip
+        if self.run_config.grad_clip is not None:
+            torch.nn.utils.clip_grad_value_(
+                self.model.parameters(), self.run_config.grad_clip
+            )
+        # update
+        self.scaler.step(self.optimizer)
+        self.scaler.update()
+
+        self.lr_scheduler.step()
+        self.run_config.step()
+        # update ema
+        if self.ema is not None:
+            self.ema.step(self.network, self.run_config.global_step)
+
+    def _train_one_epoch(self, epoch: int) -> dict[str, any]:
+        raise NotImplementedError
+
+    def train_one_epoch(self, epoch: int) -> dict[str, any]:
+        self.model.train()
+
+        self.data_provider.set_epoch(epoch)
+
+        train_info_dict = self._train_one_epoch(epoch)
+
+        return train_info_dict
+
+    def train(self) -> None:
+        raise NotImplementedError
diff --git a/src/efficientvit/apps/trainer/run_config.py b/src/efficientvit/apps/trainer/run_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a233f3908aaf0cf7829a02c0073f348f14eed10
--- /dev/null
+++ b/src/efficientvit/apps/trainer/run_config.py
@@ -0,0 +1,121 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import json
+
+import numpy as np
+import torch.nn as nn
+
+from src.efficientvit.apps.utils import CosineLRwithWarmup, build_optimizer
+
+__all__ = ["Scheduler", "RunConfig"]
+
+
+class Scheduler:
+    PROGRESS = 0
+
+
+class RunConfig:
+    n_epochs: int
+    init_lr: float
+    warmup_epochs: int
+    warmup_lr: float
+    lr_schedule_name: str
+    lr_schedule_param: dict
+    optimizer_name: str
+    optimizer_params: dict
+    weight_decay: float
+    no_wd_keys: list
+    grad_clip: float  # allow none to turn off grad clipping
+    reset_bn: bool
+    reset_bn_size: int
+    reset_bn_batch_size: int
+    eval_image_size: list  # allow none to use image_size in data_provider
+
+    @property
+    def none_allowed(self):
+        return ["grad_clip", "eval_image_size"]
+
+    def __init__(self, **kwargs):  # arguments must be passed as kwargs
+        for k, val in kwargs.items():
+            setattr(self, k, val)
+
+        # check that all relevant configs are there
+        annotations = {}
+        for clas in type(self).mro():
+            if hasattr(clas, "__annotations__"):
+                annotations.update(clas.__annotations__)
+        for k, k_type in annotations.items():
+            assert hasattr(
+                self, k
+            ), f"Key {k} with type {k_type} required for initialization."
+            attr = getattr(self, k)
+            if k in self.none_allowed:
+                k_type = (k_type, type(None))
+            assert isinstance(
+                attr, k_type
+            ), f"Key {k} must be type {k_type}, provided={attr}."
+
+        self.global_step = 0
+        self.batch_per_epoch = 1
+
+    def build_optimizer(self, network: nn.Module) -> tuple[any, any]:
+        r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
+        param_dict = {}
+        for name, param in network.named_parameters():
+            if param.requires_grad:
+                opt_config = [self.weight_decay, self.init_lr]
+                if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
+                    if np.any([key in name for key in self.no_wd_keys]):
+                        opt_config[0] = 0
+                opt_key = json.dumps(opt_config)
+                param_dict[opt_key] = param_dict.get(opt_key, []) + [param]
+
+        net_params = []
+        for opt_key, param_list in param_dict.items():
+            wd, lr = json.loads(opt_key)
+            net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})
+
+        optimizer = build_optimizer(
+            net_params, self.optimizer_name, self.optimizer_params, self.init_lr
+        )
+        # build lr scheduler
+        if self.lr_schedule_name == "cosine":
+            decay_steps = []
+            for epoch in self.lr_schedule_param.get("step", []):
+                decay_steps.append(epoch * self.batch_per_epoch)
+            decay_steps.append(self.n_epochs * self.batch_per_epoch)
+            decay_steps.sort()
+            lr_scheduler = CosineLRwithWarmup(
+                optimizer,
+                self.warmup_epochs * self.batch_per_epoch,
+                self.warmup_lr,
+                decay_steps,
+            )
+        else:
+            raise NotImplementedError
+        return optimizer, lr_scheduler
+
+    def update_global_step(self, epoch, batch_id=0) -> None:
+        self.global_step = epoch * self.batch_per_epoch + batch_id
+        Scheduler.PROGRESS = self.progress
+
+    @property
+    def progress(self) -> float:
+        warmup_steps = self.warmup_epochs * self.batch_per_epoch
+        steps = max(0, self.global_step - warmup_steps)
+        return steps / (self.n_epochs * self.batch_per_epoch)
+
+    def step(self) -> None:
+        self.global_step += 1
+        Scheduler.PROGRESS = self.progress
+
+    def get_remaining_epoch(self, epoch, post=True) -> int:
+        return self.n_epochs + self.warmup_epochs - epoch - int(post)
+
+    def epoch_format(self, epoch: int) -> str:
+        epoch_format = f"%.{len(str(self.n_epochs))}d"
+        epoch_format = f"[{epoch_format}/{epoch_format}]"
+        epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
+        return epoch_format
diff --git a/src/efficientvit/apps/utils/__init__.py b/src/efficientvit/apps/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c826a22544285746c588741f3f20fbe3802ccd50
--- /dev/null
+++ b/src/efficientvit/apps/utils/__init__.py
@@ -0,0 +1,12 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .dist import *
+from .ema import *
+from .export import *
+from .init import *
+from .lr import *
+from .metric import *
+from .misc import *
+from .opt import *
diff --git a/src/efficientvit/apps/utils/dist.py b/src/efficientvit/apps/utils/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b55142678fa168d67dc33fff6907d3b8c87a485
--- /dev/null
+++ b/src/efficientvit/apps/utils/dist.py
@@ -0,0 +1,73 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+
+import torch
+import torch.distributed
+
+from src.efficientvit.models.utils.list import list_mean, list_sum
+
+__all__ = [
+    "dist_init",
+    "get_dist_rank",
+    "get_dist_size",
+    "is_master",
+    "dist_barrier",
+    "get_dist_local_rank",
+    "sync_tensor",
+]
+
+
+def dist_init() -> None:
+    try:
+        torch.distributed.init_process_group(backend="nccl")
+        assert torch.distributed.is_initialized()
+    except Exception:
+        # use torchpack
+        from torchpack import distributed as dist
+
+        dist.init()
+        os.environ["RANK"] = f"{dist.rank()}"
+        os.environ["WORLD_SIZE"] = f"{dist.size()}"
+        os.environ["LOCAL_RANK"] = f"{dist.local_rank()}"
+
+
+def get_dist_rank() -> int:
+    return int(os.environ["RANK"])
+
+
+def get_dist_size() -> int:
+    return int(os.environ["WORLD_SIZE"])
+
+
+def is_master() -> bool:
+    return get_dist_rank() == 0
+
+
+def dist_barrier() -> None:
+    torch.distributed.barrier()
+
+
+def get_dist_local_rank() -> int:
+    return int(os.environ["LOCAL_RANK"])
+
+
+def sync_tensor(
+    tensor: torch.Tensor or float, reduce="mean"
+) -> torch.Tensor or list[torch.Tensor]:
+    if not isinstance(tensor, torch.Tensor):
+        tensor = torch.Tensor(1).fill_(tensor).cuda()
+    tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
+    torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
+    if reduce == "mean":
+        return list_mean(tensor_list)
+    elif reduce == "sum":
+        return list_sum(tensor_list)
+    elif reduce == "cat":
+        return torch.cat(tensor_list, dim=0)
+    elif reduce == "root":
+        return tensor_list[0]
+    else:
+        return tensor_list
diff --git a/src/efficientvit/apps/utils/ema.py b/src/efficientvit/apps/utils/ema.py
new file mode 100644
index 0000000000000000000000000000000000000000..9de7f7fee67840f98ec97bf759dd0a390618a576
--- /dev/null
+++ b/src/efficientvit/apps/utils/ema.py
@@ -0,0 +1,50 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+import math
+
+import torch
+import torch.nn as nn
+
+from src.efficientvit.models.utils import is_parallel
+
+__all__ = ["EMA"]
+
+
+def update_ema(
+    ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float
+) -> None:
+    for k, v in ema.state_dict().items():
+        if v.dtype.is_floating_point:
+            v -= (1.0 - decay) * (v - new_state_dict[k].detach())
+
+
+class EMA:
+    def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
+        self.shadows = copy.deepcopy(
+            model.module if is_parallel(model) else model
+        ).eval()
+        self.decay = decay
+        self.warmup_steps = warmup_steps
+
+        for p in self.shadows.parameters():
+            p.requires_grad = False
+
+    def step(self, model: nn.Module, global_step: int) -> None:
+        with torch.no_grad():
+            msd = (model.module if is_parallel(model) else model).state_dict()
+            update_ema(
+                self.shadows,
+                msd,
+                self.decay * (1 - math.exp(-global_step / self.warmup_steps)),
+            )
+
+    def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
+        return {self.decay: self.shadows.state_dict()}
+
+    def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
+        for decay in state_dict:
+            if decay == self.decay:
+                self.shadows.load_state_dict(state_dict[decay])
diff --git a/src/efficientvit/apps/utils/export.py b/src/efficientvit/apps/utils/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..d611f957a6ff22b98210d611e7344426e091d3df
--- /dev/null
+++ b/src/efficientvit/apps/utils/export.py
@@ -0,0 +1,47 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import io
+import os
+
+import onnx
+import torch
+import torch.nn as nn
+from onnxsim import simplify as simplify_func
+
+__all__ = ["export_onnx"]
+
+
+def export_onnx(
+    model: nn.Module, export_path: str, sample_inputs: any, simplify=True, opset=11
+) -> None:
+    """Export a model to a platform-specific onnx format.
+
+    Args:
+        model: a torch.nn.Module object.
+        export_path: export location.
+        sample_inputs: Any.
+        simplify: a flag to turn on onnx-simplifier
+        opset: int
+    """
+    model.eval()
+
+    buffer = io.BytesIO()
+    with torch.no_grad():
+        torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
+        buffer.seek(0, 0)
+        if simplify:
+            onnx_model = onnx.load_model(buffer)
+            onnx_model, success = simplify_func(onnx_model)
+            assert success
+            new_buffer = io.BytesIO()
+            onnx.save(onnx_model, new_buffer)
+            buffer = new_buffer
+            buffer.seek(0, 0)
+
+    if buffer.getbuffer().nbytes > 0:
+        save_dir = os.path.dirname(export_path)
+        os.makedirs(save_dir, exist_ok=True)
+        with open(export_path, "wb") as f:
+            f.write(buffer.read())
diff --git a/src/efficientvit/apps/utils/init.py b/src/efficientvit/apps/utils/init.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d2ebe26ff45a7ee1de614a39e0db24198097152
--- /dev/null
+++ b/src/efficientvit/apps/utils/init.py
@@ -0,0 +1,68 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+__all__ = ["init_modules", "zero_last_gamma"]
+
+
+def init_modules(model: nn.Module or list[nn.Module], init_type="trunc_normal") -> None:
+    _DEFAULT_INIT_PARAM = {"trunc_normal": 0.02}
+
+    if isinstance(model, list):
+        for sub_module in model:
+            init_modules(sub_module, init_type)
+    else:
+        init_params = init_type.split("@")
+        init_params = float(init_params[1]) if len(init_params) > 1 else None
+
+        if init_type.startswith("trunc_normal"):
+            init_func = lambda param: nn.init.trunc_normal_(
+                param, std=(init_params or _DEFAULT_INIT_PARAM["trunc_normal"])
+            )
+        else:
+            raise NotImplementedError
+
+        for m in model.modules():
+            if isinstance(m, (nn.Conv2d, nn.Linear, nn.ConvTranspose2d)):
+                init_func(m.weight)
+                if m.bias is not None:
+                    m.bias.data.zero_()
+            elif isinstance(m, nn.Embedding):
+                init_func(m.weight)
+            elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+            else:
+                weight = getattr(m, "weight", None)
+                bias = getattr(m, "bias", None)
+                if isinstance(weight, torch.nn.Parameter):
+                    init_func(weight)
+                if isinstance(bias, torch.nn.Parameter):
+                    bias.data.zero_()
+
+
+def zero_last_gamma(model: nn.Module, init_val=0) -> None:
+    import efficientvit.models.nn.ops as ops
+
+    for m in model.modules():
+        if isinstance(m, ops.ResidualBlock) and isinstance(
+            m.shortcut, ops.IdentityLayer
+        ):
+            if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
+                parent_module = m.main.point_conv
+            elif isinstance(m.main, ops.ResBlock):
+                parent_module = m.main.conv2
+            elif isinstance(m.main, ops.ConvLayer):
+                parent_module = m.main
+            elif isinstance(m.main, (ops.LiteMLA)):
+                parent_module = m.main.proj
+            else:
+                parent_module = None
+            if parent_module is not None:
+                norm = getattr(parent_module, "norm", None)
+                if norm is not None:
+                    nn.init.constant_(norm.weight, init_val)
diff --git a/src/efficientvit/apps/utils/lr.py b/src/efficientvit/apps/utils/lr.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe10c6aee6ce3be5afdc766bbc028829d492903a
--- /dev/null
+++ b/src/efficientvit/apps/utils/lr.py
@@ -0,0 +1,48 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import math
+
+import torch
+
+from src.efficientvit.models.utils.list import val2list
+
+__all__ = ["CosineLRwithWarmup"]
+
+
+class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
+    def __init__(
+        self,
+        optimizer: torch.optim.Optimizer,
+        warmup_steps: int,
+        warmup_lr: float,
+        decay_steps: int or list[int],
+        last_epoch: int = -1,
+    ) -> None:
+        self.warmup_steps = warmup_steps
+        self.warmup_lr = warmup_lr
+        self.decay_steps = val2list(decay_steps)
+        super().__init__(optimizer, last_epoch)
+
+    def get_lr(self) -> list[float]:
+        if self.last_epoch < self.warmup_steps:
+            return [
+                (base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps
+                + self.warmup_lr
+                for base_lr in self.base_lrs
+            ]
+        else:
+            current_steps = self.last_epoch - self.warmup_steps
+            decay_steps = [0] + self.decay_steps
+            idx = len(decay_steps) - 2
+            for i, decay_step in enumerate(decay_steps[:-1]):
+                if decay_step <= current_steps < decay_steps[i + 1]:
+                    idx = i
+                    break
+            current_steps -= decay_steps[idx]
+            decay_step = decay_steps[idx + 1] - decay_steps[idx]
+            return [
+                0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step))
+                for base_lr in self.base_lrs
+            ]
diff --git a/src/efficientvit/apps/utils/metric.py b/src/efficientvit/apps/utils/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e656679c8fcf4d3c6320fc10de72d47310f7cf2
--- /dev/null
+++ b/src/efficientvit/apps/utils/metric.py
@@ -0,0 +1,37 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+
+from src.efficientvit.apps.utils.dist import sync_tensor
+
+__all__ = ["AverageMeter"]
+
+
+class AverageMeter:
+    """Computes and stores the average and current value."""
+
+    def __init__(self, is_distributed=True):
+        self.is_distributed = is_distributed
+        self.sum = 0
+        self.count = 0
+
+    def _sync(self, val: torch.Tensor or int or float) -> torch.Tensor or int or float:
+        return sync_tensor(val, reduce="sum") if self.is_distributed else val
+
+    def update(self, val: torch.Tensor or int or float, delta_n=1):
+        self.count += self._sync(delta_n)
+        self.sum += self._sync(val * delta_n)
+
+    def get_count(self) -> torch.Tensor or int or float:
+        return (
+            self.count.item()
+            if isinstance(self.count, torch.Tensor) and self.count.numel() == 1
+            else self.count
+        )
+
+    @property
+    def avg(self):
+        avg = -1 if self.count == 0 else self.sum / self.count
+        return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg
diff --git a/src/efficientvit/apps/utils/misc.py b/src/efficientvit/apps/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..111b6618ab20bd02b5b6d8785091122c82fc8a24
--- /dev/null
+++ b/src/efficientvit/apps/utils/misc.py
@@ -0,0 +1,111 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+
+import yaml
+
+__all__ = [
+    "parse_with_yaml",
+    "parse_unknown_args",
+    "partial_update_config",
+    "resolve_and_load_config",
+    "load_config",
+    "dump_config",
+]
+
+
+def parse_with_yaml(config_str: str) -> str or dict:
+    try:
+        # add space manually for dict
+        if "{" in config_str and "}" in config_str and ":" in config_str:
+            out_str = config_str.replace(":", ": ")
+        else:
+            out_str = config_str
+        return yaml.safe_load(out_str)
+    except ValueError:
+        # return raw string if parsing fails
+        return config_str
+
+
+def parse_unknown_args(unknown: list) -> dict:
+    """Parse unknown args."""
+    index = 0
+    parsed_dict = {}
+    while index < len(unknown):
+        key, val = unknown[index], unknown[index + 1]
+        index += 2
+        if not key.startswith("--"):
+            continue
+        key = key[2:]
+
+        # try parsing with either dot notation or full yaml notation
+        # Note that the vanilla case "--key value" will be parsed the same
+        if "." in key:
+            # key == a.b.c, val == val --> parsed_dict[a][b][c] = val
+            keys = key.split(".")
+            dict_to_update = parsed_dict
+            for key in keys[:-1]:
+                if not (
+                    key in dict_to_update and isinstance(dict_to_update[key], dict)
+                ):
+                    dict_to_update[key] = {}
+                dict_to_update = dict_to_update[key]
+            dict_to_update[keys[-1]] = parse_with_yaml(
+                val
+            )  # so we can parse lists, bools, etc...
+        else:
+            parsed_dict[key] = parse_with_yaml(val)
+    return parsed_dict
+
+
+def partial_update_config(config: dict, partial_config: dict) -> dict:
+    for key in partial_config:
+        if (
+            key in config
+            and isinstance(partial_config[key], dict)
+            and isinstance(config[key], dict)
+        ):
+            partial_update_config(config[key], partial_config[key])
+        else:
+            config[key] = partial_config[key]
+    return config
+
+
+def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
+    path = os.path.realpath(os.path.expanduser(path))
+    if os.path.isdir(path):
+        config_path = os.path.join(path, config_name)
+    else:
+        config_path = path
+    if os.path.isfile(config_path):
+        pass
+    else:
+        raise Exception(f"Cannot find a valid config at {path}")
+    config = load_config(config_path)
+    return config
+
+
+class SafeLoaderWithTuple(yaml.SafeLoader):
+    """A yaml safe loader with python tuple loading capabilities."""
+
+    def construct_python_tuple(self, node):
+        return tuple(self.construct_sequence(node))
+
+
+SafeLoaderWithTuple.add_constructor(
+    "tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple
+)
+
+
+def load_config(filename: str) -> dict:
+    """Load a yaml file."""
+    filename = os.path.realpath(os.path.expanduser(filename))
+    return yaml.load(open(filename), Loader=SafeLoaderWithTuple)
+
+
+def dump_config(config: dict, filename: str) -> None:
+    """Dump a config file"""
+    filename = os.path.realpath(os.path.expanduser(filename))
+    yaml.dump(config, open(filename, "w"), sort_keys=False)
diff --git a/src/efficientvit/apps/utils/opt.py b/src/efficientvit/apps/utils/opt.py
new file mode 100644
index 0000000000000000000000000000000000000000..79a03507c8b0aa8ad6e7210657630d5af6555521
--- /dev/null
+++ b/src/efficientvit/apps/utils/opt.py
@@ -0,0 +1,31 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+
+__all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]
+
+# register optimizer here
+#   name: optimizer, kwargs with default values
+REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, any]]] = {
+    "sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
+    "adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
+    "adamw": (
+        torch.optim.AdamW,
+        {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False},
+    ),
+}
+
+
+def build_optimizer(
+    net_params, optimizer_name: str, optimizer_params: dict or None, init_lr: float
+) -> torch.optim.Optimizer:
+    optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
+    optimizer_params = optimizer_params or {}
+
+    for key in default_params:
+        if key in optimizer_params:
+            default_params[key] = optimizer_params[key]
+    optimizer = optimizer_class(net_params, init_lr, **default_params)
+    return optimizer
diff --git a/src/efficientvit/models/__init__.py b/src/efficientvit/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/efficientvit/models/efficientvit/__init__.py b/src/efficientvit/models/efficientvit/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea677f24763b605249c05ea37483b579c507cbc
--- /dev/null
+++ b/src/efficientvit/models/efficientvit/__init__.py
@@ -0,0 +1,8 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .backbone import *
+from .cls import *
+from .sam import *
+from .seg import *
diff --git a/src/efficientvit/models/efficientvit/backbone.py b/src/efficientvit/models/efficientvit/backbone.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c861a5f1486d243bd8eb0d120f6d646dfe8615e
--- /dev/null
+++ b/src/efficientvit/models/efficientvit/backbone.py
@@ -0,0 +1,372 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+
+from src.efficientvit.models.nn import (ConvLayer, DSConv, EfficientViTBlock,
+                                    FusedMBConv, IdentityLayer, MBConv,
+                                    OpSequential, ResBlock, ResidualBlock)
+from src.efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = [
+    "EfficientViTBackbone",
+    "efficientvit_backbone_b0",
+    "efficientvit_backbone_b1",
+    "efficientvit_backbone_b2",
+    "efficientvit_backbone_b3",
+    "EfficientViTLargeBackbone",
+    "efficientvit_backbone_l0",
+    "efficientvit_backbone_l1",
+    "efficientvit_backbone_l2",
+    "efficientvit_backbone_l3",
+]
+
+
+class EfficientViTBackbone(nn.Module):
+    def __init__(
+        self,
+        width_list: list[int],
+        depth_list: list[int],
+        in_channels=3,
+        dim=32,
+        expand_ratio=4,
+        norm="bn2d",
+        act_func="hswish",
+    ) -> None:
+        super().__init__()
+
+        self.width_list = []
+        # input stem
+        self.input_stem = [
+            ConvLayer(
+                in_channels=3,
+                out_channels=width_list[0],
+                stride=2,
+                norm=norm,
+                act_func=act_func,
+            )
+        ]
+        for _ in range(depth_list[0]):
+            block = self.build_local_block(
+                in_channels=width_list[0],
+                out_channels=width_list[0],
+                stride=1,
+                expand_ratio=1,
+                norm=norm,
+                act_func=act_func,
+            )
+            self.input_stem.append(ResidualBlock(block, IdentityLayer()))
+        in_channels = width_list[0]
+        self.input_stem = OpSequential(self.input_stem)
+        self.width_list.append(in_channels)
+
+        # stages
+        self.stages = []
+        for w, d in zip(width_list[1:3], depth_list[1:3]):
+            stage = []
+            for i in range(d):
+                stride = 2 if i == 0 else 1
+                block = self.build_local_block(
+                    in_channels=in_channels,
+                    out_channels=w,
+                    stride=stride,
+                    expand_ratio=expand_ratio,
+                    norm=norm,
+                    act_func=act_func,
+                )
+                block = ResidualBlock(block, IdentityLayer() if stride == 1 else None)
+                stage.append(block)
+                in_channels = w
+            self.stages.append(OpSequential(stage))
+            self.width_list.append(in_channels)
+
+        for w, d in zip(width_list[3:], depth_list[3:]):
+            stage = []
+            block = self.build_local_block(
+                in_channels=in_channels,
+                out_channels=w,
+                stride=2,
+                expand_ratio=expand_ratio,
+                norm=norm,
+                act_func=act_func,
+                fewer_norm=True,
+            )
+            stage.append(ResidualBlock(block, None))
+            in_channels = w
+
+            for _ in range(d):
+                stage.append(
+                    EfficientViTBlock(
+                        in_channels=in_channels,
+                        dim=dim,
+                        expand_ratio=expand_ratio,
+                        norm=norm,
+                        act_func=act_func,
+                    )
+                )
+            self.stages.append(OpSequential(stage))
+            self.width_list.append(in_channels)
+        self.stages = nn.ModuleList(self.stages)
+
+    @staticmethod
+    def build_local_block(
+        in_channels: int,
+        out_channels: int,
+        stride: int,
+        expand_ratio: float,
+        norm: str,
+        act_func: str,
+        fewer_norm: bool = False,
+    ) -> nn.Module:
+        if expand_ratio == 1:
+            block = DSConv(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                stride=stride,
+                use_bias=(True, False) if fewer_norm else False,
+                norm=(None, norm) if fewer_norm else norm,
+                act_func=(act_func, None),
+            )
+        else:
+            block = MBConv(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                stride=stride,
+                expand_ratio=expand_ratio,
+                use_bias=(True, True, False) if fewer_norm else False,
+                norm=(None, None, norm) if fewer_norm else norm,
+                act_func=(act_func, act_func, None),
+            )
+        return block
+
+    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+        output_dict = {"input": x}
+        output_dict["stage0"] = x = self.input_stem(x)
+        for stage_id, stage in enumerate(self.stages, 1):
+            output_dict["stage%d" % stage_id] = x = stage(x)
+        output_dict["stage_final"] = x
+        return output_dict
+
+
+def efficientvit_backbone_b0(**kwargs) -> EfficientViTBackbone:
+    backbone = EfficientViTBackbone(
+        width_list=[8, 16, 32, 64, 128],
+        depth_list=[1, 2, 2, 2, 2],
+        dim=16,
+        **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+    )
+    return backbone
+
+
+def efficientvit_backbone_b1(**kwargs) -> EfficientViTBackbone:
+    backbone = EfficientViTBackbone(
+        width_list=[16, 32, 64, 128, 256],
+        depth_list=[1, 2, 3, 3, 4],
+        dim=16,
+        **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+    )
+    return backbone
+
+
+def efficientvit_backbone_b2(**kwargs) -> EfficientViTBackbone:
+    backbone = EfficientViTBackbone(
+        width_list=[24, 48, 96, 192, 384],
+        depth_list=[1, 3, 4, 4, 6],
+        dim=32,
+        **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+    )
+    return backbone
+
+
+def efficientvit_backbone_b3(**kwargs) -> EfficientViTBackbone:
+    backbone = EfficientViTBackbone(
+        width_list=[32, 64, 128, 256, 512],
+        depth_list=[1, 4, 6, 6, 9],
+        dim=32,
+        **build_kwargs_from_config(kwargs, EfficientViTBackbone),
+    )
+    return backbone
+
+
+class EfficientViTLargeBackbone(nn.Module):
+    def __init__(
+        self,
+        width_list: list[int],
+        depth_list: list[int],
+        block_list: list[str] or None = None,
+        expand_list: list[float] or None = None,
+        fewer_norm_list: list[bool] or None = None,
+        in_channels=3,
+        qkv_dim=32,
+        norm="bn2d",
+        act_func="gelu",
+    ) -> None:
+        super().__init__()
+        block_list = block_list or ["res", "fmb", "fmb", "mb", "att"]
+        expand_list = expand_list or [1, 4, 4, 4, 6]
+        fewer_norm_list = fewer_norm_list or [False, False, False, True, True]
+
+        self.width_list = []
+        self.stages = []
+        # stage 0
+        stage0 = [
+            ConvLayer(
+                in_channels=3,
+                out_channels=width_list[0],
+                stride=2,
+                norm=norm,
+                act_func=act_func,
+            )
+        ]
+        for _ in range(depth_list[0]):
+            block = self.build_local_block(
+                block=block_list[0],
+                in_channels=width_list[0],
+                out_channels=width_list[0],
+                stride=1,
+                expand_ratio=expand_list[0],
+                norm=norm,
+                act_func=act_func,
+                fewer_norm=fewer_norm_list[0],
+            )
+            stage0.append(ResidualBlock(block, IdentityLayer()))
+        in_channels = width_list[0]
+        self.stages.append(OpSequential(stage0))
+        self.width_list.append(in_channels)
+
+        for stage_id, (w, d) in enumerate(zip(width_list[1:], depth_list[1:]), start=1):
+            stage = []
+            block = self.build_local_block(
+                block=(
+                    "mb"
+                    if block_list[stage_id] not in ["mb", "fmb"]
+                    else block_list[stage_id]
+                ),
+                in_channels=in_channels,
+                out_channels=w,
+                stride=2,
+                expand_ratio=expand_list[stage_id] * 4,
+                norm=norm,
+                act_func=act_func,
+                fewer_norm=fewer_norm_list[stage_id],
+            )
+            stage.append(ResidualBlock(block, None))
+            in_channels = w
+
+            for _ in range(d):
+                if block_list[stage_id].startswith("att"):
+                    stage.append(
+                        EfficientViTBlock(
+                            in_channels=in_channels,
+                            dim=qkv_dim,
+                            expand_ratio=expand_list[stage_id],
+                            scales=(3,) if block_list[stage_id] == "att@3" else (5,),
+                            norm=norm,
+                            act_func=act_func,
+                        )
+                    )
+                else:
+                    block = self.build_local_block(
+                        block=block_list[stage_id],
+                        in_channels=in_channels,
+                        out_channels=in_channels,
+                        stride=1,
+                        expand_ratio=expand_list[stage_id],
+                        norm=norm,
+                        act_func=act_func,
+                        fewer_norm=fewer_norm_list[stage_id],
+                    )
+                    block = ResidualBlock(block, IdentityLayer())
+                    stage.append(block)
+            self.stages.append(OpSequential(stage))
+            self.width_list.append(in_channels)
+        self.stages = nn.ModuleList(self.stages)
+
+    @staticmethod
+    def build_local_block(
+        block: str,
+        in_channels: int,
+        out_channels: int,
+        stride: int,
+        expand_ratio: float,
+        norm: str,
+        act_func: str,
+        fewer_norm: bool = False,
+    ) -> nn.Module:
+        if block == "res":
+            block = ResBlock(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                stride=stride,
+                use_bias=(True, False) if fewer_norm else False,
+                norm=(None, norm) if fewer_norm else norm,
+                act_func=(act_func, None),
+            )
+        elif block == "fmb":
+            block = FusedMBConv(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                stride=stride,
+                expand_ratio=expand_ratio,
+                use_bias=(True, False) if fewer_norm else False,
+                norm=(None, norm) if fewer_norm else norm,
+                act_func=(act_func, None),
+            )
+        elif block == "mb":
+            block = MBConv(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                stride=stride,
+                expand_ratio=expand_ratio,
+                use_bias=(True, True, False) if fewer_norm else False,
+                norm=(None, None, norm) if fewer_norm else norm,
+                act_func=(act_func, act_func, None),
+            )
+        else:
+            raise ValueError(block)
+        return block
+
+    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
+        output_dict = {"input": x}
+        for stage_id, stage in enumerate(self.stages):
+            output_dict["stage%d" % stage_id] = x = stage(x)
+        output_dict["stage_final"] = x
+        return output_dict
+
+
+def efficientvit_backbone_l0(**kwargs) -> EfficientViTLargeBackbone:
+    backbone = EfficientViTLargeBackbone(
+        width_list=[32, 64, 128, 256, 512],
+        depth_list=[1, 1, 1, 4, 4],
+        **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+    )
+    return backbone
+
+
+def efficientvit_backbone_l1(**kwargs) -> EfficientViTLargeBackbone:
+    backbone = EfficientViTLargeBackbone(
+        width_list=[32, 64, 128, 256, 512],
+        depth_list=[1, 1, 1, 6, 6],
+        **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+    )
+    return backbone
+
+
+def efficientvit_backbone_l2(**kwargs) -> EfficientViTLargeBackbone:
+    backbone = EfficientViTLargeBackbone(
+        width_list=[32, 64, 128, 256, 512],
+        depth_list=[1, 2, 2, 8, 8],
+        **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+    )
+    return backbone
+
+
+def efficientvit_backbone_l3(**kwargs) -> EfficientViTLargeBackbone:
+    backbone = EfficientViTLargeBackbone(
+        width_list=[64, 128, 256, 512, 1024],
+        depth_list=[1, 2, 2, 8, 8],
+        **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+    )
+    return backbone
diff --git a/src/efficientvit/models/efficientvit/cls.py b/src/efficientvit/models/efficientvit/cls.py
new file mode 100644
index 0000000000000000000000000000000000000000..98fac3c4be1ac504585197dbee6cea6f087ffddd
--- /dev/null
+++ b/src/efficientvit/models/efficientvit/cls.py
@@ -0,0 +1,174 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+
+from src.efficientvit.models.efficientvit.backbone import (
+    EfficientViTBackbone, EfficientViTLargeBackbone)
+from src.efficientvit.models.nn import ConvLayer, LinearLayer, OpSequential
+from src.efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = [
+    "EfficientViTCls",
+    ######################
+    "efficientvit_cls_b0",
+    "efficientvit_cls_b1",
+    "efficientvit_cls_b2",
+    "efficientvit_cls_b3",
+    ######################
+    "efficientvit_cls_l1",
+    "efficientvit_cls_l2",
+    "efficientvit_cls_l3",
+]
+
+
+class ClsHead(OpSequential):
+    def __init__(
+        self,
+        in_channels: int,
+        width_list: list[int],
+        n_classes=1000,
+        dropout=0.0,
+        norm="bn2d",
+        act_func="hswish",
+        fid="stage_final",
+    ):
+        ops = [
+            ConvLayer(in_channels, width_list[0], 1, norm=norm, act_func=act_func),
+            nn.AdaptiveAvgPool2d(output_size=1),
+            LinearLayer(
+                width_list[0], width_list[1], False, norm="ln", act_func=act_func
+            ),
+            LinearLayer(width_list[1], n_classes, True, dropout, None, None),
+        ]
+        super().__init__(ops)
+
+        self.fid = fid
+
+    def forward(self, feed_dict: dict[str, torch.Tensor]) -> torch.Tensor:
+        x = feed_dict[self.fid]
+        return OpSequential.forward(self, x)
+
+
+class EfficientViTCls(nn.Module):
+    def __init__(
+        self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: ClsHead
+    ) -> None:
+        super().__init__()
+        self.backbone = backbone
+        self.head = head
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        feed_dict = self.backbone(x)
+        output = self.head(feed_dict)
+        return output
+
+
+def efficientvit_cls_b0(**kwargs) -> EfficientViTCls:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b0
+
+    backbone = efficientvit_backbone_b0(**kwargs)
+
+    head = ClsHead(
+        in_channels=128,
+        width_list=[1024, 1280],
+        **build_kwargs_from_config(kwargs, ClsHead),
+    )
+    model = EfficientViTCls(backbone, head)
+    return model
+
+
+def efficientvit_cls_b1(**kwargs) -> EfficientViTCls:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b1
+
+    backbone = efficientvit_backbone_b1(**kwargs)
+
+    head = ClsHead(
+        in_channels=256,
+        width_list=[1536, 1600],
+        **build_kwargs_from_config(kwargs, ClsHead),
+    )
+    model = EfficientViTCls(backbone, head)
+    return model
+
+
+def efficientvit_cls_b2(**kwargs) -> EfficientViTCls:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b2
+
+    backbone = efficientvit_backbone_b2(**kwargs)
+
+    head = ClsHead(
+        in_channels=384,
+        width_list=[2304, 2560],
+        **build_kwargs_from_config(kwargs, ClsHead),
+    )
+    model = EfficientViTCls(backbone, head)
+    return model
+
+
+def efficientvit_cls_b3(**kwargs) -> EfficientViTCls:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b3
+
+    backbone = efficientvit_backbone_b3(**kwargs)
+
+    head = ClsHead(
+        in_channels=512,
+        width_list=[2304, 2560],
+        **build_kwargs_from_config(kwargs, ClsHead),
+    )
+    model = EfficientViTCls(backbone, head)
+    return model
+
+
+def efficientvit_cls_l1(**kwargs) -> EfficientViTCls:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l1
+
+    backbone = efficientvit_backbone_l1(**kwargs)
+
+    head = ClsHead(
+        in_channels=512,
+        width_list=[3072, 3200],
+        act_func="gelu",
+        **build_kwargs_from_config(kwargs, ClsHead),
+    )
+    model = EfficientViTCls(backbone, head)
+    return model
+
+
+def efficientvit_cls_l2(**kwargs) -> EfficientViTCls:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l2
+
+    backbone = efficientvit_backbone_l2(**kwargs)
+
+    head = ClsHead(
+        in_channels=512,
+        width_list=[3072, 3200],
+        act_func="gelu",
+        **build_kwargs_from_config(kwargs, ClsHead),
+    )
+    model = EfficientViTCls(backbone, head)
+    return model
+
+
+def efficientvit_cls_l3(**kwargs) -> EfficientViTCls:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l3
+
+    backbone = efficientvit_backbone_l3(**kwargs)
+
+    head = ClsHead(
+        in_channels=1024,
+        width_list=[6144, 6400],
+        act_func="gelu",
+        **build_kwargs_from_config(kwargs, ClsHead),
+    )
+    model = EfficientViTCls(backbone, head)
+    return model
diff --git a/src/efficientvit/models/efficientvit/sam.py b/src/efficientvit/models/efficientvit/sam.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe331ef112b64c09ae338033074737645e67e11d
--- /dev/null
+++ b/src/efficientvit/models/efficientvit/sam.py
@@ -0,0 +1,653 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import copy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+from segment_anything import SamAutomaticMaskGenerator
+from segment_anything.modeling import (MaskDecoder, PromptEncoder,
+                                       TwoWayTransformer)
+from segment_anything.modeling.mask_decoder import MaskDecoder
+from segment_anything.modeling.prompt_encoder import PromptEncoder
+from segment_anything.utils.amg import build_all_layer_point_grids
+from segment_anything.utils.transforms import ResizeLongestSide
+from torchvision.transforms.functional import resize, to_pil_image
+
+from src.efficientvit.models.efficientvit.backbone import (
+    EfficientViTBackbone, EfficientViTLargeBackbone)
+from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
+                                    IdentityLayer, MBConv, OpSequential,
+                                    ResBlock, ResidualBlock, UpSampleLayer,
+                                    build_norm)
+from src.efficientvit.models.utils import build_kwargs_from_config, get_device
+
+__all__ = [
+    "SamPad",
+    "SamResize",
+    "SamNeck",
+    "EfficientViTSamImageEncoder",
+    "EfficientViTSam",
+    "EfficientViTSamPredictor",
+    "EfficientViTSamAutomaticMaskGenerator",
+    "efficientvit_sam_l0",
+    "efficientvit_sam_l1",
+    "efficientvit_sam_l2",
+    "efficientvit_sam_xl0",
+    "efficientvit_sam_xl1",
+]
+
+
+class SamPad:
+    def __init__(self, size: int, fill: float = 0, pad_mode="corner") -> None:
+        self.size = size
+        self.fill = fill
+        self.pad_mode = pad_mode
+
+    def __call__(self, image: torch.Tensor) -> torch.Tensor:
+        h, w = image.shape[-2:]
+        th, tw = self.size, self.size
+        assert th >= h and tw >= w
+        if self.pad_mode == "corner":
+            image = F.pad(image, (0, tw - w, 0, th - h), value=self.fill)
+        else:
+            raise NotImplementedError
+        return image
+
+    def __repr__(self) -> str:
+        return f"{type(self).__name__}(size={self.size},mode={self.pad_mode},fill={self.fill})"
+
+
+class SamResize:
+    def __init__(self, size: int) -> None:
+        self.size = size
+
+    def __call__(self, image: np.ndarray) -> np.ndarray:
+        h, w, _ = image.shape
+        long_side = max(h, w)
+        if long_side != self.size:
+            return self.apply_image(image)
+        else:
+            return image
+
+    def apply_image(self, image: np.ndarray) -> np.ndarray:
+        """
+        Expects a numpy array with shape HxWxC in uint8 format.
+        """
+        target_size = self.get_preprocess_shape(
+            image.shape[0], image.shape[1], self.size
+        )
+        return np.array(resize(to_pil_image(image), target_size))
+
+    @staticmethod
+    def get_preprocess_shape(
+        oldh: int, oldw: int, long_side_length: int
+    ) -> tuple[int, int]:
+        """
+        Compute the output size given input size and target long side length.
+        """
+        scale = long_side_length * 1.0 / max(oldh, oldw)
+        newh, neww = oldh * scale, oldw * scale
+        neww = int(neww + 0.5)
+        newh = int(newh + 0.5)
+        return (newh, neww)
+
+    def __repr__(self) -> str:
+        return f"{type(self).__name__}(size={self.size})"
+
+
+class SamNeck(DAGBlock):
+    def __init__(
+        self,
+        fid_list: list[str],
+        in_channel_list: list[int],
+        head_width: int,
+        head_depth: int,
+        expand_ratio: float,
+        middle_op: str,
+        out_dim: int = 256,
+        norm="bn2d",
+        act_func="gelu",
+    ):
+        inputs = {}
+        for fid, in_channel in zip(fid_list, in_channel_list):
+            inputs[fid] = OpSequential(
+                [
+                    ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
+                    UpSampleLayer(size=(64, 64)),
+                ]
+            )
+
+        middle = []
+        for _ in range(head_depth):
+            if middle_op == "mb":
+                block = MBConv(
+                    head_width,
+                    head_width,
+                    expand_ratio=expand_ratio,
+                    norm=norm,
+                    act_func=(act_func, act_func, None),
+                )
+            elif middle_op == "fmb":
+                block = FusedMBConv(
+                    head_width,
+                    head_width,
+                    expand_ratio=expand_ratio,
+                    norm=norm,
+                    act_func=(act_func, None),
+                )
+            elif middle_op == "res":
+                block = ResBlock(
+                    head_width,
+                    head_width,
+                    expand_ratio=expand_ratio,
+                    norm=norm,
+                    act_func=(act_func, None),
+                )
+            else:
+                raise NotImplementedError
+            middle.append(ResidualBlock(block, IdentityLayer()))
+        middle = OpSequential(middle)
+
+        outputs = {
+            "sam_encoder": OpSequential(
+                [
+                    ConvLayer(
+                        head_width,
+                        out_dim,
+                        1,
+                        use_bias=True,
+                        norm=None,
+                        act_func=None,
+                    ),
+                ]
+            )
+        }
+
+        super(SamNeck, self).__init__(
+            inputs, "add", None, middle=middle, outputs=outputs
+        )
+
+
+class EfficientViTSamImageEncoder(nn.Module):
+    def __init__(
+        self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, neck: SamNeck
+    ):
+        super().__init__()
+        self.backbone = backbone
+        self.neck = neck
+
+        self.norm = build_norm("ln2d", 256)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        feed_dict = self.backbone(x)
+        feed_dict = self.neck(feed_dict)
+
+        output = feed_dict["sam_encoder"]
+        output = self.norm(output)
+        return output
+
+
+class EfficientViTSam(nn.Module):
+    mask_threshold: float = 0.0
+    image_format: str = "RGB"
+
+    def __init__(
+        self,
+        image_encoder: EfficientViTSamImageEncoder,
+        prompt_encoder: PromptEncoder,
+        mask_decoder: MaskDecoder,
+        image_size: tuple[int, int] = (1024, 512),
+    ) -> None:
+        super().__init__()
+        self.image_encoder = image_encoder
+        self.prompt_encoder = prompt_encoder
+        self.mask_decoder = mask_decoder
+
+        self.image_size = image_size
+
+        self.transform = transforms.Compose(
+            [
+                SamResize(self.image_size[1]),
+                transforms.ToTensor(),
+                transforms.Normalize(
+                    mean=[123.675 / 255, 116.28 / 255, 103.53 / 255],
+                    std=[58.395 / 255, 57.12 / 255, 57.375 / 255],
+                ),
+                SamPad(self.image_size[1]),
+            ]
+        )
+
+    def postprocess_masks(
+        self,
+        masks: torch.Tensor,
+        input_size: tuple[int, ...],
+        original_size: tuple[int, ...],
+    ) -> torch.Tensor:
+        masks = F.interpolate(
+            masks,
+            (self.image_size[0], self.image_size[0]),
+            mode="bilinear",
+            align_corners=False,
+        )
+        masks = masks[..., : input_size[0], : input_size[1]]
+        masks = F.interpolate(
+            masks, original_size, mode="bilinear", align_corners=False
+        )
+        return masks
+
+
+class EfficientViTSamPredictor:
+    def __init__(self, sam_model: EfficientViTSam) -> None:
+        self.model = sam_model
+        self.reset_image()
+
+    @property
+    def transform(self):
+        return self
+
+    @property
+    def device(self):
+        return get_device(self.model)
+
+    def reset_image(self) -> None:
+        self.is_image_set = False
+        self.features = None
+        self.original_size = None
+        self.input_size = None
+
+    def apply_coords(self, coords: np.ndarray, im_size=None) -> np.ndarray:
+        old_h, old_w = self.original_size
+        new_h, new_w = self.input_size
+        coords = copy.deepcopy(coords).astype(float)
+        coords[..., 0] = coords[..., 0] * (new_w / old_w)
+        coords[..., 1] = coords[..., 1] * (new_h / old_h)
+        return coords
+
+    def apply_boxes(self, boxes: np.ndarray, im_size=None) -> np.ndarray:
+        boxes = self.apply_coords(boxes.reshape(-1, 2, 2))
+        return boxes.reshape(-1, 4)
+
+    @torch.inference_mode()
+    def set_image(self, image: np.ndarray, image_format: str = "RGB") -> None:
+        assert image_format in [
+            "RGB",
+            "BGR",
+        ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
+        if image_format != self.model.image_format:
+            image = image[..., ::-1]
+
+        self.reset_image()
+
+        self.original_size = image.shape[:2]
+        self.input_size = ResizeLongestSide.get_preprocess_shape(
+            *self.original_size, long_side_length=self.model.image_size[0]
+        )
+
+        torch_data = (
+            self.model.transform(image).unsqueeze(dim=0).to(get_device(self.model))
+        )
+        self.features = self.model.image_encoder(torch_data)
+        self.is_image_set = True
+
+    def predict(
+        self,
+        point_coords: np.ndarray or None = None,
+        point_labels: np.ndarray or None = None,
+        box: np.ndarray or None = None,
+        mask_input: np.ndarray or None = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Predict masks for the given input prompts, using the currently set image.
+
+        Arguments:
+          point_coords (np.ndarray or None): A Nx2 array of point prompts to the
+            model. Each point is in (X,Y) in pixels.
+          point_labels (np.ndarray or None): A length N array of labels for the
+            point prompts. 1 indicates a foreground point and 0 indicates a
+            background point.
+          box (np.ndarray or None): A length 4 array given a box prompt to the
+            model, in XYXY format.
+          mask_input (np.ndarray): A low resolution mask input to the model, typically
+            coming from a previous prediction iteration. Has form 1xHxW, where
+            for SAM, H=W=256.
+          multimask_output (bool): If true, the model will return three masks.
+            For ambiguous input prompts (such as a single click), this will often
+            produce better masks than a single prediction. If only a single
+            mask is needed, the model's predicted quality score can be used
+            to select the best mask. For non-ambiguous prompts, such as multiple
+            input prompts, multimask_output=False can give better results.
+          return_logits (bool): If true, returns un-thresholded masks logits
+            instead of a binary mask.
+
+        Returns:
+          (np.ndarray): The output masks in CxHxW format, where C is the
+            number of masks, and (H, W) is the original image size.
+          (np.ndarray): An array of length C containing the model's
+            predictions for the quality of each mask.
+          (np.ndarray): An array of shape CxHxW, where C is the number
+            of masks and H=W=256. These low resolution logits can be passed to
+            a subsequent iteration as mask input.
+        """
+        if not self.is_image_set:
+            raise RuntimeError(
+                "An image must be set with .set_image(...) before mask prediction."
+            )
+
+        device = get_device(self.model)
+        # Transform input prompts
+        coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
+        if point_coords is not None:
+            assert (
+                point_labels is not None
+            ), "point_labels must be supplied if point_coords is supplied."
+            point_coords = self.apply_coords(point_coords)
+            coords_torch = torch.as_tensor(
+                point_coords, dtype=torch.float, device=device
+            )
+            labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device)
+            coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
+        if box is not None:
+            box = self.apply_boxes(box)
+            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
+            box_torch = box_torch[None, :]
+        if mask_input is not None:
+            mask_input_torch = torch.as_tensor(
+                mask_input, dtype=torch.float, device=device
+            )
+            mask_input_torch = mask_input_torch[None, :, :, :]
+
+        masks, iou_predictions, low_res_masks = self.predict_torch(
+            coords_torch,
+            labels_torch,
+            box_torch,
+            mask_input_torch,
+            multimask_output,
+            return_logits=return_logits,
+        )
+
+        masks = masks[0].detach().cpu().numpy()
+        iou_predictions = iou_predictions[0].detach().cpu().numpy()
+        low_res_masks = low_res_masks[0].detach().cpu().numpy()
+        return masks, iou_predictions, low_res_masks
+
+    @torch.inference_mode()
+    def predict_torch(
+        self,
+        point_coords: torch.Tensor or None = None,
+        point_labels: torch.Tensor or None = None,
+        boxes: torch.Tensor or None = None,
+        mask_input: torch.Tensor or None = None,
+        multimask_output: bool = True,
+        return_logits: bool = False,
+    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        Predict masks for the given input prompts, using the currently set image.
+        Input prompts are batched torch tensors and are expected to already be
+        transformed to the input frame using ResizeLongestSide.
+
+        Arguments:
+          point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
+            model. Each point is in (X,Y) in pixels.
+          point_labels (torch.Tensor or None): A BxN array of labels for the
+            point prompts. 1 indicates a foreground point and 0 indicates a
+            background point.
+          box (np.ndarray or None): A Bx4 array given a box prompt to the
+            model, in XYXY format.
+          mask_input (np.ndarray): A low resolution mask input to the model, typically
+            coming from a previous prediction iteration. Has form Bx1xHxW, where
+            for SAM, H=W=256. Masks returned by a previous iteration of the
+            predict method do not need further transformation.
+          multimask_output (bool): If true, the model will return three masks.
+            For ambiguous input prompts (such as a single click), this will often
+            produce better masks than a single prediction. If only a single
+            mask is needed, the model's predicted quality score can be used
+            to select the best mask. For non-ambiguous prompts, such as multiple
+            input prompts, multimask_output=False can give better results.
+          return_logits (bool): If true, returns un-thresholded masks logits
+            instead of a binary mask.
+
+        Returns:
+          (torch.Tensor): The output masks in BxCxHxW format, where C is the
+            number of masks, and (H, W) is the original image size.
+          (torch.Tensor): An array of shape BxC containing the model's
+            predictions for the quality of each mask.
+          (torch.Tensor): An array of shape BxCxHxW, where C is the number
+            of masks and H=W=256. These low res logits can be passed to
+            a subsequent iteration as mask input.
+        """
+        if not self.is_image_set:
+            raise RuntimeError(
+                "An image must be set with .set_image(...) before mask prediction."
+            )
+
+        if point_coords is not None:
+            points = (point_coords, point_labels)
+        else:
+            points = None
+
+        # Embed prompts
+        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
+            points=points,
+            boxes=boxes,
+            masks=mask_input,
+        )
+
+        # Predict masks
+        low_res_masks, iou_predictions = self.model.mask_decoder(
+            image_embeddings=self.features,
+            image_pe=self.model.prompt_encoder.get_dense_pe(),
+            sparse_prompt_embeddings=sparse_embeddings,
+            dense_prompt_embeddings=dense_embeddings,
+            multimask_output=multimask_output,
+        )
+
+        # Upscale the masks to the original image resolution
+        masks = self.model.postprocess_masks(
+            low_res_masks, self.input_size, self.original_size
+        )
+
+        if not return_logits:
+            masks = masks > self.model.mask_threshold
+
+        return masks, iou_predictions, low_res_masks
+
+
+class EfficientViTSamAutomaticMaskGenerator(SamAutomaticMaskGenerator):
+    def __init__(
+        self,
+        model: EfficientViTSam,
+        points_per_side: int or None = 32,
+        points_per_batch: int = 64,
+        pred_iou_thresh: float = 0.88,
+        stability_score_thresh: float = 0.95,
+        stability_score_offset: float = 1.0,
+        box_nms_thresh: float = 0.7,
+        crop_n_layers: int = 0,
+        crop_nms_thresh: float = 0.7,
+        crop_overlap_ratio: float = 512 / 1500,
+        crop_n_points_downscale_factor: int = 1,
+        point_grids: list[np.ndarray] or None = None,
+        min_mask_region_area: int = 0,
+        output_mode: str = "binary_mask",
+    ) -> None:
+        assert (points_per_side is None) != (
+            point_grids is None
+        ), "Exactly one of points_per_side or point_grid must be provided."
+        if points_per_side is not None:
+            self.point_grids = build_all_layer_point_grids(
+                points_per_side,
+                crop_n_layers,
+                crop_n_points_downscale_factor,
+            )
+        elif point_grids is not None:
+            self.point_grids = point_grids
+        else:
+            raise ValueError("Can't have both points_per_side and point_grid be None.")
+
+        assert output_mode in [
+            "binary_mask",
+            "uncompressed_rle",
+            "coco_rle",
+        ], f"Unknown output_mode {output_mode}."
+        if output_mode == "coco_rle":
+            from pycocotools import \
+                mask as mask_utils  # type: ignore # noqa: F401
+
+        if min_mask_region_area > 0:
+            import cv2  # type: ignore # noqa: F401
+
+        self.predictor = EfficientViTSamPredictor(model)
+        self.points_per_batch = points_per_batch
+        self.pred_iou_thresh = pred_iou_thresh
+        self.stability_score_thresh = stability_score_thresh
+        self.stability_score_offset = stability_score_offset
+        self.box_nms_thresh = box_nms_thresh
+        self.crop_n_layers = crop_n_layers
+        self.crop_nms_thresh = crop_nms_thresh
+        self.crop_overlap_ratio = crop_overlap_ratio
+        self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
+        self.min_mask_region_area = min_mask_region_area
+        self.output_mode = output_mode
+
+
+def build_efficientvit_sam(
+    image_encoder: EfficientViTSamImageEncoder, image_size: int
+) -> EfficientViTSam:
+    return EfficientViTSam(
+        image_encoder=image_encoder,
+        prompt_encoder=PromptEncoder(
+            embed_dim=256,
+            image_embedding_size=(64, 64),
+            input_image_size=(1024, 1024),
+            mask_in_chans=16,
+        ),
+        mask_decoder=MaskDecoder(
+            num_multimask_outputs=3,
+            transformer=TwoWayTransformer(
+                depth=2,
+                embedding_dim=256,
+                mlp_dim=2048,
+                num_heads=8,
+            ),
+            transformer_dim=256,
+            iou_head_depth=3,
+            iou_head_hidden_dim=256,
+        ),
+        image_size=(1024, image_size),
+    )
+
+
+def efficientvit_sam_l0(image_size: int = 512, **kwargs) -> EfficientViTSam:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l0
+
+    backbone = efficientvit_backbone_l0(**kwargs)
+
+    neck = SamNeck(
+        fid_list=["stage4", "stage3", "stage2"],
+        in_channel_list=[512, 256, 128],
+        head_width=256,
+        head_depth=4,
+        expand_ratio=1,
+        middle_op="fmb",
+    )
+
+    image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+    return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_l1(image_size: int = 512, **kwargs) -> EfficientViTSam:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l1
+
+    backbone = efficientvit_backbone_l1(**kwargs)
+
+    neck = SamNeck(
+        fid_list=["stage4", "stage3", "stage2"],
+        in_channel_list=[512, 256, 128],
+        head_width=256,
+        head_depth=8,
+        expand_ratio=1,
+        middle_op="fmb",
+    )
+
+    image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+    return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_l2(image_size: int = 512, **kwargs) -> EfficientViTSam:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l2
+
+    backbone = efficientvit_backbone_l2(**kwargs)
+
+    neck = SamNeck(
+        fid_list=["stage4", "stage3", "stage2"],
+        in_channel_list=[512, 256, 128],
+        head_width=256,
+        head_depth=12,
+        expand_ratio=1,
+        middle_op="fmb",
+    )
+
+    image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+    return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_xl0(image_size: int = 1024, **kwargs) -> EfficientViTSam:
+    from efficientvit.models.efficientvit.backbone import \
+        EfficientViTLargeBackbone
+
+    backbone = EfficientViTLargeBackbone(
+        width_list=[32, 64, 128, 256, 512, 1024],
+        depth_list=[0, 1, 1, 2, 3, 3],
+        block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
+        expand_list=[1, 4, 4, 4, 4, 6],
+        fewer_norm_list=[False, False, False, False, True, True],
+        **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+    )
+
+    neck = SamNeck(
+        fid_list=["stage5", "stage4", "stage3"],
+        in_channel_list=[1024, 512, 256],
+        head_width=256,
+        head_depth=6,
+        expand_ratio=4,
+        middle_op="fmb",
+    )
+
+    image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+    return build_efficientvit_sam(image_encoder, image_size)
+
+
+def efficientvit_sam_xl1(image_size: int = 1024, **kwargs) -> EfficientViTSam:
+    from src.efficientvit.models.efficientvit.backbone import \
+        EfficientViTLargeBackbone
+
+    backbone = EfficientViTLargeBackbone(
+        width_list=[32, 64, 128, 256, 512, 1024],
+        depth_list=[1, 2, 2, 4, 6, 6],
+        block_list=["res", "fmb", "fmb", "fmb", "att@3", "att@3"],
+        expand_list=[1, 4, 4, 4, 4, 6],
+        fewer_norm_list=[False, False, False, False, True, True],
+        **build_kwargs_from_config(kwargs, EfficientViTLargeBackbone),
+    )
+
+    neck = SamNeck(
+        fid_list=["stage5", "stage4", "stage3"],
+        in_channel_list=[1024, 512, 256],
+        head_width=256,
+        head_depth=12,
+        expand_ratio=4,
+        middle_op="fmb",
+    )
+
+    image_encoder = EfficientViTSamImageEncoder(backbone, neck)
+    return build_efficientvit_sam(image_encoder, image_size)
diff --git a/src/efficientvit/models/efficientvit/seg.py b/src/efficientvit/models/efficientvit/seg.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2e4bf44a8e13a66b63cbf39f74dbee02b9d7045
--- /dev/null
+++ b/src/efficientvit/models/efficientvit/seg.py
@@ -0,0 +1,355 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+
+from src.efficientvit.models.efficientvit.backbone import (
+    EfficientViTBackbone, EfficientViTLargeBackbone)
+from src.efficientvit.models.nn import (ConvLayer, DAGBlock, FusedMBConv,
+                                    IdentityLayer, MBConv, OpSequential,
+                                    ResidualBlock, UpSampleLayer)
+from src.efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = [
+    "EfficientViTSeg",
+    "efficientvit_seg_b0",
+    "efficientvit_seg_b1",
+    "efficientvit_seg_b2",
+    "efficientvit_seg_b3",
+    "efficientvit_seg_l1",
+    "efficientvit_seg_l2",
+]
+
+
+class SegHead(DAGBlock):
+    def __init__(
+        self,
+        fid_list: list[str],
+        in_channel_list: list[int],
+        stride_list: list[int],
+        head_stride: int,
+        head_width: int,
+        head_depth: int,
+        expand_ratio: float,
+        middle_op: str,
+        final_expand: float or None,
+        n_classes: int,
+        dropout=0,
+        norm="bn2d",
+        act_func="hswish",
+    ):
+        inputs = {}
+        for fid, in_channel, stride in zip(fid_list, in_channel_list, stride_list):
+            factor = stride // head_stride
+            if factor == 1:
+                inputs[fid] = ConvLayer(
+                    in_channel, head_width, 1, norm=norm, act_func=None
+                )
+            else:
+                inputs[fid] = OpSequential(
+                    [
+                        ConvLayer(in_channel, head_width, 1, norm=norm, act_func=None),
+                        UpSampleLayer(factor=factor),
+                    ]
+                )
+
+        middle = []
+        for _ in range(head_depth):
+            if middle_op == "mbconv":
+                block = MBConv(
+                    head_width,
+                    head_width,
+                    expand_ratio=expand_ratio,
+                    norm=norm,
+                    act_func=(act_func, act_func, None),
+                )
+            elif middle_op == "fmbconv":
+                block = FusedMBConv(
+                    head_width,
+                    head_width,
+                    expand_ratio=expand_ratio,
+                    norm=norm,
+                    act_func=(act_func, None),
+                )
+            else:
+                raise NotImplementedError
+            middle.append(ResidualBlock(block, IdentityLayer()))
+        middle = OpSequential(middle)
+
+        outputs = {
+            "segout": OpSequential(
+                [
+                    (
+                        None
+                        if final_expand is None
+                        else ConvLayer(
+                            head_width,
+                            head_width * final_expand,
+                            1,
+                            norm=norm,
+                            act_func=act_func,
+                        )
+                    ),
+                    ConvLayer(
+                        head_width * (final_expand or 1),
+                        n_classes,
+                        1,
+                        use_bias=True,
+                        dropout=dropout,
+                        norm=None,
+                        act_func=None,
+                    ),
+                ]
+            )
+        }
+
+        super(SegHead, self).__init__(
+            inputs, "add", None, middle=middle, outputs=outputs
+        )
+
+
+class EfficientViTSeg(nn.Module):
+    def __init__(
+        self, backbone: EfficientViTBackbone or EfficientViTLargeBackbone, head: SegHead
+    ) -> None:
+        super().__init__()
+        self.backbone = backbone
+        self.head = head
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        feed_dict = self.backbone(x)
+        feed_dict = self.head(feed_dict)
+
+        return feed_dict["segout"]
+
+
+def efficientvit_seg_b0(dataset: str, **kwargs) -> EfficientViTSeg:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b0
+
+    backbone = efficientvit_backbone_b0(**kwargs)
+
+    if dataset == "cityscapes":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[128, 64, 32],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=32,
+            head_depth=1,
+            expand_ratio=4,
+            middle_op="mbconv",
+            final_expand=4,
+            n_classes=19,
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    else:
+        raise NotImplementedError
+    model = EfficientViTSeg(backbone, head)
+    return model
+
+
+def efficientvit_seg_b1(dataset: str, **kwargs) -> EfficientViTSeg:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b1
+
+    backbone = efficientvit_backbone_b1(**kwargs)
+
+    if dataset == "cityscapes":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[256, 128, 64],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=64,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="mbconv",
+            final_expand=4,
+            n_classes=19,
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    elif dataset == "ade20k":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[256, 128, 64],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=64,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="mbconv",
+            final_expand=None,
+            n_classes=150,
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    else:
+        raise NotImplementedError
+    model = EfficientViTSeg(backbone, head)
+    return model
+
+
+def efficientvit_seg_b2(dataset: str, **kwargs) -> EfficientViTSeg:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b2
+
+    backbone = efficientvit_backbone_b2(**kwargs)
+
+    if dataset == "cityscapes":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[384, 192, 96],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=96,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="mbconv",
+            final_expand=4,
+            n_classes=19,
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    elif dataset == "ade20k":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[384, 192, 96],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=96,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="mbconv",
+            final_expand=None,
+            n_classes=150,
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    else:
+        raise NotImplementedError
+    model = EfficientViTSeg(backbone, head)
+    return model
+
+
+def efficientvit_seg_b3(dataset: str, **kwargs) -> EfficientViTSeg:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_b3
+
+    backbone = efficientvit_backbone_b3(**kwargs)
+
+    if dataset == "cityscapes":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[512, 256, 128],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=128,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="mbconv",
+            final_expand=4,
+            n_classes=19,
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    elif dataset == "ade20k":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[512, 256, 128],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=128,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="mbconv",
+            final_expand=None,
+            n_classes=150,
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    else:
+        raise NotImplementedError
+    model = EfficientViTSeg(backbone, head)
+    return model
+
+
+def efficientvit_seg_l1(dataset: str, **kwargs) -> EfficientViTSeg:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l1
+
+    backbone = efficientvit_backbone_l1(**kwargs)
+
+    if dataset == "cityscapes":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[512, 256, 128],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=256,
+            head_depth=3,
+            expand_ratio=1,
+            middle_op="fmbconv",
+            final_expand=None,
+            n_classes=19,
+            act_func="gelu",
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    elif dataset == "ade20k":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[512, 256, 128],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=128,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="fmbconv",
+            final_expand=8,
+            n_classes=150,
+            act_func="gelu",
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    else:
+        raise NotImplementedError
+    model = EfficientViTSeg(backbone, head)
+    return model
+
+
+def efficientvit_seg_l2(dataset: str, **kwargs) -> EfficientViTSeg:
+    from efficientvit.models.efficientvit.backbone import \
+        efficientvit_backbone_l2
+
+    backbone = efficientvit_backbone_l2(**kwargs)
+
+    if dataset == "cityscapes":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[512, 256, 128],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=256,
+            head_depth=5,
+            expand_ratio=1,
+            middle_op="fmbconv",
+            final_expand=None,
+            n_classes=19,
+            act_func="gelu",
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    elif dataset == "ade20k":
+        head = SegHead(
+            fid_list=["stage4", "stage3", "stage2"],
+            in_channel_list=[512, 256, 128],
+            stride_list=[32, 16, 8],
+            head_stride=8,
+            head_width=128,
+            head_depth=3,
+            expand_ratio=4,
+            middle_op="fmbconv",
+            final_expand=8,
+            n_classes=150,
+            act_func="gelu",
+            **build_kwargs_from_config(kwargs, SegHead),
+        )
+    else:
+        raise NotImplementedError
+    model = EfficientViTSeg(backbone, head)
+    return model
diff --git a/src/efficientvit/models/nn/__init__.py b/src/efficientvit/models/nn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6152158a1a8a0b4d2fc53622bdf338fbf34809d
--- /dev/null
+++ b/src/efficientvit/models/nn/__init__.py
@@ -0,0 +1,8 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .act import *
+from .drop import *
+from .norm import *
+from .ops import *
diff --git a/src/efficientvit/models/nn/act.py b/src/efficientvit/models/nn/act.py
new file mode 100644
index 0000000000000000000000000000000000000000..31d439e24a3453222265c63593537942657ff8eb
--- /dev/null
+++ b/src/efficientvit/models/nn/act.py
@@ -0,0 +1,30 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from functools import partial
+
+import torch.nn as nn
+
+from src.efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = ["build_act"]
+
+
+# register activation function here
+REGISTERED_ACT_DICT: dict[str, type] = {
+    "relu": nn.ReLU,
+    "relu6": nn.ReLU6,
+    "hswish": nn.Hardswish,
+    "silu": nn.SiLU,
+    "gelu": partial(nn.GELU, approximate="tanh"),
+}
+
+
+def build_act(name: str, **kwargs) -> nn.Module or None:
+    if name in REGISTERED_ACT_DICT:
+        act_cls = REGISTERED_ACT_DICT[name]
+        args = build_kwargs_from_config(kwargs, act_cls)
+        return act_cls(**args)
+    else:
+        return None
diff --git a/src/efficientvit/models/nn/drop.py b/src/efficientvit/models/nn/drop.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c10aa3dbc89a360fd4c19b1cb7172c2ceea71eb
--- /dev/null
+++ b/src/efficientvit/models/nn/drop.py
@@ -0,0 +1,98 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from src.efficientvit.apps.trainer.run_config import Scheduler
+from src.efficientvit.models.nn.ops import IdentityLayer, ResidualBlock
+from src.efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = ["apply_drop_func"]
+
+
+def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None:
+    if drop_config is None:
+        return
+
+    drop_lookup_table = {
+        "droppath": apply_droppath,
+    }
+
+    drop_func = drop_lookup_table[drop_config["name"]]
+    drop_kwargs = build_kwargs_from_config(drop_config, drop_func)
+
+    drop_func(network, **drop_kwargs)
+
+
+def apply_droppath(
+    network: nn.Module,
+    drop_prob: float,
+    linear_decay=True,
+    scheduled=True,
+    skip=0,
+) -> None:
+    all_valid_blocks = []
+    for m in network.modules():
+        for name, sub_module in m.named_children():
+            if isinstance(sub_module, ResidualBlock) and isinstance(
+                sub_module.shortcut, IdentityLayer
+            ):
+                all_valid_blocks.append((m, name, sub_module))
+    all_valid_blocks = all_valid_blocks[skip:]
+    for i, (m, name, sub_module) in enumerate(all_valid_blocks):
+        prob = (
+            drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob
+        )
+        new_module = DropPathResidualBlock(
+            sub_module.main,
+            sub_module.shortcut,
+            sub_module.post_act,
+            sub_module.pre_norm,
+            prob,
+            scheduled,
+        )
+        m._modules[name] = new_module
+
+
+class DropPathResidualBlock(ResidualBlock):
+    def __init__(
+        self,
+        main: nn.Module,
+        shortcut: nn.Module or None,
+        post_act=None,
+        pre_norm: nn.Module or None = None,
+        ######################################
+        drop_prob: float = 0,
+        scheduled=True,
+    ):
+        super().__init__(main, shortcut, post_act, pre_norm)
+
+        self.drop_prob = drop_prob
+        self.scheduled = scheduled
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if (
+            not self.training
+            or self.drop_prob == 0
+            or not isinstance(self.shortcut, IdentityLayer)
+        ):
+            return ResidualBlock.forward(self, x)
+        else:
+            drop_prob = self.drop_prob
+            if self.scheduled:
+                drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1)
+            keep_prob = 1 - drop_prob
+
+            shape = (x.shape[0],) + (1,) * (x.ndim - 1)
+            random_tensor = keep_prob + torch.rand(
+                shape, dtype=x.dtype, device=x.device
+            )
+            random_tensor.floor_()  # binarize
+
+            res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x)
+            if self.post_act:
+                res = self.post_act(res)
+            return res
diff --git a/src/efficientvit/models/nn/norm.py b/src/efficientvit/models/nn/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..03fcacba84e6b2258b4fc4b893f6e2e5151257bb
--- /dev/null
+++ b/src/efficientvit/models/nn/norm.py
@@ -0,0 +1,157 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from src.efficientvit.models.utils import build_kwargs_from_config
+
+__all__ = ["LayerNorm2d", "build_norm", "reset_bn", "set_norm_eps"]
+
+
+class LayerNorm2d(nn.LayerNorm):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        out = x - torch.mean(x, dim=1, keepdim=True)
+        out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
+        if self.elementwise_affine:
+            out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
+        return out
+
+
+# register normalization function here
+REGISTERED_NORM_DICT: dict[str, type] = {
+    "bn2d": nn.BatchNorm2d,
+    "ln": nn.LayerNorm,
+    "ln2d": LayerNorm2d,
+}
+
+
+def build_norm(name="bn2d", num_features=None, **kwargs) -> nn.Module or None:
+    if name in ["ln", "ln2d"]:
+        kwargs["normalized_shape"] = num_features
+    else:
+        kwargs["num_features"] = num_features
+    if name in REGISTERED_NORM_DICT:
+        norm_cls = REGISTERED_NORM_DICT[name]
+        args = build_kwargs_from_config(kwargs, norm_cls)
+        return norm_cls(**args)
+    else:
+        return None
+
+
+def reset_bn(
+    model: nn.Module,
+    data_loader: list,
+    sync=True,
+    progress_bar=False,
+) -> None:
+    import copy
+
+    import torch.nn.functional as F
+    from tqdm import tqdm
+
+    from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
+    from efficientvit.models.utils import get_device, list_join
+
+    bn_mean = {}
+    bn_var = {}
+
+    tmp_model = copy.deepcopy(model)
+    for name, m in tmp_model.named_modules():
+        if isinstance(m, _BatchNorm):
+            bn_mean[name] = AverageMeter(is_distributed=False)
+            bn_var[name] = AverageMeter(is_distributed=False)
+
+            def new_forward(bn, mean_est, var_est):
+                def lambda_forward(x):
+                    x = x.contiguous()
+                    if sync:
+                        batch_mean = (
+                            x.mean(0, keepdim=True)
+                            .mean(2, keepdim=True)
+                            .mean(3, keepdim=True)
+                        )  # 1, C, 1, 1
+                        batch_mean = sync_tensor(batch_mean, reduce="cat")
+                        batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)
+
+                        batch_var = (x - batch_mean) * (x - batch_mean)
+                        batch_var = (
+                            batch_var.mean(0, keepdim=True)
+                            .mean(2, keepdim=True)
+                            .mean(3, keepdim=True)
+                        )
+                        batch_var = sync_tensor(batch_var, reduce="cat")
+                        batch_var = torch.mean(batch_var, dim=0, keepdim=True)
+                    else:
+                        batch_mean = (
+                            x.mean(0, keepdim=True)
+                            .mean(2, keepdim=True)
+                            .mean(3, keepdim=True)
+                        )  # 1, C, 1, 1
+                        batch_var = (x - batch_mean) * (x - batch_mean)
+                        batch_var = (
+                            batch_var.mean(0, keepdim=True)
+                            .mean(2, keepdim=True)
+                            .mean(3, keepdim=True)
+                        )
+
+                    batch_mean = torch.squeeze(batch_mean)
+                    batch_var = torch.squeeze(batch_var)
+
+                    mean_est.update(batch_mean.data, x.size(0))
+                    var_est.update(batch_var.data, x.size(0))
+
+                    # bn forward using calculated mean & var
+                    _feature_dim = batch_mean.shape[0]
+                    return F.batch_norm(
+                        x,
+                        batch_mean,
+                        batch_var,
+                        bn.weight[:_feature_dim],
+                        bn.bias[:_feature_dim],
+                        False,
+                        0.0,
+                        bn.eps,
+                    )
+
+                return lambda_forward
+
+            m.forward = new_forward(m, bn_mean[name], bn_var[name])
+
+    # skip if there is no batch normalization layers in the network
+    if len(bn_mean) == 0:
+        return
+
+    tmp_model.eval()
+    with torch.no_grad():
+        with tqdm(
+            total=len(data_loader),
+            desc="reset bn",
+            disable=not progress_bar or not is_master(),
+        ) as t:
+            for images in data_loader:
+                images = images.to(get_device(tmp_model))
+                tmp_model(images)
+                t.set_postfix(
+                    {
+                        "bs": images.size(0),
+                        "res": list_join(images.shape[-2:], "x"),
+                    }
+                )
+                t.update()
+
+    for name, m in model.named_modules():
+        if name in bn_mean and bn_mean[name].count > 0:
+            feature_dim = bn_mean[name].avg.size(0)
+            assert isinstance(m, _BatchNorm)
+            m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
+            m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
+
+
+def set_norm_eps(model: nn.Module, eps: float or None = None) -> None:
+    for m in model.modules():
+        if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
+            if eps is not None:
+                m.eps = eps
diff --git a/src/efficientvit/models/nn/ops.py b/src/efficientvit/models/nn/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fd0716834c091142168f56037aead9527180222
--- /dev/null
+++ b/src/efficientvit/models/nn/ops.py
@@ -0,0 +1,585 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+
+from src.efficientvit.models.nn.act import build_act
+from src.efficientvit.models.nn.norm import build_norm
+from src.efficientvit.models.utils import (get_same_padding, list_sum, resize,
+                                       val2list, val2tuple)
+
+__all__ = [
+    "ConvLayer",
+    "UpSampleLayer",
+    "LinearLayer",
+    "IdentityLayer",
+    "DSConv",
+    "MBConv",
+    "FusedMBConv",
+    "ResBlock",
+    "LiteMLA",
+    "EfficientViTBlock",
+    "ResidualBlock",
+    "DAGBlock",
+    "OpSequential",
+]
+
+
+#################################################################################
+#                             Basic Layers                                      #
+#################################################################################
+
+
+class ConvLayer(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size=3,
+        stride=1,
+        dilation=1,
+        groups=1,
+        use_bias=False,
+        dropout=0,
+        norm="bn2d",
+        act_func="relu",
+    ):
+        super(ConvLayer, self).__init__()
+
+        padding = get_same_padding(kernel_size)
+        padding *= dilation
+
+        self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
+        self.conv = nn.Conv2d(
+            in_channels,
+            out_channels,
+            kernel_size=(kernel_size, kernel_size),
+            stride=(stride, stride),
+            padding=padding,
+            dilation=(dilation, dilation),
+            groups=groups,
+            bias=use_bias,
+        )
+        self.norm = build_norm(norm, num_features=out_channels)
+        self.act = build_act(act_func)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.dropout is not None:
+            x = self.dropout(x)
+        x = self.conv(x)
+        if self.norm:
+            x = self.norm(x)
+        if self.act:
+            x = self.act(x)
+        return x
+
+
+class UpSampleLayer(nn.Module):
+    def __init__(
+        self,
+        mode="bicubic",
+        size: int or tuple[int, int] or list[int] or None = None,
+        factor=2,
+        align_corners=False,
+    ):
+        super(UpSampleLayer, self).__init__()
+        self.mode = mode
+        self.size = val2list(size, 2) if size is not None else None
+        self.factor = None if self.size is not None else factor
+        self.align_corners = align_corners
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if (
+            self.size is not None and tuple(x.shape[-2:]) == self.size
+        ) or self.factor == 1:
+            return x
+        return resize(x, self.size, self.factor, self.mode, self.align_corners)
+
+
+class LinearLayer(nn.Module):
+    def __init__(
+        self,
+        in_features: int,
+        out_features: int,
+        use_bias=True,
+        dropout=0,
+        norm=None,
+        act_func=None,
+    ):
+        super(LinearLayer, self).__init__()
+
+        self.dropout = nn.Dropout(dropout, inplace=False) if dropout > 0 else None
+        self.linear = nn.Linear(in_features, out_features, use_bias)
+        self.norm = build_norm(norm, num_features=out_features)
+        self.act = build_act(act_func)
+
+    def _try_squeeze(self, x: torch.Tensor) -> torch.Tensor:
+        if x.dim() > 2:
+            x = torch.flatten(x, start_dim=1)
+        return x
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self._try_squeeze(x)
+        if self.dropout:
+            x = self.dropout(x)
+        x = self.linear(x)
+        if self.norm:
+            x = self.norm(x)
+        if self.act:
+            x = self.act(x)
+        return x
+
+
+class IdentityLayer(nn.Module):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return x
+
+
+#################################################################################
+#                             Basic Blocks                                      #
+#################################################################################
+
+
+class DSConv(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size=3,
+        stride=1,
+        use_bias=False,
+        norm=("bn2d", "bn2d"),
+        act_func=("relu6", None),
+    ):
+        super(DSConv, self).__init__()
+
+        use_bias = val2tuple(use_bias, 2)
+        norm = val2tuple(norm, 2)
+        act_func = val2tuple(act_func, 2)
+
+        self.depth_conv = ConvLayer(
+            in_channels,
+            in_channels,
+            kernel_size,
+            stride,
+            groups=in_channels,
+            norm=norm[0],
+            act_func=act_func[0],
+            use_bias=use_bias[0],
+        )
+        self.point_conv = ConvLayer(
+            in_channels,
+            out_channels,
+            1,
+            norm=norm[1],
+            act_func=act_func[1],
+            use_bias=use_bias[1],
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.depth_conv(x)
+        x = self.point_conv(x)
+        return x
+
+
+class MBConv(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size=3,
+        stride=1,
+        mid_channels=None,
+        expand_ratio=6,
+        use_bias=False,
+        norm=("bn2d", "bn2d", "bn2d"),
+        act_func=("relu6", "relu6", None),
+    ):
+        super(MBConv, self).__init__()
+
+        use_bias = val2tuple(use_bias, 3)
+        norm = val2tuple(norm, 3)
+        act_func = val2tuple(act_func, 3)
+        mid_channels = mid_channels or round(in_channels * expand_ratio)
+
+        self.inverted_conv = ConvLayer(
+            in_channels,
+            mid_channels,
+            1,
+            stride=1,
+            norm=norm[0],
+            act_func=act_func[0],
+            use_bias=use_bias[0],
+        )
+        self.depth_conv = ConvLayer(
+            mid_channels,
+            mid_channels,
+            kernel_size,
+            stride=stride,
+            groups=mid_channels,
+            norm=norm[1],
+            act_func=act_func[1],
+            use_bias=use_bias[1],
+        )
+        self.point_conv = ConvLayer(
+            mid_channels,
+            out_channels,
+            1,
+            norm=norm[2],
+            act_func=act_func[2],
+            use_bias=use_bias[2],
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.inverted_conv(x)
+        x = self.depth_conv(x)
+        x = self.point_conv(x)
+        return x
+
+
+class FusedMBConv(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size=3,
+        stride=1,
+        mid_channels=None,
+        expand_ratio=6,
+        groups=1,
+        use_bias=False,
+        norm=("bn2d", "bn2d"),
+        act_func=("relu6", None),
+    ):
+        super().__init__()
+        use_bias = val2tuple(use_bias, 2)
+        norm = val2tuple(norm, 2)
+        act_func = val2tuple(act_func, 2)
+
+        mid_channels = mid_channels or round(in_channels * expand_ratio)
+
+        self.spatial_conv = ConvLayer(
+            in_channels,
+            mid_channels,
+            kernel_size,
+            stride,
+            groups=groups,
+            use_bias=use_bias[0],
+            norm=norm[0],
+            act_func=act_func[0],
+        )
+        self.point_conv = ConvLayer(
+            mid_channels,
+            out_channels,
+            1,
+            use_bias=use_bias[1],
+            norm=norm[1],
+            act_func=act_func[1],
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.spatial_conv(x)
+        x = self.point_conv(x)
+        return x
+
+
+class ResBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size=3,
+        stride=1,
+        mid_channels=None,
+        expand_ratio=1,
+        use_bias=False,
+        norm=("bn2d", "bn2d"),
+        act_func=("relu6", None),
+    ):
+        super().__init__()
+        use_bias = val2tuple(use_bias, 2)
+        norm = val2tuple(norm, 2)
+        act_func = val2tuple(act_func, 2)
+
+        mid_channels = mid_channels or round(in_channels * expand_ratio)
+
+        self.conv1 = ConvLayer(
+            in_channels,
+            mid_channels,
+            kernel_size,
+            stride,
+            use_bias=use_bias[0],
+            norm=norm[0],
+            act_func=act_func[0],
+        )
+        self.conv2 = ConvLayer(
+            mid_channels,
+            out_channels,
+            kernel_size,
+            1,
+            use_bias=use_bias[1],
+            norm=norm[1],
+            act_func=act_func[1],
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.conv1(x)
+        x = self.conv2(x)
+        return x
+
+
+class LiteMLA(nn.Module):
+    r"""Lightweight multi-scale linear attention"""
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        heads: int or None = None,
+        heads_ratio: float = 1.0,
+        dim=8,
+        use_bias=False,
+        norm=(None, "bn2d"),
+        act_func=(None, None),
+        kernel_func="relu",
+        scales: tuple[int, ...] = (5,),
+        eps=1.0e-15,
+    ):
+        super(LiteMLA, self).__init__()
+        self.eps = eps
+        heads = heads or int(in_channels // dim * heads_ratio)
+
+        total_dim = heads * dim
+
+        use_bias = val2tuple(use_bias, 2)
+        norm = val2tuple(norm, 2)
+        act_func = val2tuple(act_func, 2)
+
+        self.dim = dim
+        self.qkv = ConvLayer(
+            in_channels,
+            3 * total_dim,
+            1,
+            use_bias=use_bias[0],
+            norm=norm[0],
+            act_func=act_func[0],
+        )
+        self.aggreg = nn.ModuleList(
+            [
+                nn.Sequential(
+                    nn.Conv2d(
+                        3 * total_dim,
+                        3 * total_dim,
+                        scale,
+                        padding=get_same_padding(scale),
+                        groups=3 * total_dim,
+                        bias=use_bias[0],
+                    ),
+                    nn.Conv2d(
+                        3 * total_dim,
+                        3 * total_dim,
+                        1,
+                        groups=3 * heads,
+                        bias=use_bias[0],
+                    ),
+                )
+                for scale in scales
+            ]
+        )
+        self.kernel_func = build_act(kernel_func, inplace=False)
+
+        self.proj = ConvLayer(
+            total_dim * (1 + len(scales)),
+            out_channels,
+            1,
+            use_bias=use_bias[1],
+            norm=norm[1],
+            act_func=act_func[1],
+        )
+
+    @autocast(enabled=False)
+    def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
+        B, _, H, W = list(qkv.size())
+
+        if qkv.dtype == torch.float16:
+            qkv = qkv.float()
+
+        qkv = torch.reshape(
+            qkv,
+            (
+                B,
+                -1,
+                3 * self.dim,
+                H * W,
+            ),
+        )
+        qkv = torch.transpose(qkv, -1, -2)
+        q, k, v = (
+            qkv[..., 0 : self.dim],
+            qkv[..., self.dim : 2 * self.dim],
+            qkv[..., 2 * self.dim :],
+        )
+
+        # lightweight linear attention
+        q = self.kernel_func(q)
+        k = self.kernel_func(k)
+
+        # linear matmul
+        trans_k = k.transpose(-1, -2)
+
+        v = F.pad(v, (0, 1), mode="constant", value=1)
+        kv = torch.matmul(trans_k, v)
+        out = torch.matmul(q, kv)
+        out = out[..., :-1] / (out[..., -1:] + self.eps)
+
+        out = torch.transpose(out, -1, -2)
+        out = torch.reshape(out, (B, -1, H, W))
+        return out
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # generate multi-scale q, k, v
+        qkv = self.qkv(x)
+        multi_scale_qkv = [qkv]
+        for op in self.aggreg:
+            multi_scale_qkv.append(op(qkv))
+        multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
+
+        out = self.relu_linear_att(multi_scale_qkv)
+        out = self.proj(out)
+
+        return out
+
+
+class EfficientViTBlock(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        heads_ratio: float = 1.0,
+        dim=32,
+        expand_ratio: float = 4,
+        scales=(5,),
+        norm="bn2d",
+        act_func="hswish",
+    ):
+        super(EfficientViTBlock, self).__init__()
+        self.context_module = ResidualBlock(
+            LiteMLA(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                heads_ratio=heads_ratio,
+                dim=dim,
+                norm=(None, norm),
+                scales=scales,
+            ),
+            IdentityLayer(),
+        )
+        local_module = MBConv(
+            in_channels=in_channels,
+            out_channels=in_channels,
+            expand_ratio=expand_ratio,
+            use_bias=(True, True, False),
+            norm=(None, None, norm),
+            act_func=(act_func, act_func, None),
+        )
+        self.local_module = ResidualBlock(local_module, IdentityLayer())
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.context_module(x)
+        x = self.local_module(x)
+        return x
+
+
+#################################################################################
+#                             Functional Blocks                                 #
+#################################################################################
+
+
+class ResidualBlock(nn.Module):
+    def __init__(
+        self,
+        main: nn.Module or None,
+        shortcut: nn.Module or None,
+        post_act=None,
+        pre_norm: nn.Module or None = None,
+    ):
+        super(ResidualBlock, self).__init__()
+
+        self.pre_norm = pre_norm
+        self.main = main
+        self.shortcut = shortcut
+        self.post_act = build_act(post_act)
+
+    def forward_main(self, x: torch.Tensor) -> torch.Tensor:
+        if self.pre_norm is None:
+            return self.main(x)
+        else:
+            return self.main(self.pre_norm(x))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.main is None:
+            res = x
+        elif self.shortcut is None:
+            res = self.forward_main(x)
+        else:
+            res = self.forward_main(x) + self.shortcut(x)
+            if self.post_act:
+                res = self.post_act(res)
+        return res
+
+
+class DAGBlock(nn.Module):
+    def __init__(
+        self,
+        inputs: dict[str, nn.Module],
+        merge: str,
+        post_input: nn.Module or None,
+        middle: nn.Module,
+        outputs: dict[str, nn.Module],
+    ):
+        super(DAGBlock, self).__init__()
+
+        self.input_keys = list(inputs.keys())
+        self.input_ops = nn.ModuleList(list(inputs.values()))
+        self.merge = merge
+        self.post_input = post_input
+
+        self.middle = middle
+
+        self.output_keys = list(outputs.keys())
+        self.output_ops = nn.ModuleList(list(outputs.values()))
+
+    def forward(self, feature_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
+        feat = [
+            op(feature_dict[key]) for key, op in zip(self.input_keys, self.input_ops)
+        ]
+        if self.merge == "add":
+            feat = list_sum(feat)
+        elif self.merge == "cat":
+            feat = torch.concat(feat, dim=1)
+        else:
+            raise NotImplementedError
+        if self.post_input is not None:
+            feat = self.post_input(feat)
+        feat = self.middle(feat)
+        for key, op in zip(self.output_keys, self.output_ops):
+            feature_dict[key] = op(feat)
+        return feature_dict
+
+
+class OpSequential(nn.Module):
+    def __init__(self, op_list: list[nn.Module or None]):
+        super(OpSequential, self).__init__()
+        valid_op_list = []
+        for op in op_list:
+            if op is not None:
+                valid_op_list.append(op)
+        self.op_list = nn.ModuleList(valid_op_list)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        for op in self.op_list:
+            x = op(x)
+        return x
diff --git a/src/efficientvit/models/utils/__init__.py b/src/efficientvit/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0aab6b0a576b33e1e72029210f7b4232c9b7b8b6
--- /dev/null
+++ b/src/efficientvit/models/utils/__init__.py
@@ -0,0 +1,7 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from .list import *
+from .network import *
+from .random import *
diff --git a/src/efficientvit/models/utils/list.py b/src/efficientvit/models/utils/list.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a2c3291a88ab1d3cc77f7bc7d5eb475e9670a28
--- /dev/null
+++ b/src/efficientvit/models/utils/list.py
@@ -0,0 +1,57 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+__all__ = [
+    "list_sum",
+    "list_mean",
+    "weighted_list_sum",
+    "list_join",
+    "val2list",
+    "val2tuple",
+    "squeeze_list",
+]
+
+
+def list_sum(x: list) -> any:
+    return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
+
+
+def list_mean(x: list) -> any:
+    return list_sum(x) / len(x)
+
+
+def weighted_list_sum(x: list, weights: list) -> any:
+    assert len(x) == len(weights)
+    return (
+        x[0] * weights[0]
+        if len(x) == 1
+        else x[0] * weights[0] + weighted_list_sum(x[1:], weights[1:])
+    )
+
+
+def list_join(x: list, sep="\t", format_str="%s") -> str:
+    return sep.join([format_str % val for val in x])
+
+
+def val2list(x: list or tuple or any, repeat_time=1) -> list:
+    if isinstance(x, (list, tuple)):
+        return list(x)
+    return [x for _ in range(repeat_time)]
+
+
+def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple:
+    x = val2list(x)
+
+    # repeat elements if necessary
+    if len(x) > 0:
+        x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
+
+    return tuple(x)
+
+
+def squeeze_list(x: list or None) -> list or any:
+    if x is not None and len(x) == 1:
+        return x[0]
+    else:
+        return x
diff --git a/src/efficientvit/models/utils/network.py b/src/efficientvit/models/utils/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ba96ec255dc7543be2a7995fed58f7d139d2c75
--- /dev/null
+++ b/src/efficientvit/models/utils/network.py
@@ -0,0 +1,77 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import os
+from inspect import signature
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+__all__ = [
+    "is_parallel",
+    "get_device",
+    "get_same_padding",
+    "resize",
+    "build_kwargs_from_config",
+    "load_state_dict_from_file",
+]
+
+
+def is_parallel(model: nn.Module) -> bool:
+    return isinstance(
+        model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+    )
+
+
+def get_device(model: nn.Module) -> torch.device:
+    return model.parameters().__next__().device
+
+
+def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
+    if isinstance(kernel_size, tuple):
+        return tuple([get_same_padding(ks) for ks in kernel_size])
+    else:
+        assert kernel_size % 2 > 0, "kernel size should be odd number"
+        return kernel_size // 2
+
+
+def resize(
+    x: torch.Tensor,
+    size: any or None = None,
+    scale_factor: list[float] or None = None,
+    mode: str = "bicubic",
+    align_corners: bool or None = False,
+) -> torch.Tensor:
+    if mode in {"bilinear", "bicubic"}:
+        return F.interpolate(
+            x,
+            size=size,
+            scale_factor=scale_factor,
+            mode=mode,
+            align_corners=align_corners,
+        )
+    elif mode in {"nearest", "area"}:
+        return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
+    else:
+        raise NotImplementedError(f"resize(mode={mode}) not implemented.")
+
+
+def build_kwargs_from_config(config: dict, target_func: callable) -> dict[str, any]:
+    valid_keys = list(signature(target_func).parameters)
+    kwargs = {}
+    for key in config:
+        if key in valid_keys:
+            kwargs[key] = config[key]
+    return kwargs
+
+
+def load_state_dict_from_file(
+    file: str, only_state_dict=True
+) -> dict[str, torch.Tensor]:
+    file = os.path.realpath(os.path.expanduser(file))
+    checkpoint = torch.load(file, map_location="cpu")
+    if only_state_dict and "state_dict" in checkpoint:
+        checkpoint = checkpoint["state_dict"]
+    return checkpoint
diff --git a/src/efficientvit/models/utils/random.py b/src/efficientvit/models/utils/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..0257f7ab93a3781c159a917823c36d8ada976292
--- /dev/null
+++ b/src/efficientvit/models/utils/random.py
@@ -0,0 +1,73 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+import numpy as np
+import torch
+
+__all__ = [
+    "torch_randint",
+    "torch_random",
+    "torch_shuffle",
+    "torch_uniform",
+    "torch_random_choices",
+]
+
+
+def torch_randint(
+    low: int, high: int, generator: torch.Generator or None = None
+) -> int:
+    """uniform: [low, high)"""
+    if low == high:
+        return low
+    else:
+        assert low < high
+        return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))
+
+
+def torch_random(generator: torch.Generator or None = None) -> float:
+    """uniform distribution on the interval [0, 1)"""
+    return float(torch.rand(1, generator=generator))
+
+
+def torch_shuffle(
+    src_list: list[any], generator: torch.Generator or None = None
+) -> list[any]:
+    rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
+    return [src_list[i] for i in rand_indexes]
+
+
+def torch_uniform(
+    low: float, high: float, generator: torch.Generator or None = None
+) -> float:
+    """uniform distribution on the interval [low, high)"""
+    rand_val = torch_random(generator)
+    return (high - low) * rand_val + low
+
+
+def torch_random_choices(
+    src_list: list[any],
+    generator: torch.Generator or None = None,
+    k=1,
+    weight_list: list[float] or None = None,
+) -> any or list:
+    if weight_list is None:
+        rand_idx = torch.randint(
+            low=0, high=len(src_list), generator=generator, size=(k,)
+        )
+        out_list = [src_list[i] for i in rand_idx]
+    else:
+        assert len(weight_list) == len(src_list)
+        accumulate_weight_list = np.cumsum(weight_list)
+
+        out_list = []
+        for _ in range(k):
+            val = torch_uniform(0, accumulate_weight_list[-1], generator)
+            active_id = 0
+            for i, weight_val in enumerate(accumulate_weight_list):
+                active_id = i
+                if weight_val > val:
+                    break
+            out_list.append(src_list[active_id])
+
+    return out_list[0] if k == 1 else out_list
diff --git a/src/efficientvit/sam_model_zoo.py b/src/efficientvit/sam_model_zoo.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc0c1c290a924ddf93e5dae326b6f3a5d17c7a1
--- /dev/null
+++ b/src/efficientvit/sam_model_zoo.py
@@ -0,0 +1,53 @@
+# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
+# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
+# International Conference on Computer Vision (ICCV), 2023
+
+from src.efficientvit.models.efficientvit import (EfficientViTSam,
+                                              efficientvit_sam_l0,
+                                              efficientvit_sam_l1,
+                                              efficientvit_sam_l2,
+                                              efficientvit_sam_xl0,
+                                              efficientvit_sam_xl1)
+from src.efficientvit.models.nn.norm import set_norm_eps
+from src.efficientvit.models.utils import load_state_dict_from_file
+
+__all__ = ["create_sam_model"]
+
+
+REGISTERED_SAM_MODEL: dict[str, str] = {
+    "l0": "assets/checkpoints/sam/l0.pt",
+    "l1": "assets/checkpoints/sam/l1.pt",
+    "l2": "assets/checkpoints/sam/l2.pt",
+    "xl0": "assets/checkpoints/sam/xl0.pt",
+    "xl1": "assets/checkpoints/sam/xl1.pt",
+}
+
+
+def create_sam_model(
+    name: str, pretrained=True, weight_url: str or None = None, **kwargs
+) -> EfficientViTSam:
+    model_dict = {
+        "l0": efficientvit_sam_l0,
+        "l1": efficientvit_sam_l1,
+        "l2": efficientvit_sam_l2,
+        "xl0": efficientvit_sam_xl0,
+        "xl1": efficientvit_sam_xl1,
+    }
+
+    model_id = name.split("-")[0]
+    if model_id not in model_dict:
+        raise ValueError(
+            f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}"
+        )
+    else:
+        model = model_dict[model_id](**kwargs)
+    set_norm_eps(model, 1e-6)
+
+    if pretrained:
+        weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
+        if weight_url is None:
+            raise ValueError(f"Do not find the pretrained weight of {name}.")
+        else:
+            weight = load_state_dict_from_file(weight_url)
+            model.load_state_dict(weight)
+    return model
diff --git a/src/ip_adapter/attention_processor.py b/src/ip_adapter/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c38f0e4886895d987ea2b4b1f0694546a7728287
--- /dev/null
+++ b/src/ip_adapter/attention_processor.py
@@ -0,0 +1,424 @@
+# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+    import xformers
+    import xformers.ops
+
+    xformers_available = True
+except Exception as e:
+    xformers_available = False
+
+
+class AttnProcessor(nn.Module):
+    r"""
+    Default processor for performing attention-related computations.
+    """
+
+    def __init__(
+            self,
+            hidden_size=None,
+            cross_attention_dim=None,
+    ):
+        super().__init__()
+
+    def __call__(
+            self,
+            attn,
+            hidden_states,
+            encoder_hidden_states=None,
+            attention_mask=None,
+            temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class IPAttnProcessor(nn.Module):
+    r"""
+    Attention processor for IP-Adapater.
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`):
+            The number of channels in the `encoder_hidden_states`.
+        scale (`float`, defaults to 1.0):
+            the weight scale of image prompt.
+        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+            The context length of the image features.
+    """
+
+    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+        super().__init__()
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.scale = scale
+        self.num_tokens = num_tokens
+
+        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+    def __call__(
+            self,
+            attn,
+            hidden_states,
+            encoder_hidden_states=None,
+            attention_mask=None,
+            temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        else:
+            # get encoder_hidden_states, ip_hidden_states
+            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+            encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:,
+                                                                                             end_pos:, :]
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        if xformers_available:
+            hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+        else:
+            attention_probs = attn.get_attention_scores(query, key, attention_mask)
+            hidden_states = torch.bmm(attention_probs, value)
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # for ip-adapter
+        ip_key = self.to_k_ip(ip_hidden_states)
+        ip_value = self.to_v_ip(ip_hidden_states)
+
+        ip_key = attn.head_to_batch_dim(ip_key)
+        ip_value = attn.head_to_batch_dim(ip_value)
+
+        if xformers_available:
+            ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
+        else:
+            ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+            ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+        ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+        hidden_states = hidden_states + self.scale * ip_hidden_states
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+    def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+        # TODO attention_mask
+        query = query.contiguous()
+        key = key.contiguous()
+        value = value.contiguous()
+        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+        # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+        return hidden_states
+
+
+class AttnProcessor2_0(torch.nn.Module):
+    r"""
+    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+    """
+
+    def __init__(
+            self,
+            hidden_size=None,
+            cross_attention_dim=None,
+    ):
+        super().__init__()
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+    def __call__(
+            self,
+            attn,
+            hidden_states,
+            encoder_hidden_states=None,
+            attention_mask=None,
+            temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+class IPAttnProcessor2_0(torch.nn.Module):
+    r"""
+    Attention processor for IP-Adapater for PyTorch 2.0.
+    Args:
+        hidden_size (`int`):
+            The hidden size of the attention layer.
+        cross_attention_dim (`int`):
+            The number of channels in the `encoder_hidden_states`.
+        scale (`float`, defaults to 1.0):
+            the weight scale of image prompt.
+        num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+            The context length of the image features.
+    """
+
+    def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+        super().__init__()
+
+        if not hasattr(F, "scaled_dot_product_attention"):
+            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+        self.hidden_size = hidden_size
+        self.cross_attention_dim = cross_attention_dim
+        self.scale = scale
+        self.num_tokens = num_tokens
+
+        self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+        self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+    def __call__(
+            self,
+            attn,
+            hidden_states,
+            encoder_hidden_states=None,
+            attention_mask=None,
+            temb=None,
+    ):
+        residual = hidden_states
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+
+        if attention_mask is not None:
+            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+            # scaled_dot_product_attention expects attention_mask shape to be
+            # (batch, heads, source_length, target_length)
+            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states)
+
+        if encoder_hidden_states is None:
+            encoder_hidden_states = hidden_states
+        else:
+            # get encoder_hidden_states, ip_hidden_states
+            end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+            encoder_hidden_states, ip_hidden_states = (
+                encoder_hidden_states[:, :end_pos, :],
+                encoder_hidden_states[:, end_pos:, :],
+            )
+            if attn.norm_cross:
+                encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states)
+        value = attn.to_v(encoder_hidden_states)
+
+        inner_dim = key.shape[-1]
+        head_dim = inner_dim // attn.heads
+
+        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        hidden_states = F.scaled_dot_product_attention(
+            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+        )
+
+        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        hidden_states = hidden_states.to(query.dtype)
+
+        # for ip-adapter
+        ip_key = self.to_k_ip(ip_hidden_states)
+        ip_value = self.to_v_ip(ip_hidden_states)
+
+        ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+        ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+        # the output of sdp = (batch, num_heads, seq_len, head_dim)
+        # TODO: add support for attn.scale when we move to Torch 2.1
+        ip_hidden_states = F.scaled_dot_product_attention(
+            query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+        )
+        with torch.no_grad():
+            self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
+            # print(self.attn_map.shape)
+
+        ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+        ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+        hidden_states = hidden_states + self.scale * ip_hidden_states
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
\ No newline at end of file
diff --git a/src/ip_adapter/resampler.py b/src/ip_adapter/resampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..149ffa5c031bb18a948f471a83c99e031bea14a6
--- /dev/null
+++ b/src/ip_adapter/resampler.py
@@ -0,0 +1,120 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+import math
+
+import torch
+import torch.nn as nn
+
+
+# FFN
+def FeedForward(dim, mult=4):
+    inner_dim = int(dim * mult)
+    return nn.Sequential(
+        nn.LayerNorm(dim),
+        nn.Linear(dim, inner_dim, bias=False),
+        nn.GELU(),
+        nn.Linear(inner_dim, dim, bias=False),
+    )
+
+
+def reshape_tensor(x, heads):
+    bs, length, width = x.shape
+    # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+    x = x.view(bs, length, heads, -1)
+    # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+    x = x.transpose(1, 2)
+    # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+    x = x.reshape(bs, heads, length, -1)
+    return x
+
+
+class PerceiverAttention(nn.Module):
+    def __init__(self, *, dim, dim_head=64, heads=8):
+        super().__init__()
+        self.scale = dim_head ** -0.5
+        self.dim_head = dim_head
+        self.heads = heads
+        inner_dim = dim_head * heads
+
+        self.norm1 = nn.LayerNorm(dim)
+        self.norm2 = nn.LayerNorm(dim)
+
+        self.to_q = nn.Linear(dim, inner_dim, bias=False)
+        self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+        self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+    def forward(self, x, latents):
+        """
+        Args:
+            x (torch.Tensor): image features
+                shape (b, n1, D)
+            latent (torch.Tensor): latent features
+                shape (b, n2, D)
+        """
+        x = self.norm1(x)
+        latents = self.norm2(latents)
+
+        b, l, _ = latents.shape
+
+        q = self.to_q(latents)
+        kv_input = torch.cat((x, latents), dim=-2)
+        k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+        q = reshape_tensor(q, self.heads)
+        k = reshape_tensor(k, self.heads)
+        v = reshape_tensor(v, self.heads)
+
+        # attention
+        scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+        weight = (q * scale) @ (k * scale).transpose(-2, -1)  # More stable with f16 than dividing afterwards
+        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+        out = weight @ v
+
+        out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+        return self.to_out(out)
+
+
+class Resampler(nn.Module):
+    def __init__(
+            self,
+            dim=1024,
+            depth=8,
+            dim_head=64,
+            heads=16,
+            num_queries=8,
+            embedding_dim=768,
+            output_dim=1024,
+            ff_mult=4,
+    ):
+        super().__init__()
+
+        self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
+
+        self.proj_in = nn.Linear(embedding_dim, dim)
+
+        self.proj_out = nn.Linear(dim, output_dim)
+        self.norm_out = nn.LayerNorm(output_dim)
+
+        self.layers = nn.ModuleList([])
+        for _ in range(depth):
+            self.layers.append(
+                nn.ModuleList(
+                    [
+                        PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+                        FeedForward(dim=dim, mult=ff_mult),
+                    ]
+                )
+            )
+
+    def forward(self, x):
+
+        latents = self.latents.repeat(x.size(0), 1, 1)
+
+        x = self.proj_in(x)
+
+        for attn, ff in self.layers:
+            latents = attn(x, latents) + latents
+            latents = ff(latents) + latents
+
+        latents = self.proj_out(latents)
+        return self.norm_out(latents)
\ No newline at end of file
diff --git a/src/ip_adapter/utils.py b/src/ip_adapter/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a105f3701c15e8d3bbf838d79bacc51e91d0696
--- /dev/null
+++ b/src/ip_adapter/utils.py
@@ -0,0 +1,5 @@
+import torch.nn.functional as F
+
+
+def is_torch2_available():
+    return hasattr(F, "scaled_dot_product_attention")
diff --git a/src/pipelines/instantid_pipeline.py b/src/pipelines/instantid_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..029536a35c149d06c8b2da35c6a2cbc66180a683
--- /dev/null
+++ b/src/pipelines/instantid_pipeline.py
@@ -0,0 +1,720 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTextModelWithProjection,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+    FromSingleFileMixin,
+    IPAdapterMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+    AttnProcessor2_0,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers import StableDiffusionXLControlNetPipeline
+from PIL import Image
+from torchvision.transforms.functional import to_tensor
+from einops import rearrange
+from torch import einsum
+import math
+from torchvision.utils import save_image
+from diffusers.utils import load_image
+import cv2
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+class RegionControlNet_AttnProcessor:
+    def __init__(self, attention_op=None, controller=None, place_in_unet=None):
+        self.attention_op = attention_op
+        self.controller = controller
+        self.place_in_unet = place_in_unet
+
+    def __call__(
+            self,
+            attn,
+            hidden_states: torch.FloatTensor,
+            encoder_hidden_states: Optional[torch.FloatTensor] = None,
+            attention_mask: Optional[torch.FloatTensor] = None,
+            temb: Optional[torch.FloatTensor] = None,
+            scale: float = 1.0,
+            **cross_attention_kwargs
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states, *args)
+
+        is_cross = True
+        if encoder_hidden_states is None:
+            is_cross = False
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states, *args)
+        value = attn.to_v(encoder_hidden_states, *args)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
+        hidden_states = torch.bmm(attention_probs, value)
+
+
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+def revise_regionally_controlnet_forward(unet, controller):
+    def change_forward(unet, count, place_in_unet):
+        for name, layer in unet.named_children():
+            if layer.__class__.__name__ == 'Attention':
+                layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet))
+                if 'attn2' in name:
+                    count += 1
+            else:
+                count = change_forward(layer, count, place_in_unet)
+        return count
+
+    # use this to ensure the order
+    cross_attention_idx = change_forward(unet.down_blocks, 0, "down")
+    cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up")
+    cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid")
+    print(f'Number of attention layer registered {cross_attention_idx}')
+    controller.num_att_layers = cross_attention_idx*2
+
+class InstantidMultiConceptPipeline(StableDiffusionXLControlNetPipeline):
+    # leave controlnet out on purpose because it iterates with unet
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+    _optional_components = [
+        "tokenizer",
+        "tokenizer_2",
+        "text_encoder",
+        "text_encoder_2",
+        "feature_extractor",
+        "image_encoder",
+    ]
+    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        text_encoder_2: CLIPTextModelWithProjection,
+        tokenizer: CLIPTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+        scheduler: KarrasDiffusionSchedulers,
+        force_zeros_for_empty_prompt: bool = True,
+        add_watermarker: Optional[bool] = None,
+        feature_extractor: CLIPImageProcessor = None,
+        image_encoder: CLIPVisionModelWithProjection = None,
+    ):
+        if isinstance(controlnet, (list, tuple)):
+            controlnet = MultiControlNetModel(controlnet)
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            text_encoder_2=text_encoder_2,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            unet=unet,
+            controlnet=controlnet,
+            scheduler=scheduler,
+            feature_extractor=feature_extractor,
+            image_encoder=image_encoder,
+        )
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+        self.control_image_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+        )
+        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+        if add_watermarker:
+            self.watermark = StableDiffusionXLWatermarker()
+        else:
+            self.watermark = None
+
+        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        prompt_2: Optional[Union[str, List[str]]] = None,
+        image: PipelineImageInput = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 5.0,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt_2: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+        guess_mode: bool = False,
+        control_guidance_start: Union[float, List[float]] = 0.0,
+        control_guidance_end: Union[float, List[float]] = 1.0,
+        original_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Tuple[int, int] = None,
+        negative_original_size: Optional[Tuple[int, int]] = None,
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_target_size: Optional[Tuple[int, int]] = None,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        controller=None,
+        concept_models=None,
+        indices_to_alter=None,
+        face_app=None,
+        stage=None,
+        region_masks=None,
+        **kwargs,
+    ):
+        # revise_regionally_controlnet_forward(self.unet, controller)
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+
+        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+        # align format for control guidance
+        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+            control_guidance_start, control_guidance_end = (
+                mult * [control_guidance_start],
+                mult * [control_guidance_end],
+            )
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            prompt_2,
+            image,
+            callback_steps,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+            controlnet_conditioning_scale,
+            control_guidance_start,
+            control_guidance_end,
+            callback_on_step_end_tensor_inputs,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+
+        # 2. Define call parameters
+        batch_size = 2
+
+        device = self._execution_device
+
+        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+        global_pool_conditions = (
+            controlnet.config.global_pool_conditions
+            if isinstance(controlnet, ControlNetModel)
+            else controlnet.nets[0].config.global_pool_conditions
+        )
+        guess_mode = guess_mode or global_pool_conditions
+
+        # 3.1 Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+
+        global_prompt = prompt[0]
+        global_negative_prompt = negative_prompt
+        region_prompts = [pt[0] for pt in prompt[1]]
+        region_negative_prompts = [pt[1] for pt in prompt[1]]
+        ref_images = [pt[2] for pt in prompt[1]]
+
+        concat_prompts = global_prompt + region_prompts
+        concat_negative_prompts = global_negative_prompt + region_negative_prompts
+
+        (
+            concat_prompt_embeds,
+            concat_negative_prompt_embeds,
+            concat_pooled_prompt_embeds,
+            concat_negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            concat_prompts,
+            prompt_2,
+            device,
+            num_images_per_prompt,
+            self.do_classifier_free_guidance,
+            concat_negative_prompts,
+            negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        prompt_embeds = concat_prompt_embeds[:2]
+        negative_prompt_embeds = concat_negative_prompt_embeds[:2]
+        pooled_prompt_embeds = concat_pooled_prompt_embeds[:2]
+        negative_pooled_prompt_embeds = concat_negative_pooled_prompt_embeds[:2]
+
+        region_prompt_embeds_list = []
+        region_add_text_embeds_list = []
+        for region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds in zip(concat_prompt_embeds[2:], concat_negative_prompt_embeds[2:], concat_pooled_prompt_embeds[2:], concat_negative_pooled_prompt_embeds[2:]):
+            region_prompt_embeds_list.append(
+                torch.concat([region_negative_prompt_embeds.unsqueeze(0), region_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device))
+            region_add_text_embeds_list.append(
+                torch.concat([region_negative_pooled_prompt_embeds.unsqueeze(0), region_pooled_prompt_embeds.unsqueeze(0)], dim=0).to(concept_models._execution_device))
+
+
+        if stage==2:
+            mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks]
+            image_embedding_list = get_face_embedding(face_app, ref_images)
+            image_prompt_image_emb_list = []
+            for image_embeds in image_embedding_list:
+                prompt_image_emb = concept_models._encode_prompt_image_emb(image_embeds,
+                                                             concept_models._execution_device,
+                                                             num_images_per_prompt,
+                                                             concept_models.unet.dtype,
+                                                             True)
+                image_prompt_image_emb_list.append(prompt_image_emb)
+
+
+
+        # 4. Prepare image
+        if isinstance(controlnet, ControlNetModel) and image is not None:
+            image = self.prepare_image(
+                image=image,
+                width=width,
+                height=height,
+                batch_size=1 * num_images_per_prompt,
+                num_images_per_prompt=num_images_per_prompt,
+                device=device,
+                dtype=controlnet.dtype,
+                do_classifier_free_guidance=self.do_classifier_free_guidance,
+                guess_mode=guess_mode,
+            )
+            height, width = image.shape[-2:]
+        elif isinstance(controlnet, MultiControlNetModel) and image is not None:
+            images = []
+
+            for image_ in image:
+                image_ = self.prepare_image(
+                    image=image_,
+                    width=width,
+                    height=height,
+                    batch_size=batch_size * num_images_per_prompt,
+                    num_images_per_prompt=num_images_per_prompt,
+                    device=device,
+                    dtype=controlnet.dtype,
+                    do_classifier_free_guidance=self.do_classifier_free_guidance,
+                    guess_mode=guess_mode,
+                )
+
+                images.append(image_)
+
+            image = images
+            height, width = image[0].shape[-2:]
+        else:
+            height = height or self.unet.config.sample_size * self.vae_scale_factor
+            width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 5. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+        self._num_timesteps = len(timesteps)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size//2 * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6.1 repeat latent
+        latents = torch.cat([latents, latents.clone()])
+
+        # 6.5 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7.1 Create tensor stating which controlnets to keep
+        controlnet_keep = []
+        for i in range(len(timesteps)):
+            keeps = [
+                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+                for s, e in zip(control_guidance_start, control_guidance_end)
+            ]
+            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+        # 7.2 Prepare added time ids & embeddings
+        if isinstance(image, list):
+            original_size = original_size or image[0].shape[-2:]
+        else:
+            original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        add_text_embeds = pooled_prompt_embeds
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+
+        add_time_ids_list = []
+        region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim)
+        for _ in range(len(prompt[1])):
+            add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device))
+
+        if negative_original_size is not None and negative_target_size is not None:
+            negative_add_time_ids = self._get_add_time_ids(
+                negative_original_size,
+                negative_crops_coords_top_left,
+                negative_target_size,
+                dtype=prompt_embeds.dtype,
+                text_encoder_projection_dim=text_encoder_projection_dim,
+            )
+        else:
+            negative_add_time_ids = add_time_ids
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+        # 8. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        is_unet_compiled = is_compiled_module(self.unet)
+        is_controlnet_compiled = is_compiled_module(self.controlnet)
+        is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+        # hyper-parameters
+        scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps))
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # Relevant thread:
+                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+                    torch._inductor.cudagraph_mark_step_begin()
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+                # controlnet(s) inference
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infer ControlNet only for the conditional batch.
+                    control_model_input = latents
+                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+                    controlnet_added_cond_kwargs = {
+                        "text_embeds": add_text_embeds.chunk(2)[1],
+                        "time_ids": add_time_ids.chunk(2)[1],
+                    }
+                else:
+                    control_model_input = latent_model_input
+                    controlnet_prompt_embeds = prompt_embeds
+                    controlnet_added_cond_kwargs = added_cond_kwargs
+
+                if isinstance(controlnet_keep[i], list):
+                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+                else:
+                    controlnet_cond_scale = controlnet_conditioning_scale
+                    if isinstance(controlnet_cond_scale, list):
+                        controlnet_cond_scale = controlnet_cond_scale[0]
+                    cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+
+                # predict the noise residual
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                if i > 15 and stage == 2:
+                    region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3])
+                    edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0)
+                    new_noise_pred = torch.zeros_like(edit_noise)
+                    new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0]
+                    replace_ratio = 1.0
+                    new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0]
+
+                    for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, region_prompt_image_emb in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, image_prompt_image_emb_list):
+                        if concept_mask is not None:
+                            concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0),
+                                                         size=(noise_pred.shape[2], noise_pred.shape[3]),
+                                                         mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device)
+
+                            region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device)
+
+                            region_latent_model_input = torch.cat([region_latent_model_input] * 2)
+                            region_added_cond_kwargs = {"text_embeds": region_add_text_embeds,
+                                                        "time_ids": region_add_time_ids}
+
+                            if image is not None:
+                                down_block_res_samples, mid_block_res_sample = self.controlnet(
+                                    region_latent_model_input,
+                                    t,
+                                    encoder_hidden_states=region_prompt_image_emb,
+                                    controlnet_cond=image,
+                                    conditioning_scale=cond_scale,
+                                    guess_mode=guess_mode,
+                                    added_cond_kwargs=region_added_cond_kwargs,
+                                    return_dict=False,
+                                )
+
+                                if guess_mode and self.do_classifier_free_guidance:
+                                    # Infered ControlNet only for the conditional batch.
+                                    # To apply the output of ControlNet to both the unconditional and conditional batches,
+                                    # add 0 to the unconditional batch to keep it unchanged.
+                                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in
+                                                              down_block_res_samples]
+                                    mid_block_res_sample = torch.cat(
+                                        [torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+                            else:
+                                down_block_res_samples = None
+                                mid_block_res_sample = None
+
+                            region_encoder_hidden_states = torch.cat([region_prompt_embeds, region_prompt_image_emb], dim=1)
+
+                            region_noise_pred = concept_models.unet(
+                                region_latent_model_input,
+                                t,
+                                encoder_hidden_states=region_encoder_hidden_states,
+                                cross_attention_kwargs=None,
+                                down_block_additional_residuals=down_block_res_samples,
+                                mid_block_additional_residual=mid_block_res_sample,
+                                added_cond_kwargs=region_added_cond_kwargs,
+                                return_dict=False,
+                            )[0]
+
+
+                            new_noise_pred = new_noise_pred.to(concept_models._execution_device)
+                            new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device)))
+
+
+                    new_noise_pred = new_noise_pred.to(noise_pred.device)
+                    noise_pred[1, :, :, :] = new_noise_pred[0]
+                    noise_pred[3, :, :, :] = new_noise_pred[1]
+
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+        # manually for max memory savings
+        if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+            self.upcast_vae()
+            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            image = latents
+
+        if not output_type == "latent":
+            # apply watermark if available
+            if self.watermark is not None:
+                image = self.watermark.apply_watermark(image)
+
+            image = self.image_processor.postprocess(image, output_type=output_type)
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
+
+    def check_image(self, image, prompt, prompt_embeds):
+        pass
+
+    def get_region_mask(self, mask_list, feat_height, feat_width):
+        exclusive_mask = torch.zeros((feat_height, feat_width))
+        for mask in mask_list:
+            if mask is not None:
+                mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width),
+                                     mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device)
+                exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype)
+        return exclusive_mask
+
+def get_face_embedding(face_app, ref_images):
+    emb_list = []
+    for img_path in ref_images:
+        face_image = load_image(img_path)
+
+        # prepare face emb
+        face_info = face_app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
+        face_info = sorted(face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * x['bbox'][3] - x['bbox'][1])[0]  # only use the maximum face
+        face_emb = face_info['embedding']
+        emb_list.append(face_emb)
+        # face_kps = draw_kps(face_image, face_info['kps'])
+    return emb_list
\ No newline at end of file
diff --git a/src/pipelines/instantid_single_pieline.py b/src/pipelines/instantid_single_pieline.py
new file mode 100644
index 0000000000000000000000000000000000000000..133944e2008ae327a85b3a5d9e8244042a4e82e8
--- /dev/null
+++ b/src/pipelines/instantid_single_pieline.py
@@ -0,0 +1,772 @@
+# Copyright 2024 The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import cv2
+import math
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+
+from diffusers.image_processor import PipelineImageInput
+
+from diffusers.models import ControlNetModel
+
+from diffusers.utils import (
+    deprecate,
+    logging,
+    replace_example_docstring,
+)
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
+from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
+
+from diffusers import StableDiffusionXLControlNetPipeline
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers.utils.import_utils import is_xformers_available
+
+from src.ip_adapter.resampler import Resampler
+from src.ip_adapter.utils import is_torch2_available
+
+if is_torch2_available():
+    from src.ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
+else:
+    from src.ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> # !pip install opencv-python transformers accelerate insightface
+        >>> import diffusers
+        >>> from diffusers.utils import load_image
+        >>> from diffusers.models import ControlNetModel
+
+        >>> import cv2
+        >>> import torch
+        >>> import numpy as np
+        >>> from PIL import Image
+
+        >>> from insightface.app import FaceAnalysis
+        >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
+
+        >>> # download 'antelopev2' under ./models
+        >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
+        >>> app.prepare(ctx_id=0, det_size=(640, 640))
+
+        >>> # download models under ./checkpoints
+        >>> face_adapter = f'./checkpoints/ip-adapter.bin'
+        >>> controlnet_path = f'./checkpoints/ControlNetModel'
+
+        >>> # load IdentityNet
+        >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
+
+        >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
+        ...     "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
+        ... )
+        >>> pipe.cuda()
+
+        >>> # load adapter
+        >>> pipe.load_ip_adapter_instantid(face_adapter)
+
+        >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
+        >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
+
+        >>> # load an image
+        >>> image = load_image("your-example.jpg")
+
+        >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
+        >>> face_emb = face_info['embedding']
+        >>> face_kps = draw_kps(face_image, face_info['kps'])
+
+        >>> pipe.set_ip_adapter_scale(0.8)
+
+        >>> # generate image
+        >>> image = pipe(
+        ...     prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
+        ... ).images[0]
+        ```
+"""
+
+
+def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
+    stickwidth = 4
+    limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
+    kps = np.array(kps)
+
+    w, h = image_pil.size
+    out_img = np.zeros([h, w, 3])
+
+    for i in range(len(limbSeq)):
+        index = limbSeq[i]
+        color = color_list[index[0]]
+
+        x = kps[index][:, 0]
+        y = kps[index][:, 1]
+        length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
+        angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
+        polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
+                                   360, 1)
+        out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
+    out_img = (out_img * 0.6).astype(np.uint8)
+
+    for idx_kp, kp in enumerate(kps):
+        color = color_list[idx_kp]
+        x, y = kp
+        out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
+
+    out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
+    return out_img_pil
+
+
+class InstantidSingleConceptPipeline(StableDiffusionXLControlNetPipeline):
+
+    def cuda(self, dtype=torch.float16, use_xformers=False):
+        self.to('cuda', dtype)
+
+        if hasattr(self, 'image_proj_model'):
+            self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
+
+        if use_xformers:
+            if is_xformers_available():
+                import xformers
+                from packaging import version
+
+                xformers_version = version.parse(xformers.__version__)
+                if xformers_version == version.parse("0.0.16"):
+                    logger.warn(
+                        "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
+                    )
+                self.enable_xformers_memory_efficient_attention()
+            else:
+                raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+    def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
+        self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
+        self.set_ip_adapter(model_ckpt, num_tokens, scale)
+
+    def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
+
+        image_proj_model = Resampler(
+            dim=1280,
+            depth=4,
+            dim_head=64,
+            heads=20,
+            num_queries=num_tokens,
+            embedding_dim=image_emb_dim,
+            output_dim=self.unet.config.cross_attention_dim,
+            ff_mult=4,
+        )
+
+        image_proj_model.eval()
+
+        self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
+        state_dict = torch.load(model_ckpt, map_location="cpu")
+        if 'image_proj' in state_dict:
+            state_dict = state_dict["image_proj"]
+        self.image_proj_model.load_state_dict(state_dict)
+
+        self.image_proj_model_in_features = image_emb_dim
+
+    def set_ip_adapter(self, model_ckpt, num_tokens, scale):
+
+        unet = self.unet
+        attn_procs = {}
+        for name in unet.attn_processors.keys():
+            cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+            if name.startswith("mid_block"):
+                hidden_size = unet.config.block_out_channels[-1]
+            elif name.startswith("up_blocks"):
+                block_id = int(name[len("up_blocks.")])
+                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+            elif name.startswith("down_blocks"):
+                block_id = int(name[len("down_blocks.")])
+                hidden_size = unet.config.block_out_channels[block_id]
+            if cross_attention_dim is None:
+                attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
+            else:
+                attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
+                                                   cross_attention_dim=cross_attention_dim,
+                                                   scale=scale,
+                                                   num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
+        unet.set_attn_processor(attn_procs)
+
+        state_dict = torch.load(model_ckpt, map_location="cpu")
+        ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
+        if 'ip_adapter' in state_dict:
+            state_dict = state_dict['ip_adapter']
+        ip_layers.load_state_dict(state_dict)
+
+    def set_ip_adapter_scale(self, scale):
+        unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
+        for attn_processor in unet.attn_processors.values():
+            if isinstance(attn_processor, IPAttnProcessor):
+                attn_processor.scale = scale
+
+    def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype,
+                                 do_classifier_free_guidance):
+
+        if isinstance(prompt_image_emb, torch.Tensor):
+            prompt_image_emb = prompt_image_emb.clone().detach()
+        else:
+            prompt_image_emb = torch.tensor(prompt_image_emb)
+
+        prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
+        prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
+
+        if do_classifier_free_guidance:
+            prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
+        else:
+            prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
+
+        prompt_image_emb = self.image_proj_model(prompt_image_emb)
+
+        bs_embed, seq_len, _ = prompt_image_emb.shape
+        prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
+        prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        return prompt_image_emb
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+            self,
+            prompt: Union[str, List[str]] = None,
+            prompt_2: Optional[Union[str, List[str]]] = None,
+            image: PipelineImageInput = None,
+            height: Optional[int] = None,
+            width: Optional[int] = None,
+            num_inference_steps: int = 50,
+            guidance_scale: float = 5.0,
+            negative_prompt: Optional[Union[str, List[str]]] = None,
+            negative_prompt_2: Optional[Union[str, List[str]]] = None,
+            num_images_per_prompt: Optional[int] = 1,
+            eta: float = 0.0,
+            generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+            latents: Optional[torch.FloatTensor] = None,
+            prompt_embeds: Optional[torch.FloatTensor] = None,
+            negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+            pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+            negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+            image_embeds: Optional[torch.FloatTensor] = None,
+            output_type: Optional[str] = "pil",
+            return_dict: bool = True,
+            cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+            controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+            guess_mode: bool = False,
+            control_guidance_start: Union[float, List[float]] = 0.0,
+            control_guidance_end: Union[float, List[float]] = 1.0,
+            original_size: Tuple[int, int] = None,
+            crops_coords_top_left: Tuple[int, int] = (0, 0),
+            target_size: Tuple[int, int] = None,
+            negative_original_size: Optional[Tuple[int, int]] = None,
+            negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+            negative_target_size: Optional[Tuple[int, int]] = None,
+            clip_skip: Optional[int] = None,
+            callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+            callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+
+            # IP adapter
+            ip_adapter_scale=None,
+
+            **kwargs,
+    ):
+        r"""
+        The call function to the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders.
+            image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+                    `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+                specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+                accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+                and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+                `init`, images must be passed as a list such that each element of the list can be correctly batched for
+                input to a single ControlNet.
+            height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The height in pixels of the generated image. Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+                The width in pixels of the generated image. Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            guidance_scale (`float`, *optional*, defaults to 5.0):
+                A higher guidance scale value encourages the model to generate images closely linked to the text
+                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
+                and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+                generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor is generated by sampling using the supplied random `generator`.
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+                provided, text embeddings are generated from the `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+                not provided, pooled text embeddings are generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
+                weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
+                argument.
+            image_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated image embeddings.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+                The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+                to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+                the corresponding scale as a list.
+            guess_mode (`bool`, *optional*, defaults to `False`):
+                The ControlNet encoder tries to recognize the content of the input image even if you remove all
+                prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+            control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+                The percentage of total steps at which the ControlNet starts applying.
+            control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+                The percentage of total steps at which the ControlNet stops applying.
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a target image resolution. It should be as same
+                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeine class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+                otherwise a `tuple` is returned containing the output images.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+
+        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+        # align format for control guidance
+        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+            control_guidance_start, control_guidance_end = (
+                mult * [control_guidance_start],
+                mult * [control_guidance_end],
+            )
+
+        # 0. set ip_adapter_scale
+        if ip_adapter_scale is not None:
+            self.set_ip_adapter_scale(ip_adapter_scale)
+
+        # 1. Check inputs. Raise error if not correct
+        self.check_inputs(
+            prompt,
+            prompt_2,
+            image,
+            callback_steps,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+            controlnet_conditioning_scale,
+            control_guidance_start,
+            control_guidance_end,
+            callback_on_step_end_tensor_inputs,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+        global_pool_conditions = (
+            controlnet.config.global_pool_conditions
+            if isinstance(controlnet, ControlNetModel)
+            else controlnet.nets[0].config.global_pool_conditions
+        )
+        guess_mode = guess_mode or global_pool_conditions
+
+        # 3.1 Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            prompt,
+            prompt_2,
+            device,
+            num_images_per_prompt,
+            self.do_classifier_free_guidance,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        # 3.2 Encode image prompt
+        prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
+                                                         device,
+                                                         num_images_per_prompt,
+                                                         self.unet.dtype,
+                                                         self.do_classifier_free_guidance)
+
+        # 4. Prepare image
+        if isinstance(controlnet, ControlNetModel):
+            image = self.prepare_image(
+                image=image,
+                width=width,
+                height=height,
+                batch_size=batch_size * num_images_per_prompt,
+                num_images_per_prompt=num_images_per_prompt,
+                device=device,
+                dtype=controlnet.dtype,
+                do_classifier_free_guidance=self.do_classifier_free_guidance,
+                guess_mode=guess_mode,
+            )
+            height, width = image.shape[-2:]
+        elif isinstance(controlnet, MultiControlNetModel):
+            images = []
+
+            for image_ in image:
+                image_ = self.prepare_image(
+                    image=image_,
+                    width=width,
+                    height=height,
+                    batch_size=batch_size * num_images_per_prompt,
+                    num_images_per_prompt=num_images_per_prompt,
+                    device=device,
+                    dtype=controlnet.dtype,
+                    do_classifier_free_guidance=self.do_classifier_free_guidance,
+                    guess_mode=guess_mode,
+                )
+
+                images.append(image_)
+
+            image = images
+            height, width = image[0].shape[-2:]
+        else:
+            assert False
+
+        # 5. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+        self._num_timesteps = len(timesteps)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6.5 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7.1 Create tensor stating which controlnets to keep
+        controlnet_keep = []
+        for i in range(len(timesteps)):
+            keeps = [
+                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+                for s, e in zip(control_guidance_start, control_guidance_end)
+            ]
+            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+        # 7.2 Prepare added time ids & embeddings
+        if isinstance(image, list):
+            original_size = original_size or image[0].shape[-2:]
+        else:
+            original_size = original_size or image.shape[-2:]
+        target_size = target_size or (height, width)
+
+        add_text_embeds = pooled_prompt_embeds
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+
+        if negative_original_size is not None and negative_target_size is not None:
+            negative_add_time_ids = self._get_add_time_ids(
+                negative_original_size,
+                negative_crops_coords_top_left,
+                negative_target_size,
+                dtype=prompt_embeds.dtype,
+                text_encoder_projection_dim=text_encoder_projection_dim,
+            )
+        else:
+            negative_add_time_ids = add_time_ids
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+        encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
+
+        # 8. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        is_unet_compiled = is_compiled_module(self.unet)
+        is_controlnet_compiled = is_compiled_module(self.controlnet)
+        is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # Relevant thread:
+                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+                    torch._inductor.cudagraph_mark_step_begin()
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+                # controlnet(s) inference
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infer ControlNet only for the conditional batch.
+                    control_model_input = latents
+                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+                    controlnet_added_cond_kwargs = {
+                        "text_embeds": add_text_embeds.chunk(2)[1],
+                        "time_ids": add_time_ids.chunk(2)[1],
+                    }
+                else:
+                    control_model_input = latent_model_input
+                    controlnet_prompt_embeds = prompt_embeds
+                    controlnet_added_cond_kwargs = added_cond_kwargs
+
+                if isinstance(controlnet_keep[i], list):
+                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+                else:
+                    controlnet_cond_scale = controlnet_conditioning_scale
+                    if isinstance(controlnet_cond_scale, list):
+                        controlnet_cond_scale = controlnet_cond_scale[0]
+                    cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+                down_block_res_samples, mid_block_res_sample = self.controlnet(
+                    control_model_input,
+                    t,
+                    encoder_hidden_states=prompt_image_emb,
+                    controlnet_cond=image,
+                    conditioning_scale=cond_scale,
+                    guess_mode=guess_mode,
+                    added_cond_kwargs=controlnet_added_cond_kwargs,
+                    return_dict=False,
+                )
+
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infered ControlNet only for the conditional batch.
+                    # To apply the output of ControlNet to both the unconditional and conditional batches,
+                    # add 0 to the unconditional batch to keep it unchanged.
+                    down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+                    mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+                # predict the noise residual
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=encoder_hidden_states,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    down_block_additional_residuals=down_block_res_samples,
+                    mid_block_additional_residual=mid_block_res_sample,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            image = latents
+
+        if not output_type == "latent":
+            # apply watermark if available
+            if self.watermark is not None:
+                image = self.watermark.apply_watermark(image)
+
+            image = self.image_processor.postprocess(image, output_type=output_type)
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
diff --git a/src/pipelines/lora_pipeline.py b/src/pipelines/lora_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..b24b1e64e6696897f365761c6037126ffd225395
--- /dev/null
+++ b/src/pipelines/lora_pipeline.py
@@ -0,0 +1,681 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTextModelWithProjection,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+    FromSingleFileMixin,
+    IPAdapterMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+    AttnProcessor2_0,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from diffusers import StableDiffusionXLControlNetPipeline
+from PIL import Image
+from torchvision.transforms.functional import to_tensor
+from einops import rearrange
+from torch import einsum
+import math
+from torchvision.utils import save_image
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+class RegionControlNet_AttnProcessor:
+    def __init__(self, attention_op=None, controller=None, place_in_unet=None):
+        self.attention_op = attention_op
+        self.controller = controller
+        self.place_in_unet = place_in_unet
+
+    def __call__(
+            self,
+            attn,
+            hidden_states: torch.FloatTensor,
+            encoder_hidden_states: Optional[torch.FloatTensor] = None,
+            attention_mask: Optional[torch.FloatTensor] = None,
+            temb: Optional[torch.FloatTensor] = None,
+            scale: float = 1.0,
+            **cross_attention_kwargs
+    ) -> torch.Tensor:
+        residual = hidden_states
+
+        args = () if USE_PEFT_BACKEND else (scale,)
+
+        if attn.spatial_norm is not None:
+            hidden_states = attn.spatial_norm(hidden_states, temb)
+
+        input_ndim = hidden_states.ndim
+
+        if input_ndim == 4:
+            batch_size, channel, height, width = hidden_states.shape
+            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+        batch_size, sequence_length, _ = (
+            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+        )
+        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+        if attn.group_norm is not None:
+            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+        query = attn.to_q(hidden_states, *args)
+
+        is_cross = True
+        if encoder_hidden_states is None:
+            is_cross = False
+            encoder_hidden_states = hidden_states
+        elif attn.norm_cross:
+            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+        key = attn.to_k(encoder_hidden_states, *args)
+        value = attn.to_v(encoder_hidden_states, *args)
+
+        query = attn.head_to_batch_dim(query)
+        key = attn.head_to_batch_dim(key)
+        value = attn.head_to_batch_dim(value)
+
+        attention_probs = attn.get_attention_scores(query, key, attention_mask)
+        attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
+        hidden_states = torch.bmm(attention_probs, value)
+
+        hidden_states = attn.batch_to_head_dim(hidden_states)
+
+        # linear proj
+        hidden_states = attn.to_out[0](hidden_states, *args)
+        # dropout
+        hidden_states = attn.to_out[1](hidden_states)
+
+        if input_ndim == 4:
+            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+        if attn.residual_connection:
+            hidden_states = hidden_states + residual
+
+        hidden_states = hidden_states / attn.rescale_output_factor
+
+        return hidden_states
+
+
+def revise_regionally_controlnet_forward(unet, controller):
+    def change_forward(unet, count, place_in_unet):
+        for name, layer in unet.named_children():
+            if layer.__class__.__name__ == 'Attention':
+                layer.set_processor(RegionControlNet_AttnProcessor(controller=controller, place_in_unet=place_in_unet))
+                if 'attn2' in name:
+                    count += 1
+            else:
+                count = change_forward(layer, count, place_in_unet)
+        return count
+
+    # use this to ensure the order
+    cross_attention_idx = change_forward(unet.down_blocks, 0, "down")
+    cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, "up")
+    cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, "mid")
+    print(f'Number of attention layer registered {cross_attention_idx}')
+    controller.num_att_layers = cross_attention_idx*2
+
+class LoraMultiConceptPipeline(StableDiffusionXLControlNetPipeline):
+    # leave controlnet out on purpose because it iterates with unet
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+    _optional_components = [
+        "tokenizer",
+        "tokenizer_2",
+        "text_encoder",
+        "text_encoder_2",
+        "feature_extractor",
+        "image_encoder",
+    ]
+    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        text_encoder_2: CLIPTextModelWithProjection,
+        tokenizer: CLIPTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
+        scheduler: KarrasDiffusionSchedulers,
+        force_zeros_for_empty_prompt: bool = True,
+        add_watermarker: Optional[bool] = None,
+        feature_extractor: CLIPImageProcessor = None,
+        image_encoder: CLIPVisionModelWithProjection = None
+    ):
+        if isinstance(controlnet, (list, tuple)):
+            controlnet = MultiControlNetModel(controlnet)
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            text_encoder_2=text_encoder_2,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            unet=unet,
+            controlnet=controlnet,
+            scheduler=scheduler,
+            feature_extractor=feature_extractor,
+            image_encoder=image_encoder,
+        )
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+        self.control_image_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
+        )
+        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+        if add_watermarker:
+            self.watermark = StableDiffusionXLWatermarker()
+        else:
+            self.watermark = None
+
+        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+
+    @torch.no_grad()
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        prompt_2: Optional[Union[str, List[str]]] = None,
+        image: PipelineImageInput = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        num_inference_steps: int = 50,
+        guidance_scale: float = 5.0,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt_2: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
+        guess_mode: bool = False,
+        control_guidance_start: Union[float, List[float]] = 0.0,
+        control_guidance_end: Union[float, List[float]] = 1.0,
+        original_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Tuple[int, int] = None,
+        negative_original_size: Optional[Tuple[int, int]] = None,
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_target_size: Optional[Tuple[int, int]] = None,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        controller=None,
+        concept_models=None,
+        stage=None,
+        region_masks=None,
+        lora_list=None,
+        styleL=None,
+        **kwargs,
+    ):
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+            )
+
+        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+        # align format for control guidance
+        if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+            control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+        elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+            control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+        elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+            mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+            control_guidance_start, control_guidance_end = (
+                mult * [control_guidance_start],
+                mult * [control_guidance_end],
+            )
+
+        self._guidance_scale = guidance_scale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+
+        # 2. Define call parameters
+        batch_size = 2
+
+        device = self._execution_device
+
+        if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
+            controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
+
+        global_pool_conditions = (
+            controlnet.config.global_pool_conditions
+            if isinstance(controlnet, ControlNetModel)
+            else controlnet.nets[0].config.global_pool_conditions
+        )
+        guess_mode = guess_mode or global_pool_conditions
+
+        # 3.1 Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+
+        global_prompt = prompt[0]
+        global_negative_prompt = negative_prompt
+        region_prompts = [pt[0] for pt in prompt[1]]
+        region_negative_prompts = [pt[1] for pt in prompt[1]]
+
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            global_prompt,
+            prompt_2,
+            device,
+            num_images_per_prompt,
+            self.do_classifier_free_guidance,
+            global_negative_prompt,
+            negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        region_prompt_embeds_list = []
+        region_add_text_embeds_list = []
+        for lora_param, region_prompt, region_negative_prompt in zip(lora_list, region_prompts, region_negative_prompts):
+            if styleL:
+                concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5])
+            else:
+                concept_models.set_adapters(lora_param)
+            region_prompt_embeds, region_negative_prompt_embeds, region_pooled_prompt_embeds, region_negative_pooled_prompt_embeds = concept_models.encode_prompt(
+                prompt=region_prompt, device=concept_models._execution_device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=region_negative_prompt, lora_scale=text_encoder_lora_scale
+            )
+            region_prompt_embeds_list.append(torch.concat([region_negative_prompt_embeds, region_prompt_embeds], dim=0).to(concept_models._execution_device))
+            region_add_text_embeds_list.append(torch.concat([region_negative_pooled_prompt_embeds, region_pooled_prompt_embeds], dim=0).to(concept_models._execution_device))
+
+        if stage==2:
+            mask_list = [mask.float().to(dtype=prompt_embeds.dtype, device=device) if mask is not None else None for mask in region_masks]
+
+        # 4. Prepare image
+        if isinstance(controlnet, ControlNetModel) and image is not None:
+            image = self.prepare_image(
+                image=image,
+                width=width,
+                height=height,
+                batch_size=batch_size * num_images_per_prompt,
+                num_images_per_prompt=num_images_per_prompt,
+                device=device,
+                dtype=controlnet.dtype,
+                do_classifier_free_guidance=self.do_classifier_free_guidance,
+                guess_mode=guess_mode,
+            )
+            height, width = image.shape[-2:]
+        elif isinstance(controlnet, MultiControlNetModel) and image is not None:
+            images = []
+
+            for image_ in image:
+                image_ = self.prepare_image(
+                    image=image_,
+                    width=width,
+                    height=height,
+                    batch_size=batch_size * num_images_per_prompt,
+                    num_images_per_prompt=num_images_per_prompt,
+                    device=device,
+                    dtype=controlnet.dtype,
+                    do_classifier_free_guidance=self.do_classifier_free_guidance,
+                    guess_mode=guess_mode,
+                )
+
+                images.append(image_)
+
+            image = images
+            height, width = image[0].shape[-2:]
+        else:
+            height = height or self.unet.config.sample_size * self.vae_scale_factor
+            width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 5. Prepare timesteps
+        self.scheduler.set_timesteps(num_inference_steps, device=device)
+        timesteps = self.scheduler.timesteps
+        self._num_timesteps = len(timesteps)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.unet.config.in_channels
+        latents = self.prepare_latents(
+            batch_size//2 * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+        )
+
+        # 6.1 repeat latent
+        latents = torch.cat([latents, latents.clone()])
+
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 7.1 Create tensor stating which controlnets to keep
+        controlnet_keep = []
+        for i in range(len(timesteps)):
+            keeps = [
+                1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+                for s, e in zip(control_guidance_start, control_guidance_end)
+            ]
+            controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+        # 7.2 Prepare added time ids & embeddings
+        if isinstance(image, list):
+            original_size = original_size or image[0].shape[-2:]
+        else:
+            original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        add_text_embeds = pooled_prompt_embeds
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+
+        add_time_ids_list = []
+        for _ in lora_list:
+            region_add_time_ids = concept_models._get_add_time_ids(original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim)
+            add_time_ids_list.append(torch.concat([region_add_time_ids, region_add_time_ids], dim=0).to(concept_models._execution_device))
+
+        if negative_original_size is not None and negative_target_size is not None:
+            negative_add_time_ids = self._get_add_time_ids(
+                negative_original_size,
+                negative_crops_coords_top_left,
+                negative_target_size,
+                dtype=prompt_embeds.dtype,
+                text_encoder_projection_dim=text_encoder_projection_dim,
+            )
+        else:
+            negative_add_time_ids = add_time_ids
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+        # 8. Denoising loop
+        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+        is_unet_compiled = is_compiled_module(self.unet)
+        is_controlnet_compiled = is_compiled_module(self.controlnet)
+        is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+        # hyper-parameters
+        scale_range = np.linspace(1, 0.5, len(self.scheduler.timesteps))
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                # Relevant thread:
+                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+                if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
+                    torch._inductor.cudagraph_mark_step_begin()
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+                # controlnet(s) inference
+                if guess_mode and self.do_classifier_free_guidance:
+                    # Infer ControlNet only for the conditional batch.
+                    control_model_input = latents
+                    control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+                    controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+                    controlnet_added_cond_kwargs = {
+                        "text_embeds": add_text_embeds.chunk(2)[1],
+                        "time_ids": add_time_ids.chunk(2)[1],
+                    }
+                else:
+                    control_model_input = latent_model_input
+                    controlnet_prompt_embeds = prompt_embeds
+                    controlnet_added_cond_kwargs = added_cond_kwargs
+
+                if isinstance(controlnet_keep[i], list):
+                    cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
+                else:
+                    controlnet_cond_scale = controlnet_conditioning_scale
+                    if isinstance(controlnet_cond_scale, list):
+                        controlnet_cond_scale = controlnet_cond_scale[0]
+                    cond_scale = controlnet_cond_scale * controlnet_keep[i]
+
+                if image is not None:
+                    down_block_res_samples, mid_block_res_sample = self.controlnet(
+                        control_model_input,
+                        t,
+                        encoder_hidden_states=controlnet_prompt_embeds,
+                        controlnet_cond=image,
+                        conditioning_scale=cond_scale,
+                        guess_mode=guess_mode,
+                        added_cond_kwargs=controlnet_added_cond_kwargs,
+                        return_dict=False,
+                    )
+
+                    if guess_mode and self.do_classifier_free_guidance:
+                        # Infered ControlNet only for the conditional batch.
+                        # To apply the output of ControlNet to both the unconditional and conditional batches,
+                        # add 0 to the unconditional batch to keep it unchanged.
+                        down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+                        mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+
+                else:
+                    down_block_res_samples = None
+                    mid_block_res_sample = None
+
+
+
+                # predict the noise residual
+                if image is not None:
+                    noise_pred = self.unet(
+                        latent_model_input,
+                        t,
+                        encoder_hidden_states=prompt_embeds,
+                        timestep_cond=timestep_cond,
+                        cross_attention_kwargs=self.cross_attention_kwargs,
+                        down_block_additional_residuals=down_block_res_samples,
+                        mid_block_additional_residual=mid_block_res_sample,
+                        added_cond_kwargs=added_cond_kwargs,
+                        return_dict=False,
+                    )[0]
+                else:
+                    noise_pred = self.unet(
+                        latent_model_input,
+                        t,
+                        encoder_hidden_states=prompt_embeds,
+                        timestep_cond=timestep_cond,
+                        cross_attention_kwargs=self.cross_attention_kwargs,
+                        added_cond_kwargs=added_cond_kwargs,
+                        return_dict=False,
+                    )[0]
+
+                if i > 15 and stage == 2:
+                    region_mask = self.get_region_mask(mask_list, noise_pred.shape[2], noise_pred.shape[3])
+                    edit_noise = torch.concat([noise_pred[1:2], noise_pred[3:4]], dim=0)
+                    new_noise_pred = torch.zeros_like(edit_noise)
+                    new_noise_pred[:, :, region_mask == 0] = edit_noise[:, :, region_mask == 0]
+                    replace_ratio = 1.0
+                    new_noise_pred[:, :, region_mask != 0] = (1 - replace_ratio) * edit_noise[:, :, region_mask != 0]
+
+                    for region_prompt_embeds, region_add_text_embeds, region_add_time_ids, concept_mask, region_prompt, lora_param in zip(region_prompt_embeds_list, region_add_text_embeds_list, add_time_ids_list, mask_list, region_prompts, lora_list):
+                        if concept_mask is not None:
+                            concept_mask = F.interpolate(concept_mask.unsqueeze(0).unsqueeze(0),
+                                                         size=(noise_pred.shape[2], noise_pred.shape[3]),
+                                                         mode='nearest').squeeze().to(dtype=noise_pred.dtype, device=concept_models._execution_device)
+
+
+                            region_latent_model_input = latent_model_input[3:4].clone().to(concept_models._execution_device)
+
+                            region_latent_model_input = torch.cat([region_latent_model_input] * 2)
+                            region_added_cond_kwargs = {"text_embeds": region_add_text_embeds,
+                                                        "time_ids": region_add_time_ids}
+                            if styleL:
+                                concept_models.set_adapters([lora_param, "style"], adapter_weights=[0.7, 0.5])
+                            else:
+                                concept_models.set_adapters(lora_param)
+                            region_noise_pred = concept_models.unet(
+                                region_latent_model_input,
+                                t,
+                                encoder_hidden_states=region_prompt_embeds,
+                                cross_attention_kwargs={'scale': 0.8},
+                                added_cond_kwargs=region_added_cond_kwargs,
+                                return_dict=False,
+                            )[0]
+
+                            new_noise_pred = new_noise_pred.to(concept_models._execution_device)
+                            new_noise_pred[:, :, concept_mask==1] += replace_ratio * (region_noise_pred[:, :, concept_mask==1] / (concept_mask.reshape(1, 1, *concept_mask.shape)[:, :, concept_mask==1].to(region_noise_pred.device)))
+
+
+                    new_noise_pred = new_noise_pred.to(noise_pred.device)
+                    noise_pred[1, :, :, :] = new_noise_pred[0]
+                    noise_pred[3, :, :, :] = new_noise_pred[1]
+
+
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+        # manually for max memory savings
+        if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
+            self.upcast_vae()
+            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+        if stage==2:
+            latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            image = latents
+
+        if not output_type == "latent":
+            # apply watermark if available
+            if self.watermark is not None:
+                image = self.watermark.apply_watermark(image)
+
+            image = self.image_processor.postprocess(image, output_type=output_type)
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
+
+    def check_image(self, image, prompt, prompt_embeds):
+        pass
+
+    def get_region_mask(self, mask_list, feat_height, feat_width):
+        exclusive_mask = torch.zeros((feat_height, feat_width))
+        for mask in mask_list:
+            if mask is not None:
+                mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(feat_height, feat_width),
+                                     mode='nearest').squeeze().to(dtype=exclusive_mask.dtype, device=exclusive_mask.device)
+                exclusive_mask = ((mask == 1) | (exclusive_mask == 1)).to(dtype=mask.dtype)
+        return exclusive_mask
diff --git a/src/prompt_attention/p2p_attention.py b/src/prompt_attention/p2p_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..afb92d2ed66d8f8b16eb9b70ace242bb47198693
--- /dev/null
+++ b/src/prompt_attention/p2p_attention.py
@@ -0,0 +1,148 @@
+from typing import Optional, Union, Tuple, List, Callable, Dict
+import torch
+import torch.nn.functional as nnf
+import numpy as np
+import abc
+import src.prompt_attention.p2p_utils as p2p_utils
+import src.prompt_attention.seq_aligner as seq_aligner
+
+
+
+class AttentionControl(abc.ABC):
+
+    def step_callback(self, x_t):
+        return x_t
+
+    def between_steps(self):
+        return
+
+    @property
+    def num_uncond_att_layers(self):
+        # return self.num_att_layers if self.low_resource else 0
+        return 0
+
+    @abc.abstractmethod
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        raise NotImplementedError
+
+    def __call__(self, attn, is_cross: bool, place_in_unet: str):
+        if self.cur_att_layer >= self.num_uncond_att_layers:
+            if self.low_resource:
+                attn = self.forward(attn, is_cross, place_in_unet)
+            else:
+                h = attn.shape[0]
+                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
+        self.cur_att_layer += 1
+        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
+            self.cur_att_layer = 0
+            self.cur_step += 1
+            self.between_steps()
+        return attn
+
+    def reset(self):
+        self.cur_step = 0
+        self.cur_att_layer = 0
+
+    def __init__(self, low_resource=False, width=None, height=None):
+        self.cur_step = 0
+        self.num_att_layers = -1
+        self.cur_att_layer = 0
+        self.low_resource = low_resource
+        self.width = width
+        self.height = height
+
+class AttentionStore(AttentionControl):
+
+    @staticmethod
+    def get_empty_store():
+        return {"down_cross": [], "mid_cross": [], "up_cross": [],
+                "down_self": [], "mid_self": [], "up_self": []}
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
+        # if attn.shape[1] <= att_size * 64:
+        return attn
+
+    def between_steps(self):
+        if self.save_global_store:
+            if len(self.attention_store) == 0:
+                self.attention_store = self.step_store
+            else:
+                for key in self.attention_store:
+                    for i in range(len(self.attention_store[key])):
+                        self.attention_store[key][i] += self.step_store[key][i]
+        else:
+            self.attention_store = self.step_store
+        self.step_store = self.get_empty_store()
+
+    def get_average_attention(self):
+        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
+                             self.attention_store}
+        return average_attention
+
+    def reset(self):
+        super(AttentionStore, self).reset()
+        self.step_store = self.get_empty_store()
+        self.attention_store = {}
+
+    def __init__(self, width, height, low_resolution=False, save_global_store=False):
+        super(AttentionStore, self).__init__(low_resolution, width, height)
+        self.step_store = self.get_empty_store()
+        self.attention_store = {}
+        self.save_global_store = save_global_store
+
+class AttentionControlEdit(AttentionStore, abc.ABC):
+    def __init__(self, prompts, num_steps: int,
+                 cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
+                 self_replace_steps: Union[float, Tuple[float, float]],
+                 local_blend=None, width=None, height=None, tokenizer=None, device=None):
+        super(AttentionControlEdit, self).__init__(width, height)
+        self.batch_size = len(prompts)
+        self.cross_replace_alpha = p2p_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
+                                                                            tokenizer).to(device)
+        if type(self_replace_steps) is float:
+            self_replace_steps = 0, self_replace_steps
+        self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
+        self.local_blend = local_blend
+
+    def step_callback(self, x_t):
+        print("step_callback")
+        if self.local_blend is not None:
+            x_t = self.local_blend(x_t, self.attention_store)
+        return x_t
+
+    def replace_self_attention(self, attn_base, att_replace):
+        if att_replace.shape[2] <= self.width * self.height:
+            return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
+        else:
+            return att_replace
+
+    @abc.abstractmethod
+    def replace_cross_attention(self, attn_base, att_replace):
+        raise NotImplementedError
+
+    def forward(self, attn, is_cross: bool, place_in_unet: str):
+        super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
+        if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
+            h = attn.shape[0] // (self.batch_size)
+            attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
+            attn_base, attn_repalce = attn[0], attn[1:]
+            if is_cross:
+                alpha_words = self.cross_replace_alpha[self.cur_step]
+                attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
+                            1 - alpha_words) * attn_repalce
+                attn[1:] = attn_repalce_new
+            else:
+                attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
+            attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
+        return attn
+
+class AttentionReplace(AttentionControlEdit):
+    def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, width, height,
+                 local_blend = None, tokenizer=None, device=None, dtype=None):
+        super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, width, height, tokenizer=tokenizer, device=device)
+        self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(dtype=dtype, device=device)
+
+    def replace_cross_attention(self, attn_base, att_replace):
+        return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
+
diff --git a/src/prompt_attention/p2p_utils.py b/src/prompt_attention/p2p_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..13a6dd4a1443378689e89a03de49981a131571b3
--- /dev/null
+++ b/src/prompt_attention/p2p_utils.py
@@ -0,0 +1,74 @@
+# Copyright 2022 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+from PIL import Image, ImageDraw, ImageFont
+import cv2
+from typing import Optional, Union, Tuple, List, Callable, Dict
+
+
+
+def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
+                           word_inds: Optional[torch.Tensor] = None):
+    if type(bounds) is float:
+        bounds = 0, bounds
+    start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
+    if word_inds is None:
+        word_inds = torch.arange(alpha.shape[2])
+    alpha[: start, prompt_ind, word_inds] = 0
+    alpha[start: end, prompt_ind, word_inds] = 1
+    alpha[end:, prompt_ind, word_inds] = 0
+    return alpha
+
+def get_word_inds(text: str, word_place: int, tokenizer):
+    split_text = text.split(" ")
+    if type(word_place) is str:
+        word_place = [i for i, word in enumerate(split_text) if word_place == word]
+    elif type(word_place) is int:
+        word_place = [word_place]
+    out = []
+    if len(word_place) > 0:
+        words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
+        cur_len, ptr = 0, 0
+
+        for i in range(len(words_encode)):
+            cur_len += len(words_encode[i])
+            if ptr in word_place:
+                out.append(i + 1)
+            if cur_len >= len(split_text[ptr]):
+                ptr += 1
+                cur_len = 0
+    return np.array(out)
+
+def get_time_words_attention_alpha(prompts, num_steps,
+                                   cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
+                                   tokenizer, max_num_words=77):
+    if type(cross_replace_steps) is not dict:
+        cross_replace_steps = {"default_": cross_replace_steps}
+    if "default_" not in cross_replace_steps:
+        cross_replace_steps["default_"] = (0., 1.)
+    alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
+    for i in range(len(prompts) - 1):
+        alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
+                                                  i)
+    for key, item in cross_replace_steps.items():
+        if key != "default_":
+            inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
+            for i, ind in enumerate(inds):
+                if len(ind) > 0:
+                    alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
+    alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
+    return alpha_time_words
+
diff --git a/src/prompt_attention/seq_aligner.py b/src/prompt_attention/seq_aligner.py
new file mode 100644
index 0000000000000000000000000000000000000000..4530d05595a25f395693ef814e49a49b1581ce12
--- /dev/null
+++ b/src/prompt_attention/seq_aligner.py
@@ -0,0 +1,66 @@
+import torch
+import numpy as np
+
+
+def get_word_inds(text: str, word_place: int, tokenizer):
+    split_text = text.split(" ")
+    if type(word_place) is str:
+        word_place = [i for i, word in enumerate(split_text) if word_place == word]
+    elif type(word_place) is int:
+        word_place = [word_place]
+    out = []
+    if len(word_place) > 0:
+        words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
+        cur_len, ptr = 0, 0
+
+        for i in range(len(words_encode)):
+            cur_len += len(words_encode[i])
+            if ptr in word_place:
+                out.append(i + 1)
+            if cur_len >= len(split_text[ptr]):
+                ptr += 1
+                cur_len = 0
+    return np.array(out)
+
+def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
+    words_x = x.split(' ')
+    words_y = y.split(' ')
+    if len(words_x) != len(words_y):
+        raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
+                         f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
+    inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
+    inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
+    inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
+    mapper = np.zeros((max_len, max_len))
+    i = j = 0
+    cur_inds = 0
+    while i < max_len and j < max_len:
+        if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
+            inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
+            if len(inds_source_) == len(inds_target_):
+                mapper[inds_source_, inds_target_] = 1
+            else:
+                ratio = 1 / len(inds_target_)
+                for i_t in inds_target_:
+                    mapper[inds_source_, i_t] = ratio
+            cur_inds += 1
+            i += len(inds_source_)
+            j += len(inds_target_)
+        elif cur_inds < len(inds_source):
+            mapper[i, j] = 1
+            i += 1
+            j += 1
+        else:
+            mapper[j, j] = 1
+            i += 1
+            j += 1
+
+    return torch.from_numpy(mapper).float()
+
+def get_replacement_mapper(prompts, tokenizer, max_len=77):
+    x_seq = prompts[0]
+    mappers = []
+    for i in range(1, len(prompts)):
+        mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
+        mappers.append(mapper)
+    return torch.stack(mappers)
\ No newline at end of file