Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,26 +1,26 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
2 |
import itertools
|
3 |
import pickle
|
4 |
-
|
5 |
|
6 |
# Import and download necessary NLTK data for tokenization.
|
7 |
import nltk
|
8 |
from nltk.translate.bleu_score import sentence_bleu
|
9 |
-
|
10 |
nltk.download('punkt')
|
11 |
|
12 |
# Import the ROUGE metric implementation.
|
13 |
from rouge import Rouge
|
14 |
-
|
15 |
rouge = Rouge()
|
16 |
|
17 |
from datasets import load_dataset
|
18 |
-
import streamlit as st
|
19 |
-
|
20 |
# Use name="sample-10BT" to use the 10BT sample.
|
21 |
fw = load_dataset("HuggingFaceFW/fineweb", name="CC-MAIN-2024-10", split="train", streaming=True)
|
22 |
|
23 |
-
|
24 |
# Define helper functions for character-level accuracy and precision.
|
25 |
def char_accuracy(true_output, model_output):
|
26 |
# Compare matching characters in corresponding positions.
|
@@ -29,92 +29,85 @@ def char_accuracy(true_output, model_output):
|
|
29 |
total = max(len(true_output), len(model_output))
|
30 |
return matches / total if total > 0 else 1.0
|
31 |
|
32 |
-
|
33 |
def char_precision(true_output, model_output):
|
34 |
-
#
|
35 |
matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2)
|
36 |
return matches / len(model_output) if len(model_output) > 0 else 0.0
|
37 |
|
38 |
|
39 |
-
# Initialize
|
40 |
-
st.
|
41 |
-
st.
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
acc = []
|
50 |
-
pres
|
51 |
-
bleu
|
52 |
-
rouges
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
tag2 = model_output_full.find("</back>")
|
85 |
-
model_output = model_output_full[tag1 + 6: tag2]
|
86 |
-
st.subheader("Model Output")
|
87 |
-
st.write(model_output)
|
88 |
-
|
89 |
-
# Tokenize both outputs for BLEU calculation
|
90 |
-
reference_tokens = nltk.word_tokenize(true_output)
|
91 |
-
candidate_tokens = nltk.word_tokenize(model_output)
|
92 |
-
|
93 |
-
# Compute BLEU score (using the single reference)
|
94 |
-
bleu_score = sentence_bleu([reference_tokens], candidate_tokens)
|
95 |
-
st.write("**BLEU Score:**", bleu_score)
|
96 |
-
|
97 |
-
# Compute ROUGE scores
|
98 |
-
rouge_scores = rouge.get_scores(model_output, true_output)
|
99 |
-
st.write("**ROUGE Scores:**")
|
100 |
-
st.json(rouge_scores)
|
101 |
-
|
102 |
-
# Compute character-level accuracy and precision
|
103 |
-
accuracy_metric = char_accuracy(true_output, model_output)
|
104 |
-
precision_metric = char_precision(true_output, model_output)
|
105 |
-
st.write("**Character Accuracy:**", accuracy_metric)
|
106 |
-
st.write("**Character Precision:**", precision_metric)
|
107 |
-
|
108 |
-
st.markdown("---")
|
109 |
-
|
110 |
-
# Append metrics to lists
|
111 |
-
acc.append(accuracy_metric)
|
112 |
-
pres.append(precision_metric)
|
113 |
-
bleu.append(bleu_score)
|
114 |
-
rouges.append(rouge_scores)
|
115 |
-
|
116 |
-
# Allow the user to download the metrics
|
117 |
-
if st.button("Download Metrics"):
|
118 |
with open('accuracy.pkl', 'wb') as file:
|
119 |
pickle.dump(acc, file)
|
120 |
with open('precision.pkl', 'wb') as file:
|
@@ -123,10 +116,3 @@ if st.button("Download Metrics"):
|
|
123 |
pickle.dump(bleu, file)
|
124 |
with open('rouge.pkl', 'wb') as file:
|
125 |
pickle.dump(rouges, file)
|
126 |
-
st.success("Metrics saved successfully!")
|
127 |
-
|
128 |
-
# Provide download links
|
129 |
-
st.download_button('Download Accuracy Metrics', data=open('accuracy.pkl', 'rb'), file_name='accuracy.pkl')
|
130 |
-
st.download_button('Download Precision Metrics', data=open('precision.pkl', 'rb'), file_name='precision.pkl')
|
131 |
-
st.download_button('Download BLEU Metrics', data=open('bleu.pkl', 'rb'), file_name='bleu.pkl')
|
132 |
-
st.download_button('Download ROUGE Metrics', data=open('rouge.pkl', 'rb'), file_name='rouge.pkl')
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import requests
|
3 |
+
import clipboard
|
4 |
+
import random
|
5 |
+
import clipboard
|
6 |
+
import json
|
7 |
import itertools
|
8 |
import pickle
|
9 |
+
from huggingface_hub import InferenceClient
|
10 |
|
11 |
# Import and download necessary NLTK data for tokenization.
|
12 |
import nltk
|
13 |
from nltk.translate.bleu_score import sentence_bleu
|
|
|
14 |
nltk.download('punkt')
|
15 |
|
16 |
# Import the ROUGE metric implementation.
|
17 |
from rouge import Rouge
|
|
|
18 |
rouge = Rouge()
|
19 |
|
20 |
from datasets import load_dataset
|
|
|
|
|
21 |
# Use name="sample-10BT" to use the 10BT sample.
|
22 |
fw = load_dataset("HuggingFaceFW/fineweb", name="CC-MAIN-2024-10", split="train", streaming=True)
|
23 |
|
|
|
24 |
# Define helper functions for character-level accuracy and precision.
|
25 |
def char_accuracy(true_output, model_output):
|
26 |
# Compare matching characters in corresponding positions.
|
|
|
29 |
total = max(len(true_output), len(model_output))
|
30 |
return matches / total if total > 0 else 1.0
|
31 |
|
|
|
32 |
def char_precision(true_output, model_output):
|
33 |
+
# Here, precision is defined as matching characters divided by the length of the model's output.
|
34 |
matches = sum(1 for c1, c2 in zip(true_output, model_output) if c1 == c2)
|
35 |
return matches / len(model_output) if len(model_output) > 0 else 0.0
|
36 |
|
37 |
|
38 |
+
# Initialize session state to keep track of the current index and user input
|
39 |
+
if 'current_index' not in st.session_state:
|
40 |
+
st.session_state.current_index = 0
|
41 |
+
if 'user_input' not in st.session_state:
|
42 |
+
st.session_state.user_input = ''
|
43 |
+
|
44 |
+
def process_element(true_output, model_output):
|
45 |
+
global acc, pres, bleu, rouges
|
46 |
+
tag1 = model_output.find("<back>")
|
47 |
+
tag2 = model_output.find("</back>")
|
48 |
+
model_output = model_output[tag1 + 6:tag2]
|
49 |
+
print("Model output:", model_output)
|
50 |
+
|
51 |
+
# Tokenize both outputs for BLEU calculation.
|
52 |
+
reference_tokens = nltk.word_tokenize(true_output)
|
53 |
+
candidate_tokens = nltk.word_tokenize(model_output)
|
54 |
+
|
55 |
+
# Compute BLEU score (using the single reference).
|
56 |
+
bleu_score = sentence_bleu([reference_tokens], candidate_tokens)
|
57 |
+
print("BLEU score:", bleu_score)
|
58 |
+
|
59 |
+
# Compute ROUGE scores.
|
60 |
+
rouge_scores = rouge.get_scores(model_output, true_output)
|
61 |
+
print("ROUGE scores:", rouge_scores)
|
62 |
+
|
63 |
+
# Compute character-level accuracy and precision.
|
64 |
+
accuracy_metric = char_accuracy(true_output, model_output)
|
65 |
+
precision_metric = char_precision(true_output, model_output)
|
66 |
+
print("Character Accuracy:", accuracy_metric)
|
67 |
+
print("Character Precision:", precision_metric)
|
68 |
+
print("-" * 80)
|
69 |
+
acc.append(accuracy_metric)
|
70 |
+
pres.append(precision_metric)
|
71 |
+
bleu.append(bleu_score)
|
72 |
+
rouges.append(rouge_scores)
|
73 |
+
# Get 1 sample (modify the range if you need more samples).
|
74 |
+
samples = list(itertools.islice(fw, 100))
|
75 |
+
word_threshold = 100
|
76 |
acc = []
|
77 |
+
pres=[]
|
78 |
+
bleu=[]
|
79 |
+
rouges=[]
|
80 |
+
if st.session_state.current_index < len(samples):
|
81 |
+
# Display the current element
|
82 |
+
current_element = samples[st.session_state.current_index]["text"].split(" ")
|
83 |
+
st.write(f"**Element {st.session_state.current_index + 1}:** {current_element}")
|
84 |
+
current_element=" ".join(current_element).strip().replace("\n","")
|
85 |
+
prompt = "You are a helpful assistant that echoes the user's input, but backwards, do not simply rearrange the words, reverse the user's input down to the character (e.g. reverse Hello World to dlroW olleH). Surround the backwards version of the user's input with <back> </back> tags. "+current_element
|
86 |
+
clipboard.copy(prompt)
|
87 |
+
true_output = current_element[::-1]
|
88 |
+
print("True output:", true_output)
|
89 |
+
# Input widget to get user input, tied to session state
|
90 |
+
user_input = st.text_input(
|
91 |
+
"Enter your input:",
|
92 |
+
value=st.session_state.user_input,
|
93 |
+
key='input_field'
|
94 |
+
)
|
95 |
+
|
96 |
+
# Check if the user has entered input
|
97 |
+
if st.button("Submit"):
|
98 |
+
if user_input:
|
99 |
+
# Process the current element and user input
|
100 |
+
process_element(true_output, user_input)
|
101 |
+
# Clear the input field by resetting the session state
|
102 |
+
st.session_state.user_input = ''
|
103 |
+
# Move to the next element
|
104 |
+
st.session_state.current_index += 1
|
105 |
+
# Rerun the app to update the state
|
106 |
+
st.rerun()
|
107 |
+
else:
|
108 |
+
st.warning("Please enter your input before submitting.")
|
109 |
+
else:
|
110 |
+
st.success("All elements have been processed!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
with open('accuracy.pkl', 'wb') as file:
|
112 |
pickle.dump(acc, file)
|
113 |
with open('precision.pkl', 'wb') as file:
|
|
|
116 |
pickle.dump(bleu, file)
|
117 |
with open('rouge.pkl', 'wb') as file:
|
118 |
pickle.dump(rouges, file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|