Shining-Data commited on
Commit
957da21
·
verified ·
1 Parent(s): 9f8313e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -79,18 +79,28 @@ def load_pipeline(model_name):
79
  Load and cache a transformers pipeline for text generation.
80
  Tries bfloat16, falls back to float16 or float32 if unsupported.
81
  """
82
- global PIPELINES
83
  if model_name in PIPELINES.keys():
84
  return PIPELINES[model_name]
85
  repo = MODELS[model_name]["repo_id"]
86
- tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
87
- model = AutoModelForCausalLM.from_pretrained(
88
- repo,
89
- device_map=device,
90
- trust_remote_code=True,
91
- )
 
 
 
 
 
 
 
 
 
 
92
  PIPELINES[model_name] = {"tokenizer": tokenizer, "model": model}
93
- return PIPELINES[model_name]
94
 
95
 
96
  def retrieve_context(query, max_results=6, max_chars=600):
 
79
  Load and cache a transformers pipeline for text generation.
80
  Tries bfloat16, falls back to float16 or float32 if unsupported.
81
  """
82
+
83
  if model_name in PIPELINES.keys():
84
  return PIPELINES[model_name]
85
  repo = MODELS[model_name]["repo_id"]
86
+ if model_name == "secgpt-mini":
87
+ tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True, subfolder="models")
88
+ model = AutoModelForCausalLM.from_pretrained(
89
+ repo,
90
+ device_map=device,
91
+ trust_remote_code=True,
92
+ subfolder="models",
93
+ )
94
+ else:
95
+ tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True)
96
+ model = AutoModelForCausalLM.from_pretrained(
97
+ repo,
98
+ device_map=device,
99
+ trust_remote_code=True,
100
+ )
101
+ global PIPELINES
102
  PIPELINES[model_name] = {"tokenizer": tokenizer, "model": model}
103
+ return {"tokenizer": tokenizer, "model": model}
104
 
105
 
106
  def retrieve_context(query, max_results=6, max_chars=600):