pendar02 commited on
Commit
2d9eebc
·
verified ·
1 Parent(s): 1229bf2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -15
app.py CHANGED
@@ -40,49 +40,42 @@ def load_model(model_type):
40
  manage_resources()
41
 
42
  try:
43
- # Set lower precision to reduce memory usage
44
- torch_dtype = torch.float32
45
- if torch.cuda.is_available():
46
- device = "cuda"
47
- else:
48
- device = "cpu"
49
- torch_dtype = torch.float32 # Use float32 for CPU
50
-
51
  if model_type == "summarize":
52
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
53
  "facebook/bart-large-cnn",
54
  cache_dir="./models",
55
- torch_dtype=torch_dtype,
56
  low_cpu_mem_usage=True
57
  )
58
  model = PeftModel.from_pretrained(
59
  base_model,
60
  "pendar02/results",
61
- device_map=device,
62
- torch_dtype=torch_dtype
63
  )
64
  tokenizer = AutoTokenizer.from_pretrained(
65
  "facebook/bart-large-cnn",
66
  cache_dir="./models"
67
  )
68
  else: # question_focused
69
- base_model = AutoModelForSeq2SeqLM.from_pretrained(
70
  "GanjinZero/biobart-base",
71
  cache_dir="./models",
72
- torch_dtype=torch_dtype,
73
  low_cpu_mem_usage=True
74
  )
75
  model = PeftModel.from_pretrained(
76
  base_model,
77
  "pendar02/biobart-finetune",
78
- device_map=device,
79
- torch_dtype=torch_dtype
80
  )
81
  tokenizer = AutoTokenizer.from_pretrained(
82
  "GanjinZero/biobart-base",
83
  cache_dir="./models"
84
  )
85
 
 
 
86
  model.eval()
87
  return model, tokenizer
88
  except Exception as e:
 
40
  manage_resources()
41
 
42
  try:
43
+ # For CPU-only environment, don't use device_map
 
 
 
 
 
 
 
44
  if model_type == "summarize":
45
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
46
  "facebook/bart-large-cnn",
47
  cache_dir="./models",
48
+ torch_dtype=torch.float32,
49
  low_cpu_mem_usage=True
50
  )
51
  model = PeftModel.from_pretrained(
52
  base_model,
53
  "pendar02/results",
54
+ torch_dtype=torch.float32
 
55
  )
56
  tokenizer = AutoTokenizer.from_pretrained(
57
  "facebook/bart-large-cnn",
58
  cache_dir="./models"
59
  )
60
  else: # question_focused
61
+ base_model = AutoModelForSeq2SeqLation_model = AutoModelForSeq2SeqLM.from_pretrained(
62
  "GanjinZero/biobart-base",
63
  cache_dir="./models",
64
+ torch_dtype=torch.float32,
65
  low_cpu_mem_usage=True
66
  )
67
  model = PeftModel.from_pretrained(
68
  base_model,
69
  "pendar02/biobart-finetune",
70
+ torch_dtype=torch.float32
 
71
  )
72
  tokenizer = AutoTokenizer.from_pretrained(
73
  "GanjinZero/biobart-base",
74
  cache_dir="./models"
75
  )
76
 
77
+ # Ensure model is on CPU
78
+ model = model.cpu()
79
  model.eval()
80
  return model, tokenizer
81
  except Exception as e: