Nonnormalizable commited on
Commit
3ec6adb
·
1 Parent(s): a5a3465

Fine tuning bert-base to classify text.

Browse files
Files changed (2) hide show
  1. Finetune BERT.ipynb +512 -0
  2. tasks/text.py +1 -1
Finetune BERT.ipynb ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "73e72549-69f2-46b5-b0f5-655777139972",
7
+ "metadata": {
8
+ "execution": {
9
+ "iopub.execute_input": "2025-01-17T04:45:37.715126Z",
10
+ "iopub.status.busy": "2025-01-17T04:45:37.714808Z",
11
+ "iopub.status.idle": "2025-01-17T04:45:41.232154Z",
12
+ "shell.execute_reply": "2025-01-17T04:45:41.231851Z",
13
+ "shell.execute_reply.started": "2025-01-17T04:45:37.715090Z"
14
+ }
15
+ },
16
+ "outputs": [],
17
+ "source": [
18
+ "from datetime import datetime\n",
19
+ "import numpy as np\n",
20
+ "import torch\n",
21
+ "from torch import nn\n",
22
+ "from transformers import BertTokenizer, BertModel\n",
23
+ "from torch.utils.data import Dataset, DataLoader\n",
24
+ "from datasets import load_dataset"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "id": "d4b79fb9-5e70-4600-8885-94bc0a6e917c",
31
+ "metadata": {
32
+ "execution": {
33
+ "iopub.execute_input": "2025-01-17T04:45:41.232694Z",
34
+ "iopub.status.busy": "2025-01-17T04:45:41.232554Z",
35
+ "iopub.status.idle": "2025-01-17T04:45:41.236434Z",
36
+ "shell.execute_reply": "2025-01-17T04:45:41.236218Z",
37
+ "shell.execute_reply.started": "2025-01-17T04:45:41.232685Z"
38
+ }
39
+ },
40
+ "outputs": [],
41
+ "source": [
42
+ "def my_print(x):\n",
43
+ " time_str = datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\")\n",
44
+ " print(time_str, x)\n",
45
+ "\n",
46
+ "class BertClassifier(nn.Module):\n",
47
+ " def __init__(self, num_classes: int = 8, bert_variety='bert-base-uncased'):\n",
48
+ " super().__init__()\n",
49
+ " self.bert = BertModel.from_pretrained(bert_variety)\n",
50
+ " self.dropout = nn.Dropout(0.05)\n",
51
+ " self.classifier = nn.Linear(self.bert.pooler.dense.out_features, num_classes)\n",
52
+ "\n",
53
+ " def forward(self, input_ids, attention_mask):\n",
54
+ " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n",
55
+ " pooled_output = outputs.pooler_output\n",
56
+ " pooled_output = self.dropout(pooled_output)\n",
57
+ " logits = self.classifier(pooled_output)\n",
58
+ " return logits\n",
59
+ "\n",
60
+ "class TextDataset(Dataset):\n",
61
+ " def __init__(self, texts, labels, tokenizer, max_length=200):\n",
62
+ " self.encodings = tokenizer(\n",
63
+ " texts,\n",
64
+ " truncation=True,\n",
65
+ " padding=True,\n",
66
+ " max_length=max_length,\n",
67
+ " return_tensors='pt',\n",
68
+ " )\n",
69
+ " self.labels = torch.tensor([int(l[0]) for l in labels])\n",
70
+ "\n",
71
+ " def __getitem__(self, idx):\n",
72
+ " item = {key: val[idx] for key, val in self.encodings.items()}\n",
73
+ " item['labels'] = self.labels[idx]\n",
74
+ " return item\n",
75
+ "\n",
76
+ " def __len__(self) -> int:\n",
77
+ " return len(self.labels)\n",
78
+ "\n",
79
+ "def train_model(model, train_dataloader, device, num_epochs):\n",
80
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)\n",
81
+ " criterion = nn.CrossEntropyLoss()\n",
82
+ " model.train()\n",
83
+ "\n",
84
+ " my_print('Starting epoch 1.')\n",
85
+ " for epoch in range(num_epochs):\n",
86
+ " total_loss = 0\n",
87
+ " for batch in train_dataloader:\n",
88
+ " optimizer.zero_grad()\n",
89
+ "\n",
90
+ " input_ids = batch['input_ids'].to(device)\n",
91
+ " attention_mask = batch['attention_mask'].to(device)\n",
92
+ " labels = batch['labels'].to(device)\n",
93
+ "\n",
94
+ " outputs = model(input_ids, attention_mask)\n",
95
+ " loss = criterion(outputs, labels)\n",
96
+ "\n",
97
+ " loss.backward()\n",
98
+ " optimizer.step()\n",
99
+ "\n",
100
+ " total_loss += loss.item()\n",
101
+ " avg_loss = total_loss / len(train_dataloader)\n",
102
+ " my_print(f'Epoch {epoch+1}/{num_epochs} done, Average Loss: {avg_loss:0.4f}')"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 3,
108
+ "id": "07131bce-23ad-4787-8622-cce401f3e5ce",
109
+ "metadata": {
110
+ "execution": {
111
+ "iopub.execute_input": "2025-01-17T04:45:41.237451Z",
112
+ "iopub.status.busy": "2025-01-17T04:45:41.237358Z",
113
+ "iopub.status.idle": "2025-01-17T04:45:41.252075Z",
114
+ "shell.execute_reply": "2025-01-17T04:45:41.251851Z",
115
+ "shell.execute_reply.started": "2025-01-17T04:45:41.237443Z"
116
+ }
117
+ },
118
+ "outputs": [],
119
+ "source": [
120
+ "if torch.backends.mps.is_available():\n",
121
+ " device = torch.device('mps')\n",
122
+ " torch.mps.empty_cache()\n",
123
+ "elif torch.cuda.is_available():\n",
124
+ " device = torch.device('cuda')\n",
125
+ "else:\n",
126
+ " device = torch.device('cpu')"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 4,
132
+ "id": "695bc080-bbd7-4937-af5b-50db1c936500",
133
+ "metadata": {
134
+ "execution": {
135
+ "iopub.execute_input": "2025-01-17T04:45:41.252581Z",
136
+ "iopub.status.busy": "2025-01-17T04:45:41.252476Z",
137
+ "iopub.status.idle": "2025-01-17T04:45:41.255279Z",
138
+ "shell.execute_reply": "2025-01-17T04:45:41.255045Z",
139
+ "shell.execute_reply.started": "2025-01-17T04:45:41.252572Z"
140
+ }
141
+ },
142
+ "outputs": [],
143
+ "source": [
144
+ "def run_training(\n",
145
+ " max_dataset_size=16 * 200,\n",
146
+ " bert_variety='bert-base-uncased',\n",
147
+ " max_length=200,\n",
148
+ " num_epochs=3,\n",
149
+ " batch_size=32,\n",
150
+ "):\n",
151
+ " hf_dataset = load_dataset(\"quotaclimat/frugalaichallenge-text-train\")\n",
152
+ " if not max_dataset_size == 'full' and max_dataset_size < len(hf_dataset['train']):\n",
153
+ " train_dataset = hf_dataset['train'][:max_dataset_size]\n",
154
+ " else:\n",
155
+ " train_dataset = hf_dataset['train']\n",
156
+ " \n",
157
+ " tokenizer = BertTokenizer.from_pretrained(bert_variety, max_length=max_length)\n",
158
+ " model = BertClassifier(bert_variety=bert_variety)\n",
159
+ " if torch.backends.mps.is_available():\n",
160
+ " device = torch.device('mps')\n",
161
+ " torch.mps.empty_cache()\n",
162
+ " elif torch.cuda.is_available():\n",
163
+ " device = torch.device('cuda')\n",
164
+ " else:\n",
165
+ " device = torch.device('cpu')\n",
166
+ " model.to(device)\n",
167
+ " \n",
168
+ " dataset = TextDataset(\n",
169
+ " train_dataset['quote'],\n",
170
+ " train_dataset['label'],\n",
171
+ " tokenizer=tokenizer,\n",
172
+ " max_length=max_length,\n",
173
+ " )\n",
174
+ " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
175
+ " \n",
176
+ " train_model(model, dataloader, device, num_epochs=num_epochs)\n",
177
+ " return model, tokenizer"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 5,
183
+ "id": "792fd13f-e7cc-4d90-832d-c0da15e193cd",
184
+ "metadata": {
185
+ "execution": {
186
+ "iopub.execute_input": "2025-01-17T04:45:41.255750Z",
187
+ "iopub.status.busy": "2025-01-17T04:45:41.255661Z",
188
+ "iopub.status.idle": "2025-01-17T04:47:17.151654Z",
189
+ "shell.execute_reply": "2025-01-17T04:47:17.149076Z",
190
+ "shell.execute_reply.started": "2025-01-17T04:45:41.255742Z"
191
+ }
192
+ },
193
+ "outputs": [
194
+ {
195
+ "name": "stdout",
196
+ "output_type": "stream",
197
+ "text": [
198
+ "2025-01-16 20:45:45 Starting epoch 1.\n",
199
+ "2025-01-16 20:46:15 Epoch 1/3 done, Average Loss: 1.9223\n",
200
+ "2025-01-16 20:46:46 Epoch 2/3 done, Average Loss: 1.6052\n",
201
+ "2025-01-16 20:47:17 Epoch 3/3 done, Average Loss: 1.2876\n"
202
+ ]
203
+ }
204
+ ],
205
+ "source": [
206
+ "model, tokenizer = run_training(\n",
207
+ " max_dataset_size=16 * 50,\n",
208
+ " bert_variety='bert-base-uncased',\n",
209
+ " max_length=200,\n",
210
+ " num_epochs=3,\n",
211
+ " batch_size=32,\n",
212
+ ")"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": 6,
218
+ "id": "0aedfcca-843e-4f4c-8062-3e4625161bcc",
219
+ "metadata": {
220
+ "execution": {
221
+ "iopub.execute_input": "2025-01-17T04:47:17.158101Z",
222
+ "iopub.status.busy": "2025-01-17T04:47:17.157305Z",
223
+ "iopub.status.idle": "2025-01-17T04:47:17.333568Z",
224
+ "shell.execute_reply": "2025-01-17T04:47:17.333317Z",
225
+ "shell.execute_reply.started": "2025-01-17T04:47:17.157437Z"
226
+ }
227
+ },
228
+ "outputs": [
229
+ {
230
+ "name": "stdout",
231
+ "output_type": "stream",
232
+ "text": [
233
+ "2025-01-16 20:47:17 Predictions: tensor([6, 1, 1, 6, 1, 6, 6], device='mps:0')\n"
234
+ ]
235
+ }
236
+ ],
237
+ "source": [
238
+ "model.eval()\n",
239
+ "test_text = [\n",
240
+ " 'This was a great experience!', # 0_not_relevant\n",
241
+ " 'My favorite hike is Laguna de los Tres.', # 0_not_relevant\n",
242
+ " 'Crops will grow great in Finland if it\\'s warmer there.', # 3_not_bad\n",
243
+ " 'Climate change is fake.', # 1_not_happening\n",
244
+ " 'The apparent warming is caused by solar cycles.', # 2_not_human\n",
245
+ " 'Solar panels emit bad vibes.', # 4_solutions_harmful_unnecessary\n",
246
+ " 'All those so-called scientists are Democrats.', # 6_proponents_biased\n",
247
+ "]\n",
248
+ "test_encoding = tokenizer(\n",
249
+ " test_text,\n",
250
+ " truncation=True,\n",
251
+ " padding=True,\n",
252
+ " return_tensors='pt',\n",
253
+ ")\n",
254
+ "\n",
255
+ "with torch.no_grad():\n",
256
+ " test_input_ids = test_encoding['input_ids'].to(device)\n",
257
+ " test_attention_mask = test_encoding['attention_mask'].to(device)\n",
258
+ " outputs = model(test_input_ids, test_attention_mask)\n",
259
+ " predictions = torch.argmax(outputs, dim=1)\n",
260
+ " my_print(f'Predictions: {predictions}')"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": 7,
266
+ "id": "881b738e-2392-4b7e-a0de-a0bad572ddfa",
267
+ "metadata": {
268
+ "execution": {
269
+ "iopub.execute_input": "2025-01-17T04:47:17.334399Z",
270
+ "iopub.status.busy": "2025-01-17T04:47:17.334287Z",
271
+ "iopub.status.idle": "2025-01-17T04:50:59.116389Z",
272
+ "shell.execute_reply": "2025-01-17T04:50:59.115528Z",
273
+ "shell.execute_reply.started": "2025-01-17T04:47:17.334390Z"
274
+ }
275
+ },
276
+ "outputs": [
277
+ {
278
+ "name": "stdout",
279
+ "output_type": "stream",
280
+ "text": [
281
+ "2025-01-16 20:47:23 Starting epoch 1.\n",
282
+ "2025-01-16 20:48:35 Epoch 1/3 done, Average Loss: 1.4272\n",
283
+ "2025-01-16 20:49:46 Epoch 2/3 done, Average Loss: 0.8694\n",
284
+ "2025-01-16 20:50:59 Epoch 3/3 done, Average Loss: 0.5774\n"
285
+ ]
286
+ }
287
+ ],
288
+ "source": [
289
+ "model, tokenizer = run_training(\n",
290
+ " max_dataset_size='full',\n",
291
+ " bert_variety='bert-base-uncased',\n",
292
+ " max_length=64,\n",
293
+ " num_epochs=3,\n",
294
+ " batch_size=32,\n",
295
+ ")"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 8,
301
+ "id": "1d29336e-7f88-4127-afdf-2fe043e310e1",
302
+ "metadata": {
303
+ "execution": {
304
+ "iopub.execute_input": "2025-01-17T04:50:59.118025Z",
305
+ "iopub.status.busy": "2025-01-17T04:50:59.117838Z",
306
+ "iopub.status.idle": "2025-01-17T04:58:02.423121Z",
307
+ "shell.execute_reply": "2025-01-17T04:58:02.421532Z",
308
+ "shell.execute_reply.started": "2025-01-17T04:50:59.118005Z"
309
+ }
310
+ },
311
+ "outputs": [
312
+ {
313
+ "name": "stdout",
314
+ "output_type": "stream",
315
+ "text": [
316
+ "2025-01-16 20:51:04 Starting epoch 1.\n",
317
+ "2025-01-16 20:53:20 Epoch 1/3 done, Average Loss: 1.4107\n",
318
+ "2025-01-16 20:55:41 Epoch 2/3 done, Average Loss: 0.8491\n",
319
+ "2025-01-16 20:58:02 Epoch 3/3 done, Average Loss: 0.5359\n"
320
+ ]
321
+ }
322
+ ],
323
+ "source": [
324
+ "model, tokenizer = run_training(\n",
325
+ " max_dataset_size='full',\n",
326
+ " bert_variety='bert-base-uncased',\n",
327
+ " max_length=128,\n",
328
+ " num_epochs=3,\n",
329
+ " batch_size=32,\n",
330
+ ")"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": 9,
336
+ "id": "461b8f57-0c52-403a-bb69-3bc192b323bf",
337
+ "metadata": {
338
+ "execution": {
339
+ "iopub.execute_input": "2025-01-17T04:58:02.426159Z",
340
+ "iopub.status.busy": "2025-01-17T04:58:02.425896Z",
341
+ "iopub.status.idle": "2025-01-17T05:05:36.903446Z",
342
+ "shell.execute_reply": "2025-01-17T05:05:36.901961Z",
343
+ "shell.execute_reply.started": "2025-01-17T04:58:02.426132Z"
344
+ }
345
+ },
346
+ "outputs": [
347
+ {
348
+ "name": "stdout",
349
+ "output_type": "stream",
350
+ "text": [
351
+ "2025-01-16 20:58:08 Starting epoch 1.\n",
352
+ "2025-01-16 21:00:38 Epoch 1/3 done, Average Loss: 1.2946\n",
353
+ "2025-01-16 21:03:07 Epoch 2/3 done, Average Loss: 0.7425\n",
354
+ "2025-01-16 21:05:36 Epoch 3/3 done, Average Loss: 0.4126\n"
355
+ ]
356
+ }
357
+ ],
358
+ "source": [
359
+ "model, tokenizer = run_training(\n",
360
+ " max_dataset_size='full',\n",
361
+ " bert_variety='bert-base-uncased',\n",
362
+ " max_length=128,\n",
363
+ " num_epochs=3,\n",
364
+ " batch_size=16,\n",
365
+ ")"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "code",
370
+ "execution_count": 10,
371
+ "id": "28354e8c-886a-4523-8968-8c688c13f6a3",
372
+ "metadata": {
373
+ "execution": {
374
+ "iopub.execute_input": "2025-01-17T05:05:36.905668Z",
375
+ "iopub.status.busy": "2025-01-17T05:05:36.905353Z",
376
+ "iopub.status.idle": "2025-01-17T05:21:10.045463Z",
377
+ "shell.execute_reply": "2025-01-17T05:21:10.044788Z",
378
+ "shell.execute_reply.started": "2025-01-17T05:05:36.905630Z"
379
+ }
380
+ },
381
+ "outputs": [
382
+ {
383
+ "name": "stdout",
384
+ "output_type": "stream",
385
+ "text": [
386
+ "2025-01-16 21:05:43 Starting epoch 1.\n",
387
+ "2025-01-16 21:10:53 Epoch 1/3 done, Average Loss: 1.3415\n",
388
+ "2025-01-16 21:16:02 Epoch 2/3 done, Average Loss: 0.7216\n",
389
+ "2025-01-16 21:21:10 Epoch 3/3 done, Average Loss: 0.3978\n"
390
+ ]
391
+ }
392
+ ],
393
+ "source": [
394
+ "model, tokenizer = run_training(\n",
395
+ " max_dataset_size='full',\n",
396
+ " bert_variety='bert-base-uncased',\n",
397
+ " max_length=256,\n",
398
+ " num_epochs=3,\n",
399
+ " batch_size=16,\n",
400
+ ")"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": 11,
406
+ "id": "e3b099c6-6b98-473b-8797-5032213b9fcb",
407
+ "metadata": {
408
+ "execution": {
409
+ "iopub.execute_input": "2025-01-17T05:21:10.059844Z",
410
+ "iopub.status.busy": "2025-01-17T05:21:10.058980Z",
411
+ "iopub.status.idle": "2025-01-17T05:21:10.164116Z",
412
+ "shell.execute_reply": "2025-01-17T05:21:10.163826Z",
413
+ "shell.execute_reply.started": "2025-01-17T05:21:10.059552Z"
414
+ }
415
+ },
416
+ "outputs": [
417
+ {
418
+ "name": "stdout",
419
+ "output_type": "stream",
420
+ "text": [
421
+ "2025-01-16 21:21:10 Predictions: tensor([0, 0, 3, 6, 2, 4, 6], device='mps:0')\n"
422
+ ]
423
+ }
424
+ ],
425
+ "source": [
426
+ "model.eval()\n",
427
+ "test_text = [\n",
428
+ " 'This was a great experience!', # 0_not_relevant\n",
429
+ " 'My favorite hike is Laguna de los Tres.', # 0_not_relevant\n",
430
+ " 'Crops will grow great in Finland if it\\'s warmer there.', # 3_not_bad\n",
431
+ " 'Climate change is fake.', # 1_not_happening\n",
432
+ " 'The apparent warming is caused by solar cycles.', # 2_not_human\n",
433
+ " 'Solar panels emit bad vibes.', # 4_solutions_harmful_unnecessary\n",
434
+ " 'All those so-called scientists are Democrats.', # 6_proponents_biased\n",
435
+ "]\n",
436
+ "test_encoding = tokenizer(\n",
437
+ " test_text,\n",
438
+ " truncation=True,\n",
439
+ " padding=True,\n",
440
+ " return_tensors='pt',\n",
441
+ ")\n",
442
+ "\n",
443
+ "with torch.no_grad():\n",
444
+ " test_input_ids = test_encoding['input_ids'].to(device)\n",
445
+ " test_attention_mask = test_encoding['attention_mask'].to(device)\n",
446
+ " outputs = model(test_input_ids, test_attention_mask)\n",
447
+ " predictions = torch.argmax(outputs, dim=1)\n",
448
+ " my_print(f'Predictions: {predictions}')"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": 12,
454
+ "id": "befb94b5-88bf-40fc-8b26-cf373d1256e0",
455
+ "metadata": {
456
+ "execution": {
457
+ "iopub.execute_input": "2025-01-17T05:27:58.042752Z",
458
+ "iopub.status.busy": "2025-01-17T05:27:58.042151Z",
459
+ "iopub.status.idle": "2025-01-17T05:27:58.454054Z",
460
+ "shell.execute_reply": "2025-01-17T05:27:58.453644Z",
461
+ "shell.execute_reply.started": "2025-01-17T05:27:58.042662Z"
462
+ }
463
+ },
464
+ "outputs": [
465
+ {
466
+ "ename": "AttributeError",
467
+ "evalue": "'BertClassifier' object has no attribute 'push_to_hub'",
468
+ "output_type": "error",
469
+ "traceback": [
470
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
471
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
472
+ "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m()\n",
473
+ "File \u001b[0;32m~/miniconda3/envs/py313/lib/python3.13/site-packages/torch/nn/modules/module.py:1931\u001b[0m, in \u001b[0;36mModule.__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1929\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01min\u001b[39;00m modules:\n\u001b[1;32m 1930\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m modules[name]\n\u001b[0;32m-> 1931\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mAttributeError\u001b[39;00m(\n\u001b[1;32m 1932\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mself\u001b[39m)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m object has no attribute \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1933\u001b[0m )\n",
474
+ "\u001b[0;31mAttributeError\u001b[0m: 'BertClassifier' object has no attribute 'push_to_hub'"
475
+ ]
476
+ }
477
+ ],
478
+ "source": [
479
+ "model.push_to_hub()"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": null,
485
+ "id": "251ef9ee-8ba3-495f-8fe6-a93aa63168ce",
486
+ "metadata": {},
487
+ "outputs": [],
488
+ "source": []
489
+ }
490
+ ],
491
+ "metadata": {
492
+ "kernelspec": {
493
+ "display_name": "Python 3 (ipykernel)",
494
+ "language": "python",
495
+ "name": "python3"
496
+ },
497
+ "language_info": {
498
+ "codemirror_mode": {
499
+ "name": "ipython",
500
+ "version": 3
501
+ },
502
+ "file_extension": ".py",
503
+ "mimetype": "text/x-python",
504
+ "name": "python",
505
+ "nbconvert_exporter": "python",
506
+ "pygments_lexer": "ipython3",
507
+ "version": "3.13.1"
508
+ }
509
+ },
510
+ "nbformat": 4,
511
+ "nbformat_minor": 5
512
+ }
tasks/text.py CHANGED
@@ -12,7 +12,7 @@ router = APIRouter()
12
  DESCRIPTION = "Most common class baseline"
13
  ROUTE = "/text"
14
 
15
- def baseline_model(dataset_length):
16
  # Make random predictions (placeholder for actual model inference)
17
  #predictions = [random.randint(0, 7) for _ in range(dataset_length)]
18
 
 
12
  DESCRIPTION = "Most common class baseline"
13
  ROUTE = "/text"
14
 
15
+ def baseline_model(dataset_length: int):
16
  # Make random predictions (placeholder for actual model inference)
17
  #predictions = [random.randint(0, 7) for _ in range(dataset_length)]
18