rosetta / app.py
yhavinga's picture
Add default texts and detect of language direction
bc21832
raw
history blame
5.92 kB
import time
import torch
import psutil
import streamlit as st
from generator import GeneratorFactory
from langdetect import detect
from default_texts import default_texts
device = torch.cuda.device_count() - 1
TRANSLATION_EN_TO_NL = "translation_en_to_nl"
TRANSLATION_NL_TO_EN = "translation_nl_to_en"
GENERATOR_LIST = [
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": True,
},
{
"model_name": "yhavinga/t5-small-24L-ccmatrix-multi",
"desc": "T5 small nl24 ccmatrix nl-en",
"task": TRANSLATION_NL_TO_EN,
"split_sentences": True,
},
{
"model_name": "Helsinki-NLP/opus-mt-en-nl",
"desc": "Opus MT en->nl",
"task": TRANSLATION_EN_TO_NL,
"split_sentences": True,
},
# {
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512l-nedd-256ccmatrix-en-nl",
# "desc": "longT5 large nl8 256cc/512beta/512l en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": False,
# },
# {
# "model_name": "yhavinga/byt5-small-ccmatrix-en-nl",
# "desc": "ByT5 small ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/t5-eff-large-8l-nedd-en-nl",
# "desc": "T5 eff large nl8 en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/t5-base-36L-ccmatrix-multi",
# "desc": "T5 base nl36 ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/longt5-local-eff-large-nl8-voc8k-ddwn-512beta-512-nedd-en-nl",
# "desc": "longT5 large nl8 512beta/512l en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": False,
# },
# {
# "model_name": "yhavinga/t5-base-36L-nedd-x-en-nl-300",
# "desc": "T5 base 36L nedd en->nl 300",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
# {
# "model_name": "yhavinga/long-t5-local-small-ccmatrix-en-nl",
# "desc": "longT5 small ccmatrix en->nl",
# "task": TRANSLATION_EN_TO_NL,
# "split_sentences": True,
# },
]
def main():
st.set_page_config( # Alternate names: setup_page, page, layout
page_title="Babel", # String or None. Strings get appended with "โ€ข Streamlit".
layout="wide", # Can be "centered" or "wide". In the future also "dashboard", etc.
initial_sidebar_state="expanded", # Can be "auto", "expanded", "collapsed"
page_icon="๐Ÿ“š", # String, anything supported by st.image, or None.
)
if "generators" not in st.session_state:
st.session_state["generators"] = GeneratorFactory(GENERATOR_LIST)
generators = st.session_state["generators"]
with open("style.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.sidebar.image("babel.png", width=200)
st.sidebar.markdown(
"""# Babel
Vertaal van en naar Engels"""
)
st.sidebar.title("Parameters:")
default_text = st.sidebar.radio(
"Change default text",
tuple(default_texts.keys()),
index=0,
)
if default_text or "prompt_box" not in st.session_state:
st.session_state["prompt_box"] = default_texts[default_text]["text"]
text_area = st.text_area("Enter text", st.session_state.prompt_box, height=300)
st.session_state["text"] = text_area
num_beams = st.sidebar.number_input("Num beams", min_value=1, max_value=10, value=1)
num_beam_groups = st.sidebar.number_input(
"Num beam groups", min_value=1, max_value=10, value=1
)
length_penalty = st.sidebar.number_input(
"Length penalty", min_value=0.0, max_value=2.0, value=1.2, step=0.1
)
st.sidebar.markdown(
"""For an explanation of the parameters, head over to the [Huggingface blog post about text generation](https://huggingface.co/blog/how-to-generate)
and the [Huggingface text generation interface doc](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).
"""
)
params = {
"num_beams": num_beams,
"num_beam_groups": num_beam_groups,
"length_penalty": length_penalty,
"early_stopping": True,
}
if st.button("Run"):
memory = psutil.virtual_memory()
language = detect(st.session_state.text)
if language == "en":
task = TRANSLATION_EN_TO_NL
elif language == "nl":
task = TRANSLATION_NL_TO_EN
else:
st.error(f"Language {language} not supported")
return
# Num beam groups should be a divisor of num beams
if num_beams % num_beam_groups != 0:
st.error("Num beams should be a multiple of num beam groups")
return
for generator in generators.filter(task=task):
st.markdown(f"๐Ÿงฎ **Model `{generator}`**")
time_start = time.time()
result, params_used = generator.generate(
text=st.session_state.text, **params
)
time_end = time.time()
time_diff = time_end - time_start
st.write(result.replace("\n", " \n"))
text_line = ", ".join([f"{k}={v}" for k, v in params_used.items()])
st.markdown(f" ๐Ÿ•™ *generated in {time_diff:.2f}s, `{text_line}`*")
st.write(
f"""
---
*Memory: {memory.total / 10**9:.2f}GB, used: {memory.percent}%, available: {memory.available / 10**9:.2f}GB*
"""
)
if __name__ == "__main__":
main()