Update app.py
Browse files
app.py
CHANGED
|
@@ -48,25 +48,19 @@ labels = st.radio("Select Label List", ("Label List #1", "Label List #2", "New L
|
|
| 48 |
if text == "Text #1": selected_text = text_1
|
| 49 |
elif text == "Text #2": selected_text = text_2
|
| 50 |
elif text == "New Text":
|
| 51 |
-
|
| 52 |
|
| 53 |
if labels == "Label List #1": selected_labels = label_list_1
|
| 54 |
elif labels == "Label List #2": selected_labels = label_list_2
|
| 55 |
elif labels == "New Label List":
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
-
@st.cache(allow_output_mutation=True)
|
| 59 |
-
def setModel(model_name):
|
| 60 |
-
tokenizer = MT5Tokenizer.from_pretrained(model_name)
|
| 61 |
-
model = MT5ForConditionalGeneration.from_pretrained(model_name)
|
| 62 |
-
model.eval()
|
| 63 |
-
return pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
|
| 64 |
|
| 65 |
Run_Button = st.button("Run", key=None)
|
| 66 |
if Run_Button == True:
|
| 67 |
-
|
| 68 |
-
zstc_pipeline = setModel(model_checkpoint)
|
| 69 |
-
output = zstc_pipeline(sequences=selected_text, candidate_labels=selected_labels)
|
| 70 |
output_labels = output["labels"]
|
| 71 |
output_scores = output["scores"]
|
| 72 |
|
|
|
|
| 48 |
if text == "Text #1": selected_text = text_1
|
| 49 |
elif text == "Text #2": selected_text = text_2
|
| 50 |
elif text == "New Text":
|
| 51 |
+
sequence_to_classify = st.text_area("New Text", value="", height=128)
|
| 52 |
|
| 53 |
if labels == "Label List #1": selected_labels = label_list_1
|
| 54 |
elif labels == "Label List #2": selected_labels = label_list_2
|
| 55 |
elif labels == "New Label List":
|
| 56 |
+
candidate_labels = st.text_area("New Label List (Pls Input as comma-separated)", value="", height=16).split(",")
|
| 57 |
+
|
| 58 |
+
hypothesis_template = "Bu yazı {} konusundadır."
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
Run_Button = st.button("Run", key=None)
|
| 62 |
if Run_Button == True:
|
| 63 |
+
output = runModel(model_name, sequence_to_classify, candidate_labels, hypothesis_template)
|
|
|
|
|
|
|
| 64 |
output_labels = output["labels"]
|
| 65 |
output_scores = output["scores"]
|
| 66 |
|