Zen0 commited on
Commit
eb1168d
·
verified ·
1 Parent(s): f63546e

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +65 -76
tasks/text.py CHANGED
@@ -1,32 +1,20 @@
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- from fastapi import FastAPI, APIRouter
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from datetime import datetime
5
  from datasets import load_dataset
6
  from sklearn.metrics import accuracy_score
7
  import torch
8
- import numpy as np
9
 
10
  from .utils.evaluation import TextEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
13
- # Initialize FastAPI app and router
14
- app = FastAPI()
15
  router = APIRouter()
16
 
17
- # Add CORS middleware
18
- app.add_middleware(
19
- CORSMiddleware,
20
- allow_origins=["*"],
21
- allow_credentials=True,
22
- allow_methods=["*"],
23
- allow_headers=["*"],
24
- )
25
-
26
- DESCRIPTION = "Efficient Climate Disinformation Detection"
27
  ROUTE = "/text"
28
 
29
- @router.post("/text", tags=["Text Task"], description=DESCRIPTION)
30
  async def evaluate_text(request: TextEvaluationRequest):
31
  """
32
  Evaluate text classification for climate disinformation detection.
@@ -48,7 +36,11 @@ async def evaluate_text(request: TextEvaluationRequest):
48
 
49
  # Load and prepare the dataset
50
  dataset = load_dataset(request.dataset_name)
 
 
51
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
 
 
52
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
53
  test_dataset = train_test["test"]
54
 
@@ -57,65 +49,69 @@ async def evaluate_text(request: TextEvaluationRequest):
57
  tracker.start_task("inference")
58
 
59
  try:
60
- # Model configuration
61
- model_name = "distilbert-base-uncased"
62
- BATCH_SIZE = 64
63
- MAX_LENGTH = 128
64
-
65
- # Initialize tokenizer and model
66
- tokenizer = AutoTokenizer.from_pretrained(model_name)
67
- model = AutoModelForSequenceClassification.from_pretrained(
68
- model_name,
69
- num_labels=8,
70
- problem_type="single_label_classification"
71
- )
72
-
73
- # Enable mixed precision if available
74
- if torch.cuda.is_available():
75
- model = model.half()
76
-
77
- # Move model to device
78
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
- model = model.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  model.eval()
81
-
82
- # Get test texts
83
- test_texts = test_dataset["quote"]
84
  predictions = []
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Process in batches
87
- for i in range(0, len(test_texts), BATCH_SIZE):
88
- if torch.cuda.is_available():
89
- torch.cuda.empty_cache()
90
-
91
- batch_texts = test_texts[i:i + BATCH_SIZE]
92
-
93
- # Tokenize batch
94
- inputs = tokenizer(
95
- batch_texts,
96
- padding=True,
97
- truncation=True,
98
- max_length=MAX_LENGTH,
99
- return_tensors="pt"
100
- )
101
-
102
- # Move inputs to device
103
- inputs = {k: v.to(device) for k, v in inputs.items()}
104
-
105
- # Run inference
106
- with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
107
- outputs = model(**inputs)
108
- batch_preds = torch.argmax(outputs.logits, dim=1)
109
- predictions.extend(batch_preds.cpu().numpy())
110
-
111
- # Get true labels
112
- true_labels = test_dataset['label']
113
 
114
  # Stop tracking emissions
115
  emissions_data = tracker.stop_task()
116
 
117
  # Calculate accuracy
118
- accuracy = accuracy_score(true_labels, predictions)
119
 
120
  # Prepare results
121
  results = {
@@ -138,13 +134,6 @@ async def evaluate_text(request: TextEvaluationRequest):
138
  return results
139
 
140
  except Exception as e:
 
141
  tracker.stop_task()
142
- raise e
143
-
144
- # Include the router
145
- app.include_router(router)
146
-
147
- # Add a health check endpoint
148
- @app.get("/health")
149
- async def health_check():
150
- return {"status": "healthy"}
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ from fastapi import APIRouter
 
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
6
  import torch
7
+ from torch.utils.data import Dataset, DataLoader
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
 
 
12
  router = APIRouter()
13
 
14
+ DESCRIPTION = "Climate Disinformation Detection"
 
 
 
 
 
 
 
 
 
15
  ROUTE = "/text"
16
 
17
+ @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
18
  async def evaluate_text(request: TextEvaluationRequest):
19
  """
20
  Evaluate text classification for climate disinformation detection.
 
36
 
37
  # Load and prepare the dataset
38
  dataset = load_dataset(request.dataset_name)
39
+
40
+ # Convert string labels to integers
41
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
42
+
43
+ # Split dataset
44
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
45
  test_dataset = train_test["test"]
46
 
 
49
  tracker.start_task("inference")
50
 
51
  try:
52
+ # Get texts and labels
53
+ texts = test_dataset["quote"]
54
+ labels = test_dataset["label"]
55
+
56
+ # Load model and tokenizer from local directory
57
+ model_dir = "./"
58
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
59
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
60
+
61
+ # Define dataset class
62
+ class TextDataset(Dataset):
63
+ def __init__(self, texts, labels, tokenizer, max_len=128):
64
+ self.texts = texts
65
+ self.labels = labels
66
+ self.tokenizer = tokenizer
67
+ self.max_len = max_len
68
+
69
+ def __len__(self):
70
+ return len(self.texts)
71
+
72
+ def __getitem__(self, idx):
73
+ text = self.texts[idx]
74
+ label = self.labels[idx]
75
+ encodings = self.tokenizer(
76
+ text,
77
+ max_length=self.max_len,
78
+ padding='max_length',
79
+ truncation=True,
80
+ return_tensors="pt"
81
+ )
82
+ return {
83
+ 'input_ids': encodings['input_ids'].squeeze(0),
84
+ 'attention_mask': encodings['attention_mask'].squeeze(0),
85
+ 'labels': torch.tensor(label, dtype=torch.long)
86
+ }
87
+
88
+ # Create dataset and dataloader
89
+ test_dataset = TextDataset(texts, labels, tokenizer)
90
+ test_loader = DataLoader(test_dataset, batch_size=16)
91
+
92
+ # Model inference
93
  model.eval()
 
 
 
94
  predictions = []
95
+ ground_truth = []
96
+ device = 'cpu'
97
+
98
+ with torch.no_grad():
99
+ for batch in test_loader:
100
+ input_ids = batch['input_ids'].to(device)
101
+ attention_mask = batch['attention_mask'].to(device)
102
+ labels = batch['labels'].to(device)
103
+
104
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
105
+ _, predicted = torch.max(outputs.logits, 1)
106
 
107
+ predictions.extend(predicted.cpu().numpy())
108
+ ground_truth.extend(labels.cpu().numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  # Stop tracking emissions
111
  emissions_data = tracker.stop_task()
112
 
113
  # Calculate accuracy
114
+ accuracy = accuracy_score(test_dataset["label"], predictions)
115
 
116
  # Prepare results
117
  results = {
 
134
  return results
135
 
136
  except Exception as e:
137
+ # Stop tracking in case of error
138
  tracker.stop_task()
139
+ raise e