Jingyuan-Zhu commited on
Commit
03f782f
·
verified ·
1 Parent(s): 82af86f

added the final_model weight and evaluation pipeline

Browse files
Files changed (2) hide show
  1. NN_evaluation.ipynb +269 -0
  2. final_model.pth +3 -0
NN_evaluation.ipynb ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "459d08f0-b6d6-4062-8e60-d3349ae86cce",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
14
+ "To disable this warning, you can either:\n",
15
+ "\t- Avoid using `tokenizers` before the fork if possible\n",
16
+ "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
17
+ ]
18
+ },
19
+ {
20
+ "name": "stdout",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
24
+ "Requirement already satisfied: FlagEmbedding in /opt/conda/lib/python3.11/site-packages (1.3.3)\n",
25
+ "Requirement already satisfied: torch>=1.6.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.2.2+cu121)\n",
26
+ "Requirement already satisfied: transformers==4.44.2 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.44.2)\n",
27
+ "Requirement already satisfied: datasets==2.19.0 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (2.19.0)\n",
28
+ "Requirement already satisfied: accelerate>=0.20.1 in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (1.2.0)\n",
29
+ "Requirement already satisfied: sentence-transformers in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (3.3.1)\n",
30
+ "Requirement already satisfied: peft in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.14.0)\n",
31
+ "Requirement already satisfied: ir-datasets in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.5.9)\n",
32
+ "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (0.2.0)\n",
33
+ "Requirement already satisfied: protobuf in /opt/conda/lib/python3.11/site-packages (from FlagEmbedding) (4.25.3)\n",
34
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.9.0)\n",
35
+ "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (1.26.4)\n",
36
+ "Requirement already satisfied: pyarrow>=12.0.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (15.0.2)\n",
37
+ "Requirement already satisfied: pyarrow-hotfix in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.6)\n",
38
+ "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.3.8)\n",
39
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.2.2)\n",
40
+ "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (2.31.0)\n",
41
+ "Requirement already satisfied: tqdm>=4.62.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (4.66.2)\n",
42
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.5.0)\n",
43
+ "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.70.16)\n",
44
+ "Requirement already satisfied: fsspec<=2024.3.1,>=2023.1.0 in /opt/conda/lib/python3.11/site-packages (from fsspec[http]<=2024.3.1,>=2023.1.0->datasets==2.19.0->FlagEmbedding) (2024.3.1)\n",
45
+ "Requirement already satisfied: aiohttp in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (3.11.10)\n",
46
+ "Requirement already satisfied: huggingface-hub>=0.21.2 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (0.26.5)\n",
47
+ "Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (24.0)\n",
48
+ "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.11/site-packages (from datasets==2.19.0->FlagEmbedding) (6.0.1)\n",
49
+ "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (2024.11.6)\n",
50
+ "Requirement already satisfied: safetensors>=0.4.1 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.4.5)\n",
51
+ "Requirement already satisfied: tokenizers<0.20,>=0.19 in /opt/conda/lib/python3.11/site-packages (from transformers==4.44.2->FlagEmbedding) (0.19.1)\n",
52
+ "Requirement already satisfied: psutil in /opt/conda/lib/python3.11/site-packages (from accelerate>=0.20.1->FlagEmbedding) (5.9.8)\n",
53
+ "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (4.11.0)\n",
54
+ "Requirement already satisfied: sympy in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (1.12)\n",
55
+ "Requirement already satisfied: networkx in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.3)\n",
56
+ "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (3.1.3)\n",
57
+ "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
58
+ "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
59
+ "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
60
+ "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (8.9.2.26)\n",
61
+ "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.3.1)\n",
62
+ "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.0.2.54)\n",
63
+ "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (10.3.2.106)\n",
64
+ "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (11.4.5.107)\n",
65
+ "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.0.106)\n",
66
+ "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.19.3)\n",
67
+ "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (12.1.105)\n",
68
+ "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.11/site-packages (from torch>=1.6.0->FlagEmbedding) (2.2.0)\n",
69
+ "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.6.0->FlagEmbedding) (12.4.127)\n",
70
+ "Requirement already satisfied: beautifulsoup4>=4.4.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.12.3)\n",
71
+ "Requirement already satisfied: inscriptis>=2.2.0 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.5.0)\n",
72
+ "Requirement already satisfied: lxml>=4.5.2 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (5.3.0)\n",
73
+ "Requirement already satisfied: trec-car-tools>=2.5.4 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (2.6)\n",
74
+ "Requirement already satisfied: lz4>=3.1.10 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (4.3.3)\n",
75
+ "Requirement already satisfied: warc3-wet>=0.2.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
76
+ "Requirement already satisfied: warc3-wet-clueweb09>=0.2.5 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.5)\n",
77
+ "Requirement already satisfied: zlib-state>=0.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.1.9)\n",
78
+ "Requirement already satisfied: ijson>=3.1.3 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (3.3.0)\n",
79
+ "Requirement already satisfied: unlzw3>=0.2.1 in /opt/conda/lib/python3.11/site-packages (from ir-datasets->FlagEmbedding) (0.2.2)\n",
80
+ "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.4.2)\n",
81
+ "Requirement already satisfied: scipy in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (1.13.0)\n",
82
+ "Requirement already satisfied: Pillow in /opt/conda/lib/python3.11/site-packages (from sentence-transformers->FlagEmbedding) (10.3.0)\n",
83
+ "Requirement already satisfied: soupsieve>1.2 in /opt/conda/lib/python3.11/site-packages (from beautifulsoup4>=4.4.1->ir-datasets->FlagEmbedding) (2.5)\n",
84
+ "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (2.4.4)\n",
85
+ "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.3.2)\n",
86
+ "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (23.2.0)\n",
87
+ "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.5.0)\n",
88
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (6.1.0)\n",
89
+ "Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (0.2.1)\n",
90
+ "Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.11/site-packages (from aiohttp->datasets==2.19.0->FlagEmbedding) (1.18.3)\n",
91
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.3.2)\n",
92
+ "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (3.7)\n",
93
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2.2.1)\n",
94
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.11/site-packages (from requests>=2.19.0->datasets==2.19.0->FlagEmbedding) (2024.2.2)\n",
95
+ "Requirement already satisfied: cbor>=1.0.0 in /opt/conda/lib/python3.11/site-packages (from trec-car-tools>=2.5.4->ir-datasets->FlagEmbedding) (1.0.0)\n",
96
+ "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.11/site-packages (from jinja2->torch>=1.6.0->FlagEmbedding) (2.1.5)\n",
97
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2.9.0)\n",
98
+ "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
99
+ "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.11/site-packages (from pandas->datasets==2.19.0->FlagEmbedding) (2024.1)\n",
100
+ "Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (1.4.0)\n",
101
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.11/site-packages (from scikit-learn->sentence-transformers->FlagEmbedding) (3.4.0)\n",
102
+ "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.11/site-packages (from sympy->torch>=1.6.0->FlagEmbedding) (1.3.0)\n",
103
+ "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets==2.19.0->FlagEmbedding) (1.16.0)\n",
104
+ "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n",
105
+ "\u001b[0m"
106
+ ]
107
+ },
108
+ {
109
+ "data": {
110
+ "application/vnd.jupyter.widget-view+json": {
111
+ "model_id": "3cb658d3527149008a6d17938370142d",
112
+ "version_major": 2,
113
+ "version_minor": 0
114
+ },
115
+ "text/plain": [
116
+ "Fetching 30 files: 0%| | 0/30 [00:00<?, ?it/s]"
117
+ ]
118
+ },
119
+ "metadata": {},
120
+ "output_type": "display_data"
121
+ },
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "Encoding titles...\n"
127
+ ]
128
+ },
129
+ {
130
+ "name": "stderr",
131
+ "output_type": "stream",
132
+ "text": [
133
+ "You're using a XLMRobertaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
134
+ ]
135
+ },
136
+ {
137
+ "name": "stdout",
138
+ "output_type": "stream",
139
+ "text": [
140
+ "Processed 20/20 titles\n",
141
+ "Test Accuracy: 0.9000\n"
142
+ ]
143
+ }
144
+ ],
145
+ "source": [
146
+ "import torch\n",
147
+ "import pandas as pd\n",
148
+ "from sklearn.metrics import accuracy_score\n",
149
+ "from torch.utils.data import DataLoader, TensorDataset\n",
150
+ "!pip install -U FlagEmbedding\n",
151
+ "from FlagEmbedding import BGEM3FlagModel\n",
152
+ "import torch.nn as nn\n",
153
+ "model = BGEM3FlagModel('BAAI/bge-m3')\n",
154
+ "\n",
155
+ "\n",
156
+ "# Define paths\n",
157
+ "test_data_path = \"/home/jovyan/work/test_data_random_subset.csv\"\n",
158
+ "model_weights_path = \"/home/jovyan/work/model_weights/final_model.pth\"\n",
159
+ "\n",
160
+ "data = data = pd.read_csv(test_data_path)\n",
161
+ "titles = data['title'].tolist()\n",
162
+ "labels = data['labels'].tolist()\n",
163
+ "\n",
164
+ "batch_size = 32\n",
165
+ "embeddings = []\n",
166
+ "\n",
167
+ "print('Encoding titles...')\n",
168
+ "for i in range(0, len(titles), batch_size):\n",
169
+ " batch = titles[i:i + batch_size]\n",
170
+ " batch_embeddings = model.encode(batch, batch_size=batch_size, max_length=512)['dense_vecs']\n",
171
+ " embeddings.extend(batch_embeddings)\n",
172
+ " print(f\"Processed {i + len(batch)}/{len(titles)} titles\")\n",
173
+ "\n",
174
+ "embeddings_df = pd.DataFrame(embeddings)\n",
175
+ "embeddings_df['label'] = labels\n",
176
+ "\n",
177
+ "X_test = torch.FloatTensor(embeddings_df.iloc[:, :-1].values) # Features\n",
178
+ "y_test = torch.FloatTensor(embeddings_df.iloc[:, -1].values).view(-1, 1) # Labels\n",
179
+ "\n",
180
+ "# Create DataLoader for the test dataset\n",
181
+ "test_dataset = TensorDataset(X_test, y_test)\n",
182
+ "test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n",
183
+ "\n",
184
+ "# Define the model architecture\n",
185
+ "class EmbeddingMLPClassifier(nn.Module):\n",
186
+ " def __init__(self,\n",
187
+ " input_dim=1024,\n",
188
+ " hidden_layers=[512, 256, 128],\n",
189
+ " dropout_rate=0.2,\n",
190
+ " device=None):\n",
191
+ " super(EmbeddingMLPClassifier, self).__init__()\n",
192
+ " self.device = device or torch.device('cpu') # Force CPU usage\n",
193
+ "\n",
194
+ " layers = []\n",
195
+ " prev_dim = input_dim\n",
196
+ " for hidden_dim in hidden_layers:\n",
197
+ " layers.extend([\n",
198
+ " nn.Linear(prev_dim, hidden_dim),\n",
199
+ " nn.BatchNorm1d(hidden_dim),\n",
200
+ " nn.ReLU(),\n",
201
+ " nn.Dropout(dropout_rate)\n",
202
+ " ])\n",
203
+ " prev_dim = hidden_dim\n",
204
+ "\n",
205
+ " layers.append(nn.Linear(prev_dim, 1))\n",
206
+ " layers.append(nn.Sigmoid())\n",
207
+ "\n",
208
+ " self.model = nn.Sequential(*layers)\n",
209
+ " self.to(self.device)\n",
210
+ "\n",
211
+ " def forward(self, x):\n",
212
+ " return self.model(x)\n",
213
+ "\n",
214
+ "# Load the saved model\n",
215
+ "device = torch.device('cpu') # Force evaluation on CPU\n",
216
+ "model = EmbeddingMLPClassifier(input_dim=X_test.shape[1], device=device)\n",
217
+ "model.load_state_dict(torch.load(model_weights_path, map_location=device)) # Map weights to CPU\n",
218
+ "model.eval()\n",
219
+ "\n",
220
+ "# Evaluate the model on the test set\n",
221
+ "test_preds, test_labels = [], []\n",
222
+ "\n",
223
+ "with torch.no_grad():\n",
224
+ " for batch_X, batch_y in test_loader:\n",
225
+ " batch_X = batch_X.to(device) # Ensure data is on CPU\n",
226
+ " batch_y = batch_y.to(device)\n",
227
+ "\n",
228
+ " outputs = model(batch_X)\n",
229
+ " preds = (outputs > 0.5).float()\n",
230
+ "\n",
231
+ " test_preds.extend(preds.cpu().numpy())\n",
232
+ " test_labels.extend(batch_y.cpu().numpy())\n",
233
+ "\n",
234
+ "# Calculate accuracy\n",
235
+ "test_accuracy = accuracy_score(test_labels, test_preds)\n",
236
+ "print(f\"Test Accuracy: {test_accuracy:.4f}\")"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "e2869985-46ac-4a07-a13d-c11e01ef9ec4",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": []
246
+ }
247
+ ],
248
+ "metadata": {
249
+ "kernelspec": {
250
+ "display_name": "Python 3 (ipykernel)",
251
+ "language": "python",
252
+ "name": "python3"
253
+ },
254
+ "language_info": {
255
+ "codemirror_mode": {
256
+ "name": "ipython",
257
+ "version": 3
258
+ },
259
+ "file_extension": ".py",
260
+ "mimetype": "text/x-python",
261
+ "name": "python",
262
+ "nbconvert_exporter": "python",
263
+ "pygments_lexer": "ipython3",
264
+ "version": "3.11.8"
265
+ }
266
+ },
267
+ "nbformat": 4,
268
+ "nbformat_minor": 5
269
+ }
final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de1266abb692bcdb1e6ceb8462ffa0a29be5921d7bbeaf11ffe6f265e6750db8
3
+ size 2778685