Zen0 commited on
Commit
9e24c35
·
verified ·
1 Parent(s): eb1168d

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +29 -91
tasks/text.py CHANGED
@@ -1,17 +1,15 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from fastapi import APIRouter
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
6
- import torch
7
- from torch.utils.data import Dataset, DataLoader
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  router = APIRouter()
13
 
14
- DESCRIPTION = "Climate Disinformation Detection"
15
  ROUTE = "/text"
16
 
17
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
@@ -48,92 +46,32 @@ async def evaluate_text(request: TextEvaluationRequest):
48
  tracker.start()
49
  tracker.start_task("inference")
50
 
51
- try:
52
- # Get texts and labels
53
- texts = test_dataset["quote"]
54
- labels = test_dataset["label"]
55
-
56
- # Load model and tokenizer from local directory
57
- model_dir = "./"
58
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
59
- model = AutoModelForSequenceClassification.from_pretrained(model_dir)
60
-
61
- # Define dataset class
62
- class TextDataset(Dataset):
63
- def __init__(self, texts, labels, tokenizer, max_len=128):
64
- self.texts = texts
65
- self.labels = labels
66
- self.tokenizer = tokenizer
67
- self.max_len = max_len
68
 
69
- def __len__(self):
70
- return len(self.texts)
71
-
72
- def __getitem__(self, idx):
73
- text = self.texts[idx]
74
- label = self.labels[idx]
75
- encodings = self.tokenizer(
76
- text,
77
- max_length=self.max_len,
78
- padding='max_length',
79
- truncation=True,
80
- return_tensors="pt"
81
- )
82
- return {
83
- 'input_ids': encodings['input_ids'].squeeze(0),
84
- 'attention_mask': encodings['attention_mask'].squeeze(0),
85
- 'labels': torch.tensor(label, dtype=torch.long)
86
- }
87
-
88
- # Create dataset and dataloader
89
- test_dataset = TextDataset(texts, labels, tokenizer)
90
- test_loader = DataLoader(test_dataset, batch_size=16)
91
-
92
- # Model inference
93
- model.eval()
94
- predictions = []
95
- ground_truth = []
96
- device = 'cpu'
97
-
98
- with torch.no_grad():
99
- for batch in test_loader:
100
- input_ids = batch['input_ids'].to(device)
101
- attention_mask = batch['attention_mask'].to(device)
102
- labels = batch['labels'].to(device)
103
-
104
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
105
- _, predicted = torch.max(outputs.logits, 1)
106
-
107
- predictions.extend(predicted.cpu().numpy())
108
- ground_truth.extend(labels.cpu().numpy())
109
-
110
- # Stop tracking emissions
111
- emissions_data = tracker.stop_task()
112
-
113
- # Calculate accuracy
114
- accuracy = accuracy_score(test_dataset["label"], predictions)
115
-
116
- # Prepare results
117
- results = {
118
- "username": username,
119
- "space_url": space_url,
120
- "submission_timestamp": datetime.now().isoformat(),
121
- "model_description": DESCRIPTION,
122
- "accuracy": float(accuracy),
123
- "energy_consumed_wh": emissions_data.energy_consumed * 1000,
124
- "emissions_gco2eq": emissions_data.emissions * 1000,
125
- "emissions_data": clean_emissions_data(emissions_data),
126
- "api_route": ROUTE,
127
- "dataset_config": {
128
- "dataset_name": request.dataset_name,
129
- "test_size": request.test_size,
130
- "test_seed": request.test_seed
131
- }
132
  }
133
-
134
- return results
135
-
136
- except Exception as e:
137
- # Stop tracking in case of error
138
- tracker.stop_task()
139
- raise e
 
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
+ import random
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
+ DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
 
46
  tracker.start()
47
  tracker.start_task("inference")
48
 
49
+ # Get true labels
50
+ true_labels = test_dataset["label"]
51
+ predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Stop tracking emissions
54
+ emissions_data = tracker.stop_task()
55
+
56
+ # Calculate accuracy
57
+ accuracy = accuracy_score(true_labels, predictions)
58
+
59
+ # Prepare results dictionary
60
+ results = {
61
+ "username": username,
62
+ "space_url": space_url,
63
+ "submission_timestamp": datetime.now().isoformat(),
64
+ "model_description": DESCRIPTION,
65
+ "accuracy": float(accuracy),
66
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
67
+ "emissions_gco2eq": emissions_data.emissions * 1000,
68
+ "emissions_data": clean_emissions_data(emissions_data),
69
+ "api_route": ROUTE,
70
+ "dataset_config": {
71
+ "dataset_name": request.dataset_name,
72
+ "test_size": request.test_size,
73
+ "test_seed": request.test_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  }
75
+ }
76
+
77
+ return results