pendar02 commited on
Commit
c620786
·
verified ·
1 Parent(s): 2ad2c86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -38,15 +38,15 @@ def load_model(model_type):
38
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
39
  "facebook/bart-large-cnn",
40
  cache_dir="./models",
41
- low_cpu_mem_usage=True,
42
- device_map={"": device}
43
- )
44
  model = PeftModel.from_pretrained(
45
  base_model,
46
  "pendar02/results",
47
- device_map={"": device},
48
- torch_dtype=torch.float32
49
  ).to(device)
 
50
  tokenizer = AutoTokenizer.from_pretrained(
51
  "facebook/bart-large-cnn",
52
  cache_dir="./models"
@@ -55,15 +55,15 @@ def load_model(model_type):
55
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
56
  "GanjinZero/biobart-base",
57
  cache_dir="./models",
58
- low_cpu_mem_usage=True,
59
- device_map={"": device}
60
- )
61
  model = PeftModel.from_pretrained(
62
  base_model,
63
  "pendar02/biobart-finetune",
64
- device_map={"": device},
65
- torch_dtype=torch.float32
66
  ).to(device)
 
67
  tokenizer = AutoTokenizer.from_pretrained(
68
  "GanjinZero/biobart-base",
69
  cache_dir="./models"
@@ -74,6 +74,12 @@ def load_model(model_type):
74
  except Exception as e:
75
  st.error(f"Error loading model: {str(e)}")
76
  raise
 
 
 
 
 
 
77
 
78
  def cleanup_model(model, tokenizer):
79
  """Properly cleanup model resources"""
 
38
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
39
  "facebook/bart-large-cnn",
40
  cache_dir="./models",
41
+ torch_dtype=torch.float32
42
+ ).to(device)
43
+
44
  model = PeftModel.from_pretrained(
45
  base_model,
46
  "pendar02/results",
47
+ is_trainable=False
 
48
  ).to(device)
49
+
50
  tokenizer = AutoTokenizer.from_pretrained(
51
  "facebook/bart-large-cnn",
52
  cache_dir="./models"
 
55
  base_model = AutoModelForSeq2SeqLM.from_pretrained(
56
  "GanjinZero/biobart-base",
57
  cache_dir="./models",
58
+ torch_dtype=torch.float32
59
+ ).to(device)
60
+
61
  model = PeftModel.from_pretrained(
62
  base_model,
63
  "pendar02/biobart-finetune",
64
+ is_trainable=False
 
65
  ).to(device)
66
+
67
  tokenizer = AutoTokenizer.from_pretrained(
68
  "GanjinZero/biobart-base",
69
  cache_dir="./models"
 
74
  except Exception as e:
75
  st.error(f"Error loading model: {str(e)}")
76
  raise
77
+
78
+ model.eval()
79
+ return model, tokenizer
80
+ except Exception as e:
81
+ st.error(f"Error loading model: {str(e)}")
82
+ raise
83
 
84
  def cleanup_model(model, tokenizer):
85
  """Properly cleanup model resources"""