JuanJoseMV's picture
hotfix
c16a6f3
raw
history blame
2.8 kB
import gradio as gr
from NeuralTextGenerator import BertTextGenerator
model_name = "cardiffnlp/twitter-xlm-roberta-base" #"dbmdz/bert-base-italian-uncased"
en_model = BertTextGenerator(model_name, tokenizer=model_name)
finetunned_BERT_model_name = "JuanJoseMV/BERT_text_gen"
finetunned_BERT_en_model = BertTextGenerator(finetunned_BERT_model_name, tokenizer='bert-base-uncased')
finetunned_RoBERTa_model_name = "JuanJoseMV/XLM_RoBERTa_text_gen"
finetunned_RoBERTa_en_model = BertTextGenerator(finetunned_RoBERTa_model_name, tokenizer=finetunned_RoBERTa_model_name)
special_tokens = [
'[POSITIVE-0]',
'[POSITIVE-1]',
'[POSITIVE-2]',
'[NEGATIVE-0]',
'[NEGATIVE-1]',
'[NEGATIVE-2]'
]
en_model.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
en_model.model.resize_token_embeddings(len(en_model.tokenizer))
finetunned_BERT_en_model.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
finetunned_BERT_en_model.model.resize_token_embeddings(len(en_model.tokenizer))
finetunned_RoBERTa_en_model.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
finetunned_RoBERTa_en_model.model.resize_token_embeddings(len(en_model.tokenizer))
def sentence_builder(selected_model, n_sentences, max_iter, sentiment, seed_text):
if selected_model == "Finetuned_RoBERTA":
generator = finetunned_RoBERTa_en_model
elif selected_model == "Finetuned_BERT":
generator = finetunned_BERT_en_model
else:
generator = en_model
parameters = {'n_sentences': n_sentences,
'batch_size': 2,
'avg_len':30,
'max_len':50,
# 'std_len' : 3,
'generation_method':'parallel',
'sample': True,
'burnin': 450,
'max_iter': max_iter,
'top_k': 100,
'seed_text': f"[{sentiment}-0] [{sentiment}-1] [{sentiment}-2] {seed_text}",
'verbose': True
}
sents = generator.generate(**parameters)
gen_text = ''
for i, s in enumerate(sents):
gen_text += f'- GENERATED TWEET #{i}: {s}\n'
return gen_text
demo = gr.Interface(
sentence_builder,
[
gr.Radio(["Pre-trained", "Finetuned_RoBERTA", "Finetunned_BERT"], value="Pre-trained", label="Sentiment to generate"),
gr.Slider(1, 15, value=2, label="Num. Tweets", step=1, info="Number of tweets to be generated."),
gr.Slider(50, 500, value=100, label="Max. iter", info="Maximum number of iterations for the generation."),
gr.Radio(["POSITIVE", "NEGATIVE"], value="POSITIVE", label="Sentiment to generate"),
gr.Textbox('', label="Seed text", info="Seed text for the generation.")
],
"text",
)
demo.launch()