Canstralian commited on
Commit
acf165a
·
verified ·
1 Parent(s): c079ce1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -25
app.py CHANGED
@@ -1,30 +1,29 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer
4
- from datasets import load_dataset
5
- import os
6
 
7
- # Function to fine-tune model
8
- def fine_tune(model_name, dataset_url, file, epochs, batch_size, learning_rate):
9
- try:
10
- # Load dataset
11
- if dataset_url:
12
- dataset = load_dataset(dataset_url)
13
- elif file:
14
- dataset = load_dataset("csv", data_files={"train": file.name})
15
- else:
16
- return "Please provide a dataset URL or upload a file."
17
 
18
- # Load model & tokenizer
19
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
20
- tokenizer = AutoTokenizer.from_pretrained(model_name)
21
 
22
- def tokenize_function(examples):
23
- return tokenizer(examples["text"], padding="max_length", truncation=True)
 
 
 
 
 
 
 
 
 
 
24
 
25
- dataset = dataset.map(tokenize_function, batched=True)
26
-
27
- # Define training arguments
28
- training_args = TrainingArguments(
29
- output_dir="./results",
30
- evaluation_strategy="epoch
 
1
  import gradio as gr
2
+ from model.model import fine_tune
3
+ from data.preprocess import load_data, preprocess_data, save_processed_data
 
 
4
 
5
+ def prepare_and_train(model_name, dataset_path, epochs, batch_size, learning_rate):
6
+ # Load and preprocess the dataset
7
+ data = load_data(dataset_path)
8
+ cleaned_data = preprocess_data(data)
9
+ processed_data_path = 'data/processed/processed_dataset.csv'
10
+ save_processed_data(cleaned_data, processed_data_path)
 
 
 
 
11
 
12
+ # Proceed with model fine-tuning
13
+ return fine_tune(model_name, dataset_url=None, file=processed_data_path, epochs=epochs, batch_size=batch_size, learning_rate=learning_rate)
 
14
 
15
+ iface = gr.Interface(
16
+ fn=prepare_and_train,
17
+ inputs=[
18
+ gr.Textbox(label="Model Name", placeholder="e.g., bert-base-uncased"),
19
+ gr.File(label="Upload Dataset"),
20
+ gr.Number(label="Epochs", value=3),
21
+ gr.Number(label="Batch Size", value=8),
22
+ gr.Number(label="Learning Rate", value=5e-5),
23
+ ],
24
+ outputs="text",
25
+ live=True,
26
+ )
27
 
28
+ if __name__ == "__main__":
29
+ iface.launch()