File size: 5,446 Bytes
c6c8d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1ef781b
8433676
c6c8d88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
---
language:
  - en
  - de
  - fr
  - it
  - pt
  - hi
  - es
  - th
library_name: transformers
pipeline_tag: image-text-to-text
tags:
  - meta
  - pytorch
  - llama
  - llama-3
  - vision
base_model:
- meta-llama/Llama-3.2-11B-Vision-Instruct
- rombodawg/Llama-3-8B-Instruct-Coder
---
# Llama-3-8B-Instruct-Coder + Llama3.2Vision Adapter

This model was created using the script below. It is compatible with:

* Llama 3.1 8B & 70B
  
Respectively

* Llama Vision 3.2 11B & 90B
  
## Merge Script

```python
from transformers import MllamaForConditionalGeneration, MllamaProcessor, AutoModelForCausalLM

# NOTE: You need sufficient DRAM to load both models at once (otherwise, need to process layer by layer which is not shown here)

multimodal_model_path = "meta-llama/Llama-3.2-11B-Vision-Instruct"  # Original Llama vision model (11B or 90B)
text_model_path = "rombodawg/Llama-3-8B-Instruct-Coder"  # Model to be merged (8B or 70B)
save_path = "models/merged_model"

multimodal_model = MllamaForConditionalGeneration.from_pretrained(multimodal_model_path, device_map="cpu", torch_dtype=torch.bfloat16)
multimodal_processor = MllamaProcessor.from_pretrained(multimodal_model_path)
text_model = AutoModelForCausalLM.from_pretrained(text_model_path, device_map="cpu", torch_dtype=torch.bfloat16)

state_dict_multimodal = multimodal_model.state_dict()
state_dict_text = text_model.state_dict()

num_decoder_layers_text = text_model.config.num_hidden_layers
num_decoder_layers_vision = multimodal_model.config.text_config.num_hidden_layers

# Find the list of inserted layers in multimodal Llama
inserted_layers = set()
for key_multimodal in state_dict_multimodal.keys():
    if "language_model" in key_multimodal and "cross_attn" in key_multimodal and ".layers." in key_multimodal:
        layer_num_multimodal = int(key_multimodal.split(".layers.")[1].split(".")[0]) if ".layers." in key_multimodal else None
        if layer_num_multimodal is not None: inserted_layers.add(layer_num_multimodal)
# Here are the hard-coded list of layers added:
# inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38, 43, 48, 53, 58, 63, 68, 73, 78, 83, 88, 93, 98} $ For 90B
inserted_layers = {3, 8, 13, 18, 23, 28, 33, 38}  # For 11B

assert len(inserted_layers) == num_decoder_layers_vision - num_decoder_layers_text, "# of added layers do not match"

# Build decoder layer map from multimodal layer# to text layer#, skipping layers listed in inserted_layers
layer_map = dict()
layer_num_multimodal = 0
for layer_num_text in range(num_decoder_layers_text):
    while layer_num_multimodal in inserted_layers: layer_num_multimodal += 1  # Increment to skip mismatched layers
    layer_map[layer_num_multimodal] = layer_num_text
    layer_num_multimodal += 1

for key_multimodal in state_dict_multimodal.keys():
    if "language_model" not in key_multimodal: continue  # A multi-modal param
    if "cross_attn" in key_multimodal: continue  # A multi-modal param
    key_text = key_multimodal.replace("language_model.", "")
    if "embed_tokens.weight" in key_multimodal:  # Handle embed tokens separately
        assert key_text in state_dict_text, f"Key not found: {key_text}"
        extra_tokens = state_dict_multimodal[key_multimodal].shape[0] - state_dict_text[key_text].shape[0]
        state_dict_multimodal[key_multimodal][:state_dict_text[key_text].shape[0], :].copy_(state_dict_text[key_text])
        print(f"Replaced {key_multimodal} with {key_text} (preserving last {extra_tokens} tokens)")
        continue
    if "lm_head" in key_multimodal or "model.norm.weight" in key_multimodal:  # Handle other non-decoder layers separately
        assert key_text in state_dict_text, f"Key not found: {key_text}"
        state_dict_multimodal[key_multimodal].copy_(state_dict_text[key_text])
        print(f"Replaced {key_multimodal} with {key_text}")
        continue
    layer_num_multimodal = int(key_multimodal.split(".layers.")[1].split(".")[0]) if ".layers." in key_multimodal else None
    assert layer_num_multimodal is not None, f"Unknown non-decoder key encountered: {key_multimodal}"
    if layer_num_multimodal in inserted_layers: continue  # Skip mismatched layers
    assert layer_num_multimodal in layer_map, f"Layer not found in layer_map: {layer_num_multimodal}"
    layer_num_text = layer_map[layer_num_multimodal]
    key_text = key_text.replace(f".layers.{layer_num_multimodal}.", f".layers.{layer_num_text}.")
    assert key_text in state_dict_text, f"Key not found: {key_text}"
    state_dict_multimodal[key_multimodal].copy_(state_dict_text[key_text])
    print(f"Replaced {key_multimodal} with {key_text}")

print("Merged model successfully. Saving...")
# Apply the changes
multimodal_model.load_state_dict(state_dict_multimodal)

# Create save_path if it does not exist
os.makedirs(save_path, exist_ok=True)
multimodal_model.save_pretrained(save_path, safe_serialization=True, max_shard_size="8192MB")
multimodal_processor.save_pretrained(save_path)
print(f"Model saved to {save_path}")
```

## Model Inference:

```python
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "rombodawg/Llama-3-8B-Instruct-Coder"

model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)
```

## License

This project is licensed under the MIT License.