|
import os |
|
os.system('pip install streamlit transformers torch') |
|
|
|
import streamlit as st |
|
from transformers import BartTokenizer, BartForConditionalGeneration |
|
|
|
|
|
model_name = 'ahmadrocks/facebook_bart_base_new' |
|
|
|
tokenizer = BartTokenizer.from_pretrained(model_name) |
|
model = BartForConditionalGeneration.from_pretrained(model_name) |
|
|
|
def summarize_text(text): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="longest") |
|
summary_ids = model.generate( |
|
inputs["input_ids"], |
|
max_length=150, |
|
min_length=30, |
|
length_penalty=2.0, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) |
|
return summary |
|
|
|
st.title("Text Summarization with Fine-Tuned Model") |
|
st.write("Enter text to generate a summary using the fine-tuned summarization model.") |
|
|
|
text = st.text_area("Input Text", height=200) |
|
if st.button("Summarize"): |
|
if text: |
|
with st.spinner("Summarizing..."): |
|
summary = summarize_text(text) |
|
st.success("Summary Generated") |
|
st.write(summary) |
|
else: |
|
st.warning("Please enter some text to summarize.") |
|
|
|
if _name_ == "_main_": |
|
st.set_option('deprecation.showfileUploaderEncoding', False) |
|
st.markdown( |
|
""" |
|
<style> |
|
.reportview-container { |
|
flex-direction: row; |
|
justify-content: center. |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |