import streamlit as st from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import gdown import os # Set the title of the Streamlit app st.title("Text Summarization with Fine-Tuned BART") # Function to download the model from Google Drive def download_model_from_drive(file_id, dest_path): url = f'https://drive.google.com/uc?id={file_id}' try: gdown.download(url, dest_path, quiet=False) st.success(f"Downloaded {dest_path}") except Exception as e: st.error(f"Error downloading {dest_path}: {e}") # Ensure the model directory exists model_dir = 'model' if not os.path.exists(model_dir): os.makedirs(model_dir) # File IDs for your model components file_ids = { 'model': '1-V2bEtPR9Y3iBXK9zOR-qM5y9hKiQUnF', 'config': '1-T2etSP_k_3j5LzunWq8viKGQCQ5RMr_', 'tokenizer': '1-cRYNPWqlNNGRxeztympRRfVuy3hWuMY', 'vocab': '1-t9AhomeH7YIIpAqCGTok8wjvl0tml0F', 'merges': '1-l77_KEdK7GBFjMX_6UXGE-ZTGDraaDm' } # Download the model files with st.spinner("Downloading model..."): download_model_from_drive(file_ids['model'], os.path.join(model_dir, 'pytorch_model.bin')) download_model_from_drive(file_ids['config'], os.path.join(model_dir, 'config.json')) download_model_from_drive(file_ids['tokenizer'], os.path.join(model_dir, 'tokenizer.json')) download_model_from_drive(file_ids['vocab'], os.path.join(model_dir, 'vocab.json')) download_model_from_drive(file_ids['merges'], os.path.join(model_dir, 'merges.txt')) # Load the model and tokenizer @st.cache(allow_output_mutation=True) def load_model_and_tokenizer(): try: tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) return tokenizer, model except Exception as e: st.error(f"Error loading model or tokenizer: {e}") return None, None tokenizer, model = load_model_and_tokenizer() # Input text from user input_text = st.text_area("Enter the text to summarize:") if st.button("Summarize"): if input_text: if tokenizer and model: try: # Tokenize the input text inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) # Perform summarization with torch.no_grad(): summary_ids = model.generate(inputs['input_ids'], max_length=150, num_beams=4, early_stopping=True) # Decode the summary summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) st.write(f"Summary: {summary}") except Exception as e: st.error(f"Error during summarization: {e}") else: st.error("Model or tokenizer not loaded.") else: st.write("Please enter some text to summarize.")