Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- .ipynb_checkpoints/FalconDataSet-checkpoint.ipynb +394 -0
- .ipynb_checkpoints/language_modeling-checkpoint.ipynb +1186 -0
- .ipynb_checkpoints/language_modeling-checkpoint.py +187 -0
- FalconData.csv +3 -0
- FalconData2.csv +3 -0
- FalconDataSet.ipynb +717 -0
- FalconData_train.csv +3 -0
- FalconData_train2.csv +3 -0
- FalconData_validation.csv +0 -0
- FalconData_validation2.csv +0 -0
- LICENSE +21 -0
- README.md +71 -0
- language_modeling.ipynb +932 -0
- language_modeling.py +187 -0
- short_gpt/.ipynb_checkpoints/short_hf-checkpoint.ipynb +1679 -0
- short_gpt/.ipynb_checkpoints/short_llama-checkpoint.py +219 -0
- short_gpt/layer_removal.py +23 -0
- short_gpt/metrics.py +26 -0
- short_gpt/short_hf.ipynb +1679 -0
- short_gpt/short_llama.ipynb +573 -0
- short_gpt/short_llama.py +219 -0
.gitattributes
CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
FalconData.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
+
FalconData2.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
FalconData_train.csv filter=lfs diff=lfs merge=lfs -text
|
39 |
+
FalconData_train2.csv filter=lfs diff=lfs merge=lfs -text
|
.ipynb_checkpoints/FalconDataSet-checkpoint.ipynb
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 7,
|
6 |
+
"id": "460d90da-b986-4c1c-8a66-eab144b0ba8d",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Fetched data for all the Pages.\n",
|
14 |
+
"Fetched data for all the Pages.\n",
|
15 |
+
"Fetched data for all the Pages.\n",
|
16 |
+
"Fetched data for all the Pages.\n",
|
17 |
+
"Fetched data for all the Pages.\n",
|
18 |
+
"Fetched data for all the Pages.\n",
|
19 |
+
"Fetched data for all the Pages.\n",
|
20 |
+
"Fetched data for all the Pages.\n",
|
21 |
+
"Fetched data for all the Pages.\n",
|
22 |
+
"Fetched data for all the Pages.\n"
|
23 |
+
]
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"source": [
|
27 |
+
"import requests\n",
|
28 |
+
"import time\n",
|
29 |
+
"\n",
|
30 |
+
"import random\n",
|
31 |
+
"pages = [\n",
|
32 |
+
" random.randint(1, 968000015)\n",
|
33 |
+
" for _ in range(10)\n",
|
34 |
+
" ]\n",
|
35 |
+
"# print(pages)\n",
|
36 |
+
"\n",
|
37 |
+
"base_url = \"https://datasets-server.huggingface.co/rows\"\n",
|
38 |
+
"params = {\n",
|
39 |
+
" \"dataset\": \"tiiuae/falcon-refinedweb\",\n",
|
40 |
+
" \"config\": \"default\",\n",
|
41 |
+
" \"split\": \"train\",\n",
|
42 |
+
" }\n",
|
43 |
+
"# response = requests.get(base_url, params=params)\n",
|
44 |
+
"# response.raise_for_status()\n",
|
45 |
+
"# for row in response.json()[\"rows\"]:\n",
|
46 |
+
"# content = row[\"row\"][\"content\"]\n",
|
47 |
+
"num_rows_per_page = 100\n",
|
48 |
+
"retry_limit = 10\n",
|
49 |
+
"retry_delay = 5\n",
|
50 |
+
"Falcon = []\n",
|
51 |
+
"\n",
|
52 |
+
"def fetch_data_for_page(page):\n",
|
53 |
+
" params[\"offset\"] = page\n",
|
54 |
+
" params[\"limit\"] = num_rows_per_page\n",
|
55 |
+
" attempt = 0\n",
|
56 |
+
" while attempt < retry_limit:\n",
|
57 |
+
" try:\n",
|
58 |
+
" response = requests.get(base_url, params=params)\n",
|
59 |
+
" response.raise_for_status() # This will raise an HTTPError if the HTTP request returned an unsuccessful status code\n",
|
60 |
+
" for row in response.json()[\"rows\"]:\n",
|
61 |
+
" content = row[\"row\"][\"content\"]\n",
|
62 |
+
" Falcon.append(content)\n",
|
63 |
+
" len(Falcon)\n",
|
64 |
+
" print(f\"Fetched data for all the Pages.\")\n",
|
65 |
+
" break\n",
|
66 |
+
" except requests.exceptions.HTTPError as e:\n",
|
67 |
+
" attempt += 1\n",
|
68 |
+
" print(\n",
|
69 |
+
" f\"Failed to fetch data, retrying. Attempt {attempt}/{retry_limit}\"\n",
|
70 |
+
" )\n",
|
71 |
+
" if attempt < retry_limit:\n",
|
72 |
+
" time.sleep(retry_delay) # Wait before the next retry\n",
|
73 |
+
" else:\n",
|
74 |
+
" print(\n",
|
75 |
+
" \"Maximum retry limit reached. Unable to fetch data.\"\n",
|
76 |
+
" )\n",
|
77 |
+
" raise\n",
|
78 |
+
"\n",
|
79 |
+
"for page in pages:\n",
|
80 |
+
" fetch_data_for_page(page)\n",
|
81 |
+
"\n"
|
82 |
+
]
|
83 |
+
},
|
84 |
+
{
|
85 |
+
"cell_type": "code",
|
86 |
+
"execution_count": 8,
|
87 |
+
"id": "f8f3baf1-5480-450b-a456-174a5c114d3e",
|
88 |
+
"metadata": {},
|
89 |
+
"outputs": [],
|
90 |
+
"source": [
|
91 |
+
"import csv\n",
|
92 |
+
"\n",
|
93 |
+
"# Open the CSV file for writing\n",
|
94 |
+
"with open(\"FalconDataEval2.csv\", \"w\", newline=\"\") as csvfile:\n",
|
95 |
+
" # Create a CSV writer object\n",
|
96 |
+
" writer = csv.writer(csvfile)\n",
|
97 |
+
"\n",
|
98 |
+
" # Write the header row\n",
|
99 |
+
" writer.writerow([\"Text\"])\n",
|
100 |
+
"\n",
|
101 |
+
" # Write each element in the list as a row in the CSV file\n",
|
102 |
+
" for element in Falcon:\n",
|
103 |
+
" writer.writerow([element])\n"
|
104 |
+
]
|
105 |
+
},
|
106 |
+
{
|
107 |
+
"cell_type": "code",
|
108 |
+
"execution_count": 9,
|
109 |
+
"id": "ea47c936-2c2b-4414-ba57-74fb6827ec0a",
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [
|
112 |
+
{
|
113 |
+
"name": "stdout",
|
114 |
+
"output_type": "stream",
|
115 |
+
"text": [
|
116 |
+
"Number of duplicate rows: 0\n",
|
117 |
+
"Empty DataFrame\n",
|
118 |
+
"Columns: [Text]\n",
|
119 |
+
"Index: []\n"
|
120 |
+
]
|
121 |
+
}
|
122 |
+
],
|
123 |
+
"source": [
|
124 |
+
"import pandas as pd\n",
|
125 |
+
"\n",
|
126 |
+
"# Read the CSV file into a pandas DataFrame\n",
|
127 |
+
"df = pd.read_csv(\"FalconDataEval2.csv\")\n",
|
128 |
+
"\n",
|
129 |
+
"# Check for duplicate rows\n",
|
130 |
+
"duplicate_rows = df[df.duplicated()]\n",
|
131 |
+
"\n",
|
132 |
+
"# Print the number of duplicate rows\n",
|
133 |
+
"print(f\"Number of duplicate rows: {len(duplicate_rows)}\")\n",
|
134 |
+
"\n",
|
135 |
+
"# Print the duplicate rows\n",
|
136 |
+
"print(duplicate_rows)"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"execution_count": 10,
|
142 |
+
"id": "f4178cd6-747f-4e05-a9bf-17b97f959e06",
|
143 |
+
"metadata": {},
|
144 |
+
"outputs": [
|
145 |
+
{
|
146 |
+
"data": {
|
147 |
+
"text/html": [
|
148 |
+
"<div>\n",
|
149 |
+
"<style scoped>\n",
|
150 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
151 |
+
" vertical-align: middle;\n",
|
152 |
+
" }\n",
|
153 |
+
"\n",
|
154 |
+
" .dataframe tbody tr th {\n",
|
155 |
+
" vertical-align: top;\n",
|
156 |
+
" }\n",
|
157 |
+
"\n",
|
158 |
+
" .dataframe thead th {\n",
|
159 |
+
" text-align: right;\n",
|
160 |
+
" }\n",
|
161 |
+
"</style>\n",
|
162 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
163 |
+
" <thead>\n",
|
164 |
+
" <tr style=\"text-align: right;\">\n",
|
165 |
+
" <th></th>\n",
|
166 |
+
" <th>Text</th>\n",
|
167 |
+
" </tr>\n",
|
168 |
+
" </thead>\n",
|
169 |
+
" <tbody>\n",
|
170 |
+
" <tr>\n",
|
171 |
+
" <th>0</th>\n",
|
172 |
+
" <td>Our Annual Garden Party is a fun-filled event ...</td>\n",
|
173 |
+
" </tr>\n",
|
174 |
+
" <tr>\n",
|
175 |
+
" <th>1</th>\n",
|
176 |
+
" <td>Photos by Philip Cosores\\n“There were many poi...</td>\n",
|
177 |
+
" </tr>\n",
|
178 |
+
" <tr>\n",
|
179 |
+
" <th>2</th>\n",
|
180 |
+
" <td>Media Matters Also Wants To Throw Out The Firs...</td>\n",
|
181 |
+
" </tr>\n",
|
182 |
+
" <tr>\n",
|
183 |
+
" <th>3</th>\n",
|
184 |
+
" <td>[More]\\nWhile bringing in your own cup is fine...</td>\n",
|
185 |
+
" </tr>\n",
|
186 |
+
" <tr>\n",
|
187 |
+
" <th>4</th>\n",
|
188 |
+
" <td>Read at : Google Alert – gardening\\nHow to Bui...</td>\n",
|
189 |
+
" </tr>\n",
|
190 |
+
" </tbody>\n",
|
191 |
+
"</table>\n",
|
192 |
+
"</div>"
|
193 |
+
],
|
194 |
+
"text/plain": [
|
195 |
+
" Text\n",
|
196 |
+
"0 Our Annual Garden Party is a fun-filled event ...\n",
|
197 |
+
"1 Photos by Philip Cosores\\n“There were many poi...\n",
|
198 |
+
"2 Media Matters Also Wants To Throw Out The Firs...\n",
|
199 |
+
"3 [More]\\nWhile bringing in your own cup is fine...\n",
|
200 |
+
"4 Read at : Google Alert – gardening\\nHow to Bui..."
|
201 |
+
]
|
202 |
+
},
|
203 |
+
"execution_count": 10,
|
204 |
+
"metadata": {},
|
205 |
+
"output_type": "execute_result"
|
206 |
+
}
|
207 |
+
],
|
208 |
+
"source": [
|
209 |
+
"df.head()"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": 11,
|
215 |
+
"id": "264548c1-4cf4-441f-a433-2f5d57861dc4",
|
216 |
+
"metadata": {},
|
217 |
+
"outputs": [
|
218 |
+
{
|
219 |
+
"data": {
|
220 |
+
"text/html": [
|
221 |
+
"<div>\n",
|
222 |
+
"<style scoped>\n",
|
223 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
224 |
+
" vertical-align: middle;\n",
|
225 |
+
" }\n",
|
226 |
+
"\n",
|
227 |
+
" .dataframe tbody tr th {\n",
|
228 |
+
" vertical-align: top;\n",
|
229 |
+
" }\n",
|
230 |
+
"\n",
|
231 |
+
" .dataframe thead th {\n",
|
232 |
+
" text-align: right;\n",
|
233 |
+
" }\n",
|
234 |
+
"</style>\n",
|
235 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
236 |
+
" <thead>\n",
|
237 |
+
" <tr style=\"text-align: right;\">\n",
|
238 |
+
" <th></th>\n",
|
239 |
+
" <th>Text</th>\n",
|
240 |
+
" </tr>\n",
|
241 |
+
" </thead>\n",
|
242 |
+
" <tbody>\n",
|
243 |
+
" <tr>\n",
|
244 |
+
" <th>995</th>\n",
|
245 |
+
" <td>The Ketologic review Diaries\\nShould you have ...</td>\n",
|
246 |
+
" </tr>\n",
|
247 |
+
" <tr>\n",
|
248 |
+
" <th>996</th>\n",
|
249 |
+
" <td>A pack of hand cooked sea salted and red wine ...</td>\n",
|
250 |
+
" </tr>\n",
|
251 |
+
" <tr>\n",
|
252 |
+
" <th>997</th>\n",
|
253 |
+
" <td>この広告は、90日以上更新していないブログに表示しています。\\nsniperspy free...</td>\n",
|
254 |
+
" </tr>\n",
|
255 |
+
" <tr>\n",
|
256 |
+
" <th>998</th>\n",
|
257 |
+
" <td>Arthur Koestler - Wikipedia.\\nEssay - Merriam-...</td>\n",
|
258 |
+
" </tr>\n",
|
259 |
+
" <tr>\n",
|
260 |
+
" <th>999</th>\n",
|
261 |
+
" <td>Serving Software Downloads in 976 Categories, ...</td>\n",
|
262 |
+
" </tr>\n",
|
263 |
+
" </tbody>\n",
|
264 |
+
"</table>\n",
|
265 |
+
"</div>"
|
266 |
+
],
|
267 |
+
"text/plain": [
|
268 |
+
" Text\n",
|
269 |
+
"995 The Ketologic review Diaries\\nShould you have ...\n",
|
270 |
+
"996 A pack of hand cooked sea salted and red wine ...\n",
|
271 |
+
"997 この広告は、90日以上更新していないブログに表示しています。\\nsniperspy free...\n",
|
272 |
+
"998 Arthur Koestler - Wikipedia.\\nEssay - Merriam-...\n",
|
273 |
+
"999 Serving Software Downloads in 976 Categories, ..."
|
274 |
+
]
|
275 |
+
},
|
276 |
+
"execution_count": 11,
|
277 |
+
"metadata": {},
|
278 |
+
"output_type": "execute_result"
|
279 |
+
}
|
280 |
+
],
|
281 |
+
"source": [
|
282 |
+
"df.tail()"
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "code",
|
287 |
+
"execution_count": 6,
|
288 |
+
"id": "be5a87a8-cfee-4f63-992e-8fa1d4a5cdbb",
|
289 |
+
"metadata": {},
|
290 |
+
"outputs": [
|
291 |
+
{
|
292 |
+
"data": {
|
293 |
+
"text/plain": [
|
294 |
+
"Text To imagine delaying myself. Hard cock, selling...\n",
|
295 |
+
"Name: 48, dtype: object"
|
296 |
+
]
|
297 |
+
},
|
298 |
+
"execution_count": 6,
|
299 |
+
"metadata": {},
|
300 |
+
"output_type": "execute_result"
|
301 |
+
}
|
302 |
+
],
|
303 |
+
"source": [
|
304 |
+
"target_row=48\n",
|
305 |
+
"specific_row = df.iloc[target_row]\n",
|
306 |
+
"specific_row"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": 13,
|
312 |
+
"id": "e97d9e18-eaa0-4a1b-96ab-c89a0f4c738d",
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [
|
315 |
+
{
|
316 |
+
"name": "stdout",
|
317 |
+
"output_type": "stream",
|
318 |
+
"text": [
|
319 |
+
"Text The old wireline Bell telephone system was bui...\n",
|
320 |
+
"Name: 19995, dtype: object\n"
|
321 |
+
]
|
322 |
+
}
|
323 |
+
],
|
324 |
+
"source": [
|
325 |
+
"print(specific_row)"
|
326 |
+
]
|
327 |
+
},
|
328 |
+
{
|
329 |
+
"cell_type": "code",
|
330 |
+
"execution_count": 14,
|
331 |
+
"id": "940ef35f-7517-403d-9f42-73760182dcaa",
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [
|
334 |
+
{
|
335 |
+
"name": "stdout",
|
336 |
+
"output_type": "stream",
|
337 |
+
"text": [
|
338 |
+
"Text The old wireline Bell telephone system was bui...\n"
|
339 |
+
]
|
340 |
+
}
|
341 |
+
],
|
342 |
+
"source": [
|
343 |
+
"print(specific_row.to_string())"
|
344 |
+
]
|
345 |
+
},
|
346 |
+
{
|
347 |
+
"cell_type": "code",
|
348 |
+
"execution_count": 17,
|
349 |
+
"id": "915ac669-718f-47f5-b175-a5f928b407db",
|
350 |
+
"metadata": {},
|
351 |
+
"outputs": [
|
352 |
+
{
|
353 |
+
"name": "stdout",
|
354 |
+
"output_type": "stream",
|
355 |
+
"text": [
|
356 |
+
"57\n"
|
357 |
+
]
|
358 |
+
}
|
359 |
+
],
|
360 |
+
"source": [
|
361 |
+
"print(len(specific_row.to_string()))"
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"cell_type": "code",
|
366 |
+
"execution_count": null,
|
367 |
+
"id": "ab5ee254-9ba7-496b-97c7-3b6185c21971",
|
368 |
+
"metadata": {},
|
369 |
+
"outputs": [],
|
370 |
+
"source": []
|
371 |
+
}
|
372 |
+
],
|
373 |
+
"metadata": {
|
374 |
+
"kernelspec": {
|
375 |
+
"display_name": "Python 3 (ipykernel)",
|
376 |
+
"language": "python",
|
377 |
+
"name": "python3"
|
378 |
+
},
|
379 |
+
"language_info": {
|
380 |
+
"codemirror_mode": {
|
381 |
+
"name": "ipython",
|
382 |
+
"version": 3
|
383 |
+
},
|
384 |
+
"file_extension": ".py",
|
385 |
+
"mimetype": "text/x-python",
|
386 |
+
"name": "python",
|
387 |
+
"nbconvert_exporter": "python",
|
388 |
+
"pygments_lexer": "ipython3",
|
389 |
+
"version": "3.10.12"
|
390 |
+
}
|
391 |
+
},
|
392 |
+
"nbformat": 4,
|
393 |
+
"nbformat_minor": 5
|
394 |
+
}
|
.ipynb_checkpoints/language_modeling-checkpoint.ipynb
ADDED
@@ -0,0 +1,1186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.40.2)\n",
|
13 |
+
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.19.1)\n",
|
14 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n",
|
15 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.0)\n",
|
16 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.2)\n",
|
17 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n",
|
18 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
|
19 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.4.28)\n",
|
20 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
|
21 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n",
|
22 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n",
|
23 |
+
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.2)\n",
|
24 |
+
"Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.0.0)\n",
|
25 |
+
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n",
|
26 |
+
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
|
27 |
+
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n",
|
28 |
+
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
|
29 |
+
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
|
30 |
+
"Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2023.10.0)\n",
|
31 |
+
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.0b0)\n",
|
32 |
+
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
|
33 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n",
|
34 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
|
35 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
|
36 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
|
37 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
|
38 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.8.0)\n",
|
39 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n",
|
40 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n",
|
41 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.1.0)\n",
|
42 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n",
|
43 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
|
44 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
|
45 |
+
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
|
46 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
|
47 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
48 |
+
"\u001b[0m"
|
49 |
+
]
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"source": [
|
53 |
+
"# Transformers installation\n",
|
54 |
+
"# ! pip install transformers datasets\n",
|
55 |
+
"# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
|
56 |
+
"# ! pip install git+https://github.com/huggingface/transformers.git"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "markdown",
|
61 |
+
"metadata": {},
|
62 |
+
"source": [
|
63 |
+
"# Causal language modeling"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "markdown",
|
68 |
+
"metadata": {},
|
69 |
+
"source": [
|
70 |
+
"There are two types of language modeling, causal and masked. This guide illustrates causal language modeling.\n",
|
71 |
+
"Causal language models are frequently used for text generation. You can use these models for creative applications like\n",
|
72 |
+
"choosing your own text adventure or an intelligent coding assistant like Copilot or CodeParrot."
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": 4,
|
78 |
+
"metadata": {
|
79 |
+
"cellView": "form",
|
80 |
+
"hide_input": true
|
81 |
+
},
|
82 |
+
"outputs": [],
|
83 |
+
"source": [
|
84 |
+
"# #@title\n",
|
85 |
+
"# from IPython.display import HTML\n",
|
86 |
+
"\n",
|
87 |
+
"# HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/Vpjb1lu0MDk?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "markdown",
|
92 |
+
"metadata": {},
|
93 |
+
"source": [
|
94 |
+
"Causal language modeling predicts the next token in a sequence of tokens, and the model can only attend to tokens on\n",
|
95 |
+
"the left. This means the model cannot see future tokens. GPT-2 is an example of a causal language model.\n",
|
96 |
+
"\n",
|
97 |
+
"This guide will show you how to:\n",
|
98 |
+
"\n",
|
99 |
+
"1. Finetune [DistilGPT2](https://huggingface.co/distilgpt2) on the [r/askscience](https://www.reddit.com/r/askscience/) subset of the [ELI5](https://huggingface.co/datasets/eli5) dataset.\n",
|
100 |
+
"2. Use your finetuned model for inference.\n",
|
101 |
+
"\n",
|
102 |
+
"<Tip>\n",
|
103 |
+
"You can finetune other architectures for causal language modeling following the same steps in this guide.\n",
|
104 |
+
"Choose one of the following architectures:\n",
|
105 |
+
"\n",
|
106 |
+
"<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->\n",
|
107 |
+
"[BART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bart), [BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bert), [Bert Generation](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bert-generation), [BigBird](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/big_bird), [BigBird-Pegasus](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bigbird_pegasus), [BioGpt](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/biogpt), [Blenderbot](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/blenderbot), [BlenderbotSmall](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/blenderbot-small), [BLOOM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bloom), [CamemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/camembert), [CodeGen](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/codegen), [CPM-Ant](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/cpmant), [CTRL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ctrl), [Data2VecText](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/data2vec-text), [ELECTRA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/electra), [ERNIE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie), [GIT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/git), [GPT-Sw3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt-sw3), [OpenAI GPT-2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt2), [GPTBigCode](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_bigcode), [GPT Neo](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neo), [GPT NeoX](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neox), [GPT NeoX Japanese](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neox_japanese), [GPT-J](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gptj), [LLaMA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/llama), [Marian](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/marian), [mBART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mbart), [MEGA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mega), [Megatron-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/megatron-bert), [MVP](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mvp), [OpenLlama](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/open-llama), [OpenAI GPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/openai-gpt), [OPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/opt), [Pegasus](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/pegasus), [PLBart](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/plbart), [ProphetNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/prophetnet), [QDQBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/qdqbert), [Reformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/reformer), [RemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/rembert), [RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta), [RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta-prelayernorm), [RoCBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roc_bert), [RoFormer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roformer), [RWKV](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/rwkv), [Speech2Text2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/speech_to_text_2), [Transformer-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/transfo-xl), [TrOCR](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/trocr), [XGLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xglm), [XLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm), [XLM-ProphetNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-prophetnet), [XLM-RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta), [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta-xl), [XLNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlnet), [X-MOD](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xmod)\n",
|
108 |
+
"\n",
|
109 |
+
"\n",
|
110 |
+
"<!--End of the generated tip-->\n",
|
111 |
+
"\n",
|
112 |
+
"</Tip>\n",
|
113 |
+
"\n",
|
114 |
+
"Before you begin, make sure you have all the necessary libraries installed:\n",
|
115 |
+
"\n",
|
116 |
+
"```bash\n",
|
117 |
+
"pip install transformers datasets evaluate\n",
|
118 |
+
"```\n",
|
119 |
+
"\n",
|
120 |
+
"We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in:"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 5,
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [
|
128 |
+
{
|
129 |
+
"data": {
|
130 |
+
"application/vnd.jupyter.widget-view+json": {
|
131 |
+
"model_id": "a6d9e280e08e40ddbbcb8fbe97e1fae9",
|
132 |
+
"version_major": 2,
|
133 |
+
"version_minor": 0
|
134 |
+
},
|
135 |
+
"text/plain": [
|
136 |
+
"VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
|
137 |
+
]
|
138 |
+
},
|
139 |
+
"metadata": {},
|
140 |
+
"output_type": "display_data"
|
141 |
+
}
|
142 |
+
],
|
143 |
+
"source": [
|
144 |
+
"# from huggingface_hub import notebook_login\n",
|
145 |
+
"\n",
|
146 |
+
"# notebook_login()"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "markdown",
|
151 |
+
"metadata": {},
|
152 |
+
"source": [
|
153 |
+
"## Load ELI5 dataset"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "markdown",
|
158 |
+
"metadata": {},
|
159 |
+
"source": [
|
160 |
+
"Start by loading a smaller subset of the r/askscience subset of the ELI5 dataset from the 🤗 Datasets library.\n",
|
161 |
+
" This'll give you a chance to experiment and make sure everything works before spending more time training on the full dataset."
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": null,
|
167 |
+
"metadata": {},
|
168 |
+
"outputs": [],
|
169 |
+
"source": [
|
170 |
+
"# from datasets import load_dataset\n",
|
171 |
+
"\n",
|
172 |
+
"# eli5 = load_dataset(\"eli5\", split=\"train_asks[:5000]\")"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": 1,
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [
|
180 |
+
{
|
181 |
+
"data": {
|
182 |
+
"application/vnd.jupyter.widget-view+json": {
|
183 |
+
"model_id": "e5c92a52c290468496943cb8023e4479",
|
184 |
+
"version_major": 2,
|
185 |
+
"version_minor": 0
|
186 |
+
},
|
187 |
+
"text/plain": [
|
188 |
+
"Generating train split: 0 examples [00:00, ? examples/s]"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
"metadata": {},
|
192 |
+
"output_type": "display_data"
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"data": {
|
196 |
+
"application/vnd.jupyter.widget-view+json": {
|
197 |
+
"model_id": "cf14d12614594f51b63d4aa8259d278f",
|
198 |
+
"version_major": 2,
|
199 |
+
"version_minor": 0
|
200 |
+
},
|
201 |
+
"text/plain": [
|
202 |
+
"Generating validation split: 0 examples [00:00, ? examples/s]"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
"metadata": {},
|
206 |
+
"output_type": "display_data"
|
207 |
+
}
|
208 |
+
],
|
209 |
+
"source": [
|
210 |
+
"from datasets import load_dataset\n",
|
211 |
+
"# Falcon = load_dataset(\"csv\", data_files=\"FalconData.csv\")\n",
|
212 |
+
"Falcon = load_dataset('csv', data_files={\"train\": 'FalconData.csv', \"validation\": 'FalconDataEval.csv'})"
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"cell_type": "markdown",
|
217 |
+
"metadata": {},
|
218 |
+
"source": [
|
219 |
+
"Split the dataset's `train_asks` split into a train and test set with the [train_test_split](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.train_test_split) method:"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
{
|
223 |
+
"cell_type": "code",
|
224 |
+
"execution_count": 8,
|
225 |
+
"metadata": {},
|
226 |
+
"outputs": [],
|
227 |
+
"source": [
|
228 |
+
"# Falcon = Falcon.train_test_split(test_size=0.10)"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "markdown",
|
233 |
+
"metadata": {},
|
234 |
+
"source": [
|
235 |
+
"Then take a look at an example:"
|
236 |
+
]
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": 2,
|
241 |
+
"metadata": {},
|
242 |
+
"outputs": [
|
243 |
+
{
|
244 |
+
"data": {
|
245 |
+
"text/plain": [
|
246 |
+
"{'Text': 'Allow me to clarify a genuine fast for amateur online users What exactly is Youtube . com? Youtube . com is probably the most in-demand web site on the web which allow you to view and publish video lessons for free. These are generally submitted by Vimeo members on this video discussing system. Yet another thing YouTube registration is provided for free so anyone can join, however account is not required for watching video lessons. In order to sometimes observe video clips or post your own video lessons so that you can show to your friends, loved ones as well as other Vimeo members. Once you get dependent at viewing video clip, it is possible to phone yourself a YouTuber!\\n- Everything you are unable to upload? Nonetheless there are some regulations or YouTube\\'s regards to use that you should.\\n- Observing a Vimeo movie is really simple, you just need to.\\nObserving a You tube movie is absolutely simple, you just need to variety your best song or television set plan from the research discipline click on \"Research\" option and that\\'s it. It will approach your demand and give you a list of related results. You are able to click on a outcome and this will commence taking part in the recording. youtube downloader\\nAble to click on a outcome and\\nSo, just how to publish your chosen videos? Youtube . com is very popular online video discussing foundation that allows one to publish their video lessons. Uploading a relevant video online is an easy process, just select any video submit through your computer on your YouTube accounts webpage and it will surely begin posting the video. Nonetheless Vimeo will not offer any choice to down load a printed video that you will be seeing, you can easily take note of the site Link so that you can view it later, which seems handy for YouTube users.\\nEverything you cannot upload? Nevertheless there are a few regulations or YouTube\\'s terms of use that you have to comply with, specifically you happen to be unacceptable to upload any restricted content or erotic information. Nevertheless you can use it to showcase your products online.\\nA few regulations or\\nOnline video good quality once you upload Vimeo permits you to post all popular movie formats and produces good quality probable. Whenever you post a youtube video to Youtube . com, you ought to anticipate that high quality will slightly be changed, it is because YouTube optimizes the video for speedier packing. You can even add Hi-def or Hi-def video lessons nevertheless it will take much longer to weight once you observe it. Greater the high quality more slowly movie will load.\\nYou upload Vimeo\\nProbably the most well-known movie web sites online is You tube as well as for certain, you can find videos inside the web site you want to create you everywhere and adding it inside your PSP device might be what you need. However, YouTube video lessons will not be quickly down loadable. You might need a downloader to download the recording through the website and shop it inside your personal computer. video downloader\\nAfter you have saved the recording, it may possibly not certainly be around the preferred format which can be legible along with your Playstation portable. For those who have saved a structure not in mp4, you may want to transform the submit with your Computer in to a Playstation portable-pleasant structure. You may need a video clip converter for this task, and when you have changed the video tutorials, anyone can down load these to your Playstation portable.\\nWith your Playstation portable For those\\nIn accessing, simply link up your Playstation portable to the laptop or computer by means of its cord, use the Universal serial bus setting and download the video lessons and music that you want to bring along.\\nThat will help you look for a converter or a video downloader, specifically if you want to obtain video clips from Vimeo, be involved in forums and discover topics relevant to this. Certainly, you will also find a great deal of PSP movie information that may also assist you in making the best from your gadget and help you learn to see a number of videos on your gadget.\\nAlso find a great deal of PSP\\nYou can even get into membership web sites where PSP enthusiast collect and discuss information and facts and even more importantly, offers you the tools and software program that you will want to save music, videos and media records to your devices and permit you to enjoy the gizmo a lot more. Although these membership internet sites require only a minimum cost, it really is however vital that you are working with and creating dealings in a guaranteed and harmless internet site.\\n- You can even get into membership websites.\\n- One of the more preferred video clip sites on.\\n- Video quality when you post Vimeo enables you.\\n0 thoughts on “The Most Effective and Well-liked you tube downloader6675”'}"
|
247 |
+
]
|
248 |
+
},
|
249 |
+
"execution_count": 2,
|
250 |
+
"metadata": {},
|
251 |
+
"output_type": "execute_result"
|
252 |
+
}
|
253 |
+
],
|
254 |
+
"source": [
|
255 |
+
"Falcon['train'][0]"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"cell_type": "code",
|
260 |
+
"execution_count": 3,
|
261 |
+
"metadata": {},
|
262 |
+
"outputs": [
|
263 |
+
{
|
264 |
+
"data": {
|
265 |
+
"text/plain": [
|
266 |
+
"{'Text': 'For some reason, removing motor grease from cotton-poly blend is perceived as one of the more difficult laundry problems out there. The truth is, that there are several methods that you can use to get rid of this type of stain, which are listed here. While some of these methods may seem a little strange, each and every one of them will work. All you need to do is be willing to try it. If you are hesitant about using any of these methods at all, be sure to test them out on a similar piece of fabric to see what the end result will be. If there is any damage to your particular piece of fabric, than do not use the method to happen to have a few white t-shirts, blouses, or button-up shirts, then chances are you know the pain of having to ...Discover More\\nTablecloths are not cheap, and it is always a great idea to protect anything that is expensive. Cleaning tablecloths is ...Discover More\\nWhile it can be annoying to find that your white apparel and linens have turned yellow in the laundry, it no longer needs ...Discover More\\nFREE SERVICE: Get tips like this every week in Cleaning Tips from Tips.Net. Enter your address and click \"Subscribe.\"\\nView most recent newsletter.\\n2015-08-29 08:54:35\\nJune\\nComing from a long line of mechanics, I\\'ve always kept a bottle of LESTOIL around...works GREAT on auto grease, and cooking grease as well, just follow the directions on the bottle.\\nFREE SERVICE: Get tips like this every week in Cleaning Tips from Tips.Net. Enter your address and click \"Subscribe.\"\\n(Your e-mail address is not shared with anyone, ever.)\\nView the most recent newsletter.'}"
|
267 |
+
]
|
268 |
+
},
|
269 |
+
"execution_count": 3,
|
270 |
+
"metadata": {},
|
271 |
+
"output_type": "execute_result"
|
272 |
+
}
|
273 |
+
],
|
274 |
+
"source": [
|
275 |
+
"Falcon['validation'][0]"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"cell_type": "markdown",
|
280 |
+
"metadata": {},
|
281 |
+
"source": [
|
282 |
+
"While this may look like a lot, you're only really interested in the `text` field. What's cool about language modeling\n",
|
283 |
+
"tasks is you don't need labels (also known as an unsupervised task) because the next word *is* the label."
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "markdown",
|
288 |
+
"metadata": {},
|
289 |
+
"source": [
|
290 |
+
"## Preprocess"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"cell_type": "code",
|
295 |
+
"execution_count": null,
|
296 |
+
"metadata": {
|
297 |
+
"cellView": "form",
|
298 |
+
"hide_input": true
|
299 |
+
},
|
300 |
+
"outputs": [
|
301 |
+
{
|
302 |
+
"data": {
|
303 |
+
"text/html": [
|
304 |
+
"<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/ma1TrR7gE7I?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>"
|
305 |
+
],
|
306 |
+
"text/plain": [
|
307 |
+
"<IPython.core.display.HTML object>"
|
308 |
+
]
|
309 |
+
},
|
310 |
+
"execution_count": null,
|
311 |
+
"metadata": {},
|
312 |
+
"output_type": "execute_result"
|
313 |
+
}
|
314 |
+
],
|
315 |
+
"source": [
|
316 |
+
"# #@title\n",
|
317 |
+
"# from IPython.display import HTML\n",
|
318 |
+
"\n",
|
319 |
+
"# HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/ma1TrR7gE7I?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
|
320 |
+
]
|
321 |
+
},
|
322 |
+
{
|
323 |
+
"cell_type": "markdown",
|
324 |
+
"metadata": {},
|
325 |
+
"source": [
|
326 |
+
"The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "code",
|
331 |
+
"execution_count": 4,
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [
|
334 |
+
{
|
335 |
+
"name": "stderr",
|
336 |
+
"output_type": "stream",
|
337 |
+
"text": [
|
338 |
+
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
339 |
+
" warnings.warn(\n",
|
340 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
341 |
+
]
|
342 |
+
}
|
343 |
+
],
|
344 |
+
"source": [
|
345 |
+
"from transformers import AutoTokenizer, GPT2TokenizerFast\n",
|
346 |
+
"\n",
|
347 |
+
"# tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
|
348 |
+
"\n",
|
349 |
+
"\n",
|
350 |
+
"tokenizer = GPT2TokenizerFast.from_pretrained(\"Xenova/gpt-4\")#, cache_dir=cache_dir)\n",
|
351 |
+
"tokenizer.pad_token = tokenizer.eos_token"
|
352 |
+
]
|
353 |
+
},
|
354 |
+
{
|
355 |
+
"cell_type": "markdown",
|
356 |
+
"metadata": {},
|
357 |
+
"source": [
|
358 |
+
"You'll notice from the example above, the `text` field is actually nested inside `answers`. This means you'll need to\n",
|
359 |
+
"extract the `text` subfield from its nested structure with the [`flatten`](https://huggingface.co/docs/datasets/process.html#flatten) method:"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": 5,
|
365 |
+
"metadata": {},
|
366 |
+
"outputs": [
|
367 |
+
{
|
368 |
+
"data": {
|
369 |
+
"text/plain": [
|
370 |
+
"{'Text': 'Allow me to clarify a genuine fast for amateur online users What exactly is Youtube . com? Youtube . com is probably the most in-demand web site on the web which allow you to view and publish video lessons for free. These are generally submitted by Vimeo members on this video discussing system. Yet another thing YouTube registration is provided for free so anyone can join, however account is not required for watching video lessons. In order to sometimes observe video clips or post your own video lessons so that you can show to your friends, loved ones as well as other Vimeo members. Once you get dependent at viewing video clip, it is possible to phone yourself a YouTuber!\\n- Everything you are unable to upload? Nonetheless there are some regulations or YouTube\\'s regards to use that you should.\\n- Observing a Vimeo movie is really simple, you just need to.\\nObserving a You tube movie is absolutely simple, you just need to variety your best song or television set plan from the research discipline click on \"Research\" option and that\\'s it. It will approach your demand and give you a list of related results. You are able to click on a outcome and this will commence taking part in the recording. youtube downloader\\nAble to click on a outcome and\\nSo, just how to publish your chosen videos? Youtube . com is very popular online video discussing foundation that allows one to publish their video lessons. Uploading a relevant video online is an easy process, just select any video submit through your computer on your YouTube accounts webpage and it will surely begin posting the video. Nonetheless Vimeo will not offer any choice to down load a printed video that you will be seeing, you can easily take note of the site Link so that you can view it later, which seems handy for YouTube users.\\nEverything you cannot upload? Nevertheless there are a few regulations or YouTube\\'s terms of use that you have to comply with, specifically you happen to be unacceptable to upload any restricted content or erotic information. Nevertheless you can use it to showcase your products online.\\nA few regulations or\\nOnline video good quality once you upload Vimeo permits you to post all popular movie formats and produces good quality probable. Whenever you post a youtube video to Youtube . com, you ought to anticipate that high quality will slightly be changed, it is because YouTube optimizes the video for speedier packing. You can even add Hi-def or Hi-def video lessons nevertheless it will take much longer to weight once you observe it. Greater the high quality more slowly movie will load.\\nYou upload Vimeo\\nProbably the most well-known movie web sites online is You tube as well as for certain, you can find videos inside the web site you want to create you everywhere and adding it inside your PSP device might be what you need. However, YouTube video lessons will not be quickly down loadable. You might need a downloader to download the recording through the website and shop it inside your personal computer. video downloader\\nAfter you have saved the recording, it may possibly not certainly be around the preferred format which can be legible along with your Playstation portable. For those who have saved a structure not in mp4, you may want to transform the submit with your Computer in to a Playstation portable-pleasant structure. You may need a video clip converter for this task, and when you have changed the video tutorials, anyone can down load these to your Playstation portable.\\nWith your Playstation portable For those\\nIn accessing, simply link up your Playstation portable to the laptop or computer by means of its cord, use the Universal serial bus setting and download the video lessons and music that you want to bring along.\\nThat will help you look for a converter or a video downloader, specifically if you want to obtain video clips from Vimeo, be involved in forums and discover topics relevant to this. Certainly, you will also find a great deal of PSP movie information that may also assist you in making the best from your gadget and help you learn to see a number of videos on your gadget.\\nAlso find a great deal of PSP\\nYou can even get into membership web sites where PSP enthusiast collect and discuss information and facts and even more importantly, offers you the tools and software program that you will want to save music, videos and media records to your devices and permit you to enjoy the gizmo a lot more. Although these membership internet sites require only a minimum cost, it really is however vital that you are working with and creating dealings in a guaranteed and harmless internet site.\\n- You can even get into membership websites.\\n- One of the more preferred video clip sites on.\\n- Video quality when you post Vimeo enables you.\\n0 thoughts on “The Most Effective and Well-liked you tube downloader6675”'}"
|
371 |
+
]
|
372 |
+
},
|
373 |
+
"execution_count": 5,
|
374 |
+
"metadata": {},
|
375 |
+
"output_type": "execute_result"
|
376 |
+
}
|
377 |
+
],
|
378 |
+
"source": [
|
379 |
+
"Falcon = Falcon.flatten()\n",
|
380 |
+
"Falcon[\"train\"][0]"
|
381 |
+
]
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"cell_type": "markdown",
|
385 |
+
"metadata": {},
|
386 |
+
"source": [
|
387 |
+
"Each subfield is now a separate column as indicated by the `answers` prefix, and the `text` field is a list now. Instead\n",
|
388 |
+
"of tokenizing each sentence separately, convert the list to a string so you can jointly tokenize them.\n",
|
389 |
+
"\n",
|
390 |
+
"Here is a first preprocessing function to join the list of strings for each example and tokenize the result:"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"cell_type": "code",
|
395 |
+
"execution_count": 6,
|
396 |
+
"metadata": {},
|
397 |
+
"outputs": [],
|
398 |
+
"source": [
|
399 |
+
"def preprocess_function(examples):\n",
|
400 |
+
" return tokenizer([\" \".join(x) for x in examples[\"Text\"]])"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "markdown",
|
405 |
+
"metadata": {},
|
406 |
+
"source": [
|
407 |
+
"To apply this preprocessing function over the entire dataset, use the 🤗 Datasets [map](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map) method. You can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once, and increasing the number of processes with `num_proc`. Remove any columns you don't need:"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "code",
|
412 |
+
"execution_count": 7,
|
413 |
+
"metadata": {},
|
414 |
+
"outputs": [
|
415 |
+
{
|
416 |
+
"name": "stdout",
|
417 |
+
"output_type": "stream",
|
418 |
+
"text": [
|
419 |
+
"The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !\n"
|
420 |
+
]
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"data": {
|
424 |
+
"application/vnd.jupyter.widget-view+json": {
|
425 |
+
"model_id": "51bff46d94664c468064b17d1a8bf1c0",
|
426 |
+
"version_major": 2,
|
427 |
+
"version_minor": 0
|
428 |
+
},
|
429 |
+
"text/plain": [
|
430 |
+
"Map (num_proc=4): 0%| | 0/20000 [00:00<?, ? examples/s]"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
"metadata": {},
|
434 |
+
"output_type": "display_data"
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"name": "stderr",
|
438 |
+
"output_type": "stream",
|
439 |
+
"text": [
|
440 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (8569 > 8192). Running this sequence through the model will result in indexing errors\n",
|
441 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (14224 > 8192). Running this sequence through the model will result in indexing errors\n",
|
442 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (15104 > 8192). Running this sequence through the model will result in indexing errors\n",
|
443 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (32874 > 8192). Running this sequence through the model will result in indexing errors\n"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"name": "stdout",
|
448 |
+
"output_type": "stream",
|
449 |
+
"text": [
|
450 |
+
"The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !\n"
|
451 |
+
]
|
452 |
+
},
|
453 |
+
{
|
454 |
+
"data": {
|
455 |
+
"application/vnd.jupyter.widget-view+json": {
|
456 |
+
"model_id": "5a093fd9868042a9ac76ed1c141711a7",
|
457 |
+
"version_major": 2,
|
458 |
+
"version_minor": 0
|
459 |
+
},
|
460 |
+
"text/plain": [
|
461 |
+
"Map (num_proc=4): 0%| | 0/2000 [00:00<?, ? examples/s]"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
"metadata": {},
|
465 |
+
"output_type": "display_data"
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"name": "stderr",
|
469 |
+
"output_type": "stream",
|
470 |
+
"text": [
|
471 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (8414 > 8192). Running this sequence through the model will result in indexing errors\n",
|
472 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (11892 > 8192). Running this sequence through the model will result in indexing errors\n",
|
473 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (22303 > 8192). Running this sequence through the model will result in indexing errors\n",
|
474 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (12749 > 8192). Running this sequence through the model will result in indexing errors\n"
|
475 |
+
]
|
476 |
+
}
|
477 |
+
],
|
478 |
+
"source": [
|
479 |
+
"tokenized_Falcon = Falcon.map(\n",
|
480 |
+
" preprocess_function,\n",
|
481 |
+
" batched=True,\n",
|
482 |
+
" num_proc=4,\n",
|
483 |
+
" remove_columns=Falcon[\"train\"].column_names,\n",
|
484 |
+
")"
|
485 |
+
]
|
486 |
+
},
|
487 |
+
{
|
488 |
+
"cell_type": "markdown",
|
489 |
+
"metadata": {},
|
490 |
+
"source": [
|
491 |
+
"This dataset contains the token sequences, but some of these are longer than the maximum input length for the model.\n",
|
492 |
+
"\n",
|
493 |
+
"You can now use a second preprocessing function to\n",
|
494 |
+
"- concatenate all the sequences\n",
|
495 |
+
"- split the concatenated sequences into shorter chunks defined by `block_size`, which should be both shorter than the maximum input length and short enough for your GPU RAM."
|
496 |
+
]
|
497 |
+
},
|
498 |
+
{
|
499 |
+
"cell_type": "code",
|
500 |
+
"execution_count": 8,
|
501 |
+
"metadata": {},
|
502 |
+
"outputs": [],
|
503 |
+
"source": [
|
504 |
+
"block_size = 1048\n",
|
505 |
+
"\n",
|
506 |
+
"\n",
|
507 |
+
"def group_texts(examples):\n",
|
508 |
+
" # Concatenate all texts.\n",
|
509 |
+
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
510 |
+
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
511 |
+
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
512 |
+
" # customize this part to your needs.\n",
|
513 |
+
" if total_length >= block_size:\n",
|
514 |
+
" total_length = (total_length // block_size) * block_size\n",
|
515 |
+
" # Split by chunks of block_size.\n",
|
516 |
+
" result = {\n",
|
517 |
+
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
|
518 |
+
" for k, t in concatenated_examples.items()\n",
|
519 |
+
" }\n",
|
520 |
+
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
|
521 |
+
" return result"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"cell_type": "markdown",
|
526 |
+
"metadata": {},
|
527 |
+
"source": [
|
528 |
+
"Apply the `group_texts` function over the entire dataset:"
|
529 |
+
]
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"cell_type": "code",
|
533 |
+
"execution_count": 9,
|
534 |
+
"metadata": {},
|
535 |
+
"outputs": [
|
536 |
+
{
|
537 |
+
"data": {
|
538 |
+
"application/vnd.jupyter.widget-view+json": {
|
539 |
+
"model_id": "6134c09493054ce3940da711dc2e965e",
|
540 |
+
"version_major": 2,
|
541 |
+
"version_minor": 0
|
542 |
+
},
|
543 |
+
"text/plain": [
|
544 |
+
"Map (num_proc=4): 0%| | 0/20000 [00:00<?, ? examples/s]"
|
545 |
+
]
|
546 |
+
},
|
547 |
+
"metadata": {},
|
548 |
+
"output_type": "display_data"
|
549 |
+
},
|
550 |
+
{
|
551 |
+
"data": {
|
552 |
+
"application/vnd.jupyter.widget-view+json": {
|
553 |
+
"model_id": "bd3f26e9c76f42f1827aa11aa45416df",
|
554 |
+
"version_major": 2,
|
555 |
+
"version_minor": 0
|
556 |
+
},
|
557 |
+
"text/plain": [
|
558 |
+
"Map (num_proc=4): 0%| | 0/2000 [00:00<?, ? examples/s]"
|
559 |
+
]
|
560 |
+
},
|
561 |
+
"metadata": {},
|
562 |
+
"output_type": "display_data"
|
563 |
+
}
|
564 |
+
],
|
565 |
+
"source": [
|
566 |
+
"lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)"
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"cell_type": "markdown",
|
571 |
+
"metadata": {},
|
572 |
+
"source": [
|
573 |
+
"Now create a batch of examples using [DataCollatorForLanguageModeling](https://huggingface.co/docs/transformers/main/en/main_classes/data_collator#transformers.DataCollatorForLanguageModeling). It's more efficient to *dynamically pad* the\n",
|
574 |
+
"sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.\n",
|
575 |
+
"\n",
|
576 |
+
"Use the end-of-sequence token as the padding token and set `mlm=False`. This will use the inputs as labels shifted to the right by one element:"
|
577 |
+
]
|
578 |
+
},
|
579 |
+
{
|
580 |
+
"cell_type": "code",
|
581 |
+
"execution_count": 10,
|
582 |
+
"metadata": {},
|
583 |
+
"outputs": [],
|
584 |
+
"source": [
|
585 |
+
"from transformers import DataCollatorForLanguageModeling\n",
|
586 |
+
"\n",
|
587 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
588 |
+
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)"
|
589 |
+
]
|
590 |
+
},
|
591 |
+
{
|
592 |
+
"cell_type": "markdown",
|
593 |
+
"metadata": {},
|
594 |
+
"source": [
|
595 |
+
"## Train"
|
596 |
+
]
|
597 |
+
},
|
598 |
+
{
|
599 |
+
"cell_type": "markdown",
|
600 |
+
"metadata": {},
|
601 |
+
"source": [
|
602 |
+
"<Tip>\n",
|
603 |
+
"\n",
|
604 |
+
"If you aren't familiar with finetuning a model with the [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer), take a look at the [basic tutorial](https://huggingface.co/docs/transformers/main/en/tasks/../training#train-with-pytorch-trainer)!\n",
|
605 |
+
"\n",
|
606 |
+
"</Tip>\n",
|
607 |
+
"\n",
|
608 |
+
"You're ready to start training your model now! Load DistilGPT2 with [AutoModelForCausalLM](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForCausalLM):"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"cell_type": "code",
|
613 |
+
"execution_count": 11,
|
614 |
+
"metadata": {},
|
615 |
+
"outputs": [
|
616 |
+
{
|
617 |
+
"name": "stderr",
|
618 |
+
"output_type": "stream",
|
619 |
+
"text": [
|
620 |
+
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
621 |
+
" warnings.warn(\n"
|
622 |
+
]
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"data": {
|
626 |
+
"application/vnd.jupyter.widget-view+json": {
|
627 |
+
"model_id": "f55ae69743a74a08943641e2da03e791",
|
628 |
+
"version_major": 2,
|
629 |
+
"version_minor": 0
|
630 |
+
},
|
631 |
+
"text/plain": [
|
632 |
+
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
633 |
+
]
|
634 |
+
},
|
635 |
+
"metadata": {},
|
636 |
+
"output_type": "display_data"
|
637 |
+
}
|
638 |
+
],
|
639 |
+
"source": [
|
640 |
+
"from transformers import AutoModelForCausalLM, TrainingArguments, Trainer\n",
|
641 |
+
"import torch\n",
|
642 |
+
"model = AutoModelForCausalLM.from_pretrained(\"tensorplex-labs/pretraining-sn9-7B-5\", torch_dtype=torch.bfloat16) "
|
643 |
+
]
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"cell_type": "markdown",
|
647 |
+
"metadata": {},
|
648 |
+
"source": [
|
649 |
+
"At this point, only three steps remain:\n",
|
650 |
+
"\n",
|
651 |
+
"1. Define your training hyperparameters in [TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments). The only required parameter is `output_dir` which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model).\n",
|
652 |
+
"2. Pass the training arguments to [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer) along with the model, datasets, and data collator.\n",
|
653 |
+
"3. Call [train()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train) to finetune your model."
|
654 |
+
]
|
655 |
+
},
|
656 |
+
{
|
657 |
+
"cell_type": "code",
|
658 |
+
"execution_count": 28,
|
659 |
+
"metadata": {},
|
660 |
+
"outputs": [],
|
661 |
+
"source": [
|
662 |
+
"import torch\n",
|
663 |
+
"torch.cuda.empty_cache()\n"
|
664 |
+
]
|
665 |
+
},
|
666 |
+
{
|
667 |
+
"cell_type": "code",
|
668 |
+
"execution_count": 14,
|
669 |
+
"metadata": {},
|
670 |
+
"outputs": [],
|
671 |
+
"source": [
|
672 |
+
"import torch\n",
|
673 |
+
"import gc\n",
|
674 |
+
"\n",
|
675 |
+
"# del tensor_name # Delete the tensor\n",
|
676 |
+
"gc.collect() # Collect garbage\n",
|
677 |
+
"torch.cuda.empty_cache() # Clear cache"
|
678 |
+
]
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"cell_type": "code",
|
682 |
+
"execution_count": 21,
|
683 |
+
"metadata": {},
|
684 |
+
"outputs": [],
|
685 |
+
"source": [
|
686 |
+
"torch.cuda.empty_cache()"
|
687 |
+
]
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"cell_type": "code",
|
691 |
+
"execution_count": 20,
|
692 |
+
"metadata": {},
|
693 |
+
"outputs": [
|
694 |
+
{
|
695 |
+
"data": {
|
696 |
+
"text/plain": [
|
697 |
+
"<torch.autograd.grad_mode.no_grad at 0x7f41880db6d0>"
|
698 |
+
]
|
699 |
+
},
|
700 |
+
"execution_count": 20,
|
701 |
+
"metadata": {},
|
702 |
+
"output_type": "execute_result"
|
703 |
+
}
|
704 |
+
],
|
705 |
+
"source": [
|
706 |
+
"torch.no_grad()"
|
707 |
+
]
|
708 |
+
},
|
709 |
+
{
|
710 |
+
"cell_type": "code",
|
711 |
+
"execution_count": 12,
|
712 |
+
"metadata": {},
|
713 |
+
"outputs": [
|
714 |
+
{
|
715 |
+
"data": {
|
716 |
+
"text/plain": [
|
717 |
+
"LlamaForCausalLM(\n",
|
718 |
+
" (model): LlamaModel(\n",
|
719 |
+
" (embed_tokens): Embedding(100288, 4096)\n",
|
720 |
+
" (layers): ModuleList(\n",
|
721 |
+
" (0-29): 30 x LlamaDecoderLayer(\n",
|
722 |
+
" (self_attn): LlamaSdpaAttention(\n",
|
723 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
724 |
+
" (k_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
725 |
+
" (v_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
726 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
727 |
+
" (rotary_emb): LlamaRotaryEmbedding()\n",
|
728 |
+
" )\n",
|
729 |
+
" (mlp): LlamaMLP(\n",
|
730 |
+
" (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
|
731 |
+
" (up_proj): Linear(in_features=4096, out_features=11008, bias=False)\n",
|
732 |
+
" (down_proj): Linear(in_features=11008, out_features=4096, bias=False)\n",
|
733 |
+
" (act_fn): SiLU()\n",
|
734 |
+
" )\n",
|
735 |
+
" (input_layernorm): LlamaRMSNorm()\n",
|
736 |
+
" (post_attention_layernorm): LlamaRMSNorm()\n",
|
737 |
+
" )\n",
|
738 |
+
" )\n",
|
739 |
+
" (norm): LlamaRMSNorm()\n",
|
740 |
+
" )\n",
|
741 |
+
" (lm_head): Linear(in_features=4096, out_features=100288, bias=False)\n",
|
742 |
+
")"
|
743 |
+
]
|
744 |
+
},
|
745 |
+
"execution_count": 12,
|
746 |
+
"metadata": {},
|
747 |
+
"output_type": "execute_result"
|
748 |
+
}
|
749 |
+
],
|
750 |
+
"source": [
|
751 |
+
"model.to('cuda')"
|
752 |
+
]
|
753 |
+
},
|
754 |
+
{
|
755 |
+
"cell_type": "code",
|
756 |
+
"execution_count": 14,
|
757 |
+
"metadata": {},
|
758 |
+
"outputs": [],
|
759 |
+
"source": [
|
760 |
+
"training_args = TrainingArguments(\n",
|
761 |
+
" output_dir=\"Fine-Tuned-S9\",\n",
|
762 |
+
" bf16=True,\n",
|
763 |
+
" # evaluation_strategy=\"epoch\",\n",
|
764 |
+
" evaluation_strategy=\"steps\",\n",
|
765 |
+
" learning_rate=2e-5,\n",
|
766 |
+
" weight_decay=0.01,\n",
|
767 |
+
" num_train_epochs=1,\n",
|
768 |
+
" per_device_train_batch_size=2,\n",
|
769 |
+
" per_device_eval_batch_size=2,\n",
|
770 |
+
" # lr_scheduler_type = 'cosine',\n",
|
771 |
+
" push_to_hub=False,\n",
|
772 |
+
" save_total_limit = 2\n",
|
773 |
+
" # save_strategy = “no”\n",
|
774 |
+
" load_best_model_at_end=False\n",
|
775 |
+
")\n",
|
776 |
+
"\n",
|
777 |
+
"trainer = Trainer(\n",
|
778 |
+
" model=model,\n",
|
779 |
+
" args=training_args,\n",
|
780 |
+
" train_dataset=lm_dataset[\"train\"],\n",
|
781 |
+
" eval_dataset=lm_dataset[\"validation\"],\n",
|
782 |
+
" # eval_dataset=lm_dataset[\"test\"],\n",
|
783 |
+
" data_collator=data_collator,\n",
|
784 |
+
")\n",
|
785 |
+
"\n",
|
786 |
+
"# trainer.train()"
|
787 |
+
]
|
788 |
+
},
|
789 |
+
{
|
790 |
+
"cell_type": "code",
|
791 |
+
"execution_count": 7,
|
792 |
+
"metadata": {},
|
793 |
+
"outputs": [],
|
794 |
+
"source": [
|
795 |
+
"trainer.train()"
|
796 |
+
]
|
797 |
+
},
|
798 |
+
{
|
799 |
+
"cell_type": "markdown",
|
800 |
+
"metadata": {},
|
801 |
+
"source": [
|
802 |
+
"Once training is completed, use the [evaluate()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.evaluate) method to evaluate your model and get its perplexity:"
|
803 |
+
]
|
804 |
+
},
|
805 |
+
{
|
806 |
+
"cell_type": "code",
|
807 |
+
"execution_count": 17,
|
808 |
+
"metadata": {},
|
809 |
+
"outputs": [
|
810 |
+
{
|
811 |
+
"name": "stdout",
|
812 |
+
"output_type": "stream",
|
813 |
+
"text": [
|
814 |
+
"Perplexity: 2.21\n"
|
815 |
+
]
|
816 |
+
}
|
817 |
+
],
|
818 |
+
"source": [
|
819 |
+
"import math\n",
|
820 |
+
"\n",
|
821 |
+
"eval_results = trainer.evaluate()\n",
|
822 |
+
"print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
|
823 |
+
]
|
824 |
+
},
|
825 |
+
{
|
826 |
+
"cell_type": "markdown",
|
827 |
+
"metadata": {},
|
828 |
+
"source": [
|
829 |
+
"Then share your model to the Hub with the [push_to_hub()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.push_to_hub) method so everyone can use your model:"
|
830 |
+
]
|
831 |
+
},
|
832 |
+
{
|
833 |
+
"cell_type": "code",
|
834 |
+
"execution_count": null,
|
835 |
+
"metadata": {},
|
836 |
+
"outputs": [],
|
837 |
+
"source": [
|
838 |
+
"# trainer.push_to_hub()"
|
839 |
+
]
|
840 |
+
},
|
841 |
+
{
|
842 |
+
"cell_type": "markdown",
|
843 |
+
"metadata": {},
|
844 |
+
"source": [
|
845 |
+
"<Tip>\n",
|
846 |
+
"\n",
|
847 |
+
"For a more in-depth example of how to finetune a model for causal language modeling, take a look at the corresponding\n",
|
848 |
+
"[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb)\n",
|
849 |
+
"or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).\n",
|
850 |
+
"\n",
|
851 |
+
"</Tip>"
|
852 |
+
]
|
853 |
+
},
|
854 |
+
{
|
855 |
+
"cell_type": "markdown",
|
856 |
+
"metadata": {},
|
857 |
+
"source": [
|
858 |
+
"## Inference"
|
859 |
+
]
|
860 |
+
},
|
861 |
+
{
|
862 |
+
"cell_type": "markdown",
|
863 |
+
"metadata": {},
|
864 |
+
"source": [
|
865 |
+
"Great, now that you've finetuned a model, you can use it for inference!\n",
|
866 |
+
"\n",
|
867 |
+
"Come up with a prompt you'd like to generate text from:"
|
868 |
+
]
|
869 |
+
},
|
870 |
+
{
|
871 |
+
"cell_type": "code",
|
872 |
+
"execution_count": 2,
|
873 |
+
"metadata": {},
|
874 |
+
"outputs": [],
|
875 |
+
"source": [
|
876 |
+
"# prompt = \"Somatic hypermutation allows the immune system to\""
|
877 |
+
]
|
878 |
+
},
|
879 |
+
{
|
880 |
+
"cell_type": "markdown",
|
881 |
+
"metadata": {},
|
882 |
+
"source": [
|
883 |
+
"The simplest way to try out your finetuned model for inference is to use it in a [pipeline()](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.pipeline). Instantiate a `pipeline` for text generation with your model, and pass your text to it:"
|
884 |
+
]
|
885 |
+
},
|
886 |
+
{
|
887 |
+
"cell_type": "code",
|
888 |
+
"execution_count": 20,
|
889 |
+
"metadata": {},
|
890 |
+
"outputs": [
|
891 |
+
{
|
892 |
+
"ename": "ValueError",
|
893 |
+
"evalue": "Could not load model Fine-Tuned-S9/checkpoint-4000 with any of the following classes: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>, <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>). See the original errors:\n\nwhile loading with AutoModelForCausalLM, an error is thrown:\nTraceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py\", line 283, in infer_framework_load_model\n model = model_class.from_pretrained(model, **kwargs)\n File \"/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py\", line 563, in from_pretrained\n return model_class.from_pretrained(\n File \"/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py\", line 3260, in from_pretrained\n raise EnvironmentError(\nOSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory Fine-Tuned-S9/checkpoint-4000.\n\nwhile loading with LlamaForCausalLM, an error is thrown:\nTraceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py\", line 283, in infer_framework_load_model\n model = model_class.from_pretrained(model, **kwargs)\n File \"/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py\", line 3260, in from_pretrained\n raise EnvironmentError(\nOSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory Fine-Tuned-S9/checkpoint-4000.\n\n\n",
|
894 |
+
"output_type": "error",
|
895 |
+
"traceback": [
|
896 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
897 |
+
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
898 |
+
"Cell \u001b[0;32mIn[20], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtransformers\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m pipeline\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# checkpoint-4000\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m generator \u001b[38;5;241m=\u001b[39m \u001b[43mpipeline\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtext-generation\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mFine-Tuned-S9/checkpoint-4000\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m generator(prompt)\n",
|
899 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/pipelines/__init__.py:906\u001b[0m, in \u001b[0;36mpipeline\u001b[0;34m(task, model, config, tokenizer, feature_extractor, image_processor, framework, revision, use_fast, token, device, device_map, torch_dtype, trust_remote_code, model_kwargs, pipeline_class, **kwargs)\u001b[0m\n\u001b[1;32m 904\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, \u001b[38;5;28mstr\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m framework \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 905\u001b[0m model_classes \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtf\u001b[39m\u001b[38;5;124m\"\u001b[39m: targeted_task[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtf\u001b[39m\u001b[38;5;124m\"\u001b[39m], \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m: targeted_task[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m]}\n\u001b[0;32m--> 906\u001b[0m framework, model \u001b[38;5;241m=\u001b[39m \u001b[43minfer_framework_load_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 907\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 908\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel_classes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel_classes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 909\u001b[0m \u001b[43m \u001b[49m\u001b[43mconfig\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconfig\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 910\u001b[0m \u001b[43m \u001b[49m\u001b[43mframework\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mframework\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 911\u001b[0m \u001b[43m \u001b[49m\u001b[43mtask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 912\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mhub_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 913\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmodel_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 914\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 916\u001b[0m model_config \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mconfig\n\u001b[1;32m 917\u001b[0m hub_kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_commit_hash\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39m_commit_hash\n",
|
900 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py:296\u001b[0m, in \u001b[0;36minfer_framework_load_model\u001b[0;34m(model, config, model_classes, task, framework, **model_kwargs)\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m class_name, trace \u001b[38;5;129;01min\u001b[39;00m all_traceback\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 295\u001b[0m error \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwhile loading with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mclass_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m, an error is thrown:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00mtrace\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 296\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 297\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCould not load model \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m with any of the following classes: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mclass_tuple\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m. See the original errors:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;132;01m{\u001b[39;00merror\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 298\u001b[0m )\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m framework \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 301\u001b[0m framework \u001b[38;5;241m=\u001b[39m infer_framework(model\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m)\n",
|
901 |
+
"\u001b[0;31mValueError\u001b[0m: Could not load model Fine-Tuned-S9/checkpoint-4000 with any of the following classes: (<class 'transformers.models.auto.modeling_auto.AutoModelForCausalLM'>, <class 'transformers.models.llama.modeling_llama.LlamaForCausalLM'>). See the original errors:\n\nwhile loading with AutoModelForCausalLM, an error is thrown:\nTraceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py\", line 283, in infer_framework_load_model\n model = model_class.from_pretrained(model, **kwargs)\n File \"/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py\", line 563, in from_pretrained\n return model_class.from_pretrained(\n File \"/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py\", line 3260, in from_pretrained\n raise EnvironmentError(\nOSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory Fine-Tuned-S9/checkpoint-4000.\n\nwhile loading with LlamaForCausalLM, an error is thrown:\nTraceback (most recent call last):\n File \"/usr/local/lib/python3.10/dist-packages/transformers/pipelines/base.py\", line 283, in infer_framework_load_model\n model = model_class.from_pretrained(model, **kwargs)\n File \"/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py\", line 3260, in from_pretrained\n raise EnvironmentError(\nOSError: Error no file named pytorch_model.bin, tf_model.h5, model.ckpt.index or flax_model.msgpack found in directory Fine-Tuned-S9/checkpoint-4000.\n\n\n"
|
902 |
+
]
|
903 |
+
}
|
904 |
+
],
|
905 |
+
"source": [
|
906 |
+
"# from transformers import pipeline\n",
|
907 |
+
"# # checkpoint-4000\n",
|
908 |
+
"# generator = pipeline(\"text-generation\", model=\"Fine-Tuned-S9/checkpoint-4000\")\n",
|
909 |
+
"# generator(prompt)"
|
910 |
+
]
|
911 |
+
},
|
912 |
+
{
|
913 |
+
"cell_type": "markdown",
|
914 |
+
"metadata": {},
|
915 |
+
"source": [
|
916 |
+
"Tokenize the text and return the `input_ids` as PyTorch tensors:"
|
917 |
+
]
|
918 |
+
},
|
919 |
+
{
|
920 |
+
"cell_type": "code",
|
921 |
+
"execution_count": 3,
|
922 |
+
"metadata": {},
|
923 |
+
"outputs": [
|
924 |
+
{
|
925 |
+
"name": "stderr",
|
926 |
+
"output_type": "stream",
|
927 |
+
"text": [
|
928 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
929 |
+
]
|
930 |
+
}
|
931 |
+
],
|
932 |
+
"source": [
|
933 |
+
"# from transformers import AutoTokenizer\n",
|
934 |
+
"\n",
|
935 |
+
"# tokenizer = AutoTokenizer.from_pretrained(\"Xenova/gpt-4\")\n",
|
936 |
+
"# inputs = tokenizer(prompt, return_tensors=\"pt\").input_ids"
|
937 |
+
]
|
938 |
+
},
|
939 |
+
{
|
940 |
+
"cell_type": "markdown",
|
941 |
+
"metadata": {},
|
942 |
+
"source": [
|
943 |
+
"Use the [generate()](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) method to generate text.\n",
|
944 |
+
"For more details about the different text generation strategies and parameters for controlling generation, check out the [Text generation strategies](https://huggingface.co/docs/transformers/main/en/tasks/../generation_strategies) page."
|
945 |
+
]
|
946 |
+
},
|
947 |
+
{
|
948 |
+
"cell_type": "code",
|
949 |
+
"execution_count": 4,
|
950 |
+
"metadata": {},
|
951 |
+
"outputs": [
|
952 |
+
{
|
953 |
+
"data": {
|
954 |
+
"application/vnd.jupyter.widget-view+json": {
|
955 |
+
"model_id": "7ba147780e8548d28a00a655e81e588a",
|
956 |
+
"version_major": 2,
|
957 |
+
"version_minor": 0
|
958 |
+
},
|
959 |
+
"text/plain": [
|
960 |
+
"config.json: 0%| | 0.00/688 [00:00<?, ?B/s]"
|
961 |
+
]
|
962 |
+
},
|
963 |
+
"metadata": {},
|
964 |
+
"output_type": "display_data"
|
965 |
+
},
|
966 |
+
{
|
967 |
+
"name": "stderr",
|
968 |
+
"output_type": "stream",
|
969 |
+
"text": [
|
970 |
+
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
971 |
+
" warnings.warn(\n"
|
972 |
+
]
|
973 |
+
},
|
974 |
+
{
|
975 |
+
"data": {
|
976 |
+
"application/vnd.jupyter.widget-view+json": {
|
977 |
+
"model_id": "04e2f536d4d1492bbb4dcf72abbf2cc3",
|
978 |
+
"version_major": 2,
|
979 |
+
"version_minor": 0
|
980 |
+
},
|
981 |
+
"text/plain": [
|
982 |
+
"model.safetensors.index.json: 0%| | 0.00/22.5k [00:00<?, ?B/s]"
|
983 |
+
]
|
984 |
+
},
|
985 |
+
"metadata": {},
|
986 |
+
"output_type": "display_data"
|
987 |
+
},
|
988 |
+
{
|
989 |
+
"data": {
|
990 |
+
"application/vnd.jupyter.widget-view+json": {
|
991 |
+
"model_id": "df7e14c799c0457f8422442a065f3b03",
|
992 |
+
"version_major": 2,
|
993 |
+
"version_minor": 0
|
994 |
+
},
|
995 |
+
"text/plain": [
|
996 |
+
"Downloading shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
997 |
+
]
|
998 |
+
},
|
999 |
+
"metadata": {},
|
1000 |
+
"output_type": "display_data"
|
1001 |
+
},
|
1002 |
+
{
|
1003 |
+
"data": {
|
1004 |
+
"application/vnd.jupyter.widget-view+json": {
|
1005 |
+
"model_id": "ee74102a34694e6cb57a00210d34cf19",
|
1006 |
+
"version_major": 2,
|
1007 |
+
"version_minor": 0
|
1008 |
+
},
|
1009 |
+
"text/plain": [
|
1010 |
+
"model-00001-of-00003.safetensors: 0%| | 0.00/4.97G [00:00<?, ?B/s]"
|
1011 |
+
]
|
1012 |
+
},
|
1013 |
+
"metadata": {},
|
1014 |
+
"output_type": "display_data"
|
1015 |
+
},
|
1016 |
+
{
|
1017 |
+
"data": {
|
1018 |
+
"application/vnd.jupyter.widget-view+json": {
|
1019 |
+
"model_id": "978d214714044affb97e1b31ab6deafc",
|
1020 |
+
"version_major": 2,
|
1021 |
+
"version_minor": 0
|
1022 |
+
},
|
1023 |
+
"text/plain": [
|
1024 |
+
"model-00002-of-00003.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
|
1025 |
+
]
|
1026 |
+
},
|
1027 |
+
"metadata": {},
|
1028 |
+
"output_type": "display_data"
|
1029 |
+
},
|
1030 |
+
{
|
1031 |
+
"data": {
|
1032 |
+
"application/vnd.jupyter.widget-view+json": {
|
1033 |
+
"model_id": "0a2fb5b3f2ec4e3e8d7bc9db54a0635e",
|
1034 |
+
"version_major": 2,
|
1035 |
+
"version_minor": 0
|
1036 |
+
},
|
1037 |
+
"text/plain": [
|
1038 |
+
"model-00003-of-00003.safetensors: 0%| | 0.00/3.84G [00:00<?, ?B/s]"
|
1039 |
+
]
|
1040 |
+
},
|
1041 |
+
"metadata": {},
|
1042 |
+
"output_type": "display_data"
|
1043 |
+
},
|
1044 |
+
{
|
1045 |
+
"name": "stderr",
|
1046 |
+
"output_type": "stream",
|
1047 |
+
"text": [
|
1048 |
+
"Error while downloading from https://cdn-lfs-us-1.huggingface.co/repos/54/cf/54cf63a091d3be4443d28131b5c3686f6dd17bc8fe13dfd74b30bc4eafc5b3e2/4c4148f267d0c0cb2979c9cf8e60f11fb91770076c28a2a79f4446ea30bff523?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27model-00003-of-00003.safetensors%3B+filename%3D%22model-00003-of-00003.safetensors%22%3B&Expires=1715867899&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxNTg2Nzg5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzU0L2NmLzU0Y2Y2M2EwOTFkM2JlNDQ0M2QyODEzMWI1YzM2ODZmNmRkMTdiYzhmZTEzZGZkNzRiMzBiYzRlYWZjNWIzZTIvNGM0MTQ4ZjI2N2QwYzBjYjI5NzljOWNmOGU2MGYxMWZiOTE3NzAwNzZjMjhhMmE3OWY0NDQ2ZWEzMGJmZjUyMz9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=NRnXWL-gncnyNfcEhT0Xqi7WNbx5rVxELBfBIjnfb3zk7DCNDIqSPi-iNcrXmNkEmINWGbghFy4ifzUqvzNOmm0cJF10hMi%7E6R5DBKRBK0DRGtC2fC72sXzk9ysyJ6mQRSegUeDZy2KZqUL3wzwRC2Xhv8baK%7ENi0FGjUSh0Hmpg7Wgbs2quZRMM7lXqI-y3bkKh7L6OBXnx3W55Mlzzt87CgYLyotXuFIUrQ1W5lN6R3LWZuDvJ0ClLVuSKjTGwBv9MRQYLewybb4yqSmmEDfTkmuCphg2%7EfzNJ53Q2kqMEVC6gRPf67v8NDR9j57zOtoNSc1-SdaCem95aycbC7A__&Key-Pair-Id=KCD77M1F0VK2B: HTTPSConnectionPool(host='cdn-lfs-us-1.huggingface.co', port=443): Read timed out.\n",
|
1049 |
+
"Trying to resume download...\n"
|
1050 |
+
]
|
1051 |
+
},
|
1052 |
+
{
|
1053 |
+
"data": {
|
1054 |
+
"application/vnd.jupyter.widget-view+json": {
|
1055 |
+
"model_id": "635db10feaa74dff93285752d9e79520",
|
1056 |
+
"version_major": 2,
|
1057 |
+
"version_minor": 0
|
1058 |
+
},
|
1059 |
+
"text/plain": [
|
1060 |
+
"model-00003-of-00003.safetensors: 71%|####### | 2.71G/3.84G [00:00<?, ?B/s]"
|
1061 |
+
]
|
1062 |
+
},
|
1063 |
+
"metadata": {},
|
1064 |
+
"output_type": "display_data"
|
1065 |
+
},
|
1066 |
+
{
|
1067 |
+
"data": {
|
1068 |
+
"application/vnd.jupyter.widget-view+json": {
|
1069 |
+
"model_id": "38e479e6424d4edc8d00795ce084d4c2",
|
1070 |
+
"version_major": 2,
|
1071 |
+
"version_minor": 0
|
1072 |
+
},
|
1073 |
+
"text/plain": [
|
1074 |
+
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
|
1075 |
+
]
|
1076 |
+
},
|
1077 |
+
"metadata": {},
|
1078 |
+
"output_type": "display_data"
|
1079 |
+
},
|
1080 |
+
{
|
1081 |
+
"data": {
|
1082 |
+
"application/vnd.jupyter.widget-view+json": {
|
1083 |
+
"model_id": "602b879326a44c58bc0909a3b86cd666",
|
1084 |
+
"version_major": 2,
|
1085 |
+
"version_minor": 0
|
1086 |
+
},
|
1087 |
+
"text/plain": [
|
1088 |
+
"generation_config.json: 0%| | 0.00/121 [00:00<?, ?B/s]"
|
1089 |
+
]
|
1090 |
+
},
|
1091 |
+
"metadata": {},
|
1092 |
+
"output_type": "display_data"
|
1093 |
+
},
|
1094 |
+
{
|
1095 |
+
"name": "stderr",
|
1096 |
+
"output_type": "stream",
|
1097 |
+
"text": [
|
1098 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
1099 |
+
"Setting `pad_token_id` to `eos_token_id`:100257 for open-end generation.\n"
|
1100 |
+
]
|
1101 |
+
}
|
1102 |
+
],
|
1103 |
+
"source": [
|
1104 |
+
"# from transformers import AutoModelForCausalLM\n",
|
1105 |
+
"\n",
|
1106 |
+
"# model = AutoModelForCausalLM.from_pretrained(\"deepnet/SN6-BestLlama\")\n",
|
1107 |
+
"# outputs = model.generate(inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)"
|
1108 |
+
]
|
1109 |
+
},
|
1110 |
+
{
|
1111 |
+
"cell_type": "markdown",
|
1112 |
+
"metadata": {},
|
1113 |
+
"source": [
|
1114 |
+
"Decode the generated token ids back into text:"
|
1115 |
+
]
|
1116 |
+
},
|
1117 |
+
{
|
1118 |
+
"cell_type": "code",
|
1119 |
+
"execution_count": 5,
|
1120 |
+
"metadata": {},
|
1121 |
+
"outputs": [
|
1122 |
+
{
|
1123 |
+
"data": {
|
1124 |
+
"text/plain": [
|
1125 |
+
"['Somatic hypermutation allows the immune system to recognize foreign proteins. \\n - . \\n - \\n 1 . 3 \\n S e t s \\n 0 \\n A c c e p t s \\n A l m o s t \\n 1 \\n C l o s e d \\n T o p i c s \\n P a p e r s \\n 0 \\n P a p e r s \\n B e a r i n g \\n P a g e s \\n 0 \\n P a g e s \\n R e c o']"
|
1126 |
+
]
|
1127 |
+
},
|
1128 |
+
"execution_count": 5,
|
1129 |
+
"metadata": {},
|
1130 |
+
"output_type": "execute_result"
|
1131 |
+
}
|
1132 |
+
],
|
1133 |
+
"source": [
|
1134 |
+
"# tokenizer.batch_decode(outputs, skip_special_tokens=True)"
|
1135 |
+
]
|
1136 |
+
},
|
1137 |
+
{
|
1138 |
+
"cell_type": "code",
|
1139 |
+
"execution_count": 6,
|
1140 |
+
"metadata": {},
|
1141 |
+
"outputs": [
|
1142 |
+
{
|
1143 |
+
"data": {
|
1144 |
+
"text/plain": [
|
1145 |
+
"['Somatic hypermutation allows the immune system to recognize foreign proteins. \\n - . \\n - \\n 1 . 3 \\n S e t s \\n 0 \\n A c c e p t s \\n A l m o s t \\n 1 \\n C l o s e d \\n T o p i c s \\n P a p e r s \\n 0 \\n P a p e r s \\n B e a r i n g \\n P a g e s \\n 0 \\n P a g e s \\n R e c o']"
|
1146 |
+
]
|
1147 |
+
},
|
1148 |
+
"execution_count": 6,
|
1149 |
+
"metadata": {},
|
1150 |
+
"output_type": "execute_result"
|
1151 |
+
}
|
1152 |
+
],
|
1153 |
+
"source": [
|
1154 |
+
"# tokenizer.batch_decode(outputs, skip_special_tokens=True)"
|
1155 |
+
]
|
1156 |
+
},
|
1157 |
+
{
|
1158 |
+
"cell_type": "code",
|
1159 |
+
"execution_count": null,
|
1160 |
+
"metadata": {},
|
1161 |
+
"outputs": [],
|
1162 |
+
"source": []
|
1163 |
+
}
|
1164 |
+
],
|
1165 |
+
"metadata": {
|
1166 |
+
"kernelspec": {
|
1167 |
+
"display_name": "Python 3 (ipykernel)",
|
1168 |
+
"language": "python",
|
1169 |
+
"name": "python3"
|
1170 |
+
},
|
1171 |
+
"language_info": {
|
1172 |
+
"codemirror_mode": {
|
1173 |
+
"name": "ipython",
|
1174 |
+
"version": 3
|
1175 |
+
},
|
1176 |
+
"file_extension": ".py",
|
1177 |
+
"mimetype": "text/x-python",
|
1178 |
+
"name": "python",
|
1179 |
+
"nbconvert_exporter": "python",
|
1180 |
+
"pygments_lexer": "ipython3",
|
1181 |
+
"version": "3.10.12"
|
1182 |
+
}
|
1183 |
+
},
|
1184 |
+
"nbformat": 4,
|
1185 |
+
"nbformat_minor": 4
|
1186 |
+
}
|
.ipynb_checkpoints/language_modeling-checkpoint.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
|
4 |
+
# Transformers installation
|
5 |
+
# ! pip install transformers datasets
|
6 |
+
# To install from source instead of the last release, comment the command above and uncomment the following one.
|
7 |
+
# ! pip install git+https://github.com/huggingface/transformers.git
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
# #@title
|
12 |
+
# from IPython.display import HTML
|
13 |
+
|
14 |
+
# HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/Vpjb1lu0MDk?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
# from huggingface_hub import notebook_login
|
19 |
+
|
20 |
+
# notebook_login()
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# from datasets import load_dataset
|
25 |
+
|
26 |
+
# eli5 = load_dataset("eli5", split="train_asks[:5000]")
|
27 |
+
|
28 |
+
from datasets import load_dataset
|
29 |
+
# Falcon = load_dataset("csv", data_files="FalconData.csv")
|
30 |
+
Falcon = load_dataset('csv', data_files={"train": 'FalconData_train.csv', "validation": 'FalconData_validation.csv'})
|
31 |
+
|
32 |
+
print('Dataset Loaded!')
|
33 |
+
|
34 |
+
# Falcon = Falcon.train_test_split(test_size=0.10)
|
35 |
+
|
36 |
+
"""Then take a look at an example:"""
|
37 |
+
|
38 |
+
Falcon['train'][0]
|
39 |
+
|
40 |
+
Falcon['validation'][0]
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
# #@title
|
45 |
+
# from IPython.display import HTML
|
46 |
+
|
47 |
+
# HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/ma1TrR7gE7I?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
|
48 |
+
|
49 |
+
"""The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:"""
|
50 |
+
|
51 |
+
from transformers import AutoTokenizer, GPT2TokenizerFast
|
52 |
+
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
54 |
+
|
55 |
+
|
56 |
+
# tokenizer = GPT2TokenizerFast.from_pretrained("Xenova/gpt-4")#, cache_dir=cache_dir)
|
57 |
+
# tokenizer.pad_token
|
58 |
+
|
59 |
+
# tokenizer.eos_token=128000
|
60 |
+
# tokenizer.bos_token='128000'
|
61 |
+
# tokenizer.eos_token='128001'
|
62 |
+
|
63 |
+
tokenizer.pad_token = tokenizer.eos_token
|
64 |
+
|
65 |
+
Falcon = Falcon.flatten()
|
66 |
+
Falcon["train"][0]
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def preprocess_function(examples):
|
71 |
+
return tokenizer([" ".join(x) for x in examples["Text"]])
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
tokenized_Falcon = Falcon.map(
|
76 |
+
preprocess_function,
|
77 |
+
batched=True,
|
78 |
+
num_proc=4,
|
79 |
+
remove_columns=Falcon["train"].column_names,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
block_size = tokenizer.model_max_length
|
84 |
+
# block_size = 2048
|
85 |
+
|
86 |
+
|
87 |
+
def group_texts(examples):
|
88 |
+
# Concatenate all texts.
|
89 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
90 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
91 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
92 |
+
# customize this part to your needs.
|
93 |
+
if total_length >= block_size:
|
94 |
+
total_length = (total_length // block_size) * block_size
|
95 |
+
# Split by chunks of block_size.
|
96 |
+
result = {
|
97 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
98 |
+
for k, t in concatenated_examples.items()
|
99 |
+
}
|
100 |
+
result["labels"] = result["input_ids"].copy()
|
101 |
+
return result
|
102 |
+
|
103 |
+
"""Apply the `group_texts` function over the entire dataset:"""
|
104 |
+
|
105 |
+
lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
from transformers import DataCollatorForLanguageModeling
|
110 |
+
|
111 |
+
# tokenizer.pad_token
|
112 |
+
# tokenizer.bos_token='128000'
|
113 |
+
# tokenizer.eos_token='128001'
|
114 |
+
|
115 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
120 |
+
import torch
|
121 |
+
model = AutoModelForCausalLM.from_pretrained("rwh/tinytoo", torch_dtype=torch.bfloat16)
|
122 |
+
|
123 |
+
print('Model Loaded!')
|
124 |
+
|
125 |
+
# import torch
|
126 |
+
# torch.cuda.empty_cache()
|
127 |
+
|
128 |
+
# import torch
|
129 |
+
# import gc
|
130 |
+
|
131 |
+
# # del tensor_name # Delete the tensor
|
132 |
+
# gc.collect() # Collect garbage
|
133 |
+
# torch.cuda.empty_cache() # Clear cache
|
134 |
+
|
135 |
+
# torch.cuda.empty_cache()
|
136 |
+
|
137 |
+
# torch.no_grad()
|
138 |
+
|
139 |
+
model.to('cuda')
|
140 |
+
|
141 |
+
OutputDir = "ReadyModel3"
|
142 |
+
|
143 |
+
training_args = TrainingArguments(
|
144 |
+
output_dir=OutputDir,
|
145 |
+
overwrite_output_dir=True,
|
146 |
+
bf16=True,
|
147 |
+
# evaluation_strategy="epoch",
|
148 |
+
evaluation_strategy="steps",
|
149 |
+
# learning_rate=3.25e-06,
|
150 |
+
# learning_rate=2e-5,
|
151 |
+
learning_rate=1e-5,
|
152 |
+
# weight_decay=0.01,
|
153 |
+
weight_decay=0.001,
|
154 |
+
num_train_epochs=5,
|
155 |
+
per_device_train_batch_size=8,
|
156 |
+
per_device_eval_batch_size=8,
|
157 |
+
# lr_scheduler_type = 'cosine',
|
158 |
+
lr_scheduler_type = 'linear',
|
159 |
+
push_to_hub=False,
|
160 |
+
save_total_limit = 2,
|
161 |
+
save_strategy = "steps",
|
162 |
+
load_best_model_at_end=True,
|
163 |
+
save_safetensors=True,
|
164 |
+
)
|
165 |
+
|
166 |
+
trainer = Trainer(
|
167 |
+
model=model,
|
168 |
+
args=training_args,
|
169 |
+
train_dataset=lm_dataset["train"],
|
170 |
+
eval_dataset=lm_dataset["validation"],
|
171 |
+
# eval_dataset=lm_dataset["test"],
|
172 |
+
data_collator=data_collator,
|
173 |
+
)
|
174 |
+
|
175 |
+
# trainer.train()
|
176 |
+
print('Started Training!')
|
177 |
+
trainer.train()
|
178 |
+
|
179 |
+
trainer.save_model(OutputDir)
|
180 |
+
print('Saved Model Path:', OutputDir)
|
181 |
+
|
182 |
+
import math
|
183 |
+
|
184 |
+
eval_results = trainer.evaluate()
|
185 |
+
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
|
186 |
+
|
187 |
+
|
FalconData.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4da726ba49818c96e679a57343b2b03c3c34af0cff0fe5b84725d6ccbc2405c8
|
3 |
+
size 25530585
|
FalconData2.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c40586cddb6904a918b7f6e2f1b09293434df3c62f77ccae9664cc08df4aa7ef
|
3 |
+
size 129479461
|
FalconDataSet.ipynb
ADDED
@@ -0,0 +1,717 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 14,
|
6 |
+
"id": "460d90da-b986-4c1c-8a66-eab144b0ba8d",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Started Fetching Data\n",
|
14 |
+
"Failed to fetch data, retrying. Attempt 1/10\n",
|
15 |
+
"Failed to fetch data, retrying. Attempt 1/10\n",
|
16 |
+
"Fetched data for all the Pages.\n"
|
17 |
+
]
|
18 |
+
}
|
19 |
+
],
|
20 |
+
"source": [
|
21 |
+
"import requests\n",
|
22 |
+
"import time\n",
|
23 |
+
"\n",
|
24 |
+
"import random\n",
|
25 |
+
"pages = [\n",
|
26 |
+
" random.randint(1, 968000015)\n",
|
27 |
+
" for _ in range(500)\n",
|
28 |
+
" ]\n",
|
29 |
+
"# print(pages)\n",
|
30 |
+
"\n",
|
31 |
+
"base_url = \"https://datasets-server.huggingface.co/rows\"\n",
|
32 |
+
"params = {\n",
|
33 |
+
" \"dataset\": \"tiiuae/falcon-refinedweb\",\n",
|
34 |
+
" \"config\": \"default\",\n",
|
35 |
+
" \"split\": \"train\",\n",
|
36 |
+
" }\n",
|
37 |
+
"# response = requests.get(base_url, params=params)\n",
|
38 |
+
"# response.raise_for_status()\n",
|
39 |
+
"# for row in response.json()[\"rows\"]:\n",
|
40 |
+
"# content = row[\"row\"][\"content\"]\n",
|
41 |
+
"num_rows_per_page = 100\n",
|
42 |
+
"retry_limit = 10\n",
|
43 |
+
"retry_delay = 5\n",
|
44 |
+
"Falcon = []\n",
|
45 |
+
"\n",
|
46 |
+
"print('Started Fetching Data')\n",
|
47 |
+
"def fetch_data_for_page(page):\n",
|
48 |
+
" params[\"offset\"] = page\n",
|
49 |
+
" params[\"limit\"] = num_rows_per_page\n",
|
50 |
+
" attempt = 0\n",
|
51 |
+
" while attempt < retry_limit:\n",
|
52 |
+
" try:\n",
|
53 |
+
" response = requests.get(base_url, params=params)\n",
|
54 |
+
" response.raise_for_status() # This will raise an HTTPError if the HTTP request returned an unsuccessful status code\n",
|
55 |
+
" for row in response.json()[\"rows\"]:\n",
|
56 |
+
" content = row[\"row\"][\"content\"]\n",
|
57 |
+
" Falcon.append(content)\n",
|
58 |
+
" len(Falcon)\n",
|
59 |
+
" #print(f\"Fetched data for all the Pages.\")\n",
|
60 |
+
" break\n",
|
61 |
+
" except requests.exceptions.HTTPError as e:\n",
|
62 |
+
" attempt += 1\n",
|
63 |
+
" print(\n",
|
64 |
+
" f\"Failed to fetch data, retrying. Attempt {attempt}/{retry_limit}\"\n",
|
65 |
+
" )\n",
|
66 |
+
" if attempt < retry_limit:\n",
|
67 |
+
" time.sleep(retry_delay) # Wait before the next retry\n",
|
68 |
+
" else:\n",
|
69 |
+
" print(\n",
|
70 |
+
" \"Maximum retry limit reached. Unable to fetch data.\"\n",
|
71 |
+
" )\n",
|
72 |
+
" raise\n",
|
73 |
+
"\n",
|
74 |
+
"for page in pages:\n",
|
75 |
+
" fetch_data_for_page(page)\n",
|
76 |
+
"\n",
|
77 |
+
"print(f\"Fetched data for all the Pages.\")"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 15,
|
83 |
+
"id": "f8f3baf1-5480-450b-a456-174a5c114d3e",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [],
|
86 |
+
"source": [
|
87 |
+
"import csv\n",
|
88 |
+
"\n",
|
89 |
+
"# Open the CSV file for writing\n",
|
90 |
+
"with open(\"FalconData2.csv\", \"w\", newline=\"\") as csvfile:\n",
|
91 |
+
" # Create a CSV writer object\n",
|
92 |
+
" writer = csv.writer(csvfile)\n",
|
93 |
+
"\n",
|
94 |
+
" # Write the header row\n",
|
95 |
+
" writer.writerow([\"Text\"])\n",
|
96 |
+
"\n",
|
97 |
+
" # Write each element in the list as a row in the CSV file\n",
|
98 |
+
" for element in Falcon:\n",
|
99 |
+
" writer.writerow([element])\n"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": 30,
|
105 |
+
"id": "ea47c936-2c2b-4414-ba57-74fb6827ec0a",
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [
|
108 |
+
{
|
109 |
+
"name": "stdout",
|
110 |
+
"output_type": "stream",
|
111 |
+
"text": [
|
112 |
+
"Number of duplicate rows: 5\n",
|
113 |
+
" Text\n",
|
114 |
+
"522 Name:\n",
|
115 |
+
"11746 Description.\\nReviews\\nThere are no reviews yet.\n",
|
116 |
+
"17606 Description.\\nReviews\\nThere are no reviews yet.\n",
|
117 |
+
"30436 NaN\n",
|
118 |
+
"42549 !\\n\n"
|
119 |
+
]
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"source": [
|
123 |
+
"import pandas as pd\n",
|
124 |
+
"\n",
|
125 |
+
"# Read the CSV file into a pandas DataFrame\n",
|
126 |
+
"df = pd.read_csv(\"FalconData2.csv\")\n",
|
127 |
+
"\n",
|
128 |
+
"# Check for duplicate rows\n",
|
129 |
+
"duplicate_rows = df[df.duplicated()]\n",
|
130 |
+
"\n",
|
131 |
+
"# Print the number of duplicate rows\n",
|
132 |
+
"print(f\"Number of duplicate rows: {len(duplicate_rows)}\")\n",
|
133 |
+
"\n",
|
134 |
+
"# Print the duplicate rows\n",
|
135 |
+
"print(duplicate_rows)"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": 31,
|
141 |
+
"id": "f4178cd6-747f-4e05-a9bf-17b97f959e06",
|
142 |
+
"metadata": {},
|
143 |
+
"outputs": [
|
144 |
+
{
|
145 |
+
"data": {
|
146 |
+
"text/html": [
|
147 |
+
"<div>\n",
|
148 |
+
"<style scoped>\n",
|
149 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
150 |
+
" vertical-align: middle;\n",
|
151 |
+
" }\n",
|
152 |
+
"\n",
|
153 |
+
" .dataframe tbody tr th {\n",
|
154 |
+
" vertical-align: top;\n",
|
155 |
+
" }\n",
|
156 |
+
"\n",
|
157 |
+
" .dataframe thead th {\n",
|
158 |
+
" text-align: right;\n",
|
159 |
+
" }\n",
|
160 |
+
"</style>\n",
|
161 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
162 |
+
" <thead>\n",
|
163 |
+
" <tr style=\"text-align: right;\">\n",
|
164 |
+
" <th></th>\n",
|
165 |
+
" <th>Text</th>\n",
|
166 |
+
" </tr>\n",
|
167 |
+
" </thead>\n",
|
168 |
+
" <tbody>\n",
|
169 |
+
" <tr>\n",
|
170 |
+
" <th>0</th>\n",
|
171 |
+
" <td>[…]\\nM&S bank […]\\nLowest unsecured loan rate...</td>\n",
|
172 |
+
" </tr>\n",
|
173 |
+
" <tr>\n",
|
174 |
+
" <th>1</th>\n",
|
175 |
+
" <td>JavaScript seems to be disabled in your browse...</td>\n",
|
176 |
+
" </tr>\n",
|
177 |
+
" <tr>\n",
|
178 |
+
" <th>2</th>\n",
|
179 |
+
" <td>CMTech has designed a game to foster social in...</td>\n",
|
180 |
+
" </tr>\n",
|
181 |
+
" <tr>\n",
|
182 |
+
" <th>3</th>\n",
|
183 |
+
" <td>A Storyteller's Point of View\\nMy\\nWriting\\nLe...</td>\n",
|
184 |
+
" </tr>\n",
|
185 |
+
" <tr>\n",
|
186 |
+
" <th>4</th>\n",
|
187 |
+
" <td>mspu.us was registered 1 decade 3 years ago. I...</td>\n",
|
188 |
+
" </tr>\n",
|
189 |
+
" </tbody>\n",
|
190 |
+
"</table>\n",
|
191 |
+
"</div>"
|
192 |
+
],
|
193 |
+
"text/plain": [
|
194 |
+
" Text\n",
|
195 |
+
"0 […]\\nM&S bank […]\\nLowest unsecured loan rate...\n",
|
196 |
+
"1 JavaScript seems to be disabled in your browse...\n",
|
197 |
+
"2 CMTech has designed a game to foster social in...\n",
|
198 |
+
"3 A Storyteller's Point of View\\nMy\\nWriting\\nLe...\n",
|
199 |
+
"4 mspu.us was registered 1 decade 3 years ago. I..."
|
200 |
+
]
|
201 |
+
},
|
202 |
+
"execution_count": 31,
|
203 |
+
"metadata": {},
|
204 |
+
"output_type": "execute_result"
|
205 |
+
}
|
206 |
+
],
|
207 |
+
"source": [
|
208 |
+
"df.head()"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": 32,
|
214 |
+
"id": "264548c1-4cf4-441f-a433-2f5d57861dc4",
|
215 |
+
"metadata": {},
|
216 |
+
"outputs": [
|
217 |
+
{
|
218 |
+
"data": {
|
219 |
+
"text/html": [
|
220 |
+
"<div>\n",
|
221 |
+
"<style scoped>\n",
|
222 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
223 |
+
" vertical-align: middle;\n",
|
224 |
+
" }\n",
|
225 |
+
"\n",
|
226 |
+
" .dataframe tbody tr th {\n",
|
227 |
+
" vertical-align: top;\n",
|
228 |
+
" }\n",
|
229 |
+
"\n",
|
230 |
+
" .dataframe thead th {\n",
|
231 |
+
" text-align: right;\n",
|
232 |
+
" }\n",
|
233 |
+
"</style>\n",
|
234 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
235 |
+
" <thead>\n",
|
236 |
+
" <tr style=\"text-align: right;\">\n",
|
237 |
+
" <th></th>\n",
|
238 |
+
" <th>Text</th>\n",
|
239 |
+
" </tr>\n",
|
240 |
+
" </thead>\n",
|
241 |
+
" <tbody>\n",
|
242 |
+
" <tr>\n",
|
243 |
+
" <th>49995</th>\n",
|
244 |
+
" <td>Alumni in Action: Grace Heyne Lybrand\\nWhen Gr...</td>\n",
|
245 |
+
" </tr>\n",
|
246 |
+
" <tr>\n",
|
247 |
+
" <th>49996</th>\n",
|
248 |
+
" <td>This.\\n51.351813 -105.220438\\n12 replies on “L...</td>\n",
|
249 |
+
" </tr>\n",
|
250 |
+
" <tr>\n",
|
251 |
+
" <th>49997</th>\n",
|
252 |
+
" <td>VIDEO 1: Panel discussion with John Nichols, a...</td>\n",
|
253 |
+
" </tr>\n",
|
254 |
+
" <tr>\n",
|
255 |
+
" <th>49998</th>\n",
|
256 |
+
" <td>The Prototype DA-2A made its first flight on M...</td>\n",
|
257 |
+
" </tr>\n",
|
258 |
+
" <tr>\n",
|
259 |
+
" <th>49999</th>\n",
|
260 |
+
" <td>default search action\\nBibTeX record journals/...</td>\n",
|
261 |
+
" </tr>\n",
|
262 |
+
" </tbody>\n",
|
263 |
+
"</table>\n",
|
264 |
+
"</div>"
|
265 |
+
],
|
266 |
+
"text/plain": [
|
267 |
+
" Text\n",
|
268 |
+
"49995 Alumni in Action: Grace Heyne Lybrand\\nWhen Gr...\n",
|
269 |
+
"49996 This.\\n51.351813 -105.220438\\n12 replies on “L...\n",
|
270 |
+
"49997 VIDEO 1: Panel discussion with John Nichols, a...\n",
|
271 |
+
"49998 The Prototype DA-2A made its first flight on M...\n",
|
272 |
+
"49999 default search action\\nBibTeX record journals/..."
|
273 |
+
]
|
274 |
+
},
|
275 |
+
"execution_count": 32,
|
276 |
+
"metadata": {},
|
277 |
+
"output_type": "execute_result"
|
278 |
+
}
|
279 |
+
],
|
280 |
+
"source": [
|
281 |
+
"df.tail()"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"cell_type": "code",
|
286 |
+
"execution_count": 33,
|
287 |
+
"id": "3f215b09-8050-4477-860c-d3ed0a19f45d",
|
288 |
+
"metadata": {},
|
289 |
+
"outputs": [
|
290 |
+
{
|
291 |
+
"name": "stdout",
|
292 |
+
"output_type": "stream",
|
293 |
+
"text": [
|
294 |
+
"Number of Words:\n",
|
295 |
+
"0 65\n",
|
296 |
+
"1 79\n",
|
297 |
+
"2 287\n",
|
298 |
+
"3 302\n",
|
299 |
+
"4 130\n",
|
300 |
+
" ... \n",
|
301 |
+
"49995 64\n",
|
302 |
+
"49996 325\n",
|
303 |
+
"49997 58\n",
|
304 |
+
"49998 623\n",
|
305 |
+
"49999 67\n",
|
306 |
+
"Name: Text, Length: 50000, dtype: int64\n",
|
307 |
+
"Smallest Row:\n",
|
308 |
+
"Text This\n",
|
309 |
+
"Name: 270, dtype: object\n",
|
310 |
+
"\n",
|
311 |
+
"Largest Row:\n",
|
312 |
+
"Text MAMMALS\\n400. Abu Jafar, M.Z., and C. Hays-Sha...\n",
|
313 |
+
"Name: 33020, dtype: object\n"
|
314 |
+
]
|
315 |
+
}
|
316 |
+
],
|
317 |
+
"source": [
|
318 |
+
"# Calculate the word count for each row without storing it as a column\n",
|
319 |
+
"word_counts = df['Text'].apply(lambda x: len(str(x).split()))\n",
|
320 |
+
"\n",
|
321 |
+
"\n",
|
322 |
+
"print(\"Number of Words:\")\n",
|
323 |
+
"print(word_counts)\n",
|
324 |
+
"\n",
|
325 |
+
"# print(\"Smallest Count\")\n",
|
326 |
+
"# print(word_counts.min())\n",
|
327 |
+
"\n",
|
328 |
+
"# print(\"Largest Count\")\n",
|
329 |
+
"# print(word_counts.max())\n",
|
330 |
+
"\n",
|
331 |
+
"# Find the row with the smallest word count\n",
|
332 |
+
"smallest_row = df.loc[word_counts.idxmin()]\n",
|
333 |
+
"\n",
|
334 |
+
"# Find the row with the largest word count\n",
|
335 |
+
"largest_row = df.loc[word_counts.idxmax()]\n",
|
336 |
+
"\n",
|
337 |
+
"# Display the smallest and largest rows\n",
|
338 |
+
"print(\"Smallest Row:\")\n",
|
339 |
+
"print(smallest_row)\n",
|
340 |
+
"\n",
|
341 |
+
"print(\"\\nLargest Row:\")\n",
|
342 |
+
"print(largest_row)\n"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "code",
|
347 |
+
"execution_count": 34,
|
348 |
+
"id": "be5a87a8-cfee-4f63-992e-8fa1d4a5cdbb",
|
349 |
+
"metadata": {},
|
350 |
+
"outputs": [
|
351 |
+
{
|
352 |
+
"data": {
|
353 |
+
"text/plain": [
|
354 |
+
"Text NaN\n",
|
355 |
+
"Name: 30436, dtype: object"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
"execution_count": 34,
|
359 |
+
"metadata": {},
|
360 |
+
"output_type": "execute_result"
|
361 |
+
}
|
362 |
+
],
|
363 |
+
"source": [
|
364 |
+
"target_row=30436\n",
|
365 |
+
"specific_row = df.iloc[target_row]\n",
|
366 |
+
"specific_row"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
{
|
370 |
+
"cell_type": "code",
|
371 |
+
"execution_count": 13,
|
372 |
+
"id": "e97d9e18-eaa0-4a1b-96ab-c89a0f4c738d",
|
373 |
+
"metadata": {},
|
374 |
+
"outputs": [
|
375 |
+
{
|
376 |
+
"name": "stdout",
|
377 |
+
"output_type": "stream",
|
378 |
+
"text": [
|
379 |
+
"Text The old wireline Bell telephone system was bui...\n",
|
380 |
+
"Name: 19995, dtype: object\n"
|
381 |
+
]
|
382 |
+
}
|
383 |
+
],
|
384 |
+
"source": [
|
385 |
+
"print(specific_row)"
|
386 |
+
]
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"cell_type": "code",
|
390 |
+
"execution_count": 14,
|
391 |
+
"id": "940ef35f-7517-403d-9f42-73760182dcaa",
|
392 |
+
"metadata": {},
|
393 |
+
"outputs": [
|
394 |
+
{
|
395 |
+
"name": "stdout",
|
396 |
+
"output_type": "stream",
|
397 |
+
"text": [
|
398 |
+
"Text The old wireline Bell telephone system was bui...\n"
|
399 |
+
]
|
400 |
+
}
|
401 |
+
],
|
402 |
+
"source": [
|
403 |
+
"print(specific_row.to_string())"
|
404 |
+
]
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"cell_type": "code",
|
408 |
+
"execution_count": 17,
|
409 |
+
"id": "915ac669-718f-47f5-b175-a5f928b407db",
|
410 |
+
"metadata": {},
|
411 |
+
"outputs": [
|
412 |
+
{
|
413 |
+
"name": "stdout",
|
414 |
+
"output_type": "stream",
|
415 |
+
"text": [
|
416 |
+
"57\n"
|
417 |
+
]
|
418 |
+
}
|
419 |
+
],
|
420 |
+
"source": [
|
421 |
+
"print(len(specific_row.to_string()))"
|
422 |
+
]
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"cell_type": "code",
|
426 |
+
"execution_count": 24,
|
427 |
+
"id": "ab5ee254-9ba7-496b-97c7-3b6185c21971",
|
428 |
+
"metadata": {},
|
429 |
+
"outputs": [
|
430 |
+
{
|
431 |
+
"name": "stdout",
|
432 |
+
"output_type": "stream",
|
433 |
+
"text": [
|
434 |
+
"Training set size: 49000\n",
|
435 |
+
"Validation set size: 1000\n"
|
436 |
+
]
|
437 |
+
}
|
438 |
+
],
|
439 |
+
"source": [
|
440 |
+
"# import pandas as pd\n",
|
441 |
+
"\n",
|
442 |
+
"# # Load the dataset\n",
|
443 |
+
"# df = pd.read_csv(\"FalconData2.csv\")\n",
|
444 |
+
"\n",
|
445 |
+
"# # Calculate the index to split the data at the last 10%\n",
|
446 |
+
"# split_index = int(len(df) * 0.980)\n",
|
447 |
+
"\n",
|
448 |
+
"# # Split the data into training and validation sets\n",
|
449 |
+
"# train_df = df.iloc[:split_index] # First 90% for training\n",
|
450 |
+
"# validation_df = df.iloc[split_index:] # Last 10% for validation\n",
|
451 |
+
"\n",
|
452 |
+
"# # Display the sizes of the training and validation sets\n",
|
453 |
+
"# print(f\"Training set size: {len(train_df)}\")\n",
|
454 |
+
"# print(f\"Validation set size: {len(validation_df)}\")\n",
|
455 |
+
"\n",
|
456 |
+
"# # Optionally, save the datasets to new CSV files\n",
|
457 |
+
"# train_df.to_csv(\"FalconData_train2.csv\", index=False)\n",
|
458 |
+
"# validation_df.to_csv(\"FalconData_validation2.csv\", index=False)\n"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "code",
|
463 |
+
"execution_count": 35,
|
464 |
+
"id": "7a16fb10-40cd-4668-b363-57ca64819ad3",
|
465 |
+
"metadata": {},
|
466 |
+
"outputs": [
|
467 |
+
{
|
468 |
+
"name": "stdout",
|
469 |
+
"output_type": "stream",
|
470 |
+
"text": [
|
471 |
+
"Number of rows removed due to NaN values: 2\n",
|
472 |
+
"Training set size: 48998\n",
|
473 |
+
"Validation set size: 1000\n"
|
474 |
+
]
|
475 |
+
}
|
476 |
+
],
|
477 |
+
"source": [
|
478 |
+
"import pandas as pd\n",
|
479 |
+
"\n",
|
480 |
+
"# Load the dataset\n",
|
481 |
+
"df = pd.read_csv(\"FalconData2.csv\")\n",
|
482 |
+
"\n",
|
483 |
+
"# Check for NaN values and remove rows with NaN values\n",
|
484 |
+
"# df = df.dropna()\n",
|
485 |
+
"original_length = len(df)\n",
|
486 |
+
"\n",
|
487 |
+
"df = df.dropna()\n",
|
488 |
+
"\n",
|
489 |
+
"removed_rows = original_length - len(df)\n",
|
490 |
+
"print(f\"Number of rows removed due to NaN values: {removed_rows}\")\n",
|
491 |
+
"\n",
|
492 |
+
"# Calculate the index to split the data at the last 2%\n",
|
493 |
+
"split_index = int(len(df) * 0.98)\n",
|
494 |
+
"\n",
|
495 |
+
"# Split the data into training and validation sets\n",
|
496 |
+
"train_df = df.iloc[:split_index] # First 98% for training\n",
|
497 |
+
"validation_df = df.iloc[split_index:] # Last 2% for validation\n",
|
498 |
+
"\n",
|
499 |
+
"# Display the sizes of the training and validation sets\n",
|
500 |
+
"print(f\"Training set size: {len(train_df)}\")\n",
|
501 |
+
"print(f\"Validation set size: {len(validation_df)}\")\n",
|
502 |
+
"\n",
|
503 |
+
"# Save the datasets to new CSV files\n",
|
504 |
+
"train_df.to_csv(\"FalconData_train2.csv\", index=False)\n",
|
505 |
+
"validation_df.to_csv(\"FalconData_validation2.csv\", index=False)\n"
|
506 |
+
]
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"cell_type": "code",
|
510 |
+
"execution_count": 36,
|
511 |
+
"id": "55d929c5-c198-4a91-b31d-65dd83fa00d2",
|
512 |
+
"metadata": {},
|
513 |
+
"outputs": [
|
514 |
+
{
|
515 |
+
"name": "stdout",
|
516 |
+
"output_type": "stream",
|
517 |
+
"text": [
|
518 |
+
"Number of duplicate rows: 4\n",
|
519 |
+
" Text\n",
|
520 |
+
"522 Name:\n",
|
521 |
+
"11745 Description.\\nReviews\\nThere are no reviews yet.\n",
|
522 |
+
"17605 Description.\\nReviews\\nThere are no reviews yet.\n",
|
523 |
+
"42547 !\\n\n"
|
524 |
+
]
|
525 |
+
}
|
526 |
+
],
|
527 |
+
"source": [
|
528 |
+
"# Read the CSV file into a pandas DataFrame\n",
|
529 |
+
"df1 = pd.read_csv(\"FalconData_train2.csv\")\n",
|
530 |
+
"\n",
|
531 |
+
"# Check for duplicate rows\n",
|
532 |
+
"duplicate_rows1 = df1[df1.duplicated()]\n",
|
533 |
+
"\n",
|
534 |
+
"# Print the number of duplicate rows\n",
|
535 |
+
"print(f\"Number of duplicate rows: {len(duplicate_rows1)}\")\n",
|
536 |
+
"\n",
|
537 |
+
"# Print the duplicate rows\n",
|
538 |
+
"print(duplicate_rows1)"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"cell_type": "code",
|
543 |
+
"execution_count": 26,
|
544 |
+
"id": "3cc404d9-e85e-48ff-aa34-750ebe3e3d3c",
|
545 |
+
"metadata": {},
|
546 |
+
"outputs": [
|
547 |
+
{
|
548 |
+
"data": {
|
549 |
+
"text/html": [
|
550 |
+
"<div>\n",
|
551 |
+
"<style scoped>\n",
|
552 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
553 |
+
" vertical-align: middle;\n",
|
554 |
+
" }\n",
|
555 |
+
"\n",
|
556 |
+
" .dataframe tbody tr th {\n",
|
557 |
+
" vertical-align: top;\n",
|
558 |
+
" }\n",
|
559 |
+
"\n",
|
560 |
+
" .dataframe thead th {\n",
|
561 |
+
" text-align: right;\n",
|
562 |
+
" }\n",
|
563 |
+
"</style>\n",
|
564 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
565 |
+
" <thead>\n",
|
566 |
+
" <tr style=\"text-align: right;\">\n",
|
567 |
+
" <th></th>\n",
|
568 |
+
" <th>Text</th>\n",
|
569 |
+
" </tr>\n",
|
570 |
+
" </thead>\n",
|
571 |
+
" <tbody>\n",
|
572 |
+
" <tr>\n",
|
573 |
+
" <th>0</th>\n",
|
574 |
+
" <td>[…]\\nM&S bank […]\\nLowest unsecured loan rate...</td>\n",
|
575 |
+
" </tr>\n",
|
576 |
+
" <tr>\n",
|
577 |
+
" <th>1</th>\n",
|
578 |
+
" <td>JavaScript seems to be disabled in your browse...</td>\n",
|
579 |
+
" </tr>\n",
|
580 |
+
" <tr>\n",
|
581 |
+
" <th>2</th>\n",
|
582 |
+
" <td>CMTech has designed a game to foster social in...</td>\n",
|
583 |
+
" </tr>\n",
|
584 |
+
" <tr>\n",
|
585 |
+
" <th>3</th>\n",
|
586 |
+
" <td>A Storyteller's Point of View\\nMy\\nWriting\\nLe...</td>\n",
|
587 |
+
" </tr>\n",
|
588 |
+
" <tr>\n",
|
589 |
+
" <th>4</th>\n",
|
590 |
+
" <td>mspu.us was registered 1 decade 3 years ago. I...</td>\n",
|
591 |
+
" </tr>\n",
|
592 |
+
" </tbody>\n",
|
593 |
+
"</table>\n",
|
594 |
+
"</div>"
|
595 |
+
],
|
596 |
+
"text/plain": [
|
597 |
+
" Text\n",
|
598 |
+
"0 […]\\nM&S bank […]\\nLowest unsecured loan rate...\n",
|
599 |
+
"1 JavaScript seems to be disabled in your browse...\n",
|
600 |
+
"2 CMTech has designed a game to foster social in...\n",
|
601 |
+
"3 A Storyteller's Point of View\\nMy\\nWriting\\nLe...\n",
|
602 |
+
"4 mspu.us was registered 1 decade 3 years ago. I..."
|
603 |
+
]
|
604 |
+
},
|
605 |
+
"execution_count": 26,
|
606 |
+
"metadata": {},
|
607 |
+
"output_type": "execute_result"
|
608 |
+
}
|
609 |
+
],
|
610 |
+
"source": [
|
611 |
+
"df1.head()"
|
612 |
+
]
|
613 |
+
},
|
614 |
+
{
|
615 |
+
"cell_type": "code",
|
616 |
+
"execution_count": 27,
|
617 |
+
"id": "641c606f-6f7f-4097-a8de-a9f6be0047b1",
|
618 |
+
"metadata": {},
|
619 |
+
"outputs": [
|
620 |
+
{
|
621 |
+
"data": {
|
622 |
+
"text/html": [
|
623 |
+
"<div>\n",
|
624 |
+
"<style scoped>\n",
|
625 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
626 |
+
" vertical-align: middle;\n",
|
627 |
+
" }\n",
|
628 |
+
"\n",
|
629 |
+
" .dataframe tbody tr th {\n",
|
630 |
+
" vertical-align: top;\n",
|
631 |
+
" }\n",
|
632 |
+
"\n",
|
633 |
+
" .dataframe thead th {\n",
|
634 |
+
" text-align: right;\n",
|
635 |
+
" }\n",
|
636 |
+
"</style>\n",
|
637 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
638 |
+
" <thead>\n",
|
639 |
+
" <tr style=\"text-align: right;\">\n",
|
640 |
+
" <th></th>\n",
|
641 |
+
" <th>Text</th>\n",
|
642 |
+
" </tr>\n",
|
643 |
+
" </thead>\n",
|
644 |
+
" <tbody>\n",
|
645 |
+
" <tr>\n",
|
646 |
+
" <th>48995</th>\n",
|
647 |
+
" <td>A Chenango County man was charged Wednesday wi...</td>\n",
|
648 |
+
" </tr>\n",
|
649 |
+
" <tr>\n",
|
650 |
+
" <th>48996</th>\n",
|
651 |
+
" <td>2-Tone Black Personalized Embroidered One Init...</td>\n",
|
652 |
+
" </tr>\n",
|
653 |
+
" <tr>\n",
|
654 |
+
" <th>48997</th>\n",
|
655 |
+
" <td>NARAL Pro-Choice America PAC Endorses Colleen ...</td>\n",
|
656 |
+
" </tr>\n",
|
657 |
+
" <tr>\n",
|
658 |
+
" <th>48998</th>\n",
|
659 |
+
" <td>Posts Tagged by Thomas Paine\\nAEI Hosts Peter ...</td>\n",
|
660 |
+
" </tr>\n",
|
661 |
+
" <tr>\n",
|
662 |
+
" <th>48999</th>\n",
|
663 |
+
" <td>Pantry feeds families in need\\n- Details\\n- Ca...</td>\n",
|
664 |
+
" </tr>\n",
|
665 |
+
" </tbody>\n",
|
666 |
+
"</table>\n",
|
667 |
+
"</div>"
|
668 |
+
],
|
669 |
+
"text/plain": [
|
670 |
+
" Text\n",
|
671 |
+
"48995 A Chenango County man was charged Wednesday wi...\n",
|
672 |
+
"48996 2-Tone Black Personalized Embroidered One Init...\n",
|
673 |
+
"48997 NARAL Pro-Choice America PAC Endorses Colleen ...\n",
|
674 |
+
"48998 Posts Tagged by Thomas Paine\\nAEI Hosts Peter ...\n",
|
675 |
+
"48999 Pantry feeds families in need\\n- Details\\n- Ca..."
|
676 |
+
]
|
677 |
+
},
|
678 |
+
"execution_count": 27,
|
679 |
+
"metadata": {},
|
680 |
+
"output_type": "execute_result"
|
681 |
+
}
|
682 |
+
],
|
683 |
+
"source": [
|
684 |
+
"df1.tail()"
|
685 |
+
]
|
686 |
+
},
|
687 |
+
{
|
688 |
+
"cell_type": "code",
|
689 |
+
"execution_count": null,
|
690 |
+
"id": "b8f7dbf6-5d74-4f8f-85d0-e890a5b8d152",
|
691 |
+
"metadata": {},
|
692 |
+
"outputs": [],
|
693 |
+
"source": []
|
694 |
+
}
|
695 |
+
],
|
696 |
+
"metadata": {
|
697 |
+
"kernelspec": {
|
698 |
+
"display_name": "Python 3 (ipykernel)",
|
699 |
+
"language": "python",
|
700 |
+
"name": "python3"
|
701 |
+
},
|
702 |
+
"language_info": {
|
703 |
+
"codemirror_mode": {
|
704 |
+
"name": "ipython",
|
705 |
+
"version": 3
|
706 |
+
},
|
707 |
+
"file_extension": ".py",
|
708 |
+
"mimetype": "text/x-python",
|
709 |
+
"name": "python",
|
710 |
+
"nbconvert_exporter": "python",
|
711 |
+
"pygments_lexer": "ipython3",
|
712 |
+
"version": "3.11.9"
|
713 |
+
}
|
714 |
+
},
|
715 |
+
"nbformat": 4,
|
716 |
+
"nbformat_minor": 5
|
717 |
+
}
|
FalconData_train.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7f50d9deb1ffab95c3a7026107c574f6024ae3791849e11c1705f8951caa6a2
|
3 |
+
size 23342205
|
FalconData_train2.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c9f267ded526b4479e8a8754d2554ac84100531c971a962cb3fc0d0a74c52de
|
3 |
+
size 126785171
|
FalconData_validation.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
FalconData_validation2.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Shivaen
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ShortGPT
|
2 |
+
Unofficial implementations of:
|
3 |
+
- ["ShortGPT: Layers in Large Language Models are More Redundant Than You Expect"](https://arxiv.org/pdf/2403.03853)
|
4 |
+
- ["The Unreasonable Ineffectiveness of the Deeper Layers"](https://arxiv.org/abs/2403.17887)
|
5 |
+
|
6 |
+
### To Use
|
7 |
+
- Follow Llama 2 setup found [here](https://github.com/facebookresearch/llama).
|
8 |
+
- Reference `short_gpt/short_llama.ipynb` for necessary function calls.
|
9 |
+
- For HuggingFace models, reference this [branch](https://github.com/sramshetty/ShortGPT/tree/hf-models).
|
10 |
+
|
11 |
+
|
12 |
+
### Details
|
13 |
+
- Use a wrapper around Llama to collect hidden states and compute BI (block influence).
|
14 |
+
- BI implementation may be subject to change or improvements if others find issues, thanks in advance!
|
15 |
+
- Sum importance values across layers while inferencing on [pg19](https://huggingface.co/datasets/pg19).
|
16 |
+
- Dataset can be slow to load from huggingface so you may want to use an alternative.
|
17 |
+
- Use sorted layer-wise importance values to determine which layers are least important and subject to removal.
|
18 |
+
- Demonstrate *model-healing* with Mistral-7B-v0.1 described in "The Unreasonable Ineffectiveness of the Deeper Layers", where finetuning with LoRA after layer removal can recover downstream model performance.
|
19 |
+
|
20 |
+
|
21 |
+
### Results
|
22 |
+
Comparison of ShortGPT layers removed on Llama-2-7B (9 least important layers):
|
23 |
+
|
24 |
+
Paper: [27, 26, 25, 28, 24, 29, 23, 21, 22] \
|
25 |
+
This Implementation: [25, 27, 24, 26, 28, 29, 23, 22, 21]
|
26 |
+
|
27 |
+
Same layers but different order.
|
28 |
+
|
29 |
+
### TODO:
|
30 |
+
- [x] Is order significant -> Authors mention that layer order varies between datasets but their relative ordering suggests "similar levels of importance" [link](https://huggingface.co/papers/2403.03853#65f028667c916f24c80e93b3).
|
31 |
+
- [x] Add more models and metrics -> Add experimental support for HF models on this [branch](https://github.com/sramshetty/ShortGPT/tree/hf-models).
|
32 |
+
- [x] Add angular distance metric
|
33 |
+
- [x] Demonstrate model healing using HuggingFace model [here](https://github.com/sramshetty/ShortGPT/blob/hf-models/short_gpt/short_hf.ipynb).
|
34 |
+
|
35 |
+
### Citations
|
36 |
+
```bibtex
|
37 |
+
@misc{men2024shortgpt,
|
38 |
+
title={ShortGPT: Layers in Large Language Models are More Redundant Than You Expect},
|
39 |
+
author={Xin Men and Mingyu Xu and Qingyu Zhang and Bingning Wang and Hongyu Lin and Yaojie Lu and Xianpei Han and Weipeng Chen},
|
40 |
+
year={2024},
|
41 |
+
eprint={2403.03853},
|
42 |
+
archivePrefix={arXiv},
|
43 |
+
primaryClass={cs.CL}
|
44 |
+
}
|
45 |
+
|
46 |
+
@misc{gromov2024unreasonable,
|
47 |
+
title={The Unreasonable Ineffectiveness of the Deeper Layers},
|
48 |
+
author={Andrey Gromov and Kushal Tirumala and Hassan Shapourian and Paolo Glorioso and Daniel A. Roberts},
|
49 |
+
year={2024},
|
50 |
+
eprint={2403.17887},
|
51 |
+
archivePrefix={arXiv},
|
52 |
+
primaryClass={cs.CL}
|
53 |
+
}
|
54 |
+
|
55 |
+
@misc{song2024sleb,
|
56 |
+
title={SLEB: Streamlining LLMs through Redundancy Verification and Elimination of Transformer Blocks},
|
57 |
+
author={Jiwon Song and Kyungseok Oh and Taesu Kim and Hyungjun Kim and Yulhwa Kim and Jae-Joon Kim},
|
58 |
+
year={2024},
|
59 |
+
eprint={2402.09025},
|
60 |
+
archivePrefix={arXiv},
|
61 |
+
primaryClass={cs.CL}
|
62 |
+
}
|
63 |
+
|
64 |
+
@article{raecompressive2019,
|
65 |
+
author = {Rae, Jack W and Potapenko, Anna and Jayakumar, Siddhant M and Hillier, Chloe and Lillicrap, Timothy P},
|
66 |
+
title = {Compressive Transformers for Long-Range Sequence Modelling},
|
67 |
+
journal = {arXiv preprint},
|
68 |
+
url = {https://arxiv.org/abs/1911.05507},
|
69 |
+
year = {2019},
|
70 |
+
}
|
71 |
+
```
|
language_modeling.ipynb
ADDED
@@ -0,0 +1,932 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"Requirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.45.0.dev0)\n",
|
13 |
+
"Requirement already satisfied: datasets in /usr/local/lib/python3.11/dist-packages (2.21.0)\n",
|
14 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.15.4)\n",
|
15 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.24.6)\n",
|
16 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.4)\n",
|
17 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.1)\n",
|
18 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
|
19 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.7.24)\n",
|
20 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
|
21 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.19.1)\n",
|
22 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.4.5)\n",
|
23 |
+
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.66.5)\n",
|
24 |
+
"Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (17.0.0)\n",
|
25 |
+
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.11/dist-packages (from datasets) (0.3.8)\n",
|
26 |
+
"Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from datasets) (2.2.2)\n",
|
27 |
+
"Requirement already satisfied: xxhash in /usr/local/lib/python3.11/dist-packages (from datasets) (3.5.0)\n",
|
28 |
+
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.11/dist-packages (from datasets) (0.70.16)\n",
|
29 |
+
"Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.11/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n",
|
30 |
+
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.11/dist-packages (from datasets) (3.10.5)\n",
|
31 |
+
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (2.4.0)\n",
|
32 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.3.1)\n",
|
33 |
+
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (24.2.0)\n",
|
34 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.4.1)\n",
|
35 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (6.1.0)\n",
|
36 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from aiohttp->datasets) (1.11.1)\n",
|
37 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n",
|
38 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.3.2)\n",
|
39 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.7)\n",
|
40 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.2.2)\n",
|
41 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.7.4)\n",
|
42 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2.9.0.post0)\n",
|
43 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2024.1)\n",
|
44 |
+
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->datasets) (2024.1)\n",
|
45 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
|
46 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
47 |
+
"\u001b[0mCollecting git+https://github.com/huggingface/transformers.git\n",
|
48 |
+
" Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-sok4bqyk\n",
|
49 |
+
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-sok4bqyk\n",
|
50 |
+
" Resolved https://github.com/huggingface/transformers.git to commit 96429e74a8191521bcb4b99f48ad1fbc8f9e6873\n",
|
51 |
+
" Installing build dependencies ... \u001b[?25ldone\n",
|
52 |
+
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
|
53 |
+
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
|
54 |
+
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (3.15.4)\n",
|
55 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (0.24.6)\n",
|
56 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (1.26.4)\n",
|
57 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (24.1)\n",
|
58 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (6.0.2)\n",
|
59 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (2024.7.24)\n",
|
60 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (2.32.3)\n",
|
61 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (0.19.1)\n",
|
62 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (0.4.5)\n",
|
63 |
+
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers==4.45.0.dev0) (4.66.5)\n",
|
64 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.45.0.dev0) (2024.6.1)\n",
|
65 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers==4.45.0.dev0) (4.12.2)\n",
|
66 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.45.0.dev0) (3.3.2)\n",
|
67 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.45.0.dev0) (3.7)\n",
|
68 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.45.0.dev0) (2.2.2)\n",
|
69 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers==4.45.0.dev0) (2024.7.4)\n",
|
70 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
71 |
+
"\u001b[0m"
|
72 |
+
]
|
73 |
+
}
|
74 |
+
],
|
75 |
+
"source": [
|
76 |
+
"# Transformers installation\n",
|
77 |
+
"! pip install transformers datasets\n",
|
78 |
+
"# To install from source instead of the last release, comment the command above and uncomment the following one.\n",
|
79 |
+
"! pip install git+https://github.com/huggingface/transformers.git"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 2,
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [
|
87 |
+
{
|
88 |
+
"name": "stdout",
|
89 |
+
"output_type": "stream",
|
90 |
+
"text": [
|
91 |
+
"Requirement already satisfied: accelerate in /usr/local/lib/python3.11/dist-packages (0.34.2)\n",
|
92 |
+
"Requirement already satisfied: numpy<3.0.0,>=1.17 in /usr/local/lib/python3.11/dist-packages (from accelerate) (1.26.4)\n",
|
93 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from accelerate) (24.1)\n",
|
94 |
+
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate) (6.0.0)\n",
|
95 |
+
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.11/dist-packages (from accelerate) (6.0.2)\n",
|
96 |
+
"Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.11/dist-packages (from accelerate) (2.4.0)\n",
|
97 |
+
"Requirement already satisfied: huggingface-hub>=0.21.0 in /usr/local/lib/python3.11/dist-packages (from accelerate) (0.24.6)\n",
|
98 |
+
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from accelerate) (0.4.5)\n",
|
99 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.21.0->accelerate) (3.15.4)\n",
|
100 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2024.6.1)\n",
|
101 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.21.0->accelerate) (2.32.3)\n",
|
102 |
+
"Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.66.5)\n",
|
103 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.21.0->accelerate) (4.12.2)\n",
|
104 |
+
"Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (1.13.2)\n",
|
105 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (3.3)\n",
|
106 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (3.1.4)\n",
|
107 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
|
108 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
|
109 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
|
110 |
+
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (9.1.0.70)\n",
|
111 |
+
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (12.1.3.1)\n",
|
112 |
+
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (11.0.2.54)\n",
|
113 |
+
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (10.3.2.106)\n",
|
114 |
+
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (11.4.5.107)\n",
|
115 |
+
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (12.1.0.106)\n",
|
116 |
+
"Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (2.20.5)\n",
|
117 |
+
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (12.1.105)\n",
|
118 |
+
"Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch>=1.10.0->accelerate) (3.0.0)\n",
|
119 |
+
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.11/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->accelerate) (12.6.20)\n",
|
120 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.5)\n",
|
121 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.3.2)\n",
|
122 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (3.7)\n",
|
123 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2.2.2)\n",
|
124 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.21.0->accelerate) (2024.7.4)\n",
|
125 |
+
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n",
|
126 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
127 |
+
"\u001b[0mRequirement already satisfied: transformers in /usr/local/lib/python3.11/dist-packages (4.45.0.dev0)\n",
|
128 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.15.4)\n",
|
129 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.24.6)\n",
|
130 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.4)\n",
|
131 |
+
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.1)\n",
|
132 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n",
|
133 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.7.24)\n",
|
134 |
+
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n",
|
135 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.19.1)\n",
|
136 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.4.5)\n",
|
137 |
+
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.66.5)\n",
|
138 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.6.1)\n",
|
139 |
+
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n",
|
140 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.3.2)\n",
|
141 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.7)\n",
|
142 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.2.2)\n",
|
143 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.7.4)\n",
|
144 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
|
145 |
+
"\u001b[0m"
|
146 |
+
]
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"! pip install -U accelerate\n",
|
151 |
+
"! pip install -U transformers"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": 3,
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [],
|
159 |
+
"source": [
|
160 |
+
"# !pip install accelerate"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": 4,
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"# !pip install transformers[torch]"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "markdown",
|
174 |
+
"metadata": {},
|
175 |
+
"source": [
|
176 |
+
"# Causal language modeling"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"metadata": {},
|
182 |
+
"source": [
|
183 |
+
"There are two types of language modeling, causal and masked. This guide illustrates causal language modeling.\n",
|
184 |
+
"Causal language models are frequently used for text generation. You can use these models for creative applications like\n",
|
185 |
+
"choosing your own text adventure or an intelligent coding assistant like Copilot or CodeParrot."
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": 5,
|
191 |
+
"metadata": {
|
192 |
+
"cellView": "form",
|
193 |
+
"hide_input": true
|
194 |
+
},
|
195 |
+
"outputs": [],
|
196 |
+
"source": [
|
197 |
+
"# #@title\n",
|
198 |
+
"# from IPython.display import HTML\n",
|
199 |
+
"\n",
|
200 |
+
"# HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/Vpjb1lu0MDk?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "markdown",
|
205 |
+
"metadata": {},
|
206 |
+
"source": [
|
207 |
+
"Causal language modeling predicts the next token in a sequence of tokens, and the model can only attend to tokens on\n",
|
208 |
+
"the left. This means the model cannot see future tokens. GPT-2 is an example of a causal language model.\n",
|
209 |
+
"\n",
|
210 |
+
"This guide will show you how to:\n",
|
211 |
+
"\n",
|
212 |
+
"1. Finetune [DistilGPT2](https://huggingface.co/distilgpt2) on the [r/askscience](https://www.reddit.com/r/askscience/) subset of the [ELI5](https://huggingface.co/datasets/eli5) dataset.\n",
|
213 |
+
"2. Use your finetuned model for inference.\n",
|
214 |
+
"\n",
|
215 |
+
"<Tip>\n",
|
216 |
+
"You can finetune other architectures for causal language modeling following the same steps in this guide.\n",
|
217 |
+
"Choose one of the following architectures:\n",
|
218 |
+
"\n",
|
219 |
+
"<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->\n",
|
220 |
+
"[BART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bart), [BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bert), [Bert Generation](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bert-generation), [BigBird](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/big_bird), [BigBird-Pegasus](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bigbird_pegasus), [BioGpt](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/biogpt), [Blenderbot](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/blenderbot), [BlenderbotSmall](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/blenderbot-small), [BLOOM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/bloom), [CamemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/camembert), [CodeGen](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/codegen), [CPM-Ant](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/cpmant), [CTRL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ctrl), [Data2VecText](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/data2vec-text), [ELECTRA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/electra), [ERNIE](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/ernie), [GIT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/git), [GPT-Sw3](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt-sw3), [OpenAI GPT-2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt2), [GPTBigCode](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_bigcode), [GPT Neo](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neo), [GPT NeoX](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neox), [GPT NeoX Japanese](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gpt_neox_japanese), [GPT-J](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/gptj), [LLaMA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/llama), [Marian](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/marian), [mBART](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mbart), [MEGA](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mega), [Megatron-BERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/megatron-bert), [MVP](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/mvp), [OpenLlama](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/open-llama), [OpenAI GPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/openai-gpt), [OPT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/opt), [Pegasus](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/pegasus), [PLBart](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/plbart), [ProphetNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/prophetnet), [QDQBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/qdqbert), [Reformer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/reformer), [RemBERT](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/rembert), [RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta), [RoBERTa-PreLayerNorm](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roberta-prelayernorm), [RoCBert](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roc_bert), [RoFormer](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/roformer), [RWKV](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/rwkv), [Speech2Text2](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/speech_to_text_2), [Transformer-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/transfo-xl), [TrOCR](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/trocr), [XGLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xglm), [XLM](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm), [XLM-ProphetNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-prophetnet), [XLM-RoBERTa](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta), [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlm-roberta-xl), [XLNet](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xlnet), [X-MOD](https://huggingface.co/docs/transformers/main/en/tasks/../model_doc/xmod)\n",
|
221 |
+
"\n",
|
222 |
+
"\n",
|
223 |
+
"<!--End of the generated tip-->\n",
|
224 |
+
"\n",
|
225 |
+
"</Tip>\n",
|
226 |
+
"\n",
|
227 |
+
"Before you begin, make sure you have all the necessary libraries installed:\n",
|
228 |
+
"\n",
|
229 |
+
"```bash\n",
|
230 |
+
"pip install transformers datasets evaluate\n",
|
231 |
+
"```\n",
|
232 |
+
"\n",
|
233 |
+
"We encourage you to log in to your Hugging Face account so you can upload and share your model with the community. When prompted, enter your token to log in:"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 6,
|
239 |
+
"metadata": {},
|
240 |
+
"outputs": [],
|
241 |
+
"source": [
|
242 |
+
"# from huggingface_hub import notebook_login\n",
|
243 |
+
"\n",
|
244 |
+
"# notebook_login()"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
{
|
248 |
+
"cell_type": "markdown",
|
249 |
+
"metadata": {},
|
250 |
+
"source": [
|
251 |
+
"## Load ELI5 dataset"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "markdown",
|
256 |
+
"metadata": {},
|
257 |
+
"source": [
|
258 |
+
"Start by loading a smaller subset of the r/askscience subset of the ELI5 dataset from the 🤗 Datasets library.\n",
|
259 |
+
" This'll give you a chance to experiment and make sure everything works before spending more time training on the full dataset."
|
260 |
+
]
|
261 |
+
},
|
262 |
+
{
|
263 |
+
"cell_type": "code",
|
264 |
+
"execution_count": 7,
|
265 |
+
"metadata": {},
|
266 |
+
"outputs": [],
|
267 |
+
"source": [
|
268 |
+
"# from datasets import load_dataset\n",
|
269 |
+
"\n",
|
270 |
+
"# eli5 = load_dataset(\"eli5\", split=\"train_asks[:5000]\")"
|
271 |
+
]
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"cell_type": "code",
|
275 |
+
"execution_count": 8,
|
276 |
+
"metadata": {},
|
277 |
+
"outputs": [],
|
278 |
+
"source": [
|
279 |
+
"from datasets import load_dataset\n",
|
280 |
+
"# Falcon = load_dataset(\"csv\", data_files=\"FalconData.csv\")\n",
|
281 |
+
"Falcon = load_dataset('csv', data_files={\"train\": 'FalconData_train.csv', \"validation\": 'FalconData_validation.csv'})"
|
282 |
+
]
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"cell_type": "markdown",
|
286 |
+
"metadata": {},
|
287 |
+
"source": [
|
288 |
+
"Split the dataset's `train_asks` split into a train and test set with the [train_test_split](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.train_test_split) method:"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 9,
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"# Falcon = Falcon.train_test_split(test_size=0.10)"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "markdown",
|
302 |
+
"metadata": {},
|
303 |
+
"source": [
|
304 |
+
"Then take a look at an example:"
|
305 |
+
]
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"cell_type": "code",
|
309 |
+
"execution_count": 10,
|
310 |
+
"metadata": {},
|
311 |
+
"outputs": [
|
312 |
+
{
|
313 |
+
"data": {
|
314 |
+
"text/plain": [
|
315 |
+
"{'Text': 'Once the kind of organization is decided, right now is the time for the purpose of the huge talk with the parents. Additionally, you will have to credit your company while using the board. Right now there a few techniques which usually you can get started on the cellular phone restoration organization.\\nBecause you develop your organization, you can want to realize how to raise your skill sets and tactics. After formulating your firm notion and organizing the funds, the next idea to perform is to check out the organization. In addition , if occur to be certainly not in the automobile business yet work via the internet with consumers via the net and email, after that some of your suggestions you are going to see are certain to get the work performed to get you too.\\nWhat you will requirement for your company depends upon a great deal of factors, therefore is actually ideal to pay a visit to the Nevada Department of Insurance internet site to get detailed info. Once you wish to start up your unique enterprise, then simply it is important to apply entitlements of your have firm. The few males and ladies in little business want to know more and carry out more with a great deal fewer. For illustration, the ordinary organization runs the data centre 10 hours every day. Even more businesses experience began to take notice of the huge benefits of giving birth to a business program analyst in staff. As you take your small business to the world-wide market segments, it is going to become important to think about a lot a large number of things to ascertain the organization efficiently. Decide what kind of business being you desire to allocate to your panorama business.\\nRecuperate this will depend after the sort of assistance you give. Right now there are a lot of different varieties of Web service yet I will list the most typical types out there. Found in addition, you will need high-speed on the net service to mail and acquire job data files to your consumers.\\nMany people today are unsuccessful in organization given that they make avoidable mistakes! A put together organization is a great likelihood to communicate the fine art just the way that you like it. You can actually without difficulty control the company if it’s legitimate. While not efficient communication, the businesses could not discover the strategies to create the business and website link while using the all over the world clients and companions. A great excellent car shop tools business will make sure you experience all owners and parts manuals alongside one another with service plan directives for all of you heavy machines you purchase or perhaps let out.\\nIn case you blowing wind up going, where you began your company won’t change! It’s actually now possible to advertise your business to anybody anywhere for the purpose of practically no selling price. So you may absolutely cost-free to pay attention to different important things that matter to you such as growing your business and a lot more. If the service is mostly an operation product, you should supply a replicate within the operation contract. Websites like craigslist and or perhaps Tradelit That is certainly, in the event people are likely to build a company. Presently a days and nights Many businesses are unaware of the significance of SEO in improving the internet occurrence. If you expect to have carrying out a fee-for-service tutoring organization, then you might preference to think about signing up your company considering the state.\\nKind of organization Primarily based upon at the sort of business, you need to do business with a variety of organizations. Not only a single company are able to take advantage of a similar well-known. If an organization can better figure out their normal user’s requires, it will develop into a excellent less complicated to guarantee that every consumer has a confident knowledge in handling your business with regards to a entire. Even firms want a huge data stats official certifications prior to taking the help of a person. As a result, all of them over the world are inclined to take full advantage of technology, on particular, cordless devices and public hotspots. The organization should also be capable of offering any kind of teaching vital to buy and sell each machine safely. Daily, an increasing number of businesses are putting up or perhaps establishing an electronic business. For more info read right here whatsbakingsd.com .'}"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
"execution_count": 10,
|
319 |
+
"metadata": {},
|
320 |
+
"output_type": "execute_result"
|
321 |
+
}
|
322 |
+
],
|
323 |
+
"source": [
|
324 |
+
"Falcon['train'][0]"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "code",
|
329 |
+
"execution_count": 11,
|
330 |
+
"metadata": {},
|
331 |
+
"outputs": [
|
332 |
+
{
|
333 |
+
"data": {
|
334 |
+
"text/plain": [
|
335 |
+
"{'Text': ', John Morris (19282003), historian\\nOxford Biography Index Number 101089999 [what is this?] Primary authority: Oxford DNB\\nColin Lucas, Roberts, John Morris (19282003), first published\\nJan 2007; online edn, Oct 2009, 1683 words, with portrait illustration\\n> View John Roberts complete biography [Oxford DNB subscription required; no subscription?]\\n> View John Roberts complete biography\\n[WWW subscription required; no subscription?]'}"
|
336 |
+
]
|
337 |
+
},
|
338 |
+
"execution_count": 11,
|
339 |
+
"metadata": {},
|
340 |
+
"output_type": "execute_result"
|
341 |
+
}
|
342 |
+
],
|
343 |
+
"source": [
|
344 |
+
"Falcon['validation'][0]"
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "markdown",
|
349 |
+
"metadata": {},
|
350 |
+
"source": [
|
351 |
+
"While this may look like a lot, you're only really interested in the `text` field. What's cool about language modeling\n",
|
352 |
+
"tasks is you don't need labels (also known as an unsupervised task) because the next word *is* the label."
|
353 |
+
]
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"cell_type": "markdown",
|
357 |
+
"metadata": {},
|
358 |
+
"source": [
|
359 |
+
"## Preprocess"
|
360 |
+
]
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"execution_count": 12,
|
365 |
+
"metadata": {
|
366 |
+
"cellView": "form",
|
367 |
+
"hide_input": true
|
368 |
+
},
|
369 |
+
"outputs": [],
|
370 |
+
"source": [
|
371 |
+
"# #@title\n",
|
372 |
+
"# from IPython.display import HTML\n",
|
373 |
+
"\n",
|
374 |
+
"# HTML('<iframe width=\"560\" height=\"315\" src=\"https://www.youtube.com/embed/ma1TrR7gE7I?rel=0&controls=0&showinfo=0\" frameborder=\"0\" allowfullscreen></iframe>')"
|
375 |
+
]
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"cell_type": "markdown",
|
379 |
+
"metadata": {},
|
380 |
+
"source": [
|
381 |
+
"The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:"
|
382 |
+
]
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"cell_type": "code",
|
386 |
+
"execution_count": 28,
|
387 |
+
"metadata": {},
|
388 |
+
"outputs": [
|
389 |
+
{
|
390 |
+
"name": "stderr",
|
391 |
+
"output_type": "stream",
|
392 |
+
"text": [
|
393 |
+
"/usr/local/lib/python3.11/dist-packages/transformers/tokenization_utils_base.py:1614: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
|
394 |
+
" warnings.warn(\n"
|
395 |
+
]
|
396 |
+
}
|
397 |
+
],
|
398 |
+
"source": [
|
399 |
+
"from transformers import AutoTokenizer, GPT2TokenizerFast\n",
|
400 |
+
"\n",
|
401 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
|
402 |
+
"\n",
|
403 |
+
"\n",
|
404 |
+
"# tokenizer = GPT2TokenizerFast.from_pretrained(\"Xenova/gpt-4\")#, cache_dir=cache_dir)\n",
|
405 |
+
"tokenizer.pad_token = tokenizer.eos_token"
|
406 |
+
]
|
407 |
+
},
|
408 |
+
{
|
409 |
+
"cell_type": "markdown",
|
410 |
+
"metadata": {},
|
411 |
+
"source": [
|
412 |
+
"You'll notice from the example above, the `text` field is actually nested inside `answers`. This means you'll need to\n",
|
413 |
+
"extract the `text` subfield from its nested structure with the [`flatten`](https://huggingface.co/docs/datasets/process.html#flatten) method:"
|
414 |
+
]
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "code",
|
418 |
+
"execution_count": 14,
|
419 |
+
"metadata": {},
|
420 |
+
"outputs": [
|
421 |
+
{
|
422 |
+
"data": {
|
423 |
+
"text/plain": [
|
424 |
+
"{'Text': 'Once the kind of organization is decided, right now is the time for the purpose of the huge talk with the parents. Additionally, you will have to credit your company while using the board. Right now there a few techniques which usually you can get started on the cellular phone restoration organization.\\nBecause you develop your organization, you can want to realize how to raise your skill sets and tactics. After formulating your firm notion and organizing the funds, the next idea to perform is to check out the organization. In addition , if occur to be certainly not in the automobile business yet work via the internet with consumers via the net and email, after that some of your suggestions you are going to see are certain to get the work performed to get you too.\\nWhat you will requirement for your company depends upon a great deal of factors, therefore is actually ideal to pay a visit to the Nevada Department of Insurance internet site to get detailed info. Once you wish to start up your unique enterprise, then simply it is important to apply entitlements of your have firm. The few males and ladies in little business want to know more and carry out more with a great deal fewer. For illustration, the ordinary organization runs the data centre 10 hours every day. Even more businesses experience began to take notice of the huge benefits of giving birth to a business program analyst in staff. As you take your small business to the world-wide market segments, it is going to become important to think about a lot a large number of things to ascertain the organization efficiently. Decide what kind of business being you desire to allocate to your panorama business.\\nRecuperate this will depend after the sort of assistance you give. Right now there are a lot of different varieties of Web service yet I will list the most typical types out there. Found in addition, you will need high-speed on the net service to mail and acquire job data files to your consumers.\\nMany people today are unsuccessful in organization given that they make avoidable mistakes! A put together organization is a great likelihood to communicate the fine art just the way that you like it. You can actually without difficulty control the company if it’s legitimate. While not efficient communication, the businesses could not discover the strategies to create the business and website link while using the all over the world clients and companions. A great excellent car shop tools business will make sure you experience all owners and parts manuals alongside one another with service plan directives for all of you heavy machines you purchase or perhaps let out.\\nIn case you blowing wind up going, where you began your company won’t change! It’s actually now possible to advertise your business to anybody anywhere for the purpose of practically no selling price. So you may absolutely cost-free to pay attention to different important things that matter to you such as growing your business and a lot more. If the service is mostly an operation product, you should supply a replicate within the operation contract. Websites like craigslist and or perhaps Tradelit That is certainly, in the event people are likely to build a company. Presently a days and nights Many businesses are unaware of the significance of SEO in improving the internet occurrence. If you expect to have carrying out a fee-for-service tutoring organization, then you might preference to think about signing up your company considering the state.\\nKind of organization Primarily based upon at the sort of business, you need to do business with a variety of organizations. Not only a single company are able to take advantage of a similar well-known. If an organization can better figure out their normal user’s requires, it will develop into a excellent less complicated to guarantee that every consumer has a confident knowledge in handling your business with regards to a entire. Even firms want a huge data stats official certifications prior to taking the help of a person. As a result, all of them over the world are inclined to take full advantage of technology, on particular, cordless devices and public hotspots. The organization should also be capable of offering any kind of teaching vital to buy and sell each machine safely. Daily, an increasing number of businesses are putting up or perhaps establishing an electronic business. For more info read right here whatsbakingsd.com .'}"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
"execution_count": 14,
|
428 |
+
"metadata": {},
|
429 |
+
"output_type": "execute_result"
|
430 |
+
}
|
431 |
+
],
|
432 |
+
"source": [
|
433 |
+
"Falcon = Falcon.flatten()\n",
|
434 |
+
"Falcon[\"train\"][0]"
|
435 |
+
]
|
436 |
+
},
|
437 |
+
{
|
438 |
+
"cell_type": "markdown",
|
439 |
+
"metadata": {},
|
440 |
+
"source": [
|
441 |
+
"Each subfield is now a separate column as indicated by the `answers` prefix, and the `text` field is a list now. Instead\n",
|
442 |
+
"of tokenizing each sentence separately, convert the list to a string so you can jointly tokenize them.\n",
|
443 |
+
"\n",
|
444 |
+
"Here is a first preprocessing function to join the list of strings for each example and tokenize the result:"
|
445 |
+
]
|
446 |
+
},
|
447 |
+
{
|
448 |
+
"cell_type": "code",
|
449 |
+
"execution_count": 15,
|
450 |
+
"metadata": {},
|
451 |
+
"outputs": [],
|
452 |
+
"source": [
|
453 |
+
"def preprocess_function(examples):\n",
|
454 |
+
" return tokenizer([\" \".join(x) for x in examples[\"Text\"]])"
|
455 |
+
]
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"cell_type": "markdown",
|
459 |
+
"metadata": {},
|
460 |
+
"source": [
|
461 |
+
"To apply this preprocessing function over the entire dataset, use the 🤗 Datasets [map](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map) method. You can speed up the `map` function by setting `batched=True` to process multiple elements of the dataset at once, and increasing the number of processes with `num_proc`. Remove any columns you don't need:"
|
462 |
+
]
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "code",
|
466 |
+
"execution_count": 16,
|
467 |
+
"metadata": {},
|
468 |
+
"outputs": [],
|
469 |
+
"source": [
|
470 |
+
"tokenized_Falcon = Falcon.map(\n",
|
471 |
+
" preprocess_function,\n",
|
472 |
+
" batched=True,\n",
|
473 |
+
" num_proc=4,\n",
|
474 |
+
" remove_columns=Falcon[\"train\"].column_names,\n",
|
475 |
+
")"
|
476 |
+
]
|
477 |
+
},
|
478 |
+
{
|
479 |
+
"cell_type": "markdown",
|
480 |
+
"metadata": {},
|
481 |
+
"source": [
|
482 |
+
"This dataset contains the token sequences, but some of these are longer than the maximum input length for the model.\n",
|
483 |
+
"\n",
|
484 |
+
"You can now use a second preprocessing function to\n",
|
485 |
+
"- concatenate all the sequences\n",
|
486 |
+
"- split the concatenated sequences into shorter chunks defined by `block_size`, which should be both shorter than the maximum input length and short enough for your GPU RAM."
|
487 |
+
]
|
488 |
+
},
|
489 |
+
{
|
490 |
+
"cell_type": "code",
|
491 |
+
"execution_count": 17,
|
492 |
+
"metadata": {},
|
493 |
+
"outputs": [],
|
494 |
+
"source": [
|
495 |
+
"block_size = 1048\n",
|
496 |
+
"\n",
|
497 |
+
"\n",
|
498 |
+
"def group_texts(examples):\n",
|
499 |
+
" # Concatenate all texts.\n",
|
500 |
+
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
501 |
+
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
502 |
+
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
503 |
+
" # customize this part to your needs.\n",
|
504 |
+
" if total_length >= block_size:\n",
|
505 |
+
" total_length = (total_length // block_size) * block_size\n",
|
506 |
+
" # Split by chunks of block_size.\n",
|
507 |
+
" result = {\n",
|
508 |
+
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
|
509 |
+
" for k, t in concatenated_examples.items()\n",
|
510 |
+
" }\n",
|
511 |
+
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
|
512 |
+
" return result"
|
513 |
+
]
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"cell_type": "markdown",
|
517 |
+
"metadata": {},
|
518 |
+
"source": [
|
519 |
+
"Apply the `group_texts` function over the entire dataset:"
|
520 |
+
]
|
521 |
+
},
|
522 |
+
{
|
523 |
+
"cell_type": "code",
|
524 |
+
"execution_count": 30,
|
525 |
+
"metadata": {},
|
526 |
+
"outputs": [],
|
527 |
+
"source": [
|
528 |
+
"lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)"
|
529 |
+
]
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"cell_type": "markdown",
|
533 |
+
"metadata": {},
|
534 |
+
"source": [
|
535 |
+
"Now create a batch of examples using [DataCollatorForLanguageModeling](https://huggingface.co/docs/transformers/main/en/main_classes/data_collator#transformers.DataCollatorForLanguageModeling). It's more efficient to *dynamically pad* the\n",
|
536 |
+
"sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.\n",
|
537 |
+
"\n",
|
538 |
+
"Use the end-of-sequence token as the padding token and set `mlm=False`. This will use the inputs as labels shifted to the right by one element:"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"cell_type": "code",
|
543 |
+
"execution_count": 29,
|
544 |
+
"metadata": {},
|
545 |
+
"outputs": [],
|
546 |
+
"source": [
|
547 |
+
"from transformers import DataCollatorForLanguageModeling\n",
|
548 |
+
"\n",
|
549 |
+
"tokenizer.pad_token = tokenizer.eos_token\n",
|
550 |
+
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)"
|
551 |
+
]
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"cell_type": "markdown",
|
555 |
+
"metadata": {},
|
556 |
+
"source": [
|
557 |
+
"## Train"
|
558 |
+
]
|
559 |
+
},
|
560 |
+
{
|
561 |
+
"cell_type": "markdown",
|
562 |
+
"metadata": {},
|
563 |
+
"source": [
|
564 |
+
"<Tip>\n",
|
565 |
+
"\n",
|
566 |
+
"If you aren't familiar with finetuning a model with the [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer), take a look at the [basic tutorial](https://huggingface.co/docs/transformers/main/en/tasks/../training#train-with-pytorch-trainer)!\n",
|
567 |
+
"\n",
|
568 |
+
"</Tip>\n",
|
569 |
+
"\n",
|
570 |
+
"You're ready to start training your model now! Load DistilGPT2 with [AutoModelForCausalLM](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForCausalLM):"
|
571 |
+
]
|
572 |
+
},
|
573 |
+
{
|
574 |
+
"cell_type": "code",
|
575 |
+
"execution_count": 20,
|
576 |
+
"metadata": {},
|
577 |
+
"outputs": [],
|
578 |
+
"source": [
|
579 |
+
"from transformers import AutoModelForCausalLM, TrainingArguments, Trainer\n",
|
580 |
+
"import torch\n",
|
581 |
+
"model = AutoModelForCausalLM.from_pretrained(\"rwh/tinytoo\", torch_dtype=torch.bfloat16) "
|
582 |
+
]
|
583 |
+
},
|
584 |
+
{
|
585 |
+
"cell_type": "markdown",
|
586 |
+
"metadata": {},
|
587 |
+
"source": [
|
588 |
+
"At this point, only three steps remain:\n",
|
589 |
+
"\n",
|
590 |
+
"1. Define your training hyperparameters in [TrainingArguments](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments). The only required parameter is `output_dir` which specifies where to save your model. You'll push this model to the Hub by setting `push_to_hub=True` (you need to be signed in to Hugging Face to upload your model).\n",
|
591 |
+
"2. Pass the training arguments to [Trainer](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer) along with the model, datasets, and data collator.\n",
|
592 |
+
"3. Call [train()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.train) to finetune your model."
|
593 |
+
]
|
594 |
+
},
|
595 |
+
{
|
596 |
+
"cell_type": "code",
|
597 |
+
"execution_count": 40,
|
598 |
+
"metadata": {},
|
599 |
+
"outputs": [],
|
600 |
+
"source": [
|
601 |
+
"import torch\n",
|
602 |
+
"torch.cuda.empty_cache()"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
{
|
606 |
+
"cell_type": "code",
|
607 |
+
"execution_count": 41,
|
608 |
+
"metadata": {},
|
609 |
+
"outputs": [],
|
610 |
+
"source": [
|
611 |
+
"import torch\n",
|
612 |
+
"import gc\n",
|
613 |
+
"\n",
|
614 |
+
"# del tensor_name # Delete the tensor\n",
|
615 |
+
"gc.collect() # Collect garbage\n",
|
616 |
+
"torch.cuda.empty_cache() # Clear cache"
|
617 |
+
]
|
618 |
+
},
|
619 |
+
{
|
620 |
+
"cell_type": "code",
|
621 |
+
"execution_count": 44,
|
622 |
+
"metadata": {},
|
623 |
+
"outputs": [],
|
624 |
+
"source": [
|
625 |
+
"torch.cuda.empty_cache()"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"cell_type": "code",
|
630 |
+
"execution_count": 45,
|
631 |
+
"metadata": {},
|
632 |
+
"outputs": [
|
633 |
+
{
|
634 |
+
"data": {
|
635 |
+
"text/plain": [
|
636 |
+
"<torch.autograd.grad_mode.no_grad at 0x7f0a24519350>"
|
637 |
+
]
|
638 |
+
},
|
639 |
+
"execution_count": 45,
|
640 |
+
"metadata": {},
|
641 |
+
"output_type": "execute_result"
|
642 |
+
}
|
643 |
+
],
|
644 |
+
"source": [
|
645 |
+
"torch.no_grad()"
|
646 |
+
]
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"cell_type": "code",
|
650 |
+
"execution_count": 25,
|
651 |
+
"metadata": {},
|
652 |
+
"outputs": [
|
653 |
+
{
|
654 |
+
"data": {
|
655 |
+
"text/plain": [
|
656 |
+
"LlamaForCausalLM(\n",
|
657 |
+
" (model): LlamaModel(\n",
|
658 |
+
" (embed_tokens): Embedding(50257, 1408)\n",
|
659 |
+
" (layers): ModuleList(\n",
|
660 |
+
" (0-23): 24 x LlamaDecoderLayer(\n",
|
661 |
+
" (self_attn): LlamaSdpaAttention(\n",
|
662 |
+
" (q_proj): Linear(in_features=1408, out_features=1408, bias=False)\n",
|
663 |
+
" (k_proj): Linear(in_features=1408, out_features=1408, bias=False)\n",
|
664 |
+
" (v_proj): Linear(in_features=1408, out_features=1408, bias=False)\n",
|
665 |
+
" (o_proj): Linear(in_features=1408, out_features=1408, bias=False)\n",
|
666 |
+
" (rotary_emb): LlamaRotaryEmbedding()\n",
|
667 |
+
" )\n",
|
668 |
+
" (mlp): LlamaMLP(\n",
|
669 |
+
" (gate_proj): Linear(in_features=1408, out_features=4340, bias=False)\n",
|
670 |
+
" (up_proj): Linear(in_features=1408, out_features=4340, bias=False)\n",
|
671 |
+
" (down_proj): Linear(in_features=4340, out_features=1408, bias=False)\n",
|
672 |
+
" (act_fn): SiLU()\n",
|
673 |
+
" )\n",
|
674 |
+
" (input_layernorm): LlamaRMSNorm((1408,), eps=1e-05)\n",
|
675 |
+
" (post_attention_layernorm): LlamaRMSNorm((1408,), eps=1e-05)\n",
|
676 |
+
" )\n",
|
677 |
+
" )\n",
|
678 |
+
" (norm): LlamaRMSNorm((1408,), eps=1e-05)\n",
|
679 |
+
" (rotary_emb): LlamaRotaryEmbedding()\n",
|
680 |
+
" )\n",
|
681 |
+
" (lm_head): Linear(in_features=1408, out_features=50257, bias=False)\n",
|
682 |
+
")"
|
683 |
+
]
|
684 |
+
},
|
685 |
+
"execution_count": 25,
|
686 |
+
"metadata": {},
|
687 |
+
"output_type": "execute_result"
|
688 |
+
}
|
689 |
+
],
|
690 |
+
"source": [
|
691 |
+
"model.to('cuda')"
|
692 |
+
]
|
693 |
+
},
|
694 |
+
{
|
695 |
+
"cell_type": "code",
|
696 |
+
"execution_count": 31,
|
697 |
+
"metadata": {},
|
698 |
+
"outputs": [
|
699 |
+
{
|
700 |
+
"name": "stderr",
|
701 |
+
"output_type": "stream",
|
702 |
+
"text": [
|
703 |
+
"/usr/local/lib/python3.11/dist-packages/transformers/training_args.py:1541: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
|
704 |
+
" warnings.warn(\n"
|
705 |
+
]
|
706 |
+
}
|
707 |
+
],
|
708 |
+
"source": [
|
709 |
+
"training_args = TrainingArguments(\n",
|
710 |
+
" output_dir=\"Fine-Tuned-S9\",\n",
|
711 |
+
" bf16=True,\n",
|
712 |
+
" # evaluation_strategy=\"epoch\",\n",
|
713 |
+
" evaluation_strategy=\"steps\",\n",
|
714 |
+
" learning_rate=2e-5,\n",
|
715 |
+
" weight_decay=0.01,\n",
|
716 |
+
" num_train_epochs=1,\n",
|
717 |
+
" per_device_train_batch_size=2,\n",
|
718 |
+
" per_device_eval_batch_size=2,\n",
|
719 |
+
" # lr_scheduler_type = 'cosine',\n",
|
720 |
+
" push_to_hub=False,\n",
|
721 |
+
" save_total_limit = 2,\n",
|
722 |
+
" # save_strategy = “no”\n",
|
723 |
+
" load_best_model_at_end=False\n",
|
724 |
+
")\n",
|
725 |
+
"\n",
|
726 |
+
"trainer = Trainer(\n",
|
727 |
+
" model=model,\n",
|
728 |
+
" args=training_args,\n",
|
729 |
+
" train_dataset=lm_dataset[\"train\"],\n",
|
730 |
+
" eval_dataset=lm_dataset[\"validation\"],\n",
|
731 |
+
" # eval_dataset=lm_dataset[\"test\"],\n",
|
732 |
+
" data_collator=data_collator,\n",
|
733 |
+
")\n",
|
734 |
+
"\n",
|
735 |
+
"# trainer.train()"
|
736 |
+
]
|
737 |
+
},
|
738 |
+
{
|
739 |
+
"cell_type": "code",
|
740 |
+
"execution_count": null,
|
741 |
+
"metadata": {},
|
742 |
+
"outputs": [],
|
743 |
+
"source": [
|
744 |
+
"trainer.train()"
|
745 |
+
]
|
746 |
+
},
|
747 |
+
{
|
748 |
+
"cell_type": "markdown",
|
749 |
+
"metadata": {},
|
750 |
+
"source": [
|
751 |
+
"Once training is completed, use the [evaluate()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.evaluate) method to evaluate your model and get its perplexity:"
|
752 |
+
]
|
753 |
+
},
|
754 |
+
{
|
755 |
+
"cell_type": "code",
|
756 |
+
"execution_count": null,
|
757 |
+
"metadata": {},
|
758 |
+
"outputs": [],
|
759 |
+
"source": [
|
760 |
+
"import math\n",
|
761 |
+
"\n",
|
762 |
+
"eval_results = trainer.evaluate()\n",
|
763 |
+
"print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
|
764 |
+
]
|
765 |
+
},
|
766 |
+
{
|
767 |
+
"cell_type": "markdown",
|
768 |
+
"metadata": {},
|
769 |
+
"source": [
|
770 |
+
"Then share your model to the Hub with the [push_to_hub()](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Trainer.push_to_hub) method so everyone can use your model:"
|
771 |
+
]
|
772 |
+
},
|
773 |
+
{
|
774 |
+
"cell_type": "code",
|
775 |
+
"execution_count": null,
|
776 |
+
"metadata": {},
|
777 |
+
"outputs": [],
|
778 |
+
"source": [
|
779 |
+
"# trainer.push_to_hub()"
|
780 |
+
]
|
781 |
+
},
|
782 |
+
{
|
783 |
+
"cell_type": "markdown",
|
784 |
+
"metadata": {},
|
785 |
+
"source": [
|
786 |
+
"<Tip>\n",
|
787 |
+
"\n",
|
788 |
+
"For a more in-depth example of how to finetune a model for causal language modeling, take a look at the corresponding\n",
|
789 |
+
"[PyTorch notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling.ipynb)\n",
|
790 |
+
"or [TensorFlow notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/language_modeling-tf.ipynb).\n",
|
791 |
+
"\n",
|
792 |
+
"</Tip>"
|
793 |
+
]
|
794 |
+
},
|
795 |
+
{
|
796 |
+
"cell_type": "markdown",
|
797 |
+
"metadata": {},
|
798 |
+
"source": [
|
799 |
+
"## Inference"
|
800 |
+
]
|
801 |
+
},
|
802 |
+
{
|
803 |
+
"cell_type": "markdown",
|
804 |
+
"metadata": {},
|
805 |
+
"source": [
|
806 |
+
"Great, now that you've finetuned a model, you can use it for inference!\n",
|
807 |
+
"\n",
|
808 |
+
"Come up with a prompt you'd like to generate text from:"
|
809 |
+
]
|
810 |
+
},
|
811 |
+
{
|
812 |
+
"cell_type": "code",
|
813 |
+
"execution_count": null,
|
814 |
+
"metadata": {},
|
815 |
+
"outputs": [],
|
816 |
+
"source": [
|
817 |
+
"# prompt = \"Somatic hypermutation allows the immune system to\""
|
818 |
+
]
|
819 |
+
},
|
820 |
+
{
|
821 |
+
"cell_type": "markdown",
|
822 |
+
"metadata": {},
|
823 |
+
"source": [
|
824 |
+
"The simplest way to try out your finetuned model for inference is to use it in a [pipeline()](https://huggingface.co/docs/transformers/main/en/main_classes/pipelines#transformers.pipeline). Instantiate a `pipeline` for text generation with your model, and pass your text to it:"
|
825 |
+
]
|
826 |
+
},
|
827 |
+
{
|
828 |
+
"cell_type": "code",
|
829 |
+
"execution_count": null,
|
830 |
+
"metadata": {},
|
831 |
+
"outputs": [],
|
832 |
+
"source": [
|
833 |
+
"# from transformers import pipeline\n",
|
834 |
+
"# # checkpoint-4000\n",
|
835 |
+
"# generator = pipeline(\"text-generation\", model=\"Fine-Tuned-S9/checkpoint-4000\")\n",
|
836 |
+
"# generator(prompt)"
|
837 |
+
]
|
838 |
+
},
|
839 |
+
{
|
840 |
+
"cell_type": "markdown",
|
841 |
+
"metadata": {},
|
842 |
+
"source": [
|
843 |
+
"Tokenize the text and return the `input_ids` as PyTorch tensors:"
|
844 |
+
]
|
845 |
+
},
|
846 |
+
{
|
847 |
+
"cell_type": "code",
|
848 |
+
"execution_count": null,
|
849 |
+
"metadata": {},
|
850 |
+
"outputs": [],
|
851 |
+
"source": [
|
852 |
+
"# from transformers import AutoTokenizer\n",
|
853 |
+
"\n",
|
854 |
+
"# tokenizer = AutoTokenizer.from_pretrained(\"Xenova/gpt-4\")\n",
|
855 |
+
"# inputs = tokenizer(prompt, return_tensors=\"pt\").input_ids"
|
856 |
+
]
|
857 |
+
},
|
858 |
+
{
|
859 |
+
"cell_type": "markdown",
|
860 |
+
"metadata": {},
|
861 |
+
"source": [
|
862 |
+
"Use the [generate()](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate) method to generate text.\n",
|
863 |
+
"For more details about the different text generation strategies and parameters for controlling generation, check out the [Text generation strategies](https://huggingface.co/docs/transformers/main/en/tasks/../generation_strategies) page."
|
864 |
+
]
|
865 |
+
},
|
866 |
+
{
|
867 |
+
"cell_type": "code",
|
868 |
+
"execution_count": null,
|
869 |
+
"metadata": {},
|
870 |
+
"outputs": [],
|
871 |
+
"source": [
|
872 |
+
"# from transformers import AutoModelForCausalLM\n",
|
873 |
+
"\n",
|
874 |
+
"# model = AutoModelForCausalLM.from_pretrained(\"deepnet/SN6-BestLlama\")\n",
|
875 |
+
"# outputs = model.generate(inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95)"
|
876 |
+
]
|
877 |
+
},
|
878 |
+
{
|
879 |
+
"cell_type": "markdown",
|
880 |
+
"metadata": {},
|
881 |
+
"source": [
|
882 |
+
"Decode the generated token ids back into text:"
|
883 |
+
]
|
884 |
+
},
|
885 |
+
{
|
886 |
+
"cell_type": "code",
|
887 |
+
"execution_count": null,
|
888 |
+
"metadata": {},
|
889 |
+
"outputs": [],
|
890 |
+
"source": [
|
891 |
+
"# tokenizer.batch_decode(outputs, skip_special_tokens=True)"
|
892 |
+
]
|
893 |
+
},
|
894 |
+
{
|
895 |
+
"cell_type": "code",
|
896 |
+
"execution_count": null,
|
897 |
+
"metadata": {},
|
898 |
+
"outputs": [],
|
899 |
+
"source": [
|
900 |
+
"# tokenizer.batch_decode(outputs, skip_special_tokens=True)"
|
901 |
+
]
|
902 |
+
},
|
903 |
+
{
|
904 |
+
"cell_type": "code",
|
905 |
+
"execution_count": null,
|
906 |
+
"metadata": {},
|
907 |
+
"outputs": [],
|
908 |
+
"source": []
|
909 |
+
}
|
910 |
+
],
|
911 |
+
"metadata": {
|
912 |
+
"kernelspec": {
|
913 |
+
"display_name": "Python 3 (ipykernel)",
|
914 |
+
"language": "python",
|
915 |
+
"name": "python3"
|
916 |
+
},
|
917 |
+
"language_info": {
|
918 |
+
"codemirror_mode": {
|
919 |
+
"name": "ipython",
|
920 |
+
"version": 3
|
921 |
+
},
|
922 |
+
"file_extension": ".py",
|
923 |
+
"mimetype": "text/x-python",
|
924 |
+
"name": "python",
|
925 |
+
"nbconvert_exporter": "python",
|
926 |
+
"pygments_lexer": "ipython3",
|
927 |
+
"version": "3.11.9"
|
928 |
+
}
|
929 |
+
},
|
930 |
+
"nbformat": 4,
|
931 |
+
"nbformat_minor": 4
|
932 |
+
}
|
language_modeling.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
|
4 |
+
# Transformers installation
|
5 |
+
# ! pip install transformers datasets
|
6 |
+
# To install from source instead of the last release, comment the command above and uncomment the following one.
|
7 |
+
# ! pip install git+https://github.com/huggingface/transformers.git
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
# #@title
|
12 |
+
# from IPython.display import HTML
|
13 |
+
|
14 |
+
# HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/Vpjb1lu0MDk?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
# from huggingface_hub import notebook_login
|
19 |
+
|
20 |
+
# notebook_login()
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# from datasets import load_dataset
|
25 |
+
|
26 |
+
# eli5 = load_dataset("eli5", split="train_asks[:5000]")
|
27 |
+
|
28 |
+
from datasets import load_dataset
|
29 |
+
# Falcon = load_dataset("csv", data_files="FalconData.csv")
|
30 |
+
Falcon = load_dataset('csv', data_files={"train": 'FalconData_train2.csv', "validation": 'FalconData_validation2.csv'})
|
31 |
+
|
32 |
+
print('Dataset Loaded!')
|
33 |
+
|
34 |
+
# Falcon = Falcon.train_test_split(test_size=0.10)
|
35 |
+
|
36 |
+
"""Then take a look at an example:"""
|
37 |
+
|
38 |
+
Falcon['train'][0]
|
39 |
+
|
40 |
+
Falcon['validation'][0]
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
# #@title
|
45 |
+
# from IPython.display import HTML
|
46 |
+
|
47 |
+
# HTML('<iframe width="560" height="315" src="https://www.youtube.com/embed/ma1TrR7gE7I?rel=0&controls=0&showinfo=0" frameborder="0" allowfullscreen></iframe>')
|
48 |
+
|
49 |
+
"""The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:"""
|
50 |
+
|
51 |
+
from transformers import AutoTokenizer, GPT2TokenizerFast
|
52 |
+
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
54 |
+
|
55 |
+
|
56 |
+
# tokenizer = GPT2TokenizerFast.from_pretrained("Xenova/gpt-4")#, cache_dir=cache_dir)
|
57 |
+
# tokenizer.pad_token
|
58 |
+
|
59 |
+
# tokenizer.eos_token=128000
|
60 |
+
# tokenizer.bos_token='128000'
|
61 |
+
# tokenizer.eos_token='128001'
|
62 |
+
|
63 |
+
tokenizer.pad_token = tokenizer.eos_token
|
64 |
+
|
65 |
+
Falcon = Falcon.flatten()
|
66 |
+
Falcon["train"][0]
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
def preprocess_function(examples):
|
71 |
+
return tokenizer([" ".join(x) for x in examples["Text"]])
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
tokenized_Falcon = Falcon.map(
|
76 |
+
preprocess_function,
|
77 |
+
batched=True,
|
78 |
+
num_proc=4,
|
79 |
+
remove_columns=Falcon["train"].column_names,
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
block_size = tokenizer.model_max_length
|
84 |
+
# block_size = 2048
|
85 |
+
|
86 |
+
|
87 |
+
def group_texts(examples):
|
88 |
+
# Concatenate all texts.
|
89 |
+
concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
|
90 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
91 |
+
# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
|
92 |
+
# customize this part to your needs.
|
93 |
+
if total_length >= block_size:
|
94 |
+
total_length = (total_length // block_size) * block_size
|
95 |
+
# Split by chunks of block_size.
|
96 |
+
result = {
|
97 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
98 |
+
for k, t in concatenated_examples.items()
|
99 |
+
}
|
100 |
+
result["labels"] = result["input_ids"].copy()
|
101 |
+
return result
|
102 |
+
|
103 |
+
"""Apply the `group_texts` function over the entire dataset:"""
|
104 |
+
|
105 |
+
lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
from transformers import DataCollatorForLanguageModeling
|
110 |
+
|
111 |
+
# tokenizer.pad_token
|
112 |
+
# tokenizer.bos_token='128000'
|
113 |
+
# tokenizer.eos_token='128001'
|
114 |
+
|
115 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
120 |
+
import torch
|
121 |
+
model = AutoModelForCausalLM.from_pretrained("rwh/tiny8", torch_dtype=torch.bfloat16)
|
122 |
+
|
123 |
+
print('Model Loaded!')
|
124 |
+
|
125 |
+
# import torch
|
126 |
+
# torch.cuda.empty_cache()
|
127 |
+
|
128 |
+
# import torch
|
129 |
+
# import gc
|
130 |
+
|
131 |
+
# # del tensor_name # Delete the tensor
|
132 |
+
# gc.collect() # Collect garbage
|
133 |
+
# torch.cuda.empty_cache() # Clear cache
|
134 |
+
|
135 |
+
# torch.cuda.empty_cache()
|
136 |
+
|
137 |
+
# torch.no_grad()
|
138 |
+
|
139 |
+
model.to('cuda')
|
140 |
+
|
141 |
+
OutputDir = "C1ReadyModel"
|
142 |
+
|
143 |
+
training_args = TrainingArguments(
|
144 |
+
output_dir=OutputDir,
|
145 |
+
overwrite_output_dir=True,
|
146 |
+
bf16=True,
|
147 |
+
# evaluation_strategy="epoch",
|
148 |
+
evaluation_strategy="steps",
|
149 |
+
# learning_rate=3.25e-06,
|
150 |
+
# learning_rate=2e-5,
|
151 |
+
learning_rate=1e-5,
|
152 |
+
weight_decay=0.01,
|
153 |
+
# weight_decay=0.001,
|
154 |
+
num_train_epochs=6,
|
155 |
+
per_device_train_batch_size=8,
|
156 |
+
per_device_eval_batch_size=8,
|
157 |
+
# lr_scheduler_type = 'cosine',
|
158 |
+
lr_scheduler_type = 'linear',
|
159 |
+
push_to_hub=False,
|
160 |
+
save_total_limit = 2,
|
161 |
+
save_strategy = "steps",
|
162 |
+
load_best_model_at_end=True,
|
163 |
+
save_safetensors=True,
|
164 |
+
)
|
165 |
+
|
166 |
+
trainer = Trainer(
|
167 |
+
model=model,
|
168 |
+
args=training_args,
|
169 |
+
train_dataset=lm_dataset["train"],
|
170 |
+
eval_dataset=lm_dataset["validation"],
|
171 |
+
# eval_dataset=lm_dataset["test"],
|
172 |
+
data_collator=data_collator,
|
173 |
+
)
|
174 |
+
|
175 |
+
# trainer.train()
|
176 |
+
print('Started Training!')
|
177 |
+
trainer.train()
|
178 |
+
|
179 |
+
trainer.save_model(OutputDir)
|
180 |
+
print('Saved Model Path:', OutputDir)
|
181 |
+
|
182 |
+
import math
|
183 |
+
|
184 |
+
eval_results = trainer.evaluate()
|
185 |
+
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
|
186 |
+
|
187 |
+
|
short_gpt/.ipynb_checkpoints/short_hf-checkpoint.ipynb
ADDED
@@ -0,0 +1,1679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.19.1)\n",
|
13 |
+
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.1)\n",
|
14 |
+
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.40.2)\n",
|
15 |
+
"Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.10.0)\n",
|
16 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.1)\n",
|
17 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.2)\n",
|
18 |
+
"Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.0.0)\n",
|
19 |
+
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n",
|
20 |
+
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
|
21 |
+
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n",
|
22 |
+
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
|
23 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.2)\n",
|
24 |
+
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
|
25 |
+
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
|
26 |
+
"Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2023.10.0)\n",
|
27 |
+
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.0b0)\n",
|
28 |
+
"Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.0)\n",
|
29 |
+
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n",
|
30 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
|
31 |
+
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.8.0)\n",
|
32 |
+
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n",
|
33 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n",
|
34 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n",
|
35 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
36 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
37 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
38 |
+
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n",
|
39 |
+
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n",
|
40 |
+
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n",
|
41 |
+
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n",
|
42 |
+
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n",
|
43 |
+
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n",
|
44 |
+
"Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.18.1)\n",
|
45 |
+
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
46 |
+
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n",
|
47 |
+
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.3.101)\n",
|
48 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.4.28)\n",
|
49 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n",
|
50 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n",
|
51 |
+
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.6)\n",
|
52 |
+
"Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.30.0)\n",
|
53 |
+
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
|
54 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n",
|
55 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
|
56 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
|
57 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
|
58 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
|
59 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
|
60 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n",
|
61 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.1.0)\n",
|
62 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.11.17)\n",
|
63 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n",
|
64 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
|
65 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
|
66 |
+
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
|
67 |
+
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
|
68 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
|
69 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
70 |
+
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
71 |
+
]
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"source": [
|
75 |
+
"pip install datasets torch transformers peft"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 4,
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [],
|
83 |
+
"source": [
|
84 |
+
"from tqdm.notebook import tqdm\n",
|
85 |
+
"\n",
|
86 |
+
"from datasets import load_dataset\n",
|
87 |
+
"import torch\n",
|
88 |
+
"from torch.utils.data import DataLoader\n",
|
89 |
+
"\n",
|
90 |
+
"from peft import (\n",
|
91 |
+
" get_peft_model,\n",
|
92 |
+
" LoraConfig,\n",
|
93 |
+
" TaskType,\n",
|
94 |
+
")\n",
|
95 |
+
"from transformers import default_data_collator, Trainer, TrainingArguments\n",
|
96 |
+
"\n",
|
97 |
+
"from short_hf import ShortHFModel"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "markdown",
|
102 |
+
"metadata": {},
|
103 |
+
"source": [
|
104 |
+
"### Load Data"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": null,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"# data = load_dataset(\"pg19\", split=\"validation\") # authors sample 10,000 texts to compute block influences\n",
|
114 |
+
"# dataloader = DataLoader(\n",
|
115 |
+
"# data,\n",
|
116 |
+
"# batch_size=2,\n",
|
117 |
+
"# shuffle=True,\n",
|
118 |
+
"# )"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 5,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"data = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"validation\") # authors sample 10,000 texts to compute block influences\n",
|
128 |
+
"dataloader = DataLoader(\n",
|
129 |
+
" data,\n",
|
130 |
+
" batch_size=1,\n",
|
131 |
+
" shuffle=True,\n",
|
132 |
+
")"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "markdown",
|
137 |
+
"metadata": {},
|
138 |
+
"source": [
|
139 |
+
"### Load Model"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 3,
|
145 |
+
"metadata": {},
|
146 |
+
"outputs": [],
|
147 |
+
"source": [
|
148 |
+
"# !huggingface-cli login\n",
|
149 |
+
"# pip install huggingface_hub\n",
|
150 |
+
"!python3 -c \"from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')\""
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [],
|
158 |
+
"source": [
|
159 |
+
"#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"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": 3,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [
|
167 |
+
{
|
168 |
+
"name": "stdout",
|
169 |
+
"output_type": "stream",
|
170 |
+
"text": [
|
171 |
+
"asifahmed\n"
|
172 |
+
]
|
173 |
+
}
|
174 |
+
],
|
175 |
+
"source": [
|
176 |
+
"!huggingface-cli whoami"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 2,
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [],
|
184 |
+
"source": [
|
185 |
+
"# pip install git+https://github.com/tri-ml/linear_open_lm.git\n"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": 6,
|
191 |
+
"metadata": {},
|
192 |
+
"outputs": [
|
193 |
+
{
|
194 |
+
"name": "stderr",
|
195 |
+
"output_type": "stream",
|
196 |
+
"text": [
|
197 |
+
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
198 |
+
" warnings.warn(\n"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"data": {
|
203 |
+
"application/vnd.jupyter.widget-view+json": {
|
204 |
+
"model_id": "9fcf366ecc414808b39285438599f5b9",
|
205 |
+
"version_major": 2,
|
206 |
+
"version_minor": 0
|
207 |
+
},
|
208 |
+
"text/plain": [
|
209 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
"metadata": {},
|
213 |
+
"output_type": "display_data"
|
214 |
+
}
|
215 |
+
],
|
216 |
+
"source": [
|
217 |
+
"# from open_lm.open_lm_hf import *\n",
|
218 |
+
"\n",
|
219 |
+
"MAX_SEQ_LEN = 2048\n",
|
220 |
+
"short_model = ShortHFModel(\n",
|
221 |
+
" # model_name=\"tiiuae/falcon-7b\",\n",
|
222 |
+
" model_name=\"mistralai/Mistral-7B-v0.1\",\n",
|
223 |
+
" layers_path=\"model.layers\",\n",
|
224 |
+
" n_prune_layers=2\n",
|
225 |
+
")\n",
|
226 |
+
"# short_model.model"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": 7,
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [
|
234 |
+
{
|
235 |
+
"data": {
|
236 |
+
"text/plain": [
|
237 |
+
"MistralForCausalLM(\n",
|
238 |
+
" (model): MistralModel(\n",
|
239 |
+
" (embed_tokens): Embedding(32000, 4096)\n",
|
240 |
+
" (layers): ModuleList(\n",
|
241 |
+
" (0-31): 32 x MistralDecoderLayer(\n",
|
242 |
+
" (self_attn): MistralSdpaAttention(\n",
|
243 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
244 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
245 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
246 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
247 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
248 |
+
" )\n",
|
249 |
+
" (mlp): MistralMLP(\n",
|
250 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
251 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
252 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
253 |
+
" (act_fn): SiLU()\n",
|
254 |
+
" )\n",
|
255 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
256 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
257 |
+
" )\n",
|
258 |
+
" )\n",
|
259 |
+
" (norm): MistralRMSNorm()\n",
|
260 |
+
" )\n",
|
261 |
+
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
|
262 |
+
")"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
"execution_count": 7,
|
266 |
+
"metadata": {},
|
267 |
+
"output_type": "execute_result"
|
268 |
+
}
|
269 |
+
],
|
270 |
+
"source": [
|
271 |
+
"short_model.model"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "code",
|
276 |
+
"execution_count": null,
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [],
|
279 |
+
"source": [
|
280 |
+
"# AutoModelForCausalLM.from_pretrained(\n",
|
281 |
+
"# pretrained_model_name_or_path=model_dir,\n",
|
282 |
+
"# local_files_only=True,\n",
|
283 |
+
"# use_safetensors=True,\n",
|
284 |
+
"# torch_dtype=torch.bfloat16,\n",
|
285 |
+
"# )"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": 8,
|
291 |
+
"metadata": {},
|
292 |
+
"outputs": [
|
293 |
+
{
|
294 |
+
"data": {
|
295 |
+
"text/plain": [
|
296 |
+
"<generator object Module.parameters at 0x7f00b3917840>"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
"execution_count": 8,
|
300 |
+
"metadata": {},
|
301 |
+
"output_type": "execute_result"
|
302 |
+
}
|
303 |
+
],
|
304 |
+
"source": [
|
305 |
+
"short_model.model.parameters()"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"execution_count": 9,
|
311 |
+
"metadata": {},
|
312 |
+
"outputs": [
|
313 |
+
{
|
314 |
+
"data": {
|
315 |
+
"text/plain": [
|
316 |
+
"7241732096"
|
317 |
+
]
|
318 |
+
},
|
319 |
+
"execution_count": 9,
|
320 |
+
"metadata": {},
|
321 |
+
"output_type": "execute_result"
|
322 |
+
}
|
323 |
+
],
|
324 |
+
"source": [
|
325 |
+
"pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n",
|
326 |
+
"pytorch_total_params"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "code",
|
331 |
+
"execution_count": 36,
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [],
|
334 |
+
"source": [
|
335 |
+
" # Save the model state to the specified path.\n",
|
336 |
+
"# model_dir='ShortModelSaved/'\n",
|
337 |
+
"# short_model.model.save_pretrained(\n",
|
338 |
+
"# save_directory=model_dir,\n",
|
339 |
+
"# safe_serialization=True,\n",
|
340 |
+
"# )"
|
341 |
+
]
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"cell_type": "code",
|
345 |
+
"execution_count": 10,
|
346 |
+
"metadata": {},
|
347 |
+
"outputs": [
|
348 |
+
{
|
349 |
+
"data": {
|
350 |
+
"text/plain": [
|
351 |
+
"MistralDecoderLayer(\n",
|
352 |
+
" (self_attn): MistralSdpaAttention(\n",
|
353 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
354 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
355 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
356 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
357 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
358 |
+
" )\n",
|
359 |
+
" (mlp): MistralMLP(\n",
|
360 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
361 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
362 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
363 |
+
" (act_fn): SiLU()\n",
|
364 |
+
" )\n",
|
365 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
366 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
367 |
+
")"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
"execution_count": 10,
|
371 |
+
"metadata": {},
|
372 |
+
"output_type": "execute_result"
|
373 |
+
}
|
374 |
+
],
|
375 |
+
"source": [
|
376 |
+
"short_model.layers[0]"
|
377 |
+
]
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"cell_type": "code",
|
381 |
+
"execution_count": 12,
|
382 |
+
"metadata": {},
|
383 |
+
"outputs": [
|
384 |
+
{
|
385 |
+
"name": "stderr",
|
386 |
+
"output_type": "stream",
|
387 |
+
"text": [
|
388 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
389 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"data": {
|
394 |
+
"text/plain": [
|
395 |
+
"['I am an avid fan of 3D printing. I have been using 3D printers for over 10 years and have been involved in the development of several 3D printers. I have also been involved in the development of several 3D printing software packages.\\n\\nI have been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages.']"
|
396 |
+
]
|
397 |
+
},
|
398 |
+
"execution_count": 12,
|
399 |
+
"metadata": {},
|
400 |
+
"output_type": "execute_result"
|
401 |
+
}
|
402 |
+
],
|
403 |
+
"source": [
|
404 |
+
"# sample generationThe evolution of AI has lead to \n",
|
405 |
+
"gen = short_model.model.generate(\n",
|
406 |
+
" short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
407 |
+
" max_new_tokens=256\n",
|
408 |
+
")\n",
|
409 |
+
"short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
410 |
+
]
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"cell_type": "code",
|
414 |
+
"execution_count": 2,
|
415 |
+
"metadata": {},
|
416 |
+
"outputs": [],
|
417 |
+
"source": [
|
418 |
+
"# # sample generation\n",
|
419 |
+
"# gen = short_model.model.generate(\n",
|
420 |
+
"# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
421 |
+
"# max_new_tokens=256\n",
|
422 |
+
"# )\n",
|
423 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
424 |
+
]
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"cell_type": "markdown",
|
428 |
+
"metadata": {},
|
429 |
+
"source": [
|
430 |
+
"### Compute Importances"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"cell_type": "code",
|
435 |
+
"execution_count": 50,
|
436 |
+
"metadata": {},
|
437 |
+
"outputs": [],
|
438 |
+
"source": [
|
439 |
+
"# for i, batch in enumerate(tqdm(dataloader)):\n",
|
440 |
+
"# prompts = batch['text']\n",
|
441 |
+
"\n",
|
442 |
+
"# short_model.eval_importance(\n",
|
443 |
+
"# prompts=prompts,\n",
|
444 |
+
"# max_seq_len=MAX_SEQ_LEN,\n",
|
445 |
+
"# stride=256,\n",
|
446 |
+
"# max_gen_len=0\n",
|
447 |
+
"# )"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": 51,
|
453 |
+
"metadata": {},
|
454 |
+
"outputs": [],
|
455 |
+
"source": [
|
456 |
+
"# short_model.importances"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"cell_type": "markdown",
|
461 |
+
"metadata": {},
|
462 |
+
"source": [
|
463 |
+
"### Remove unimportant layers\n",
|
464 |
+
"\n",
|
465 |
+
"Layers removed when using subset of pg19 val set: [25, 26, 24, 27, 22, 23, 28, 21, 29]\n",
|
466 |
+
"\n",
|
467 |
+
"Authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance."
|
468 |
+
]
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"cell_type": "code",
|
472 |
+
"execution_count": 55,
|
473 |
+
"metadata": {},
|
474 |
+
"outputs": [],
|
475 |
+
"source": [
|
476 |
+
"# short_model.remove_layers()"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"cell_type": "code",
|
481 |
+
"execution_count": 54,
|
482 |
+
"metadata": {},
|
483 |
+
"outputs": [],
|
484 |
+
"source": [
|
485 |
+
"# short_model.remove_layers()"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"cell_type": "code",
|
490 |
+
"execution_count": 56,
|
491 |
+
"metadata": {},
|
492 |
+
"outputs": [],
|
493 |
+
"source": [
|
494 |
+
"# short_model.layers"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": 48,
|
500 |
+
"metadata": {},
|
501 |
+
"outputs": [],
|
502 |
+
"source": [
|
503 |
+
"# # reassign layer_idx to attentions for caching\n",
|
504 |
+
"# for layer_idx, module in enumerate(short_model.layers):\n",
|
505 |
+
"# module.self_attn.layer_idx = layer_idx"
|
506 |
+
]
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"cell_type": "code",
|
510 |
+
"execution_count": 20,
|
511 |
+
"metadata": {},
|
512 |
+
"outputs": [
|
513 |
+
{
|
514 |
+
"data": {
|
515 |
+
"text/plain": [
|
516 |
+
"<generator object Module.parameters at 0x7f625768a2d0>"
|
517 |
+
]
|
518 |
+
},
|
519 |
+
"execution_count": 20,
|
520 |
+
"metadata": {},
|
521 |
+
"output_type": "execute_result"
|
522 |
+
}
|
523 |
+
],
|
524 |
+
"source": [
|
525 |
+
"# short_model.model.parameters()"
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"cell_type": "code",
|
530 |
+
"execution_count": 68,
|
531 |
+
"metadata": {},
|
532 |
+
"outputs": [
|
533 |
+
{
|
534 |
+
"data": {
|
535 |
+
"text/plain": [
|
536 |
+
"7241732096"
|
537 |
+
]
|
538 |
+
},
|
539 |
+
"execution_count": 68,
|
540 |
+
"metadata": {},
|
541 |
+
"output_type": "execute_result"
|
542 |
+
}
|
543 |
+
],
|
544 |
+
"source": [
|
545 |
+
"# pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n",
|
546 |
+
"# pytorch_total_params"
|
547 |
+
]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"cell_type": "markdown",
|
551 |
+
"metadata": {},
|
552 |
+
"source": [
|
553 |
+
"As the paper states: \\\n",
|
554 |
+
" - \"Our experiments reveal that the effect of layer removal is significantly more pronounced on generative\n",
|
555 |
+
" tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and\n",
|
556 |
+
" HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance\n",
|
557 |
+
" drop, with scores approaching zero.\""
|
558 |
+
]
|
559 |
+
},
|
560 |
+
{
|
561 |
+
"cell_type": "code",
|
562 |
+
"execution_count": 53,
|
563 |
+
"metadata": {},
|
564 |
+
"outputs": [],
|
565 |
+
"source": [
|
566 |
+
"# gen = short_model.model.generate(\n",
|
567 |
+
"# short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
568 |
+
"# max_new_tokens=20,\n",
|
569 |
+
"# use_cache=True\n",
|
570 |
+
"# )\n",
|
571 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"cell_type": "code",
|
576 |
+
"execution_count": 52,
|
577 |
+
"metadata": {},
|
578 |
+
"outputs": [],
|
579 |
+
"source": [
|
580 |
+
"# gen = short_model.model.generate(I am an avid fan of \n",
|
581 |
+
"# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
582 |
+
"# max_new_tokens=20,\n",
|
583 |
+
"# use_cache=True\n",
|
584 |
+
"# )\n",
|
585 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
586 |
+
]
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"cell_type": "markdown",
|
590 |
+
"metadata": {},
|
591 |
+
"source": [
|
592 |
+
"### Compute Angular Importances"
|
593 |
+
]
|
594 |
+
},
|
595 |
+
{
|
596 |
+
"cell_type": "code",
|
597 |
+
"execution_count": 16,
|
598 |
+
"metadata": {},
|
599 |
+
"outputs": [
|
600 |
+
{
|
601 |
+
"data": {
|
602 |
+
"application/vnd.jupyter.widget-view+json": {
|
603 |
+
"model_id": "a6fd2bf4360b4aba801085bab0755a06",
|
604 |
+
"version_major": 2,
|
605 |
+
"version_minor": 0
|
606 |
+
},
|
607 |
+
"text/plain": [
|
608 |
+
" 0%| | 0/3760 [00:00<?, ?it/s]"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
"metadata": {},
|
612 |
+
"output_type": "display_data"
|
613 |
+
}
|
614 |
+
],
|
615 |
+
"source": [
|
616 |
+
"for i, batch in enumerate(tqdm(dataloader)):\n",
|
617 |
+
" prompts = batch['text']\n",
|
618 |
+
"\n",
|
619 |
+
" short_model.eval_importance(\n",
|
620 |
+
" prompts=prompts,\n",
|
621 |
+
" max_seq_len=MAX_SEQ_LEN,\n",
|
622 |
+
" stride=256,\n",
|
623 |
+
" max_gen_len=0,\n",
|
624 |
+
" angular=True\n",
|
625 |
+
" )"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"cell_type": "code",
|
630 |
+
"execution_count": 17,
|
631 |
+
"metadata": {},
|
632 |
+
"outputs": [
|
633 |
+
{
|
634 |
+
"data": {
|
635 |
+
"text/plain": [
|
636 |
+
"[128390.1328125,\n",
|
637 |
+
" 80922.06787109375,\n",
|
638 |
+
" 61075.2890625,\n",
|
639 |
+
" nan,\n",
|
640 |
+
" nan,\n",
|
641 |
+
" 56557.81268310547,\n",
|
642 |
+
" nan,\n",
|
643 |
+
" 52294.552001953125,\n",
|
644 |
+
" 47928.185302734375,\n",
|
645 |
+
" 42335.215576171875,\n",
|
646 |
+
" 40547.564208984375,\n",
|
647 |
+
" 37178.684326171875,\n",
|
648 |
+
" 34713.912841796875,\n",
|
649 |
+
" 33843.728271484375,\n",
|
650 |
+
" 35384.353271484375,\n",
|
651 |
+
" 35603.388427734375,\n",
|
652 |
+
" 35621.970458984375,\n",
|
653 |
+
" 35356.719482421875,\n",
|
654 |
+
" 35365.243896484375,\n",
|
655 |
+
" 34914.025146484375,\n",
|
656 |
+
" 27854.576904296875,\n",
|
657 |
+
" 24398.073974609375,\n",
|
658 |
+
" 20450.390380859375,\n",
|
659 |
+
" 19501.300537109375,\n",
|
660 |
+
" 18430.427490234375,\n",
|
661 |
+
" 18231.873779296875,\n",
|
662 |
+
" 17917.493896484375,\n",
|
663 |
+
" 17806.815185546875,\n",
|
664 |
+
" 21227.195068359375,\n",
|
665 |
+
" 23928.313018798828,\n",
|
666 |
+
" 22738.702880859375,\n",
|
667 |
+
" 86123.783203125]"
|
668 |
+
]
|
669 |
+
},
|
670 |
+
"execution_count": 17,
|
671 |
+
"metadata": {},
|
672 |
+
"output_type": "execute_result"
|
673 |
+
}
|
674 |
+
],
|
675 |
+
"source": [
|
676 |
+
"short_model.importances"
|
677 |
+
]
|
678 |
+
},
|
679 |
+
{
|
680 |
+
"cell_type": "markdown",
|
681 |
+
"metadata": {},
|
682 |
+
"source": [
|
683 |
+
"### Remove unimportant layers"
|
684 |
+
]
|
685 |
+
},
|
686 |
+
{
|
687 |
+
"cell_type": "code",
|
688 |
+
"execution_count": 18,
|
689 |
+
"metadata": {},
|
690 |
+
"outputs": [
|
691 |
+
{
|
692 |
+
"data": {
|
693 |
+
"text/plain": [
|
694 |
+
"[27, 28]"
|
695 |
+
]
|
696 |
+
},
|
697 |
+
"execution_count": 18,
|
698 |
+
"metadata": {},
|
699 |
+
"output_type": "execute_result"
|
700 |
+
}
|
701 |
+
],
|
702 |
+
"source": [
|
703 |
+
"short_model.remove_layers(angular=True)"
|
704 |
+
]
|
705 |
+
},
|
706 |
+
{
|
707 |
+
"cell_type": "code",
|
708 |
+
"execution_count": 20,
|
709 |
+
"metadata": {},
|
710 |
+
"outputs": [
|
711 |
+
{
|
712 |
+
"data": {
|
713 |
+
"text/plain": [
|
714 |
+
"MistralDecoderLayer(\n",
|
715 |
+
" (self_attn): MistralSdpaAttention(\n",
|
716 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
717 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
718 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
719 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
720 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
721 |
+
" )\n",
|
722 |
+
" (mlp): MistralMLP(\n",
|
723 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
724 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
725 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
726 |
+
" (act_fn): SiLU()\n",
|
727 |
+
" )\n",
|
728 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
729 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
730 |
+
")"
|
731 |
+
]
|
732 |
+
},
|
733 |
+
"execution_count": 20,
|
734 |
+
"metadata": {},
|
735 |
+
"output_type": "execute_result"
|
736 |
+
}
|
737 |
+
],
|
738 |
+
"source": [
|
739 |
+
"short_model.layers[0]"
|
740 |
+
]
|
741 |
+
},
|
742 |
+
{
|
743 |
+
"cell_type": "code",
|
744 |
+
"execution_count": 21,
|
745 |
+
"metadata": {},
|
746 |
+
"outputs": [
|
747 |
+
{
|
748 |
+
"data": {
|
749 |
+
"text/plain": [
|
750 |
+
"ModuleList(\n",
|
751 |
+
" (0-29): 30 x MistralDecoderLayer(\n",
|
752 |
+
" (self_attn): MistralSdpaAttention(\n",
|
753 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
754 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
755 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
756 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
757 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
758 |
+
" )\n",
|
759 |
+
" (mlp): MistralMLP(\n",
|
760 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
761 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
762 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
763 |
+
" (act_fn): SiLU()\n",
|
764 |
+
" )\n",
|
765 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
766 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
767 |
+
" )\n",
|
768 |
+
")"
|
769 |
+
]
|
770 |
+
},
|
771 |
+
"execution_count": 21,
|
772 |
+
"metadata": {},
|
773 |
+
"output_type": "execute_result"
|
774 |
+
}
|
775 |
+
],
|
776 |
+
"source": [
|
777 |
+
"short_model.layers"
|
778 |
+
]
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"cell_type": "code",
|
782 |
+
"execution_count": 22,
|
783 |
+
"metadata": {},
|
784 |
+
"outputs": [],
|
785 |
+
"source": [
|
786 |
+
"# reassign layer_idx to attentions for caching\n",
|
787 |
+
"for layer_idx, module in enumerate(short_model.layers):\n",
|
788 |
+
" module.self_attn.layer_idx = layer_idx"
|
789 |
+
]
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"cell_type": "code",
|
793 |
+
"execution_count": 23,
|
794 |
+
"metadata": {},
|
795 |
+
"outputs": [
|
796 |
+
{
|
797 |
+
"data": {
|
798 |
+
"text/plain": [
|
799 |
+
"ModuleList(\n",
|
800 |
+
" (0-29): 30 x MistralDecoderLayer(\n",
|
801 |
+
" (self_attn): MistralSdpaAttention(\n",
|
802 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
803 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
804 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
805 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
806 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
807 |
+
" )\n",
|
808 |
+
" (mlp): MistralMLP(\n",
|
809 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
810 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
811 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
812 |
+
" (act_fn): SiLU()\n",
|
813 |
+
" )\n",
|
814 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
815 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
816 |
+
" )\n",
|
817 |
+
")"
|
818 |
+
]
|
819 |
+
},
|
820 |
+
"execution_count": 23,
|
821 |
+
"metadata": {},
|
822 |
+
"output_type": "execute_result"
|
823 |
+
}
|
824 |
+
],
|
825 |
+
"source": [
|
826 |
+
"short_model.layers"
|
827 |
+
]
|
828 |
+
},
|
829 |
+
{
|
830 |
+
"cell_type": "code",
|
831 |
+
"execution_count": 24,
|
832 |
+
"metadata": {},
|
833 |
+
"outputs": [
|
834 |
+
{
|
835 |
+
"name": "stderr",
|
836 |
+
"output_type": "stream",
|
837 |
+
"text": [
|
838 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
839 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
840 |
+
]
|
841 |
+
},
|
842 |
+
{
|
843 |
+
"data": {
|
844 |
+
"text/plain": [
|
845 |
+
"['I am an avid fan of 19th century American literature. I have read all of the classics, and I have also read many of the lesser known works. I have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens']"
|
846 |
+
]
|
847 |
+
},
|
848 |
+
"execution_count": 24,
|
849 |
+
"metadata": {},
|
850 |
+
"output_type": "execute_result"
|
851 |
+
}
|
852 |
+
],
|
853 |
+
"source": [
|
854 |
+
"gen = short_model.model.generate(\n",
|
855 |
+
" short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
856 |
+
" max_new_tokens=256,\n",
|
857 |
+
" use_cache=True\n",
|
858 |
+
")\n",
|
859 |
+
"short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
860 |
+
]
|
861 |
+
},
|
862 |
+
{
|
863 |
+
"cell_type": "code",
|
864 |
+
"execution_count": 27,
|
865 |
+
"metadata": {},
|
866 |
+
"outputs": [
|
867 |
+
{
|
868 |
+
"name": "stderr",
|
869 |
+
"output_type": "stream",
|
870 |
+
"text": [
|
871 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
872 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
873 |
+
]
|
874 |
+
},
|
875 |
+
{
|
876 |
+
"data": {
|
877 |
+
"text/plain": [
|
878 |
+
"['The evolution of AI has lead to 3 major types of AI:\\n\\n1. Strong AI\\n2. Weak AI\\n3. Super AI\\n\\nStrong AI is the type of AI that is capable of performing any task that a human can perform. This type of AI is still in the development phase and is not yet available in the market.\\n\\nWeak AI is the type of AI that is capable of performing a specific task. This type of AI is available in the market and is used in a variety of applications.\\n\\nSuper AI is the type of AI that is capable of performing any task that a human can perform and is also capable of learning and adapting. This type of AI is still in the development phase and is not yet available in the market.\\n\\n## What is the difference between AI and AI?\\n\\nThe difference between AI and AI is that AI is a type of artificial intelligence that is capable of performing a specific task, while AI is a type of artificial intelligence that is capable of performing any task.\\n\\n## What is the difference between AI and AI?\\n\\nThe difference between AI and AI is that AI is a type of artificial intelligence that is capable of performing a specific task, while AI is a type of artificial intelligence that is capable']"
|
879 |
+
]
|
880 |
+
},
|
881 |
+
"execution_count": 27,
|
882 |
+
"metadata": {},
|
883 |
+
"output_type": "execute_result"
|
884 |
+
}
|
885 |
+
],
|
886 |
+
"source": [
|
887 |
+
"# gen = short_model.model.generate(I am an avid fan of \n",
|
888 |
+
"# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
889 |
+
"# max_new_tokens=256,\n",
|
890 |
+
"# use_cache=True\n",
|
891 |
+
"# )\n",
|
892 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)\n",
|
893 |
+
"\n",
|
894 |
+
"\n",
|
895 |
+
"gen = short_model.model.generate(\n",
|
896 |
+
" short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
897 |
+
" max_new_tokens=256,\n",
|
898 |
+
" use_cache=True\n",
|
899 |
+
")\n",
|
900 |
+
"short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
901 |
+
]
|
902 |
+
},
|
903 |
+
{
|
904 |
+
"cell_type": "code",
|
905 |
+
"execution_count": 28,
|
906 |
+
"metadata": {},
|
907 |
+
"outputs": [
|
908 |
+
{
|
909 |
+
"data": {
|
910 |
+
"text/plain": [
|
911 |
+
"6805508096"
|
912 |
+
]
|
913 |
+
},
|
914 |
+
"execution_count": 28,
|
915 |
+
"metadata": {},
|
916 |
+
"output_type": "execute_result"
|
917 |
+
}
|
918 |
+
],
|
919 |
+
"source": [
|
920 |
+
"pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n",
|
921 |
+
"pytorch_total_params"
|
922 |
+
]
|
923 |
+
},
|
924 |
+
{
|
925 |
+
"cell_type": "code",
|
926 |
+
"execution_count": 35,
|
927 |
+
"metadata": {},
|
928 |
+
"outputs": [],
|
929 |
+
"source": [
|
930 |
+
" # Save the model state to the specified path.\n",
|
931 |
+
"model_dir='SmallModelSaved/'\n",
|
932 |
+
"short_model.model.save_pretrained(\n",
|
933 |
+
" save_directory=model_dir,\n",
|
934 |
+
" safe_serialization=True,\n",
|
935 |
+
" )"
|
936 |
+
]
|
937 |
+
},
|
938 |
+
{
|
939 |
+
"cell_type": "markdown",
|
940 |
+
"metadata": {},
|
941 |
+
"source": [
|
942 |
+
"### Model Healing"
|
943 |
+
]
|
944 |
+
},
|
945 |
+
{
|
946 |
+
"cell_type": "code",
|
947 |
+
"execution_count": 36,
|
948 |
+
"metadata": {},
|
949 |
+
"outputs": [],
|
950 |
+
"source": [
|
951 |
+
"# tokenizer = short_model.tokenizer\n",
|
952 |
+
"model = short_model.model"
|
953 |
+
]
|
954 |
+
},
|
955 |
+
{
|
956 |
+
"cell_type": "code",
|
957 |
+
"execution_count": 37,
|
958 |
+
"metadata": {},
|
959 |
+
"outputs": [
|
960 |
+
{
|
961 |
+
"name": "stdout",
|
962 |
+
"output_type": "stream",
|
963 |
+
"text": [
|
964 |
+
"Datset Loaded!\n"
|
965 |
+
]
|
966 |
+
}
|
967 |
+
],
|
968 |
+
"source": [
|
969 |
+
"from datasets import load_dataset\n",
|
970 |
+
"# Falcon = load_dataset(\"csv\", data_files=\"FalconData.csv\")\n",
|
971 |
+
"Falcon = load_dataset('csv', data_files={\"train\": 'FalconData2.csv', \"validation\": 'FalconDataEval2.csv'})\n",
|
972 |
+
"\n",
|
973 |
+
"print('Datset Loaded!')\n"
|
974 |
+
]
|
975 |
+
},
|
976 |
+
{
|
977 |
+
"cell_type": "code",
|
978 |
+
"execution_count": 38,
|
979 |
+
"metadata": {},
|
980 |
+
"outputs": [
|
981 |
+
{
|
982 |
+
"data": {
|
983 |
+
"text/plain": [
|
984 |
+
"{'Text': 'School Picture Gallery\\nFrance Ski School\\nChildren from Year 5 & 6 travelled to France from Newcastle airport to take part in a week of Ski School. The children had already spent 3 weeks learning the basics of skiing at Silksworth Ski School in Sunderland. When the children arrived in France they took part in a daily Ski School, during which the children made OUTSTANDING progress. The children also took part in French activities, explored local landmarks and took part in shopping activities in Chamonix. It was an incredible adventure for the children and staff!'}"
|
985 |
+
]
|
986 |
+
},
|
987 |
+
"execution_count": 38,
|
988 |
+
"metadata": {},
|
989 |
+
"output_type": "execute_result"
|
990 |
+
}
|
991 |
+
],
|
992 |
+
"source": [
|
993 |
+
"# Falcon = Falcon.train_test_split(test_size=0.10)\n",
|
994 |
+
"\n",
|
995 |
+
"\"\"\"Then take a look at an example:\"\"\"\n",
|
996 |
+
"\n",
|
997 |
+
"Falcon['train'][0]\n"
|
998 |
+
]
|
999 |
+
},
|
1000 |
+
{
|
1001 |
+
"cell_type": "code",
|
1002 |
+
"execution_count": 39,
|
1003 |
+
"metadata": {},
|
1004 |
+
"outputs": [
|
1005 |
+
{
|
1006 |
+
"data": {
|
1007 |
+
"text/plain": [
|
1008 |
+
"{'Text': 'Our Annual Garden Party is a fun-filled event with a ton of landscaping and garden supplies; gardening demonstrations, experts, and vendors; activities for kids; live bands; and local food. It’s been so popular that we’re extending it to TWO DAYS this year!\\nFestivities at 10am – 4pm Saturday and 11am – 3pm Sunday\\nShopping from 9am – 6pm both days\\nThroughout the winter, we collect gently-used and surplus lawn & garden supplies as well as outdoor décor and furniture. Then, we put it all out for your shopping pleasure! The sale begins at 9:00 am Saturday, but folks start lining up outside the gates even earlier, eager to dig through piles of flowerpots and shovels. (If you can’t get there in the morning, don’t worry – the staff continues to bring out items throughout the weekend.)\\nThe Garden Sale 1st.\\nThere will be prizes for people and pets dressed in garden party finery.\\nPhoto by Carrie Delesky\\nSo find yourself a dapper suit or fancy hat, and check out all the activities in store for you:\\nAnacostia Watershed Society\\nPrince George’s Chapter, Maryland Master Gardeners\\nMOM’s Organic Market\\nTreincarnation\\nVeteran Compost\\nPhoto by Carrie Delesky\\nSaturday the Forklift’s Matt Menke and Gary Barnhart of GL Barnhart Construction. Drop in for a while, or stay the whole.'}"
|
1009 |
+
]
|
1010 |
+
},
|
1011 |
+
"execution_count": 39,
|
1012 |
+
"metadata": {},
|
1013 |
+
"output_type": "execute_result"
|
1014 |
+
}
|
1015 |
+
],
|
1016 |
+
"source": [
|
1017 |
+
"Falcon['validation'][0]\n"
|
1018 |
+
]
|
1019 |
+
},
|
1020 |
+
{
|
1021 |
+
"cell_type": "code",
|
1022 |
+
"execution_count": 41,
|
1023 |
+
"metadata": {},
|
1024 |
+
"outputs": [
|
1025 |
+
{
|
1026 |
+
"name": "stderr",
|
1027 |
+
"output_type": "stream",
|
1028 |
+
"text": [
|
1029 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
1030 |
+
]
|
1031 |
+
}
|
1032 |
+
],
|
1033 |
+
"source": [
|
1034 |
+
"\"\"\"The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:\"\"\"\n",
|
1035 |
+
"\n",
|
1036 |
+
"from transformers import AutoTokenizer, GPT2TokenizerFast\n",
|
1037 |
+
"\n",
|
1038 |
+
"# tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
|
1039 |
+
"\n",
|
1040 |
+
"\n",
|
1041 |
+
"tokenizer = GPT2TokenizerFast.from_pretrained(\"Xenova/gpt-4\")#, cache_dir=cache_dir)\n",
|
1042 |
+
"tokenizer.pad_token = tokenizer.eos_token\n"
|
1043 |
+
]
|
1044 |
+
},
|
1045 |
+
{
|
1046 |
+
"cell_type": "code",
|
1047 |
+
"execution_count": 42,
|
1048 |
+
"metadata": {},
|
1049 |
+
"outputs": [
|
1050 |
+
{
|
1051 |
+
"data": {
|
1052 |
+
"text/plain": [
|
1053 |
+
"{'Text': 'School Picture Gallery\\nFrance Ski School\\nChildren from Year 5 & 6 travelled to France from Newcastle airport to take part in a week of Ski School. The children had already spent 3 weeks learning the basics of skiing at Silksworth Ski School in Sunderland. When the children arrived in France they took part in a daily Ski School, during which the children made OUTSTANDING progress. The children also took part in French activities, explored local landmarks and took part in shopping activities in Chamonix. It was an incredible adventure for the children and staff!'}"
|
1054 |
+
]
|
1055 |
+
},
|
1056 |
+
"execution_count": 42,
|
1057 |
+
"metadata": {},
|
1058 |
+
"output_type": "execute_result"
|
1059 |
+
}
|
1060 |
+
],
|
1061 |
+
"source": [
|
1062 |
+
"Falcon = Falcon.flatten()\n",
|
1063 |
+
"Falcon[\"train\"][0]"
|
1064 |
+
]
|
1065 |
+
},
|
1066 |
+
{
|
1067 |
+
"cell_type": "code",
|
1068 |
+
"execution_count": 43,
|
1069 |
+
"metadata": {},
|
1070 |
+
"outputs": [
|
1071 |
+
{
|
1072 |
+
"name": "stdout",
|
1073 |
+
"output_type": "stream",
|
1074 |
+
"text": [
|
1075 |
+
"The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !\n"
|
1076 |
+
]
|
1077 |
+
},
|
1078 |
+
{
|
1079 |
+
"data": {
|
1080 |
+
"application/vnd.jupyter.widget-view+json": {
|
1081 |
+
"model_id": "d2182d4fa561406ab7eb5fc6c19c6d17",
|
1082 |
+
"version_major": 2,
|
1083 |
+
"version_minor": 0
|
1084 |
+
},
|
1085 |
+
"text/plain": [
|
1086 |
+
"Map (num_proc=4): 0%| | 0/10000 [00:00<?, ? examples/s]"
|
1087 |
+
]
|
1088 |
+
},
|
1089 |
+
"metadata": {},
|
1090 |
+
"output_type": "display_data"
|
1091 |
+
},
|
1092 |
+
{
|
1093 |
+
"name": "stderr",
|
1094 |
+
"output_type": "stream",
|
1095 |
+
"text": [
|
1096 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (10412 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1097 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (10738 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1098 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (12860 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1099 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (23091 > 8192). Running this sequence through the model will result in indexing errors\n"
|
1100 |
+
]
|
1101 |
+
},
|
1102 |
+
{
|
1103 |
+
"name": "stdout",
|
1104 |
+
"output_type": "stream",
|
1105 |
+
"text": [
|
1106 |
+
"The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !\n"
|
1107 |
+
]
|
1108 |
+
},
|
1109 |
+
{
|
1110 |
+
"data": {
|
1111 |
+
"application/vnd.jupyter.widget-view+json": {
|
1112 |
+
"model_id": "121ffe72baf143f4aeea4616bee88405",
|
1113 |
+
"version_major": 2,
|
1114 |
+
"version_minor": 0
|
1115 |
+
},
|
1116 |
+
"text/plain": [
|
1117 |
+
"Map (num_proc=4): 0%| | 0/1000 [00:00<?, ? examples/s]"
|
1118 |
+
]
|
1119 |
+
},
|
1120 |
+
"metadata": {},
|
1121 |
+
"output_type": "display_data"
|
1122 |
+
},
|
1123 |
+
{
|
1124 |
+
"name": "stderr",
|
1125 |
+
"output_type": "stream",
|
1126 |
+
"text": [
|
1127 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (9078 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1128 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (15886 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1129 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (28727 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1130 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (8257 > 8192). Running this sequence through the model will result in indexing errors\n"
|
1131 |
+
]
|
1132 |
+
}
|
1133 |
+
],
|
1134 |
+
"source": [
|
1135 |
+
"def preprocess_function(examples):\n",
|
1136 |
+
" return tokenizer([\" \".join(x) for x in examples[\"Text\"]])\n",
|
1137 |
+
"\n",
|
1138 |
+
"\n",
|
1139 |
+
"\n",
|
1140 |
+
"tokenized_Falcon = Falcon.map(\n",
|
1141 |
+
" preprocess_function,\n",
|
1142 |
+
" batched=True,\n",
|
1143 |
+
" num_proc=4,\n",
|
1144 |
+
" remove_columns=Falcon[\"train\"].column_names,\n",
|
1145 |
+
")"
|
1146 |
+
]
|
1147 |
+
},
|
1148 |
+
{
|
1149 |
+
"cell_type": "code",
|
1150 |
+
"execution_count": 44,
|
1151 |
+
"metadata": {},
|
1152 |
+
"outputs": [
|
1153 |
+
{
|
1154 |
+
"data": {
|
1155 |
+
"application/vnd.jupyter.widget-view+json": {
|
1156 |
+
"model_id": "6d7b13436ae54624bd96973987373482",
|
1157 |
+
"version_major": 2,
|
1158 |
+
"version_minor": 0
|
1159 |
+
},
|
1160 |
+
"text/plain": [
|
1161 |
+
"Map (num_proc=4): 0%| | 0/10000 [00:00<?, ? examples/s]"
|
1162 |
+
]
|
1163 |
+
},
|
1164 |
+
"metadata": {},
|
1165 |
+
"output_type": "display_data"
|
1166 |
+
},
|
1167 |
+
{
|
1168 |
+
"data": {
|
1169 |
+
"application/vnd.jupyter.widget-view+json": {
|
1170 |
+
"model_id": "beade64b537441ef99a54830bb66eef2",
|
1171 |
+
"version_major": 2,
|
1172 |
+
"version_minor": 0
|
1173 |
+
},
|
1174 |
+
"text/plain": [
|
1175 |
+
"Map (num_proc=4): 0%| | 0/1000 [00:00<?, ? examples/s]"
|
1176 |
+
]
|
1177 |
+
},
|
1178 |
+
"metadata": {},
|
1179 |
+
"output_type": "display_data"
|
1180 |
+
}
|
1181 |
+
],
|
1182 |
+
"source": [
|
1183 |
+
"# block_size = tokenizer.model_max_length\n",
|
1184 |
+
"block_size = 2048\n",
|
1185 |
+
"\n",
|
1186 |
+
"\n",
|
1187 |
+
"def group_texts(examples):\n",
|
1188 |
+
" # Concatenate all texts.\n",
|
1189 |
+
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
1190 |
+
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
1191 |
+
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
1192 |
+
" # customize this part to your needs.\n",
|
1193 |
+
" if total_length >= block_size:\n",
|
1194 |
+
" total_length = (total_length // block_size) * block_size\n",
|
1195 |
+
" # Split by chunks of block_size.\n",
|
1196 |
+
" result = {\n",
|
1197 |
+
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
|
1198 |
+
" for k, t in concatenated_examples.items()\n",
|
1199 |
+
" }\n",
|
1200 |
+
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
|
1201 |
+
" return result\n",
|
1202 |
+
"\n",
|
1203 |
+
"\"\"\"Apply the `group_texts` function over the entire dataset:\"\"\"\n",
|
1204 |
+
"\n",
|
1205 |
+
"lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)\n"
|
1206 |
+
]
|
1207 |
+
},
|
1208 |
+
{
|
1209 |
+
"cell_type": "code",
|
1210 |
+
"execution_count": 45,
|
1211 |
+
"metadata": {},
|
1212 |
+
"outputs": [],
|
1213 |
+
"source": [
|
1214 |
+
"from transformers import DataCollatorForLanguageModeling\n",
|
1215 |
+
"\n",
|
1216 |
+
"# tokenizer.pad_token = tokenizer.eos_token\n",
|
1217 |
+
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n"
|
1218 |
+
]
|
1219 |
+
},
|
1220 |
+
{
|
1221 |
+
"cell_type": "code",
|
1222 |
+
"execution_count": null,
|
1223 |
+
"metadata": {},
|
1224 |
+
"outputs": [],
|
1225 |
+
"source": [
|
1226 |
+
"# from transformers import AutoModelForCausalLM, TrainingArguments, Trainer\n",
|
1227 |
+
"# import torch\n",
|
1228 |
+
"# model = AutoModelForCausalLM.from_pretrained(\"tensorplex-labs/pretraining-sn9-7B-5\", torch_dtype=torch.bfloat16)\n",
|
1229 |
+
"\n",
|
1230 |
+
"# print('Model Loaded!')\n"
|
1231 |
+
]
|
1232 |
+
},
|
1233 |
+
{
|
1234 |
+
"cell_type": "code",
|
1235 |
+
"execution_count": 46,
|
1236 |
+
"metadata": {},
|
1237 |
+
"outputs": [
|
1238 |
+
{
|
1239 |
+
"data": {
|
1240 |
+
"text/plain": [
|
1241 |
+
"MistralForCausalLM(\n",
|
1242 |
+
" (model): MistralModel(\n",
|
1243 |
+
" (embed_tokens): Embedding(32000, 4096)\n",
|
1244 |
+
" (layers): ModuleList(\n",
|
1245 |
+
" (0-29): 30 x MistralDecoderLayer(\n",
|
1246 |
+
" (self_attn): MistralSdpaAttention(\n",
|
1247 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
1248 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
1249 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
1250 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
1251 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
1252 |
+
" )\n",
|
1253 |
+
" (mlp): MistralMLP(\n",
|
1254 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
1255 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
1256 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
1257 |
+
" (act_fn): SiLU()\n",
|
1258 |
+
" )\n",
|
1259 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
1260 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
1261 |
+
" )\n",
|
1262 |
+
" )\n",
|
1263 |
+
" (norm): MistralRMSNorm()\n",
|
1264 |
+
" )\n",
|
1265 |
+
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
|
1266 |
+
")"
|
1267 |
+
]
|
1268 |
+
},
|
1269 |
+
"execution_count": 46,
|
1270 |
+
"metadata": {},
|
1271 |
+
"output_type": "execute_result"
|
1272 |
+
}
|
1273 |
+
],
|
1274 |
+
"source": [
|
1275 |
+
"model.to('cuda')"
|
1276 |
+
]
|
1277 |
+
},
|
1278 |
+
{
|
1279 |
+
"cell_type": "code",
|
1280 |
+
"execution_count": 47,
|
1281 |
+
"metadata": {},
|
1282 |
+
"outputs": [
|
1283 |
+
{
|
1284 |
+
"data": {
|
1285 |
+
"text/plain": [
|
1286 |
+
"6805508096"
|
1287 |
+
]
|
1288 |
+
},
|
1289 |
+
"execution_count": 47,
|
1290 |
+
"metadata": {},
|
1291 |
+
"output_type": "execute_result"
|
1292 |
+
}
|
1293 |
+
],
|
1294 |
+
"source": [
|
1295 |
+
"pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
|
1296 |
+
"pytorch_total_params"
|
1297 |
+
]
|
1298 |
+
},
|
1299 |
+
{
|
1300 |
+
"cell_type": "code",
|
1301 |
+
"execution_count": 48,
|
1302 |
+
"metadata": {},
|
1303 |
+
"outputs": [],
|
1304 |
+
"source": [
|
1305 |
+
"training_args = TrainingArguments(\n",
|
1306 |
+
" output_dir=\"Fine-Tuned-S9-2\",\n",
|
1307 |
+
" overwrite_output_dir=True,\n",
|
1308 |
+
" bf16=True,\n",
|
1309 |
+
" # evaluation_strategy=\"epoch\",\n",
|
1310 |
+
" evaluation_strategy=\"steps\",\n",
|
1311 |
+
" learning_rate=2e-5,\n",
|
1312 |
+
" weight_decay=0.01,\n",
|
1313 |
+
" num_train_epochs=1,\n",
|
1314 |
+
" per_device_train_batch_size=2,\n",
|
1315 |
+
" per_device_eval_batch_size=2,\n",
|
1316 |
+
" lr_scheduler_type = 'cosine',\n",
|
1317 |
+
" push_to_hub=False,\n",
|
1318 |
+
" save_total_limit = 2,\n",
|
1319 |
+
" # save_strategy = “no”\n",
|
1320 |
+
" load_best_model_at_end=False,\n",
|
1321 |
+
")\n",
|
1322 |
+
"\n",
|
1323 |
+
"trainer = Trainer(\n",
|
1324 |
+
" model=model,\n",
|
1325 |
+
" args=training_args,\n",
|
1326 |
+
" train_dataset=lm_dataset[\"train\"],\n",
|
1327 |
+
" eval_dataset=lm_dataset[\"validation\"],\n",
|
1328 |
+
" # eval_dataset=lm_dataset[\"test\"],\n",
|
1329 |
+
" data_collator=data_collator,\n",
|
1330 |
+
")"
|
1331 |
+
]
|
1332 |
+
},
|
1333 |
+
{
|
1334 |
+
"cell_type": "code",
|
1335 |
+
"execution_count": 49,
|
1336 |
+
"metadata": {},
|
1337 |
+
"outputs": [
|
1338 |
+
{
|
1339 |
+
"name": "stdout",
|
1340 |
+
"output_type": "stream",
|
1341 |
+
"text": [
|
1342 |
+
"Started Training!\n"
|
1343 |
+
]
|
1344 |
+
},
|
1345 |
+
{
|
1346 |
+
"name": "stderr",
|
1347 |
+
"output_type": "stream",
|
1348 |
+
"text": [
|
1349 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mthatmlguy\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
1350 |
+
]
|
1351 |
+
},
|
1352 |
+
{
|
1353 |
+
"data": {
|
1354 |
+
"text/html": [
|
1355 |
+
"Tracking run with wandb version 0.17.0"
|
1356 |
+
],
|
1357 |
+
"text/plain": [
|
1358 |
+
"<IPython.core.display.HTML object>"
|
1359 |
+
]
|
1360 |
+
},
|
1361 |
+
"metadata": {},
|
1362 |
+
"output_type": "display_data"
|
1363 |
+
},
|
1364 |
+
{
|
1365 |
+
"data": {
|
1366 |
+
"text/html": [
|
1367 |
+
"Run data is saved locally in <code>/workspace/ShortGPT/short_gpt/wandb/run-20240516_090043-ni1hktjg</code>"
|
1368 |
+
],
|
1369 |
+
"text/plain": [
|
1370 |
+
"<IPython.core.display.HTML object>"
|
1371 |
+
]
|
1372 |
+
},
|
1373 |
+
"metadata": {},
|
1374 |
+
"output_type": "display_data"
|
1375 |
+
},
|
1376 |
+
{
|
1377 |
+
"data": {
|
1378 |
+
"text/html": [
|
1379 |
+
"Syncing run <strong><a href='https://wandb.ai/thatmlguy/huggingface/runs/ni1hktjg' target=\"_blank\">misty-serenity-4</a></strong> to <a href='https://wandb.ai/thatmlguy/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
1380 |
+
],
|
1381 |
+
"text/plain": [
|
1382 |
+
"<IPython.core.display.HTML object>"
|
1383 |
+
]
|
1384 |
+
},
|
1385 |
+
"metadata": {},
|
1386 |
+
"output_type": "display_data"
|
1387 |
+
},
|
1388 |
+
{
|
1389 |
+
"data": {
|
1390 |
+
"text/html": [
|
1391 |
+
" View project at <a href='https://wandb.ai/thatmlguy/huggingface' target=\"_blank\">https://wandb.ai/thatmlguy/huggingface</a>"
|
1392 |
+
],
|
1393 |
+
"text/plain": [
|
1394 |
+
"<IPython.core.display.HTML object>"
|
1395 |
+
]
|
1396 |
+
},
|
1397 |
+
"metadata": {},
|
1398 |
+
"output_type": "display_data"
|
1399 |
+
},
|
1400 |
+
{
|
1401 |
+
"data": {
|
1402 |
+
"text/html": [
|
1403 |
+
" View run at <a href='https://wandb.ai/thatmlguy/huggingface/runs/ni1hktjg' target=\"_blank\">https://wandb.ai/thatmlguy/huggingface/runs/ni1hktjg</a>"
|
1404 |
+
],
|
1405 |
+
"text/plain": [
|
1406 |
+
"<IPython.core.display.HTML object>"
|
1407 |
+
]
|
1408 |
+
},
|
1409 |
+
"metadata": {},
|
1410 |
+
"output_type": "display_data"
|
1411 |
+
},
|
1412 |
+
{
|
1413 |
+
"data": {
|
1414 |
+
"text/html": [
|
1415 |
+
"\n",
|
1416 |
+
" <div>\n",
|
1417 |
+
" \n",
|
1418 |
+
" <progress value='2' max='6459' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1419 |
+
" [ 2/6459 : < :, Epoch 0.00/1]\n",
|
1420 |
+
" </div>\n",
|
1421 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1422 |
+
" <thead>\n",
|
1423 |
+
" <tr style=\"text-align: left;\">\n",
|
1424 |
+
" <th>Step</th>\n",
|
1425 |
+
" <th>Training Loss</th>\n",
|
1426 |
+
" <th>Validation Loss</th>\n",
|
1427 |
+
" </tr>\n",
|
1428 |
+
" </thead>\n",
|
1429 |
+
" <tbody>\n",
|
1430 |
+
" </tbody>\n",
|
1431 |
+
"</table><p>"
|
1432 |
+
],
|
1433 |
+
"text/plain": [
|
1434 |
+
"<IPython.core.display.HTML object>"
|
1435 |
+
]
|
1436 |
+
},
|
1437 |
+
"metadata": {},
|
1438 |
+
"output_type": "display_data"
|
1439 |
+
},
|
1440 |
+
{
|
1441 |
+
"ename": "OutOfMemoryError",
|
1442 |
+
"evalue": "CUDA out of memory. Tried to allocate 112.00 MiB. GPU ",
|
1443 |
+
"output_type": "error",
|
1444 |
+
"traceback": [
|
1445 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
1446 |
+
"\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
|
1447 |
+
"Cell \u001b[0;32mIn[49], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# trainer.train()\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mStarted Training!\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
1448 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1857\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1860\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1862\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1863\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
1449 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2203\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2203\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2206\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2207\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2208\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2209\u001b[0m ):\n\u001b[1;32m 2210\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2211\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
1450 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3138\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 3137\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3138\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3140\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 3141\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
|
1451 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3161\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 3159\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3160\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3161\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 3163\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 3164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
|
1452 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1453 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1454 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:822\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 822\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1455 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:810\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 810\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m convert_to_fp32(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
|
1456 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:16\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_autocast\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1457 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:1158\u001b[0m, in \u001b[0;36mMistralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1155\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1157\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1158\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1159\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1160\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1166\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1167\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1168\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1170\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1171\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
|
1458 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1459 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1460 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:1043\u001b[0m, in \u001b[0;36mMistralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1033\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1034\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1035\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1040\u001b[0m use_cache,\n\u001b[1;32m 1041\u001b[0m )\n\u001b[1;32m 1042\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1043\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1044\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1045\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1052\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
|
1461 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1462 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1463 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:770\u001b[0m, in \u001b[0;36mMistralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m 768\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 769\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 770\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 771\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 773\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
|
1464 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1465 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1466 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:179\u001b[0m, in \u001b[0;36mMistralMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fn(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mup_proj(x))\n",
|
1467 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1468 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1469 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
|
1470 |
+
"\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 112.00 MiB. GPU "
|
1471 |
+
]
|
1472 |
+
}
|
1473 |
+
],
|
1474 |
+
"source": [
|
1475 |
+
"# trainer.train()\n",
|
1476 |
+
"print('Started Training!')\n",
|
1477 |
+
"trainer.train()"
|
1478 |
+
]
|
1479 |
+
},
|
1480 |
+
{
|
1481 |
+
"cell_type": "code",
|
1482 |
+
"execution_count": null,
|
1483 |
+
"metadata": {},
|
1484 |
+
"outputs": [],
|
1485 |
+
"source": [
|
1486 |
+
"import math\n",
|
1487 |
+
"\n",
|
1488 |
+
"eval_results = trainer.evaluate()\n",
|
1489 |
+
"print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")\n"
|
1490 |
+
]
|
1491 |
+
},
|
1492 |
+
{
|
1493 |
+
"cell_type": "code",
|
1494 |
+
"execution_count": 29,
|
1495 |
+
"metadata": {},
|
1496 |
+
"outputs": [],
|
1497 |
+
"source": [
|
1498 |
+
"# # referencing https://github.com/meta-llama/llama-recipes/blob/main/recipes/finetuning/huggingface_trainer/peft_finetuning.ipynb\n",
|
1499 |
+
"# eval_prompt = \"\"\"\n",
|
1500 |
+
"# Summarize this dialog:\n",
|
1501 |
+
"# A: Hi Tom, are you busy tomorrow's afternoon?\n",
|
1502 |
+
"# B: I'm pretty sure I am. What's up?\n",
|
1503 |
+
"# A: Can you go with me to the animal shelter?.\n",
|
1504 |
+
"# B: What do you want to do?\n",
|
1505 |
+
"# A: I want to get a puppy for my son.\n",
|
1506 |
+
"# B: That will make him so happy.\n",
|
1507 |
+
"# A: Yeah, we've discussed it many times. I think he's ready now.\n",
|
1508 |
+
"# B: That's good. Raising a dog is a tough issue. Like having a baby ;-) \n",
|
1509 |
+
"# A: I'll get him one of those little dogs.\n",
|
1510 |
+
"# B: One that won't grow up too big;-)\n",
|
1511 |
+
"# A: And eat too much;-))\n",
|
1512 |
+
"# B: Do you know which one he would like?\n",
|
1513 |
+
"# A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
|
1514 |
+
"# B: I bet you had to drag him away.\n",
|
1515 |
+
"# A: He wanted to take it home right away ;-).\n",
|
1516 |
+
"# B: I wonder what he'll name it.\n",
|
1517 |
+
"# A: He said he'd name it after his dead hamster - Lemmy - he's a great Motorhead fan :-)))\n",
|
1518 |
+
"# ---\n",
|
1519 |
+
"# Summary:\n",
|
1520 |
+
"# \"\"\"\n",
|
1521 |
+
"\n",
|
1522 |
+
"# model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
|
1523 |
+
"\n",
|
1524 |
+
"# model.eval()\n",
|
1525 |
+
"# with torch.no_grad():\n",
|
1526 |
+
"# print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100, use_cache=True)[0], skip_special_tokens=True))"
|
1527 |
+
]
|
1528 |
+
},
|
1529 |
+
{
|
1530 |
+
"cell_type": "code",
|
1531 |
+
"execution_count": 30,
|
1532 |
+
"metadata": {},
|
1533 |
+
"outputs": [],
|
1534 |
+
"source": [
|
1535 |
+
"# def get_preprocessed_samsum():\n",
|
1536 |
+
"# dataset = load_dataset(\"samsum\", split=\"train\")\n",
|
1537 |
+
"\n",
|
1538 |
+
"# prompt = (\n",
|
1539 |
+
"# f\"Summarize this dialog:\\n{{dialog}}\\n---\\nSummary:\\n\"\n",
|
1540 |
+
"# )\n",
|
1541 |
+
"\n",
|
1542 |
+
"# def apply_prompt_template(sample):\n",
|
1543 |
+
"# return {\n",
|
1544 |
+
"# \"prompt\": prompt.format(dialog=sample[\"dialogue\"]),\n",
|
1545 |
+
"# \"summary\": sample[\"summary\"],\n",
|
1546 |
+
"# }\n",
|
1547 |
+
"\n",
|
1548 |
+
"# dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))\n",
|
1549 |
+
"\n",
|
1550 |
+
"# def tokenize_add_label(sample):\n",
|
1551 |
+
"# prompt = tokenizer.encode(tokenizer.bos_token + sample[\"prompt\"], add_special_tokens=False)\n",
|
1552 |
+
"# summary = tokenizer.encode(sample[\"summary\"] + tokenizer.eos_token, add_special_tokens=False)\n",
|
1553 |
+
"# sample = {\n",
|
1554 |
+
"# \"input_ids\": prompt + summary,\n",
|
1555 |
+
"# \"attention_mask\" : [1] * (len(prompt) + len(summary)),\n",
|
1556 |
+
"# \"labels\": [-100] * len(prompt) + summary,\n",
|
1557 |
+
"# }\n",
|
1558 |
+
"\n",
|
1559 |
+
"# return sample\n",
|
1560 |
+
"\n",
|
1561 |
+
"# dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))\n",
|
1562 |
+
"\n",
|
1563 |
+
"# return dataset"
|
1564 |
+
]
|
1565 |
+
},
|
1566 |
+
{
|
1567 |
+
"cell_type": "code",
|
1568 |
+
"execution_count": 31,
|
1569 |
+
"metadata": {},
|
1570 |
+
"outputs": [],
|
1571 |
+
"source": [
|
1572 |
+
"# model.train()\n",
|
1573 |
+
"\n",
|
1574 |
+
"# def create_peft_config(model):\n",
|
1575 |
+
"# peft_config = LoraConfig(\n",
|
1576 |
+
"# task_type=TaskType.CAUSAL_LM,\n",
|
1577 |
+
"# inference_mode=False,\n",
|
1578 |
+
"# r=8,\n",
|
1579 |
+
"# lora_alpha=32,\n",
|
1580 |
+
"# lora_dropout=0.05,\n",
|
1581 |
+
"# target_modules = [\"q_proj\", \"v_proj\"]\n",
|
1582 |
+
"# )\n",
|
1583 |
+
"\n",
|
1584 |
+
"# model = get_peft_model(model, peft_config)\n",
|
1585 |
+
"# model.print_trainable_parameters()\n",
|
1586 |
+
"# return model, peft_config\n",
|
1587 |
+
"\n",
|
1588 |
+
"# # create peft config\n",
|
1589 |
+
"# model, lora_config = create_peft_config(model)"
|
1590 |
+
]
|
1591 |
+
},
|
1592 |
+
{
|
1593 |
+
"cell_type": "code",
|
1594 |
+
"execution_count": 32,
|
1595 |
+
"metadata": {},
|
1596 |
+
"outputs": [],
|
1597 |
+
"source": [
|
1598 |
+
"# output_dir = \"tmp/\"\n",
|
1599 |
+
"\n",
|
1600 |
+
"# config = {\n",
|
1601 |
+
"# 'lora_config': lora_config,\n",
|
1602 |
+
"# 'learning_rate': 1e-6,\n",
|
1603 |
+
"# 'num_train_epochs': 1,\n",
|
1604 |
+
"# 'per_device_train_batch_size': 1,\n",
|
1605 |
+
"# 'gradient_checkpointing': False,\n",
|
1606 |
+
"# }\n"
|
1607 |
+
]
|
1608 |
+
},
|
1609 |
+
{
|
1610 |
+
"cell_type": "code",
|
1611 |
+
"execution_count": 33,
|
1612 |
+
"metadata": {},
|
1613 |
+
"outputs": [],
|
1614 |
+
"source": [
|
1615 |
+
"# training_args = TrainingArguments(\n",
|
1616 |
+
"# output_dir=output_dir,\n",
|
1617 |
+
"# overwrite_output_dir=True,\n",
|
1618 |
+
"# # logging strategies\n",
|
1619 |
+
"# logging_strategy=\"steps\",\n",
|
1620 |
+
"# logging_steps=10,\n",
|
1621 |
+
"# save_strategy=\"no\",\n",
|
1622 |
+
"# optim=\"adamw_torch_fused\",\n",
|
1623 |
+
"# **{k:v for k,v in config.items() if k != 'lora_config'}\n",
|
1624 |
+
"# )\n",
|
1625 |
+
"\n",
|
1626 |
+
"# # Create Trainer instance\n",
|
1627 |
+
"# trainer = Trainer(\n",
|
1628 |
+
"# model=model,\n",
|
1629 |
+
"# args=training_args,\n",
|
1630 |
+
"# train_dataset=get_preprocessed_samsum(),\n",
|
1631 |
+
"# data_collator=default_data_collator,\n",
|
1632 |
+
"# callbacks=[],\n",
|
1633 |
+
"# )\n",
|
1634 |
+
"\n",
|
1635 |
+
"# # Start training\n",
|
1636 |
+
"# trainer.train()"
|
1637 |
+
]
|
1638 |
+
},
|
1639 |
+
{
|
1640 |
+
"cell_type": "code",
|
1641 |
+
"execution_count": 34,
|
1642 |
+
"metadata": {},
|
1643 |
+
"outputs": [],
|
1644 |
+
"source": [
|
1645 |
+
"# model.eval()\n",
|
1646 |
+
"# with torch.no_grad():\n",
|
1647 |
+
"# print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))"
|
1648 |
+
]
|
1649 |
+
},
|
1650 |
+
{
|
1651 |
+
"cell_type": "code",
|
1652 |
+
"execution_count": null,
|
1653 |
+
"metadata": {},
|
1654 |
+
"outputs": [],
|
1655 |
+
"source": []
|
1656 |
+
}
|
1657 |
+
],
|
1658 |
+
"metadata": {
|
1659 |
+
"kernelspec": {
|
1660 |
+
"display_name": "Python 3 (ipykernel)",
|
1661 |
+
"language": "python",
|
1662 |
+
"name": "python3"
|
1663 |
+
},
|
1664 |
+
"language_info": {
|
1665 |
+
"codemirror_mode": {
|
1666 |
+
"name": "ipython",
|
1667 |
+
"version": 3
|
1668 |
+
},
|
1669 |
+
"file_extension": ".py",
|
1670 |
+
"mimetype": "text/x-python",
|
1671 |
+
"name": "python",
|
1672 |
+
"nbconvert_exporter": "python",
|
1673 |
+
"pygments_lexer": "ipython3",
|
1674 |
+
"version": "3.10.12"
|
1675 |
+
}
|
1676 |
+
},
|
1677 |
+
"nbformat": 4,
|
1678 |
+
"nbformat_minor": 4
|
1679 |
+
}
|
short_gpt/.ipynb_checkpoints/short_llama-checkpoint.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from llama import Llama, Transformer
|
7 |
+
|
8 |
+
from metrics import *
|
9 |
+
|
10 |
+
|
11 |
+
def sample_top_p(probs: torch.Tensor, p: float):
|
12 |
+
"""
|
13 |
+
Perform top-p (nucleus) sampling on a probability distribution.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
probs (torch.Tensor): Probability distribution tensor.
|
17 |
+
p (float): Probability threshold for top-p sampling.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
torch.Tensor: Sampled token indices.
|
21 |
+
|
22 |
+
Note:
|
23 |
+
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
24 |
+
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
25 |
+
|
26 |
+
"""
|
27 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
28 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
29 |
+
mask = probs_sum - probs_sort > p
|
30 |
+
probs_sort[mask] = 0.0
|
31 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
32 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
33 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
34 |
+
return next_token
|
35 |
+
|
36 |
+
|
37 |
+
class TransformerWrapper(Transformer):
|
38 |
+
def __init__(self, model):
|
39 |
+
self.__dict__ = model.__dict__.copy()
|
40 |
+
|
41 |
+
@torch.inference_mode()
|
42 |
+
def forward(
|
43 |
+
self,
|
44 |
+
tokens: torch.Tensor,
|
45 |
+
start_pos: int,
|
46 |
+
return_hiddens: Optional[bool] = False):
|
47 |
+
"""
|
48 |
+
Perform a forward pass through the Transformer model.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
tokens (torch.Tensor): Input token indices.
|
52 |
+
start_pos (int): Starting position for attention caching.
|
53 |
+
(Optional) return_hiddens (bool): Whether to return hidden states. Defaults to False.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
torch.Tensor: Output logits after applying the Transformer model.
|
57 |
+
(Optional) List[torch.Tensor]: Hidden states for each transformer block.
|
58 |
+
"""
|
59 |
+
_bsz, seqlen = tokens.shape
|
60 |
+
h = self.tok_embeddings(tokens)
|
61 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
62 |
+
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
63 |
+
|
64 |
+
mask = None
|
65 |
+
if seqlen > 1:
|
66 |
+
mask = torch.full(
|
67 |
+
(seqlen, seqlen), float("-inf"), device=tokens.device
|
68 |
+
)
|
69 |
+
|
70 |
+
mask = torch.triu(mask, diagonal=1)
|
71 |
+
|
72 |
+
# When performing key-value caching, we compute the attention scores
|
73 |
+
# only for the new sequence. Thus, the matrix of scores is of size
|
74 |
+
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
75 |
+
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
76 |
+
mask = torch.hstack([
|
77 |
+
torch.zeros((seqlen, start_pos), device=tokens.device),
|
78 |
+
mask
|
79 |
+
]).type_as(h)
|
80 |
+
|
81 |
+
hiddens = [h]
|
82 |
+
for layer in self.layers:
|
83 |
+
h = layer(h, start_pos, freqs_cis, mask)
|
84 |
+
if return_hiddens:
|
85 |
+
hiddens.append(h)
|
86 |
+
|
87 |
+
h = self.norm(h)
|
88 |
+
output = self.output(h).float()
|
89 |
+
|
90 |
+
if return_hiddens:
|
91 |
+
return output, hiddens
|
92 |
+
|
93 |
+
return output
|
94 |
+
|
95 |
+
|
96 |
+
class ShortLlama():
|
97 |
+
|
98 |
+
def __init__(self, llama: Llama, n_prune_layers: Optional[int] = None):
|
99 |
+
checkpoint = llama.model.state_dict()
|
100 |
+
llama.model = TransformerWrapper(llama.model) # wrap transformer to collect hidden states
|
101 |
+
llama.model.load_state_dict(checkpoint, strict=False)
|
102 |
+
self.llama = llama
|
103 |
+
|
104 |
+
self.n_prune_layers = n_prune_layers
|
105 |
+
self.importances = [0 for _ in self.llama.model.layers] # layer-wise importance scores
|
106 |
+
|
107 |
+
def remove_layers(
|
108 |
+
self,
|
109 |
+
layers_to_remove: Optional[List[int]] = [],
|
110 |
+
angular: Optional[bool] = False
|
111 |
+
):
|
112 |
+
if angular:
|
113 |
+
assert self.importances, "Need to compute importances with eval_importance()"
|
114 |
+
assert self.n_prune_layers, "Need number of layers to prune, set `n_prune_layers`"
|
115 |
+
start_layer = np.argsort(np.array(self.importances[:-self.n_prune_layers+1]))[0]
|
116 |
+
layers_to_remove = list(range(start_layer, start_layer + self.n_prune_layers))
|
117 |
+
elif not layers_to_remove and self.n_prune_layers:
|
118 |
+
assert self.importances, "Need to compute importances with eval_importance()"
|
119 |
+
layers_to_remove = np.argsort(np.array(self.importances))[:self.n_prune_layers].tolist()
|
120 |
+
|
121 |
+
# remove layers in reverse to avoid indexing errors
|
122 |
+
for layer_idx in sorted(layers_to_remove, reverse=True):
|
123 |
+
try:
|
124 |
+
del self.llama.model.layers[layer_idx]
|
125 |
+
except IndexError:
|
126 |
+
print(f"layer {layer_idx} does not exist, function may have already been called")
|
127 |
+
return []
|
128 |
+
|
129 |
+
return layers_to_remove
|
130 |
+
|
131 |
+
def compute_bi(self, hiddens: List[torch.Tensor], angular: bool):
|
132 |
+
n = 1
|
133 |
+
if angular:
|
134 |
+
assert self.n_prune_layers is not None, "Set number of layers to prune to use angular importance"
|
135 |
+
n = self.n_prune_layers
|
136 |
+
|
137 |
+
for i in range(len(hiddens) - n):
|
138 |
+
in_hidden = hiddens[i]
|
139 |
+
out_hidden = hiddens[i+n]
|
140 |
+
if angular:
|
141 |
+
# use only last token for angular distance as described in section 3.2
|
142 |
+
# https://arxiv.org/pdf/2403.17887.pdf
|
143 |
+
in_hidden = in_hidden[:,-1:]
|
144 |
+
out_hidden = out_hidden[:,-1:]
|
145 |
+
|
146 |
+
self.importances[i] += block_influence(
|
147 |
+
in_hidden,
|
148 |
+
out_hidden,
|
149 |
+
angular=angular
|
150 |
+
).sum().cpu().item()
|
151 |
+
|
152 |
+
@torch.inference_mode()
|
153 |
+
def eval_importance(
|
154 |
+
self,
|
155 |
+
prompt_tokens: List[List[int]],
|
156 |
+
max_gen_len: Optional[int] = 0,
|
157 |
+
temperature: Optional[float] = 0.6,
|
158 |
+
top_p: Optional[float] = 0.9,
|
159 |
+
angular: Optional[bool] = False
|
160 |
+
):
|
161 |
+
"""
|
162 |
+
Computes layer-wise importances over input tokens.
|
163 |
+
|
164 |
+
NOTE: ShortGPT paper performs no generation during importance computation, which suggests a `max_gen_len`= 0.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
|
168 |
+
(Optional) max_gen_len (int): Maximum length of the generated text sequence.
|
169 |
+
(Optional) temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
170 |
+
(Optional) top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
171 |
+
(Optional) angular (bool): Whether to ues angular distance. Defaults to False.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
None
|
175 |
+
"""
|
176 |
+
params = self.llama.model.params
|
177 |
+
bsz = len(prompt_tokens)
|
178 |
+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
179 |
+
|
180 |
+
min_prompt_len = min(len(t) for t in prompt_tokens)
|
181 |
+
max_prompt_len = max(len(t) for t in prompt_tokens)
|
182 |
+
assert max_prompt_len <= params.max_seq_len
|
183 |
+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
|
184 |
+
|
185 |
+
pad_id = self.llama.tokenizer.pad_id
|
186 |
+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
187 |
+
for k, t in enumerate(prompt_tokens):
|
188 |
+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
189 |
+
|
190 |
+
prev_pos = 0
|
191 |
+
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
192 |
+
input_text_mask = tokens != pad_id
|
193 |
+
|
194 |
+
for cur_pos in range(min_prompt_len, total_len):
|
195 |
+
logits = self.llama.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
196 |
+
if temperature > 0:
|
197 |
+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
198 |
+
next_token = sample_top_p(probs, top_p)
|
199 |
+
else:
|
200 |
+
next_token = torch.argmax(logits[:, -1], dim=-1)
|
201 |
+
|
202 |
+
next_token = next_token.reshape(-1)
|
203 |
+
# only replace token if prompt has already been generated
|
204 |
+
next_token = torch.where(
|
205 |
+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
206 |
+
)
|
207 |
+
tokens[:, cur_pos] = next_token
|
208 |
+
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
209 |
+
next_token == self.llama.tokenizer.eos_id
|
210 |
+
)
|
211 |
+
prev_pos = cur_pos
|
212 |
+
if all(eos_reached):
|
213 |
+
break
|
214 |
+
|
215 |
+
# compute block influence over full sequences rather than at each token
|
216 |
+
_, hiddens = self.llama.model.forward(tokens, 0, return_hiddens=True)
|
217 |
+
self.compute_bi(hiddens, angular=angular)
|
218 |
+
|
219 |
+
return
|
short_gpt/layer_removal.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
def layer_removal(
|
7 |
+
model: nn.Module,
|
8 |
+
layers_to_remove: OrderedDict
|
9 |
+
):
|
10 |
+
"""
|
11 |
+
Generic removal implementation
|
12 |
+
"""
|
13 |
+
|
14 |
+
for layer_name, layer_idx in layers_to_remove.items():
|
15 |
+
modules = layer_name.split(".")
|
16 |
+
mod = model
|
17 |
+
for m in modules[:-1]:
|
18 |
+
mod = getattr(mod, m)
|
19 |
+
|
20 |
+
if layer_idx is None:
|
21 |
+
del getattr(mod, modules[-1])
|
22 |
+
else:
|
23 |
+
del getattr(mod, modules[-1])[layer_idx]
|
short_gpt/metrics.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
def block_influence(
|
5 |
+
input_hidden_state: torch.Tensor,
|
6 |
+
output_hidden_state: torch.Tensor,
|
7 |
+
angular=False,
|
8 |
+
):
|
9 |
+
"""
|
10 |
+
input_hidden_state: B, S, D
|
11 |
+
output_hidden_state: B, S, D
|
12 |
+
"""
|
13 |
+
_, _, d = input_hidden_state.shape
|
14 |
+
input_hidden_state = input_hidden_state.reshape(-1, d)
|
15 |
+
output_hidden_state = output_hidden_state.reshape(-1, d)
|
16 |
+
|
17 |
+
norm_input = input_hidden_state.norm(dim=-1, keepdim=True)
|
18 |
+
norm_output = output_hidden_state.norm(dim=-1, keepdim=True)
|
19 |
+
|
20 |
+
sim = (input_hidden_state @ output_hidden_state.T) / (norm_input * norm_output)
|
21 |
+
sim = sim.diagonal().nan_to_num(nan=0.5)
|
22 |
+
|
23 |
+
if angular:
|
24 |
+
return (torch.arccos(sim) / torch.pi)
|
25 |
+
|
26 |
+
return 1 - sim
|
short_gpt/short_hf.ipynb
ADDED
@@ -0,0 +1,1679 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.19.1)\n",
|
13 |
+
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.1)\n",
|
14 |
+
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.40.2)\n",
|
15 |
+
"Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.10.0)\n",
|
16 |
+
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.13.1)\n",
|
17 |
+
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.2)\n",
|
18 |
+
"Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.0.0)\n",
|
19 |
+
"Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n",
|
20 |
+
"Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n",
|
21 |
+
"Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n",
|
22 |
+
"Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.31.0)\n",
|
23 |
+
"Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.2)\n",
|
24 |
+
"Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n",
|
25 |
+
"Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n",
|
26 |
+
"Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets) (2023.10.0)\n",
|
27 |
+
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.0b0)\n",
|
28 |
+
"Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.0)\n",
|
29 |
+
"Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (23.2)\n",
|
30 |
+
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1)\n",
|
31 |
+
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.8.0)\n",
|
32 |
+
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n",
|
33 |
+
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n",
|
34 |
+
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n",
|
35 |
+
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
36 |
+
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
37 |
+
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
38 |
+
"Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n",
|
39 |
+
"Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n",
|
40 |
+
"Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n",
|
41 |
+
"Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n",
|
42 |
+
"Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n",
|
43 |
+
"Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n",
|
44 |
+
"Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.18.1)\n",
|
45 |
+
"Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n",
|
46 |
+
"Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n",
|
47 |
+
"Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.3.101)\n",
|
48 |
+
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.4.28)\n",
|
49 |
+
"Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n",
|
50 |
+
"Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3)\n",
|
51 |
+
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.6)\n",
|
52 |
+
"Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.30.0)\n",
|
53 |
+
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
|
54 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n",
|
55 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n",
|
56 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n",
|
57 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
|
58 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
|
59 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.3.2)\n",
|
60 |
+
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (3.6)\n",
|
61 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2.1.0)\n",
|
62 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->datasets) (2023.11.17)\n",
|
63 |
+
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n",
|
64 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
|
65 |
+
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
|
66 |
+
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n",
|
67 |
+
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
|
68 |
+
"Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
|
69 |
+
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
|
70 |
+
"\u001b[0mNote: you may need to restart the kernel to use updated packages.\n"
|
71 |
+
]
|
72 |
+
}
|
73 |
+
],
|
74 |
+
"source": [
|
75 |
+
"pip install datasets torch transformers peft"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 4,
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [],
|
83 |
+
"source": [
|
84 |
+
"from tqdm.notebook import tqdm\n",
|
85 |
+
"\n",
|
86 |
+
"from datasets import load_dataset\n",
|
87 |
+
"import torch\n",
|
88 |
+
"from torch.utils.data import DataLoader\n",
|
89 |
+
"\n",
|
90 |
+
"from peft import (\n",
|
91 |
+
" get_peft_model,\n",
|
92 |
+
" LoraConfig,\n",
|
93 |
+
" TaskType,\n",
|
94 |
+
")\n",
|
95 |
+
"from transformers import default_data_collator, Trainer, TrainingArguments\n",
|
96 |
+
"\n",
|
97 |
+
"from short_hf import ShortHFModel"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "markdown",
|
102 |
+
"metadata": {},
|
103 |
+
"source": [
|
104 |
+
"### Load Data"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": null,
|
110 |
+
"metadata": {},
|
111 |
+
"outputs": [],
|
112 |
+
"source": [
|
113 |
+
"# data = load_dataset(\"pg19\", split=\"validation\") # authors sample 10,000 texts to compute block influences\n",
|
114 |
+
"# dataloader = DataLoader(\n",
|
115 |
+
"# data,\n",
|
116 |
+
"# batch_size=2,\n",
|
117 |
+
"# shuffle=True,\n",
|
118 |
+
"# )"
|
119 |
+
]
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"cell_type": "code",
|
123 |
+
"execution_count": 5,
|
124 |
+
"metadata": {},
|
125 |
+
"outputs": [],
|
126 |
+
"source": [
|
127 |
+
"data = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=\"validation\") # authors sample 10,000 texts to compute block influences\n",
|
128 |
+
"dataloader = DataLoader(\n",
|
129 |
+
" data,\n",
|
130 |
+
" batch_size=1,\n",
|
131 |
+
" shuffle=True,\n",
|
132 |
+
")"
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "markdown",
|
137 |
+
"metadata": {},
|
138 |
+
"source": [
|
139 |
+
"### Load Model"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 3,
|
145 |
+
"metadata": {},
|
146 |
+
"outputs": [],
|
147 |
+
"source": [
|
148 |
+
"# !huggingface-cli login\n",
|
149 |
+
"# pip install huggingface_hub\n",
|
150 |
+
"!python3 -c \"from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')\""
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": null,
|
156 |
+
"metadata": {},
|
157 |
+
"outputs": [],
|
158 |
+
"source": [
|
159 |
+
"#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"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": 3,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [
|
167 |
+
{
|
168 |
+
"name": "stdout",
|
169 |
+
"output_type": "stream",
|
170 |
+
"text": [
|
171 |
+
"asifahmed\n"
|
172 |
+
]
|
173 |
+
}
|
174 |
+
],
|
175 |
+
"source": [
|
176 |
+
"!huggingface-cli whoami"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 2,
|
182 |
+
"metadata": {},
|
183 |
+
"outputs": [],
|
184 |
+
"source": [
|
185 |
+
"# pip install git+https://github.com/tri-ml/linear_open_lm.git\n"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": 6,
|
191 |
+
"metadata": {},
|
192 |
+
"outputs": [
|
193 |
+
{
|
194 |
+
"name": "stderr",
|
195 |
+
"output_type": "stream",
|
196 |
+
"text": [
|
197 |
+
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
198 |
+
" warnings.warn(\n"
|
199 |
+
]
|
200 |
+
},
|
201 |
+
{
|
202 |
+
"data": {
|
203 |
+
"application/vnd.jupyter.widget-view+json": {
|
204 |
+
"model_id": "9fcf366ecc414808b39285438599f5b9",
|
205 |
+
"version_major": 2,
|
206 |
+
"version_minor": 0
|
207 |
+
},
|
208 |
+
"text/plain": [
|
209 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
210 |
+
]
|
211 |
+
},
|
212 |
+
"metadata": {},
|
213 |
+
"output_type": "display_data"
|
214 |
+
}
|
215 |
+
],
|
216 |
+
"source": [
|
217 |
+
"# from open_lm.open_lm_hf import *\n",
|
218 |
+
"\n",
|
219 |
+
"MAX_SEQ_LEN = 2048\n",
|
220 |
+
"short_model = ShortHFModel(\n",
|
221 |
+
" # model_name=\"tiiuae/falcon-7b\",\n",
|
222 |
+
" model_name=\"mistralai/Mistral-7B-v0.1\",\n",
|
223 |
+
" layers_path=\"model.layers\",\n",
|
224 |
+
" n_prune_layers=2\n",
|
225 |
+
")\n",
|
226 |
+
"# short_model.model"
|
227 |
+
]
|
228 |
+
},
|
229 |
+
{
|
230 |
+
"cell_type": "code",
|
231 |
+
"execution_count": 7,
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [
|
234 |
+
{
|
235 |
+
"data": {
|
236 |
+
"text/plain": [
|
237 |
+
"MistralForCausalLM(\n",
|
238 |
+
" (model): MistralModel(\n",
|
239 |
+
" (embed_tokens): Embedding(32000, 4096)\n",
|
240 |
+
" (layers): ModuleList(\n",
|
241 |
+
" (0-31): 32 x MistralDecoderLayer(\n",
|
242 |
+
" (self_attn): MistralSdpaAttention(\n",
|
243 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
244 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
245 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
246 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
247 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
248 |
+
" )\n",
|
249 |
+
" (mlp): MistralMLP(\n",
|
250 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
251 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
252 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
253 |
+
" (act_fn): SiLU()\n",
|
254 |
+
" )\n",
|
255 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
256 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
257 |
+
" )\n",
|
258 |
+
" )\n",
|
259 |
+
" (norm): MistralRMSNorm()\n",
|
260 |
+
" )\n",
|
261 |
+
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
|
262 |
+
")"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
"execution_count": 7,
|
266 |
+
"metadata": {},
|
267 |
+
"output_type": "execute_result"
|
268 |
+
}
|
269 |
+
],
|
270 |
+
"source": [
|
271 |
+
"short_model.model"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "code",
|
276 |
+
"execution_count": null,
|
277 |
+
"metadata": {},
|
278 |
+
"outputs": [],
|
279 |
+
"source": [
|
280 |
+
"# AutoModelForCausalLM.from_pretrained(\n",
|
281 |
+
"# pretrained_model_name_or_path=model_dir,\n",
|
282 |
+
"# local_files_only=True,\n",
|
283 |
+
"# use_safetensors=True,\n",
|
284 |
+
"# torch_dtype=torch.bfloat16,\n",
|
285 |
+
"# )"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": 8,
|
291 |
+
"metadata": {},
|
292 |
+
"outputs": [
|
293 |
+
{
|
294 |
+
"data": {
|
295 |
+
"text/plain": [
|
296 |
+
"<generator object Module.parameters at 0x7f00b3917840>"
|
297 |
+
]
|
298 |
+
},
|
299 |
+
"execution_count": 8,
|
300 |
+
"metadata": {},
|
301 |
+
"output_type": "execute_result"
|
302 |
+
}
|
303 |
+
],
|
304 |
+
"source": [
|
305 |
+
"short_model.model.parameters()"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"execution_count": 9,
|
311 |
+
"metadata": {},
|
312 |
+
"outputs": [
|
313 |
+
{
|
314 |
+
"data": {
|
315 |
+
"text/plain": [
|
316 |
+
"7241732096"
|
317 |
+
]
|
318 |
+
},
|
319 |
+
"execution_count": 9,
|
320 |
+
"metadata": {},
|
321 |
+
"output_type": "execute_result"
|
322 |
+
}
|
323 |
+
],
|
324 |
+
"source": [
|
325 |
+
"pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n",
|
326 |
+
"pytorch_total_params"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "code",
|
331 |
+
"execution_count": 36,
|
332 |
+
"metadata": {},
|
333 |
+
"outputs": [],
|
334 |
+
"source": [
|
335 |
+
" # Save the model state to the specified path.\n",
|
336 |
+
"# model_dir='ShortModelSaved/'\n",
|
337 |
+
"# short_model.model.save_pretrained(\n",
|
338 |
+
"# save_directory=model_dir,\n",
|
339 |
+
"# safe_serialization=True,\n",
|
340 |
+
"# )"
|
341 |
+
]
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"cell_type": "code",
|
345 |
+
"execution_count": 10,
|
346 |
+
"metadata": {},
|
347 |
+
"outputs": [
|
348 |
+
{
|
349 |
+
"data": {
|
350 |
+
"text/plain": [
|
351 |
+
"MistralDecoderLayer(\n",
|
352 |
+
" (self_attn): MistralSdpaAttention(\n",
|
353 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
354 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
355 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
356 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
357 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
358 |
+
" )\n",
|
359 |
+
" (mlp): MistralMLP(\n",
|
360 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
361 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
362 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
363 |
+
" (act_fn): SiLU()\n",
|
364 |
+
" )\n",
|
365 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
366 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
367 |
+
")"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
"execution_count": 10,
|
371 |
+
"metadata": {},
|
372 |
+
"output_type": "execute_result"
|
373 |
+
}
|
374 |
+
],
|
375 |
+
"source": [
|
376 |
+
"short_model.layers[0]"
|
377 |
+
]
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"cell_type": "code",
|
381 |
+
"execution_count": 12,
|
382 |
+
"metadata": {},
|
383 |
+
"outputs": [
|
384 |
+
{
|
385 |
+
"name": "stderr",
|
386 |
+
"output_type": "stream",
|
387 |
+
"text": [
|
388 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
389 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
390 |
+
]
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"data": {
|
394 |
+
"text/plain": [
|
395 |
+
"['I am an avid fan of 3D printing. I have been using 3D printers for over 10 years and have been involved in the development of several 3D printers. I have also been involved in the development of several 3D printing software packages.\\n\\nI have been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages. I have also been involved in the development of several 3D printing software packages.']"
|
396 |
+
]
|
397 |
+
},
|
398 |
+
"execution_count": 12,
|
399 |
+
"metadata": {},
|
400 |
+
"output_type": "execute_result"
|
401 |
+
}
|
402 |
+
],
|
403 |
+
"source": [
|
404 |
+
"# sample generationThe evolution of AI has lead to \n",
|
405 |
+
"gen = short_model.model.generate(\n",
|
406 |
+
" short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
407 |
+
" max_new_tokens=256\n",
|
408 |
+
")\n",
|
409 |
+
"short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
410 |
+
]
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"cell_type": "code",
|
414 |
+
"execution_count": 2,
|
415 |
+
"metadata": {},
|
416 |
+
"outputs": [],
|
417 |
+
"source": [
|
418 |
+
"# # sample generation\n",
|
419 |
+
"# gen = short_model.model.generate(\n",
|
420 |
+
"# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
421 |
+
"# max_new_tokens=256\n",
|
422 |
+
"# )\n",
|
423 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
424 |
+
]
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"cell_type": "markdown",
|
428 |
+
"metadata": {},
|
429 |
+
"source": [
|
430 |
+
"### Compute Importances"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"cell_type": "code",
|
435 |
+
"execution_count": 50,
|
436 |
+
"metadata": {},
|
437 |
+
"outputs": [],
|
438 |
+
"source": [
|
439 |
+
"# for i, batch in enumerate(tqdm(dataloader)):\n",
|
440 |
+
"# prompts = batch['text']\n",
|
441 |
+
"\n",
|
442 |
+
"# short_model.eval_importance(\n",
|
443 |
+
"# prompts=prompts,\n",
|
444 |
+
"# max_seq_len=MAX_SEQ_LEN,\n",
|
445 |
+
"# stride=256,\n",
|
446 |
+
"# max_gen_len=0\n",
|
447 |
+
"# )"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": 51,
|
453 |
+
"metadata": {},
|
454 |
+
"outputs": [],
|
455 |
+
"source": [
|
456 |
+
"# short_model.importances"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"cell_type": "markdown",
|
461 |
+
"metadata": {},
|
462 |
+
"source": [
|
463 |
+
"### Remove unimportant layers\n",
|
464 |
+
"\n",
|
465 |
+
"Layers removed when using subset of pg19 val set: [25, 26, 24, 27, 22, 23, 28, 21, 29]\n",
|
466 |
+
"\n",
|
467 |
+
"Authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance."
|
468 |
+
]
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"cell_type": "code",
|
472 |
+
"execution_count": 55,
|
473 |
+
"metadata": {},
|
474 |
+
"outputs": [],
|
475 |
+
"source": [
|
476 |
+
"# short_model.remove_layers()"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"cell_type": "code",
|
481 |
+
"execution_count": 54,
|
482 |
+
"metadata": {},
|
483 |
+
"outputs": [],
|
484 |
+
"source": [
|
485 |
+
"# short_model.remove_layers()"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"cell_type": "code",
|
490 |
+
"execution_count": 56,
|
491 |
+
"metadata": {},
|
492 |
+
"outputs": [],
|
493 |
+
"source": [
|
494 |
+
"# short_model.layers"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": 48,
|
500 |
+
"metadata": {},
|
501 |
+
"outputs": [],
|
502 |
+
"source": [
|
503 |
+
"# # reassign layer_idx to attentions for caching\n",
|
504 |
+
"# for layer_idx, module in enumerate(short_model.layers):\n",
|
505 |
+
"# module.self_attn.layer_idx = layer_idx"
|
506 |
+
]
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"cell_type": "code",
|
510 |
+
"execution_count": 20,
|
511 |
+
"metadata": {},
|
512 |
+
"outputs": [
|
513 |
+
{
|
514 |
+
"data": {
|
515 |
+
"text/plain": [
|
516 |
+
"<generator object Module.parameters at 0x7f625768a2d0>"
|
517 |
+
]
|
518 |
+
},
|
519 |
+
"execution_count": 20,
|
520 |
+
"metadata": {},
|
521 |
+
"output_type": "execute_result"
|
522 |
+
}
|
523 |
+
],
|
524 |
+
"source": [
|
525 |
+
"# short_model.model.parameters()"
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"cell_type": "code",
|
530 |
+
"execution_count": 68,
|
531 |
+
"metadata": {},
|
532 |
+
"outputs": [
|
533 |
+
{
|
534 |
+
"data": {
|
535 |
+
"text/plain": [
|
536 |
+
"7241732096"
|
537 |
+
]
|
538 |
+
},
|
539 |
+
"execution_count": 68,
|
540 |
+
"metadata": {},
|
541 |
+
"output_type": "execute_result"
|
542 |
+
}
|
543 |
+
],
|
544 |
+
"source": [
|
545 |
+
"# pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n",
|
546 |
+
"# pytorch_total_params"
|
547 |
+
]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"cell_type": "markdown",
|
551 |
+
"metadata": {},
|
552 |
+
"source": [
|
553 |
+
"As the paper states: \\\n",
|
554 |
+
" - \"Our experiments reveal that the effect of layer removal is significantly more pronounced on generative\n",
|
555 |
+
" tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and\n",
|
556 |
+
" HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance\n",
|
557 |
+
" drop, with scores approaching zero.\""
|
558 |
+
]
|
559 |
+
},
|
560 |
+
{
|
561 |
+
"cell_type": "code",
|
562 |
+
"execution_count": 53,
|
563 |
+
"metadata": {},
|
564 |
+
"outputs": [],
|
565 |
+
"source": [
|
566 |
+
"# gen = short_model.model.generate(\n",
|
567 |
+
"# short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
568 |
+
"# max_new_tokens=20,\n",
|
569 |
+
"# use_cache=True\n",
|
570 |
+
"# )\n",
|
571 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
572 |
+
]
|
573 |
+
},
|
574 |
+
{
|
575 |
+
"cell_type": "code",
|
576 |
+
"execution_count": 52,
|
577 |
+
"metadata": {},
|
578 |
+
"outputs": [],
|
579 |
+
"source": [
|
580 |
+
"# gen = short_model.model.generate(I am an avid fan of \n",
|
581 |
+
"# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
582 |
+
"# max_new_tokens=20,\n",
|
583 |
+
"# use_cache=True\n",
|
584 |
+
"# )\n",
|
585 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
586 |
+
]
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"cell_type": "markdown",
|
590 |
+
"metadata": {},
|
591 |
+
"source": [
|
592 |
+
"### Compute Angular Importances"
|
593 |
+
]
|
594 |
+
},
|
595 |
+
{
|
596 |
+
"cell_type": "code",
|
597 |
+
"execution_count": 16,
|
598 |
+
"metadata": {},
|
599 |
+
"outputs": [
|
600 |
+
{
|
601 |
+
"data": {
|
602 |
+
"application/vnd.jupyter.widget-view+json": {
|
603 |
+
"model_id": "a6fd2bf4360b4aba801085bab0755a06",
|
604 |
+
"version_major": 2,
|
605 |
+
"version_minor": 0
|
606 |
+
},
|
607 |
+
"text/plain": [
|
608 |
+
" 0%| | 0/3760 [00:00<?, ?it/s]"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
"metadata": {},
|
612 |
+
"output_type": "display_data"
|
613 |
+
}
|
614 |
+
],
|
615 |
+
"source": [
|
616 |
+
"for i, batch in enumerate(tqdm(dataloader)):\n",
|
617 |
+
" prompts = batch['text']\n",
|
618 |
+
"\n",
|
619 |
+
" short_model.eval_importance(\n",
|
620 |
+
" prompts=prompts,\n",
|
621 |
+
" max_seq_len=MAX_SEQ_LEN,\n",
|
622 |
+
" stride=256,\n",
|
623 |
+
" max_gen_len=0,\n",
|
624 |
+
" angular=True\n",
|
625 |
+
" )"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"cell_type": "code",
|
630 |
+
"execution_count": 17,
|
631 |
+
"metadata": {},
|
632 |
+
"outputs": [
|
633 |
+
{
|
634 |
+
"data": {
|
635 |
+
"text/plain": [
|
636 |
+
"[128390.1328125,\n",
|
637 |
+
" 80922.06787109375,\n",
|
638 |
+
" 61075.2890625,\n",
|
639 |
+
" nan,\n",
|
640 |
+
" nan,\n",
|
641 |
+
" 56557.81268310547,\n",
|
642 |
+
" nan,\n",
|
643 |
+
" 52294.552001953125,\n",
|
644 |
+
" 47928.185302734375,\n",
|
645 |
+
" 42335.215576171875,\n",
|
646 |
+
" 40547.564208984375,\n",
|
647 |
+
" 37178.684326171875,\n",
|
648 |
+
" 34713.912841796875,\n",
|
649 |
+
" 33843.728271484375,\n",
|
650 |
+
" 35384.353271484375,\n",
|
651 |
+
" 35603.388427734375,\n",
|
652 |
+
" 35621.970458984375,\n",
|
653 |
+
" 35356.719482421875,\n",
|
654 |
+
" 35365.243896484375,\n",
|
655 |
+
" 34914.025146484375,\n",
|
656 |
+
" 27854.576904296875,\n",
|
657 |
+
" 24398.073974609375,\n",
|
658 |
+
" 20450.390380859375,\n",
|
659 |
+
" 19501.300537109375,\n",
|
660 |
+
" 18430.427490234375,\n",
|
661 |
+
" 18231.873779296875,\n",
|
662 |
+
" 17917.493896484375,\n",
|
663 |
+
" 17806.815185546875,\n",
|
664 |
+
" 21227.195068359375,\n",
|
665 |
+
" 23928.313018798828,\n",
|
666 |
+
" 22738.702880859375,\n",
|
667 |
+
" 86123.783203125]"
|
668 |
+
]
|
669 |
+
},
|
670 |
+
"execution_count": 17,
|
671 |
+
"metadata": {},
|
672 |
+
"output_type": "execute_result"
|
673 |
+
}
|
674 |
+
],
|
675 |
+
"source": [
|
676 |
+
"short_model.importances"
|
677 |
+
]
|
678 |
+
},
|
679 |
+
{
|
680 |
+
"cell_type": "markdown",
|
681 |
+
"metadata": {},
|
682 |
+
"source": [
|
683 |
+
"### Remove unimportant layers"
|
684 |
+
]
|
685 |
+
},
|
686 |
+
{
|
687 |
+
"cell_type": "code",
|
688 |
+
"execution_count": 18,
|
689 |
+
"metadata": {},
|
690 |
+
"outputs": [
|
691 |
+
{
|
692 |
+
"data": {
|
693 |
+
"text/plain": [
|
694 |
+
"[27, 28]"
|
695 |
+
]
|
696 |
+
},
|
697 |
+
"execution_count": 18,
|
698 |
+
"metadata": {},
|
699 |
+
"output_type": "execute_result"
|
700 |
+
}
|
701 |
+
],
|
702 |
+
"source": [
|
703 |
+
"short_model.remove_layers(angular=True)"
|
704 |
+
]
|
705 |
+
},
|
706 |
+
{
|
707 |
+
"cell_type": "code",
|
708 |
+
"execution_count": 20,
|
709 |
+
"metadata": {},
|
710 |
+
"outputs": [
|
711 |
+
{
|
712 |
+
"data": {
|
713 |
+
"text/plain": [
|
714 |
+
"MistralDecoderLayer(\n",
|
715 |
+
" (self_attn): MistralSdpaAttention(\n",
|
716 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
717 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
718 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
719 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
720 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
721 |
+
" )\n",
|
722 |
+
" (mlp): MistralMLP(\n",
|
723 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
724 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
725 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
726 |
+
" (act_fn): SiLU()\n",
|
727 |
+
" )\n",
|
728 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
729 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
730 |
+
")"
|
731 |
+
]
|
732 |
+
},
|
733 |
+
"execution_count": 20,
|
734 |
+
"metadata": {},
|
735 |
+
"output_type": "execute_result"
|
736 |
+
}
|
737 |
+
],
|
738 |
+
"source": [
|
739 |
+
"short_model.layers[0]"
|
740 |
+
]
|
741 |
+
},
|
742 |
+
{
|
743 |
+
"cell_type": "code",
|
744 |
+
"execution_count": 21,
|
745 |
+
"metadata": {},
|
746 |
+
"outputs": [
|
747 |
+
{
|
748 |
+
"data": {
|
749 |
+
"text/plain": [
|
750 |
+
"ModuleList(\n",
|
751 |
+
" (0-29): 30 x MistralDecoderLayer(\n",
|
752 |
+
" (self_attn): MistralSdpaAttention(\n",
|
753 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
754 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
755 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
756 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
757 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
758 |
+
" )\n",
|
759 |
+
" (mlp): MistralMLP(\n",
|
760 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
761 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
762 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
763 |
+
" (act_fn): SiLU()\n",
|
764 |
+
" )\n",
|
765 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
766 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
767 |
+
" )\n",
|
768 |
+
")"
|
769 |
+
]
|
770 |
+
},
|
771 |
+
"execution_count": 21,
|
772 |
+
"metadata": {},
|
773 |
+
"output_type": "execute_result"
|
774 |
+
}
|
775 |
+
],
|
776 |
+
"source": [
|
777 |
+
"short_model.layers"
|
778 |
+
]
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"cell_type": "code",
|
782 |
+
"execution_count": 22,
|
783 |
+
"metadata": {},
|
784 |
+
"outputs": [],
|
785 |
+
"source": [
|
786 |
+
"# reassign layer_idx to attentions for caching\n",
|
787 |
+
"for layer_idx, module in enumerate(short_model.layers):\n",
|
788 |
+
" module.self_attn.layer_idx = layer_idx"
|
789 |
+
]
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"cell_type": "code",
|
793 |
+
"execution_count": 23,
|
794 |
+
"metadata": {},
|
795 |
+
"outputs": [
|
796 |
+
{
|
797 |
+
"data": {
|
798 |
+
"text/plain": [
|
799 |
+
"ModuleList(\n",
|
800 |
+
" (0-29): 30 x MistralDecoderLayer(\n",
|
801 |
+
" (self_attn): MistralSdpaAttention(\n",
|
802 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
803 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
804 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
805 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
806 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
807 |
+
" )\n",
|
808 |
+
" (mlp): MistralMLP(\n",
|
809 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
810 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
811 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
812 |
+
" (act_fn): SiLU()\n",
|
813 |
+
" )\n",
|
814 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
815 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
816 |
+
" )\n",
|
817 |
+
")"
|
818 |
+
]
|
819 |
+
},
|
820 |
+
"execution_count": 23,
|
821 |
+
"metadata": {},
|
822 |
+
"output_type": "execute_result"
|
823 |
+
}
|
824 |
+
],
|
825 |
+
"source": [
|
826 |
+
"short_model.layers"
|
827 |
+
]
|
828 |
+
},
|
829 |
+
{
|
830 |
+
"cell_type": "code",
|
831 |
+
"execution_count": 24,
|
832 |
+
"metadata": {},
|
833 |
+
"outputs": [
|
834 |
+
{
|
835 |
+
"name": "stderr",
|
836 |
+
"output_type": "stream",
|
837 |
+
"text": [
|
838 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
839 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
840 |
+
]
|
841 |
+
},
|
842 |
+
{
|
843 |
+
"data": {
|
844 |
+
"text/plain": [
|
845 |
+
"['I am an avid fan of 19th century American literature. I have read all of the classics, and I have also read many of the lesser known works. I have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens, and I have read all of his novels. I have also read many of the novels of other 19th century authors, such as Jane Austen, William Shakespeare, and William Blake.\\n\\nI have a particular interest in the works of Charles Dickens']"
|
846 |
+
]
|
847 |
+
},
|
848 |
+
"execution_count": 24,
|
849 |
+
"metadata": {},
|
850 |
+
"output_type": "execute_result"
|
851 |
+
}
|
852 |
+
],
|
853 |
+
"source": [
|
854 |
+
"gen = short_model.model.generate(\n",
|
855 |
+
" short_model.tokenizer([\"I am an avid fan of \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
856 |
+
" max_new_tokens=256,\n",
|
857 |
+
" use_cache=True\n",
|
858 |
+
")\n",
|
859 |
+
"short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
860 |
+
]
|
861 |
+
},
|
862 |
+
{
|
863 |
+
"cell_type": "code",
|
864 |
+
"execution_count": 27,
|
865 |
+
"metadata": {},
|
866 |
+
"outputs": [
|
867 |
+
{
|
868 |
+
"name": "stderr",
|
869 |
+
"output_type": "stream",
|
870 |
+
"text": [
|
871 |
+
"The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
|
872 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
873 |
+
]
|
874 |
+
},
|
875 |
+
{
|
876 |
+
"data": {
|
877 |
+
"text/plain": [
|
878 |
+
"['The evolution of AI has lead to 3 major types of AI:\\n\\n1. Strong AI\\n2. Weak AI\\n3. Super AI\\n\\nStrong AI is the type of AI that is capable of performing any task that a human can perform. This type of AI is still in the development phase and is not yet available in the market.\\n\\nWeak AI is the type of AI that is capable of performing a specific task. This type of AI is available in the market and is used in a variety of applications.\\n\\nSuper AI is the type of AI that is capable of performing any task that a human can perform and is also capable of learning and adapting. This type of AI is still in the development phase and is not yet available in the market.\\n\\n## What is the difference between AI and AI?\\n\\nThe difference between AI and AI is that AI is a type of artificial intelligence that is capable of performing a specific task, while AI is a type of artificial intelligence that is capable of performing any task.\\n\\n## What is the difference between AI and AI?\\n\\nThe difference between AI and AI is that AI is a type of artificial intelligence that is capable of performing a specific task, while AI is a type of artificial intelligence that is capable']"
|
879 |
+
]
|
880 |
+
},
|
881 |
+
"execution_count": 27,
|
882 |
+
"metadata": {},
|
883 |
+
"output_type": "execute_result"
|
884 |
+
}
|
885 |
+
],
|
886 |
+
"source": [
|
887 |
+
"# gen = short_model.model.generate(I am an avid fan of \n",
|
888 |
+
"# short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
889 |
+
"# max_new_tokens=256,\n",
|
890 |
+
"# use_cache=True\n",
|
891 |
+
"# )\n",
|
892 |
+
"# short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)\n",
|
893 |
+
"\n",
|
894 |
+
"\n",
|
895 |
+
"gen = short_model.model.generate(\n",
|
896 |
+
" short_model.tokenizer([\"The evolution of AI has lead to \"], return_tensors='pt').input_ids.to(\"cuda\"),\n",
|
897 |
+
" max_new_tokens=256,\n",
|
898 |
+
" use_cache=True\n",
|
899 |
+
")\n",
|
900 |
+
"short_model.tokenizer.batch_decode(gen, skip_special_tokens=True)"
|
901 |
+
]
|
902 |
+
},
|
903 |
+
{
|
904 |
+
"cell_type": "code",
|
905 |
+
"execution_count": 28,
|
906 |
+
"metadata": {},
|
907 |
+
"outputs": [
|
908 |
+
{
|
909 |
+
"data": {
|
910 |
+
"text/plain": [
|
911 |
+
"6805508096"
|
912 |
+
]
|
913 |
+
},
|
914 |
+
"execution_count": 28,
|
915 |
+
"metadata": {},
|
916 |
+
"output_type": "execute_result"
|
917 |
+
}
|
918 |
+
],
|
919 |
+
"source": [
|
920 |
+
"pytorch_total_params = sum(p.numel() for p in short_model.model.parameters())\n",
|
921 |
+
"pytorch_total_params"
|
922 |
+
]
|
923 |
+
},
|
924 |
+
{
|
925 |
+
"cell_type": "code",
|
926 |
+
"execution_count": 35,
|
927 |
+
"metadata": {},
|
928 |
+
"outputs": [],
|
929 |
+
"source": [
|
930 |
+
" # Save the model state to the specified path.\n",
|
931 |
+
"model_dir='SmallModelSaved/'\n",
|
932 |
+
"short_model.model.save_pretrained(\n",
|
933 |
+
" save_directory=model_dir,\n",
|
934 |
+
" safe_serialization=True,\n",
|
935 |
+
" )"
|
936 |
+
]
|
937 |
+
},
|
938 |
+
{
|
939 |
+
"cell_type": "markdown",
|
940 |
+
"metadata": {},
|
941 |
+
"source": [
|
942 |
+
"### Model Healing"
|
943 |
+
]
|
944 |
+
},
|
945 |
+
{
|
946 |
+
"cell_type": "code",
|
947 |
+
"execution_count": 36,
|
948 |
+
"metadata": {},
|
949 |
+
"outputs": [],
|
950 |
+
"source": [
|
951 |
+
"# tokenizer = short_model.tokenizer\n",
|
952 |
+
"model = short_model.model"
|
953 |
+
]
|
954 |
+
},
|
955 |
+
{
|
956 |
+
"cell_type": "code",
|
957 |
+
"execution_count": 37,
|
958 |
+
"metadata": {},
|
959 |
+
"outputs": [
|
960 |
+
{
|
961 |
+
"name": "stdout",
|
962 |
+
"output_type": "stream",
|
963 |
+
"text": [
|
964 |
+
"Datset Loaded!\n"
|
965 |
+
]
|
966 |
+
}
|
967 |
+
],
|
968 |
+
"source": [
|
969 |
+
"from datasets import load_dataset\n",
|
970 |
+
"# Falcon = load_dataset(\"csv\", data_files=\"FalconData.csv\")\n",
|
971 |
+
"Falcon = load_dataset('csv', data_files={\"train\": 'FalconData2.csv', \"validation\": 'FalconDataEval2.csv'})\n",
|
972 |
+
"\n",
|
973 |
+
"print('Datset Loaded!')\n"
|
974 |
+
]
|
975 |
+
},
|
976 |
+
{
|
977 |
+
"cell_type": "code",
|
978 |
+
"execution_count": 38,
|
979 |
+
"metadata": {},
|
980 |
+
"outputs": [
|
981 |
+
{
|
982 |
+
"data": {
|
983 |
+
"text/plain": [
|
984 |
+
"{'Text': 'School Picture Gallery\\nFrance Ski School\\nChildren from Year 5 & 6 travelled to France from Newcastle airport to take part in a week of Ski School. The children had already spent 3 weeks learning the basics of skiing at Silksworth Ski School in Sunderland. When the children arrived in France they took part in a daily Ski School, during which the children made OUTSTANDING progress. The children also took part in French activities, explored local landmarks and took part in shopping activities in Chamonix. It was an incredible adventure for the children and staff!'}"
|
985 |
+
]
|
986 |
+
},
|
987 |
+
"execution_count": 38,
|
988 |
+
"metadata": {},
|
989 |
+
"output_type": "execute_result"
|
990 |
+
}
|
991 |
+
],
|
992 |
+
"source": [
|
993 |
+
"# Falcon = Falcon.train_test_split(test_size=0.10)\n",
|
994 |
+
"\n",
|
995 |
+
"\"\"\"Then take a look at an example:\"\"\"\n",
|
996 |
+
"\n",
|
997 |
+
"Falcon['train'][0]\n"
|
998 |
+
]
|
999 |
+
},
|
1000 |
+
{
|
1001 |
+
"cell_type": "code",
|
1002 |
+
"execution_count": 39,
|
1003 |
+
"metadata": {},
|
1004 |
+
"outputs": [
|
1005 |
+
{
|
1006 |
+
"data": {
|
1007 |
+
"text/plain": [
|
1008 |
+
"{'Text': 'Our Annual Garden Party is a fun-filled event with a ton of landscaping and garden supplies; gardening demonstrations, experts, and vendors; activities for kids; live bands; and local food. It’s been so popular that we’re extending it to TWO DAYS this year!\\nFestivities at 10am – 4pm Saturday and 11am – 3pm Sunday\\nShopping from 9am – 6pm both days\\nThroughout the winter, we collect gently-used and surplus lawn & garden supplies as well as outdoor décor and furniture. Then, we put it all out for your shopping pleasure! The sale begins at 9:00 am Saturday, but folks start lining up outside the gates even earlier, eager to dig through piles of flowerpots and shovels. (If you can’t get there in the morning, don’t worry – the staff continues to bring out items throughout the weekend.)\\nThe Garden Sale 1st.\\nThere will be prizes for people and pets dressed in garden party finery.\\nPhoto by Carrie Delesky\\nSo find yourself a dapper suit or fancy hat, and check out all the activities in store for you:\\nAnacostia Watershed Society\\nPrince George’s Chapter, Maryland Master Gardeners\\nMOM’s Organic Market\\nTreincarnation\\nVeteran Compost\\nPhoto by Carrie Delesky\\nSaturday the Forklift’s Matt Menke and Gary Barnhart of GL Barnhart Construction. Drop in for a while, or stay the whole.'}"
|
1009 |
+
]
|
1010 |
+
},
|
1011 |
+
"execution_count": 39,
|
1012 |
+
"metadata": {},
|
1013 |
+
"output_type": "execute_result"
|
1014 |
+
}
|
1015 |
+
],
|
1016 |
+
"source": [
|
1017 |
+
"Falcon['validation'][0]\n"
|
1018 |
+
]
|
1019 |
+
},
|
1020 |
+
{
|
1021 |
+
"cell_type": "code",
|
1022 |
+
"execution_count": 41,
|
1023 |
+
"metadata": {},
|
1024 |
+
"outputs": [
|
1025 |
+
{
|
1026 |
+
"name": "stderr",
|
1027 |
+
"output_type": "stream",
|
1028 |
+
"text": [
|
1029 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
1030 |
+
]
|
1031 |
+
}
|
1032 |
+
],
|
1033 |
+
"source": [
|
1034 |
+
"\"\"\"The next step is to load a DistilGPT2 tokenizer to process the `text` subfield:\"\"\"\n",
|
1035 |
+
"\n",
|
1036 |
+
"from transformers import AutoTokenizer, GPT2TokenizerFast\n",
|
1037 |
+
"\n",
|
1038 |
+
"# tokenizer = AutoTokenizer.from_pretrained(\"distilgpt2\")\n",
|
1039 |
+
"\n",
|
1040 |
+
"\n",
|
1041 |
+
"tokenizer = GPT2TokenizerFast.from_pretrained(\"Xenova/gpt-4\")#, cache_dir=cache_dir)\n",
|
1042 |
+
"tokenizer.pad_token = tokenizer.eos_token\n"
|
1043 |
+
]
|
1044 |
+
},
|
1045 |
+
{
|
1046 |
+
"cell_type": "code",
|
1047 |
+
"execution_count": 42,
|
1048 |
+
"metadata": {},
|
1049 |
+
"outputs": [
|
1050 |
+
{
|
1051 |
+
"data": {
|
1052 |
+
"text/plain": [
|
1053 |
+
"{'Text': 'School Picture Gallery\\nFrance Ski School\\nChildren from Year 5 & 6 travelled to France from Newcastle airport to take part in a week of Ski School. The children had already spent 3 weeks learning the basics of skiing at Silksworth Ski School in Sunderland. When the children arrived in France they took part in a daily Ski School, during which the children made OUTSTANDING progress. The children also took part in French activities, explored local landmarks and took part in shopping activities in Chamonix. It was an incredible adventure for the children and staff!'}"
|
1054 |
+
]
|
1055 |
+
},
|
1056 |
+
"execution_count": 42,
|
1057 |
+
"metadata": {},
|
1058 |
+
"output_type": "execute_result"
|
1059 |
+
}
|
1060 |
+
],
|
1061 |
+
"source": [
|
1062 |
+
"Falcon = Falcon.flatten()\n",
|
1063 |
+
"Falcon[\"train\"][0]"
|
1064 |
+
]
|
1065 |
+
},
|
1066 |
+
{
|
1067 |
+
"cell_type": "code",
|
1068 |
+
"execution_count": 43,
|
1069 |
+
"metadata": {},
|
1070 |
+
"outputs": [
|
1071 |
+
{
|
1072 |
+
"name": "stdout",
|
1073 |
+
"output_type": "stream",
|
1074 |
+
"text": [
|
1075 |
+
"The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !\n"
|
1076 |
+
]
|
1077 |
+
},
|
1078 |
+
{
|
1079 |
+
"data": {
|
1080 |
+
"application/vnd.jupyter.widget-view+json": {
|
1081 |
+
"model_id": "d2182d4fa561406ab7eb5fc6c19c6d17",
|
1082 |
+
"version_major": 2,
|
1083 |
+
"version_minor": 0
|
1084 |
+
},
|
1085 |
+
"text/plain": [
|
1086 |
+
"Map (num_proc=4): 0%| | 0/10000 [00:00<?, ? examples/s]"
|
1087 |
+
]
|
1088 |
+
},
|
1089 |
+
"metadata": {},
|
1090 |
+
"output_type": "display_data"
|
1091 |
+
},
|
1092 |
+
{
|
1093 |
+
"name": "stderr",
|
1094 |
+
"output_type": "stream",
|
1095 |
+
"text": [
|
1096 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (10412 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1097 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (10738 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1098 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (12860 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1099 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (23091 > 8192). Running this sequence through the model will result in indexing errors\n"
|
1100 |
+
]
|
1101 |
+
},
|
1102 |
+
{
|
1103 |
+
"name": "stdout",
|
1104 |
+
"output_type": "stream",
|
1105 |
+
"text": [
|
1106 |
+
"The OrderedVocab you are attempting to save contains holes for indices [100256, 100261, 100262, 100263, 100266, 100267, 100268, 100269, 100270, 100271, 100272, 100273, 100274, 100275], your vocabulary could be corrupted !\n"
|
1107 |
+
]
|
1108 |
+
},
|
1109 |
+
{
|
1110 |
+
"data": {
|
1111 |
+
"application/vnd.jupyter.widget-view+json": {
|
1112 |
+
"model_id": "121ffe72baf143f4aeea4616bee88405",
|
1113 |
+
"version_major": 2,
|
1114 |
+
"version_minor": 0
|
1115 |
+
},
|
1116 |
+
"text/plain": [
|
1117 |
+
"Map (num_proc=4): 0%| | 0/1000 [00:00<?, ? examples/s]"
|
1118 |
+
]
|
1119 |
+
},
|
1120 |
+
"metadata": {},
|
1121 |
+
"output_type": "display_data"
|
1122 |
+
},
|
1123 |
+
{
|
1124 |
+
"name": "stderr",
|
1125 |
+
"output_type": "stream",
|
1126 |
+
"text": [
|
1127 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (9078 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1128 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (15886 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1129 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (28727 > 8192). Running this sequence through the model will result in indexing errors\n",
|
1130 |
+
"Token indices sequence length is longer than the specified maximum sequence length for this model (8257 > 8192). Running this sequence through the model will result in indexing errors\n"
|
1131 |
+
]
|
1132 |
+
}
|
1133 |
+
],
|
1134 |
+
"source": [
|
1135 |
+
"def preprocess_function(examples):\n",
|
1136 |
+
" return tokenizer([\" \".join(x) for x in examples[\"Text\"]])\n",
|
1137 |
+
"\n",
|
1138 |
+
"\n",
|
1139 |
+
"\n",
|
1140 |
+
"tokenized_Falcon = Falcon.map(\n",
|
1141 |
+
" preprocess_function,\n",
|
1142 |
+
" batched=True,\n",
|
1143 |
+
" num_proc=4,\n",
|
1144 |
+
" remove_columns=Falcon[\"train\"].column_names,\n",
|
1145 |
+
")"
|
1146 |
+
]
|
1147 |
+
},
|
1148 |
+
{
|
1149 |
+
"cell_type": "code",
|
1150 |
+
"execution_count": 44,
|
1151 |
+
"metadata": {},
|
1152 |
+
"outputs": [
|
1153 |
+
{
|
1154 |
+
"data": {
|
1155 |
+
"application/vnd.jupyter.widget-view+json": {
|
1156 |
+
"model_id": "6d7b13436ae54624bd96973987373482",
|
1157 |
+
"version_major": 2,
|
1158 |
+
"version_minor": 0
|
1159 |
+
},
|
1160 |
+
"text/plain": [
|
1161 |
+
"Map (num_proc=4): 0%| | 0/10000 [00:00<?, ? examples/s]"
|
1162 |
+
]
|
1163 |
+
},
|
1164 |
+
"metadata": {},
|
1165 |
+
"output_type": "display_data"
|
1166 |
+
},
|
1167 |
+
{
|
1168 |
+
"data": {
|
1169 |
+
"application/vnd.jupyter.widget-view+json": {
|
1170 |
+
"model_id": "beade64b537441ef99a54830bb66eef2",
|
1171 |
+
"version_major": 2,
|
1172 |
+
"version_minor": 0
|
1173 |
+
},
|
1174 |
+
"text/plain": [
|
1175 |
+
"Map (num_proc=4): 0%| | 0/1000 [00:00<?, ? examples/s]"
|
1176 |
+
]
|
1177 |
+
},
|
1178 |
+
"metadata": {},
|
1179 |
+
"output_type": "display_data"
|
1180 |
+
}
|
1181 |
+
],
|
1182 |
+
"source": [
|
1183 |
+
"# block_size = tokenizer.model_max_length\n",
|
1184 |
+
"block_size = 2048\n",
|
1185 |
+
"\n",
|
1186 |
+
"\n",
|
1187 |
+
"def group_texts(examples):\n",
|
1188 |
+
" # Concatenate all texts.\n",
|
1189 |
+
" concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}\n",
|
1190 |
+
" total_length = len(concatenated_examples[list(examples.keys())[0]])\n",
|
1191 |
+
" # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n",
|
1192 |
+
" # customize this part to your needs.\n",
|
1193 |
+
" if total_length >= block_size:\n",
|
1194 |
+
" total_length = (total_length // block_size) * block_size\n",
|
1195 |
+
" # Split by chunks of block_size.\n",
|
1196 |
+
" result = {\n",
|
1197 |
+
" k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n",
|
1198 |
+
" for k, t in concatenated_examples.items()\n",
|
1199 |
+
" }\n",
|
1200 |
+
" result[\"labels\"] = result[\"input_ids\"].copy()\n",
|
1201 |
+
" return result\n",
|
1202 |
+
"\n",
|
1203 |
+
"\"\"\"Apply the `group_texts` function over the entire dataset:\"\"\"\n",
|
1204 |
+
"\n",
|
1205 |
+
"lm_dataset = tokenized_Falcon.map(group_texts, batched=True, num_proc=4)\n"
|
1206 |
+
]
|
1207 |
+
},
|
1208 |
+
{
|
1209 |
+
"cell_type": "code",
|
1210 |
+
"execution_count": 45,
|
1211 |
+
"metadata": {},
|
1212 |
+
"outputs": [],
|
1213 |
+
"source": [
|
1214 |
+
"from transformers import DataCollatorForLanguageModeling\n",
|
1215 |
+
"\n",
|
1216 |
+
"# tokenizer.pad_token = tokenizer.eos_token\n",
|
1217 |
+
"data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)\n"
|
1218 |
+
]
|
1219 |
+
},
|
1220 |
+
{
|
1221 |
+
"cell_type": "code",
|
1222 |
+
"execution_count": null,
|
1223 |
+
"metadata": {},
|
1224 |
+
"outputs": [],
|
1225 |
+
"source": [
|
1226 |
+
"# from transformers import AutoModelForCausalLM, TrainingArguments, Trainer\n",
|
1227 |
+
"# import torch\n",
|
1228 |
+
"# model = AutoModelForCausalLM.from_pretrained(\"tensorplex-labs/pretraining-sn9-7B-5\", torch_dtype=torch.bfloat16)\n",
|
1229 |
+
"\n",
|
1230 |
+
"# print('Model Loaded!')\n"
|
1231 |
+
]
|
1232 |
+
},
|
1233 |
+
{
|
1234 |
+
"cell_type": "code",
|
1235 |
+
"execution_count": 46,
|
1236 |
+
"metadata": {},
|
1237 |
+
"outputs": [
|
1238 |
+
{
|
1239 |
+
"data": {
|
1240 |
+
"text/plain": [
|
1241 |
+
"MistralForCausalLM(\n",
|
1242 |
+
" (model): MistralModel(\n",
|
1243 |
+
" (embed_tokens): Embedding(32000, 4096)\n",
|
1244 |
+
" (layers): ModuleList(\n",
|
1245 |
+
" (0-29): 30 x MistralDecoderLayer(\n",
|
1246 |
+
" (self_attn): MistralSdpaAttention(\n",
|
1247 |
+
" (q_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
1248 |
+
" (k_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
1249 |
+
" (v_proj): Linear(in_features=4096, out_features=1024, bias=False)\n",
|
1250 |
+
" (o_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
|
1251 |
+
" (rotary_emb): MistralRotaryEmbedding()\n",
|
1252 |
+
" )\n",
|
1253 |
+
" (mlp): MistralMLP(\n",
|
1254 |
+
" (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
1255 |
+
" (up_proj): Linear(in_features=4096, out_features=14336, bias=False)\n",
|
1256 |
+
" (down_proj): Linear(in_features=14336, out_features=4096, bias=False)\n",
|
1257 |
+
" (act_fn): SiLU()\n",
|
1258 |
+
" )\n",
|
1259 |
+
" (input_layernorm): MistralRMSNorm()\n",
|
1260 |
+
" (post_attention_layernorm): MistralRMSNorm()\n",
|
1261 |
+
" )\n",
|
1262 |
+
" )\n",
|
1263 |
+
" (norm): MistralRMSNorm()\n",
|
1264 |
+
" )\n",
|
1265 |
+
" (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
|
1266 |
+
")"
|
1267 |
+
]
|
1268 |
+
},
|
1269 |
+
"execution_count": 46,
|
1270 |
+
"metadata": {},
|
1271 |
+
"output_type": "execute_result"
|
1272 |
+
}
|
1273 |
+
],
|
1274 |
+
"source": [
|
1275 |
+
"model.to('cuda')"
|
1276 |
+
]
|
1277 |
+
},
|
1278 |
+
{
|
1279 |
+
"cell_type": "code",
|
1280 |
+
"execution_count": 47,
|
1281 |
+
"metadata": {},
|
1282 |
+
"outputs": [
|
1283 |
+
{
|
1284 |
+
"data": {
|
1285 |
+
"text/plain": [
|
1286 |
+
"6805508096"
|
1287 |
+
]
|
1288 |
+
},
|
1289 |
+
"execution_count": 47,
|
1290 |
+
"metadata": {},
|
1291 |
+
"output_type": "execute_result"
|
1292 |
+
}
|
1293 |
+
],
|
1294 |
+
"source": [
|
1295 |
+
"pytorch_total_params = sum(p.numel() for p in model.parameters())\n",
|
1296 |
+
"pytorch_total_params"
|
1297 |
+
]
|
1298 |
+
},
|
1299 |
+
{
|
1300 |
+
"cell_type": "code",
|
1301 |
+
"execution_count": 48,
|
1302 |
+
"metadata": {},
|
1303 |
+
"outputs": [],
|
1304 |
+
"source": [
|
1305 |
+
"training_args = TrainingArguments(\n",
|
1306 |
+
" output_dir=\"Fine-Tuned-S9-2\",\n",
|
1307 |
+
" overwrite_output_dir=True,\n",
|
1308 |
+
" bf16=True,\n",
|
1309 |
+
" # evaluation_strategy=\"epoch\",\n",
|
1310 |
+
" evaluation_strategy=\"steps\",\n",
|
1311 |
+
" learning_rate=2e-5,\n",
|
1312 |
+
" weight_decay=0.01,\n",
|
1313 |
+
" num_train_epochs=1,\n",
|
1314 |
+
" per_device_train_batch_size=2,\n",
|
1315 |
+
" per_device_eval_batch_size=2,\n",
|
1316 |
+
" lr_scheduler_type = 'cosine',\n",
|
1317 |
+
" push_to_hub=False,\n",
|
1318 |
+
" save_total_limit = 2,\n",
|
1319 |
+
" # save_strategy = “no”\n",
|
1320 |
+
" load_best_model_at_end=False,\n",
|
1321 |
+
")\n",
|
1322 |
+
"\n",
|
1323 |
+
"trainer = Trainer(\n",
|
1324 |
+
" model=model,\n",
|
1325 |
+
" args=training_args,\n",
|
1326 |
+
" train_dataset=lm_dataset[\"train\"],\n",
|
1327 |
+
" eval_dataset=lm_dataset[\"validation\"],\n",
|
1328 |
+
" # eval_dataset=lm_dataset[\"test\"],\n",
|
1329 |
+
" data_collator=data_collator,\n",
|
1330 |
+
")"
|
1331 |
+
]
|
1332 |
+
},
|
1333 |
+
{
|
1334 |
+
"cell_type": "code",
|
1335 |
+
"execution_count": 49,
|
1336 |
+
"metadata": {},
|
1337 |
+
"outputs": [
|
1338 |
+
{
|
1339 |
+
"name": "stdout",
|
1340 |
+
"output_type": "stream",
|
1341 |
+
"text": [
|
1342 |
+
"Started Training!\n"
|
1343 |
+
]
|
1344 |
+
},
|
1345 |
+
{
|
1346 |
+
"name": "stderr",
|
1347 |
+
"output_type": "stream",
|
1348 |
+
"text": [
|
1349 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mthatmlguy\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
1350 |
+
]
|
1351 |
+
},
|
1352 |
+
{
|
1353 |
+
"data": {
|
1354 |
+
"text/html": [
|
1355 |
+
"Tracking run with wandb version 0.17.0"
|
1356 |
+
],
|
1357 |
+
"text/plain": [
|
1358 |
+
"<IPython.core.display.HTML object>"
|
1359 |
+
]
|
1360 |
+
},
|
1361 |
+
"metadata": {},
|
1362 |
+
"output_type": "display_data"
|
1363 |
+
},
|
1364 |
+
{
|
1365 |
+
"data": {
|
1366 |
+
"text/html": [
|
1367 |
+
"Run data is saved locally in <code>/workspace/ShortGPT/short_gpt/wandb/run-20240516_090043-ni1hktjg</code>"
|
1368 |
+
],
|
1369 |
+
"text/plain": [
|
1370 |
+
"<IPython.core.display.HTML object>"
|
1371 |
+
]
|
1372 |
+
},
|
1373 |
+
"metadata": {},
|
1374 |
+
"output_type": "display_data"
|
1375 |
+
},
|
1376 |
+
{
|
1377 |
+
"data": {
|
1378 |
+
"text/html": [
|
1379 |
+
"Syncing run <strong><a href='https://wandb.ai/thatmlguy/huggingface/runs/ni1hktjg' target=\"_blank\">misty-serenity-4</a></strong> to <a href='https://wandb.ai/thatmlguy/huggingface' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
1380 |
+
],
|
1381 |
+
"text/plain": [
|
1382 |
+
"<IPython.core.display.HTML object>"
|
1383 |
+
]
|
1384 |
+
},
|
1385 |
+
"metadata": {},
|
1386 |
+
"output_type": "display_data"
|
1387 |
+
},
|
1388 |
+
{
|
1389 |
+
"data": {
|
1390 |
+
"text/html": [
|
1391 |
+
" View project at <a href='https://wandb.ai/thatmlguy/huggingface' target=\"_blank\">https://wandb.ai/thatmlguy/huggingface</a>"
|
1392 |
+
],
|
1393 |
+
"text/plain": [
|
1394 |
+
"<IPython.core.display.HTML object>"
|
1395 |
+
]
|
1396 |
+
},
|
1397 |
+
"metadata": {},
|
1398 |
+
"output_type": "display_data"
|
1399 |
+
},
|
1400 |
+
{
|
1401 |
+
"data": {
|
1402 |
+
"text/html": [
|
1403 |
+
" View run at <a href='https://wandb.ai/thatmlguy/huggingface/runs/ni1hktjg' target=\"_blank\">https://wandb.ai/thatmlguy/huggingface/runs/ni1hktjg</a>"
|
1404 |
+
],
|
1405 |
+
"text/plain": [
|
1406 |
+
"<IPython.core.display.HTML object>"
|
1407 |
+
]
|
1408 |
+
},
|
1409 |
+
"metadata": {},
|
1410 |
+
"output_type": "display_data"
|
1411 |
+
},
|
1412 |
+
{
|
1413 |
+
"data": {
|
1414 |
+
"text/html": [
|
1415 |
+
"\n",
|
1416 |
+
" <div>\n",
|
1417 |
+
" \n",
|
1418 |
+
" <progress value='2' max='6459' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1419 |
+
" [ 2/6459 : < :, Epoch 0.00/1]\n",
|
1420 |
+
" </div>\n",
|
1421 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1422 |
+
" <thead>\n",
|
1423 |
+
" <tr style=\"text-align: left;\">\n",
|
1424 |
+
" <th>Step</th>\n",
|
1425 |
+
" <th>Training Loss</th>\n",
|
1426 |
+
" <th>Validation Loss</th>\n",
|
1427 |
+
" </tr>\n",
|
1428 |
+
" </thead>\n",
|
1429 |
+
" <tbody>\n",
|
1430 |
+
" </tbody>\n",
|
1431 |
+
"</table><p>"
|
1432 |
+
],
|
1433 |
+
"text/plain": [
|
1434 |
+
"<IPython.core.display.HTML object>"
|
1435 |
+
]
|
1436 |
+
},
|
1437 |
+
"metadata": {},
|
1438 |
+
"output_type": "display_data"
|
1439 |
+
},
|
1440 |
+
{
|
1441 |
+
"ename": "OutOfMemoryError",
|
1442 |
+
"evalue": "CUDA out of memory. Tried to allocate 112.00 MiB. GPU ",
|
1443 |
+
"output_type": "error",
|
1444 |
+
"traceback": [
|
1445 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
1446 |
+
"\u001b[0;31mOutOfMemoryError\u001b[0m Traceback (most recent call last)",
|
1447 |
+
"Cell \u001b[0;32mIn[49], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# trainer.train()\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mStarted Training!\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 3\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
1448 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:1859\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1857\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1858\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1859\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1860\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1861\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1862\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1863\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1864\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
1449 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2203\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 2200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 2203\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2205\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 2206\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 2207\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_xla_available()\n\u001b[1;32m 2208\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 2209\u001b[0m ):\n\u001b[1;32m 2210\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 2211\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
1450 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3138\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 3135\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 3137\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 3138\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3140\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 3141\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
|
1451 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3161\u001b[0m, in \u001b[0;36mTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 3159\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3160\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 3161\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;66;03m# Save past state if it exists\u001b[39;00m\n\u001b[1;32m 3163\u001b[0m \u001b[38;5;66;03m# TODO: this needs to be fixed and made cleaner later.\u001b[39;00m\n\u001b[1;32m 3164\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mpast_index \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
|
1452 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1453 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1454 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:822\u001b[0m, in \u001b[0;36mconvert_outputs_to_fp32.<locals>.forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 822\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1455 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:810\u001b[0m, in \u001b[0;36mConvertOutputsToFp32.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 809\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 810\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m convert_to_fp32(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m)\n",
|
1456 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:16\u001b[0m, in \u001b[0;36mautocast_decorator.<locals>.decorate_autocast\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_autocast\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m autocast_instance:\n\u001b[0;32m---> 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1457 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:1158\u001b[0m, in \u001b[0;36mMistralForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1155\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1157\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1158\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1159\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1160\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1161\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1162\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1163\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1164\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1165\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1166\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1167\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1168\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1170\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1171\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
|
1458 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1459 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1460 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:1043\u001b[0m, in \u001b[0;36mMistralModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1033\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 1034\u001b[0m decoder_layer\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 1035\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1040\u001b[0m use_cache,\n\u001b[1;32m 1041\u001b[0m )\n\u001b[1;32m 1042\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1043\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_layer\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1044\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1045\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1046\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1047\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1048\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1049\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1052\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1054\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
|
1461 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1462 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1463 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:770\u001b[0m, in \u001b[0;36mMistralDecoderLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, **kwargs)\u001b[0m\n\u001b[1;32m 768\u001b[0m residual \u001b[38;5;241m=\u001b[39m hidden_states\n\u001b[1;32m 769\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpost_attention_layernorm(hidden_states)\n\u001b[0;32m--> 770\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 771\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m residual \u001b[38;5;241m+\u001b[39m hidden_states\n\u001b[1;32m 773\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (hidden_states,)\n",
|
1464 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1465 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1466 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:179\u001b[0m, in \u001b[0;36mMistralMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdown_proj(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mact_fn(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgate_proj\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mup_proj(x))\n",
|
1467 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
1468 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
1469 |
+
"File \u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
|
1470 |
+
"\u001b[0;31mOutOfMemoryError\u001b[0m: CUDA out of memory. Tried to allocate 112.00 MiB. GPU "
|
1471 |
+
]
|
1472 |
+
}
|
1473 |
+
],
|
1474 |
+
"source": [
|
1475 |
+
"# trainer.train()\n",
|
1476 |
+
"print('Started Training!')\n",
|
1477 |
+
"trainer.train()"
|
1478 |
+
]
|
1479 |
+
},
|
1480 |
+
{
|
1481 |
+
"cell_type": "code",
|
1482 |
+
"execution_count": null,
|
1483 |
+
"metadata": {},
|
1484 |
+
"outputs": [],
|
1485 |
+
"source": [
|
1486 |
+
"import math\n",
|
1487 |
+
"\n",
|
1488 |
+
"eval_results = trainer.evaluate()\n",
|
1489 |
+
"print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")\n"
|
1490 |
+
]
|
1491 |
+
},
|
1492 |
+
{
|
1493 |
+
"cell_type": "code",
|
1494 |
+
"execution_count": 29,
|
1495 |
+
"metadata": {},
|
1496 |
+
"outputs": [],
|
1497 |
+
"source": [
|
1498 |
+
"# # referencing https://github.com/meta-llama/llama-recipes/blob/main/recipes/finetuning/huggingface_trainer/peft_finetuning.ipynb\n",
|
1499 |
+
"# eval_prompt = \"\"\"\n",
|
1500 |
+
"# Summarize this dialog:\n",
|
1501 |
+
"# A: Hi Tom, are you busy tomorrow's afternoon?\n",
|
1502 |
+
"# B: I'm pretty sure I am. What's up?\n",
|
1503 |
+
"# A: Can you go with me to the animal shelter?.\n",
|
1504 |
+
"# B: What do you want to do?\n",
|
1505 |
+
"# A: I want to get a puppy for my son.\n",
|
1506 |
+
"# B: That will make him so happy.\n",
|
1507 |
+
"# A: Yeah, we've discussed it many times. I think he's ready now.\n",
|
1508 |
+
"# B: That's good. Raising a dog is a tough issue. Like having a baby ;-) \n",
|
1509 |
+
"# A: I'll get him one of those little dogs.\n",
|
1510 |
+
"# B: One that won't grow up too big;-)\n",
|
1511 |
+
"# A: And eat too much;-))\n",
|
1512 |
+
"# B: Do you know which one he would like?\n",
|
1513 |
+
"# A: Oh, yes, I took him there last Monday. He showed me one that he really liked.\n",
|
1514 |
+
"# B: I bet you had to drag him away.\n",
|
1515 |
+
"# A: He wanted to take it home right away ;-).\n",
|
1516 |
+
"# B: I wonder what he'll name it.\n",
|
1517 |
+
"# A: He said he'd name it after his dead hamster - Lemmy - he's a great Motorhead fan :-)))\n",
|
1518 |
+
"# ---\n",
|
1519 |
+
"# Summary:\n",
|
1520 |
+
"# \"\"\"\n",
|
1521 |
+
"\n",
|
1522 |
+
"# model_input = tokenizer(eval_prompt, return_tensors=\"pt\").to(\"cuda\")\n",
|
1523 |
+
"\n",
|
1524 |
+
"# model.eval()\n",
|
1525 |
+
"# with torch.no_grad():\n",
|
1526 |
+
"# print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100, use_cache=True)[0], skip_special_tokens=True))"
|
1527 |
+
]
|
1528 |
+
},
|
1529 |
+
{
|
1530 |
+
"cell_type": "code",
|
1531 |
+
"execution_count": 30,
|
1532 |
+
"metadata": {},
|
1533 |
+
"outputs": [],
|
1534 |
+
"source": [
|
1535 |
+
"# def get_preprocessed_samsum():\n",
|
1536 |
+
"# dataset = load_dataset(\"samsum\", split=\"train\")\n",
|
1537 |
+
"\n",
|
1538 |
+
"# prompt = (\n",
|
1539 |
+
"# f\"Summarize this dialog:\\n{{dialog}}\\n---\\nSummary:\\n\"\n",
|
1540 |
+
"# )\n",
|
1541 |
+
"\n",
|
1542 |
+
"# def apply_prompt_template(sample):\n",
|
1543 |
+
"# return {\n",
|
1544 |
+
"# \"prompt\": prompt.format(dialog=sample[\"dialogue\"]),\n",
|
1545 |
+
"# \"summary\": sample[\"summary\"],\n",
|
1546 |
+
"# }\n",
|
1547 |
+
"\n",
|
1548 |
+
"# dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))\n",
|
1549 |
+
"\n",
|
1550 |
+
"# def tokenize_add_label(sample):\n",
|
1551 |
+
"# prompt = tokenizer.encode(tokenizer.bos_token + sample[\"prompt\"], add_special_tokens=False)\n",
|
1552 |
+
"# summary = tokenizer.encode(sample[\"summary\"] + tokenizer.eos_token, add_special_tokens=False)\n",
|
1553 |
+
"# sample = {\n",
|
1554 |
+
"# \"input_ids\": prompt + summary,\n",
|
1555 |
+
"# \"attention_mask\" : [1] * (len(prompt) + len(summary)),\n",
|
1556 |
+
"# \"labels\": [-100] * len(prompt) + summary,\n",
|
1557 |
+
"# }\n",
|
1558 |
+
"\n",
|
1559 |
+
"# return sample\n",
|
1560 |
+
"\n",
|
1561 |
+
"# dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))\n",
|
1562 |
+
"\n",
|
1563 |
+
"# return dataset"
|
1564 |
+
]
|
1565 |
+
},
|
1566 |
+
{
|
1567 |
+
"cell_type": "code",
|
1568 |
+
"execution_count": 31,
|
1569 |
+
"metadata": {},
|
1570 |
+
"outputs": [],
|
1571 |
+
"source": [
|
1572 |
+
"# model.train()\n",
|
1573 |
+
"\n",
|
1574 |
+
"# def create_peft_config(model):\n",
|
1575 |
+
"# peft_config = LoraConfig(\n",
|
1576 |
+
"# task_type=TaskType.CAUSAL_LM,\n",
|
1577 |
+
"# inference_mode=False,\n",
|
1578 |
+
"# r=8,\n",
|
1579 |
+
"# lora_alpha=32,\n",
|
1580 |
+
"# lora_dropout=0.05,\n",
|
1581 |
+
"# target_modules = [\"q_proj\", \"v_proj\"]\n",
|
1582 |
+
"# )\n",
|
1583 |
+
"\n",
|
1584 |
+
"# model = get_peft_model(model, peft_config)\n",
|
1585 |
+
"# model.print_trainable_parameters()\n",
|
1586 |
+
"# return model, peft_config\n",
|
1587 |
+
"\n",
|
1588 |
+
"# # create peft config\n",
|
1589 |
+
"# model, lora_config = create_peft_config(model)"
|
1590 |
+
]
|
1591 |
+
},
|
1592 |
+
{
|
1593 |
+
"cell_type": "code",
|
1594 |
+
"execution_count": 32,
|
1595 |
+
"metadata": {},
|
1596 |
+
"outputs": [],
|
1597 |
+
"source": [
|
1598 |
+
"# output_dir = \"tmp/\"\n",
|
1599 |
+
"\n",
|
1600 |
+
"# config = {\n",
|
1601 |
+
"# 'lora_config': lora_config,\n",
|
1602 |
+
"# 'learning_rate': 1e-6,\n",
|
1603 |
+
"# 'num_train_epochs': 1,\n",
|
1604 |
+
"# 'per_device_train_batch_size': 1,\n",
|
1605 |
+
"# 'gradient_checkpointing': False,\n",
|
1606 |
+
"# }\n"
|
1607 |
+
]
|
1608 |
+
},
|
1609 |
+
{
|
1610 |
+
"cell_type": "code",
|
1611 |
+
"execution_count": 33,
|
1612 |
+
"metadata": {},
|
1613 |
+
"outputs": [],
|
1614 |
+
"source": [
|
1615 |
+
"# training_args = TrainingArguments(\n",
|
1616 |
+
"# output_dir=output_dir,\n",
|
1617 |
+
"# overwrite_output_dir=True,\n",
|
1618 |
+
"# # logging strategies\n",
|
1619 |
+
"# logging_strategy=\"steps\",\n",
|
1620 |
+
"# logging_steps=10,\n",
|
1621 |
+
"# save_strategy=\"no\",\n",
|
1622 |
+
"# optim=\"adamw_torch_fused\",\n",
|
1623 |
+
"# **{k:v for k,v in config.items() if k != 'lora_config'}\n",
|
1624 |
+
"# )\n",
|
1625 |
+
"\n",
|
1626 |
+
"# # Create Trainer instance\n",
|
1627 |
+
"# trainer = Trainer(\n",
|
1628 |
+
"# model=model,\n",
|
1629 |
+
"# args=training_args,\n",
|
1630 |
+
"# train_dataset=get_preprocessed_samsum(),\n",
|
1631 |
+
"# data_collator=default_data_collator,\n",
|
1632 |
+
"# callbacks=[],\n",
|
1633 |
+
"# )\n",
|
1634 |
+
"\n",
|
1635 |
+
"# # Start training\n",
|
1636 |
+
"# trainer.train()"
|
1637 |
+
]
|
1638 |
+
},
|
1639 |
+
{
|
1640 |
+
"cell_type": "code",
|
1641 |
+
"execution_count": 34,
|
1642 |
+
"metadata": {},
|
1643 |
+
"outputs": [],
|
1644 |
+
"source": [
|
1645 |
+
"# model.eval()\n",
|
1646 |
+
"# with torch.no_grad():\n",
|
1647 |
+
"# print(tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))"
|
1648 |
+
]
|
1649 |
+
},
|
1650 |
+
{
|
1651 |
+
"cell_type": "code",
|
1652 |
+
"execution_count": null,
|
1653 |
+
"metadata": {},
|
1654 |
+
"outputs": [],
|
1655 |
+
"source": []
|
1656 |
+
}
|
1657 |
+
],
|
1658 |
+
"metadata": {
|
1659 |
+
"kernelspec": {
|
1660 |
+
"display_name": "Python 3 (ipykernel)",
|
1661 |
+
"language": "python",
|
1662 |
+
"name": "python3"
|
1663 |
+
},
|
1664 |
+
"language_info": {
|
1665 |
+
"codemirror_mode": {
|
1666 |
+
"name": "ipython",
|
1667 |
+
"version": 3
|
1668 |
+
},
|
1669 |
+
"file_extension": ".py",
|
1670 |
+
"mimetype": "text/x-python",
|
1671 |
+
"name": "python",
|
1672 |
+
"nbconvert_exporter": "python",
|
1673 |
+
"pygments_lexer": "ipython3",
|
1674 |
+
"version": "3.11.9"
|
1675 |
+
}
|
1676 |
+
},
|
1677 |
+
"nbformat": 4,
|
1678 |
+
"nbformat_minor": 4
|
1679 |
+
}
|
short_gpt/short_llama.ipynb
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from tqdm.notebook import tqdm\n",
|
10 |
+
"\n",
|
11 |
+
"from datasets import load_dataset\n",
|
12 |
+
"import torch\n",
|
13 |
+
"from torch.utils.data import DataLoader\n",
|
14 |
+
"\n",
|
15 |
+
"from llama import Llama\n",
|
16 |
+
"\n",
|
17 |
+
"from short_llama import ShortLlama"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "markdown",
|
22 |
+
"metadata": {},
|
23 |
+
"source": [
|
24 |
+
"### Load Data"
|
25 |
+
]
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"cell_type": "code",
|
29 |
+
"execution_count": 2,
|
30 |
+
"metadata": {},
|
31 |
+
"outputs": [
|
32 |
+
{
|
33 |
+
"name": "stderr",
|
34 |
+
"output_type": "stream",
|
35 |
+
"text": [
|
36 |
+
"c:\\Users\\Shivaen\\anaconda3\\envs\\shortgpt\\lib\\site-packages\\datasets\\load.py:1461: FutureWarning: The repository for pg19 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/pg19\n",
|
37 |
+
"You can avoid this message in future by passing the argument `trust_remote_code=True`.\n",
|
38 |
+
"Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n",
|
39 |
+
" warnings.warn(\n"
|
40 |
+
]
|
41 |
+
}
|
42 |
+
],
|
43 |
+
"source": [
|
44 |
+
"data = load_dataset(\"pg19\", split=\"validation\") # authors sample 10,000 texts to compute block influences\n",
|
45 |
+
"dataloader = DataLoader(\n",
|
46 |
+
" data,\n",
|
47 |
+
" batch_size=1,\n",
|
48 |
+
" shuffle=True,\n",
|
49 |
+
" generator=torch.Generator(device=\"cuda\")\n",
|
50 |
+
")"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "markdown",
|
55 |
+
"metadata": {},
|
56 |
+
"source": [
|
57 |
+
"### Fetch and Wrap Model"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 3,
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [
|
65 |
+
{
|
66 |
+
"name": "stdout",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"> initializing model parallel with size 1\n",
|
70 |
+
"> initializing ddp with size 1\n",
|
71 |
+
"> initializing pipeline with size 1\n"
|
72 |
+
]
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"name": "stderr",
|
76 |
+
"output_type": "stream",
|
77 |
+
"text": [
|
78 |
+
"c:\\Users\\Shivaen\\anaconda3\\envs\\shortgpt\\lib\\site-packages\\torch\\__init__.py:696: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\torch\\csrc\\tensor\\python_tensor.cpp:453.)\n",
|
79 |
+
" _C._set_default_tensor_type(t)\n"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "stdout",
|
84 |
+
"output_type": "stream",
|
85 |
+
"text": [
|
86 |
+
"Loaded in 10.96 seconds\n"
|
87 |
+
]
|
88 |
+
}
|
89 |
+
],
|
90 |
+
"source": [
|
91 |
+
"MAX_SEQ_LEN = 1024 # authors use a context width of 1024\n",
|
92 |
+
"llama = Llama.build(\n",
|
93 |
+
" ckpt_dir=\"../../llama/llama-2-7b\",\n",
|
94 |
+
" tokenizer_path=\"../../llama/tokenizer.model\",\n",
|
95 |
+
" max_seq_len=MAX_SEQ_LEN,\n",
|
96 |
+
" max_batch_size=1,\n",
|
97 |
+
")"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"execution_count": 4,
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [
|
105 |
+
{
|
106 |
+
"data": {
|
107 |
+
"text/plain": [
|
108 |
+
"ModuleList(\n",
|
109 |
+
" (0-31): 32 x TransformerBlock(\n",
|
110 |
+
" (attention): Attention(\n",
|
111 |
+
" (wq): ColumnParallelLinear()\n",
|
112 |
+
" (wk): ColumnParallelLinear()\n",
|
113 |
+
" (wv): ColumnParallelLinear()\n",
|
114 |
+
" (wo): RowParallelLinear()\n",
|
115 |
+
" )\n",
|
116 |
+
" (feed_forward): FeedForward(\n",
|
117 |
+
" (w1): ColumnParallelLinear()\n",
|
118 |
+
" (w2): RowParallelLinear()\n",
|
119 |
+
" (w3): ColumnParallelLinear()\n",
|
120 |
+
" )\n",
|
121 |
+
" (attention_norm): RMSNorm()\n",
|
122 |
+
" (ffn_norm): RMSNorm()\n",
|
123 |
+
" )\n",
|
124 |
+
")"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
"execution_count": 4,
|
128 |
+
"metadata": {},
|
129 |
+
"output_type": "execute_result"
|
130 |
+
}
|
131 |
+
],
|
132 |
+
"source": [
|
133 |
+
"short_llama = ShortLlama(llama=llama, n_prune_layers=9)\n",
|
134 |
+
"\n",
|
135 |
+
"short_llama.llama.model.layers"
|
136 |
+
]
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"execution_count": 5,
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [
|
143 |
+
{
|
144 |
+
"data": {
|
145 |
+
"text/plain": [
|
146 |
+
"[{'generation': '1960s-70s era pop music. I grew up listening to the radio'}]"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
"execution_count": 5,
|
150 |
+
"metadata": {},
|
151 |
+
"output_type": "execute_result"
|
152 |
+
}
|
153 |
+
],
|
154 |
+
"source": [
|
155 |
+
"# sample generation\n",
|
156 |
+
"short_llama.llama.text_completion(\n",
|
157 |
+
" prompts=[\"I am an avid fan of \"],\n",
|
158 |
+
" max_gen_len=20\n",
|
159 |
+
")"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "markdown",
|
164 |
+
"metadata": {},
|
165 |
+
"source": [
|
166 |
+
"### Compute Importances"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": 6,
|
172 |
+
"metadata": {},
|
173 |
+
"outputs": [
|
174 |
+
{
|
175 |
+
"data": {
|
176 |
+
"application/vnd.jupyter.widget-view+json": {
|
177 |
+
"model_id": "bf50ed0464aa454386d996e71b4541b4",
|
178 |
+
"version_major": 2,
|
179 |
+
"version_minor": 0
|
180 |
+
},
|
181 |
+
"text/plain": [
|
182 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
183 |
+
]
|
184 |
+
},
|
185 |
+
"metadata": {},
|
186 |
+
"output_type": "display_data"
|
187 |
+
}
|
188 |
+
],
|
189 |
+
"source": [
|
190 |
+
"for batch in tqdm(dataloader):\n",
|
191 |
+
" prompts = batch['text']\n",
|
192 |
+
"\n",
|
193 |
+
" prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos=True, eos=False) for x in prompts]\n",
|
194 |
+
" max_prompt_len = max(len(t) for t in prompt_tokens)\n",
|
195 |
+
"\n",
|
196 |
+
" # authors use a sliding window of size 1024 with a shift of 256\n",
|
197 |
+
" for start in range(0, max_prompt_len, 256):\n",
|
198 |
+
"\n",
|
199 |
+
" inputs = [p[start:start+MAX_SEQ_LEN] for p in prompt_tokens if len(p) > start]\n",
|
200 |
+
"\n",
|
201 |
+
" short_llama.eval_importance(\n",
|
202 |
+
" prompt_tokens=inputs,\n",
|
203 |
+
" max_gen_len=0\n",
|
204 |
+
" )"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 7,
|
210 |
+
"metadata": {},
|
211 |
+
"outputs": [
|
212 |
+
{
|
213 |
+
"data": {
|
214 |
+
"text/plain": [
|
215 |
+
"[8358921.716796875,\n",
|
216 |
+
" 5211709.220703125,\n",
|
217 |
+
" 3259066.66796875,\n",
|
218 |
+
" 3164092.5087890625,\n",
|
219 |
+
" 3518517.248046875,\n",
|
220 |
+
" 3153696.0009765625,\n",
|
221 |
+
" 3062620.751953125,\n",
|
222 |
+
" 2856062.2998046875,\n",
|
223 |
+
" 2674124.23828125,\n",
|
224 |
+
" 2545894.03125,\n",
|
225 |
+
" 2382950.501953125,\n",
|
226 |
+
" 2194983.1455078125,\n",
|
227 |
+
" 2146358.5107421875,\n",
|
228 |
+
" 2180816.779296875,\n",
|
229 |
+
" 2145900.15234375,\n",
|
230 |
+
" 2126212.3974609375,\n",
|
231 |
+
" 2180678.5244140625,\n",
|
232 |
+
" 1686190.7548828125,\n",
|
233 |
+
" 1524035.5732421875,\n",
|
234 |
+
" 1270041.162109375,\n",
|
235 |
+
" 1368594.52734375,\n",
|
236 |
+
" 954588.056640625,\n",
|
237 |
+
" 944560.7900390625,\n",
|
238 |
+
" 780482.943359375,\n",
|
239 |
+
" 743930.5283203125,\n",
|
240 |
+
" 732873.1806640625,\n",
|
241 |
+
" 745402.265625,\n",
|
242 |
+
" 733417.81640625,\n",
|
243 |
+
" 762292.994140625,\n",
|
244 |
+
" 771143.9541015625,\n",
|
245 |
+
" 1303522.251953125,\n",
|
246 |
+
" 5824847.5546875]"
|
247 |
+
]
|
248 |
+
},
|
249 |
+
"execution_count": 7,
|
250 |
+
"metadata": {},
|
251 |
+
"output_type": "execute_result"
|
252 |
+
}
|
253 |
+
],
|
254 |
+
"source": [
|
255 |
+
"short_llama.importances"
|
256 |
+
]
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"cell_type": "markdown",
|
260 |
+
"metadata": {},
|
261 |
+
"source": [
|
262 |
+
"### Remove unimportant layers\n",
|
263 |
+
"\n",
|
264 |
+
"Layers removed when using pg19 val set: [25, 27, 24, 26, 28, 29, 23, 22, 21]\n",
|
265 |
+
"\n",
|
266 |
+
"Note: Different order than paper but same 9 least important layers -> [27, 26, 25, 28, 24, 29, 23, 21, 22]\n",
|
267 |
+
"\n",
|
268 |
+
"Additionally, authors mention that the layer order is quite nuanced and can vary with different datasets. However, relative order suggests similar importance."
|
269 |
+
]
|
270 |
+
},
|
271 |
+
{
|
272 |
+
"cell_type": "code",
|
273 |
+
"execution_count": 8,
|
274 |
+
"metadata": {},
|
275 |
+
"outputs": [
|
276 |
+
{
|
277 |
+
"data": {
|
278 |
+
"text/plain": [
|
279 |
+
"[25, 27, 24, 26, 28, 29, 23, 22, 21]"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
"execution_count": 8,
|
283 |
+
"metadata": {},
|
284 |
+
"output_type": "execute_result"
|
285 |
+
}
|
286 |
+
],
|
287 |
+
"source": [
|
288 |
+
"short_llama.remove_layers()"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 9,
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"data": {
|
298 |
+
"text/plain": [
|
299 |
+
"ModuleList(\n",
|
300 |
+
" (0-22): 23 x TransformerBlock(\n",
|
301 |
+
" (attention): Attention(\n",
|
302 |
+
" (wq): ColumnParallelLinear()\n",
|
303 |
+
" (wk): ColumnParallelLinear()\n",
|
304 |
+
" (wv): ColumnParallelLinear()\n",
|
305 |
+
" (wo): RowParallelLinear()\n",
|
306 |
+
" )\n",
|
307 |
+
" (feed_forward): FeedForward(\n",
|
308 |
+
" (w1): ColumnParallelLinear()\n",
|
309 |
+
" (w2): RowParallelLinear()\n",
|
310 |
+
" (w3): ColumnParallelLinear()\n",
|
311 |
+
" )\n",
|
312 |
+
" (attention_norm): RMSNorm()\n",
|
313 |
+
" (ffn_norm): RMSNorm()\n",
|
314 |
+
" )\n",
|
315 |
+
")"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
"execution_count": 9,
|
319 |
+
"metadata": {},
|
320 |
+
"output_type": "execute_result"
|
321 |
+
}
|
322 |
+
],
|
323 |
+
"source": [
|
324 |
+
"short_llama.llama.model.layers"
|
325 |
+
]
|
326 |
+
},
|
327 |
+
{
|
328 |
+
"cell_type": "markdown",
|
329 |
+
"metadata": {},
|
330 |
+
"source": [
|
331 |
+
"As the paper states: \\\n",
|
332 |
+
" - \"Our experiments reveal that the effect of layer removal is significantly more pronounced on generative\n",
|
333 |
+
" tasks compared to multiple-choice tasks. On benchmarks such as GSM8K (Cobbe et al., 2021) and\n",
|
334 |
+
" HumanEval (Chen et al., 2021), removing 25% of the layers often leads to a severe performance\n",
|
335 |
+
" drop, with scores approaching zero.\""
|
336 |
+
]
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "code",
|
340 |
+
"execution_count": 10,
|
341 |
+
"metadata": {},
|
342 |
+
"outputs": [
|
343 |
+
{
|
344 |
+
"data": {
|
345 |
+
"text/plain": [
|
346 |
+
"[{'generation': 'Đo n Khơ 20th Century. Hinweis: In = ,t and lồ'}]"
|
347 |
+
]
|
348 |
+
},
|
349 |
+
"execution_count": 10,
|
350 |
+
"metadata": {},
|
351 |
+
"output_type": "execute_result"
|
352 |
+
}
|
353 |
+
],
|
354 |
+
"source": [
|
355 |
+
"short_llama.llama.text_completion(\n",
|
356 |
+
" prompts=[\"I am an avid fan of \"],\n",
|
357 |
+
" max_gen_len=20\n",
|
358 |
+
")"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"cell_type": "markdown",
|
363 |
+
"metadata": {},
|
364 |
+
"source": [
|
365 |
+
"### Compute Angular Importances"
|
366 |
+
]
|
367 |
+
},
|
368 |
+
{
|
369 |
+
"cell_type": "code",
|
370 |
+
"execution_count": 6,
|
371 |
+
"metadata": {},
|
372 |
+
"outputs": [
|
373 |
+
{
|
374 |
+
"data": {
|
375 |
+
"application/vnd.jupyter.widget-view+json": {
|
376 |
+
"model_id": "8ae0be70aa9344edbd252648c84e08e0",
|
377 |
+
"version_major": 2,
|
378 |
+
"version_minor": 0
|
379 |
+
},
|
380 |
+
"text/plain": [
|
381 |
+
" 0%| | 0/50 [00:00<?, ?it/s]"
|
382 |
+
]
|
383 |
+
},
|
384 |
+
"metadata": {},
|
385 |
+
"output_type": "display_data"
|
386 |
+
}
|
387 |
+
],
|
388 |
+
"source": [
|
389 |
+
"for batch in tqdm(dataloader):\n",
|
390 |
+
" prompts = batch['text']\n",
|
391 |
+
"\n",
|
392 |
+
" prompt_tokens = [short_llama.llama.tokenizer.encode(x, bos=True, eos=False) for x in prompts]\n",
|
393 |
+
" max_prompt_len = max(len(t) for t in prompt_tokens)\n",
|
394 |
+
"\n",
|
395 |
+
" # authors use a sliding window of size 1024 with a shift of 256\n",
|
396 |
+
" for start in range(0, max_prompt_len, 256):\n",
|
397 |
+
"\n",
|
398 |
+
" inputs = [p[start:start+MAX_SEQ_LEN] for p in prompt_tokens if len(p) > start]\n",
|
399 |
+
"\n",
|
400 |
+
" short_llama.eval_importance(\n",
|
401 |
+
" prompt_tokens=inputs,\n",
|
402 |
+
" max_gen_len=0,\n",
|
403 |
+
" angular=True\n",
|
404 |
+
" )"
|
405 |
+
]
|
406 |
+
},
|
407 |
+
{
|
408 |
+
"cell_type": "code",
|
409 |
+
"execution_count": 7,
|
410 |
+
"metadata": {},
|
411 |
+
"outputs": [
|
412 |
+
{
|
413 |
+
"data": {
|
414 |
+
"text/plain": [
|
415 |
+
"[8640.460205078125,\n",
|
416 |
+
" 7881.541015625,\n",
|
417 |
+
" 7303.3876953125,\n",
|
418 |
+
" 7156.226318359375,\n",
|
419 |
+
" 7003.533935546875,\n",
|
420 |
+
" 6749.5189208984375,\n",
|
421 |
+
" 6630.6031494140625,\n",
|
422 |
+
" 6494.6051025390625,\n",
|
423 |
+
" 6475.490295410156,\n",
|
424 |
+
" 6482.81884765625,\n",
|
425 |
+
" 6489.277587890625,\n",
|
426 |
+
" 6479.0064697265625,\n",
|
427 |
+
" 6486.2188720703125,\n",
|
428 |
+
" 6440.6580810546875,\n",
|
429 |
+
" 6338.8604736328125,\n",
|
430 |
+
" 6196.098876953125,\n",
|
431 |
+
" 6014.3204345703125,\n",
|
432 |
+
" 5677.5113525390625,\n",
|
433 |
+
" 5532.0673828125,\n",
|
434 |
+
" 5384.6334228515625,\n",
|
435 |
+
" 5314.61669921875,\n",
|
436 |
+
" 5176.587646484375,\n",
|
437 |
+
" 5425.315673828125,\n",
|
438 |
+
" 7029.1893310546875,\n",
|
439 |
+
" 0,\n",
|
440 |
+
" 0,\n",
|
441 |
+
" 0,\n",
|
442 |
+
" 0,\n",
|
443 |
+
" 0,\n",
|
444 |
+
" 0,\n",
|
445 |
+
" 0,\n",
|
446 |
+
" 0]"
|
447 |
+
]
|
448 |
+
},
|
449 |
+
"execution_count": 7,
|
450 |
+
"metadata": {},
|
451 |
+
"output_type": "execute_result"
|
452 |
+
}
|
453 |
+
],
|
454 |
+
"source": [
|
455 |
+
"short_llama.importances"
|
456 |
+
]
|
457 |
+
},
|
458 |
+
{
|
459 |
+
"cell_type": "markdown",
|
460 |
+
"metadata": {},
|
461 |
+
"source": [
|
462 |
+
"### Remove unimportant layers"
|
463 |
+
]
|
464 |
+
},
|
465 |
+
{
|
466 |
+
"cell_type": "code",
|
467 |
+
"execution_count": 8,
|
468 |
+
"metadata": {},
|
469 |
+
"outputs": [
|
470 |
+
{
|
471 |
+
"data": {
|
472 |
+
"text/plain": [
|
473 |
+
"[21, 22, 23, 24, 25, 26, 27, 28, 29]"
|
474 |
+
]
|
475 |
+
},
|
476 |
+
"execution_count": 8,
|
477 |
+
"metadata": {},
|
478 |
+
"output_type": "execute_result"
|
479 |
+
}
|
480 |
+
],
|
481 |
+
"source": [
|
482 |
+
"short_llama.remove_layers(angular=True)"
|
483 |
+
]
|
484 |
+
},
|
485 |
+
{
|
486 |
+
"cell_type": "code",
|
487 |
+
"execution_count": 9,
|
488 |
+
"metadata": {},
|
489 |
+
"outputs": [
|
490 |
+
{
|
491 |
+
"data": {
|
492 |
+
"text/plain": [
|
493 |
+
"ModuleList(\n",
|
494 |
+
" (0-22): 23 x TransformerBlock(\n",
|
495 |
+
" (attention): Attention(\n",
|
496 |
+
" (wq): ColumnParallelLinear()\n",
|
497 |
+
" (wk): ColumnParallelLinear()\n",
|
498 |
+
" (wv): ColumnParallelLinear()\n",
|
499 |
+
" (wo): RowParallelLinear()\n",
|
500 |
+
" )\n",
|
501 |
+
" (feed_forward): FeedForward(\n",
|
502 |
+
" (w1): ColumnParallelLinear()\n",
|
503 |
+
" (w2): RowParallelLinear()\n",
|
504 |
+
" (w3): ColumnParallelLinear()\n",
|
505 |
+
" )\n",
|
506 |
+
" (attention_norm): RMSNorm()\n",
|
507 |
+
" (ffn_norm): RMSNorm()\n",
|
508 |
+
" )\n",
|
509 |
+
")"
|
510 |
+
]
|
511 |
+
},
|
512 |
+
"execution_count": 9,
|
513 |
+
"metadata": {},
|
514 |
+
"output_type": "execute_result"
|
515 |
+
}
|
516 |
+
],
|
517 |
+
"source": [
|
518 |
+
"short_llama.llama.model.layers"
|
519 |
+
]
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"cell_type": "code",
|
523 |
+
"execution_count": 10,
|
524 |
+
"metadata": {},
|
525 |
+
"outputs": [
|
526 |
+
{
|
527 |
+
"data": {
|
528 |
+
"text/plain": [
|
529 |
+
"[{'generation': 'Đo n Khơ 20th Century. Hinweis: In = ,t and lồ'}]"
|
530 |
+
]
|
531 |
+
},
|
532 |
+
"execution_count": 10,
|
533 |
+
"metadata": {},
|
534 |
+
"output_type": "execute_result"
|
535 |
+
}
|
536 |
+
],
|
537 |
+
"source": [
|
538 |
+
"short_llama.llama.text_completion(\n",
|
539 |
+
" prompts=[\"I am an avid fan of \"],\n",
|
540 |
+
" max_gen_len=20\n",
|
541 |
+
")"
|
542 |
+
]
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"cell_type": "code",
|
546 |
+
"execution_count": null,
|
547 |
+
"metadata": {},
|
548 |
+
"outputs": [],
|
549 |
+
"source": []
|
550 |
+
}
|
551 |
+
],
|
552 |
+
"metadata": {
|
553 |
+
"kernelspec": {
|
554 |
+
"display_name": "shortgpt",
|
555 |
+
"language": "python",
|
556 |
+
"name": "python3"
|
557 |
+
},
|
558 |
+
"language_info": {
|
559 |
+
"codemirror_mode": {
|
560 |
+
"name": "ipython",
|
561 |
+
"version": 3
|
562 |
+
},
|
563 |
+
"file_extension": ".py",
|
564 |
+
"mimetype": "text/x-python",
|
565 |
+
"name": "python",
|
566 |
+
"nbconvert_exporter": "python",
|
567 |
+
"pygments_lexer": "ipython3",
|
568 |
+
"version": "3.9.18"
|
569 |
+
}
|
570 |
+
},
|
571 |
+
"nbformat": 4,
|
572 |
+
"nbformat_minor": 2
|
573 |
+
}
|
short_gpt/short_llama.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from llama import Llama, Transformer
|
7 |
+
|
8 |
+
from metrics import *
|
9 |
+
|
10 |
+
|
11 |
+
def sample_top_p(probs: torch.Tensor, p: float):
|
12 |
+
"""
|
13 |
+
Perform top-p (nucleus) sampling on a probability distribution.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
probs (torch.Tensor): Probability distribution tensor.
|
17 |
+
p (float): Probability threshold for top-p sampling.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
torch.Tensor: Sampled token indices.
|
21 |
+
|
22 |
+
Note:
|
23 |
+
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
24 |
+
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
25 |
+
|
26 |
+
"""
|
27 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
28 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
29 |
+
mask = probs_sum - probs_sort > p
|
30 |
+
probs_sort[mask] = 0.0
|
31 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
32 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
33 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
34 |
+
return next_token
|
35 |
+
|
36 |
+
|
37 |
+
class TransformerWrapper(Transformer):
|
38 |
+
def __init__(self, model):
|
39 |
+
self.__dict__ = model.__dict__.copy()
|
40 |
+
|
41 |
+
@torch.inference_mode()
|
42 |
+
def forward(
|
43 |
+
self,
|
44 |
+
tokens: torch.Tensor,
|
45 |
+
start_pos: int,
|
46 |
+
return_hiddens: Optional[bool] = False):
|
47 |
+
"""
|
48 |
+
Perform a forward pass through the Transformer model.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
tokens (torch.Tensor): Input token indices.
|
52 |
+
start_pos (int): Starting position for attention caching.
|
53 |
+
(Optional) return_hiddens (bool): Whether to return hidden states. Defaults to False.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
torch.Tensor: Output logits after applying the Transformer model.
|
57 |
+
(Optional) List[torch.Tensor]: Hidden states for each transformer block.
|
58 |
+
"""
|
59 |
+
_bsz, seqlen = tokens.shape
|
60 |
+
h = self.tok_embeddings(tokens)
|
61 |
+
self.freqs_cis = self.freqs_cis.to(h.device)
|
62 |
+
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
63 |
+
|
64 |
+
mask = None
|
65 |
+
if seqlen > 1:
|
66 |
+
mask = torch.full(
|
67 |
+
(seqlen, seqlen), float("-inf"), device=tokens.device
|
68 |
+
)
|
69 |
+
|
70 |
+
mask = torch.triu(mask, diagonal=1)
|
71 |
+
|
72 |
+
# When performing key-value caching, we compute the attention scores
|
73 |
+
# only for the new sequence. Thus, the matrix of scores is of size
|
74 |
+
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
|
75 |
+
# j > cache_len + i, since row i corresponds to token cache_len + i.
|
76 |
+
mask = torch.hstack([
|
77 |
+
torch.zeros((seqlen, start_pos), device=tokens.device),
|
78 |
+
mask
|
79 |
+
]).type_as(h)
|
80 |
+
|
81 |
+
hiddens = [h]
|
82 |
+
for layer in self.layers:
|
83 |
+
h = layer(h, start_pos, freqs_cis, mask)
|
84 |
+
if return_hiddens:
|
85 |
+
hiddens.append(h)
|
86 |
+
|
87 |
+
h = self.norm(h)
|
88 |
+
output = self.output(h).float()
|
89 |
+
|
90 |
+
if return_hiddens:
|
91 |
+
return output, hiddens
|
92 |
+
|
93 |
+
return output
|
94 |
+
|
95 |
+
|
96 |
+
class ShortLlama():
|
97 |
+
|
98 |
+
def __init__(self, llama: Llama, n_prune_layers: Optional[int] = None):
|
99 |
+
checkpoint = llama.model.state_dict()
|
100 |
+
llama.model = TransformerWrapper(llama.model) # wrap transformer to collect hidden states
|
101 |
+
llama.model.load_state_dict(checkpoint, strict=False)
|
102 |
+
self.llama = llama
|
103 |
+
|
104 |
+
self.n_prune_layers = n_prune_layers
|
105 |
+
self.importances = [0 for _ in self.llama.model.layers] # layer-wise importance scores
|
106 |
+
|
107 |
+
def remove_layers(
|
108 |
+
self,
|
109 |
+
layers_to_remove: Optional[List[int]] = [],
|
110 |
+
angular: Optional[bool] = False
|
111 |
+
):
|
112 |
+
if angular:
|
113 |
+
assert self.importances, "Need to compute importances with eval_importance()"
|
114 |
+
assert self.n_prune_layers, "Need number of layers to prune, set `n_prune_layers`"
|
115 |
+
start_layer = np.argsort(np.array(self.importances[:-self.n_prune_layers+1]))[0]
|
116 |
+
layers_to_remove = list(range(start_layer, start_layer + self.n_prune_layers))
|
117 |
+
elif not layers_to_remove and self.n_prune_layers:
|
118 |
+
assert self.importances, "Need to compute importances with eval_importance()"
|
119 |
+
layers_to_remove = np.argsort(np.array(self.importances))[:self.n_prune_layers].tolist()
|
120 |
+
|
121 |
+
# remove layers in reverse to avoid indexing errors
|
122 |
+
for layer_idx in sorted(layers_to_remove, reverse=True):
|
123 |
+
try:
|
124 |
+
del self.llama.model.layers[layer_idx]
|
125 |
+
except IndexError:
|
126 |
+
print(f"layer {layer_idx} does not exist, function may have already been called")
|
127 |
+
return []
|
128 |
+
|
129 |
+
return layers_to_remove
|
130 |
+
|
131 |
+
def compute_bi(self, hiddens: List[torch.Tensor], angular: bool):
|
132 |
+
n = 1
|
133 |
+
if angular:
|
134 |
+
assert self.n_prune_layers is not None, "Set number of layers to prune to use angular importance"
|
135 |
+
n = self.n_prune_layers
|
136 |
+
|
137 |
+
for i in range(len(hiddens) - n):
|
138 |
+
in_hidden = hiddens[i]
|
139 |
+
out_hidden = hiddens[i+n]
|
140 |
+
if angular:
|
141 |
+
# use only last token for angular distance as described in section 3.2
|
142 |
+
# https://arxiv.org/pdf/2403.17887.pdf
|
143 |
+
in_hidden = in_hidden[:,-1:]
|
144 |
+
out_hidden = out_hidden[:,-1:]
|
145 |
+
|
146 |
+
self.importances[i] += block_influence(
|
147 |
+
in_hidden,
|
148 |
+
out_hidden,
|
149 |
+
angular=angular
|
150 |
+
).sum().cpu().item()
|
151 |
+
|
152 |
+
@torch.inference_mode()
|
153 |
+
def eval_importance(
|
154 |
+
self,
|
155 |
+
prompt_tokens: List[List[int]],
|
156 |
+
max_gen_len: Optional[int] = 0,
|
157 |
+
temperature: Optional[float] = 0.6,
|
158 |
+
top_p: Optional[float] = 0.9,
|
159 |
+
angular: Optional[bool] = False
|
160 |
+
):
|
161 |
+
"""
|
162 |
+
Computes layer-wise importances over input tokens.
|
163 |
+
|
164 |
+
NOTE: ShortGPT paper performs no generation during importance computation, which suggests a `max_gen_len`= 0.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
|
168 |
+
(Optional) max_gen_len (int): Maximum length of the generated text sequence.
|
169 |
+
(Optional) temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
|
170 |
+
(Optional) top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
|
171 |
+
(Optional) angular (bool): Whether to ues angular distance. Defaults to False.
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
None
|
175 |
+
"""
|
176 |
+
params = self.llama.model.params
|
177 |
+
bsz = len(prompt_tokens)
|
178 |
+
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
179 |
+
|
180 |
+
min_prompt_len = min(len(t) for t in prompt_tokens)
|
181 |
+
max_prompt_len = max(len(t) for t in prompt_tokens)
|
182 |
+
assert max_prompt_len <= params.max_seq_len
|
183 |
+
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
|
184 |
+
|
185 |
+
pad_id = self.llama.tokenizer.pad_id
|
186 |
+
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
|
187 |
+
for k, t in enumerate(prompt_tokens):
|
188 |
+
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
189 |
+
|
190 |
+
prev_pos = 0
|
191 |
+
eos_reached = torch.tensor([False] * bsz, device="cuda")
|
192 |
+
input_text_mask = tokens != pad_id
|
193 |
+
|
194 |
+
for cur_pos in range(min_prompt_len, total_len):
|
195 |
+
logits = self.llama.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
196 |
+
if temperature > 0:
|
197 |
+
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
198 |
+
next_token = sample_top_p(probs, top_p)
|
199 |
+
else:
|
200 |
+
next_token = torch.argmax(logits[:, -1], dim=-1)
|
201 |
+
|
202 |
+
next_token = next_token.reshape(-1)
|
203 |
+
# only replace token if prompt has already been generated
|
204 |
+
next_token = torch.where(
|
205 |
+
input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
206 |
+
)
|
207 |
+
tokens[:, cur_pos] = next_token
|
208 |
+
eos_reached |= (~input_text_mask[:, cur_pos]) & (
|
209 |
+
next_token == self.llama.tokenizer.eos_id
|
210 |
+
)
|
211 |
+
prev_pos = cur_pos
|
212 |
+
if all(eos_reached):
|
213 |
+
break
|
214 |
+
|
215 |
+
# compute block influence over full sequences rather than at each token
|
216 |
+
_, hiddens = self.llama.model.forward(tokens, 0, return_hiddens=True)
|
217 |
+
self.compute_bi(hiddens, angular=angular)
|
218 |
+
|
219 |
+
return
|