pendar02 commited on
Commit
1bd8049
·
verified ·
1 Parent(s): 957d2c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -9
app.py CHANGED
@@ -27,24 +27,47 @@ def load_model(model_type):
27
  """Load appropriate model based on type"""
28
  try:
29
  if model_type == "summarize":
30
- base_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
31
- # Try loading with ignore_mismatched_sizes=True and other safety parameters
 
 
 
32
  model = PeftModel.from_pretrained(
33
  base_model,
34
  "pendar02/results",
35
- is_trainable=False,
36
- config_kwargs={"inference_mode": True}
 
 
 
 
 
37
  )
38
- tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
39
  else: # question_focused
40
- base_model = AutoModelForSeq2SeqLM.from_pretrained("GanjinZero/biobart-base")
 
 
 
 
41
  model = PeftModel.from_pretrained(
42
  base_model,
43
  "pendar02/biobart-finetune",
44
- is_trainable=False,
45
- config_kwargs={"inference_mode": True}
 
46
  )
47
- tokenizer = AutoTokenizer.from_pretrained("GanjinZero/biobart-base")
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Ensure model is in evaluation mode
50
  model.eval()
 
27
  """Load appropriate model based on type"""
28
  try:
29
  if model_type == "summarize":
30
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
31
+ "facebook/bart-large-cnn",
32
+ cache_dir="./models"
33
+ )
34
+ # Load scientific lay summarizer model
35
  model = PeftModel.from_pretrained(
36
  base_model,
37
  "pendar02/results",
38
+ load_in_8bit=False,
39
+ device_map="auto",
40
+ torch_dtype=torch.float32
41
+ )
42
+ tokenizer = AutoTokenizer.from_pretrained(
43
+ "facebook/bart-large-cnn",
44
+ cache_dir="./models"
45
  )
 
46
  else: # question_focused
47
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
48
+ "GanjinZero/biobart-base",
49
+ cache_dir="./models"
50
+ )
51
+ # Load biobart fine-tuned model
52
  model = PeftModel.from_pretrained(
53
  base_model,
54
  "pendar02/biobart-finetune",
55
+ load_in_8bit=False,
56
+ device_map="auto",
57
+ torch_dtype=torch.float32
58
  )
59
+ tokenizer = AutoTokenizer.from_pretrained(
60
+ "GanjinZero/biobart-base",
61
+ cache_dir="./models"
62
+ )
63
+
64
+ # Ensure model is in evaluation mode
65
+ model.eval()
66
+ return model, tokenizer
67
+
68
+ except Exception as e:
69
+ st.error(f"Error loading model: {str(e)}")
70
+ raise
71
 
72
  # Ensure model is in evaluation mode
73
  model.eval()