Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,185 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- blesspearl/stackexchange-math-sample
|
5 |
+
language:
|
6 |
+
- en
|
7 |
+
library_name: transformers
|
8 |
+
---
|
9 |
+
|
10 |
+
[Guide](https://medium.com/@rajatsharma_33357/fine-tuning-llama-using-lora-fb3f48a557d5)
|
11 |
+
# Fine-Tuned LLaMA 3.1 Model on Stack Exchange Math Dataset
|
12 |
+
|
13 |
+
This repository contains the fine-tuned LLaMA 3.1 model using LoRA on a dataset collected from Stack Exchange Math. The model is designed to answer mathematical questions in a manner similar to Stack Exchange responses.
|
14 |
+
|
15 |
+
## Model Details
|
16 |
+
|
17 |
+
- **Base Model:** [Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B)
|
18 |
+
- **Fine-Tuned Model:** [math-stackexchange](https://huggingface.co/blesspearl/math-stackexchange)
|
19 |
+
- **Dataset:** [stackexchange-math-sample](https://huggingface.co/datasets/blesspearl/stackexchange-math-sample)
|
20 |
+
- **Training Environment:**
|
21 |
+
- Framework: PyTorch with Transformers
|
22 |
+
- Platform: Google Colab
|
23 |
+
- Hardware: 1 x T4 GPU (15GB)
|
24 |
+
|
25 |
+
## Data Preparation
|
26 |
+
|
27 |
+
The dataset used for fine-tuning includes 1000 samples collected from Stack Exchange Math. Each sample consists of a question and its accepted answer.
|
28 |
+
|
29 |
+
### Preprocessing
|
30 |
+
|
31 |
+
The data was preprocessed using the following steps:
|
32 |
+
1. Loading the dataset from Hugging Face.
|
33 |
+
2. Shuffling the dataset and selecting 1000 samples.
|
34 |
+
3. Formatting the data into a chat template suitable for training.
|
35 |
+
|
36 |
+
## Training Details
|
37 |
+
|
38 |
+
### Libraries and Dependencies
|
39 |
+
|
40 |
+
```python
|
41 |
+
from datasets import load_dataset
|
42 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline, logging
|
43 |
+
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
|
44 |
+
from google.colab import drive, userdata
|
45 |
+
import os, torch, wandb
|
46 |
+
from trl import SFTTrainer, setup_chat_format
|
47 |
+
from huggingface_hub import login
|
48 |
+
```
|
49 |
+
|
50 |
+
### Loading Data and Model
|
51 |
+
|
52 |
+
```python
|
53 |
+
model_name = "meta-llama/Meta-Llama-3.1-8B"
|
54 |
+
dataset_name = "blesspearl/stackexchange-math-sample"
|
55 |
+
|
56 |
+
torch_dtype = torch.float16
|
57 |
+
attn_implementation = "eager"
|
58 |
+
wandb.login(key=userdata.get("WANDB_API_KEY"))
|
59 |
+
run = wandb.init(
|
60 |
+
project='Fine tunning LLama-3.1-8b on math-stack-exchange',
|
61 |
+
job_type="training",
|
62 |
+
anonymous="allow"
|
63 |
+
)
|
64 |
+
|
65 |
+
bnb_config = BitsAndBytesConfig(
|
66 |
+
load_in_4bit=True,
|
67 |
+
bnb_4bit_quant_type="nf4",
|
68 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
69 |
+
bnb_4bit_use_double_quant=True,
|
70 |
+
)
|
71 |
+
|
72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
73 |
+
model_name,
|
74 |
+
quantization_config=bnb_config,
|
75 |
+
device_map="auto",
|
76 |
+
attn_implementation=attn_implementation
|
77 |
+
)
|
78 |
+
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
80 |
+
model, tokenizer = setup_chat_format(model, tokenizer)
|
81 |
+
```
|
82 |
+
|
83 |
+
### LoRA Configuration
|
84 |
+
|
85 |
+
```python
|
86 |
+
peft_config = LoraConfig(
|
87 |
+
r=16,
|
88 |
+
lora_alpha=32,
|
89 |
+
lora_dropout=0.05,
|
90 |
+
bias="none",
|
91 |
+
task_type="CAUSAL_LM",
|
92 |
+
target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
|
93 |
+
)
|
94 |
+
model = get_peft_model(model, peft_config)
|
95 |
+
```
|
96 |
+
|
97 |
+
### Data Preparation
|
98 |
+
|
99 |
+
```python
|
100 |
+
dataset = load_dataset(dataset_name, split="all")
|
101 |
+
dataset = dataset.shuffle(seed=65).select(range(1000))
|
102 |
+
|
103 |
+
def format_chat_template(row):
|
104 |
+
row_json = [{"role": "user", "content": row["question_body"]},
|
105 |
+
{"role": "assistant", "content": row["accepted_answer"]}]
|
106 |
+
row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
|
107 |
+
return row
|
108 |
+
|
109 |
+
dataset = dataset.map(format_chat_template, num_proc=4)
|
110 |
+
dataset = dataset.train_test_split(test_size=0.2)
|
111 |
+
dataset = dataset.remove_columns(["question_body", "accepted_answer"])
|
112 |
+
```
|
113 |
+
|
114 |
+
### Training Configuration
|
115 |
+
|
116 |
+
```python
|
117 |
+
training_arguments = TrainingArguments(
|
118 |
+
output_dir="math-stackexchange",
|
119 |
+
per_device_train_batch_size=1,
|
120 |
+
per_device_eval_batch_size=1,
|
121 |
+
gradient_accumulation_steps=2,
|
122 |
+
optim="paged_adamw_32bit",
|
123 |
+
num_train_epochs=1,
|
124 |
+
evaluation_strategy="steps",
|
125 |
+
eval_steps=0.2,
|
126 |
+
logging_steps=1,
|
127 |
+
warmup_steps=10,
|
128 |
+
logging_strategy="steps",
|
129 |
+
learning_rate=2e-4,
|
130 |
+
fp16=False,
|
131 |
+
bf16=False,
|
132 |
+
group_by_length=True,
|
133 |
+
report_to="wandb"
|
134 |
+
)
|
135 |
+
|
136 |
+
trainer = SFTTrainer(
|
137 |
+
model=model,
|
138 |
+
train_dataset=dataset["train"],
|
139 |
+
eval_dataset=dataset["test"],
|
140 |
+
peft_config=peft_config,
|
141 |
+
max_seq_length=512,
|
142 |
+
dataset_text_field="text",
|
143 |
+
tokenizer=tokenizer,
|
144 |
+
args=training_arguments,
|
145 |
+
packing=False,
|
146 |
+
)
|
147 |
+
trainer.train()
|
148 |
+
wandb.finish()
|
149 |
+
model.config.use_cache = True
|
150 |
+
```
|
151 |
+
|
152 |
+
## Model and Dataset
|
153 |
+
|
154 |
+
- **Model:** [math-stackexchange](https://huggingface.co/blesspearl/math-stackexchange)
|
155 |
+
- **Dataset:** [stackexchange-math-sample](https://huggingface.co/datasets/blesspearl/stackexchange-math-sample)
|
156 |
+
|
157 |
+
## Usage
|
158 |
+
|
159 |
+
To use the fine-tuned model for inference, you can load it using the Hugging Face Transformers library and pass in your data for querying.
|
160 |
+
|
161 |
+
### Example Code
|
162 |
+
|
163 |
+
```python
|
164 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
165 |
+
|
166 |
+
model_name = "blesspearl/math-stackexchange"
|
167 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
168 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
169 |
+
|
170 |
+
def answer_question(question):
|
171 |
+
inputs = tokenizer(question, return_tensors="pt")
|
172 |
+
outputs = model.generate(**inputs)
|
173 |
+
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
174 |
+
return answer
|
175 |
+
|
176 |
+
question = "What is the derivative of sin(x)?"
|
177 |
+
answer = answer_question(question)
|
178 |
+
print(answer)
|
179 |
+
```
|
180 |
+
|
181 |
+
## Conclusion
|
182 |
+
|
183 |
+
This documentation provides an overview of the fine-tuning process of the LLaMA 3.1 model using LoRA on the Stack Exchange Math dataset. The model and dataset are available on Hugging Face for further use and exploration.
|
184 |
+
|
185 |
+
For any questions or issues, feel free to open an issue on the [model repository](https://huggingface.co/blesspearl/math-stackexchange).
|