pendar02 commited on
Commit
dde1577
·
verified ·
1 Parent(s): 2d9eebc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -40,43 +40,48 @@ def load_model(model_type):
40
  manage_resources()
41
 
42
  try:
43
- # For CPU-only environment, don't use device_map
44
  if model_type == "summarize":
45
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
46
  "facebook/bart-large-cnn",
47
  cache_dir="./models",
48
- torch_dtype=torch.float32,
49
- low_cpu_mem_usage=True
50
- )
 
51
  model = PeftModel.from_pretrained(
52
  base_model,
53
  "pendar02/results",
54
- torch_dtype=torch.float32
55
- )
 
 
 
56
  tokenizer = AutoTokenizer.from_pretrained(
57
  "facebook/bart-large-cnn",
58
  cache_dir="./models"
59
  )
60
  else: # question_focused
61
- base_model = AutoModelForSeq2SeqLation_model = AutoModelForSeq2SeqLM.from_pretrained(
62
  "GanjinZero/biobart-base",
63
  cache_dir="./models",
64
- torch_dtype=torch.float32,
65
- low_cpu_mem_usage=True
66
- )
 
67
  model = PeftModel.from_pretrained(
68
  base_model,
69
  "pendar02/biobart-finetune",
70
- torch_dtype=torch.float32
71
- )
 
 
 
72
  tokenizer = AutoTokenizer.from_pretrained(
73
  "GanjinZero/biobart-base",
74
  cache_dir="./models"
75
  )
76
 
77
- # Ensure model is on CPU
78
- model = model.cpu()
79
- model.eval()
80
  return model, tokenizer
81
  except Exception as e:
82
  st.error(f"Error loading model: {str(e)}")
 
40
  manage_resources()
41
 
42
  try:
 
43
  if model_type == "summarize":
44
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
45
  "facebook/bart-large-cnn",
46
  cache_dir="./models",
47
+ device_map=None, # Explicitly set to None for CPU
48
+ torch_dtype=torch.float32
49
+ ).to("cpu") # Force CPU
50
+
51
  model = PeftModel.from_pretrained(
52
  base_model,
53
  "pendar02/results",
54
+ device_map=None, # Explicitly set to None for CPU
55
+ torch_dtype=torch.float32,
56
+ is_trainable=False # Set to inference mode
57
+ ).to("cpu") # Force CPU
58
+
59
  tokenizer = AutoTokenizer.from_pretrained(
60
  "facebook/bart-large-cnn",
61
  cache_dir="./models"
62
  )
63
  else: # question_focused
64
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(
65
  "GanjinZero/biobart-base",
66
  cache_dir="./models",
67
+ device_map=None, # Explicitly set to None for CPU
68
+ torch_dtype=torch.float32
69
+ ).to("cpu") # Force CPU
70
+
71
  model = PeftModel.from_pretrained(
72
  base_model,
73
  "pendar02/biobart-finetune",
74
+ device_map=None, # Explicitly set to None for CPU
75
+ torch_dtype=torch.float32,
76
+ is_trainable=False # Set to inference mode
77
+ ).to("cpu") # Force CPU
78
+
79
  tokenizer = AutoTokenizer.from_pretrained(
80
  "GanjinZero/biobart-base",
81
  cache_dir="./models"
82
  )
83
 
84
+ model.eval() # Set to evaluation mode
 
 
85
  return model, tokenizer
86
  except Exception as e:
87
  st.error(f"Error loading model: {str(e)}")