sabssag commited on
Commit
c487204
·
verified ·
1 Parent(s): f68da3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import BartTokenizer, BartForConditionalGeneration, pipeline
3
  import zipfile
4
  import os
5
  import nltk
@@ -8,7 +8,7 @@ import nltk
8
  nltk.download('punkt')
9
  from nltk.tokenize import sent_tokenize
10
 
11
- # Define the path to the saved model zip file
12
  zip_model_path = 'bart_model-20240724T051306Z-001.zip'
13
 
14
  # Define the directory to extract the model
@@ -19,7 +19,7 @@ with zipfile.ZipFile(zip_model_path, 'r') as zip_ref:
19
  zip_ref.extractall(model_dir)
20
 
21
  # After unzipping, the model should be in a specific directory, check the directory structure
22
- model_path = os.path.join(model_dir, 'Bart')
23
 
24
  # Print out the model_path for debugging
25
  print("Model Path:", model_path)
@@ -28,9 +28,9 @@ print("Model Path:", model_path)
28
  if not os.path.exists(model_path):
29
  st.error(f"Model directory {model_path} does not exist or is incorrect.")
30
  else:
31
- # Load the tokenizer and model
32
- tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
33
- model = BartForConditionalGeneration.from_pretrained(model_path)
34
 
35
  # Create a summarization pipeline
36
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForConditionalGeneration, pipeline
3
  import zipfile
4
  import os
5
  import nltk
 
8
  nltk.download('punkt')
9
  from nltk.tokenize import sent_tokenize
10
 
11
+ # Define the path to the saved model zip file (ensure there is no extra space)
12
  zip_model_path = 'bart_model-20240724T051306Z-001.zip'
13
 
14
  # Define the directory to extract the model
 
19
  zip_ref.extractall(model_dir)
20
 
21
  # After unzipping, the model should be in a specific directory, check the directory structure
22
+ model_path = os.path.join(model_dir, 'Bart_model')
23
 
24
  # Print out the model_path for debugging
25
  print("Model Path:", model_path)
 
28
  if not os.path.exists(model_path):
29
  st.error(f"Model directory {model_path} does not exist or is incorrect.")
30
  else:
31
+ # Load the tokenizer and model using AutoTokenizer and AutoModelForConditionalGeneration
32
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
33
+ model = AutoModelForConditionalGeneration.from_pretrained(model_path)
34
 
35
  # Create a summarization pipeline
36
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer)