Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -38,15 +38,15 @@ def load_model(model_type):
|
|
38 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
39 |
"facebook/bart-large-cnn",
|
40 |
cache_dir="./models",
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
model = PeftModel.from_pretrained(
|
45 |
base_model,
|
46 |
"pendar02/results",
|
47 |
-
|
48 |
-
torch_dtype=torch.float32
|
49 |
).to(device)
|
|
|
50 |
tokenizer = AutoTokenizer.from_pretrained(
|
51 |
"facebook/bart-large-cnn",
|
52 |
cache_dir="./models"
|
@@ -55,15 +55,15 @@ def load_model(model_type):
|
|
55 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
56 |
"GanjinZero/biobart-base",
|
57 |
cache_dir="./models",
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
model = PeftModel.from_pretrained(
|
62 |
base_model,
|
63 |
"pendar02/biobart-finetune",
|
64 |
-
|
65 |
-
torch_dtype=torch.float32
|
66 |
).to(device)
|
|
|
67 |
tokenizer = AutoTokenizer.from_pretrained(
|
68 |
"GanjinZero/biobart-base",
|
69 |
cache_dir="./models"
|
@@ -74,6 +74,12 @@ def load_model(model_type):
|
|
74 |
except Exception as e:
|
75 |
st.error(f"Error loading model: {str(e)}")
|
76 |
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
def cleanup_model(model, tokenizer):
|
79 |
"""Properly cleanup model resources"""
|
|
|
38 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
39 |
"facebook/bart-large-cnn",
|
40 |
cache_dir="./models",
|
41 |
+
torch_dtype=torch.float32
|
42 |
+
).to(device)
|
43 |
+
|
44 |
model = PeftModel.from_pretrained(
|
45 |
base_model,
|
46 |
"pendar02/results",
|
47 |
+
is_trainable=False
|
|
|
48 |
).to(device)
|
49 |
+
|
50 |
tokenizer = AutoTokenizer.from_pretrained(
|
51 |
"facebook/bart-large-cnn",
|
52 |
cache_dir="./models"
|
|
|
55 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
56 |
"GanjinZero/biobart-base",
|
57 |
cache_dir="./models",
|
58 |
+
torch_dtype=torch.float32
|
59 |
+
).to(device)
|
60 |
+
|
61 |
model = PeftModel.from_pretrained(
|
62 |
base_model,
|
63 |
"pendar02/biobart-finetune",
|
64 |
+
is_trainable=False
|
|
|
65 |
).to(device)
|
66 |
+
|
67 |
tokenizer = AutoTokenizer.from_pretrained(
|
68 |
"GanjinZero/biobart-base",
|
69 |
cache_dir="./models"
|
|
|
74 |
except Exception as e:
|
75 |
st.error(f"Error loading model: {str(e)}")
|
76 |
raise
|
77 |
+
|
78 |
+
model.eval()
|
79 |
+
return model, tokenizer
|
80 |
+
except Exception as e:
|
81 |
+
st.error(f"Error loading model: {str(e)}")
|
82 |
+
raise
|
83 |
|
84 |
def cleanup_model(model, tokenizer):
|
85 |
"""Properly cleanup model resources"""
|