|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import constants |
|
import pandas as pd |
|
import streamlit as st |
|
from huggingface_hub import hf_hub_download |
|
from GlotScript import get_script_predictor |
|
import matplotlib.pyplot as plt |
|
import fasttext |
|
import altair as alt |
|
from altair import X, Y, Scale |
|
import base64 |
|
|
|
|
|
@st.cache_resource |
|
def load_sp(): |
|
sp = get_script_predictor() |
|
return sp |
|
|
|
|
|
sp = load_sp() |
|
|
|
def get_script(text): |
|
"""Get the writing system of given text. |
|
|
|
Args: |
|
text: The text to be preprocessed. |
|
|
|
Returns: |
|
The writing system of text. |
|
""" |
|
|
|
return sp(text)[0] |
|
|
|
@st.cache_data |
|
def render_svg(svg): |
|
"""Renders the given svg string.""" |
|
b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") |
|
html = rf'<p align="center"> <img src="data:image/svg+xml;base64,{b64}"/> </p>' |
|
c = st.container() |
|
c.write(html, unsafe_allow_html=True) |
|
|
|
|
|
@st.cache_data |
|
def convert_df(df): |
|
|
|
return df.to_csv(index=None).encode("utf-8") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
model_path = hf_hub_download(repo_id=model_name, filename="model.bin") |
|
model = fasttext.load_model(model_path) |
|
return model |
|
|
|
|
|
model = load_model(constants.MODEL_NAME) |
|
|
|
|
|
def compute(sentences): |
|
"""Computes the language labels for the given sentences. |
|
|
|
Args: |
|
sentences: A list of sentences. |
|
|
|
Returns: |
|
A list of language probablities and labels for the given sentences. |
|
""" |
|
progress_text = "Computing Language..." |
|
my_bar = st.progress(0, text=progress_text) |
|
|
|
BATCH_SIZE = 1 |
|
probs = [] |
|
labels = [] |
|
preprocessed_sentences = sentences |
|
|
|
for first_index in range(0, len(preprocessed_sentences), BATCH_SIZE): |
|
|
|
outputs = model.predict(preprocessed_sentences[first_index : first_index + BATCH_SIZE]) |
|
|
|
|
|
outputs_labels = outputs[0][0] |
|
outputs_probs = outputs[1][0] |
|
|
|
probs = probs + [max(min(o, 1), 0) for o in outputs_probs] |
|
labels = labels + outputs_labels |
|
|
|
my_bar.progress( |
|
min((first_index + BATCH_SIZE) / len(preprocessed_sentences), 1), |
|
text=progress_text, |
|
) |
|
my_bar.empty() |
|
return probs, labels |
|
|
|
|
|
render_svg(open("assets/GlotLID_logo.svg").read()) |
|
|
|
tab1, tab2 = st.tabs(["Input a Sentence", "Upload a File"]) |
|
|
|
with tab1: |
|
sent = st.text_input( |
|
"Sentence:", placeholder="Enter a sentence.", on_change=None |
|
) |
|
|
|
|
|
clicked = st.button("Submit") |
|
|
|
if sent: |
|
probs, labels = compute([sent]) |
|
prob = probs[0] |
|
label = labels[0] |
|
|
|
ORANGE_COLOR = "#FF8000" |
|
fig, ax = plt.subplots(figsize=(8, 1)) |
|
fig.patch.set_facecolor("none") |
|
ax.set_facecolor("none") |
|
|
|
ax.spines["left"].set_color(ORANGE_COLOR) |
|
ax.spines["bottom"].set_color(ORANGE_COLOR) |
|
ax.tick_params(axis="x", colors=ORANGE_COLOR) |
|
|
|
ax.spines[["right", "top"]].set_visible(False) |
|
|
|
ax.barh(y=[0], width=[prob], color=ORANGE_COLOR) |
|
ax.set_xlim(0, 1) |
|
ax.set_ylim(-1, 1) |
|
ax.set_title(f"Langauge is: {label}", color=ORANGE_COLOR) |
|
ax.get_yaxis().set_visible(False) |
|
ax.set_xlabel("Confidence", color=ORANGE_COLOR) |
|
st.pyplot(fig) |
|
|
|
print(sent) |
|
with open("logs.txt", "a") as f: |
|
f.write(sent + "\n") |
|
|
|
with tab2: |
|
file = st.file_uploader("Upload a file", type=["txt"]) |
|
if file is not None: |
|
df = pd.read_csv(file, sep="\t", header=None) |
|
df.columns = ["Sentence"] |
|
df.reset_index(drop=True, inplace=True) |
|
|
|
|
|
df['Probs'], df["Language"] = compute(df["Sentence"].tolist()) |
|
|
|
|
|
st.markdown("""---""") |
|
|
|
chart = ( |
|
alt.Chart(df.reset_index()) |
|
.mark_area(color="darkorange", opacity=0.5) |
|
.encode( |
|
x=X(field="index", title="Sentence Index"), |
|
y=Y("Probs", scale=Scale(domain=[0, 1])), |
|
) |
|
) |
|
st.altair_chart(chart.interactive(), use_container_width=True) |
|
|
|
col1, col2 = st.columns([4, 1]) |
|
|
|
with col1: |
|
|
|
st.table( |
|
df, |
|
) |
|
|
|
with col2: |
|
|
|
csv = convert_df(df) |
|
st.download_button( |
|
label=":file_folder: Download predictions as CSV", |
|
data=csv, |
|
file_name="GlotLID.csv", |
|
mime="text/csv", |
|
) |
|
|