julianaconsuegra commited on
Commit
3d612f9
·
verified ·
1 Parent(s): 0ae53cb

added inference

Browse files
Files changed (1) hide show
  1. tasks/text.py +33 -3
tasks/text.py CHANGED
@@ -7,11 +7,19 @@ import random
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
 
 
 
 
 
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
 
 
15
  @router.post(ROUTE, tags=["Text Task"],
16
  description=DESCRIPTION)
17
  async def evaluate_text(request: TextEvaluationRequest):
@@ -37,6 +45,12 @@ async def evaluate_text(request: TextEvaluationRequest):
37
  "7_fossil_fuels_needed": 7
38
  }
39
 
 
 
 
 
 
 
40
  # Load and prepare the dataset
41
  dataset = load_dataset(request.dataset_name)
42
 
@@ -55,10 +69,26 @@ async def evaluate_text(request: TextEvaluationRequest):
55
  # YOUR MODEL INFERENCE CODE HERE
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Make random predictions (placeholder for actual model inference)
60
  true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
+
11
+ import tensorflow as tf
12
+ from huggingface_hub import hf_hub_download
13
+ from transformers import ElectraTokenizer
14
+
15
+
16
  router = APIRouter()
17
 
18
+ DESCRIPTION = "Electra with balanced dataset"
19
  ROUTE = "/text"
20
 
21
+
22
+
23
  @router.post(ROUTE, tags=["Text Task"],
24
  description=DESCRIPTION)
25
  async def evaluate_text(request: TextEvaluationRequest):
 
45
  "7_fossil_fuels_needed": 7
46
  }
47
 
48
+ # Download our pre-trained model from Hugging Face
49
+ model_path = hf_hub_download(repo_id="julianaconsuegra/electra-base-climate-disinformation", filename="tf_model.h5")
50
+
51
+ # Load the model
52
+ model = tf.keras.models.load_model(model_path)
53
+
54
  # Load and prepare the dataset
55
  dataset = load_dataset(request.dataset_name)
56
 
 
69
  # YOUR MODEL INFERENCE CODE HERE
70
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
71
  #--------------------------------------------------------------------------------------------
72
+ # Load ELECTRA tokenizer
73
+ tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-discriminator")
74
+
75
+ # Tokenize test data with same parameters as training
76
+ inputs = tokenizer(
77
+ test_dataset["text"],
78
+ truncation=True,
79
+ padding="max_length",
80
+ return_tensors="tf"
81
+ )
82
+
83
+ # Run model prediction
84
+ logits = model.predict({
85
+ "input_ids": inputs["input_ids"],
86
+ "attention_mask": inputs["attention_mask"]
87
+ })
88
+ predictions = tf.argmax(logits, axis=1).numpy()
89
 
90
+ # Get ground truth labels
91
  true_labels = test_dataset["label"]
 
92
 
93
  #--------------------------------------------------------------------------------------------
94
  # YOUR MODEL INFERENCE STOPS HERE