BabaSambar commited on
Commit
5e13dab
·
1 Parent(s): a8fe7dc

add python file

Browse files
Files changed (2) hide show
  1. evaluate.py +49 -0
  2. requirements.txt +4 -0
evaluate.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ tokenizer = AutoTokenizer.from_pretrained("typeform/distilbert-base-uncased-mnli")
6
+ model = AutoModelForSequenceClassification.from_pretrained("typeform/distilbert-base-uncased-mnli")
7
+
8
+ # Label mapping (matches MNLI outputs)
9
+ label_mapping = ["entailment", "neutral", "contradiction"]
10
+
11
+ def check_entailment(premise, hypothesis):
12
+ """
13
+ Computes the entailment scores between a premise and a hypothesis.
14
+ Call this function several times for each RAG similarity.
15
+ Args:
16
+ premise (str): The reference text.
17
+ hypothesis (str): The statement to check.
18
+
19
+ Returns:
20
+ dict: A dictionary containing scores for entailment, neutral, and contradiction.
21
+ """
22
+ inputs = tokenizer(premise, hypothesis, return_tensors="pt")
23
+
24
+ with torch.no_grad():
25
+ logits = model(**inputs).logits
26
+
27
+ # Apply softmax
28
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)[0]
29
+ # Convert to dictionary
30
+ scores = {label_mapping[i]: probabilities[i].item() for i in range(len(label_mapping))}
31
+
32
+ return scores
33
+
34
+
35
+ demo = gr.Interface(
36
+ fn=check_entailment,
37
+ inputs=["text", "text"],
38
+ outputs=["json"]
39
+ )
40
+ demo.launch(share=True)
41
+ api_url = demo.share_url
42
+
43
+ data = {
44
+ "paragraph": "Software is instructions (computer programs) that when executed provide desired function and performance, data structures that enable the programs to adequately manipulate information, and documents that describe the operation and use of the programs.",
45
+ "hypothesis": "Software doesnt wear out."
46
+ }
47
+
48
+ response = requests.post(api_url, json=data)
49
+ print(response.json)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==5.17.1
2
+ requests
3
+ torch
4
+ transformers