Spaces:
Sleeping
Sleeping
File size: 2,911 Bytes
6e55b8d 82f5587 9a5c0c6 46f4682 9a5c0c6 82f5587 9a5c0c6 46f4682 82f5587 9a5c0c6 82f5587 9a5c0c6 82f5587 9a5c0c6 82f5587 9a5c0c6 82f5587 9a5c0c6 82f5587 9a5c0c6 82f5587 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import gdown
import os
# Set the title of the Streamlit app
st.title("Text Summarization with Fine-Tuned BART")
# Function to download the model from Google Drive
def download_model_from_drive(file_id, dest_path):
url = f'https://drive.google.com/uc?id={file_id}'
try:
gdown.download(url, dest_path, quiet=False)
st.success(f"Downloaded {dest_path}")
except Exception as e:
st.error(f"Error downloading {dest_path}: {e}")
# Ensure the model directory exists
model_dir = 'model'
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# File IDs for your model components
file_ids = {
'model': '1-V2bEtPR9Y3iBXK9zOR-qM5y9hKiQUnF',
'config': '1-T2etSP_k_3j5LzunWq8viKGQCQ5RMr_',
'tokenizer': '1-cRYNPWqlNNGRxeztympRRfVuy3hWuMY',
'vocab': '1-t9AhomeH7YIIpAqCGTok8wjvl0tml0F',
'merges': '1-l77_KEdK7GBFjMX_6UXGE-ZTGDraaDm'
}
# Download the model files
with st.spinner("Downloading model..."):
download_model_from_drive(file_ids['model'], os.path.join(model_dir, 'pytorch_model.bin'))
download_model_from_drive(file_ids['config'], os.path.join(model_dir, 'config.json'))
download_model_from_drive(file_ids['tokenizer'], os.path.join(model_dir, 'tokenizer.json'))
download_model_from_drive(file_ids['vocab'], os.path.join(model_dir, 'vocab.json'))
download_model_from_drive(file_ids['merges'], os.path.join(model_dir, 'merges.txt'))
# Load the model and tokenizer
@st.cache(allow_output_mutation=True)
def load_model_and_tokenizer():
try:
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
return tokenizer, model
except Exception as e:
st.error(f"Error loading model or tokenizer: {e}")
return None, None
tokenizer, model = load_model_and_tokenizer()
# Input text from user
input_text = st.text_area("Enter the text to summarize:")
if st.button("Summarize"):
if input_text:
if tokenizer and model:
try:
# Tokenize the input text
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
# Perform summarization
with torch.no_grad():
summary_ids = model.generate(inputs['input_ids'], max_length=150, num_beams=4, early_stopping=True)
# Decode the summary
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
st.write(f"Summary: {summary}")
except Exception as e:
st.error(f"Error during summarization: {e}")
else:
st.error("Model or tokenizer not loaded.")
else:
st.write("Please enter some text to summarize.")
|