Spaces:
Runtime error
Runtime error
import gradio as gr | |
from nltk.tokenize import sent_tokenize | |
import torch | |
import ujson as json | |
from transformers import AutoModelForCausalLM,LlamaTokenizer,BitsAndBytesConfig | |
from peft import PeftModel | |
from keybert import KeyBERT | |
from keyphrase_vectorizers import KeyphraseCountVectorizer | |
import nltk | |
nltk.download('punkt') | |
# loads Guanaco 7B model - takes around 2-3 minutes - can do this separately | |
model_name = "decapoda-research/llama-7b-hf" | |
adapters_name = 'timdettmers/guanaco-7b' | |
# print(f"Starting to load the model {model_name} into memory") | |
m = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16) | |
m = PeftModel.from_pretrained(m, adapters_name) | |
m = m.merge_and_unload() | |
tok = LlamaTokenizer.from_pretrained(model_name) | |
tok.bos_token_id = 1 | |
stop_token_ids = [0] | |
# print(f"Successfully loaded the model {model_name} into memory") | |
print('Guanaco model loaded into memory.') | |
def keyphraseElaboration(title, abstract, userGivenKeyphrases, maxTokensElaboration, numAbstractSentencesKeyphrase): | |
numKeywordsToExtract = 2 | |
if userGivenKeyphrases == "": | |
''' | |
Process Abstract (eliminate word abstract at front and put into sentences) | |
''' | |
# eliminate word lowercase "abstract" or "abstract." at beginning of abstract text | |
if abstract.lower()[0:9] == "abstract.": | |
abstract = abstract[9:] | |
elif abstract.lower()[0:8] == "abstract": | |
abstract = abstract[8:] | |
abstractSentences = sent_tokenize(abstract) | |
tooShort = True # if the document only has one or fewer abstract sentences, then the document is too short for the keyphrase extraction/elaboration to give a meaningful output. | |
numAbstractSentences = len(abstractSentences) | |
if numAbstractSentences > 1: | |
tooShort = False | |
numAbstractSentencesKeyphrase = min(numAbstractSentences, numAbstractSentencesKeyphrase) | |
doc = f"{title}. {' '.join(abstractSentences[:numAbstractSentencesKeyphrase])}" | |
kw_model = KeyBERT(model="all-MiniLM-L6-v2") | |
vectorizer = KeyphraseCountVectorizer() | |
keywordsOut = kw_model.extract_keywords(doc, stop_words="english", top_n = numKeywordsToExtract, vectorizer=vectorizer, use_mmr=True) | |
keyBERTKeywords = [x[0] for x in keywordsOut] | |
for entry in keyBERTKeywords: | |
print(entry) | |
keywordString = "" | |
if userGivenKeyphrases != "": | |
keywordString = userGivenKeyphrases | |
elif not tooShort: | |
separator = ', ' | |
keywordString = separator.join(keyBERTKeywords) | |
prompt = "What is the purpose of studying " + keywordString + "? Comment on areas of application." | |
if keywordString != "": | |
formatted_prompt = ( | |
f"A chat between a curious human and an artificial intelligence assistant." | |
f"The assistant gives helpful, detailed, and polite answers to the user's questions.\n" | |
f"### Human: {prompt} \n" | |
f"### Assistant:" | |
) | |
inputs = tok(formatted_prompt, return_tensors="pt").to(deviceElaboration) | |
outputs = model.generate(inputs=inputs.input_ids, max_new_tokens=maxTokensElaboration) | |
output = tok.decode(outputs[0], skip_special_tokens=True) | |
index_response = output.find("### Assistant: ") + 15 | |
end_response = output.rfind('.') + 1 | |
response = output[index_response:end_response] | |
return keywordString, response | |
def plainLanguageSummary(title, abstract, maxTokensSummary, numAbstractSentencesSummary): | |
''' | |
Process Abstract (eliminate word abstract at front and put into sentences) | |
''' | |
# eliminate word lowercase "abstract" or "abstract." at beginning of abstract text | |
if abstract.lower()[0:9] == "abstract.": | |
abstract = abstract[9:] | |
elif abstract.lower()[0:8] == "abstract": | |
abstract = abstract[8:] | |
abstractSentences = sent_tokenize(abstract) | |
''' | |
This is for summarization | |
''' | |
prompt = """ | |
Can you explain the main idea of what is being studied in the following paragraph for someone who is not familiar with the topic. Comment on areas of application.: | |
""" | |
text = "" | |
if text == "": | |
numAbstractSentences = len(abstractSentences) | |
numAbstractSentencesSummary = min(numAbstractSentences, numAbstractSentencesSummary) | |
text = f"{title}. {' '.join(abstractSentences[:numAbstractSentencesSummary])}" | |
formatted_prompt = ( | |
f"A chat between a curious human and an artificial intelligence assistant." | |
f"The assistant gives helpful, detailed, and polite answers to the user's questions.\n" | |
f"### Human: {prompt + text} \n" | |
f"### Assistant:" | |
) | |
inputs = tok(formatted_prompt, return_tensors="pt").to(deviceSummary) | |
outputs = model.generate(inputs=inputs.input_ids, max_new_tokens=maxTokensSummary) | |
output = tok.decode(outputs[0], skip_special_tokens=True) | |
index_response = output.find("### Assistant: ") + 15 | |
if (output[index_response:index_response + 10] == "Certainly!" or output[index_response:index_response + 10] == "Certainly,"): | |
index_response += 10 | |
end_response = output.rfind('.') + 1 | |
response = output[index_response:end_response] | |
return response | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
title = gr.Textbox(label="Title") | |
abstract = gr.Textbox(label="Abstract") | |
userDefinedKeyphrases = gr.Textbox(label="Your keyphrases (Optional - Model will elaborate on these keyphrases without using the title or abstract)") | |
keyphraseButton = gr.Button("Generate Keyphrase Elaboration") | |
summaryButton = gr.Button("Generate Plain Language Summary") | |
with gr.Accordion(label="Parameters", open=False): | |
maxTokensElaboration = gr.Slider( | |
label="Maximum Number of Elaboration Tokens", | |
value=500, | |
minimum=0, | |
maximum=2048, | |
step=10, | |
interactive=True, | |
info="Length of Keyphrase Elaboration", | |
) | |
maxTokensSummary = gr.Slider( | |
label="Maximum Number of Summary Tokens", | |
value=300, | |
minimum=0, | |
maximum=2048, | |
step=10, | |
interactive=True, | |
info="Length of Plain Language Summary", | |
) | |
numAbstractSentencesKeyphrase = gr.Slider( | |
label="Number of Abstract Sentences to use for Keyphrase Extraction", | |
value=2, | |
minimum=0, | |
maximum=20, | |
step=1, | |
interactive=True, | |
info="Default: use first two sentences of abstract." | |
) | |
numAbstractSentencesSummary = gr.Slider( | |
label="Number of Abstract Sentences to use for Plain Language Summary", | |
value=2, | |
minimum=0, | |
maximum=20, | |
step=1, | |
interactive=True, | |
info="Default: use first two sentences of abstract." | |
) | |
with gr.Column(): | |
outputKeyphrase = [gr.Textbox(label="Keyphrases"), gr.Textbox(label="Keyphrase Elaboration")] | |
outputSummary = gr.Textbox(label="Plain Language Summary") | |
keyphraseButton.click(fn=keyphraseElaboration, inputs=[title, abstract, userDefinedKeyphrases, maxTokensElaboration, numAbstractSentencesKeyphrase], outputs=outputKeyphrase) | |
summaryButton.click(fn=plainLanguageSummary, inputs=[title, abstract, maxTokensSummary, numAbstractSentencesSummary], outputs = outputSummary) | |
demo.launch(share=True) | |