Kevin Hu
commited on
Commit
·
96edfc5
1
Parent(s):
684f1d7
refine xinference (#2521)
Browse files### What problem does this PR solve?
#1588
### Type of change
- [x] Refactoring
- rag/llm/cv_model.py +2 -0
- rag/llm/embedding_model.py +2 -0
- rag/llm/rerank_model.py +2 -0
- rag/llm/sequence2txt_model.py +2 -0
rag/llm/cv_model.py
CHANGED
@@ -449,6 +449,8 @@ class LocalAICV(GptV4):
|
|
449 |
|
450 |
class XinferenceCV(Base):
|
451 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
|
|
|
|
452 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
453 |
self.model_name = model_name
|
454 |
self.lang = lang
|
|
|
449 |
|
450 |
class XinferenceCV(Base):
|
451 |
def __init__(self, key, model_name="", lang="Chinese", base_url=""):
|
452 |
+
if base_url.split("/")[-1] != "v1":
|
453 |
+
base_url = os.path.join(base_url, "v1")
|
454 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
455 |
self.model_name = model_name
|
456 |
self.lang = lang
|
rag/llm/embedding_model.py
CHANGED
@@ -268,6 +268,8 @@ class FastEmbed(Base):
|
|
268 |
|
269 |
class XinferenceEmbed(Base):
|
270 |
def __init__(self, key, model_name="", base_url=""):
|
|
|
|
|
271 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
272 |
self.model_name = model_name
|
273 |
|
|
|
268 |
|
269 |
class XinferenceEmbed(Base):
|
270 |
def __init__(self, key, model_name="", base_url=""):
|
271 |
+
if base_url.split("/")[-1] != "v1":
|
272 |
+
base_url = os.path.join(base_url, "v1")
|
273 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
274 |
self.model_name = model_name
|
275 |
|
rag/llm/rerank_model.py
CHANGED
@@ -140,6 +140,8 @@ class YoudaoRerank(DefaultRerank):
|
|
140 |
|
141 |
class XInferenceRerank(Base):
|
142 |
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
|
|
|
|
|
143 |
self.model_name = model_name
|
144 |
self.base_url = base_url
|
145 |
self.headers = {
|
|
|
140 |
|
141 |
class XInferenceRerank(Base):
|
142 |
def __init__(self, key="xxxxxxx", model_name="", base_url=""):
|
143 |
+
if base_url.split("/")[-1] != "v1":
|
144 |
+
base_url = os.path.join(base_url, "v1")
|
145 |
self.model_name = model_name
|
146 |
self.base_url = base_url
|
147 |
self.headers = {
|
rag/llm/sequence2txt_model.py
CHANGED
@@ -93,6 +93,8 @@ class AzureSeq2txt(Base):
|
|
93 |
|
94 |
class XinferenceSeq2txt(Base):
|
95 |
def __init__(self, key, model_name="", base_url=""):
|
|
|
|
|
96 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
97 |
self.model_name = model_name
|
98 |
|
|
|
93 |
|
94 |
class XinferenceSeq2txt(Base):
|
95 |
def __init__(self, key, model_name="", base_url=""):
|
96 |
+
if base_url.split("/")[-1] != "v1":
|
97 |
+
base_url = os.path.join(base_url, "v1")
|
98 |
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
99 |
self.model_name = model_name
|
100 |
|