File size: 3,817 Bytes
70d3bb7
 
 
 
 
 
7015d88
 
 
 
 
 
70d3bb7
 
 
 
 
 
0926a8f
 
70d3bb7
 
 
 
 
 
 
 
61c6c42
70d3bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
---
base_model: meta-llama/Meta-Llama-3.1-70B-Instruct
library_name: transformers
license: llama3.1
pipeline_tag: text-generation
tags:
- facebook
- meta
- pytorch
- pruning
- llama
- llama-3
---

## Model Information
The Llama 3.1 text only 41B model is pruned from Llama 3.1 instruction finetuned text only 70B
using [FLAP method](arxiv.org/abs/2312.11983).

> TL;DR No under maintenance. Bad performance, no value. Side product of experiment.

Hyper parameters used for pruning:
```
metrics: WIFV
structure: AL-AM
pruning_ratio: 0.5
```

## Limitation
This `llama3.1-41B-raw` model gives unstable output.
A finetune on instruction dataset is recommended.

The model is not supported by any library at the moment
due to its unconsistent shape between layers after pruning.

## Usage
The model is not supported by any library at the moment,
following is a workaround.
```python
from functools import reduce
def get_module_by_name(module, access_string):
    names = access_string.split(sep='.')
    return reduce(getattr, names, module)

import json
from safetensors import safe_open
from transformers import LlamaForCausalLM
class MyLlamaForCausalLM(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        with open(os.path.join(
                config._name_or_path,
                "model.safetensors.index.json")) as f:
            weight_map = json.load(f)
            weight_map = weight_map["weight_map"]
        for name, path in weight_map.items():
            module_name = name.replace('.weight', '')
            if '.bias' in module_name:
                continue
            layer = get_module_by_name(self, module_name)
            with safe_open(
                os.path.join(
                    config._name_or_path,
                    path), framework="pt") as f:
                tensor = f.get_tensor(name)
            if 'mlp.' in name or 'attn.' in name:
                if tensor.shape != (layer.out_features, layer.in_features):
                    layer = layer.__init__(
                        tensor.shape[1],
                        tensor.shape[0],
                        bias=layer.bias,
                        dtype=layer.weight.dtype,
                        device=layer.weight.device)
        for name, path in weight_map.items():
            if 'attn.' in name:
                module = get_module_by_name(
                    self,
                    '.'.join(name.split('.')[:-2]))
                module.num_heads = module.q_proj.out_features // module.head_dim
                module.num_key_value_heads = module.num_heads
                module.num_key_value_groups = module.num_heads // module.num_key_value_heads


model = MyLlamaForCausalLM.from_pretrained(
    "npc0/llama3.1-41B-raw",
    torch_dtype=torch.float16, 
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(
    "FLAP/llm_weights/flap_p0.5_WIFV_ALAM_llama_70b") 
model = model.eval()

messages = [ 
    {"role": "system", "content": "You are a helpful AI assistant."}, 
    {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}, 
    {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."}, 
    {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"}, 
] 

model_inputs = tokenizer.apply_chat_template(messages,
                                             return_tensors="pt").to(model.device)
generated_ids = model.generate(model_inputs, max_new_tokens=128)
decoded = tokenizer.batch_decode(generated_ids)
print(decoded[0])
```