Zen0 commited on
Commit
45b3261
·
verified ·
1 Parent(s): 9e24c35

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +83 -29
tasks/text.py CHANGED
@@ -1,15 +1,18 @@
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
- import random
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
 
 
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
@@ -46,32 +49,83 @@ async def evaluate_text(request: TextEvaluationRequest):
46
  tracker.start()
47
  tracker.start_task("inference")
48
 
49
- # Get true labels
50
- true_labels = test_dataset["label"]
51
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
52
 
53
- # Stop tracking emissions
54
- emissions_data = tracker.stop_task()
55
-
56
- # Calculate accuracy
57
- accuracy = accuracy_score(true_labels, predictions)
58
-
59
- # Prepare results dictionary
60
- results = {
61
- "username": username,
62
- "space_url": space_url,
63
- "submission_timestamp": datetime.now().isoformat(),
64
- "model_description": DESCRIPTION,
65
- "accuracy": float(accuracy),
66
- "energy_consumed_wh": emissions_data.energy_consumed * 1000,
67
- "emissions_gco2eq": emissions_data.emissions * 1000,
68
- "emissions_data": clean_emissions_data(emissions_data),
69
- "api_route": ROUTE,
70
- "dataset_config": {
71
- "dataset_name": request.dataset_name,
72
- "test_size": request.test_size,
73
- "test_seed": request.test_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  }
75
- }
76
-
77
- return results
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from fastapi import APIRouter
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
+ import numpy as np
11
+ import torch
12
+
13
  router = APIRouter()
14
 
15
+ DESCRIPTION = "FrugalDisinfoHunter Model"
16
  ROUTE = "/text"
17
 
18
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
 
49
  tracker.start()
50
  tracker.start_task("inference")
51
 
52
+ try:
53
+ # Model configuration
54
+ model_name = "Zen0/FrugalDisinfoHunter" # Model path
55
+ tokenizer_name = "google/mobilebert-uncased" # Base MobileBERT tokenizer
56
+ BATCH_SIZE = 32 # Batch size for efficient processing
57
+ MAX_LENGTH = 128 # Maximum sequence length
58
 
59
+ # Initialize model and tokenizer
60
+ model = AutoModelForSequenceClassification.from_pretrained(
61
+ model_name,
62
+ num_labels=8,
63
+ output_hidden_states=True,
64
+ problem_type="single_label_classification"
65
+ )
66
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
67
+
68
+ # Move model to appropriate device
69
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ model = model.to(device)
71
+ model.eval() # Set model to evaluation mode
72
+
73
+ # Get test texts
74
+ test_texts = test_dataset["quote"]
75
+ predictions = []
76
+
77
+ # Process in batches
78
+ for i in range(0, len(test_texts), BATCH_SIZE):
79
+ batch_texts = test_texts[i:i + BATCH_SIZE]
80
+
81
+ # Tokenize batch
82
+ inputs = tokenizer(
83
+ batch_texts,
84
+ padding=True,
85
+ truncation=True,
86
+ return_tensors="pt",
87
+ max_length=MAX_LENGTH
88
+ )
89
+
90
+ # Move inputs to device
91
+ inputs = {key: val.to(device) for key, val in inputs.items()}
92
+
93
+ # Run inference
94
+ with torch.no_grad():
95
+ outputs = model(**inputs)
96
+ batch_preds = torch.argmax(outputs.logits, dim=1)
97
+ predictions.extend(batch_preds.cpu().numpy())
98
+
99
+ # Get true labels
100
+ true_labels = test_dataset['label']
101
+
102
+ # Stop tracking emissions
103
+ emissions_data = tracker.stop_task()
104
+
105
+ # Calculate accuracy
106
+ accuracy = accuracy_score(true_labels, predictions)
107
+
108
+ # Prepare results dictionary
109
+ results = {
110
+ "username": username,
111
+ "space_url": space_url,
112
+ "submission_timestamp": datetime.now().isoformat(),
113
+ "model_description": DESCRIPTION,
114
+ "accuracy": float(accuracy),
115
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
116
+ "emissions_gco2eq": emissions_data.emissions * 1000,
117
+ "emissions_data": clean_emissions_data(emissions_data),
118
+ "api_route": ROUTE,
119
+ "dataset_config": {
120
+ "dataset_name": request.dataset_name,
121
+ "test_size": request.test_size,
122
+ "test_seed": request.test_seed
123
+ }
124
  }
125
+
126
+ return results
127
+
128
+ except Exception as e:
129
+ # Stop tracking in case of error
130
+ tracker.stop_task()
131
+ raise e