Plum551 commited on
Commit
b4769ba
·
verified ·
1 Parent(s): f038bbb

Upload train_depression_model.py.py

Browse files
Files changed (1) hide show
  1. train_depression_model.py.py +263 -0
train_depression_model.py.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ from torch import nn
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from transformers import RobertaTokenizer, RobertaModel
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
8
+ from tqdm import tqdm
9
+ import argparse
10
+
11
+ # 1. Dataset Class
12
+ class DepressionDataset(Dataset):
13
+ def __init__(self, df, tokenizer, max_length=256):
14
+ self.texts = df['clean_text'].values
15
+ self.labels = df['is_depression'].values
16
+ self.tokenizer = tokenizer
17
+ self.max_length = max_length
18
+
19
+ def __len__(self):
20
+ return len(self.texts)
21
+
22
+ def __getitem__(self, idx):
23
+ text = str(self.texts[idx])
24
+ label = self.labels[idx]
25
+
26
+ encoding = self.tokenizer.encode_plus(
27
+ text,
28
+ add_special_tokens=True,
29
+ max_length=self.max_length,
30
+ padding='max_length',
31
+ truncation=True,
32
+ return_attention_mask=True,
33
+ return_tensors='pt'
34
+ )
35
+
36
+ return {
37
+ 'input_ids': encoding['input_ids'].flatten(),
38
+ 'attention_mask': encoding['attention_mask'].flatten(),
39
+ 'label': torch.tensor(label, dtype=torch.long)
40
+ }
41
+
42
+ # 2. Model Class
43
+ class DepressionClassifier(nn.Module):
44
+ def __init__(self, dropout_rate=0.1):
45
+ super(DepressionClassifier, self).__init__()
46
+ self.roberta = RobertaModel.from_pretrained('roberta-base')
47
+ self.dropout = nn.Dropout(dropout_rate)
48
+ self.classifier = nn.Linear(768, 2)
49
+
50
+ def forward(self, input_ids, attention_mask):
51
+ outputs = self.roberta(
52
+ input_ids=input_ids,
53
+ attention_mask=attention_mask
54
+ )
55
+
56
+ sequence_output = outputs.last_hidden_state[:, 0, :]
57
+ sequence_output = self.dropout(sequence_output)
58
+ logits = self.classifier(sequence_output)
59
+
60
+ return logits
61
+
62
+ # 3. Prepare data loaders
63
+ def prepare_dataloaders(df, batch_size=16):
64
+ # Split data
65
+ train_df, temp_df = train_test_split(df, test_size=0.3, stratify=df['is_depression'], random_state=42)
66
+ val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['is_depression'], random_state=42)
67
+
68
+ # Initialize tokenizer
69
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
70
+
71
+ # Create datasets
72
+ train_dataset = DepressionDataset(train_df, tokenizer)
73
+ val_dataset = DepressionDataset(val_df, tokenizer)
74
+ test_dataset = DepressionDataset(test_df, tokenizer)
75
+
76
+ # Create dataloaders
77
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
78
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
79
+ test_loader = DataLoader(test_dataset, batch_size=batch_size)
80
+
81
+ return train_loader, val_loader, test_loader
82
+
83
+ # 4. Training function
84
+ def train_model(model, train_loader, val_loader, device, epochs=3, learning_rate=2e-5):
85
+ # Move model to device
86
+ model = model.to(device)
87
+
88
+ # Initialize optimizer
89
+ optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
90
+
91
+ # Initialize loss function
92
+ loss_fn = nn.CrossEntropyLoss()
93
+
94
+ # Training loop
95
+ best_accuracy = 0
96
+
97
+ for epoch in range(epochs):
98
+ print(f'Epoch {epoch + 1}/{epochs}')
99
+
100
+ # TRAINING
101
+ model.train()
102
+ train_loss = 0
103
+ train_preds = []
104
+ train_labels = []
105
+
106
+ # Progress bar for training
107
+ progress_bar = tqdm(train_loader, desc="Training")
108
+
109
+ for batch in progress_bar:
110
+ # Get batch data
111
+ input_ids = batch['input_ids'].to(device)
112
+ attention_mask = batch['attention_mask'].to(device)
113
+ labels = batch['label'].to(device)
114
+
115
+ # Forward pass
116
+ optimizer.zero_grad()
117
+ outputs = model(input_ids, attention_mask)
118
+ loss = loss_fn(outputs, labels)
119
+
120
+ # Backward pass
121
+ loss.backward()
122
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
123
+ optimizer.step()
124
+
125
+ # Track metrics
126
+ train_loss += loss.item()
127
+ _, preds = torch.max(outputs, dim=1)
128
+ train_preds.extend(preds.cpu().tolist())
129
+ train_labels.extend(labels.cpu().tolist())
130
+
131
+ # Update progress bar
132
+ progress_bar.set_postfix({'loss': loss.item()})
133
+
134
+ # Calculate training metrics
135
+ avg_train_loss = train_loss / len(train_loader)
136
+ train_accuracy = accuracy_score(train_labels, train_preds)
137
+
138
+ # VALIDATION
139
+ model.eval()
140
+ val_loss = 0
141
+ val_preds = []
142
+ val_labels = []
143
+
144
+ with torch.no_grad():
145
+ for batch in tqdm(val_loader, desc="Validation"):
146
+ # Get batch data
147
+ input_ids = batch['input_ids'].to(device)
148
+ attention_mask = batch['attention_mask'].to(device)
149
+ labels = batch['label'].to(device)
150
+
151
+ # Forward pass
152
+ outputs = model(input_ids, attention_mask)
153
+ loss = loss_fn(outputs, labels)
154
+
155
+ # Track metrics
156
+ val_loss += loss.item()
157
+ _, preds = torch.max(outputs, dim=1)
158
+ val_preds.extend(preds.cpu().tolist())
159
+ val_labels.extend(labels.cpu().tolist())
160
+
161
+ # Calculate validation metrics
162
+ avg_val_loss = val_loss / len(val_loader)
163
+ val_accuracy = accuracy_score(val_labels, val_preds)
164
+
165
+ # Print metrics
166
+ print(f'Train Loss: {avg_train_loss:.4f} | Train Accuracy: {train_accuracy:.4f}')
167
+ print(f'Val Loss: {avg_val_loss:.4f} | Val Accuracy: {val_accuracy:.4f}')
168
+
169
+ # Save best model
170
+ if val_accuracy > best_accuracy:
171
+ torch.save(model.state_dict(), 'best_model.pt')
172
+ best_accuracy = val_accuracy
173
+ print(f'New best model saved with accuracy: {val_accuracy:.4f}')
174
+
175
+ print('-' * 50)
176
+
177
+ # Load best model
178
+ model.load_state_dict(torch.load('best_model.pt'))
179
+ return model
180
+
181
+ # 5. Evaluation function
182
+ def evaluate_model(model, test_loader, device):
183
+ model.eval()
184
+ test_preds = []
185
+ test_labels = []
186
+
187
+ with torch.no_grad():
188
+ for batch in tqdm(test_loader, desc="Testing"):
189
+ input_ids = batch['input_ids'].to(device)
190
+ attention_mask = batch['attention_mask'].to(device)
191
+ labels = batch['label'].to(device)
192
+
193
+ outputs = model(input_ids, attention_mask)
194
+ _, preds = torch.max(outputs, dim=1)
195
+
196
+ test_preds.extend(preds.cpu().tolist())
197
+ test_labels.extend(labels.cpu().tolist())
198
+
199
+ # Calculate metrics
200
+ accuracy = accuracy_score(test_labels, test_preds)
201
+ precision, recall, f1, _ = precision_recall_fscore_support(
202
+ test_labels, test_preds, average='binary'
203
+ )
204
+
205
+ return {
206
+ 'accuracy': accuracy,
207
+ 'precision': precision,
208
+ 'recall': recall,
209
+ 'f1': f1
210
+ }
211
+
212
+ # 6. Main function
213
+ def main():
214
+ parser = argparse.ArgumentParser(description='Train depression classifier')
215
+ parser.add_argument('--data_path', type=str, default='depression_dataset_reddit_cleaned_final.csv',
216
+ help='Path to the cleaned dataset')
217
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training')
218
+ parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')
219
+ parser.add_argument('--learning_rate', type=float, default=2e-5, help='Learning rate')
220
+ args = parser.parse_args()
221
+
222
+ # Check for GPU
223
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
224
+ print(f'Using device: {device}')
225
+
226
+ # Load data
227
+ df = pd.read_csv(args.data_path)
228
+ print(f'Loaded dataset with {len(df)} examples')
229
+
230
+ # Prepare data
231
+ train_loader, val_loader, test_loader = prepare_dataloaders(
232
+ df, batch_size=args.batch_size
233
+ )
234
+ print(f'Training samples: {len(train_loader.dataset)}')
235
+ print(f'Validation samples: {len(val_loader.dataset)}')
236
+ print(f'Testing samples: {len(test_loader.dataset)}')
237
+
238
+ # Create model
239
+ model = DepressionClassifier()
240
+ print('Model created')
241
+
242
+ # Train model
243
+ print('Starting training...')
244
+ trained_model = train_model(
245
+ model,
246
+ train_loader,
247
+ val_loader,
248
+ device,
249
+ epochs=args.epochs,
250
+ learning_rate=args.learning_rate
251
+ )
252
+
253
+ # Evaluate model
254
+ print('Evaluating model...')
255
+ metrics = evaluate_model(trained_model, test_loader, device)
256
+
257
+ # Print results
258
+ print('\nTest Results:')
259
+ for metric, value in metrics.items():
260
+ print(f'{metric}: {value:.4f}')
261
+
262
+ if __name__ == '__main__':
263
+ main()