Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
import gdown | |
import os | |
# Set the title of the Streamlit app | |
st.title("Text Summarization with Fine-Tuned BART") | |
# Function to download the model from Google Drive | |
def download_model_from_drive(file_id, dest_path): | |
url = f'https://drive.google.com/uc?id={file_id}' | |
try: | |
gdown.download(url, dest_path, quiet=False) | |
st.success(f"Downloaded {dest_path}") | |
except Exception as e: | |
st.error(f"Error downloading {dest_path}: {e}") | |
# Ensure the model directory exists | |
model_dir = 'model' | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
# File IDs for your model components | |
file_ids = { | |
'model': '1-V2bEtPR9Y3iBXK9zOR-qM5y9hKiQUnF', | |
'config': '1-T2etSP_k_3j5LzunWq8viKGQCQ5RMr_', | |
'tokenizer': '1-cRYNPWqlNNGRxeztympRRfVuy3hWuMY', | |
'vocab': '1-t9AhomeH7YIIpAqCGTok8wjvl0tml0F', | |
'merges': '1-l77_KEdK7GBFjMX_6UXGE-ZTGDraaDm' | |
} | |
# Download the model files | |
with st.spinner("Downloading model..."): | |
download_model_from_drive(file_ids['model'], os.path.join(model_dir, 'pytorch_model.bin')) | |
download_model_from_drive(file_ids['config'], os.path.join(model_dir, 'config.json')) | |
download_model_from_drive(file_ids['tokenizer'], os.path.join(model_dir, 'tokenizer.json')) | |
download_model_from_drive(file_ids['vocab'], os.path.join(model_dir, 'vocab.json')) | |
download_model_from_drive(file_ids['merges'], os.path.join(model_dir, 'merges.txt')) | |
# Load the model and tokenizer | |
def load_model_and_tokenizer(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_dir) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Error loading model or tokenizer: {e}") | |
return None, None | |
tokenizer, model = load_model_and_tokenizer() | |
# Input text from user | |
input_text = st.text_area("Enter the text to summarize:") | |
if st.button("Summarize"): | |
if input_text: | |
if tokenizer and model: | |
try: | |
# Tokenize the input text | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
# Perform summarization | |
with torch.no_grad(): | |
summary_ids = model.generate(inputs['input_ids'], max_length=150, num_beams=4, early_stopping=True) | |
# Decode the summary | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
st.write(f"Summary: {summary}") | |
except Exception as e: | |
st.error(f"Error during summarization: {e}") | |
else: | |
st.error("Model or tokenizer not loaded.") | |
else: | |
st.write("Please enter some text to summarize.") | |