pendar02 commited on
Commit
2ad2c86
·
verified ·
1 Parent(s): 65d5daa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -30,22 +30,23 @@ def load_model(model_type):
30
  """Load appropriate model based on type with proper memory management"""
31
  try:
32
  # Clear any existing cached data
33
- torch.cuda.empty_cache()
34
  gc.collect()
35
 
 
 
36
  if model_type == "summarize":
37
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
38
  "facebook/bart-large-cnn",
39
  cache_dir="./models",
40
  low_cpu_mem_usage=True,
41
- torch_dtype=torch.float32
42
  )
43
  model = PeftModel.from_pretrained(
44
  base_model,
45
  "pendar02/results",
46
- device_map="auto",
47
  torch_dtype=torch.float32
48
- )
49
  tokenizer = AutoTokenizer.from_pretrained(
50
  "facebook/bart-large-cnn",
51
  cache_dir="./models"
@@ -55,14 +56,14 @@ def load_model(model_type):
55
  "GanjinZero/biobart-base",
56
  cache_dir="./models",
57
  low_cpu_mem_usage=True,
58
- torch_dtype=torch.float32
59
  )
60
  model = PeftModel.from_pretrained(
61
  base_model,
62
  "pendar02/biobart-finetune",
63
- device_map="auto",
64
  torch_dtype=torch.float32
65
- )
66
  tokenizer = AutoTokenizer.from_pretrained(
67
  "GanjinZero/biobart-base",
68
  cache_dir="./models"
@@ -137,6 +138,7 @@ def generate_summary(text, model, tokenizer):
137
  min_length = min(50, word_count) # Dynamic min length
138
 
139
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
 
140
 
141
  with torch.no_grad():
142
  summary_ids = model.generate(
@@ -167,6 +169,7 @@ def generate_focused_summary(question, abstracts, model, tokenizer):
167
  combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts)
168
 
169
  inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
 
170
 
171
  with torch.no_grad():
172
  summary_ids = model.generate(
 
30
  """Load appropriate model based on type with proper memory management"""
31
  try:
32
  # Clear any existing cached data
 
33
  gc.collect()
34
 
35
+ device = "cpu" # Force CPU usage
36
+
37
  if model_type == "summarize":
38
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
39
  "facebook/bart-large-cnn",
40
  cache_dir="./models",
41
  low_cpu_mem_usage=True,
42
+ device_map={"": device}
43
  )
44
  model = PeftModel.from_pretrained(
45
  base_model,
46
  "pendar02/results",
47
+ device_map={"": device},
48
  torch_dtype=torch.float32
49
+ ).to(device)
50
  tokenizer = AutoTokenizer.from_pretrained(
51
  "facebook/bart-large-cnn",
52
  cache_dir="./models"
 
56
  "GanjinZero/biobart-base",
57
  cache_dir="./models",
58
  low_cpu_mem_usage=True,
59
+ device_map={"": device}
60
  )
61
  model = PeftModel.from_pretrained(
62
  base_model,
63
  "pendar02/biobart-finetune",
64
+ device_map={"": device},
65
  torch_dtype=torch.float32
66
+ ).to(device)
67
  tokenizer = AutoTokenizer.from_pretrained(
68
  "GanjinZero/biobart-base",
69
  cache_dir="./models"
 
138
  min_length = min(50, word_count) # Dynamic min length
139
 
140
  inputs = tokenizer(formatted_text, return_tensors="pt", max_length=1024, truncation=True)
141
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
142
 
143
  with torch.no_grad():
144
  summary_ids = model.generate(
 
169
  combined_input = f"Question: {question} Abstracts: " + " [SEP] ".join(formatted_abstracts)
170
 
171
  inputs = tokenizer(combined_input, return_tensors="pt", max_length=1024, truncation=True)
172
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
173
 
174
  with torch.no_grad():
175
  summary_ids = model.generate(