{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "215cfd2f-62b0-4a86-a407-777a1d32597f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-01-24 15:18:49,948] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "from PIL import Image\n",
    "import requests\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
    "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
    "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2244e8f3-fcc7-4309-9d4d-fea557f89f79",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llava_phi import LlavaPhiForCausalLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "587883e1-3419-4b14-b16b-38fabbc8bfaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = LlavaPhiForCausalLM.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0e27a7db-e2ab-4d65-b21d-497222e318ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# processor = AutoProcessor.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "663efdd8-ea21-4231-a2ae-bcc0fb47b46a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prompt = \"<image>\\nUSER: What's the content of the image?\\nASSISTANT:\"\n",
    "# url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
    "# image = Image.open(requests.get(url, stream=True).raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f622609f-f6a7-4ec1-ac35-c1d33d9436ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Generate\n",
    "# generate_ids = model.generate(**inputs, max_length=30)\n",
    "# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "45f5ba72-2e41-4ccc-84c1-97d542ebee63",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llava_phi.model.builder import load_pretrained_model\n",
    "from llava_phi.mm_utils import tokenizer_image_token, get_model_name_from_path\n",
    "from llava_phi.utils import disable_torch_init\n",
    "from llava_phi.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
    "from llava_phi.conversation import conv_templates, SeparatorStyle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b98ac5d3-5503-4430-81d1-19a4f8d6bd75",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\"\n",
    "model_name = get_model_name_from_path(model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "42fd5721-75a7-475b-bd30-5ee23aeaac64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'llavaPhi-v0-3b-finetune_checkpoint-4000'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8c2076b5-3bfc-48fd-917b-5dfd06fc532f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load llaVA-Phi MLLM!!!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "20b86f2c01744081b537620c8780f12e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'device_map': 'cuda'}\n"
     ]
    }
   ],
   "source": [
    "tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4e46221e-0907-453e-8126-76199828493e",
   "metadata": {},
   "outputs": [],
   "source": [
    "qs = \"What's the content of the image?\"\n",
    "qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "07355444-0eb8-4d4d-ad50-48b91c969664",
   "metadata": {},
   "outputs": [],
   "source": [
    "conv = conv_templates[\"default\"].copy()\n",
    "conv.append_message(conv.roles[0], qs)\n",
    "conv.append_message(conv.roles[1], None)\n",
    "prompt = conv.get_prompt()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ccb5674f-aff8-456e-b61b-1d167864f1a6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <im_start><image><im_end>\\nWhat's the content of the image? ASSISTANT:\""
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a89cc181-2214-4844-b966-164a41744e54",
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
    "image = Image.open(requests.get(url, stream=True).raw)\n",
    "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
    "\n",
    "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
    "\n",
    "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "0d519851-64d4-4cf5-b2eb-19474f9aa260",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 55])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1694ff36-f214-4ed3-b2f3-d3dbd0a1a25b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
    "audio = audio_ds[0][\"audio\"]\n",
    "\n",
    "whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
    "audio_embed = whisper_w_proj(audio)[\"input_ids\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "9c4a9fae-d6ed-4fc2-ba02-97df64cddd93",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([1, 33]), device(type='cpu'))"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "audio_embed.shape, audio_embed.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "c3fffe29-98fb-4f4b-ac51-4bdda9e46752",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = torch.concat([input_ids, audio_embed.to(\"cuda:0\")], dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5dee1ec8-2db2-4f65-99e8-d34bd2735c9c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 88])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_ids.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "96033b43-4f57-4f0c-bcf7-37b57ca02e47",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.inference_mode():\n",
    "        output_ids = model.generate(\n",
    "            input_ids,\n",
    "            images=image_tensor,\n",
    "            do_sample=True,\n",
    "            temperature=0.2,\n",
    "            max_new_tokens=1024,\n",
    "            eos_token_id=tokenizer.eos_token_id,  # End of sequence token\n",
    "            pad_token_id=tokenizer.eos_token_id,  # Pad token\n",
    "            use_cache=True,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "741e8da5-0d18-4c11-b559-76054ce4ca3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "is a Japanese character from the story of Jesus, who is a Chinese monk who is also known for his teachings. The story is based on the story of the story of Jesus Christ, and it is a representation of the story of Jesus and the story of Jesus Christ.\n"
     ]
    }
   ],
   "source": [
    "input_token_len = input_ids.shape[1]\n",
    "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
    "if n_diff_input_output > 0:\n",
    "    print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
    "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
    "outputs = outputs.strip()\n",
    "if outputs.endswith(stop_str):\n",
    "    outputs = outputs[:-len(stop_str)]\n",
    "outputs = outputs.strip()\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "69d494d4-d768-4645-b4d6-5c455791b50d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a340856-a13f-4b18-9911-126a4ba37816",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c56fdea-c7a1-4e67-9832-e2ed077d8704",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "89e84d39-8ed8-45db-ae82-27c156ee6dd1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AudioLanguageConnector:\n",
    "    def __init__(self, projection_dim):\n",
    "        model_name = \"microsoft/phi-2\"\n",
    "        self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "        self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
    "        self.phi2_tokenizer.max_length = projection_dim\n",
    "\n",
    "    def __call__(self, text):\n",
    "        text = f\"<audio_start> {text} <audio_end>\"\n",
    "        tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
    "        return tokens\n",
    "        \n",
    "\n",
    "class WhisperWithProjection:\n",
    "    def __init__(self, projection_dim, device):\n",
    "        self.device = device\n",
    "        self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
    "        self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
    "        self.model.config.forced_decoder_ids = None\n",
    "        self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
    "        \n",
    "    def __call__(self, audio):\n",
    "        input_features = self.processor(audio[\"array\"],\n",
    "                                   sampling_rate=audio[\"sampling_rate\"],\n",
    "                                   return_tensors=\"pt\").input_features\n",
    "        # generate token ids\n",
    "        predicted_ids = self.model.generate(input_features.to(self.device))\n",
    "        # decode token ids to text        \n",
    "        transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
    "\n",
    "        audio_embeddings = self.audio_language_connector(transcription)\n",
    "        return audio_embeddings.to(self.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "75e24be0-b236-4047-83ef-5c344e262476",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiModalPhi2:\n",
    "    def __init__(self, model_path=\"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\",\n",
    "                temperature=0.2,\n",
    "                max_new_tokens=1024,\n",
    "                device=\"cuda\"):\n",
    "        self.temperature = temperature\n",
    "        self.max_new_tokens = max_new_tokens\n",
    "        self.device = device\n",
    "        model_name = get_model_name_from_path(model_path)\n",
    "        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name, device_map=device)\n",
    "        self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n",
    "        \n",
    "        \n",
    "    def __call__(self, text, audio, image):\n",
    "        qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n",
    "        conv = conv_templates[\"default\"].copy()\n",
    "        conv.append_message(conv.roles[0], qs)\n",
    "        conv.append_message(conv.roles[1], None)\n",
    "        prompt = conv.get_prompt()\n",
    "\n",
    "        image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
    "        \n",
    "        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
    "\n",
    "        audio_embed = self.whisper_w_proj(audio)[\"input_ids\"]\n",
    "        \n",
    "        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
    "\n",
    "        input_ids = torch.concat([input_ids, audio_embed], dim=1)\n",
    "\n",
    "        with torch.inference_mode():\n",
    "            output_ids = self.model.generate(\n",
    "                input_ids,\n",
    "                images=image_tensor,\n",
    "                do_sample=True,\n",
    "                temperature=self.temperature,\n",
    "                max_new_tokens=self.max_new_tokens,\n",
    "                eos_token_id=tokenizer.eos_token_id,  # End of sequence token\n",
    "                pad_token_id=tokenizer.eos_token_id,  # Pad token\n",
    "                use_cache=True,\n",
    "            )\n",
    "\n",
    "        input_token_len = input_ids.shape[1]\n",
    "        n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
    "        if n_diff_input_output > 0:\n",
    "            print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
    "        outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
    "        outputs = outputs.strip()\n",
    "        if outputs.endswith(stop_str):\n",
    "            outputs = outputs[:-len(stop_str)]\n",
    "        outputs = outputs.strip()\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "4efdbad4-d88a-4477-a3a0-f5591cd0b172",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load llaVA-Phi MLLM!!!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "492c17cf54f34d4d9e4f288fc9e72e79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'device_map': 'cuda'}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "multimodal_phi2 = MultiModalPhi2()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "9a6de0b0-a231-4d50-88e8-e40c6f7216c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"tell me about the audio\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "b4919948-6a75-4d19-ba95-9ba233a7d3d9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'is a popular Japanese drama series featuring a man in a red and white costume, who is dressed as Santa Claus, is walking down the street. The scene takes place in a busy city environment, with people walking and standing on the sidewalk, likely enjoying the festive atmosphere and the festive atmosphere.'"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multimodal_phi2(text, audio, image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590f2d64-62ed-4e6f-b7c8-b0cf68aecaab",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "c921eb63-feb5-4fa9-993b-2faeb6dfe1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "b470a2c4-806a-435d-9fc2-f17448dbe5fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llava_phi.model import LlavaPhiConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "4f7bc91a-0a41-45e5-92a4-daa1e3eea0da",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "993bc3a38cb84de4a2e3a79a3448c4d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "device_map = \"cuda:0\"\n",
    "load_8bit = False\n",
    "load_4bit = False\n",
    "kwargs = {\"device_map\": device_map}\n",
    "if load_8bit:\n",
    "    kwargs['load_in_8bit'] = True\n",
    "elif load_4bit:\n",
    "    kwargs['load_in_4bit'] = True\n",
    "    kwargs['quantization_config'] = BitsAndBytesConfig(\n",
    "        load_in_4bit=True,\n",
    "        bnb_4bit_compute_dtype=torch.float16,\n",
    "        bnb_4bit_use_double_quant=True,\n",
    "        bnb_4bit_quant_type='nf4'\n",
    "    )\n",
    "config = LlavaPhiConfig.from_pretrained(model_path, trust_remote_code=True)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n",
    "model = LlavaPhiForCausalLM.from_pretrained(\n",
    "    model_path, \n",
    "    config=config, \n",
    "    use_safetensors=True, \n",
    "    **kwargs).to(\"cuda\")\n",
    "image_processor = CLIPImageProcessor.from_pretrained(model_path)\n",
    "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
    "mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n",
    "\n",
    "# TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200\n",
    "if mm_use_im_patch_token:\n",
    "    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
    "if mm_use_im_start_end:\n",
    "    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
    "    \n",
    "if hasattr(model.config, \"max_sequence_length\"):\n",
    "        context_len = model.config.max_sequence_length\n",
    "else:\n",
    "    context_len = 2048"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "99355837-a297-4a25-aeb3-1670af7e9251",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[70], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLlava-Phi-Checkpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/transformers/modeling_utils.py:2376\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m   2372\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_file, shard \u001b[38;5;129;01min\u001b[39;00m shards\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m   2373\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m   2374\u001b[0m         \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m   2375\u001b[0m         \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 2376\u001b[0m         \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2377\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   2378\u001b[0m         save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n",
      "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/safetensors/torch.py:281\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m    250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m    251\u001b[0m     tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m    252\u001b[0m     filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m    253\u001b[0m     metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    254\u001b[0m ):\n\u001b[1;32m    255\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    256\u001b[0m \u001b[38;5;124;03m    Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m    257\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    279\u001b[0m \u001b[38;5;124;03m    ```\u001b[39;00m\n\u001b[1;32m    280\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m     \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "model.save_pretrained(\"Llava-Phi-Checkpoint\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa0bec34-a148-4340-a30c-6f09dd5e71ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.push_to_hub(\"RaviNaik/Llava-Phi2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "382f74b0-2967-408a-badc-a90918810d74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/RaviNaik/Llava-Phi2/commit/fa8f7240058241243f6bdc3d6ab44bb691f76e39', commit_message='Upload tokenizer', commit_description='', oid='fa8f7240058241243f6bdc3d6ab44bb691f76e39', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.push_to_hub(\"RaviNaik/Llava-Phi2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b851459b-d3ac-4fb8-99b6-17a648adc41f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}