Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -30,22 +30,23 @@ def load_model(model_type):
|
|
30 |
"""Load appropriate model based on type with proper memory management"""
|
31 |
try:
|
32 |
# Clear any existing cached data
|
33 |
-
torch.cuda.empty_cache()
|
34 |
gc.collect()
|
35 |
|
|
|
|
|
36 |
if model_type == "summarize":
|
37 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
38 |
"facebook/bart-large-cnn",
|
39 |
cache_dir="./models",
|
40 |
low_cpu_mem_usage=True,
|
41 |
-
|
42 |
)
|
43 |
model = PeftModel.from_pretrained(
|
44 |
base_model,
|
45 |
"pendar02/results",
|
46 |
-
device_map="
|
47 |
torch_dtype=torch.float32
|
48 |
-
)
|
49 |
tokenizer = AutoTokenizer.from_pretrained(
|
50 |
"facebook/bart-large-cnn",
|
51 |
cache_dir="./models"
|
@@ -55,14 +56,14 @@ def load_model(model_type):
|
|
55 |
"GanjinZero/biobart-base",
|
56 |
cache_dir="./models",
|
57 |
low_cpu_mem_usage=True,
|
58 |
-
|
59 |
)
|
60 |
model = PeftModel.from_pretrained(
|
61 |
base_model,
|
62 |
"pendar02/biobart-finetune",
|
63 |
-
device_map="
|
64 |
torch_dtype=torch.float32
|
65 |
-
)
|
66 |
tokenizer = AutoTokenizer.from_pretrained(
|
67 |
"GanjinZero/biobart-base",
|
68 |
cache_dir="./models"
|
@@ -137,6 +138,7 @@ def generate_summary(text, model, tokenizer):
|
|
137 |
min_length = min(50, word_count) # Dynamic min length
|
138 |
|
139 |
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
|
|
|
140 |
|
141 |
with torch.no_grad():
|
142 |
summary_ids = model.generate(
|
@@ -167,6 +169,7 @@ def generate_focused_summary(question, abstracts, model, tokenizer):
|
|
167 |
combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts)
|
168 |
|
169 |
inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
|
|
|
170 |
|
171 |
with torch.no_grad():
|
172 |
summary_ids = model.generate(
|
|
|
30 |
"""Load appropriate model based on type with proper memory management"""
|
31 |
try:
|
32 |
# Clear any existing cached data
|
|
|
33 |
gc.collect()
|
34 |
|
35 |
+
device = "cpu" # Force CPU usage
|
36 |
+
|
37 |
if model_type == "summarize":
|
38 |
base_model = AutoModelForSeq2SeqLM.from_pretrained(
|
39 |
"facebook/bart-large-cnn",
|
40 |
cache_dir="./models",
|
41 |
low_cpu_mem_usage=True,
|
42 |
+
device_map={"": device}
|
43 |
)
|
44 |
model = PeftModel.from_pretrained(
|
45 |
base_model,
|
46 |
"pendar02/results",
|
47 |
+
device_map={"": device},
|
48 |
torch_dtype=torch.float32
|
49 |
+
).to(device)
|
50 |
tokenizer = AutoTokenizer.from_pretrained(
|
51 |
"facebook/bart-large-cnn",
|
52 |
cache_dir="./models"
|
|
|
56 |
"GanjinZero/biobart-base",
|
57 |
cache_dir="./models",
|
58 |
low_cpu_mem_usage=True,
|
59 |
+
device_map={"": device}
|
60 |
)
|
61 |
model = PeftModel.from_pretrained(
|
62 |
base_model,
|
63 |
"pendar02/biobart-finetune",
|
64 |
+
device_map={"": device},
|
65 |
torch_dtype=torch.float32
|
66 |
+
).to(device)
|
67 |
tokenizer = AutoTokenizer.from_pretrained(
|
68 |
"GanjinZero/biobart-base",
|
69 |
cache_dir="./models"
|
|
|
138 |
min_length = min(50, word_count) # Dynamic min length
|
139 |
|
140 |
inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
|
141 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
142 |
|
143 |
with torch.no_grad():
|
144 |
summary_ids = model.generate(
|
|
|
169 |
combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts)
|
170 |
|
171 |
inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
|
172 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
173 |
|
174 |
with torch.no_grad():
|
175 |
summary_ids = model.generate(
|