File size: 5,446 Bytes
c6c8d88 1ef781b 8433676 c6c8d88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
---
language:
- en
- de
- fr
- it
- pt
- hi
- es
- th
library_name: transformers
pipeline_tag: image-text-to-text
tags:
- meta
- pytorch
- llama
- llama-3
- vision
base_model:
- meta-llama/Llama-3.2-11B-Vision-Instruct
- rombodawg/Llama-3-8B-Instruct-Coder
---
# Llama-3-8B-Instruct-Coder + Llama3.2Vision Adapter
This model was created using the script below. It is compatible with:
* Llama 3.1 8B & 70B
Respectively
* Llama Vision 3.2 11B & 90B
## Merge Script
```python
from transformers import MllamaForConditionalGeneration, MllamaProcessor, AutoModelForCausalLM
# NOTE: You need sufficient DRAM to load both models at once (otherwise, need to process layer by layer which is not shown here)
multimodal_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct" # Original Llama vision model (11B or 90B)
text_model_path = "rombodawg/Llama-3-8B-Instruct-Coder" # Model to be merged (8B or 70B)
save_path = "models/merged_model"
multimodal_model = MllamaForConditionalGeneration.from_pretrained(multimodal_model_path, device_map="cpu", torch_dtype=torch.bfloat16)
multimodal_processor = MllamaProcessor.from_pretrained(multimodal_model_path)
text_model = AutoModelForCausalLM.from_pretrained(text_model_path, device_map="cpu", torch_dtype=torch.bfloat16)
state_dict_multimodal = multimodal_model.state_dict()
state_dict_text = text_model.state_dict()
num_decoder_layers_text = text_model.config.num_hidden_layers
num_decoder_layers_vision = multimodal_model.config.text_config.num_hidden_layers
# Find the list of inserted layers in multimodal Llama
inserted_layers = set()
for key_multimodal in state_dict_multimodal.keys():
if "language_model" in key_multimodal and "cross_attn" in key_multimodal and ".layers." in key_multimodal:
layer_num_multimodal = int(key_multimodal.split(".layers.")[1].split(".")[0]) if ".layers." in key_multimodal else None
if layer_num_multimodal is not None: inserted_layers.add(layer_num_multimodal)
# Here are the hard-coded list of layers added:
# inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58, 63, 68, 73, 78, 83, 88, 93, 98} $ For 90B
inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38} # For 11B
assert len(inserted_layers) == num_decoder_layers_vision - num_decoder_layers_text, "# of added layers do not match"
# Build decoder layer map from multimodal layer# to text layer#, skipping layers listed in inserted_layers
layer_map = dict()
layer_num_multimodal = 0
for layer_num_text in range(num_decoder_layers_text):
while layer_num_multimodal in inserted_layers: layer_num_multimodal += 1 # Increment to skip mismatched layers
layer_map[layer_num_multimodal] = layer_num_text
layer_num_multimodal += 1
for key_multimodal in state_dict_multimodal.keys():
if "language_model" not in key_multimodal: continue # A multi-modal param
if "cross_attn" in key_multimodal: continue # A multi-modal param
key_text = key_multimodal.replace("language_model.", "")
if "embed_tokens.weight" in key_multimodal: # Handle embed tokens separately
assert key_text in state_dict_text, f"Key not found: {key_text}"
extra_tokens = state_dict_multimodal[key_multimodal].shape[0] - state_dict_text[key_text].shape[0]
state_dict_multimodal[key_multimodal][:state_dict_text[key_text].shape[0], :].copy_(state_dict_text[key_text])
print(f"Replaced {key_multimodal} with {key_text} (preserving last {extra_tokens} tokens)")
continue
if "lm_head" in key_multimodal or "model.norm.weight" in key_multimodal: # Handle other non-decoder layers separately
assert key_text in state_dict_text, f"Key not found: {key_text}"
state_dict_multimodal[key_multimodal].copy_(state_dict_text[key_text])
print(f"Replaced {key_multimodal} with {key_text}")
continue
layer_num_multimodal = int(key_multimodal.split(".layers.")[1].split(".")[0]) if ".layers." in key_multimodal else None
assert layer_num_multimodal is not None, f"Unknown non-decoder key encountered: {key_multimodal}"
if layer_num_multimodal in inserted_layers: continue # Skip mismatched layers
assert layer_num_multimodal in layer_map, f"Layer not found in layer_map: {layer_num_multimodal}"
layer_num_text = layer_map[layer_num_multimodal]
key_text = key_text.replace(f".layers.{layer_num_multimodal}.", f".layers.{layer_num_text}.")
assert key_text in state_dict_text, f"Key not found: {key_text}"
state_dict_multimodal[key_multimodal].copy_(state_dict_text[key_text])
print(f"Replaced {key_multimodal} with {key_text}")
print("Merged model successfully. Saving...")
# Apply the changes
multimodal_model.load_state_dict(state_dict_multimodal)
# Create save_path if it does not exist
os.makedirs(save_path, exist_ok=True)
multimodal_model.save_pretrained(save_path, safe_serialization=True, max_shard_size="8192MB")
multimodal_processor.save_pretrained(save_path)
print(f"Model saved to {save_path}")
```
## Model Inference:
```python
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
model_id = "rombodawg/Llama-3-8B-Instruct-Coder"
model = MllamaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
```
## License
This project is licensed under the MIT License. |