Kevin Hu commited on
Commit
96edfc5
·
1 Parent(s): 684f1d7

refine xinference (#2521)

Browse files

### What problem does this PR solve?

#1588

### Type of change

- [x] Refactoring

rag/llm/cv_model.py CHANGED
@@ -449,6 +449,8 @@ class LocalAICV(GptV4):
449
 
450
  class XinferenceCV(Base):
451
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
 
 
452
  self.client = OpenAI(api_key="xxx", base_url=base_url)
453
  self.model_name = model_name
454
  self.lang = lang
 
449
 
450
  class XinferenceCV(Base):
451
  def __init__(self, key, model_name="", lang="Chinese", base_url=""):
452
+ if base_url.split("/")[-1] != "v1":
453
+ base_url = os.path.join(base_url, "v1")
454
  self.client = OpenAI(api_key="xxx", base_url=base_url)
455
  self.model_name = model_name
456
  self.lang = lang
rag/llm/embedding_model.py CHANGED
@@ -268,6 +268,8 @@ class FastEmbed(Base):
268
 
269
  class XinferenceEmbed(Base):
270
  def __init__(self, key, model_name="", base_url=""):
 
 
271
  self.client = OpenAI(api_key="xxx", base_url=base_url)
272
  self.model_name = model_name
273
 
 
268
 
269
  class XinferenceEmbed(Base):
270
  def __init__(self, key, model_name="", base_url=""):
271
+ if base_url.split("/")[-1] != "v1":
272
+ base_url = os.path.join(base_url, "v1")
273
  self.client = OpenAI(api_key="xxx", base_url=base_url)
274
  self.model_name = model_name
275
 
rag/llm/rerank_model.py CHANGED
@@ -140,6 +140,8 @@ class YoudaoRerank(DefaultRerank):
140
 
141
  class XInferenceRerank(Base):
142
  def __init__(self, key="xxxxxxx", model_name="", base_url=""):
 
 
143
  self.model_name = model_name
144
  self.base_url = base_url
145
  self.headers = {
 
140
 
141
  class XInferenceRerank(Base):
142
  def __init__(self, key="xxxxxxx", model_name="", base_url=""):
143
+ if base_url.split("/")[-1] != "v1":
144
+ base_url = os.path.join(base_url, "v1")
145
  self.model_name = model_name
146
  self.base_url = base_url
147
  self.headers = {
rag/llm/sequence2txt_model.py CHANGED
@@ -93,6 +93,8 @@ class AzureSeq2txt(Base):
93
 
94
  class XinferenceSeq2txt(Base):
95
  def __init__(self, key, model_name="", base_url=""):
 
 
96
  self.client = OpenAI(api_key="xxx", base_url=base_url)
97
  self.model_name = model_name
98
 
 
93
 
94
  class XinferenceSeq2txt(Base):
95
  def __init__(self, key, model_name="", base_url=""):
96
+ if base_url.split("/")[-1] != "v1":
97
+ base_url = os.path.join(base_url, "v1")
98
  self.client = OpenAI(api_key="xxx", base_url=base_url)
99
  self.model_name = model_name
100