diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..03ff76df5665b3fa05b3be5a1699b1e0dd298d41
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,10 @@
+__pycache__
+*.pyc
+*.egg-info
+dist
+
+output
+output_dir
+*.pth
+*.log
+weights
\ No newline at end of file
diff --git a/README.md b/README.md
index e0405e1925d27b053e6581abc975e9c885d3335e..2c5b4b0ef87dfe708a2d0290369750f47f2c50be 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,100 @@
 ---
-title: LLaMA Adapter V2
+title: OneLLM
 emoji: 🚀
 colorFrom: red
 colorTo: indigo
 sdk: gradio
-sdk_version: 3.23.0
+sdk_version: 4.7.1
 app_file: app.py
 pinned: false
 ---
 
-### LLaMA-Adapter
-The official demo for LLaMA-Adapter V2.
-Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
+# OneLLM: One Framework to Align All Modalities with Language
 
+[[Project Page](https://onellm.csuhan.com)] [[Paper](#)] [[Web Demo](https://huggingface.co/spaces/csuhan/OneLLM)]
+
+Authors: [Jiaming Han](), [Kaixiong Gong](), [Yiyuan Zhang](), [Jiaqi Wang](), [Kaipeng Zhang](), [Dahua Lin](), [Yu Qiao](), [Peng Gao](), [Xiangyu Yue]().
+
+## News
+
+- **2023.12.01** Release model weights and inference code.
+
+## Contents
+
+- [Install](#install)
+- [Models](#models)
+- [Demo](#demo)
+
+<!-- - [Evaluation](#evaluation) -->
+
+<!-- - [Training](#training) -->
+
+### TODO
+
+- [ ] Data
+- [ ] Evaluation
+- [ ] Training
+
+### Install
+
+1. Clone the repo into a local folder.
+
+```bash
+git clone https://github.com/csuhan/OneLLM
+
+cd OneLLM
+```
+
+2. Install packages.
+
+```bash
+conda create -n onellm python=3.9 -y
+conda activate onellm
+
+pip install -r requirements.txt
+
+# install pointnet
+cd lib/pointnet2
+python setup.py install
+```
+
+3. Install Apex. (Optional)
+
+```bash
+git clone https://github.com/NVIDIA/apex
+cd apex
+pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
+```
+
+### Models
+
+We provide a preview model at: [csuhan/OneLLM-7B](https://huggingface.co/csuhan/OneLLM-7B).
+
+### Demo
+
+**Huggingface Demo:** [csuhan/OneLLM](https://huggingface.co/spaces/csuhan/OneLLM).
+
+**Local Demo:** Assume you have downloaded the weights to ${WEIGHTS_DIR}. Then run the following command to start a gradio demo locally.
+
+```bash
+python demos/multi_turn_mm.py --gpu_ids 0 --tokenizer_path config/llama2/tokenizer.model --llama_config config/llama2/7B.json --pretrained_path ${WEIGHTS_DIR}/consolidated.00-of-01.pth
+```
+
+<!-- ### Evaluation -->
+
+<!-- ### Training -->
+
+## Citation
+
+```
+@article{han2023onellm,
+  title={OneLLM: One Framework to Align All Modalities with Language},
+  author={Han, Jiaming and Gong, Kaixiong and Zhang, Yiyuan and Wang, Jiaqi and Zhang, Kaipeng and Lin, Dahua and Qiao, Yu and Gao, Peng and Yue, Xiangyu},
+  journal={arXiv preprint arXiv:xxxx},
+  year={2023}
+}
+```
+
+## Acknowledgement
+
+[LLaMA](https://github.com/facebookresearch/llama), [LLaMA-Adapter](https://github.com/OpenGVLab/LLaMA-Adapter), [LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory), [Meta-Transformer](https://github.com/invictus717/MetaTransformer), [ChatBridge](https://github.com/joez17/ChatBridge)
diff --git a/app.py b/app.py
index 26c88289cdf2fe061461952331735abe6fa46172..a180cda697755716b352d8fa6456204db8aac801 100644
--- a/app.py
+++ b/app.py
@@ -1,277 +1,272 @@
-import json
-import os
-import glob
 import sys
-import time
-from pathlib import Path
-from typing import Tuple
+import os
+
+import argparse
+import multiprocessing as mp
+import numpy as np
+from typing import List, Optional
 
-from huggingface_hub import hf_hub_download
-from PIL import Image
-import gradio as gr
 import torch
-from fairscale.nn.model_parallel.initialize import initialize_model_parallel
-
-from llama import LLaMA, ModelArgs, Tokenizer, Transformer, VisionModel
-
-os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
-
-PROMPT_DICT = {
-    "prompt_input": (
-        "Below is an instruction that describes a task, paired with an input that provides further context. "
-        "Write a response that appropriately completes the request.\n\n"
-        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
-    ),
-    "prompt_no_input": (
-        "Below is an instruction that describes a task. "
-        "Write a response that appropriately completes the request.\n\n"
-        "### Instruction:\n{instruction}\n\n### Response:"
-    ),
-}
-
-
-def setup_model_parallel() -> Tuple[int, int]:
-    os.environ['RANK'] = '0'
-    os.environ['WORLD_SIZE'] = '1'
-    os.environ['MP'] = '1'
-    os.environ['MASTER_ADDR'] = '127.0.0.1'
-    os.environ['MASTER_PORT'] = '2223'
-    local_rank = int(os.environ.get("LOCAL_RANK", -1))
-    world_size = int(os.environ.get("WORLD_SIZE", -1))
-
-    torch.distributed.init_process_group("nccl")
-    initialize_model_parallel(world_size)
-    torch.cuda.set_device(local_rank)
-
-    # seed must be the same in all processes
-    torch.manual_seed(1)
-    return local_rank, world_size
-
-
-def load(
-    ckpt0_path: str,
-    ckpt1_path: str,
-    param_path: str,
-    tokenizer_path: str,
-    instruct_adapter_path: str,
-    caption_adapter_path: str,
-    local_rank: int,
-    world_size: int,
-    max_seq_len: int,
-    max_batch_size: int,
-) -> LLaMA:
-    start_time = time.time()
-    print("Loading")
-    instruct_adapter_checkpoint = torch.load(
-        instruct_adapter_path, map_location="cpu")
-    caption_adapter_checkpoint = torch.load(
-        caption_adapter_path, map_location="cpu")
-    with open(param_path, "r") as f:
-        params = json.loads(f.read())
-
-    model_args: ModelArgs = ModelArgs(
-        max_seq_len=max_seq_len, max_batch_size=max_batch_size, **params
-    )
-    model_args.adapter_layer = int(
-        instruct_adapter_checkpoint['adapter_query.weight'].shape[0] / model_args.adapter_len)
-    model_args.cap_adapter_layer = int(
-        caption_adapter_checkpoint['cap_adapter_query.weight'].shape[0] / model_args.cap_adapter_len)
-
-    tokenizer = Tokenizer(model_path=tokenizer_path)
-    model_args.vocab_size = tokenizer.n_words
-    torch.set_default_tensor_type(torch.cuda.HalfTensor)
-    model = Transformer(model_args)
-
-    # To reduce memory usuage
-    ckpt0 = torch.load(ckpt0_path, map_location='cuda')
-    model.load_state_dict(ckpt0, strict=False)
-    del ckpt0
-    torch.cuda.empty_cache()
-
-    ckpt1 = torch.load(ckpt1_path, map_location='cuda')
-    model.load_state_dict(ckpt1, strict=False)
-    del ckpt1
-    torch.cuda.empty_cache()
-
-    vision_model = VisionModel(model_args)
-
-    torch.set_default_tensor_type(torch.FloatTensor)
-    model.load_state_dict(instruct_adapter_checkpoint, strict=False)
-    model.load_state_dict(caption_adapter_checkpoint, strict=False)
-    vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
-
-    generator = LLaMA(model, tokenizer, vision_model)
-    print(f"Loaded in {time.time() - start_time:.2f} seconds")
-    return generator
-
-
-def instruct_generate(
-    instruct: str,
-    input: str = 'none',
-    max_gen_len=512,
-    temperature: float = 0.1,
-    top_p: float = 0.75,
-):
-    if input == 'none':
-        prompt = PROMPT_DICT['prompt_no_input'].format_map(
-            {'instruction': instruct, 'input': ''})
-    else:
-        prompt = PROMPT_DICT['prompt_input'].format_map(
-            {'instruction': instruct, 'input': input})
-
-    results = generator.generate(
-        [prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
-    )
-    result = results[0].strip()
-    print(result)
-    return result
-
-
-def caption_generate(
-    img: str,
-    max_gen_len=512,
-    temperature: float = 0.1,
-    top_p: float = 0.75,
-):
-    imgs = [Image.open(img).convert('RGB')]
-    prompts = ["Generate caption of this image :",] * len(imgs)
-
-    results = generator.generate(
-        prompts, imgs=imgs, max_gen_len=max_gen_len, temperature=temperature, top_p=top_p
-    )
-    result = results[0].strip()
-    print(result)
-    return result
-
-
-def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
-    if not os.path.exists(instruct_adapter_path):
-        os.system(
-            f"wget -q -O {instruct_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_release.pth")
-
-    if not os.path.exists(caption_adapter_path):
-        os.system(
-            f"wget -q -O {caption_adapter_path} https://github.com/ZrrSkywalker/LLaMA-Adapter/releases/download/v.1.0.0/llama_adapter_len10_layer30_caption_vit_l.pth")
-
-
-# ckpt_path = "/data1/llma/7B/consolidated.00.pth"
-# param_path = "/data1/llma/7B/params.json"
-# tokenizer_path = "/data1/llma/tokenizer.model"
-ckpt0_path = hf_hub_download(
-    repo_id="csuhan/llama_storage", filename="consolidated.00_part0.pth")
-ckpt1_path = hf_hub_download(
-    repo_id="csuhan/llama_storage", filename="consolidated.00_part1.pth")
-param_path = hf_hub_download(
-    repo_id="nyanko7/LLaMA-7B", filename="params.json")
-tokenizer_path = hf_hub_download(
-    repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
-instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
-caption_adapter_path = "llama_adapter_len10_layer30_caption_vit_l.pth"
-max_seq_len = 512
-max_batch_size = 1
-
-# download models
-# download_llama_adapter(instruct_adapter_path, caption_adapter_path)
-
-local_rank, world_size = setup_model_parallel()
-if local_rank > 0:
-    sys.stdout = open(os.devnull, "w")
-
-generator = load(
-    ckpt0_path, ckpt1_path, param_path, tokenizer_path, instruct_adapter_path, caption_adapter_path, local_rank, world_size, max_seq_len, max_batch_size
-)
-
-
-def create_instruct_demo():
-    with gr.Blocks() as instruct_demo:
-        with gr.Row():
-            with gr.Column():
-                instruction = gr.Textbox(lines=2, label="Instruction")
-                input = gr.Textbox(
-                    lines=2, label="Context input", placeholder='none')
-                max_len = gr.Slider(minimum=1, maximum=512,
-                                    value=128, label="Max length")
-                with gr.Accordion(label='Advanced options', open=False):
-                    temp = gr.Slider(minimum=0, maximum=1,
-                                     value=0.1, label="Temperature")
-                    top_p = gr.Slider(minimum=0, maximum=1,
-                                      value=0.75, label="Top p")
-
-                run_botton = gr.Button("Run")
-
-            with gr.Column():
-                outputs = gr.Textbox(lines=10, label="Output")
-
-        inputs = [instruction, input, max_len, temp, top_p]
-
-        examples = [
-            "Tell me about alpacas.",
-            "Write a Python program that prints the first 10 Fibonacci numbers.",
-            "Write a conversation between the sun and pluto.",
-            "Write a theory to explain why cat never existed",
-        ]
-        examples = [
-            [x, "none", 128, 0.1, 0.75]
-            for x in examples]
-
-        gr.Examples(
-            examples=examples,
-            inputs=inputs,
-            outputs=outputs,
-            fn=instruct_generate,
-            cache_examples=os.getenv('SYSTEM') == 'spaces'
-        )
-        run_botton.click(fn=instruct_generate, inputs=inputs, outputs=outputs)
-    return instruct_demo
+import torch.distributed as dist
 
+from fairscale.nn.model_parallel import initialize as fs_init
 
-def create_caption_demo():
-    with gr.Blocks() as instruct_demo:
-        with gr.Row():
-            with gr.Column():
-                img = gr.Image(label='Input', type='filepath')
-                max_len = gr.Slider(minimum=1, maximum=512,
-                                    value=64, label="Max length")
-                with gr.Accordion(label='Advanced options', open=False):
-                    temp = gr.Slider(minimum=0, maximum=1,
-                                     value=0.1, label="Temperature")
-                    top_p = gr.Slider(minimum=0, maximum=1,
-                                      value=0.75, label="Top p")
-
-                run_botton = gr.Button("Run")
-
-            with gr.Column():
-                outputs = gr.Textbox(lines=10, label="Output")
-
-        inputs = [img, max_len, temp, top_p]
-
-        examples = glob.glob("caption_demo/*.jpg")
-        examples = [
-            [x, 64, 0.1, 0.75]
-            for x in examples]
-
-        gr.Examples(
-            examples=examples,
-            inputs=inputs,
-            outputs=outputs,
-            fn=caption_generate,
-            cache_examples=os.getenv('SYSTEM') == 'spaces'
-        )
-        run_botton.click(fn=caption_generate, inputs=inputs, outputs=outputs)
-    return instruct_demo
+import gradio as gr
+from util.misc import setup_for_distributed
+from util.misc import default_tensor_type
+from model.meta import MetaModel
+from data.conversation_lib import conv_templates, SeparatorStyle
+from PIL import Image
+import torchvision.transforms as transforms
+from data.fintune_dataset import make_audio_features
+from data import video_utils 
+from dataclasses import dataclass
+from huggingface_hub import hf_hub_download
 
+T_random_resized_crop = transforms.Compose([
+    transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
+                                 antialias=None),  # 3 is bicubic
+    transforms.ToTensor(),
+    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+def load_audio(audio_path):
+    fbank = make_audio_features(audio_path, mel_bins=128)
+    fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
+    return fbank
+    
+def load_video(video_path):
+    video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
+    return video_feats[:, :, 0]
+
+
+def model_worker(
+    rank: int, args: argparse.Namespace, barrier: mp.Barrier,
+    request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
+) -> None:
+    """
+    The worker function that manipulates the GPU to run the inference.
+    Exact n_gpu workers are started, with each one operating on a separate GPU.
+
+    Args:
+        rank (int): Distributed rank of the worker.
+        args (argparse.Namespace): All command line arguments.
+        barrier (multiprocessing.Barrier): A barrier used to delay the start
+            of Web UI to be after the start of the model.
+    """
+
+    world_size = len(args.gpu_ids)
+    gpu_id = args.gpu_ids[rank]
+    dist.init_process_group(
+        backend="nccl", rank=rank, world_size=world_size,
+        init_method=f"tcp://{args.master_addr}:{args.master_port}",
+    )
+    print(f"| distributed init on worker {rank}/{world_size}. "
+          f"using gpu: {gpu_id}")
+    fs_init.initialize_model_parallel(world_size)
+    torch.cuda.set_device(gpu_id)
 
-description = """
-# LLaMA-Adapter🚀
-The official demo for **LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention**.
-Please refer to our [arXiv paper](https://arxiv.org/abs/2303.16199) and [github](https://github.com/ZrrSkywalker/LLaMA-Adapter) for more details.
-"""
+    torch.manual_seed(1)
+    np.random.seed(1)
+
+    # set the print behavior.
+    setup_for_distributed(rank == 0)
+
+    target_dtype = {
+        "bf16": torch.bfloat16,
+        "fp16": torch.float16
+    }[args.dtype]
+    with default_tensor_type(dtype=target_dtype, device="cuda"):
+        model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
+    print("Loading pretrained weights ...")
+    checkpoint = torch.load(args.pretrained_path, map_location='cpu')
+    msg = model.load_state_dict(checkpoint, strict=False)
+    print("load result:\n", msg)
+    model.cuda()
+    model.eval()
+    print(f"Model = {str(model)}")
+
+    barrier.wait()
+
+    while True:
+        img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
+        if 'image' in modality and img_path is not None:
+            image = Image.open(img_path).convert('RGB')
+            inputs = T_random_resized_crop(image)
+        elif 'video' in modality and video_path is not None:
+            inputs = load_video(video_path)
+        elif 'audio' in modality and audio_path is not None:
+            inputs = load_audio(audio_path)
+        else:
+            inputs = None
+        
+        if inputs is not None:
+            inputs = inputs[None].cuda().to(target_dtype)
+    
+        conv = conv_templates["v1"].copy()
+        for user, bot in chatbot:
+            conv.append_message(conv.roles[0], user)
+            conv.append_message(conv.roles[1], bot)
+
+        with torch.cuda.amp.autocast(dtype=target_dtype):
+            print(conv.get_prompt())
+            for stream_response in model.stream_generate(
+                conv.get_prompt(), inputs,
+                max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
+                modal = modality
+            ):
+                conv_sep = (
+                    conv.sep
+                    if conv.sep_style == SeparatorStyle.SINGLE
+                    else conv.sep2
+                )
+                end_pos = stream_response["text"].find(conv_sep)
+                if end_pos != -1:
+                    stream_response["text"] = (
+                        stream_response['text'][:end_pos].rstrip() + "\n"
+                    )
+                    stream_response["end_of_content"] = True
+
+                # keep a few characters if not end_of_content to avoid sending
+                # part of conv_sep before all of it is generated.
+                if not stream_response["end_of_content"]:
+                    if len(stream_response["text"]) < len(conv_sep):
+                        continue
+                    stream_response["text"] = (
+                        stream_response["text"][:-len(conv_sep)]
+                    )
+
+                if response_queue is not None:
+                    response_queue.put(stream_response)
+
+                if stream_response["end_of_content"]:
+                    break
+
+
+def gradio_worker(
+    request_queues: List[mp.Queue], response_queue: mp.Queue,
+    args: argparse.Namespace, barrier: mp.Barrier,
+) -> None:
+    """
+    The gradio worker is responsible for displaying the WebUI and relay the
+    requests to model workers. It should be launched only once.
+
+    Args:
+        request_queues (List[mp.Queue]): A list of request queues (one for
+            each model worker).
+        args (argparse.Namespace): All command line arguments.
+        barrier (multiprocessing.Barrier): A barrier used to delay the start
+            of Web UI to be after the start of the model.
+    """
+
+    def show_user_input(msg, chatbot):
+        return "", chatbot + [[msg, None]]
+
+    def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
+        for queue in request_queues:
+            queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
+        while True:
+            content_piece = response_queue.get()
+            chatbot[-1][1] = content_piece["text"]
+            yield chatbot
+            if content_piece["end_of_content"]:
+                break
+
+    def undo(chatbot):
+        if len(chatbot) > 0:
+            chatbot = chatbot[:-1]
+        return chatbot
+
+    def clear():
+        chatbot = []
+        msg = ""
+        return chatbot, msg
+
+    CSS ="""
+    .contain { display: flex; flex-direction: column; }
+    #component-0 { height: 100%; }
+    #chatbot { flex-grow: 1; overflow: auto;}
+    """
+    with gr.Blocks(css=CSS) as demo:
+        gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
+        with gr.Row(equal_height=True):
+            with gr.Column(scale=1):
+                img_path = gr.Image(label='Image Input', type='filepath')
+                video_path = gr.Video(label='Video Input')
+                audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
+                modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
+
+            with gr.Column(scale=2):
+                chatbot = gr.Chatbot(elem_id="chatbot")
+                msg = gr.Textbox()
 
-with gr.Blocks(css='style.css') as demo:
-    gr.Markdown(description)
-    with gr.TabItem("Instruction-Following"):
-        create_instruct_demo()
-    with gr.TabItem("Image Captioning"):
-        create_caption_demo()
+        with gr.Row():
+            submit_button = gr.Button("Submit", variant="primary")
+            undo_button = gr.Button("Undo")
+            clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
+        with gr.Row():
+            max_gen_len = gr.Slider(
+                minimum=1, maximum=args.model_max_seq_len // 2,
+                value=args.model_max_seq_len // 2, interactive=True,
+                label="Single-turn max response length",
+            )
+            gen_t = gr.Slider(
+                minimum=0, maximum=1, value=0.1, interactive=True,
+                label="Temperature",
+            )
+            top_p = gr.Slider(
+                minimum=0, maximum=1, value=0.75, interactive=True,
+                label="Top-p",
+            )
+        msg.submit(
+            show_user_input, [msg, chatbot], [msg, chatbot],
+        ).then(
+            stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+        )
+        submit_button.click(
+            show_user_input, [msg, chatbot], [msg, chatbot],
+        ).then(
+            stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+        )
+        undo_button.click(undo, chatbot, chatbot)
+        # img_path.change(clear, [], [chatbot, msg])
+    barrier.wait()
+    demo.queue(api_open=True).launch(share=True, max_threads=1)
+
+
+@dataclass
+class DemoConfig:
+    gpu_ids = [0]
+    tokenizer_path = "config/llama2/tokenizer.model"
+    llama_type = "onellm"
+    llama_config = "config/llama2/7B.json"
+    model_max_seq_len = 2048
+    # pretrained_path = "weights/7B_2048/consolidated.00-of-01.pth"
+    pretrained_path = hf_hub_download(repo_id="csuhan/OneLLM-7B", filename="consolidated.00-of-01.pth")
+    master_port = 23861
+    master_addr = "127.0.0.1"
+    dtype = "fp16"
+
+if __name__ == "__main__":
+    args = DemoConfig()
+    # using the default "fork" method messes up some imported libs (e.g.,
+    # pandas)
+    mp.set_start_method("spawn")
+
+    # setup the queues and start the model workers
+    request_queues = []
+    response_queue = mp.Queue()
+    worker_processes = []
+    barrier = mp.Barrier(len(args.gpu_ids) + 1)
+    for rank, gpu_id in enumerate(args.gpu_ids):
+        request_queue = mp.Queue()
+        rank_response_queue = response_queue if rank == 0 else None
+        process = mp.Process(
+            target=model_worker,
+            args=(rank, args, barrier, request_queue, rank_response_queue),
+        )
+        process.start()
+        worker_processes.append(process)
+        request_queues.append(request_queue)
 
-demo.queue(api_open=True, concurrency_count=1).launch()
+    gradio_worker(request_queues, response_queue, args, barrier)
diff --git a/config/llama2/7B.json b/config/llama2/7B.json
new file mode 100644
index 0000000000000000000000000000000000000000..6523f76675b50e9cf3a57d1fb135189abcffe1c7
--- /dev/null
+++ b/config/llama2/7B.json
@@ -0,0 +1 @@
+{"dim": 4096, "multiple_of": 256, "n_heads": 32, "n_layers": 32, "norm_eps": 1e-05, "vocab_size": -1}
diff --git a/config/llama2/tokenizer.model b/config/llama2/tokenizer.model
new file mode 100644
index 0000000000000000000000000000000000000000..6c00c742ce03c627d6cd5b795984876fa49fa899
--- /dev/null
+++ b/config/llama2/tokenizer.model
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
+size 499723
diff --git a/data/__pycache__/conversation_lib.cpython-310.pyc b/data/__pycache__/conversation_lib.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7104daf11059185efd723e40160c7debf191003b
Binary files /dev/null and b/data/__pycache__/conversation_lib.cpython-310.pyc differ
diff --git a/data/__pycache__/conversation_lib.cpython-39.pyc b/data/__pycache__/conversation_lib.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdca3a64c8523a2b4439b3e0c894c64b2020f486
Binary files /dev/null and b/data/__pycache__/conversation_lib.cpython-39.pyc differ
diff --git a/data/__pycache__/fintune_dataset.cpython-310.pyc b/data/__pycache__/fintune_dataset.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e4e08c37f7d6e25b19727d2847d34fc1fdd0c8e
Binary files /dev/null and b/data/__pycache__/fintune_dataset.cpython-310.pyc differ
diff --git a/data/__pycache__/fintune_dataset.cpython-39.pyc b/data/__pycache__/fintune_dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..45989953b86dec0a2084a59127bc1028067b8640
Binary files /dev/null and b/data/__pycache__/fintune_dataset.cpython-39.pyc differ
diff --git a/data/__pycache__/imu_utils.cpython-310.pyc b/data/__pycache__/imu_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cae8e9e22e039ecc388eb3193589fa83b5c3847b
Binary files /dev/null and b/data/__pycache__/imu_utils.cpython-310.pyc differ
diff --git a/data/__pycache__/imu_utils.cpython-39.pyc b/data/__pycache__/imu_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..46ddf01f39b937cad76df5c6201b4f62b441694d
Binary files /dev/null and b/data/__pycache__/imu_utils.cpython-39.pyc differ
diff --git a/data/__pycache__/video_utils.cpython-310.pyc b/data/__pycache__/video_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c007e0ea4f4c05fe867e0bc31077d4c6bc0fe79
Binary files /dev/null and b/data/__pycache__/video_utils.cpython-310.pyc differ
diff --git a/data/__pycache__/video_utils.cpython-39.pyc b/data/__pycache__/video_utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81b5c7d05231e85a9f4552385921740940514e39
Binary files /dev/null and b/data/__pycache__/video_utils.cpython-39.pyc differ
diff --git a/data/conversation_lib.py b/data/conversation_lib.py
new file mode 100644
index 0000000000000000000000000000000000000000..783fe0eb8f9dd425ec6c285e820f755d2e955a3b
--- /dev/null
+++ b/data/conversation_lib.py
@@ -0,0 +1,369 @@
+import dataclasses
+from enum import auto, Enum
+from typing import List, Tuple
+
+
+class SeparatorStyle(Enum):
+    """Different separator style."""
+    SINGLE = auto()
+    TWO = auto()
+    MPT = auto()
+
+
+@dataclasses.dataclass
+class Conversation:
+    """A class that keeps all conversation history."""
+    system: str
+    roles: List[str]
+    messages: List[List[str]]
+    offset: int
+    sep_style: SeparatorStyle = SeparatorStyle.SINGLE
+    sep: str = "###"
+    sep2: str = None
+    version: str = "Unknown"
+
+    skip_next: bool = False
+
+    def get_prompt(self):
+        if self.sep_style == SeparatorStyle.SINGLE:
+            ret = self.system + '\n\n' + self.sep
+            for role, message in self.messages:
+                if message:
+                    if type(message) is tuple:
+                        message, _, _ = message
+                    ret += role + ": " + message + '\n' + self.sep
+                else:
+                    ret += role + ":"
+            return ret
+        elif self.sep_style == SeparatorStyle.TWO:
+            seps = [self.sep, self.sep2]
+            ret = self.system + seps[0]
+            for i, (role, message) in enumerate(self.messages):
+                if message:
+                    if type(message) is tuple:
+                        message, _, _ = message
+                    ret += role + ": " + message + seps[i % 2]
+                else:
+                    ret += role + ":"
+            return ret
+        if self.sep_style == SeparatorStyle.MPT:
+            ret = self.system + self.sep
+            for role, message in self.messages:
+                if message:
+                    if type(message) is tuple:
+                        message, _, _ = message
+                    ret += role + message + self.sep
+                else:
+                    ret += role
+            return ret
+        else:
+            raise ValueError(f"Invalid style: {self.sep_style}")
+
+    def append_message(self, role, message):
+        self.messages.append([role, message])
+
+    def get_images(self, return_pil=False):
+        images = []
+        for i, (role, msg) in enumerate(self.messages[self.offset:]):
+            if i % 2 == 0:
+                if type(msg) is tuple:
+                    import base64
+                    from io import BytesIO
+                    from PIL import Image
+                    msg, image, image_process_mode = msg
+                    if image_process_mode == "Pad":
+                        def expand2square(pil_img, background_color=(122, 116, 104)):
+                            width, height = pil_img.size
+                            if width == height:
+                                return pil_img
+                            elif width > height:
+                                result = Image.new(pil_img.mode, (width, width), background_color)
+                                result.paste(pil_img, (0, (width - height) // 2))
+                                return result
+                            else:
+                                result = Image.new(pil_img.mode, (height, height), background_color)
+                                result.paste(pil_img, ((height - width) // 2, 0))
+                                return result
+
+                        image = expand2square(image)
+                    elif image_process_mode == "Crop":
+                        pass
+                    elif image_process_mode == "Resize":
+                        image = image.resize((224, 224))
+                    else:
+                        raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
+                    max_hw, min_hw = max(image.size), min(image.size)
+                    aspect_ratio = max_hw / min_hw
+                    max_len, min_len = 800, 400
+                    shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+                    longest_edge = int(shortest_edge * aspect_ratio)
+                    W, H = image.size
+                    if H > W:
+                        H, W = longest_edge, shortest_edge
+                    else:
+                        H, W = shortest_edge, longest_edge
+                    image = image.resize((W, H))
+                    if return_pil:
+                        images.append(image)
+                    else:
+                        buffered = BytesIO()
+                        image.save(buffered, format="JPEG")
+                        img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+                        images.append(img_b64_str)
+        return images
+
+    def to_gradio_chatbot(self):
+        ret = []
+        for i, (role, msg) in enumerate(self.messages[self.offset:]):
+            if i % 2 == 0:
+                if type(msg) is tuple:
+                    import base64
+                    from io import BytesIO
+                    msg, image, image_process_mode = msg
+                    max_hw, min_hw = max(image.size), min(image.size)
+                    aspect_ratio = max_hw / min_hw
+                    max_len, min_len = 800, 400
+                    shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
+                    longest_edge = int(shortest_edge * aspect_ratio)
+                    W, H = image.size
+                    if H > W:
+                        H, W = longest_edge, shortest_edge
+                    else:
+                        H, W = shortest_edge, longest_edge
+                    image = image.resize((W, H))
+                    # image = image.resize((224, 224))
+                    buffered = BytesIO()
+                    image.save(buffered, format="JPEG")
+                    img_b64_str = base64.b64encode(buffered.getvalue()).decode()
+                    img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
+                    msg = msg.replace('<image>', img_str)
+                ret.append([msg, None])
+            else:
+                ret[-1][-1] = msg
+        return ret
+
+    def copy(self):
+        return Conversation(
+            system=self.system,
+            roles=self.roles,
+            messages=[[x, y] for x, y in self.messages],
+            offset=self.offset,
+            sep_style=self.sep_style,
+            sep=self.sep,
+            sep2=self.sep2)
+
+    def dict(self):
+        if len(self.get_images()) > 0:
+            return {
+                "system": self.system,
+                "roles": self.roles,
+                "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
+                "offset": self.offset,
+                "sep": self.sep,
+                "sep2": self.sep2,
+            }
+        return {
+            "system": self.system,
+            "roles": self.roles,
+            "messages": self.messages,
+            "offset": self.offset,
+            "sep": self.sep,
+            "sep2": self.sep2,
+        }
+
+
+conv_v1 = Conversation(
+    system="A chat between a curious human and an artificial intelligence assistant. "
+           "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+    roles=("Human", "Assistant"),
+    messages=(
+        ("Human", "Give three tips for staying healthy."),
+        ("Assistant",
+         "Sure, here are three tips for staying healthy:\n"
+         "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
+         "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
+         "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
+         "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
+         "activities at least two days per week.\n"
+         "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
+         "vegetables, whole grains, lean proteins, and healthy fats can help support "
+         "your overall health. Try to limit your intake of processed and high-sugar foods, "
+         "and aim to drink plenty of water throughout the day.\n"
+         "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
+         "and mental health. Adults should aim for seven to nine hours of sleep per night. "
+         "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
+         "help improve the quality of your sleep.")
+    ),
+    offset=2,
+    sep_style=SeparatorStyle.SINGLE,
+    sep="###",
+)
+
+conv_v1_2 = Conversation(
+    system="A chat between a curious human and an artificial intelligence assistant. "
+           "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+    roles=("Human", "Assistant"),
+    messages=(),
+
+    # (
+    #     ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
+    #     ("Assistant",
+    #         "Renewable energy sources are those that can be replenished naturally in a relatively "
+    #         "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
+    #         "Non-renewable energy sources, on the other hand, are finite and will eventually be "
+    #         "depleted, such as coal, oil, and natural gas. Here are some key differences between "
+    #         "renewable and non-renewable energy sources:\n"
+    #         "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
+    #         "energy sources are finite and will eventually run out.\n"
+    #         "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
+    #         "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
+    #         "and other negative effects.\n"
+    #         "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
+    #         "have lower operational costs than non-renewable sources.\n"
+    #         "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
+    #         "locations than non-renewable sources.\n"
+    #         "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
+    #         "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
+    #         "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
+    #         "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
+    # )
+    offset = 2,
+    sep_style = SeparatorStyle.SINGLE,
+    sep = "###",
+    )
+
+conv_vicuna_v1_1 = Conversation(
+    system="A chat between a curious user and an artificial intelligence assistant. "
+           "The assistant gives helpful, detailed, and polite answers to the user's questions.",
+    roles=("USER", "ASSISTANT"),
+    version="v1",
+    messages=(),
+    offset=0,
+    sep_style=SeparatorStyle.TWO,
+    sep=" ",
+    sep2="</s>",
+)
+
+conv_mpt = Conversation(
+    system="""<|im_start|>system
+- You are a helpful language and vision assistant.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+    roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+    version="mpt",
+    messages=(),
+    offset=0,
+    sep_style=SeparatorStyle.MPT,
+    sep="<|im_end|>",
+)
+
+conv_mpt_text = Conversation(
+    system="""<|im_start|>system
+- You are a helpful assistant chatbot trained by MosaicML.
+- You answer questions.
+- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
+- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
+    roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+    version="mpt",
+    messages=(),
+    offset=0,
+    sep_style=SeparatorStyle.MPT,
+    sep="<|im_end|>",
+)
+
+conv_bair_v1 = Conversation(
+    system="BEGINNING OF CONVERSATION:",
+    roles=("USER", "GPT"),
+    messages=(),
+    offset=0,
+    sep_style=SeparatorStyle.TWO,
+    sep=" ",
+    sep2="</s>",
+)
+
+simple_conv = Conversation(
+    system="A chat between a curious human and an artificial intelligence assistant. "
+           "The assistant gives helpful, detailed, and polite answers to the human's questions.",
+    roles=("Human", "Assistant"),
+    messages=(
+        ("Human", "Hi!"),
+        ("Assistant", "Hi there! How can I help you today?")
+    ),
+    offset=2,
+    sep_style=SeparatorStyle.SINGLE,
+    sep="###",
+)
+
+simple_conv_multimodal = Conversation(
+    system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+           "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+           "Follow the instructions carefully and explain your answers in detail.",
+    roles=("Human", "Assistant"),
+    messages=(
+        ("Human", "Hi!"),
+        ("Assistant", "Hi there!  How can I help you today?\n")
+    ),
+    offset=2,
+    sep_style=SeparatorStyle.SINGLE,
+    sep="###",
+)
+
+simple_conv_mpt_multimodal = Conversation(
+    system="""<|im_start|>system
+- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
+- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
+- You should follow the instructions carefully and explain your answers in detail.""",
+    roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
+    version="mpt",
+    messages=(),
+    offset=0,
+    sep_style=SeparatorStyle.MPT,
+    sep="<|im_end|>",
+)
+
+simple_conv_legacy = Conversation(
+    system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
+           "You are designed to assist human with a variety of tasks using natural language."
+           "Follow the instructions carefully.",
+    roles=("Human", "Assistant"),
+    messages=(
+        ("Human", "Hi!\n\n### Response:"),
+        ("Assistant", "Hi there!  How can I help you today?\n")
+    ),
+    offset=2,
+    sep_style=SeparatorStyle.SINGLE,
+    sep="###",
+)
+
+conv_llava_v1 = Conversation(
+    system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
+           "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
+           "Follow the instructions carefully and explain your answers in detail.",
+    roles=("USER", "ASSISTANT"),
+    version="v1",
+    messages=(),
+    offset=0,
+    sep_style=SeparatorStyle.TWO,
+    sep=" ",
+    sep2="</s>",
+)
+
+default_conversation = conv_v1_2
+conv_templates = {
+    "default": conv_v1_2,
+    "simple": simple_conv,
+    "simple_legacy": simple_conv_legacy,
+    "multimodal": simple_conv_multimodal,
+    "mpt_multimodal": simple_conv_mpt_multimodal,
+    "llava_v1": conv_llava_v1,
+
+    # fastchat
+    "v1": conv_v1_2,
+    "bair_v1": conv_bair_v1,
+    "vicuna_v1_1": conv_vicuna_v1_1,
+    "mpt": conv_mpt,
+    "mpt_text": conv_mpt_text,
+}
+
+if __name__ == "__main__":
+    print(default_conversation.get_prompt())
diff --git a/data/fintune_dataset.py b/data/fintune_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f787139d702fbc46ef5ca8189ade8f89a9c7df0
--- /dev/null
+++ b/data/fintune_dataset.py
@@ -0,0 +1,449 @@
+import warnings
+
+import torch
+import yaml
+from torch.utils.data import Dataset
+from PIL import Image
+import json
+from model.tokenizer import Tokenizer
+import os
+import torchvision.transforms as transforms
+import random
+import torchvision.transforms.functional as F
+import torchaudio
+from . import conversation_lib
+
+import numpy as np
+from . import video_utils
+from .imu_utils import get_imu_frames
+
+
+IGNORE_INDEX = -100
+
+DEFAULT_IMAGE_TOKEN = "<image>"
+try:
+    from torchvision.transforms import InterpolationMode
+
+    BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+    BICUBIC = Image.BICUBIC
+
+T_random_resized_crop = transforms.Compose([
+    transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=BICUBIC,
+                                 antialias=None),  # 3 is bicubic
+    transforms.ToTensor(),
+    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+# image transform
+transform_img_train = transforms.Compose([
+    transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
+        0.75, 1.3333), interpolation=3, antialias=None),  # 3 is bicubic
+    transforms.ToTensor(),
+    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+class PairRandomResizedCrop(transforms.RandomResizedCrop):
+    def forward(self, imgs):
+        i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio)
+        return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation, antialias=self.antialias) for img in imgs]
+
+
+class PairToTensor(transforms.ToTensor):
+    def __call__(self, pics):
+        return [F.to_tensor(pic) for pic in pics]
+
+
+class PairNormalize(transforms.Normalize):
+    def forward(self, tensors):
+        return [F.normalize(tensor, self.mean, self.std, self.inplace) for tensor in tensors]
+
+
+transform_pairimg_train = transforms.Compose([
+    PairRandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(
+        0.75, 1.3333), interpolation=3, antialias=None),  # 3 is bicubic
+    PairToTensor(),
+    PairNormalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+def pc_norm(pc):
+    """ pc: NxC, return NxC """
+    xyz = pc[:, :3]
+    other_feature = pc[:, 3:]
+
+    centroid = torch.mean(xyz, dim=0)
+    xyz = xyz - centroid
+    m = torch.max(torch.sqrt(torch.sum(xyz ** 2, dim=1)))
+    xyz = xyz / m
+
+    pc = torch.cat((xyz, other_feature), dim=1)
+    return pc
+
+
+def make_audio_features(wav_name, mel_bins=128, target_length=1024, aug=False):
+    waveform, sr = torchaudio.load(wav_name)
+    # assert sr == 16000, 'input audio sampling rate must be 16kHz'
+    if sr != 16000:
+        trans = torchaudio.transforms.Resample(sr, 16000)
+        waveform = trans(waveform)
+
+    waveform = waveform - waveform.mean()
+
+    fbank = torchaudio.compliance.kaldi.fbank(
+        waveform, htk_compat=True, sample_frequency=16000, use_energy=False,
+        window_type='hanning', num_mel_bins=mel_bins, dither=0.0, frame_shift=10)
+
+    n_frames = fbank.shape[0]
+
+    p = target_length - n_frames
+    if p > 0:
+        m = torch.nn.ZeroPad2d((0, 0, 0, p))
+        fbank = m(fbank)
+    elif p < 0:
+        fbank = fbank[0:target_length, :]
+
+    if aug:
+        freqm = torchaudio.transforms.FrequencyMasking(48)
+        timem = torchaudio.transforms.TimeMasking(192)
+        fbank = torch.transpose(fbank, 0, 1)
+        fbank = fbank.unsqueeze(0)
+        fbank = freqm(fbank)
+        fbank = timem(fbank)
+        fbank = fbank.squeeze(0)
+        fbank = torch.transpose(fbank, 0, 1)
+
+    fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
+    return fbank
+
+
+class ConversationGenerator:
+    def __init__(self, tokenizer):
+        self.tokenizer = tokenizer
+        self.header = f"{conversation_lib.default_conversation.system}\n\n"
+        self._probe_tokenizer_style()
+
+    def _probe_tokenizer_style(self):
+        """
+        Given a sentence, e.g. "My darling", some tokenizers will make the space a seperate token,
+        while some others will merge the space into the next word, forming a token representing " darling".
+        Knowing which style the tokenizer takes is necessary for correct ground-truth label masking.
+
+        """
+        probe = "Probe am I"
+        sentence1 = self.tokenizer.encode(conversation_lib.default_conversation.roles[1] + ": " + probe,
+                                          bos=False, eos=False)
+        sentence2 = self.tokenizer.encode(probe,
+                                          bos=False, eos=False)
+        if sentence1[-len(sentence2):] == sentence2:
+            self.space_before_to_predict = False
+        else:
+            sentence3 = self.tokenizer.encode(" " + probe,
+                                              bos=False, eos=False)
+            assert sentence1[-len(sentence3):] == sentence3
+            self.space_before_to_predict = True
+
+    def add_speaker_and_signal(self, source, get_conversation=True):
+        """Add speaker and start/end signal on each round."""
+        BEGIN_SIGNAL = "### "
+        END_SIGNAL = "\n"
+        conversation = self.header
+
+        to_predict_list = []
+
+        for sentence in source:
+            from_str = sentence["from"]
+            if from_str.lower() in ["human"]:
+                from_str = conversation_lib.default_conversation.roles[0]
+            elif from_str.lower() in ["gpt", "assistant"]:
+                from_str = conversation_lib.default_conversation.roles[1]
+            else:
+                raise ValueError(f"unknown dialog role: {from_str.lower()}")
+
+            value = sentence["value"]
+            if DEFAULT_IMAGE_TOKEN in value:
+                value = value.replace(DEFAULT_IMAGE_TOKEN, '').strip()
+
+            sentence_value = BEGIN_SIGNAL + from_str + ": " + value + END_SIGNAL
+
+            if from_str == conversation_lib.default_conversation.roles[1]:
+                to_predict_value = value + END_SIGNAL + "###"
+                if self.space_before_to_predict:
+                    to_predict_value = " " + to_predict_value
+                to_predict_list.append(to_predict_value)
+
+            if get_conversation:
+                conversation = conversation + sentence_value
+
+        conversation = conversation + BEGIN_SIGNAL
+        return conversation, to_predict_list
+
+
+DATASETS = dict(
+    image=[
+        dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_image.json", type='image'),
+        dict(path='datasets/InstructionTuning/image/cococap_train.json', type='image'),
+        dict(path="datasets/InstructionTuning/image/llava_v1_5_mix665k_text.json", type='text'),
+    ],
+    audio=[
+        dict(path="datasets/InstructionTuning/audio/audiocap_train.json", type='audio'),
+        dict(path="datasets/InstructionTuning/audio/audiocap_val.json", type='audio'),
+        dict(path="datasets/InstructionTuning/audio/audio_conversation.json", type='audio'),
+    ],
+    video=[
+        dict(path="datasets/InstructionTuning/video/msrvtt_cap_trainval.json", type='video'),
+        dict(path="datasets/InstructionTuning/video/msrvtt_cap_test.json", type='video'),
+        dict(path="datasets/InstructionTuning/video/msrvtt_vqa_train.json", type='video'),
+        dict(path="datasets/InstructionTuning/video/msrvtt_vqa_val.json", type='video'),
+        dict(path="datasets/InstructionTuning/video/msrvtt_vqa_test.json", type='video'),
+        dict(path="datasets/InstructionTuning/video/video_complex_reasoning_10k.json", type='video'),
+        dict(path="datasets/InstructionTuning/video/video_conversation_10k.json", type='video'),
+        dict(path="datasets/InstructionTuning/video/video_detail_10k.json", type='video'),
+    ],
+    point=[
+        dict(path="datasets/InstructionTuning/point/pointllm_70k_formated.json", type='point'),
+    ],
+    rgbd=[
+        dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_depth.json", type='rgbd'),
+    ],
+    rgbn=[
+        dict(path="datasets/InstructionTuning/depth_normal/llava_instruct_50k_normal.json", type='rgbn'),
+    ],
+    imu=[
+        dict(path="datasets/InstructionTuning/imu/imu_fixed_50k.json", type='imu'),
+    ],
+    fmri=[
+        dict(path="datasets/InstructionTuning/fmri/fmri_fixed.json", type='fmri'),
+    ],
+)
+IMU_PATH = "/mnt/petrelfs/share_data/hanjiaming/ego4d/v2/processed_imu/"
+
+
+class FinetuneDialogDataset(Dataset):
+    def __init__(self, dataset=['image'], transform=T_random_resized_crop, max_words=2048, image_words=30, tokenizer_path=None):
+        if isinstance(dataset, str):
+            dataset = [dataset]
+
+        self.dataset = dataset
+
+        group_ann = {}
+        for d in dataset:
+            for meta in DATASETS[d]:
+                meta_path, meta_type = meta['path'], meta['type']
+                meta_ext = os.path.splitext(meta_path)[-1]
+                if meta_ext == ".json":
+                    with open(meta_path) as f:
+                        meta_l = json.load(f)
+                        # add data_type
+                        # this is a temp solution
+                        new_meta_l = []
+                        for l in meta_l:
+                            l['data_type'] = meta_type
+                            new_meta_l.append(l)
+                        meta_l = new_meta_l
+                elif meta_ext == ".jsonl":
+                    meta_l = []
+                    with open(meta_path) as f:
+                        for i, line in enumerate(f):
+                            try:
+                                meta_l.append(json.loads(line))
+                            except json.decoder.JSONDecodeError as e:
+                                print(
+                                    f"Error decoding the following jsonl line ({i}):\n{line.rstrip()}", force=True)
+                                raise e
+                else:
+                    raise NotImplementedError(
+                        f"Unknown meta file extension: \"{meta_ext}\". "
+                        f"Currently, .json, .jsonl are supported. "
+                        "If you are using a supported format, please set the file extension so that the proper parsing "
+                        "routine can be called."
+                    )
+                if meta_type not in group_ann:
+                    group_ann[meta_type] = []
+                print(f"{meta_path}, type {meta_type}: len {len(meta_l)}")
+                group_ann[meta_type] += meta_l
+
+        # sort group_ann for higher efficiency (items in one global batch with similar length)
+        for meta_type, meta_l in group_ann.items():
+            meta_l.sort(key=lambda data_item: sum(
+                [len(_['value']) for _ in data_item['conversations']]))
+
+        self.group_ann = group_ann
+        self.ann = sum(list(self.group_ann.values()), start=[])
+
+        self.group_indices = {}
+        start_pos = 0
+        for meta_type, meta_l in self.group_ann.items():
+            self.group_indices[meta_type] = list(
+                range(start_pos, start_pos + len(meta_l)))
+            start_pos = start_pos + len(meta_l)
+
+        print(f"total length: {len(self)}")
+        self.transform = transform
+        print(f"transform:\n{self.transform}")
+        self.max_words = max_words
+        self.image_words = image_words
+        self.tokenizer = Tokenizer(model_path=tokenizer_path)
+        self.conversation_generator = ConversationGenerator(self.tokenizer)
+
+        self.load_funcs = dict(
+            image=self.load_image,
+            audio=self.load_audio,
+            video=self.load_video,
+            point=self.load_point,
+            rgbd=self.load_rgbx,
+            rgbn=self.load_rgbx,
+            imu=self.load_imu,
+            fmri=self.load_fmri
+        )
+
+    def __len__(self):
+        return len(self.ann)
+
+    def load_image(self, data):
+        filename = data['image']
+        image = Image.open(filename).convert('RGB')
+        image = self.transform(image)
+        return image
+
+    def load_audio(self, data):
+        audio_path = data['image']
+        fbank = make_audio_features(audio_path, mel_bins=128)
+        fbank = fbank.transpose(0, 1)[None]  # [1, 128, 1024]
+        return fbank
+
+    def load_video(self, data):
+        video_path = data['image']
+        video_feats = video_utils.load_and_transform_video_data(
+            video_path, video_path, clip_duration=1, clips_per_video=5)
+        return video_feats[:, :, 0]
+
+    def load_point(self, data):
+        point_path = data['image']
+        point_feat = torch.load(point_path, map_location='cpu')
+        point_feat = point_feat.transpose(0, 1)
+        return point_feat
+
+    def load_rgbx(self, data):
+        image_path = data['image']
+        x_image_path = data['depth_image'] if 'depth_image' in data else data['normal_image']
+        image = Image.open(image_path).convert('RGB')
+        x_image = Image.open(x_image_path).convert('RGB')
+        x_image = x_image.resize(image.size[-2:])
+
+        image, x_image = transform_pairimg_train([image, x_image])
+        # [2, 3, H, W]
+        image = torch.stack([image, x_image], dim=0)
+        return image
+
+    def load_fmri(self, data):
+        fmri_path = data['image']
+        data = np.load(fmri_path)
+        data = data.mean(axis=0)
+        data = torch.tensor(data[None])
+        return data
+
+    def load_imu(self, data_dict):
+        uid = data_dict["video_uid"]
+        w_s = data_dict["window_start"]
+        w_e = data_dict["window_end"]
+
+        imu_data = get_imu_frames(
+            IMU_PATH, uid,
+            video_start_sec=w_s,
+            video_end_sec=w_e,
+        )
+        if imu_data is None:
+            raise ValueError
+        return imu_data['signal']
+
+    def __getitem__(self, index, expect_type=None):
+        if expect_type is None:
+            data_item = self.ann[index]
+        else:
+            # in case we want get data from specific data_type
+            data_item = self.group_ann[expect_type][index]
+
+        data_type = data_item['data_type']
+        if data_type != 'text':
+            if data_type in self.load_funcs:
+                try:
+                    image = self.load_funcs[data_type](data_item)
+                    if image == None:
+                        raise ValueError('Data is None')
+                except:
+                    print('Error', data_item)
+                    rand_idx = random.randint(
+                        0, len(self.group_ann[data_type]))
+                    return self.__getitem__(rand_idx, expect_type=data_type)
+            else:
+                raise ValueError(f'Does not support {data_type}')
+        else:
+            image = None
+            # warnings.warn("pure black image for examples without image")
+            # image = torch.zeros(3, 224, 224)
+
+        source = data_item["conversations"]
+        conversation, to_predict_values = self.conversation_generator.add_speaker_and_signal(
+            source)
+        if len(to_predict_values) == 0:
+            warnings.warn(
+                f"see dialog data with nothing to predict, data: {data_item}")
+            return self[index-1]
+
+        tokenzed_conversation = self.tokenizer.encode(
+            conversation, bos=True, eos=True)
+        labels = [IGNORE_INDEX for _ in tokenzed_conversation]
+
+        check_pos = 0
+        for value in to_predict_values:
+            tokenized_value = self.tokenizer.encode(
+                value, bos=False, eos=False)
+            value_pos = find_sublist(
+                tokenzed_conversation[check_pos:], tokenized_value) + check_pos
+            if value_pos == -1:
+                print(
+                    "a sentence mismatches the corresponding piece in the conversation")
+                return self[index-1]
+            labels[value_pos:value_pos+len(tokenized_value)] = tokenized_value
+            assert labels[value_pos:value_pos+len(
+                tokenized_value)] == tokenzed_conversation[value_pos:value_pos+len(tokenized_value)]
+            check_pos = value_pos+len(tokenized_value)
+
+        input2 = torch.tensor(tokenzed_conversation, dtype=torch.int64)
+        labels = torch.tensor(labels, dtype=torch.int64)
+
+        if image is not None:
+            max_words = self.max_words - self.image_words
+        else:
+            max_words = self.max_words
+        padding = max_words - input2.shape[0]
+        if padding > 0:
+            input2 = torch.cat(
+                (input2, torch.zeros(padding, dtype=torch.int64) - 1))
+            labels = torch.cat(
+                (labels, torch.zeros(padding, dtype=torch.int64) - 1))
+        elif padding < 0:
+            input2 = input2[:max_words]
+            labels = labels[:max_words]
+
+        input2_mask = input2.ge(0)
+        label_mask = labels.ge(0)
+        input2[~input2_mask] = 0
+        labels[~label_mask] = 0
+        input2_mask = input2_mask.float()
+        label_mask = label_mask.float()
+        if image is None:
+            return input2, labels, data_item['data_type']
+        else:
+            return input2, labels, image, data_item['data_type']
+
+    def groups(self):
+        return list(self.group_indices.values())
+
+
+def find_sublist(a: list, b: list):
+    len_a, len_b = len(a), len(b)
+    for i in range(len_a - len_b + 1):
+        if a[i:i+len_b] == b:
+            return i
+    return -1
diff --git a/data/imu_utils.py b/data/imu_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..010563d67e603bd7ca5589672058380a79ee93d9
--- /dev/null
+++ b/data/imu_utils.py
@@ -0,0 +1,257 @@
+import string
+import numpy as np
+import matplotlib.animation as animation
+from matplotlib import pyplot as plt
+import json
+from collections import defaultdict
+from bisect import bisect_left
+import os
+import torch
+import torchaudio
+torchaudio.set_audio_backend("sox_io")
+
+
+def load_json(json_path: str):
+    """
+    Load a json file
+    """
+    with open(json_path, "r", encoding="utf-8") as f_name:
+        data = json.load(f_name)
+    return data
+
+
+def check_window_signal(info_t, w_s, w_e):
+    length = w_e - w_s
+    frame_offset = int(w_s * info_t.sample_rate)
+    num_frames = int(length * info_t.sample_rate)
+    if frame_offset + num_frames > int(info_t.num_frames):
+        return False
+    else:
+        return True
+
+
+def index_narrations(ann_path):
+    narration_raw = load_json(ann_path)
+
+    narration_dict = defaultdict(list)
+    summary_dict = defaultdict(list)
+    avg_len = []
+    for v_id, narr in narration_raw.items():
+        narr_list = []
+        summ_list = []
+        if "narration_pass_1" in narr:
+            narr_list += narr["narration_pass_1"]["narrations"]
+            summ_list += narr["narration_pass_1"]["summaries"]
+        if "narration_pass_2" in narr:
+            narr_list += narr["narration_pass_2"]["narrations"]
+            summ_list += narr["narration_pass_2"]["summaries"]
+
+        if len(narr_list) > 0:
+            narration_dict[v_id] = [
+                (
+                    float(n_t["timestamp_sec"]),
+                    n_t["narration_text"],
+                    n_t["annotation_uid"],
+                    n_t["timestamp_frame"],
+                )
+                for n_t in narr_list
+            ]
+            avg_len.append(len(narration_dict[v_id]))
+        else:
+            narration_dict[v_id] = []
+        if len(summ_list) > 0:
+            summary_dict[v_id] = [
+                (
+                    float(s_t["start_sec"]),
+                    float(s_t["end_sec"]),
+                    s_t["summary_text"],
+                )
+                for s_t in summ_list
+            ]
+        else:
+            summary_dict[v_id] = []
+    # print(f"Number of Videos with narration {len(narration_dict)}")
+    # print(f"Avg. narration length {np.mean(avg_len)}")
+    # print(f"Number of Videos with summaries {len(summary_dict)}")
+    return narration_dict, summary_dict
+
+
+def get_signal_info(signal_fn: str):
+    return torchaudio.info(signal_fn)
+
+
+def get_signal_frames(signal_fn: str, video_start_sec: float, video_end_sec: float):
+    """
+    Given a signal track return the frames between video_start_sec and video_end_sec
+    """
+    info_t = get_signal_info(signal_fn)
+
+    length = video_end_sec - video_start_sec
+    aframes, _ = torchaudio.load(
+        signal_fn,
+        normalize=True,
+        frame_offset=int(video_start_sec * info_t.sample_rate),
+        num_frames=int(length * info_t.sample_rate),
+    )
+    return {"signal": aframes, "meta": info_t}
+
+
+def tosec(value):
+    return value / 1000
+
+
+def toms(value):
+    return value * 1000
+
+
+def delta(first_num: float, second_num: float):
+    """Compute the absolute value of the difference of two numbers"""
+    return abs(first_num - second_num)
+
+
+def padIMU(signal, duration_sec):
+    """
+    Pad the signal if necessary
+    """
+    expected_elements = round(duration_sec) * 200
+
+    if signal.shape[0] > expected_elements:
+        signal = signal[:expected_elements, :]
+    elif signal.shape[0] < expected_elements:
+        padding = expected_elements - signal.shape[0]
+        padded_zeros = np.zeros((padding, 6))
+        signal = np.concatenate([signal, padded_zeros], 0)
+        # signal = signal[:expected_elements, :]
+    return signal
+
+
+def resample(
+    signals: np.ndarray,
+    timestamps: np.ndarray,
+    original_sample_rate: int,
+    resample_rate: int,
+):
+    """
+    Resamples data to new sample rate
+    """
+    signals = torch.as_tensor(signals)
+    timestamps = torch.from_numpy(timestamps).unsqueeze(-1)
+    signals = torchaudio.functional.resample(
+        waveform=signals.data.T,
+        orig_freq=original_sample_rate,
+        new_freq=resample_rate,
+    ).T.numpy()
+
+    nsamples = len(signals)
+
+    period = 1 / resample_rate
+
+    # timestamps are expected to be shape (N, 1)
+    initital_seconds = timestamps[0] / 1e3
+
+    ntimes = (torch.arange(nsamples) * period).view(-1, 1) + initital_seconds
+
+    timestamps = (ntimes * 1e3).squeeze().numpy()
+    return signals, timestamps
+
+
+def resampleIMU(signal, timestamps):
+    sampling_rate = int(1000 * (1 / (np.mean(np.diff(timestamps)))))
+    # resample all to 200hz
+    if sampling_rate != 200:
+        signal, timestamps = resample(signal, timestamps, sampling_rate, 200)
+    return signal, timestamps
+
+
+def get_imu_frames(
+    imu_path,
+    uid: str,
+    video_start_sec: float,
+    video_end_sec: float,
+):
+    """
+    Given a IMU signal return the frames between video_start_sec and video_end_sec
+    """
+    signal = np.load(os.path.join(imu_path, f"{uid}.npy"))
+    signal = signal.transpose()
+    timestamps = np.load(os.path.join(imu_path, f"{uid}_timestamps.npy"))
+
+    if toms(video_start_sec) > timestamps[-1] or toms(video_end_sec) > timestamps[-1]:
+        return None
+
+    start_id = bisect_left(timestamps, toms(video_start_sec))
+    end_id = bisect_left(timestamps, toms(video_end_sec))
+
+    # make sure the retrieved window interval are correct by a max of 1 sec margin
+    if (
+        delta(video_start_sec, tosec(timestamps[start_id])) > 4
+        or delta(video_end_sec, tosec(timestamps[end_id])) > 4
+    ):
+        return None
+
+    # get the window
+    if start_id == end_id:
+        start_id -= 1
+        end_id += 1
+    signal, timestamps = signal[start_id:end_id], timestamps[start_id:end_id]
+
+    if len(signal) < 10 or len(timestamps) < 10:
+        return None
+    # resample the signal at 200hz if necessary
+    signal, timestamps = resampleIMU(signal, timestamps)
+
+    # pad  the signal if necessary
+    signal = padIMU(signal, video_end_sec - video_start_sec)
+
+    sample_dict = {
+        "timestamp": timestamps,
+        "signal": torch.tensor(signal.T),
+        "sampling_rate": 200,
+    }
+
+    return sample_dict
+
+
+def display_animation(frames, title, save_path_gif):
+    fig, ax = plt.subplots()
+    frames = [[ax.imshow(frames[i])] for i in range(len(frames))]
+    plt.title(title)
+    ani = animation.ArtistAnimation(fig, frames)
+    ani.save(save_path_gif, writer="imagemagick")
+    plt.close()
+
+
+def display_animation_imu(frames, imu, title, save_path_gif):
+    fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
+    ax1.set_title(title)
+    ax2.set_title("Acc.")
+    ax3.set_title("Gyro.")
+    frames = [[ax1.imshow(frames[i])] for i in range(len(frames))]
+    ani = animation.ArtistAnimation(fig, frames)
+
+    ax2.plot(imu[0].cpu().numpy(), color="red")
+    ax2.plot(imu[1].cpu().numpy(), color="blue")
+    ax2.plot(imu[2].cpu().numpy(), color="green")
+    ax3.plot(imu[3].cpu().numpy(), color="red")
+    ax3.plot(imu[4].cpu().numpy(), color="blue")
+    ax3.plot(imu[5].cpu().numpy(), color="green")
+    plt.tight_layout()
+    ani.save(save_path_gif, writer="imagemagick")
+    plt.close()
+
+
+def filter_narration(narration_text: str) -> bool:
+    if "#c" in narration_text.lower():
+        return True
+    return False
+
+
+def clean_narration_text(narration_text: str) -> str:
+    return (
+        narration_text.replace("#C C ", "")
+        .replace("#C", "")
+        .replace("#unsure", "something")
+        .strip()
+        .strip(string.punctuation)
+        .lower()[:128]
+    )
diff --git a/data/video_utils.py b/data/video_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..43ac03067e50c8570d422a057b9d9efb18e8775b
--- /dev/null
+++ b/data/video_utils.py
@@ -0,0 +1,204 @@
+import math
+import torch
+import torch.nn as nn
+from pytorchvideo import transforms as pv_transforms
+from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
+from pytorchvideo.data.encoded_video import EncodedVideo
+from pytorchvideo.data.encoded_video_decord import EncodedVideoDecord
+from torchvision import transforms
+from torchvision.transforms._transforms_video import NormalizeVideo
+
+
+def get_clip_timepoints(clip_sampler, duration):
+    # Read out all clips in this video
+    all_clips_timepoints = []
+    is_last_clip = False
+    end = 0.0
+    while not is_last_clip:
+        start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
+        all_clips_timepoints.append((start, end))
+    return all_clips_timepoints
+
+
+
+def crop_boxes(boxes, x_offset, y_offset):
+    """
+    Perform crop on the bounding boxes given the offsets.
+    Args:
+        boxes (ndarray or None): bounding boxes to perform crop. The dimension
+            is `num boxes` x 4.
+        x_offset (int): cropping offset in the x axis.
+        y_offset (int): cropping offset in the y axis.
+    Returns:
+        cropped_boxes (ndarray or None): the cropped boxes with dimension of
+            `num boxes` x 4.
+    """
+    cropped_boxes = boxes.copy()
+    cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
+    cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
+
+    return cropped_boxes
+
+
+def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
+    """
+    Perform uniform spatial sampling on the images and corresponding boxes.
+    Args:
+        images (tensor): images to perform uniform crop. The dimension is
+            `num frames` x `channel` x `height` x `width`.
+        size (int): size of height and weight to crop the images.
+        spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+            is larger than height. Or 0, 1, or 2 for top, center, and bottom
+            crop if height is larger than width.
+        boxes (ndarray or None): optional. Corresponding boxes to images.
+            Dimension is `num boxes` x 4.
+        scale_size (int): optinal. If not None, resize the images to scale_size before
+            performing any crop.
+    Returns:
+        cropped (tensor): images with dimension of
+            `num frames` x `channel` x `size` x `size`.
+        cropped_boxes (ndarray or None): the cropped boxes with dimension of
+            `num boxes` x 4.
+    """
+    assert spatial_idx in [0, 1, 2]
+    ndim = len(images.shape)
+    if ndim == 3:
+        images = images.unsqueeze(0)
+    height = images.shape[2]
+    width = images.shape[3]
+
+    if scale_size is not None:
+        if width <= height:
+            width, height = scale_size, int(height / width * scale_size)
+        else:
+            width, height = int(width / height * scale_size), scale_size
+        images = torch.nn.functional.interpolate(
+            images,
+            size=(height, width),
+            mode="bilinear",
+            align_corners=False,
+        )
+
+    y_offset = int(math.ceil((height - size) / 2))
+    x_offset = int(math.ceil((width - size) / 2))
+
+    if height > width:
+        if spatial_idx == 0:
+            y_offset = 0
+        elif spatial_idx == 2:
+            y_offset = height - size
+    else:
+        if spatial_idx == 0:
+            x_offset = 0
+        elif spatial_idx == 2:
+            x_offset = width - size
+    cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
+    cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+    if ndim == 3:
+        cropped = cropped.squeeze(0)
+    return cropped, cropped_boxes
+
+
+class SpatialCrop(nn.Module):
+    """
+    Convert the video into 3 smaller clips spatially. Must be used after the
+        temporal crops to get spatial crops, and should be used with
+        -2 in the spatial crop at the slowfast augmentation stage (so full
+        frames are passed in here). Will return a larger list with the
+        3x spatial crops as well.
+    """
+
+    def __init__(self, crop_size: int = 224, num_crops: int = 3):
+        super().__init__()
+        self.crop_size = crop_size
+        if num_crops == 3:
+            self.crops_to_ext = [0, 1, 2]
+            self.flipped_crops_to_ext = []
+        elif num_crops == 1:
+            self.crops_to_ext = [1]
+            self.flipped_crops_to_ext = []
+        else:
+            raise NotImplementedError("Nothing else supported yet")
+
+    def forward(self, videos):
+        """
+        Args:
+            videos: A list of C, T, H, W videos.
+        Returns:
+            videos: A list with 3x the number of elements. Each video converted
+                to C, T, H', W' by spatial cropping.
+        """
+        assert isinstance(videos, list), "Must be a list of videos after temporal crops"
+        assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
+        res = []
+        for video in videos:
+            for spatial_idx in self.crops_to_ext:
+                res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
+            if not self.flipped_crops_to_ext:
+                continue
+            flipped_video = transforms.functional.hflip(video)
+            for spatial_idx in self.flipped_crops_to_ext:
+                res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
+        return res
+
+
+def load_and_transform_video_data(
+    video_file,
+    video_path,
+    clip_duration=2,
+    clips_per_video=5,
+    sample_rate=16000,
+    with_audio=False
+):
+    video_transform = transforms.Compose(
+        [
+            pv_transforms.ShortSideScale(224),
+            NormalizeVideo(
+                mean=(0.48145466, 0.4578275, 0.40821073),
+                std=(0.26862954, 0.26130258, 0.27577711),
+            ),
+        ]
+    )
+
+    clip_sampler = ConstantClipsPerVideoSampler(
+        clip_duration=clip_duration, clips_per_video=clips_per_video
+    )
+    frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
+
+    if isinstance(video_file, str):
+        video = EncodedVideo.from_path(
+            video_file,
+            decoder="decord",
+            decode_audio=with_audio,
+            # **{"sample_rate": sample_rate},
+        )
+    else:
+        video = EncodedVideoDecord(video_file, video_name=video_path, decode_video=True, decode_audio=with_audio, sample_rate=sample_rate)
+    
+    all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
+
+    all_video = []
+    for clip_timepoints in all_clips_timepoints:
+        # Read the clip, get frames
+        clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
+        if clip is None:
+            raise ValueError("No clip found")
+        video_clip = frame_sampler(clip["video"])
+        video_clip = video_clip / 255.0  # since this is float, need 0-1
+
+        all_video.append(video_clip)
+
+    all_video = [video_transform(clip) for clip in all_video]
+    all_video = SpatialCrop(224, num_crops=3)(all_video)
+
+    all_video = torch.stack(all_video, dim=0)
+
+    if not with_audio:
+        return all_video
+    else:
+        return all_video, clip['audio']
+
+if __name__ == '__main__':
+    video_path = "datasets/InstructionTuning/video/music_aqa/MUSIC-AVQA-videos-Real/00000002.mp4"
+    video, audio = load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5, with_audio=True)
+    import pdb;pdb.set_trace()
\ No newline at end of file
diff --git a/demos/multi_turn_mm.py b/demos/multi_turn_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f354e6c68d0a09df50c87a1a53f110a4fe7321a
--- /dev/null
+++ b/demos/multi_turn_mm.py
@@ -0,0 +1,300 @@
+import sys
+import os
+sys.path.append(os.path.abspath(__file__).rsplit('/', 2)[0])
+
+import argparse
+import multiprocessing as mp
+import numpy as np
+from typing import List, Optional
+
+import torch
+import torch.distributed as dist
+
+from fairscale.nn.model_parallel import initialize as fs_init
+
+import gradio as gr
+from util.misc import setup_for_distributed
+from util.misc import default_tensor_type
+from model.meta import MetaModel
+from data.conversation_lib import conv_templates, SeparatorStyle
+from PIL import Image
+import torchvision.transforms as transforms
+from data.fintune_dataset import make_audio_features
+from data import video_utils 
+
+
+T_random_resized_crop = transforms.Compose([
+    transforms.RandomResizedCrop(size=(224, 224), scale=(0.9, 1.0), ratio=(0.75, 1.3333), interpolation=3,
+                                 antialias=None),  # 3 is bicubic
+    transforms.ToTensor(),
+    transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])])
+
+
+def load_audio(audio_path):
+    fbank = make_audio_features(audio_path, mel_bins=128)
+    fbank = fbank.transpose(0, 1)[None] #[1, 128, 1024]
+    return fbank
+    
+def load_video(video_path):
+    video_feats = video_utils.load_and_transform_video_data(video_path, video_path, clip_duration=1, clips_per_video=5)
+    return video_feats[:, :, 0]
+
+
+def model_worker(
+    rank: int, args: argparse.Namespace, barrier: mp.Barrier,
+    request_queue: mp.Queue, response_queue: Optional[mp.Queue] = None,
+) -> None:
+    """
+    The worker function that manipulates the GPU to run the inference.
+    Exact n_gpu workers are started, with each one operating on a separate GPU.
+
+    Args:
+        rank (int): Distributed rank of the worker.
+        args (argparse.Namespace): All command line arguments.
+        barrier (multiprocessing.Barrier): A barrier used to delay the start
+            of Web UI to be after the start of the model.
+    """
+
+    world_size = len(args.gpu_ids)
+    gpu_id = args.gpu_ids[rank]
+    dist.init_process_group(
+        backend="nccl", rank=rank, world_size=world_size,
+        init_method=f"tcp://{args.master_addr}:{args.master_port}",
+    )
+    print(f"| distributed init on worker {rank}/{world_size}. "
+          f"using gpu: {gpu_id}")
+    fs_init.initialize_model_parallel(world_size)
+    torch.cuda.set_device(gpu_id)
+
+    torch.manual_seed(1)
+    np.random.seed(1)
+
+    # set the print behavior.
+    setup_for_distributed(rank == 0)
+
+    target_dtype = {
+        "bf16": torch.bfloat16,
+        "fp16": torch.float16
+    }[args.dtype]
+    with default_tensor_type(dtype=target_dtype, device="cuda"):
+        model = MetaModel(args.llama_type, args.llama_config, tokenizer_path=args.tokenizer_path)
+    print("Loading pretrained weights ...")
+    checkpoint = torch.load(args.pretrained_path, map_location='cpu')
+    msg = model.load_state_dict(checkpoint, strict=False)
+    print("load result:\n", msg)
+    model.cuda()
+    model.eval()
+    print(f"Model = {str(model)}")
+
+    barrier.wait()
+
+    while True:
+        img_path, audio_path, video_path, chatbot, max_gen_len, temperature, top_p, modality = request_queue.get()
+        if 'image' in modality and img_path is not None:
+            image = Image.open(img_path).convert('RGB')
+            inputs = T_random_resized_crop(image)
+        elif 'video' in modality and video_path is not None:
+            inputs = load_video(video_path)
+        elif 'audio' in modality and audio_path is not None:
+            inputs = load_audio(audio_path)
+        else:
+            inputs = None
+        
+        if inputs is not None:
+            inputs = inputs[None].cuda().to(target_dtype)
+    
+        conv = conv_templates["v1"].copy()
+        for user, bot in chatbot:
+            conv.append_message(conv.roles[0], user)
+            conv.append_message(conv.roles[1], bot)
+
+        with torch.cuda.amp.autocast(dtype=target_dtype):
+            print(conv.get_prompt())
+            for stream_response in model.stream_generate(
+                conv.get_prompt(), inputs,
+                max_gen_len=max_gen_len, temperature=temperature, top_p=top_p,
+                modal = modality
+            ):
+                conv_sep = (
+                    conv.sep
+                    if conv.sep_style == SeparatorStyle.SINGLE
+                    else conv.sep2
+                )
+                end_pos = stream_response["text"].find(conv_sep)
+                if end_pos != -1:
+                    stream_response["text"] = (
+                        stream_response['text'][:end_pos].rstrip() + "\n"
+                    )
+                    stream_response["end_of_content"] = True
+
+                # keep a few characters if not end_of_content to avoid sending
+                # part of conv_sep before all of it is generated.
+                if not stream_response["end_of_content"]:
+                    if len(stream_response["text"]) < len(conv_sep):
+                        continue
+                    stream_response["text"] = (
+                        stream_response["text"][:-len(conv_sep)]
+                    )
+
+                if response_queue is not None:
+                    response_queue.put(stream_response)
+
+                if stream_response["end_of_content"]:
+                    break
+
+
+def gradio_worker(
+    request_queues: List[mp.Queue], response_queue: mp.Queue,
+    args: argparse.Namespace, barrier: mp.Barrier,
+) -> None:
+    """
+    The gradio worker is responsible for displaying the WebUI and relay the
+    requests to model workers. It should be launched only once.
+
+    Args:
+        request_queues (List[mp.Queue]): A list of request queues (one for
+            each model worker).
+        args (argparse.Namespace): All command line arguments.
+        barrier (multiprocessing.Barrier): A barrier used to delay the start
+            of Web UI to be after the start of the model.
+    """
+
+    def show_user_input(msg, chatbot):
+        return "", chatbot + [[msg, None]]
+
+    def stream_model_output(img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality):
+        for queue in request_queues:
+            queue.put((img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality))
+        while True:
+            content_piece = response_queue.get()
+            chatbot[-1][1] = content_piece["text"]
+            yield chatbot
+            if content_piece["end_of_content"]:
+                break
+
+    def undo(chatbot):
+        if len(chatbot) > 0:
+            chatbot = chatbot[:-1]
+        return chatbot
+
+    def clear():
+        chatbot = []
+        msg = ""
+        return chatbot, msg
+
+    CSS ="""
+    .contain { display: flex; flex-direction: column; }
+    #component-0 { height: 100%; }
+    #chatbot { flex-grow: 1; overflow: auto;}
+    """
+    with gr.Blocks(css=CSS) as demo:
+        gr.Markdown("## OneLLM: One Framework to Align All Modalities with Language")
+        with gr.Row(equal_height=True):
+            with gr.Column(scale=1):
+                img_path = gr.Image(label='Image Input', type='filepath')
+                video_path = gr.Video(label='Video Input')
+                audio_path = gr.Audio(label='Audio Input', type='filepath', sources=['upload'])
+                modality = gr.Radio(choices=['image', 'audio', 'video'], value='image', interactive=True, label='Input Modalities')
+
+            with gr.Column(scale=2):
+                chatbot = gr.Chatbot(elem_id="chatbot")
+                msg = gr.Textbox()
+
+        with gr.Row():
+            submit_button = gr.Button("Submit", variant="primary")
+            undo_button = gr.Button("Undo")
+            clear_button = gr.ClearButton([chatbot, msg, img_path, audio_path, video_path, modality])
+        with gr.Row():
+            max_gen_len = gr.Slider(
+                minimum=1, maximum=args.model_max_seq_len // 2,
+                value=args.model_max_seq_len // 2, interactive=True,
+                label="Single-turn max response length",
+            )
+            gen_t = gr.Slider(
+                minimum=0, maximum=1, value=0.1, interactive=True,
+                label="Temperature",
+            )
+            top_p = gr.Slider(
+                minimum=0, maximum=1, value=0.75, interactive=True,
+                label="Top-p",
+            )
+        msg.submit(
+            show_user_input, [msg, chatbot], [msg, chatbot],
+        ).then(
+            stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+        )
+        submit_button.click(
+            show_user_input, [msg, chatbot], [msg, chatbot],
+        ).then(
+            stream_model_output, [img_path, audio_path, video_path, chatbot, max_gen_len, gen_t, top_p, modality], chatbot,
+        )
+        undo_button.click(undo, chatbot, chatbot)
+        # img_path.change(clear, [], [chatbot, msg])
+    barrier.wait()
+    demo.queue(api_open=True).launch(share=True, max_threads=1)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser("Chat Demo")
+    group = parser.add_mutually_exclusive_group()
+    group.add_argument(
+        "--gpu_ids", type=int, nargs="+",
+        help="A list of space-separated gpu ids to run the model on. "
+             "The model will span across GPUs in tensor-parallel mode."
+    )
+    parser.add_argument(
+        "--tokenizer_path", type=str,
+        help="Path to the tokenizer.model file provided along with the LLaMA "
+             "model."
+    )
+    parser.add_argument(
+        "--llama_type", default="onellm", type=str, metavar="MODEL",
+        help="LLaMA model type."
+    )
+    parser.add_argument(
+        "--llama_config", type=str, required=True,
+        help="Path to the llama model config json."
+    )
+    parser.add_argument(
+        "--model_max_seq_len", type=int, default=2048,
+        help="Max sequence length accepted by the pretrained model."
+    )
+    parser.add_argument(
+        "--pretrained_path", type=str, required=True,
+        help="Path to the llama model checkpoints. A list of checkpoints is "
+             "supported and will be merged from left to right.")
+    parser.add_argument(
+        "--master_port", type=int, default=23862,
+        help="A port used by the PyTorch distributed module to initialize."
+    )
+    parser.add_argument(
+        "--master_addr", type=str, default="127.0.0.1",
+        help="An address used by the PyTorch distributed module to initialize."
+    )
+    parser.add_argument(
+        "--dtype", type=str, choices=["fp16", "bf16"], default="fp16",
+        help="The dtype used for model weights and inference."
+    )
+    args = parser.parse_args()
+
+    # using the default "fork" method messes up some imported libs (e.g.,
+    # pandas)
+    mp.set_start_method("spawn")
+
+    # setup the queues and start the model workers
+    request_queues = []
+    response_queue = mp.Queue()
+    worker_processes = []
+    barrier = mp.Barrier(len(args.gpu_ids) + 1)
+    for rank, gpu_id in enumerate(args.gpu_ids):
+        request_queue = mp.Queue()
+        rank_response_queue = response_queue if rank == 0 else None
+        process = mp.Process(
+            target=model_worker,
+            args=(rank, args, barrier, request_queue, rank_response_queue),
+        )
+        process.start()
+        worker_processes.append(process)
+        request_queues.append(request_queue)
+
+    gradio_worker(request_queues, response_queue, args, barrier)
diff --git a/lib/__pycache__/point_utils.cpython-310.pyc b/lib/__pycache__/point_utils.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b52bf4169d4d84233f3178c745896d1fa395824f
Binary files /dev/null and b/lib/__pycache__/point_utils.cpython-310.pyc differ
diff --git a/lib/point_utils.py b/lib/point_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..834733a64b540a141bfce09f6d0fae3154f89997
--- /dev/null
+++ b/lib/point_utils.py
@@ -0,0 +1,191 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+import pointnet2_cuda
+
+class KNN(nn.Module):
+    def __init__(self, neighbors, transpose_mode=True):
+        super(KNN, self).__init__()
+        self.neighbors = neighbors
+
+    @torch.no_grad()
+    def forward(self, support, query):
+        """
+        Args:
+            support ([tensor]): [B, N, C]
+            query ([tensor]): [B, M, C]
+        Returns:
+            [int]: neighbor idx. [B, M, K]
+        """
+        dist = torch.cdist(support, query)
+        k_dist = dist.topk(k=self.neighbors, dim=1, largest=False)
+        return k_dist.values, k_dist.indices.transpose(1, 2).contiguous().int()
+
+
+class GroupingOperation(Function):
+
+    @staticmethod
+    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
+    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+        """
+        :param ctx:
+        :param features: (B, C, N) tensor of features to group
+        :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
+        :return:
+            output: (B, C, npoint, nsample) tensor
+        """
+        assert features.is_contiguous()
+        assert idx.is_contiguous()
+
+        B, nfeatures, nsample = idx.size()
+        _, C, N = features.size()
+        output = torch.cuda.FloatTensor(B, C, nfeatures, nsample, device=features.device)
+
+        pointnet2_cuda.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
+
+        ctx.for_backwards = (idx, N)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_out: torch.Tensor):
+        """
+        :param ctx:
+        :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
+        :return:
+            grad_features: (B, C, N) gradient of the features
+        """
+        idx, N = ctx.for_backwards
+
+        B, C, npoint, nsample = grad_out.size()
+        grad_features = torch.zeros([B, C, N], dtype=torch.float, device=grad_out.device, requires_grad=True)
+        grad_out_data = grad_out.data.contiguous()
+        pointnet2_cuda.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
+        return grad_features, None
+
+grouping_operation = GroupingOperation.apply
+
+
+class KNNGroup(nn.Module):
+    def __init__(self, nsample: int,
+                 relative_xyz=True,
+                 normalize_dp=False,
+                 return_only_idx=False,
+                 **kwargs
+                 ):
+        """[summary]
+
+        Args:
+            nsample (int): maximum number of features to gather in the ball
+            use_xyz (bool, optional): concate xyz. Defaults to True.
+            ret_grouped_xyz (bool, optional): [description]. Defaults to False.
+            normalize_dp (bool, optional): [description]. Defaults to False.
+        """
+        super().__init__()
+        self.nsample = nsample
+        self.knn = KNN(nsample, transpose_mode=True)
+        self.relative_xyz = relative_xyz
+        self.normalize_dp = normalize_dp
+        self.return_only_idx = return_only_idx
+
+    def forward(self, query_xyz: torch.Tensor, support_xyz: torch.Tensor, features: torch.Tensor = None):
+        """
+        :param query_xyz: (B, N, 3) xyz coordinates of the features
+        :param support_xyz: (B, npoint, 3) centroids
+        :param features: (B, C, N) descriptors of the features
+        :return:
+            new_features: (B, 3 + C, npoint, nsample)
+        """
+        _, idx = self.knn(support_xyz, query_xyz)
+        if self.return_only_idx:
+            return idx
+        idx = idx.int()
+        xyz_trans = support_xyz.transpose(1, 2).contiguous()
+        grouped_xyz = grouping_operation(xyz_trans, idx)  # (B, 3, npoint, nsample)
+        if self.relative_xyz:
+            grouped_xyz -= query_xyz.transpose(1, 2).unsqueeze(-1)  # relative position
+        if self.normalize_dp:
+            grouped_xyz /= torch.amax(torch.sqrt(torch.sum(grouped_xyz**2, dim=1)), dim=(1, 2)).view(-1, 1, 1, 1)
+        if features is not None:
+            grouped_features = grouping_operation(features, idx)
+            return grouped_xyz, grouped_features
+        else:
+            return grouped_xyz, None
+
+
+class FurthestPointSampling(Function):
+    @staticmethod
+    def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
+        """
+        Uses iterative furthest point sampling to select a set of npoint features that have the largest
+        minimum distance
+        :param ctx:
+        :param xyz: (B, N, 3) where N > npoint
+        :param npoint: int, number of features in the sampled set
+        :return:
+             output: (B, npoint) tensor containing the set (idx)
+        """
+        assert xyz.is_contiguous()
+
+        B, N, _ = xyz.size()
+        # output = torch.cuda.IntTensor(B, npoint, device=xyz.device)
+        # temp = torch.cuda.FloatTensor(B, N, device=xyz.device).fill_(1e10)
+        output = torch.cuda.IntTensor(B, npoint)
+        temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+        pointnet2_cuda.furthest_point_sampling_wrapper(
+            B, N, npoint, xyz, temp, output)
+        return output
+
+    @staticmethod
+    def backward(xyz, a=None):
+        return None, None
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class PointPatchEmbed(nn.Module):
+
+    def __init__(self,
+                 sample_ratio=0.0625,
+                 sample_number=1024,
+                 group_size=32,
+                 in_channels=6,
+                 channels=1024,
+                 kernel_size=1,
+                 stride=1,
+                 normalize_dp=False,
+                 relative_xyz=True,
+                 ):
+        super().__init__()
+        self.sample_ratio = sample_ratio
+        self.sample_number = sample_number
+        self.group_size = group_size
+
+        self.sample_fn = furthest_point_sample
+        self.grouper = KNNGroup(self.group_size, relative_xyz=relative_xyz, normalize_dp=normalize_dp)
+        
+        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=kernel_size, stride=stride)
+
+
+    def forward(self, x):
+        # coordinates
+        p = x[:, :, 3:].contiguous()
+
+        B, N, _ = p.shape[:3]
+        # idx = self.sample_fn(p, int(N * self.sample_ratio)).long()
+        idx = self.sample_fn(p, self.sample_number).long()
+        center_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
+        # query neighbors.
+        _, fj = self.grouper(center_p, p, x.permute(0, 2, 1).contiguous()) # [B, N, 6] -> [B, 6, N] -> [B, 6, 1024, 32]
+
+        # [B, 6, 1024] -> [B, channels, 1024, 1]
+        fj = self.conv1(fj).max(dim=-1, keepdim=True)[0]
+
+        return fj
+
+
+if __name__ == '__main__':
+    model = PointPatchEmbed(channels=256).cuda()
+    input = torch.rand(4, 16384, 6).cuda()
+    ou = model(input)
+    import pdb;pdb.set_trace()
\ No newline at end of file
diff --git a/lib/pointnet2/pointnet2_modules.py b/lib/pointnet2/pointnet2_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f125ce5075c738897e5f6a78c71123d0e3e44a2
--- /dev/null
+++ b/lib/pointnet2/pointnet2_modules.py
@@ -0,0 +1,160 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from . import pointnet2_utils
+from . import pytorch_utils as pt_utils
+from typing import List
+
+
+class _PointnetSAModuleBase(nn.Module):
+
+    def __init__(self):
+        super().__init__()
+        self.npoint = None
+        self.groupers = None
+        self.mlps = None
+        self.pool_method = 'max_pool'
+
+    def forward(self, xyz: torch.Tensor, features: torch.Tensor = None, new_xyz=None) -> (torch.Tensor, torch.Tensor):
+        """
+        :param xyz: (B, N, 3) tensor of the xyz coordinates of the features
+        :param features: (B, N, C) tensor of the descriptors of the the features
+        :param new_xyz:
+        :return:
+            new_xyz: (B, npoint, 3) tensor of the new features' xyz
+            new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors
+        """
+        new_features_list = []
+
+        xyz_flipped = xyz.transpose(1, 2).contiguous()
+        if new_xyz is None:
+            new_xyz = pointnet2_utils.gather_operation(
+                xyz_flipped,
+                pointnet2_utils.furthest_point_sample(xyz, self.npoint)
+            ).transpose(1, 2).contiguous() if self.npoint is not None else None
+
+        for i in range(len(self.groupers)):
+            new_features = self.groupers[i](xyz, new_xyz, features)  # (B, C, npoint, nsample)
+
+            new_features = self.mlps[i](new_features)  # (B, mlp[-1], npoint, nsample)
+            if self.pool_method == 'max_pool':
+                new_features = F.max_pool2d(
+                    new_features, kernel_size=[1, new_features.size(3)]
+                )  # (B, mlp[-1], npoint, 1)
+            elif self.pool_method == 'avg_pool':
+                new_features = F.avg_pool2d(
+                    new_features, kernel_size=[1, new_features.size(3)]
+                )  # (B, mlp[-1], npoint, 1)
+            else:
+                raise NotImplementedError
+
+            new_features = new_features.squeeze(-1)  # (B, mlp[-1], npoint)
+            new_features_list.append(new_features)
+
+        return new_xyz, torch.cat(new_features_list, dim=1)
+
+
+class PointnetSAModuleMSG(_PointnetSAModuleBase):
+    """Pointnet set abstraction layer with multiscale grouping"""
+
+    def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True,
+                 use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
+        """
+        :param npoint: int
+        :param radii: list of float, list of radii to group with
+        :param nsamples: list of int, number of samples in each ball query
+        :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale
+        :param bn: whether to use batchnorm
+        :param use_xyz:
+        :param pool_method: max_pool / avg_pool
+        :param instance_norm: whether to use instance_norm
+        """
+        super().__init__()
+
+        assert len(radii) == len(nsamples) == len(mlps)
+
+        self.npoint = npoint
+        self.groupers = nn.ModuleList()
+        self.mlps = nn.ModuleList()
+        for i in range(len(radii)):
+            radius = radii[i]
+            nsample = nsamples[i]
+            self.groupers.append(
+                pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz)
+                if npoint is not None else pointnet2_utils.GroupAll(use_xyz)
+            )
+            mlp_spec = mlps[i]
+            if use_xyz:
+                mlp_spec[0] += 3
+
+            self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm))
+        self.pool_method = pool_method
+
+
+class PointnetSAModule(PointnetSAModuleMSG):
+    """Pointnet set abstraction layer"""
+
+    def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None,
+                 bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False):
+        """
+        :param mlp: list of int, spec of the pointnet before the global max_pool
+        :param npoint: int, number of features
+        :param radius: float, radius of ball
+        :param nsample: int, number of samples in the ball query
+        :param bn: whether to use batchnorm
+        :param use_xyz:
+        :param pool_method: max_pool / avg_pool
+        :param instance_norm: whether to use instance_norm
+        """
+        super().__init__(
+            mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz,
+            pool_method=pool_method, instance_norm=instance_norm
+        )
+
+
+class PointnetFPModule(nn.Module):
+    r"""Propigates the features of one set to another"""
+
+    def __init__(self, *, mlp: List[int], bn: bool = True):
+        """
+        :param mlp: list of int
+        :param bn: whether to use batchnorm
+        """
+        super().__init__()
+        self.mlp = pt_utils.SharedMLP(mlp, bn=bn)
+
+    def forward(
+            self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor
+    ) -> torch.Tensor:
+        """
+        :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features
+        :param known: (B, m, 3) tensor of the xyz positions of the known features
+        :param unknow_feats: (B, C1, n) tensor of the features to be propigated to
+        :param known_feats: (B, C2, m) tensor of features to be propigated
+        :return:
+            new_features: (B, mlp[-1], n) tensor of the features of the unknown features
+        """
+        if known is not None:
+            dist, idx = pointnet2_utils.three_nn(unknown, known)
+            dist_recip = 1.0 / (dist + 1e-8)
+            norm = torch.sum(dist_recip, dim=2, keepdim=True)
+            weight = dist_recip / norm
+
+            interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight)
+        else:
+            interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1))
+
+        if unknow_feats is not None:
+            new_features = torch.cat([interpolated_feats, unknow_feats], dim=1)  # (B, C2 + C1, n)
+        else:
+            new_features = interpolated_feats
+
+        new_features = new_features.unsqueeze(-1)
+        new_features = self.mlp(new_features)
+
+        return new_features.squeeze(-1)
+
+
+if __name__ == "__main__":
+    pass
diff --git a/lib/pointnet2/pointnet2_utils.py b/lib/pointnet2/pointnet2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e814102d8feb5e443e64a736e7733818e0a24685
--- /dev/null
+++ b/lib/pointnet2/pointnet2_utils.py
@@ -0,0 +1,290 @@
+import torch
+from torch.autograd import Variable
+from torch.autograd import Function
+import torch.nn as nn
+from typing import Tuple
+
+import pointnet2_cuda as pointnet2
+
+
+class FurthestPointSampling(Function):
+    @staticmethod
+    def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
+        """
+        Uses iterative furthest point sampling to select a set of npoint features that have the largest
+        minimum distance
+        :param ctx:
+        :param xyz: (B, N, 3) where N > npoint
+        :param npoint: int, number of features in the sampled set
+        :return:
+             output: (B, npoint) tensor containing the set
+        """
+        assert xyz.is_contiguous()
+
+        B, N, _ = xyz.size()
+        output = torch.cuda.IntTensor(B, npoint)
+        temp = torch.cuda.FloatTensor(B, N).fill_(1e10)
+
+        pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
+        return output
+
+    @staticmethod
+    def backward(xyz, a=None):
+        return None, None
+
+
+furthest_point_sample = FurthestPointSampling.apply
+
+
+class GatherOperation(Function):
+
+    @staticmethod
+    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+        """
+        :param ctx:
+        :param features: (B, C, N)
+        :param idx: (B, npoint) index tensor of the features to gather
+        :return:
+            output: (B, C, npoint)
+        """
+        assert features.is_contiguous()
+        assert idx.is_contiguous()
+
+        B, npoint = idx.size()
+        _, C, N = features.size()
+        output = torch.cuda.FloatTensor(B, C, npoint)
+
+        pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output)
+
+        ctx.for_backwards = (idx, C, N)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_out):
+        idx, C, N = ctx.for_backwards
+        B, npoint = idx.size()
+
+        grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
+        grad_out_data = grad_out.data.contiguous()
+        pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data)
+        return grad_features, None
+
+
+gather_operation = GatherOperation.apply
+
+
+class ThreeNN(Function):
+
+    @staticmethod
+    def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Find the three nearest neighbors of unknown in known
+        :param ctx:
+        :param unknown: (B, N, 3)
+        :param known: (B, M, 3)
+        :return:
+            dist: (B, N, 3) l2 distance to the three nearest neighbors
+            idx: (B, N, 3) index of 3 nearest neighbors
+        """
+        assert unknown.is_contiguous()
+        assert known.is_contiguous()
+
+        B, N, _ = unknown.size()
+        m = known.size(1)
+        dist2 = torch.cuda.FloatTensor(B, N, 3)
+        idx = torch.cuda.IntTensor(B, N, 3)
+
+        pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx)
+        return torch.sqrt(dist2), idx
+
+    @staticmethod
+    def backward(ctx, a=None, b=None):
+        return None, None
+
+
+three_nn = ThreeNN.apply
+
+
+class ThreeInterpolate(Function):
+
+    @staticmethod
+    def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
+        """
+        Performs weight linear interpolation on 3 features
+        :param ctx:
+        :param features: (B, C, M) Features descriptors to be interpolated from
+        :param idx: (B, n, 3) three nearest neighbors of the target features in features
+        :param weight: (B, n, 3) weights
+        :return:
+            output: (B, C, N) tensor of the interpolated features
+        """
+        assert features.is_contiguous()
+        assert idx.is_contiguous()
+        assert weight.is_contiguous()
+
+        B, c, m = features.size()
+        n = idx.size(1)
+        ctx.three_interpolate_for_backward = (idx, weight, m)
+        output = torch.cuda.FloatTensor(B, c, n)
+
+        pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """
+        :param ctx:
+        :param grad_out: (B, C, N) tensor with gradients of outputs
+        :return:
+            grad_features: (B, C, M) tensor with gradients of features
+            None:
+            None:
+        """
+        idx, weight, m = ctx.three_interpolate_for_backward
+        B, c, n = grad_out.size()
+
+        grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_())
+        grad_out_data = grad_out.data.contiguous()
+
+        pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data)
+        return grad_features, None, None
+
+
+three_interpolate = ThreeInterpolate.apply
+
+
+class GroupingOperation(Function):
+
+    @staticmethod
+    def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
+        """
+        :param ctx:
+        :param features: (B, C, N) tensor of features to group
+        :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with
+        :return:
+            output: (B, C, npoint, nsample) tensor
+        """
+        assert features.is_contiguous()
+        assert idx.is_contiguous()
+
+        B, nfeatures, nsample = idx.size()
+        _, C, N = features.size()
+        output = torch.cuda.FloatTensor(B, C, nfeatures, nsample)
+
+        pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output)
+
+        ctx.for_backwards = (idx, N)
+        return output
+
+    @staticmethod
+    def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        :param ctx:
+        :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward
+        :return:
+            grad_features: (B, C, N) gradient of the features
+        """
+        idx, N = ctx.for_backwards
+
+        B, C, npoint, nsample = grad_out.size()
+        grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_())
+
+        grad_out_data = grad_out.data.contiguous()
+        pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data)
+        return grad_features, None
+
+
+grouping_operation = GroupingOperation.apply
+
+
+class BallQuery(Function):
+
+    @staticmethod
+    def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor:
+        """
+        :param ctx:
+        :param radius: float, radius of the balls
+        :param nsample: int, maximum number of features in the balls
+        :param xyz: (B, N, 3) xyz coordinates of the features
+        :param new_xyz: (B, npoint, 3) centers of the ball query
+        :return:
+            idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls
+        """
+        assert new_xyz.is_contiguous()
+        assert xyz.is_contiguous()
+
+        B, N, _ = xyz.size()
+        npoint = new_xyz.size(1)
+        idx = torch.cuda.IntTensor(B, npoint, nsample).zero_()
+
+        pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx)
+        return idx
+
+    @staticmethod
+    def backward(ctx, a=None):
+        return None, None, None, None
+
+
+ball_query = BallQuery.apply
+
+
+class QueryAndGroup(nn.Module):
+    def __init__(self, radius: float, nsample: int, use_xyz: bool = True):
+        """
+        :param radius: float, radius of ball
+        :param nsample: int, maximum number of features to gather in the ball
+        :param use_xyz:
+        """
+        super().__init__()
+        self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz
+
+    def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]:
+        """
+        :param xyz: (B, N, 3) xyz coordinates of the features
+        :param new_xyz: (B, npoint, 3) centroids
+        :param features: (B, C, N) descriptors of the features
+        :return:
+            new_features: (B, 3 + C, npoint, nsample)
+        """
+        idx = ball_query(self.radius, self.nsample, xyz, new_xyz)
+        xyz_trans = xyz.transpose(1, 2).contiguous()
+        grouped_xyz = grouping_operation(xyz_trans, idx)  # (B, 3, npoint, nsample)
+        grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1)
+
+        if features is not None:
+            grouped_features = grouping_operation(features, idx)
+            if self.use_xyz:
+                new_features = torch.cat([grouped_xyz, grouped_features], dim=1)  # (B, C + 3, npoint, nsample)
+            else:
+                new_features = grouped_features
+        else:
+            assert self.use_xyz, "Cannot have not features and not use xyz as a feature!"
+            new_features = grouped_xyz
+
+        return new_features
+
+
+class GroupAll(nn.Module):
+    def __init__(self, use_xyz: bool = True):
+        super().__init__()
+        self.use_xyz = use_xyz
+
+    def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None):
+        """
+        :param xyz: (B, N, 3) xyz coordinates of the features
+        :param new_xyz: ignored
+        :param features: (B, C, N) descriptors of the features
+        :return:
+            new_features: (B, C + 3, 1, N)
+        """
+        grouped_xyz = xyz.transpose(1, 2).unsqueeze(2)
+        if features is not None:
+            grouped_features = features.unsqueeze(2)
+            if self.use_xyz:
+                new_features = torch.cat([grouped_xyz, grouped_features], dim=1)  # (B, 3 + C, 1, N)
+            else:
+                new_features = grouped_features
+        else:
+            new_features = grouped_xyz
+
+        return new_features
diff --git a/lib/pointnet2/pytorch_utils.py b/lib/pointnet2/pytorch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..09cb7bc76d88dde5757ac70b6e05e1e0c768cc1b
--- /dev/null
+++ b/lib/pointnet2/pytorch_utils.py
@@ -0,0 +1,236 @@
+import torch.nn as nn
+from typing import List, Tuple
+
+
+class SharedMLP(nn.Sequential):
+
+    def __init__(
+            self,
+            args: List[int],
+            *,
+            bn: bool = False,
+            activation=nn.ReLU(inplace=True),
+            preact: bool = False,
+            first: bool = False,
+            name: str = "",
+            instance_norm: bool = False,
+    ):
+        super().__init__()
+
+        for i in range(len(args) - 1):
+            self.add_module(
+                name + 'layer{}'.format(i),
+                Conv2d(
+                    args[i],
+                    args[i + 1],
+                    bn=(not first or not preact or (i != 0)) and bn,
+                    activation=activation
+                    if (not first or not preact or (i != 0)) else None,
+                    preact=preact,
+                    instance_norm=instance_norm
+                )
+            )
+
+
+class _ConvBase(nn.Sequential):
+
+    def __init__(
+            self,
+            in_size,
+            out_size,
+            kernel_size,
+            stride,
+            padding,
+            activation,
+            bn,
+            init,
+            conv=None,
+            batch_norm=None,
+            bias=True,
+            preact=False,
+            name="",
+            instance_norm=False,
+            instance_norm_func=None
+    ):
+        super().__init__()
+
+        bias = bias and (not bn)
+        conv_unit = conv(
+            in_size,
+            out_size,
+            kernel_size=kernel_size,
+            stride=stride,
+            padding=padding,
+            bias=bias
+        )
+        init(conv_unit.weight)
+        if bias:
+            nn.init.constant_(conv_unit.bias, 0)
+
+        if bn:
+            if not preact:
+                bn_unit = batch_norm(out_size)
+            else:
+                bn_unit = batch_norm(in_size)
+        if instance_norm:
+            if not preact:
+                in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False)
+            else:
+                in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False)
+
+        if preact:
+            if bn:
+                self.add_module(name + 'bn', bn_unit)
+
+            if activation is not None:
+                self.add_module(name + 'activation', activation)
+
+            if not bn and instance_norm:
+                self.add_module(name + 'in', in_unit)
+
+        self.add_module(name + 'conv', conv_unit)
+
+        if not preact:
+            if bn:
+                self.add_module(name + 'bn', bn_unit)
+
+            if activation is not None:
+                self.add_module(name + 'activation', activation)
+
+            if not bn and instance_norm:
+                self.add_module(name + 'in', in_unit)
+
+
+class _BNBase(nn.Sequential):
+
+    def __init__(self, in_size, batch_norm=None, name=""):
+        super().__init__()
+        self.add_module(name + "bn", batch_norm(in_size))
+
+        nn.init.constant_(self[0].weight, 1.0)
+        nn.init.constant_(self[0].bias, 0)
+
+
+class BatchNorm1d(_BNBase):
+
+    def __init__(self, in_size: int, *, name: str = ""):
+        super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name)
+
+
+class BatchNorm2d(_BNBase):
+
+    def __init__(self, in_size: int, name: str = ""):
+        super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name)
+
+
+class Conv1d(_ConvBase):
+
+    def __init__(
+            self,
+            in_size: int,
+            out_size: int,
+            *,
+            kernel_size: int = 1,
+            stride: int = 1,
+            padding: int = 0,
+            activation=nn.ReLU(inplace=True),
+            bn: bool = False,
+            init=nn.init.kaiming_normal_,
+            bias: bool = True,
+            preact: bool = False,
+            name: str = "",
+            instance_norm=False
+    ):
+        super().__init__(
+            in_size,
+            out_size,
+            kernel_size,
+            stride,
+            padding,
+            activation,
+            bn,
+            init,
+            conv=nn.Conv1d,
+            batch_norm=BatchNorm1d,
+            bias=bias,
+            preact=preact,
+            name=name,
+            instance_norm=instance_norm,
+            instance_norm_func=nn.InstanceNorm1d
+        )
+
+
+class Conv2d(_ConvBase):
+
+    def __init__(
+            self,
+            in_size: int,
+            out_size: int,
+            *,
+            kernel_size: Tuple[int, int] = (1, 1),
+            stride: Tuple[int, int] = (1, 1),
+            padding: Tuple[int, int] = (0, 0),
+            activation=nn.ReLU(inplace=True),
+            bn: bool = False,
+            init=nn.init.kaiming_normal_,
+            bias: bool = True,
+            preact: bool = False,
+            name: str = "",
+            instance_norm=False
+    ):
+        super().__init__(
+            in_size,
+            out_size,
+            kernel_size,
+            stride,
+            padding,
+            activation,
+            bn,
+            init,
+            conv=nn.Conv2d,
+            batch_norm=BatchNorm2d,
+            bias=bias,
+            preact=preact,
+            name=name,
+            instance_norm=instance_norm,
+            instance_norm_func=nn.InstanceNorm2d
+        )
+
+
+class FC(nn.Sequential):
+
+    def __init__(
+            self,
+            in_size: int,
+            out_size: int,
+            *,
+            activation=nn.ReLU(inplace=True),
+            bn: bool = False,
+            init=None,
+            preact: bool = False,
+            name: str = ""
+    ):
+        super().__init__()
+
+        fc = nn.Linear(in_size, out_size, bias=not bn)
+        if init is not None:
+            init(fc.weight)
+        if not bn:
+            nn.init.constant(fc.bias, 0)
+
+        if preact:
+            if bn:
+                self.add_module(name + 'bn', BatchNorm1d(in_size))
+
+            if activation is not None:
+                self.add_module(name + 'activation', activation)
+
+        self.add_module(name + 'fc', fc)
+
+        if not preact:
+            if bn:
+                self.add_module(name + 'bn', BatchNorm1d(out_size))
+
+            if activation is not None:
+                self.add_module(name + 'activation', activation)
+
diff --git a/lib/pointnet2/setup.py b/lib/pointnet2/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..99e59e37b90517cc38c35d100f7f9cee0e309368
--- /dev/null
+++ b/lib/pointnet2/setup.py
@@ -0,0 +1,23 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+    name='pointnet2',
+    ext_modules=[
+        CUDAExtension('pointnet2_cuda', [
+            'src/pointnet2_api.cpp',
+            
+            'src/ball_query.cpp', 
+            'src/ball_query_gpu.cu',
+            'src/group_points.cpp', 
+            'src/group_points_gpu.cu',
+            'src/interpolate.cpp', 
+            'src/interpolate_gpu.cu',
+            'src/sampling.cpp', 
+            'src/sampling_gpu.cu',
+        ],
+        extra_compile_args={'cxx': ['-g'],
+                            'nvcc': ['-O2']})
+    ],
+    cmdclass={'build_ext': BuildExtension}
+)
diff --git a/lib/pointnet2/src/ball_query.cpp b/lib/pointnet2/src/ball_query.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..c9b176e5da5dd89a3378652f0b806925e8ee8996
--- /dev/null
+++ b/lib/pointnet2/src/ball_query.cpp
@@ -0,0 +1,24 @@
+#include <torch/serialize/tensor.h>
+#include <vector>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAEvent.h>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+#include "ball_query_gpu.h"
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
+#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
+
+int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 
+    at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor) {
+    CHECK_INPUT(new_xyz_tensor);
+    CHECK_INPUT(xyz_tensor);
+    const float *new_xyz = new_xyz_tensor.data<float>();
+    const float *xyz = xyz_tensor.data<float>();
+    int *idx = idx_tensor.data<int>();
+    
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    ball_query_kernel_launcher_fast(b, n, m, radius, nsample, new_xyz, xyz, idx, stream);
+    return 1;
+}
diff --git a/lib/pointnet2/src/ball_query_gpu.cu b/lib/pointnet2/src/ball_query_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..f8840aa6650693cea17d337008a15fef13ec1ebc
--- /dev/null
+++ b/lib/pointnet2/src/ball_query_gpu.cu
@@ -0,0 +1,67 @@
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "ball_query_gpu.h"
+#include "cuda_utils.h"
+
+
+__global__ void ball_query_kernel_fast(int b, int n, int m, float radius, int nsample, 
+    const float *__restrict__ new_xyz, const float *__restrict__ xyz, int *__restrict__ idx) {
+    // new_xyz: (B, M, 3)
+    // xyz: (B, N, 3)
+    // output:
+    //      idx: (B, M, nsample)
+    int bs_idx = blockIdx.y;
+    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (bs_idx >= b || pt_idx >= m) return;
+
+    new_xyz += bs_idx * m * 3 + pt_idx * 3;
+    xyz += bs_idx * n * 3;
+    idx += bs_idx * m * nsample + pt_idx * nsample;
+
+    float radius2 = radius * radius;
+    float new_x = new_xyz[0];
+    float new_y = new_xyz[1];
+    float new_z = new_xyz[2];
+
+    int cnt = 0;
+    for (int k = 0; k < n; ++k) {
+        float x = xyz[k * 3 + 0];
+        float y = xyz[k * 3 + 1];
+        float z = xyz[k * 3 + 2];
+        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
+        if (d2 < radius2){
+            if (cnt == 0){
+                for (int l = 0; l < nsample; ++l) {
+                    idx[l] = k;
+                }
+            }
+            idx[cnt] = k;
+            ++cnt;
+            if (cnt >= nsample) break;
+        }
+    }
+}
+
+
+void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, \
+    const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream) {
+    // new_xyz: (B, M, 3)
+    // xyz: (B, N, 3)
+    // output:
+    //      idx: (B, M, nsample)
+
+    cudaError_t err;
+
+    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+
+    ball_query_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, radius, nsample, new_xyz, xyz, idx);
+    // cudaDeviceSynchronize();  // for using printf in kernel function
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
\ No newline at end of file
diff --git a/lib/pointnet2/src/ball_query_gpu.h b/lib/pointnet2/src/ball_query_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..ffc831a8b700f46b50e0b90d49c538aa0fedca50
--- /dev/null
+++ b/lib/pointnet2/src/ball_query_gpu.h
@@ -0,0 +1,15 @@
+#ifndef _BALL_QUERY_GPU_H
+#define _BALL_QUERY_GPU_H
+
+#include <torch/serialize/tensor.h>
+#include <vector>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+
+int ball_query_wrapper_fast(int b, int n, int m, float radius, int nsample, 
+	at::Tensor new_xyz_tensor, at::Tensor xyz_tensor, at::Tensor idx_tensor);
+
+void ball_query_kernel_launcher_fast(int b, int n, int m, float radius, int nsample, 
+	const float *xyz, const float *new_xyz, int *idx, cudaStream_t stream);
+
+#endif
diff --git a/lib/pointnet2/src/cuda_utils.h b/lib/pointnet2/src/cuda_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..7fe27969179c976a88199bbe962ca4f8d97263a4
--- /dev/null
+++ b/lib/pointnet2/src/cuda_utils.h
@@ -0,0 +1,15 @@
+#ifndef _CUDA_UTILS_H
+#define _CUDA_UTILS_H
+
+#include <cmath>
+
+#define TOTAL_THREADS 1024
+#define THREADS_PER_BLOCK 256
+#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
+
+inline int opt_n_threads(int work_size) {
+    const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
+
+    return max(min(1 << pow_2, TOTAL_THREADS), 1);
+}
+#endif
diff --git a/lib/pointnet2/src/group_points.cpp b/lib/pointnet2/src/group_points.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..fa80f0e318acc57dabf76ec0a8b1d9dff482ab89
--- /dev/null
+++ b/lib/pointnet2/src/group_points.cpp
@@ -0,0 +1,34 @@
+#include <torch/serialize/tensor.h>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+#include <vector>
+#include "group_points_gpu.h"
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAEvent.h>
+
+
+
+int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
+    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
+
+    float *grad_points = grad_points_tensor.data<float>();
+    const int *idx = idx_tensor.data<int>();
+    const float *grad_out = grad_out_tensor.data<float>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    group_points_grad_kernel_launcher_fast(b, c, n, npoints, nsample, grad_out, idx, grad_points, stream);
+    return 1;
+}
+
+
+int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
+    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor) {
+
+    const float *points = points_tensor.data<float>();
+    const int *idx = idx_tensor.data<int>();
+    float *out = out_tensor.data<float>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    group_points_kernel_launcher_fast(b, c, n, npoints, nsample, points, idx, out, stream);
+    return 1;
+}
diff --git a/lib/pointnet2/src/group_points_gpu.cu b/lib/pointnet2/src/group_points_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..c015a8125e38aafa1f960000044978463b7853b1
--- /dev/null
+++ b/lib/pointnet2/src/group_points_gpu.cu
@@ -0,0 +1,86 @@
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "cuda_utils.h"
+#include "group_points_gpu.h"
+
+
+__global__ void group_points_grad_kernel_fast(int b, int c, int n, int npoints, int nsample, 
+    const float *__restrict__ grad_out, const int *__restrict__ idx, float *__restrict__ grad_points) {
+    // grad_out: (B, C, npoints, nsample)
+    // idx: (B, npoints, nsample)
+    // output:
+    //      grad_points: (B, C, N)
+    int bs_idx = blockIdx.z;
+    int c_idx = blockIdx.y;
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+    int pt_idx = index / nsample;
+    if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
+
+    int sample_idx = index % nsample;
+    grad_out += bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+    idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 
+    
+    atomicAdd(grad_points + bs_idx * c * n + c_idx * n + idx[0] , grad_out[0]);
+}
+
+void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
+    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
+    // grad_out: (B, C, npoints, nsample)
+    // idx: (B, npoints, nsample)
+    // output:
+    //      grad_points: (B, C, N)
+    cudaError_t err;
+    dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+
+    group_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, grad_out, idx, grad_points);
+
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
+
+
+__global__ void group_points_kernel_fast(int b, int c, int n, int npoints, int nsample, 
+    const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
+    // points: (B, C, N)
+    // idx: (B, npoints, nsample)
+    // output:
+    //      out: (B, C, npoints, nsample)
+    int bs_idx = blockIdx.z;
+    int c_idx = blockIdx.y;
+    int index = blockIdx.x * blockDim.x + threadIdx.x;
+    int pt_idx = index / nsample;
+    if (bs_idx >= b || c_idx >= c || pt_idx >= npoints) return;
+
+    int sample_idx = index % nsample;
+
+    idx += bs_idx * npoints * nsample + pt_idx * nsample + sample_idx; 
+    int in_idx = bs_idx * c * n + c_idx * n + idx[0];
+    int out_idx = bs_idx * c * npoints * nsample + c_idx * npoints * nsample + pt_idx * nsample + sample_idx;
+
+    out[out_idx] = points[in_idx];
+}
+
+
+void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
+    const float *points, const int *idx, float *out, cudaStream_t stream) {
+    // points: (B, C, N)
+    // idx: (B, npoints, nsample)
+    // output:
+    //      out: (B, C, npoints, nsample)
+    cudaError_t err;
+    dim3 blocks(DIVUP(npoints * nsample, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+
+    group_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, nsample, points, idx, out);
+    // cudaDeviceSynchronize();  // for using printf in kernel function
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
diff --git a/lib/pointnet2/src/group_points_gpu.h b/lib/pointnet2/src/group_points_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..76c73ca2600ef75c192b06d28f79a168f1ba368b
--- /dev/null
+++ b/lib/pointnet2/src/group_points_gpu.h
@@ -0,0 +1,22 @@
+#ifndef _GROUP_POINTS_GPU_H
+#define _GROUP_POINTS_GPU_H
+
+#include <torch/serialize/tensor.h>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+#include <vector>
+
+
+int group_points_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
+    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
+
+void group_points_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
+    const float *points, const int *idx, float *out, cudaStream_t stream);
+
+int group_points_grad_wrapper_fast(int b, int c, int n, int npoints, int nsample, 
+    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
+
+void group_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, int nsample, 
+    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
+
+#endif
diff --git a/lib/pointnet2/src/interpolate.cpp b/lib/pointnet2/src/interpolate.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..88d837f966f52696308b7d85ec1756b2395bb986
--- /dev/null
+++ b/lib/pointnet2/src/interpolate.cpp
@@ -0,0 +1,53 @@
+#include <torch/serialize/tensor.h>
+#include <vector>
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+#include "interpolate_gpu.h"
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAEvent.h>
+
+
+void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 
+    at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor) {
+    const float *unknown = unknown_tensor.data<float>();
+    const float *known = known_tensor.data<float>();
+    float *dist2 = dist2_tensor.data<float>();
+    int *idx = idx_tensor.data<int>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    three_nn_kernel_launcher_fast(b, n, m, unknown, known, dist2, idx, stream);
+}
+
+
+void three_interpolate_wrapper_fast(int b, int c, int m, int n,
+                         at::Tensor points_tensor,
+                         at::Tensor idx_tensor,
+                         at::Tensor weight_tensor,
+                         at::Tensor out_tensor) {
+
+    const float *points = points_tensor.data<float>();
+    const float *weight = weight_tensor.data<float>();
+    float *out = out_tensor.data<float>();
+    const int *idx = idx_tensor.data<int>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    three_interpolate_kernel_launcher_fast(b, c, m, n, points, idx, weight, out, stream);
+}
+
+void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m,
+                            at::Tensor grad_out_tensor,
+                            at::Tensor idx_tensor,
+                            at::Tensor weight_tensor,
+                            at::Tensor grad_points_tensor) {
+
+    const float *grad_out = grad_out_tensor.data<float>();
+    const float *weight = weight_tensor.data<float>();
+    float *grad_points = grad_points_tensor.data<float>();
+    const int *idx = idx_tensor.data<int>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    three_interpolate_grad_kernel_launcher_fast(b, c, n, m, grad_out, idx, weight, grad_points, stream);
+}
diff --git a/lib/pointnet2/src/interpolate_gpu.cu b/lib/pointnet2/src/interpolate_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a123dd8d8d4f5ed23cc4a340abb1141d140fca3c
--- /dev/null
+++ b/lib/pointnet2/src/interpolate_gpu.cu
@@ -0,0 +1,161 @@
+#include <math.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "cuda_utils.h"
+#include "interpolate_gpu.h"
+
+
+__global__ void three_nn_kernel_fast(int b, int n, int m, const float *__restrict__ unknown, 
+    const float *__restrict__ known, float *__restrict__ dist2, int *__restrict__ idx) {
+    // unknown: (B, N, 3)
+    // known: (B, M, 3)
+    // output: 
+    //      dist2: (B, N, 3)
+    //      idx: (B, N, 3)
+    
+    int bs_idx = blockIdx.y;
+    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (bs_idx >= b || pt_idx >= n) return;
+
+    unknown += bs_idx * n * 3 + pt_idx * 3;
+    known += bs_idx * m * 3;
+    dist2 += bs_idx * n * 3 + pt_idx * 3;
+    idx += bs_idx * n * 3 + pt_idx * 3;
+
+    float ux = unknown[0];
+    float uy = unknown[1];
+    float uz = unknown[2];
+
+    double best1 = 1e40, best2 = 1e40, best3 = 1e40;
+    int besti1 = 0, besti2 = 0, besti3 = 0;
+    for (int k = 0; k < m; ++k) {
+        float x = known[k * 3 + 0];
+        float y = known[k * 3 + 1];
+        float z = known[k * 3 + 2];
+        float d = (ux - x) * (ux - x) + (uy - y) * (uy - y) + (uz - z) * (uz - z);
+        if (d < best1) {
+            best3 = best2; besti3 = besti2;
+            best2 = best1; besti2 = besti1;
+            best1 = d; besti1 = k;
+        } 
+        else if (d < best2) {
+            best3 = best2; besti3 = besti2;
+            best2 = d; besti2 = k;
+        } 
+        else if (d < best3) {
+            best3 = d; besti3 = k;
+        }
+    }
+    dist2[0] = best1; dist2[1] = best2; dist2[2] = best3;
+    idx[0] = besti1; idx[1] = besti2; idx[2] = besti3;
+}
+
+
+void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown, 
+    const float *known, float *dist2, int *idx, cudaStream_t stream) {
+    // unknown: (B, N, 3)
+    // known: (B, M, 3)
+    // output: 
+    //      dist2: (B, N, 3)
+    //      idx: (B, N, 3)
+
+    cudaError_t err;
+    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+
+    three_nn_kernel_fast<<<blocks, threads, 0, stream>>>(b, n, m, unknown, known, dist2, idx);
+
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
+
+
+__global__ void three_interpolate_kernel_fast(int b, int c, int m, int n, const float *__restrict__ points, 
+    const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ out) {
+    // points: (B, C, M)
+    // idx: (B, N, 3)
+    // weight: (B, N, 3)
+    // output:
+    //      out: (B, C, N)
+
+    int bs_idx = blockIdx.z;
+    int c_idx = blockIdx.y;
+    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
+
+    weight += bs_idx * n * 3 + pt_idx * 3;
+    points += bs_idx * c * m + c_idx * m;
+    idx += bs_idx * n * 3 + pt_idx * 3;
+    out += bs_idx * c * n + c_idx * n;
+
+    out[pt_idx] = weight[0] * points[idx[0]] + weight[1] * points[idx[1]] + weight[2] * points[idx[2]];
+}
+
+void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 
+    const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream) {
+    // points: (B, C, M)
+    // idx: (B, N, 3)
+    // weight: (B, N, 3)
+    // output:
+    //      out: (B, C, N)
+
+    cudaError_t err;
+    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+    three_interpolate_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, m, n, points, idx, weight, out);
+
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
+
+
+__global__ void three_interpolate_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 
+    const int *__restrict__ idx, const float *__restrict__ weight, float *__restrict__ grad_points) {
+    // grad_out: (B, C, N)
+    // weight: (B, N, 3)
+    // output:
+    //      grad_points: (B, C, M)
+
+    int bs_idx = blockIdx.z;
+    int c_idx = blockIdx.y;
+    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+    if (bs_idx >= b || c_idx >= c || pt_idx >= n) return;
+    
+    grad_out += bs_idx * c * n + c_idx * n + pt_idx;
+    weight += bs_idx * n * 3 + pt_idx * 3;
+    grad_points += bs_idx * c * m + c_idx * m;
+    idx += bs_idx * n * 3 + pt_idx * 3;
+
+
+    atomicAdd(grad_points + idx[0], grad_out[0] * weight[0]);
+    atomicAdd(grad_points + idx[1], grad_out[0] * weight[1]);
+    atomicAdd(grad_points + idx[2], grad_out[0] * weight[2]);
+}
+
+void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 
+    const int *idx, const float *weight, float *grad_points, cudaStream_t stream) {
+    // grad_out: (B, C, N)
+    // weight: (B, N, 3)
+    // output:
+    //      grad_points: (B, C, M)
+
+    cudaError_t err;
+    dim3 blocks(DIVUP(n, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+    three_interpolate_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, m, grad_out, idx, weight, grad_points);
+
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
\ No newline at end of file
diff --git a/lib/pointnet2/src/interpolate_gpu.h b/lib/pointnet2/src/interpolate_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..f1771087c5e4146e3c5775d3b929ebffffd11ccb
--- /dev/null
+++ b/lib/pointnet2/src/interpolate_gpu.h
@@ -0,0 +1,30 @@
+#ifndef _INTERPOLATE_GPU_H
+#define _INTERPOLATE_GPU_H
+
+#include <torch/serialize/tensor.h>
+#include<vector>
+#include <cuda.h>
+#include <cuda_runtime_api.h>
+
+
+void three_nn_wrapper_fast(int b, int n, int m, at::Tensor unknown_tensor, 
+  at::Tensor known_tensor, at::Tensor dist2_tensor, at::Tensor idx_tensor);
+
+void three_nn_kernel_launcher_fast(int b, int n, int m, const float *unknown,
+	const float *known, float *dist2, int *idx, cudaStream_t stream);
+
+
+void three_interpolate_wrapper_fast(int b, int c, int m, int n, at::Tensor points_tensor, 
+    at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor out_tensor);
+
+void three_interpolate_kernel_launcher_fast(int b, int c, int m, int n, 
+    const float *points, const int *idx, const float *weight, float *out, cudaStream_t stream);
+
+
+void three_interpolate_grad_wrapper_fast(int b, int c, int n, int m, at::Tensor grad_out_tensor, 
+    at::Tensor idx_tensor, at::Tensor weight_tensor, at::Tensor grad_points_tensor);
+
+void three_interpolate_grad_kernel_launcher_fast(int b, int c, int n, int m, const float *grad_out, 
+    const int *idx, const float *weight, float *grad_points, cudaStream_t stream);
+
+#endif
diff --git a/lib/pointnet2/src/pointnet2_api.cpp b/lib/pointnet2/src/pointnet2_api.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..d91f0f2176a6080624f071e5535fe509a0ac83c4
--- /dev/null
+++ b/lib/pointnet2/src/pointnet2_api.cpp
@@ -0,0 +1,24 @@
+#include <torch/serialize/tensor.h>
+#include <torch/extension.h>
+
+#include "ball_query_gpu.h"
+#include "group_points_gpu.h"
+#include "sampling_gpu.h"
+#include "interpolate_gpu.h"
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("ball_query_wrapper", &ball_query_wrapper_fast, "ball_query_wrapper_fast");
+
+    m.def("group_points_wrapper", &group_points_wrapper_fast, "group_points_wrapper_fast");
+    m.def("group_points_grad_wrapper", &group_points_grad_wrapper_fast, "group_points_grad_wrapper_fast");
+
+    m.def("gather_points_wrapper", &gather_points_wrapper_fast, "gather_points_wrapper_fast");
+    m.def("gather_points_grad_wrapper", &gather_points_grad_wrapper_fast, "gather_points_grad_wrapper_fast");
+
+    m.def("furthest_point_sampling_wrapper", &furthest_point_sampling_wrapper, "furthest_point_sampling_wrapper");
+    
+    m.def("three_nn_wrapper", &three_nn_wrapper_fast, "three_nn_wrapper_fast");
+    m.def("three_interpolate_wrapper", &three_interpolate_wrapper_fast, "three_interpolate_wrapper_fast");
+    m.def("three_interpolate_grad_wrapper", &three_interpolate_grad_wrapper_fast, "three_interpolate_grad_wrapper_fast");
+}
diff --git a/lib/pointnet2/src/sampling.cpp b/lib/pointnet2/src/sampling.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..5f54daa763ed66240c17ba6254ee9d5a39b6dfc0
--- /dev/null
+++ b/lib/pointnet2/src/sampling.cpp
@@ -0,0 +1,45 @@
+#include <torch/serialize/tensor.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <vector>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/CUDAEvent.h>
+#include "sampling_gpu.h"
+
+
+
+int gather_points_wrapper_fast(int b, int c, int n, int npoints, 
+    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor){
+    const float *points = points_tensor.data<float>();
+    const int *idx = idx_tensor.data<int>();
+    float *out = out_tensor.data<float>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    gather_points_kernel_launcher_fast(b, c, n, npoints, points, idx, out, stream);
+    return 1;
+}
+
+
+int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 
+    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor) {
+
+    const float *grad_out = grad_out_tensor.data<float>();
+    const int *idx = idx_tensor.data<int>();
+    float *grad_points = grad_points_tensor.data<float>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    gather_points_grad_kernel_launcher_fast(b, c, n, npoints, grad_out, idx, grad_points, stream);
+    return 1;
+}
+
+
+int furthest_point_sampling_wrapper(int b, int n, int m, 
+    at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
+
+    const float *points = points_tensor.data<float>();
+    float *temp = temp_tensor.data<float>();
+    int *idx = idx_tensor.data<int>();
+
+    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
+    furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
+    return 1;
+}
diff --git a/lib/pointnet2/src/sampling_gpu.cu b/lib/pointnet2/src/sampling_gpu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..9e49a60dd6a80449be4c6c0d0d710be7b5fe9cd5
--- /dev/null
+++ b/lib/pointnet2/src/sampling_gpu.cu
@@ -0,0 +1,253 @@
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "cuda_utils.h"
+#include "sampling_gpu.h"
+
+
+__global__ void gather_points_kernel_fast(int b, int c, int n, int m, 
+    const float *__restrict__ points, const int *__restrict__ idx, float *__restrict__ out) {
+    // points: (B, C, N)
+    // idx: (B, M)
+    // output:
+    //      out: (B, C, M)
+
+    int bs_idx = blockIdx.z;
+    int c_idx = blockIdx.y;
+    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
+
+    out += bs_idx * c * m + c_idx * m + pt_idx;
+    idx += bs_idx * m + pt_idx;
+    points += bs_idx * c * n + c_idx * n;
+    out[0] = points[idx[0]];
+}
+
+void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 
+    const float *points, const int *idx, float *out, cudaStream_t stream) {
+    // points: (B, C, N)
+    // idx: (B, npoints)
+    // output:
+    //      out: (B, C, npoints)
+
+    cudaError_t err;
+    dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+
+    gather_points_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, points, idx, out);
+
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
+
+__global__ void gather_points_grad_kernel_fast(int b, int c, int n, int m, const float *__restrict__ grad_out, 
+    const int *__restrict__ idx, float *__restrict__ grad_points) {
+    // grad_out: (B, C, M)
+    // idx: (B, M)
+    // output:
+    //      grad_points: (B, C, N)
+
+    int bs_idx = blockIdx.z;
+    int c_idx = blockIdx.y;
+    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if (bs_idx >= b || c_idx >= c || pt_idx >= m) return;
+
+    grad_out += bs_idx * c * m + c_idx * m + pt_idx;
+    idx += bs_idx * m + pt_idx;
+    grad_points += bs_idx * c * n + c_idx * n;
+
+    atomicAdd(grad_points + idx[0], grad_out[0]);
+}
+
+void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 
+    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream) {
+    // grad_out: (B, C, npoints)
+    // idx: (B, npoints)
+    // output:
+    //      grad_points: (B, C, N)
+
+    cudaError_t err;
+    dim3 blocks(DIVUP(npoints, THREADS_PER_BLOCK), c, b);  // blockIdx.x(col), blockIdx.y(row)
+    dim3 threads(THREADS_PER_BLOCK);
+
+    gather_points_grad_kernel_fast<<<blocks, threads, 0, stream>>>(b, c, n, npoints, grad_out, idx, grad_points);
+
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
+
+
+__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, int idx1, int idx2){
+    const float v1 = dists[idx1], v2 = dists[idx2];
+    const int i1 = dists_i[idx1], i2 = dists_i[idx2];
+    dists[idx1] = max(v1, v2);
+    dists_i[idx1] = v2 > v1 ? i2 : i1;
+}
+
+template <unsigned int block_size>
+__global__ void furthest_point_sampling_kernel(int b, int n, int m, 
+    const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
+    // dataset: (B, N, 3)
+    // tmp: (B, N)
+    // output:
+    //      idx: (B, M)
+
+    if (m <= 0) return;
+    __shared__ float dists[block_size];
+    __shared__ int dists_i[block_size];
+
+    int batch_index = blockIdx.x;
+    dataset += batch_index * n * 3;
+    temp += batch_index * n;
+    idxs += batch_index * m;
+
+    int tid = threadIdx.x;
+    const int stride = block_size;
+
+    int old = 0;
+    if (threadIdx.x == 0)
+    idxs[0] = old;
+
+    __syncthreads();
+    for (int j = 1; j < m; j++) {
+    int besti = 0;
+    float best = -1;
+    float x1 = dataset[old * 3 + 0];
+    float y1 = dataset[old * 3 + 1];
+    float z1 = dataset[old * 3 + 2];
+    for (int k = tid; k < n; k += stride) {
+        float x2, y2, z2;
+        x2 = dataset[k * 3 + 0];
+        y2 = dataset[k * 3 + 1];
+        z2 = dataset[k * 3 + 2];
+        // float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
+        // if (mag <= 1e-3)
+        // continue;
+
+        float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
+        float d2 = min(d, temp[k]);
+        temp[k] = d2;
+        besti = d2 > best ? k : besti;
+        best = d2 > best ? d2 : best;
+    }
+    dists[tid] = best;
+    dists_i[tid] = besti;
+    __syncthreads();
+
+    if (block_size >= 1024) {
+        if (tid < 512) {
+            __update(dists, dists_i, tid, tid + 512);
+        }
+        __syncthreads();
+    }
+
+    if (block_size >= 512) {
+        if (tid < 256) {
+            __update(dists, dists_i, tid, tid + 256);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 256) {
+        if (tid < 128) {
+            __update(dists, dists_i, tid, tid + 128);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 128) {
+        if (tid < 64) {
+            __update(dists, dists_i, tid, tid + 64);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 64) {
+        if (tid < 32) {
+            __update(dists, dists_i, tid, tid + 32);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 32) {
+        if (tid < 16) {
+            __update(dists, dists_i, tid, tid + 16);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 16) {
+        if (tid < 8) {
+            __update(dists, dists_i, tid, tid + 8);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 8) {
+        if (tid < 4) {
+            __update(dists, dists_i, tid, tid + 4);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 4) {
+        if (tid < 2) {
+            __update(dists, dists_i, tid, tid + 2);
+        }
+        __syncthreads();
+    }
+    if (block_size >= 2) {
+        if (tid < 1) {
+            __update(dists, dists_i, tid, tid + 1);
+        }
+        __syncthreads();
+    }
+
+    old = dists_i[0];
+    if (tid == 0)
+        idxs[j] = old;
+    }
+}
+
+void furthest_point_sampling_kernel_launcher(int b, int n, int m, 
+    const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
+    // dataset: (B, N, 3)
+    // tmp: (B, N)
+    // output:
+    //      idx: (B, M)
+
+    cudaError_t err;
+    unsigned int n_threads = opt_n_threads(n);
+
+    switch (n_threads) {
+        case 1024:
+        furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 512:
+        furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 256:
+        furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 128:
+        furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 64:
+        furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 32:
+        furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 16:
+        furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 8:
+        furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 4:
+        furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 2:
+        furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        case 1:
+        furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
+        default:
+        furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
+    }
+
+    err = cudaGetLastError();
+    if (cudaSuccess != err) {
+        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
+        exit(-1);
+    }
+}
diff --git a/lib/pointnet2/src/sampling_gpu.h b/lib/pointnet2/src/sampling_gpu.h
new file mode 100644
index 0000000000000000000000000000000000000000..6200c5914e434ecd2fc3b36313985805f6dbe0cc
--- /dev/null
+++ b/lib/pointnet2/src/sampling_gpu.h
@@ -0,0 +1,29 @@
+#ifndef _SAMPLING_GPU_H
+#define _SAMPLING_GPU_H
+
+#include <torch/serialize/tensor.h>
+#include <ATen/cuda/CUDAContext.h>
+#include<vector>
+
+
+int gather_points_wrapper_fast(int b, int c, int n, int npoints, 
+    at::Tensor points_tensor, at::Tensor idx_tensor, at::Tensor out_tensor);
+
+void gather_points_kernel_launcher_fast(int b, int c, int n, int npoints, 
+    const float *points, const int *idx, float *out, cudaStream_t stream);
+
+
+int gather_points_grad_wrapper_fast(int b, int c, int n, int npoints, 
+    at::Tensor grad_out_tensor, at::Tensor idx_tensor, at::Tensor grad_points_tensor);
+
+void gather_points_grad_kernel_launcher_fast(int b, int c, int n, int npoints, 
+    const float *grad_out, const int *idx, float *grad_points, cudaStream_t stream);
+
+
+int furthest_point_sampling_wrapper(int b, int n, int m, 
+    at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor);
+
+void furthest_point_sampling_kernel_launcher(int b, int n, int m, 
+    const float *dataset, float *temp, int *idxs, cudaStream_t stream);
+
+#endif
diff --git a/model/LLM/__init__.py b/model/LLM/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8eb9e9325f1906f28a9d60d967ff76963ff1a8
--- /dev/null
+++ b/model/LLM/__init__.py
@@ -0,0 +1 @@
+from . import onellm
\ No newline at end of file
diff --git a/model/LLM/__pycache__/__init__.cpython-310.pyc b/model/LLM/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e70f6416d504770062ceb50661a6094181c47ea2
Binary files /dev/null and b/model/LLM/__pycache__/__init__.cpython-310.pyc differ
diff --git a/model/LLM/__pycache__/__init__.cpython-39.pyc b/model/LLM/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be815601063517a817e23e64c9a6208e4e66d833
Binary files /dev/null and b/model/LLM/__pycache__/__init__.cpython-39.pyc differ
diff --git a/model/LLM/__pycache__/onellm.cpython-310.pyc b/model/LLM/__pycache__/onellm.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ccf829243e41031a865186ba965b5d98d44174f1
Binary files /dev/null and b/model/LLM/__pycache__/onellm.cpython-310.pyc differ
diff --git a/model/LLM/__pycache__/onellm.cpython-39.pyc b/model/LLM/__pycache__/onellm.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..320f3ae803542ebcea9d3414a83ae4f5e5845455
Binary files /dev/null and b/model/LLM/__pycache__/onellm.cpython-39.pyc differ
diff --git a/model/LLM/onellm.py b/model/LLM/onellm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a5195737c0448e3d83c3301acbd3fce3bcd0a4e
--- /dev/null
+++ b/model/LLM/onellm.py
@@ -0,0 +1,495 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the GNU General Public License version 3.
+
+from typing import Optional, Tuple
+from dataclasses import dataclass
+import math
+import functools
+import copy
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+import fairscale.nn.model_parallel.initialize as fs_init
+from fairscale.nn.model_parallel.layers import (
+    ParallelEmbedding,
+    RowParallelLinear,
+    ColumnParallelLinear,
+)
+from ..components import RMSNorm
+from flash_attn import flash_attn_func
+
+import open_clip
+
+
+default_linear_init = nn.init.xavier_uniform_
+
+
+@dataclass
+class ModelArgs:
+    dim: int = 512
+    n_layers: int = 8
+    n_heads: int = 8
+    vocab_size: int = -1  # defined later by tokenizer
+    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
+    norm_eps: float = 1e-5
+
+    max_batch_size: int = 32
+    max_seq_len: int = 2048
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
+                   [: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device)  # type: ignore
+    freqs = torch.outer(t, freqs).float()  # type: ignore
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
+    return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+    shape = [d if i == 1 or i == ndim -
+             1 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+    return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+class Attention(nn.Module):
+    def __init__(self, args: ModelArgs):
+        super().__init__()
+
+        self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
+        self.head_dim = args.dim // args.n_heads
+
+        self.wq = ColumnParallelLinear(
+            args.dim,
+            args.n_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=default_linear_init,
+        )
+        self.wk = ColumnParallelLinear(
+            args.dim,
+            args.n_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=default_linear_init,
+        )
+        self.wv = ColumnParallelLinear(
+            args.dim,
+            args.n_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=default_linear_init,
+        )
+        self.wo = RowParallelLinear(
+            args.n_heads * self.head_dim,
+            args.dim,
+            bias=False,
+            input_is_parallel=True,
+            init_method=default_linear_init,
+        )
+
+        self.flash = True
+        self.k_cache, self.v_cache = None, None
+
+    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
+        bsz, seqlen, _ = x.shape
+        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
+        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+        xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+        xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+
+        if freqs_cis is not None:
+            xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+        if self.k_cache is None or self.v_cache is None:
+            keys, values = xk, xv
+        else:
+            self.k_cache = self.k_cache.to(xk)
+            self.v_cache = self.v_cache.to(xv)
+            self.k_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xk
+            self.v_cache[:bsz, start_pos: start_pos + seqlen, :, :] = xv
+            keys = self.k_cache[:bsz, :start_pos + seqlen]
+            values = self.v_cache[:bsz, :start_pos + seqlen]
+
+        output = flash_attn_func(
+            xq, keys, values, dropout_p=0.0, causal=mask is not None)
+        output = output.contiguous().view(bsz, seqlen, -1)
+
+        return self.wo(output)
+
+    def allocate_kv_cache(self, max_batch_size: int, max_seq_len: int) -> None:
+        kv_cache_shape = (max_batch_size, max_seq_len,
+                          self.n_local_heads, self.head_dim)
+        if self.k_cache is None or self.k_cache.size() != kv_cache_shape:
+            self.k_cache = torch.empty(kv_cache_shape)
+        if self.v_cache is None or self.v_cache.size() != kv_cache_shape:
+            self.v_cache = torch.empty(kv_cache_shape)
+
+    def destroy_kv_cache(self) -> None:
+        self.k_cache, self.v_cache = None, None
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        hidden_dim: int,
+        multiple_of: int,
+    ):
+        super().__init__()
+        hidden_dim = int(2 * hidden_dim / 3)
+        hidden_dim = multiple_of * \
+            ((hidden_dim + multiple_of - 1) // multiple_of)
+
+        self.w1 = ColumnParallelLinear(
+            dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init,
+        )
+        self.w2 = RowParallelLinear(
+            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=default_linear_init
+        )
+        self.w3 = ColumnParallelLinear(
+            dim, hidden_dim, bias=False, gather_output=False, init_method=default_linear_init
+        )
+
+    def _silu_gating(self, x, y):
+        return F.silu(x) * y
+
+    def forward(self, x):
+        return self.w2(self._silu_gating(self.w1(x), self.w3(x)))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, layer_id: int, args: ModelArgs):
+        super().__init__()
+        self.n_heads = args.n_heads
+        self.dim = args.dim
+        self.head_dim = args.dim // args.n_heads
+        self.attention = Attention(args)
+        self.feed_forward = FeedForward(
+            dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
+        )
+        self.layer_id = layer_id
+        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+    def _forward_ffn(self, h):
+        return h + self.feed_forward(self.ffn_norm(h))
+
+    def _forward_attention(self, x, start_pos, freqs_cis, mask, prompt):
+        return x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, prompt)
+
+    def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], prompt=None):
+        h = self._forward_attention(x, start_pos, freqs_cis, mask, prompt)
+        out = self._forward_ffn(h)
+        return out
+
+
+class Mlp(nn.Module):
+    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
+    """
+
+    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+
+        self.fc1 = nn.Linear(in_features, hidden_features)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features)
+
+    def forward(self, x):
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.fc2(x)
+        return x
+
+
+class Transformer(nn.Module):
+    def __init__(self, params: ModelArgs):
+        super().__init__()
+        self.params = params
+        self.vocab_size = params.vocab_size
+        self.n_layers = params.n_layers
+        self.tok_embeddings = ParallelEmbedding(
+            params.vocab_size, params.dim, init_method=nn.init.normal_,
+        )
+
+        self.layers = torch.nn.ModuleList()
+        for layer_id in range(params.n_layers):
+            self.layers.append(TransformerBlock(layer_id, params))
+
+        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
+        self.output = ColumnParallelLinear(
+            params.dim, params.vocab_size, bias=False, init_method=default_linear_init,
+        )
+
+        self.freqs_cis = precompute_freqs_cis(
+            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
+        )
+
+        # load clip
+        self.clip, _, _ = open_clip.create_model_and_transforms(
+            'ViT-L-14', pretrained='openai')
+        for param in self.clip.parameters():
+            param.requires_grad = False
+            param.data = param.data.half()
+        self.clip.transformer = None
+
+        self.image_words = 30
+        self.cache_image_words = 0  # for inference
+
+        clip_width = self.clip.visual.conv1.out_channels
+        # create modal shared modules
+        self.resample_layers = nn.ModuleDict()
+        self.num_experts = 3
+        self.num_resample_layers = 8
+        for expert in range(self.num_experts):
+            expert = str(expert)
+            self.resample_layers[expert] = nn.ModuleList()
+            resampler_params = copy.deepcopy(params)
+            resampler_params.n_heads = 16
+            resampler_params.dim = clip_width
+            for layer_id in range(self.num_resample_layers):
+                self.resample_layers[expert].append(
+                    TransformerBlock(layer_id, resampler_params))
+
+        self.conv1 = nn.ModuleDict()
+        self.positional_embedding = nn.ParameterDict()
+        self.resample_tokens = nn.ParameterDict()
+        self.clip_proj1 = nn.ModuleDict()
+        self.clip_proj2 = nn.ModuleDict()
+        self.routers = nn.ModuleDict()
+        self.start_tag = nn.ParameterDict()
+        self.end_tag = nn.ParameterDict()
+        # self.modals = ['image', 'audio', 'point', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
+        self.modals = ['image', 'audio', 'video', 'rgbd', 'rgbn', 'fmri', 'imu']
+        for modal in self.modals:
+            if modal in ['image', 'video', 'rgbn', 'rgbn']:
+                modal_tokens = 256 + 1
+                pass
+            elif modal == 'audio':
+                self.conv1[modal] = nn.Conv2d(
+                    1, clip_width, kernel_size=(16, 16), stride=(10, 10))
+                modal_tokens = 1212 + 1
+                self.positional_embedding[modal] = nn.Parameter(
+                    torch.empty([modal_tokens, clip_width]))
+                nn.init.normal_(self.positional_embedding[modal], std=0.02)
+            elif modal == 'point':
+                from lib.point_utils import PointPatchEmbed
+                self.conv1[modal] = PointPatchEmbed(
+                    in_channels=6, channels=clip_width)
+                modal_tokens = 1024 + 1
+                self.positional_embedding[modal] = nn.Parameter(
+                    torch.empty([modal_tokens, clip_width]))
+                nn.init.normal_(self.positional_embedding[modal], std=0.02)
+            elif modal == 'fmri':
+                self.conv1[modal] = nn.Linear(15724, 8192)
+                self.positional_embedding[modal] = nn.Parameter(
+                    torch.empty([8+1, clip_width]))
+                nn.init.normal_(self.positional_embedding[modal], std=0.02)
+            elif modal == 'imu':
+                self.conv1[modal] = nn.Conv1d(
+                    in_channels=6, out_channels=clip_width, kernel_size=10, bias=False)
+                self.positional_embedding[modal] = nn.Parameter(
+                    torch.empty([391+1, clip_width]))
+                nn.init.normal_(self.positional_embedding[modal], std=0.02)
+
+            self.routers[modal] = Mlp(
+                clip_width, clip_width * 4, self.num_experts)
+
+            self.resample_tokens[modal] = nn.Parameter(
+                torch.empty([1, 30, resampler_params.dim]))
+            nn.init.normal_(self.resample_tokens[modal], std=0.02)
+
+            self.clip_proj1[modal] = nn.Sequential(
+                nn.Linear(clip_width, resampler_params.dim),
+                nn.LayerNorm(resampler_params.dim))
+
+            self.clip_proj2[modal] = nn.Sequential(
+                nn.Linear(resampler_params.dim, params.dim),
+                nn.LayerNorm(params.dim))
+
+            self.start_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
+            self.end_tag[modal] = nn.Parameter(torch.rand(1, 1, params.dim))
+
+    # @torch.no_grad()
+
+    def clip_encode_image(self, x, modal='image'):
+        # shape = [*, width, grid ** 2]
+        x = x.reshape(x.shape[0], x.shape[1], -1)
+        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
+
+        x = torch.cat([self.clip.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1,
+                      x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
+
+        # use pretrained pos embeding for rest modalities
+        pos_embedding = self.clip.visual.positional_embedding
+        if modal in ['audio', 'point', 'fmri', 'imu']:
+            pos_embedding = self.positional_embedding[modal]
+
+        x = x + pos_embedding.to(x.dtype)
+        x = self.clip.visual.ln_pre(x)
+
+        x = x.permute(1, 0, 2)  # NLD -> LND
+        x = self.clip.visual.transformer(x)
+        x = x.permute(1, 0, 2)  # LND -> NLD
+
+        # preserve all spatial tokens
+        x = self.clip.visual.ln_post(x[:, :, :])
+
+        # if self.clip.visual.proj is not None:
+        #    x = x @ self.clip.visual.proj
+
+        return x
+
+    def encode_image(self, x, modal='image'):
+        bsz = x.size(0)
+        T = 1
+        if modal in ['image']:
+            # modified from CLIP
+            x = self.clip.visual.conv1(x)  # shape = [*, width, grid, grid]
+        elif modal in ['audio', 'imu']:
+            x = self.conv1[modal](x)
+        elif modal == 'point':
+            # [B, 16384, 6] -> [B, 1024, 1024, 1]
+            x = self.conv1[modal](x.float()).to(x.dtype)
+        elif modal in ['video', 'rgbd', 'rgbn']:
+            # [B, 15, 3, 224, 224]
+            B, T = x.shape[:2]
+            bsz = B * T
+            x = x.reshape(bsz, *x.shape[2:])
+            x = self.clip.visual.conv1(x)
+        elif modal == 'fmri':
+            x = self.conv1[modal](x)
+            # [B, 1, 8196] -> [B, 1024, 8]
+            x = x.reshape(x.size(0), self.clip.visual.conv1.out_channels, -1)
+
+        image_feats = self.clip_encode_image(x, modal=modal)
+        # take mean on time dimension
+        # all inputs are reduced to [B, L, D]
+        bsz = int(bsz / T)
+        image_feats = image_feats.reshape(
+            bsz, T, *image_feats.shape[1:]).mean(dim=1)
+
+        image_feats = self.clip_proj1[modal](image_feats)
+        image_feats = torch.cat(
+            [self.resample_tokens[modal].repeat(bsz, 1, 1), image_feats], dim=1)
+
+        # routing modalites
+        # [B, L, D]->[B, L, N]
+        routing_weights = self.routers[modal](image_feats).sigmoid()
+        routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
+
+        image_feats_experts = []
+        for expert_id in range(self.num_experts):
+            image_feats_expert = image_feats
+            for layer in self.resample_layers[str(expert_id)]:
+                image_feats_expert = layer(image_feats_expert, 0, None, None)
+
+            image_feats_expert = image_feats_expert[:, :self.resample_tokens[modal].size(1)]
+            routing_weight = routing_weights[:, :self.resample_tokens[modal].size(
+                1), expert_id]
+            # [B, L, D] * [B, L, 1]
+            image_feats_expert = image_feats_expert * routing_weight[:, :, None]
+
+            image_feats_experts.append(image_feats_expert)
+
+        image_feats = sum(image_feats_experts)
+        image_feats = self.clip_proj2[modal](image_feats)
+
+        return image_feats
+
+    def forward(self, examples, image=None, modal='image'):
+        self._destroy_kv_cache()  # training always disables kv cache
+        modal = modal[0]
+        _bsz, seqlen = examples.shape
+        h = self.tok_embeddings(examples)
+        self.freqs_cis = self.freqs_cis.to(h.device)
+
+        start_pos = 0
+        prefix_len = 0
+        if image is not None:
+            h_bos, h_caption = h[:, :1], h[:, 1:]
+            image_tokens = self.encode_image(image, modal)
+            h = torch.cat((h_bos, self.start_tag[modal].expand(
+                _bsz, -1, -1), image_tokens, self.end_tag[modal].expand(_bsz, -1, -1), h_caption), dim=1)
+            # bos + image token + start_tag[modal], end_tag[modal] is used for caption generation
+            prefix_len = image_tokens.shape[1] + 1 + 1
+            seqlen = h.shape[1]
+
+        freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
+        mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
+        mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
+        for layer in self.layers:
+            h = layer(h, start_pos, freqs_cis, mask)
+        h = self.norm(h)
+        output = self.output(h[:, prefix_len:, :])
+        return output
+
+    @torch.inference_mode()
+    def forward_inference(self, tokens: torch.Tensor, start_pos: int, image=None, modal='image'):
+        modal = modal[0] if isinstance(modal, list) else modal
+        _bsz, seqlen = tokens.shape
+        if start_pos == 0:
+            # kv cache will not re-allocate if size is unchanged
+            self._allocate_kv_cache(_bsz)
+        h = self.tok_embeddings(tokens)
+        self.freqs_cis = self.freqs_cis.to(h.device)
+
+        if image is not None:
+            h_bos, h_caption = h[:, :1], h[:, 1:]
+            image_tokens = self.encode_image(image, modal)
+            self.cache_image_words = image_tokens.shape[1]
+            h = torch.cat((h_bos, self.start_tag[modal].repeat(_bsz, 1, 1), image_tokens, self.end_tag[modal].repeat(_bsz, 1, 1), h_caption), dim=1)
+            seqlen = h.shape[1]
+            freqs_cis = self.freqs_cis[0: seqlen]
+        else:
+            if start_pos == 0:
+                self.cache_image_words = 0
+                freqs_cis = self.freqs_cis[0: seqlen]
+            else:
+                # if image was not None when start_pos=0,
+                # the offset should be added to start_pos within later forward_inference calls
+                start_pos = start_pos + self.cache_image_words
+                freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]
+
+        # freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
+
+        mask = None
+        if seqlen > 1:
+            mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
+            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
+
+        for layer in self.layers:
+            h = layer(h, start_pos, freqs_cis, mask)
+        h = self.norm(h)
+        output = self.output(h[:, -1, :])  # only compute last logits
+        return output.float()
+
+    def _allocate_kv_cache(self, max_batch_size: int) -> None:
+        for layer in self.layers:
+            layer.attention.allocate_kv_cache(
+                max_batch_size, self.params.max_seq_len)
+
+    def _destroy_kv_cache(self) -> None:
+        for layer in self.layers:
+            layer.attention.destroy_kv_cache()
diff --git a/model/__init__.py b/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/model/__pycache__/__init__.cpython-310.pyc b/model/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab67f64cfe739a7a1c51327e5e7a0ea2afc50cd9
Binary files /dev/null and b/model/__pycache__/__init__.cpython-310.pyc differ
diff --git a/model/__pycache__/__init__.cpython-39.pyc b/model/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bea7e8cd12224f18eb3eefbc92abf61852979fab
Binary files /dev/null and b/model/__pycache__/__init__.cpython-39.pyc differ
diff --git a/model/__pycache__/components.cpython-39.pyc b/model/__pycache__/components.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dfbf25224cf34dab4fa2f85fff462d2dbef6b4d6
Binary files /dev/null and b/model/__pycache__/components.cpython-39.pyc differ
diff --git a/model/__pycache__/meta.cpython-310.pyc b/model/__pycache__/meta.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3fbd547ed81e5dc062ca75c125fe8c8a668b5ead
Binary files /dev/null and b/model/__pycache__/meta.cpython-310.pyc differ
diff --git a/model/__pycache__/meta.cpython-39.pyc b/model/__pycache__/meta.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b69c01b1098d55637fb39f5be3aed62ddaf7cf43
Binary files /dev/null and b/model/__pycache__/meta.cpython-39.pyc differ
diff --git a/model/__pycache__/tokenizer.cpython-310.pyc b/model/__pycache__/tokenizer.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9a4452629f6f6edcb5522834a8e5bbdfc825b48e
Binary files /dev/null and b/model/__pycache__/tokenizer.cpython-310.pyc differ
diff --git a/model/__pycache__/tokenizer.cpython-39.pyc b/model/__pycache__/tokenizer.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1a8d58048c6364146decfe9c883d63dc197e359
Binary files /dev/null and b/model/__pycache__/tokenizer.cpython-39.pyc differ
diff --git a/model/components.py b/model/components.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c8bc4e88484950988aaad4faf6d34ec1a4ec8bf
--- /dev/null
+++ b/model/components.py
@@ -0,0 +1,57 @@
+import warnings
+import torch
+import torch.nn as nn
+
+try:
+    from apex.normalization import FusedRMSNorm as RMSNorm
+except ImportError:
+    warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
+
+    class RMSNorm(torch.nn.Module):
+        def __init__(self, dim: int, eps: float = 1e-6):
+            """
+            Initialize the RMSNorm normalization layer.
+
+            Args:
+                dim (int): The dimension of the input tensor.
+                eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+            Attributes:
+                eps (float): A small value added to the denominator for numerical stability.
+                weight (nn.Parameter): Learnable scaling parameter.
+
+            """
+            super().__init__()
+            self.eps = eps
+            self.weight = nn.Parameter(torch.ones(dim))
+
+        def _norm(self, x):
+            """
+            Apply the RMSNorm normalization to the input tensor.
+
+            Args:
+                x (torch.Tensor): The input tensor.
+
+            Returns:
+                torch.Tensor: The normalized tensor.
+
+            """
+            return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+        def forward(self, x):
+            """
+            Forward pass through the RMSNorm layer.
+
+            Args:
+                x (torch.Tensor): The input tensor.
+
+            Returns:
+                torch.Tensor: The output tensor after applying RMSNorm.
+
+            """
+            output = self._norm(x.float()).type_as(x)
+            return output * self.weight
+
+
+
+
diff --git a/model/meta.py b/model/meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0ab6daaa14337633f9d3261d78248683d04c930
--- /dev/null
+++ b/model/meta.py
@@ -0,0 +1,175 @@
+from typing import List
+import torch
+import torch.nn as nn
+import json
+import os
+from .tokenizer import Tokenizer
+from . import LLM
+
+from fairscale.nn.model_parallel import initialize as fs_init
+
+
+class MetaModel(nn.Module):
+
+    def __init__(self, llama_type, llama_config, llama_ckpt_dir=None, tokenizer_path=None):
+        super().__init__()
+
+        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
+
+        ModelArgs = LLM.__dict__[llama_type].ModelArgs
+        Transformer = LLM.__dict__[llama_type].Transformer
+
+        with open(llama_config, "r") as f:
+            params = json.loads(f.read())
+        model_args: ModelArgs = ModelArgs(
+            max_seq_len=2048, max_batch_size=32, **params
+        )
+        self.tokenizer = Tokenizer(model_path=tokenizer_path)
+        model_args.vocab_size = self.tokenizer.n_words
+
+        model = Transformer(model_args)
+        mp_rank = fs_init.get_model_parallel_rank()
+        if llama_ckpt_dir is not None:
+            ckpt_path = os.path.join(llama_ckpt_dir, f"consolidated.{mp_rank:02d}.pth")
+            if os.path.exists(ckpt_path):
+                checkpoint = torch.load(ckpt_path, map_location="cpu")
+                msg = model.load_state_dict(checkpoint, strict=False)
+                print(msg)
+            else:
+                print(f'Checkpoint not found at {ckpt_path}')
+        self.llma = model
+        for name, param in self.named_parameters():
+            if param.requires_grad:
+               print(f"Trainable param: {name}, {param.shape}, {param.dtype}")
+        count = sum(p.numel() for p in self.parameters() if p.requires_grad)
+        print(f"Parameter count : {count}")
+
+    def forward(self, examples, labels, image=None, modal='image'):
+        output = self.llma(examples, image=image, modal=modal)
+        output = output[:, :-1, :]
+        labels = labels[:, 1:]
+
+        if labels.sum() == 0:
+            c_loss = output.mean() * 0
+        else:
+            c_loss = self.criterion(output.reshape(-1, 32000), labels.flatten())
+
+        return c_loss
+
+    def generate(
+        self,
+        prompts: List[str],
+        images,
+        max_gen_len: int,
+        temperature: float = 0.8,
+        top_p: float = 0.95,
+        modal = ['image'],
+    ) -> List[str]:
+        bsz = len(prompts)
+        params = self.llma.params
+        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
+
+        prompt_tokens = [self.tokenizer.encode(
+            x, bos=True, eos=False) for x in prompts]
+
+        min_prompt_size = min([len(t) for t in prompt_tokens])
+        max_prompt_size = max([len(t) for t in prompt_tokens])
+
+        total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
+
+        tokens = torch.full(
+            (bsz, total_len), self.tokenizer.pad_id).cuda().long()
+        for k, t in enumerate(prompt_tokens):
+            tokens[k, : len(t)] = torch.tensor(t).long()
+        input_text_mask = tokens != self.tokenizer.pad_id
+        start_pos = min_prompt_size
+        prev_pos = 0
+        for cur_pos in range(start_pos, total_len):
+            logits = self.llma.forward_inference(tokens[:, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal=modal)
+            if temperature > 0:
+                probs = torch.softmax(logits / temperature, dim=-1)
+                next_token = self.sample_top_p(probs, top_p)
+            else:
+                next_token = torch.argmax(logits, dim=-1)
+            next_token = next_token.reshape(-1)
+            # only replace token if prompt has already been generated
+            next_token = torch.where(
+                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
+            )
+            tokens[:, cur_pos] = next_token
+            prev_pos = cur_pos
+
+        decoded = []
+        for i, t in enumerate(tokens.tolist()):
+            # cut to max gen len
+            t = t[: len(prompt_tokens[i]) + max_gen_len]
+            # cut to eos tok if any
+            try:
+                t = t[: t.index(self.tokenizer.eos_id)]
+            except ValueError:
+                pass
+            decoded.append(self.tokenizer.decode(t))
+        return decoded
+    
+    @torch.inference_mode()
+    def stream_generate(
+        self,
+        prompt: str,
+        images,
+        max_gen_len: int,
+        temperature: float = 0.8,
+        top_p: float = 0.95,
+        modal = ['image'],
+    ):
+        params = self.llma.params
+
+        prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
+        # truncate from the left. leave some space for generation.
+        max_seq_len = params.max_seq_len
+        if images is not None:
+            max_seq_len -= self.llma.image_words
+
+        max_prompt_size = max_seq_len - max_gen_len
+        prompt_tokens = prompt_tokens[-max_prompt_size:]
+
+        prompt_size = len(prompt_tokens)
+
+        total_len = min(max_seq_len, max_gen_len + prompt_size)
+
+        tokens = torch.full([total_len], 0).cuda().long()
+
+        tokens[:len(prompt_tokens)] = torch.tensor(prompt_tokens).long()
+        start_pos = prompt_size
+        prev_pos = 0
+        generate_until = start_pos
+        for cur_pos in range(start_pos, total_len):
+            logits = self.llma.forward_inference(tokens[None, prev_pos:cur_pos], prev_pos, images if prev_pos == 0 else None, modal = modal)
+            if temperature > 0:
+                probs = torch.softmax(logits / temperature, dim=-1)
+                next_token = self.sample_top_p(probs, top_p)
+            else:
+                next_token = torch.argmax(logits, dim=-1)
+            next_token = next_token.item()
+
+            if next_token == self.tokenizer.eos_id:
+                break
+
+            tokens[cur_pos] = next_token
+            prev_pos = cur_pos
+            generate_until = cur_pos + 1
+            yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": False}
+
+        yield {"text": self.tokenizer.decode(tokens[start_pos:generate_until].tolist()), "end_of_content": True}
+
+    def sample_top_p(self, probs, p):
+        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+        probs_sum = torch.cumsum(probs_sort, dim=-1)
+        mask = probs_sum - probs_sort > p
+        probs_sort[mask] = 0.0
+        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+        next_token = torch.multinomial(probs_sort, num_samples=1)
+        next_token = torch.gather(probs_idx, -1, next_token)
+        return next_token
+
+    def get_image_words(self):
+        return self.llma.image_words
\ No newline at end of file
diff --git a/model/tokenizer.py b/model/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4315856eea5c4318499c8909898252902252f30
--- /dev/null
+++ b/model/tokenizer.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the GNU General Public License version 3.
+
+from sentencepiece import SentencePieceProcessor
+from logging import getLogger
+from typing import List
+import os
+
+
+logger = getLogger()
+
+
+class Tokenizer:
+    def __init__(self, model_path: str):
+        # reload tokenizer
+        assert os.path.isfile(model_path), model_path
+        self.sp_model = SentencePieceProcessor(model_file=model_path)
+        logger.info(f"Reloaded SentencePiece model from {model_path}")
+
+        # BOS / EOS token IDs
+        self.n_words: int = self.sp_model.vocab_size()
+        self.bos_id: int = self.sp_model.bos_id()
+        self.eos_id: int = self.sp_model.eos_id()
+        self.pad_id: int = self.sp_model.pad_id()
+        logger.info(
+            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
+        )
+        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
+        assert type(s) is str
+        t = self.sp_model.encode(s)
+        if bos:
+            t = [self.bos_id] + t
+        if eos:
+            t = t + [self.eos_id]
+        return t
+
+    def decode(self, t: List[int]) -> str:
+        return self.sp_model.decode(t)
diff --git a/requirements.txt b/requirements.txt
index ff5a872a773e1619013dc49c7be53ad722943b40..ce74fdd1f2242fc4e7bc50f084f7030081836fa5 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,13 @@
---extra-index-url https://download.pytorch.org/whl/cu113
-torch==1.12.0+cu113
+--extra-index-url https://download.pytorch.org/whl/cu117
+torch==2.0.0+cu117
+packaging
 fairscale
 sentencepiece
 Pillow
 huggingface_hub
-git+https://github.com/csuhan/timm_0_3_2.git
-git+https://github.com/openai/CLIP.git
\ No newline at end of file
+open_clip_torch
+pytorchvideo==0.1.5
+torchaudio
+matplotlib
+flash-attn
+gradio
\ No newline at end of file
diff --git a/util/__pycache__/misc.cpython-310.pyc b/util/__pycache__/misc.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4caa262729d9934200c8f44f3ca67d0913580474
Binary files /dev/null and b/util/__pycache__/misc.cpython-310.pyc differ
diff --git a/util/__pycache__/misc.cpython-39.pyc b/util/__pycache__/misc.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9f464da6e5db2871e7f85f496f2a0df542be9804
Binary files /dev/null and b/util/__pycache__/misc.cpython-39.pyc differ
diff --git a/util/lr_sched.py b/util/lr_sched.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc4624f4fb441ea7e37e50857813cb149887a0c0
--- /dev/null
+++ b/util/lr_sched.py
@@ -0,0 +1,42 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+def adjust_learning_rate(optimizer, it, args):
+    """Decay the learning rate with half-cycle cosine after warmup"""
+    if it < args.warmup_iters: # 1) linear warmup for warmup_iters steps
+        lr = args.lr * it / args.warmup_iters
+    elif it > args.lr_decay_iters: # 2) if it > lr_decay_iters, return min learning rate
+        lr = args.min_lr
+    else: # 3) in between, use cosine decay down to min learning rate
+        decay_ratio = (it - args.warmup_iters) / (args.lr_decay_iters - args.warmup_iters)
+        assert 0 <= decay_ratio <= 1
+        coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
+        lr = args.min_lr + (args.lr - args.min_lr) * coeff
+
+    for param_group in optimizer.param_groups:
+        if "lr_scale" in param_group:
+            param_group["lr"] = lr * param_group["lr_scale"]
+        else:
+            param_group["lr"] = lr
+    return lr
+
+
+def adjust_learning_rate_epoch(optimizer, epoch, args):
+    """Decay the learning rate with half-cycle cosine after warmup"""
+    if epoch < args.warmup_epochs:
+        lr = args.lr * epoch / args.warmup_epochs
+    else:
+        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
+            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
+    for param_group in optimizer.param_groups:
+        if "lr_scale" in param_group:
+            param_group["lr"] = lr * param_group["lr_scale"]
+        else:
+            param_group["lr"] = lr
+    return lr
+
diff --git a/util/misc.py b/util/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea0d87e40afd9b8be34ef99da7b1409cb1e43ba
--- /dev/null
+++ b/util/misc.py
@@ -0,0 +1,516 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import builtins
+import datetime
+import os
+import glob
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+import subprocess
+
+import torch
+import torch.distributed as dist
+from torch import inf
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from torch.distributed.fsdp import (
+    FullyShardedDataParallel as FSDP,
+    StateDictType,
+    FullStateDictConfig,
+)
+from torch.distributed._shard.api import load_with_process_group
+
+from fairscale.nn.model_parallel import initialize as fs_init
+
+from types import TracebackType
+from typing import Any, Optional
+import torch
+import torch.nn as nn
+
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value)
+
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if v is None:
+                continue
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {}".format(name, str(meter))
+            )
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None, start_iter=0):
+        i = start_iter
+        if not header:
+            header = ''
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt='{avg:.4f}')
+        data_time = SmoothedValue(fmt='{avg:.4f}')
+        log_msg = [
+            header,
+            '[{0' + '}/{1}]',
+            '{meters}',
+            'time: {time}',
+            'data: {data}'
+        ]
+        if torch.cuda.is_available():
+            log_msg.append('max mem: {memory:.0f}')
+        log_msg = self.delimiter.join(log_msg)
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0:
+                try:
+                    total_len = len(iterable)
+                except:
+                    total_len = "unknown"
+                if torch.cuda.is_available():
+                    print(log_msg.format(
+                        i, total_len,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time),
+                        memory=torch.cuda.max_memory_allocated() / MB))
+                else:
+                    print(log_msg.format(
+                        i, total_len,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time)))
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print('{} Total time: {} ({:.4f} s / it)'.format(
+            header, total_time_str, total_time / len(iterable)))
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    builtin_print = builtins.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+#        force = force or (get_world_size() > 8)
+        if is_master or force:
+            now = datetime.datetime.now().time()
+            builtin_print('[{}] '.format(now), end='')  # print with time stamp
+            builtin_print(*args, **kwargs)
+
+    builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+def init_distributed_mode(args):
+    if args.dist_on_itp:
+        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+        args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+        os.environ['LOCAL_RANK'] = str(args.gpu)
+        os.environ['RANK'] = str(args.rank)
+        os.environ['WORLD_SIZE'] = str(args.world_size)
+        # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+    elif 'SLURM_PROCID' in os.environ:
+        os.environ['MASTER_PORT'] = '8994'
+        while 'MASTER_ADDR' not in os.environ or len(os.environ['MASTER_ADDR'].strip()) == 0:
+            os.environ['MASTER_ADDR'] = subprocess.check_output('sinfo -Nh -n %s | head -n 1 | awk \'{print $1}\'' % os.environ['SLURM_NODELIST'], shell=True, ).decode().strip()
+            time.sleep(1)
+        print(os.environ['MASTER_ADDR'])
+        args.world_size = int(os.environ['SLURM_NPROCS'])
+        args.rank = int(os.environ['SLURM_PROCID'])
+        args.gpu = args.rank % torch.cuda.device_count()
+        args.local_rank = args.gpu
+        os.environ['LOCAL_RANK'] = str(args.gpu)
+        os.environ['WORLD_SIZE'] = str(args.world_size)
+        os.environ['RANK'] = str(args.rank)
+    else:
+        print('Not using distributed mode')
+        setup_for_distributed(is_master=True)  # hack
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = 'nccl'
+    print('| distributed init (rank {}): {}, gpu {}'.format(
+        args.rank, args.dist_url, args.gpu), flush=True)
+    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                         world_size=args.world_size, rank=args.rank)
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
+
+
+def init_distributed_mode1(args):
+    if args.dist_on_itp:
+        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
+        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
+        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
+        args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
+        os.environ['LOCAL_RANK'] = str(args.gpu)
+        os.environ['RANK'] = str(args.rank)
+        os.environ['WORLD_SIZE'] = str(args.world_size)
+        # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
+    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+    elif 'SLURM_PROCID' in os.environ:
+        args.rank = int(os.environ['SLURM_PROCID'])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print('Not using distributed mode')
+        setup_for_distributed(is_master=True)  # hack
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = 'nccl'
+    print('| distributed init (rank {}): {}, gpu {}'.format(
+        args.rank, args.dist_url, args.gpu), flush=True)
+    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                         world_size=args.world_size, rank=args.rank)
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+    state_dict_key = "amp_scaler"
+
+    def __init__(self, args):
+        self._scaler = ShardedGradScaler(enabled=args.precision in ["fp16"])
+
+    def __call__(self, loss, optimizer, model, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+        if update_grad:
+            self._scaler.scale(loss).backward(create_graph=create_graph)
+            if clip_grad is not None:
+                assert parameters is not None
+                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place
+                # norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+                norm = model.clip_grad_norm_(clip_grad)
+            else:
+                raise NotImplementedError("please set clip_grad to a very large value if you do not want to clip.")
+                self._scaler.unscale_(optimizer)
+                norm = get_grad_norm_(parameters)
+            self._scaler.step(optimizer)
+            self._scaler.update()
+        else:
+            with model.no_sync():
+                self._scaler.scale(loss).backward(create_graph=create_graph)
+            norm = None
+        return norm
+
+    def state_dict(self):
+        return self._scaler.state_dict()
+
+    def load_state_dict(self, state_dict):
+        self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = [p for p in parameters if p.grad is not None]
+    norm_type = float(norm_type)
+    if len(parameters) == 0:
+        return torch.tensor(0.)
+    device = parameters[0].grad.device
+    if norm_type == inf:
+        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+    else:
+        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+    return total_norm
+
+
+def save_model(output_dir, args, epoch, iteration, model, optimizer, loss_scaler, dataset_state):
+    save_dir = os.path.join(output_dir, f"epoch_{epoch}_iter_{iteration:09d}")
+    os.makedirs(save_dir, exist_ok=True)
+    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+        to_save = {
+            "model": model.state_dict(),
+            "optimizer": optimizer.state_dict(),
+            "iter": iteration,
+            "epoch": epoch,
+            "scaler": loss_scaler.state_dict(),
+            "args": args,
+            "dataset_state": dataset_state,
+        }
+        save_path = os.path.join(
+            save_dir,
+            f"checkpoint.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth",
+        )
+        torch.save(to_save, save_path)
+
+    if args.save_consolidated:
+        mp_rank = fs_init.get_model_parallel_rank()
+        mp_world_size = fs_init.get_model_parallel_world_size()
+        consolidated_model_save_path = os.path.join(
+            save_dir,
+            f"consolidated.{mp_rank:02d}-of-{mp_world_size:02d}.pth",
+        )
+        with FSDP.state_dict_type(
+            model,
+            StateDictType.FULL_STATE_DICT,
+            FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
+        ):
+            save_dtype = {
+                "fp16": torch.float16,
+                "bf16": torch.bfloat16,
+                "tf32": torch.float32,
+            }[args.precision]
+            consolidated_model_state_dict = {
+                k: v.to(save_dtype) for k, v in model.state_dict().items()
+            }
+        if fs_init.get_data_parallel_rank() == 0:
+            torch.save(consolidated_model_state_dict, consolidated_model_save_path)
+    
+    # remove previous ckpts
+    ckpts = glob.glob(os.path.join(output_dir, "iter_*")) + glob.glob(os.path.join(output_dir, "epoch_*"))
+    ckpts.sort()
+    if len(ckpts)>2 and not args.keep_all:
+        for ckpt in ckpts[:-2]:
+            print('del', ckpt)
+            os.system(f'rm {ckpt} -rf')
+
+def load_model(args, model, optimizer, loss_scaler):
+    start_iter = 0
+    start_epoch = 0
+    if args.auto_resume:
+        ckpt_dirs = glob.glob(os.path.join(args.output_dir, "iter_*")) + glob.glob(os.path.join(args.output_dir, "epoch_*"))
+        ckpt_dirs.sort()
+        if len(ckpt_dirs) > 0:
+            args.resume = ckpt_dirs[-1]
+    if args.resume:
+        print("Resume checkpoint %s" % args.resume)
+        local_checkpoint_path = os.path.join(
+            args.resume,
+            f"checkpoint.{dist.get_rank():05d}-of-{dist.get_world_size():05d}.pth",
+        )
+        with load_with_process_group(fs_init.get_data_parallel_group()):
+            checkpoint = torch.load(local_checkpoint_path, map_location='cpu')
+        with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
+            model.load_state_dict(checkpoint['model'])
+        optimizer.load_state_dict(checkpoint['optimizer'])
+        loss_scaler.load_state_dict(checkpoint['scaler'])
+        start_iter = int(checkpoint['iter']) + 1
+        if 'epoch' in checkpoint:
+            start_epoch = int(checkpoint['epoch'])
+    return start_epoch, start_iter
+    
+def all_reduce_mean(x):
+    world_size = get_world_size()
+    if world_size > 1:
+        if isinstance(x, torch.Tensor):
+            x_reduce = x.clone().cuda()
+        else:
+            x_reduce = torch.tensor(x).cuda()
+        dist.all_reduce(x_reduce)
+        x_reduce /= world_size
+        return x_reduce.item()
+    else:
+        return x
+
+
+def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
+    decay = []
+    no_decay = []
+    for name, param in model.named_parameters():
+        if not param.requires_grad:
+            continue  # frozen weights
+        #if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+        if name.endswith(".bias") or name.endswith("norm.weight"):
+            no_decay.append(param)
+        else:
+            decay.append(param)
+    return [
+        {'params': no_decay, 'weight_decay': 0.},
+        {'params': decay, 'weight_decay': weight_decay}]
+
+
+
+
+class default_tensor_type:
+    _tensor_type_stack = [(torch.float, "cpu")]
+    
+    def __init__(
+        self,
+        dtype: Optional[torch.dtype] = None,
+        device: Optional[str] = None,
+    ) -> None:
+        # Only limited combinations are supported.
+        assert device is None or device in ["cpu", "cuda"]
+        assert dtype is None or dtype in [torch.float, torch.bfloat16, torch.half]
+        self.dtype, self.device = dtype, device
+    
+    def __enter__(self) -> None:
+        dtype, device = self.dtype, self.device
+        if dtype is None:
+            dtype = default_tensor_type._tensor_type_stack[-1][0]
+        if device is None:
+            device = default_tensor_type._tensor_type_stack[-1][1]
+        default_tensor_type._tensor_type_stack.append((dtype, device))
+        
+        # We use all 3 calls since the new apis (set_default_device, set_default_dtype)
+        # seems to be ineffective sometimes (e.g., set_default_device is ineffective to
+        # torch.Tensor calls).
+        torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device))
+        torch.set_default_device(device)
+        torch.set_default_dtype(dtype)
+
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
+        default_tensor_type._tensor_type_stack.pop()
+        dtype, device = default_tensor_type._tensor_type_stack[-1]
+
+        torch.set_default_tensor_type(default_tensor_type.get_tensor_type(dtype, device))
+        torch.set_default_device(device)
+        torch.set_default_dtype(dtype)
+
+    @staticmethod
+    def get_tensor_type(dtype: torch.dtype, device: str) -> Any:
+        return {
+            (torch.float, "cpu"): torch.FloatTensor,
+            (torch.bfloat16, "cpu"): torch.BFloat16Tensor,
+            (torch.half, "cpu"): torch.HalfTensor,
+            (torch.float, "cuda"): torch.cuda.FloatTensor,
+            (torch.bfloat16, "cuda"): torch.cuda.BFloat16Tensor,
+            (torch.half, "cuda"): torch.cuda.HalfTensor,
+        }[(dtype, device)]
+
diff --git a/util/pos_embed.py b/util/pos_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..1924913c1ffe7c73b889a4d3bad586ee8b3d2d7d
--- /dev/null
+++ b/util/pos_embed.py
@@ -0,0 +1,113 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+    """
+    grid_size: int of the grid height and width
+    return:
+    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+    """
+    grid_h = np.arange(grid_size, dtype=np.float32)
+    grid_w = np.arange(grid_size, dtype=np.float32)
+    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
+    grid = np.stack(grid, axis=0)
+
+    grid = grid.reshape([2, 1, grid_size, grid_size])
+    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+    if cls_token:
+        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+    return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+    assert embed_dim % 2 == 0
+
+    # use half of dimensions to encode grid_h
+    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
+    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
+
+    emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+    return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+    """
+    embed_dim: output dimension for each position
+    pos: a list of positions to be encoded: size (M,)
+    out: (M, D)
+    """
+    assert embed_dim % 2 == 0
+    omega = np.arange(embed_dim // 2, dtype=np.float)
+    omega /= embed_dim / 2.
+    omega = 1. / 10000**omega  # (D/2,)
+
+    pos = pos.reshape(-1)  # (M,)
+    out = np.einsum('m,d->md', pos, omega)  # (M, D/2), outer product
+
+    emb_sin = np.sin(out) # (M, D/2)
+    emb_cos = np.cos(out) # (M, D/2)
+
+    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
+    return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+    if 'pos_embed' in checkpoint_model:
+        pos_embed_checkpoint = checkpoint_model['pos_embed']
+        embedding_size = pos_embed_checkpoint.shape[-1]
+        num_patches = model.patch_embed.num_patches
+        num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+        # height (== width) for the checkpoint position embedding
+        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+        # height (== width) for the new position embedding
+        new_size = int(num_patches ** 0.5)
+        # class_token and dist_token are kept unchanged
+        if orig_size != new_size:
+            print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+            # only the position tokens are interpolated
+            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+            pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+            pos_tokens = torch.nn.functional.interpolate(
+                pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+            checkpoint_model['pos_embed'] = new_pos_embed
+
+
+def interpolate_pos_embed_online(
+    pos_embed, orig_size, new_size, num_extra_tokens: int
+):
+    # [257, 1024]
+    extra_tokens = pos_embed[:num_extra_tokens]
+    pos_tokens = pos_embed[num_extra_tokens:]
+    embedding_size = pos_tokens.shape[1]
+    pos_tokens = pos_tokens.reshape(
+        -1, orig_size[0], orig_size[1], embedding_size
+    ).permute(0, 3, 1, 2)
+    pos_tokens = torch.nn.functional.interpolate(
+        pos_tokens, size=new_size, mode="bicubic", align_corners=False,
+    )
+    pos_tokens = pos_tokens.permute(0, 2, 3, 1).reshape(-1, embedding_size)
+    new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
+    return new_pos_embed
\ No newline at end of file