pendar02 commited on
Commit
957d2c2
·
verified ·
1 Parent(s): 52d880e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -9
app.py CHANGED
@@ -25,16 +25,43 @@ if 'text_processor' not in st.session_state:
25
 
26
  def load_model(model_type):
27
  """Load appropriate model based on type"""
28
- if model_type == "summarize":
29
- base_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
30
- model = PeftModel.from_pretrained(base_model, "pendar02/results")
31
- tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
32
- else: # question_focused
33
- base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
34
- model = PeftModel.from_pretrained(base_model, "pendar02/biobart-finetune")
35
- tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @st.cache_data
40
  def process_excel(uploaded_file):
 
25
 
26
  def load_model(model_type):
27
  """Load appropriate model based on type"""
28
+ try:
29
+ if model_type == "summarize":
30
+ base_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
31
+ # Try loading with ignore_mismatched_sizes=True and other safety parameters
32
+ model = PeftModel.from_pretrained(
33
+ base_model,
34
+ "pendar02/results",
35
+ is_trainable=False,
36
+ config_kwargs={"inference_mode": True}
37
+ )
38
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
39
+ else: # question_focused
40
+ base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
41
+ model = PeftModel.from_pretrained(
42
+ base_model,
43
+ "pendar02/biobart-finetune",
44
+ is_trainable=False,
45
+ config_kwargs={"inference_mode": True}
46
+ )
47
+ tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
48
+
49
+ # Ensure model is in evaluation mode
50
+ model.eval()
51
+ return model, tokenizer
52
 
53
+ except Exception as e:
54
+ # Fallback to base model if PEFT loading fails
55
+ st.warning(f"Error loading PEFT model: {str(e)}. Falling back to base model.")
56
+ if model_type == "summarize":
57
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
58
+ tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
59
+ else:
60
+ model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
61
+ tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
62
+
63
+ model.eval()
64
+ return model, tokenizer
65
 
66
  @st.cache_data
67
  def process_excel(uploaded_file):