Canstralian commited on
Commit
6da5491
·
verified ·
1 Parent(s): 704d821

Update fine_tuner.py

Browse files
Files changed (1) hide show
  1. fine_tuner.py +18 -9
fine_tuner.py CHANGED
@@ -3,20 +3,29 @@ from transformers import AutoModelForSequenceClassification, Trainer, TrainingAr
3
  from datasets import load_dataset
4
 
5
  def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
 
6
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
 
 
7
  training_args = TrainingArguments(
8
- output_dir='./results',
9
- num_train_epochs=epochs,
10
- per_device_train_batch_size=batch_size,
11
- learning_rate=learning_rate,
12
- logging_dir='./logs',
13
- logging_steps=10,
14
  )
 
 
15
  trainer = Trainer(
16
  model=model,
17
  args=training_args,
18
- train_dataset=dataset['train'],
19
- eval_dataset=dataset['validation'],
20
  )
 
 
21
  trainer.train()
22
- return {"status": "Training complete"}
 
 
 
3
  from datasets import load_dataset
4
 
5
  def fine_tune_model(dataset, model_name, epochs, batch_size, learning_rate):
6
+ # Load the pre-trained model for sequence classification
7
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
8
+
9
+ # Define the training arguments
10
  training_args = TrainingArguments(
11
+ output_dir='./results', # Directory for storing results
12
+ num_train_epochs=epochs, # Number of training epochs
13
+ per_device_train_batch_size=batch_size, # Batch size for training
14
+ learning_rate=learning_rate, # Learning rate for the optimizer
15
+ logging_dir='./logs', # Directory for storing logs
16
+ logging_steps=10, # Log every 10 steps
17
  )
18
+
19
+ # Initialize the Trainer with the model, arguments, and dataset
20
  trainer = Trainer(
21
  model=model,
22
  args=training_args,
23
+ train_dataset=dataset['train'], # Training dataset
24
+ eval_dataset=dataset['validation'], # Validation dataset
25
  )
26
+
27
+ # Train the model
28
  trainer.train()
29
+
30
+ # Return a status message after training completes
31
+ return {"status": "Training complete"}