黄腾 aopstudio commited on
Commit
745354f
·
1 Parent(s): fe5404c

add support for NVIDIA llm (#1645)

Browse files

### What problem does this PR solve?

add support for NVIDIA llm
### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>

conf/llm_factories.json CHANGED
@@ -1918,6 +1918,290 @@
1918
  "model_type": "chat"
1919
  }
1920
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1921
  }
1922
  ]
1923
  }
 
1918
  "model_type": "chat"
1919
  }
1920
  ]
1921
+ },
1922
+ {
1923
+ "name": "NVIDIA",
1924
+ "logo": "",
1925
+ "tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
1926
+ "status": "1",
1927
+ "llm": [
1928
+ {
1929
+ "llm_name": "nvidia/nemotron-4-340b-reward",
1930
+ "tags": "LLM,CHAT,4K",
1931
+ "max_tokens": 4096,
1932
+ "model_type": "chat"
1933
+ },
1934
+ {
1935
+ "llm_name": "aisingapore/sea-lion-7b-instruct",
1936
+ "tags": "LLM,CHAT,4K",
1937
+ "max_tokens": 4096,
1938
+ "model_type": "chat"
1939
+ },
1940
+ {
1941
+ "llm_name": "databricks/dbrx-instruct",
1942
+ "tags": "LLM,CHAT,16K",
1943
+ "max_tokens": 16384,
1944
+ "model_type": "chat"
1945
+ },
1946
+ {
1947
+ "llm_name": "google/gemma-7b",
1948
+ "tags": "LLM,CHAT,32K",
1949
+ "max_tokens": 32768,
1950
+ "model_type": "chat"
1951
+ },
1952
+ {
1953
+ "llm_name": "google/gemma-2b",
1954
+ "tags": "LLM,CHAT,16K",
1955
+ "max_tokens": 16384,
1956
+ "model_type": "chat"
1957
+ },
1958
+ {
1959
+ "llm_name": "google/gemma-2-9b-it",
1960
+ "tags": "LLM,CHAT,8K",
1961
+ "max_tokens": 8192,
1962
+ "model_type": "chat"
1963
+ },
1964
+ {
1965
+ "llm_name": "google/gemma-2-27b-it",
1966
+ "tags": "LLM,CHAT,8K",
1967
+ "max_tokens": 8192,
1968
+ "model_type": "chat"
1969
+ },
1970
+ {
1971
+ "llm_name": "google/recurrentgemma-2b",
1972
+ "tags": "LLM,CHAT,4K",
1973
+ "max_tokens": 4096,
1974
+ "model_type": "chat"
1975
+ },
1976
+ {
1977
+ "llm_name": "mediatek/breeze-7b-instruct",
1978
+ "tags": "LLM,CHAT,8K",
1979
+ "max_tokens": 8192,
1980
+ "model_type": "chat"
1981
+ },
1982
+ {
1983
+ "llm_name": "meta/llama2-70b",
1984
+ "tags": "LLM,CHAT,4K",
1985
+ "max_tokens": 4096,
1986
+ "model_type": "chat"
1987
+ },
1988
+ {
1989
+ "llm_name": "meta/llama3-8b",
1990
+ "tags": "LLM,CHAT,8K",
1991
+ "max_tokens": 8192,
1992
+ "model_type": "chat"
1993
+ },
1994
+ {
1995
+ "llm_name": "meta/llama3-70b",
1996
+ "tags": "LLM,CHAT,8K",
1997
+ "max_tokens": 8192,
1998
+ "model_type": "chat"
1999
+ },
2000
+ {
2001
+ "llm_name": "microsoft/phi-3-medium-128k-instruct",
2002
+ "tags": "LLM,CHAT,128K",
2003
+ "max_tokens": 131072,
2004
+ "model_type": "chat"
2005
+ },
2006
+ {
2007
+ "llm_name": "microsoft/phi-3-medium-4k-instruct",
2008
+ "tags": "LLM,CHAT,4K",
2009
+ "max_tokens": 4096,
2010
+ "model_type": "chat"
2011
+ },
2012
+ {
2013
+ "llm_name": "microsoftphi-3-mini-128k-instruct",
2014
+ "tags": "LLM,CHAT,128K",
2015
+ "max_tokens": 131072,
2016
+ "model_type": "chat"
2017
+ },
2018
+ {
2019
+ "llm_name": "microsoft/phi-3-mini-4k-instruct",
2020
+ "tags": "LLM,CHAT,4K",
2021
+ "max_tokens": 4096,
2022
+ "model_type": "chat"
2023
+ },
2024
+ {
2025
+ "llm_name": "microsoft/phi-3-small-128k-instruct",
2026
+ "tags": "LLM,CHAT,128K",
2027
+ "max_tokens": 131072,
2028
+ "model_type": "chat"
2029
+ },
2030
+ {
2031
+ "llm_name": "microsoft/phi-3-small-8k-instruct",
2032
+ "tags": "LLM,CHAT,8K",
2033
+ "max_tokens": 8192,
2034
+ "model_type": "chat"
2035
+ },
2036
+ {
2037
+ "llm_name": "mistralai/mistral-7b-instruct",
2038
+ "tags": "LLM,CHAT,4K",
2039
+ "max_tokens": 4096,
2040
+ "model_type": "chat"
2041
+ },
2042
+ {
2043
+ "llm_name": "mistralai/mistral-7b-instruct-v0.3",
2044
+ "tags": "LLM,CHAT,4K",
2045
+ "max_tokens": 4096,
2046
+ "model_type": "chat"
2047
+ },
2048
+ {
2049
+ "llm_name": "mistralai/mixtral-8x7b-instruct",
2050
+ "tags": "LLM,CHAT,32K",
2051
+ "max_tokens": 32768,
2052
+ "model_type": "chat"
2053
+ },
2054
+ {
2055
+ "llm_name": "mistralai/mixtral-8x22b-instruct",
2056
+ "tags": "LLM,CHAT,64K",
2057
+ "max_tokens": 65536,
2058
+ "model_type": "chat"
2059
+ },
2060
+ {
2061
+ "llm_name": "mistralai/mistral-large",
2062
+ "tags": "LLM,CHAT,32K",
2063
+ "max_tokens": 32768,
2064
+ "model_type": "chat"
2065
+ },
2066
+ {
2067
+ "llm_name": "nv-mistralai/mistral-nemo-12b-instruct",
2068
+ "tags": "LLM,CHAT,128K",
2069
+ "max_tokens": 131072,
2070
+ "model_type": "chat"
2071
+ },
2072
+ {
2073
+ "llm_name": "nvidia/llama3-chatqa-1.5-70b",
2074
+ "tags": "LLM,CHAT,4K",
2075
+ "max_tokens": 4096,
2076
+ "model_type": "chat"
2077
+ },
2078
+ {
2079
+ "llm_name": "nvidia/llama3-chatqa-1.5-8b",
2080
+ "tags": "LLM,CHAT,4K",
2081
+ "max_tokens": 4096,
2082
+ "model_type": "chat"
2083
+ },
2084
+ {
2085
+ "llm_name": "nvidia/nemotron-4-340b-instruct",
2086
+ "tags": "LLM,CHAT,4K",
2087
+ "max_tokens": 4096,
2088
+ "model_type": "chat"
2089
+ },
2090
+ {
2091
+ "llm_name": "seallms/seallm-7b-v2.5",
2092
+ "tags": "LLM,CHAT,4K",
2093
+ "max_tokens": 4096,
2094
+ "model_type": "chat"
2095
+ },
2096
+ {
2097
+ "llm_name": "snowflake/arctic",
2098
+ "tags": "LLM,CHAT,4K",
2099
+ "max_tokens": 4096,
2100
+ "model_type": "chat"
2101
+ },
2102
+ {
2103
+ "llm_name": "upstage/solar-10.7b-instruct",
2104
+ "tags": "LLM,CHAT,4K",
2105
+ "max_tokens": 4096,
2106
+ "model_type": "chat"
2107
+ },
2108
+ {
2109
+ "llm_name": "baai/bge-m3",
2110
+ "tags": "TEXT EMBEDDING,8K",
2111
+ "max_tokens": 8192,
2112
+ "model_type": "embedding"
2113
+ },
2114
+ {
2115
+ "llm_name": "nvidia/embed-qa-4",
2116
+ "tags": "TEXT EMBEDDING,512",
2117
+ "max_tokens": 512,
2118
+ "model_type": "embedding"
2119
+ },
2120
+ {
2121
+ "llm_name": "nvidia/nv-embed-v1",
2122
+ "tags": "TEXT EMBEDDING,32K",
2123
+ "max_tokens": 32768,
2124
+ "model_type": "embedding"
2125
+ },
2126
+ {
2127
+ "llm_name": "nvidia/nv-embedqa-e5-v5",
2128
+ "tags": "TEXT EMBEDDING,512",
2129
+ "max_tokens": 512,
2130
+ "model_type": "embedding"
2131
+ },
2132
+ {
2133
+ "llm_name": "nvidia/nv-embedqa-mistral-7b-v2",
2134
+ "tags": "TEXT EMBEDDING,512",
2135
+ "max_tokens": 512,
2136
+ "model_type": "embedding"
2137
+ },
2138
+ {
2139
+ "llm_name": "nvidia/nv-rerankqa-mistral-4b-v3",
2140
+ "tags": "RE-RANK,512",
2141
+ "max_tokens": 512,
2142
+ "model_type": "rerank"
2143
+ },
2144
+ {
2145
+ "llm_name": "nvidia/rerank-qa-mistral-4b",
2146
+ "tags": "RE-RANK,512",
2147
+ "max_tokens": 512,
2148
+ "model_type": "rerank"
2149
+ },
2150
+ {
2151
+ "llm_name": "snowflake/arctic-embed-l",
2152
+ "tags": "TEXT EMBEDDING,512",
2153
+ "max_tokens": 512,
2154
+ "model_type": "embedding"
2155
+ },
2156
+ {
2157
+ "llm_name": "adept/fuyu-8b",
2158
+ "tags": "LLM,IMAGE2TEXT,4K",
2159
+ "max_tokens": 4096,
2160
+ "model_type": "image2text"
2161
+ },
2162
+ {
2163
+ "llm_name": "google/deplot",
2164
+ "tags": "LLM,IMAGE2TEXT,4K",
2165
+ "max_tokens": 4096,
2166
+ "model_type": "image2text"
2167
+ },
2168
+ {
2169
+ "llm_name": "google/paligemma",
2170
+ "tags": "LLM,IMAGE2TEXT,4K",
2171
+ "max_tokens": 4096,
2172
+ "model_type": "image2text"
2173
+ },
2174
+ {
2175
+ "llm_name": "Iiuhaotian/Ilava-v1.6-34b",
2176
+ "tags": "LLM,IMAGE2TEXT,4K",
2177
+ "max_tokens": 4096,
2178
+ "model_type": "image2text"
2179
+ },
2180
+ {
2181
+ "llm_name": "Iiuhaotian/Ilava-v1.6-mistral-7b",
2182
+ "tags": "LLM,IMAGE2TEXT,4K",
2183
+ "max_tokens": 4096,
2184
+ "model_type": "image2text"
2185
+ },
2186
+ {
2187
+ "llm_name": "microsoft/kosmos-2",
2188
+ "tags": "LLM,IMAGE2TEXT,4K",
2189
+ "max_tokens": 4096,
2190
+ "model_type": "image2text"
2191
+ },
2192
+ {
2193
+ "llm_name": "microsoft/phi-3-vision-128k-instruct",
2194
+ "tags": "LLM,IMAGE2TEXT,128K",
2195
+ "max_tokens": 131072,
2196
+ "model_type": "image2text"
2197
+ },
2198
+ {
2199
+ "llm_name": "nvidia/neva-22b",
2200
+ "tags": "LLM,IMAGE2TEXT,4K",
2201
+ "max_tokens": 4096,
2202
+ "model_type": "image2text"
2203
+ }
2204
+ ]
2205
  }
2206
  ]
2207
  }
rag/llm/__init__.py CHANGED
@@ -34,7 +34,8 @@ EmbeddingModel = {
34
  "BAAI": DefaultEmbedding,
35
  "Mistral": MistralEmbed,
36
  "Bedrock": BedrockEmbed,
37
- "Gemini":GeminiEmbed
 
38
  }
39
 
40
 
@@ -48,7 +49,8 @@ CvModel = {
48
  "Moonshot": LocalCV,
49
  'Gemini':GeminiCV,
50
  'OpenRouter':OpenRouterCV,
51
- "LocalAI":LocalAICV
 
52
  }
53
 
54
 
@@ -71,7 +73,8 @@ ChatModel = {
71
  "Bedrock": BedrockChat,
72
  "Groq": GroqChat,
73
  'OpenRouter':OpenRouterChat,
74
- "StepFun":StepFunChat
 
75
  }
76
 
77
 
@@ -79,7 +82,8 @@ RerankModel = {
79
  "BAAI": DefaultRerank,
80
  "Jina": JinaRerank,
81
  "Youdao": YoudaoRerank,
82
- "Xinference": XInferenceRerank
 
83
  }
84
 
85
 
 
34
  "BAAI": DefaultEmbedding,
35
  "Mistral": MistralEmbed,
36
  "Bedrock": BedrockEmbed,
37
+ "Gemini":GeminiEmbed,
38
+ "NVIDIA":NvidiaEmbed
39
  }
40
 
41
 
 
49
  "Moonshot": LocalCV,
50
  'Gemini':GeminiCV,
51
  'OpenRouter':OpenRouterCV,
52
+ "LocalAI":LocalAICV,
53
+ "NVIDIA":NvidiaCV
54
  }
55
 
56
 
 
73
  "Bedrock": BedrockChat,
74
  "Groq": GroqChat,
75
  'OpenRouter':OpenRouterChat,
76
+ "StepFun":StepFunChat,
77
+ "NVIDIA":NvidiaChat
78
  }
79
 
80
 
 
82
  "BAAI": DefaultRerank,
83
  "Jina": JinaRerank,
84
  "Youdao": YoudaoRerank,
85
+ "Xinference": XInferenceRerank,
86
+ "NVIDIA":NvidiaRerank
87
  }
88
 
89
 
rag/llm/chat_model.py CHANGED
@@ -581,7 +581,6 @@ class MiniMaxChat(Base):
581
  response = requests.request(
582
  "POST", url=self.base_url, headers=headers, data=payload
583
  )
584
- print(response, flush=True)
585
  response = response.json()
586
  ans = response["choices"][0]["message"]["content"].strip()
587
  if response["choices"][0]["finish_reason"] == "length":
@@ -902,4 +901,79 @@ class StepFunChat(Base):
902
  def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1/chat/completions"):
903
  if not base_url:
904
  base_url = "https://api.stepfun.com/v1/chat/completions"
905
- super().__init__(key, model_name, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  response = requests.request(
582
  "POST", url=self.base_url, headers=headers, data=payload
583
  )
 
584
  response = response.json()
585
  ans = response["choices"][0]["message"]["content"].strip()
586
  if response["choices"][0]["finish_reason"] == "length":
 
901
  def __init__(self, key, model_name, base_url="https://api.stepfun.com/v1/chat/completions"):
902
  if not base_url:
903
  base_url = "https://api.stepfun.com/v1/chat/completions"
904
+ super().__init__(key, model_name, base_url)
905
+
906
+
907
+ class NvidiaChat(Base):
908
+ def __init__(
909
+ self,
910
+ key,
911
+ model_name,
912
+ base_url="https://integrate.api.nvidia.com/v1/chat/completions",
913
+ ):
914
+ if not base_url:
915
+ base_url = "https://integrate.api.nvidia.com/v1/chat/completions"
916
+ self.base_url = base_url
917
+ self.model_name = model_name
918
+ self.api_key = key
919
+ self.headers = {
920
+ "accept": "application/json",
921
+ "Authorization": f"Bearer {self.api_key}",
922
+ "Content-Type": "application/json",
923
+ }
924
+
925
+ def chat(self, system, history, gen_conf):
926
+ if system:
927
+ history.insert(0, {"role": "system", "content": system})
928
+ for k in list(gen_conf.keys()):
929
+ if k not in ["temperature", "top_p", "max_tokens"]:
930
+ del gen_conf[k]
931
+ payload = {"model": self.model_name, "messages": history, **gen_conf}
932
+ try:
933
+ response = requests.post(
934
+ url=self.base_url, headers=self.headers, json=payload
935
+ )
936
+ response = response.json()
937
+ ans = response["choices"][0]["message"]["content"].strip()
938
+ return ans, response["usage"]["total_tokens"]
939
+ except Exception as e:
940
+ return "**ERROR**: " + str(e), 0
941
+
942
+ def chat_streamly(self, system, history, gen_conf):
943
+ if system:
944
+ history.insert(0, {"role": "system", "content": system})
945
+ for k in list(gen_conf.keys()):
946
+ if k not in ["temperature", "top_p", "max_tokens"]:
947
+ del gen_conf[k]
948
+ ans = ""
949
+ total_tokens = 0
950
+ payload = {
951
+ "model": self.model_name,
952
+ "messages": history,
953
+ "stream": True,
954
+ **gen_conf,
955
+ }
956
+
957
+ try:
958
+ response = requests.post(
959
+ url=self.base_url,
960
+ headers=self.headers,
961
+ json=payload,
962
+ )
963
+ for resp in response.text.split("\n\n"):
964
+ if "choices" not in resp:
965
+ continue
966
+ resp = json.loads(resp[6:])
967
+ if "content" in resp["choices"][0]["delta"]:
968
+ text = resp["choices"][0]["delta"]["content"]
969
+ else:
970
+ continue
971
+ ans += text
972
+ if "usage" in resp:
973
+ total_tokens = resp["usage"]["total_tokens"]
974
+ yield ans
975
+
976
+ except Exception as e:
977
+ yield ans + "\n**ERROR**: " + str(e)
978
+
979
+ yield total_tokens
rag/llm/cv_model.py CHANGED
@@ -137,7 +137,6 @@ class Base(ABC):
137
  ]
138
 
139
 
140
-
141
  class GptV4(Base):
142
  def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
143
  if not base_url: base_url="https://api.openai.com/v1"
@@ -619,3 +618,65 @@ class LocalCV(Base):
619
 
620
  def describe(self, image, max_tokens=1024):
621
  return "", 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  ]
138
 
139
 
 
140
  class GptV4(Base):
141
  def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
142
  if not base_url: base_url="https://api.openai.com/v1"
 
618
 
619
  def describe(self, image, max_tokens=1024):
620
  return "", 0
621
+
622
+
623
+ class NvidiaCV(Base):
624
+ def __init__(
625
+ self,
626
+ key,
627
+ model_name,
628
+ lang="Chinese",
629
+ base_url="https://ai.api.nvidia.com/v1/vlm",
630
+ ):
631
+ if not base_url:
632
+ base_url = ("https://ai.api.nvidia.com/v1/vlm",)
633
+ self.lang = lang
634
+ factory, llm_name = model_name.split("/")
635
+ if factory != "liuhaotian":
636
+ self.base_url = os.path.join(base_url, factory, llm_name)
637
+ else:
638
+ self.base_url = os.path.join(
639
+ base_url, "community", llm_name.replace("-v1.6", "16")
640
+ )
641
+ self.key = key
642
+
643
+ def describe(self, image, max_tokens=1024):
644
+ b64 = self.image2base64(image)
645
+ response = requests.post(
646
+ url=self.base_url,
647
+ headers={
648
+ "accept": "application/json",
649
+ "content-type": "application/json",
650
+ "Authorization": f"Bearer {self.key}",
651
+ },
652
+ json={
653
+ "messages": self.prompt(b64),
654
+ "max_tokens": max_tokens,
655
+ },
656
+ )
657
+ response = response.json()
658
+ return (
659
+ response["choices"][0]["message"]["content"].strip(),
660
+ response["usage"]["total_tokens"],
661
+ )
662
+
663
+ def prompt(self, b64):
664
+ return [
665
+ {
666
+ "role": "user",
667
+ "content": (
668
+ "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。"
669
+ if self.lang.lower() == "chinese"
670
+ else "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
671
+ )
672
+ + f' <img src="data:image/jpeg;base64,{b64}"/>',
673
+ }
674
+ ]
675
+
676
+ def chat_prompt(self, text, b64):
677
+ return [
678
+ {
679
+ "role": "user",
680
+ "content": text + f' <img src="data:image/jpeg;base64,{b64}"/>',
681
+ }
682
+ ]
rag/llm/embedding_model.py CHANGED
@@ -462,3 +462,41 @@ class GeminiEmbed(Base):
462
  title="Embedding of single string")
463
  token_count = num_tokens_from_string(text)
464
  return np.array(result['embedding']),token_count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
462
  title="Embedding of single string")
463
  token_count = num_tokens_from_string(text)
464
  return np.array(result['embedding']),token_count
465
+
466
+ class NvidiaEmbed(Base):
467
+ def __init__(
468
+ self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings"
469
+ ):
470
+ if not base_url:
471
+ base_url = "https://integrate.api.nvidia.com/v1/embeddings"
472
+ self.api_key = key
473
+ self.base_url = base_url
474
+ self.headers = {
475
+ "accept": "application/json",
476
+ "Content-Type": "application/json",
477
+ "authorization": f"Bearer {self.api_key}",
478
+ }
479
+ self.model_name = model_name
480
+ if model_name == "nvidia/embed-qa-4":
481
+ self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
482
+ self.model_name = "NV-Embed-QA"
483
+ if model_name == "snowflake/arctic-embed-l":
484
+ self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings"
485
+
486
+ def encode(self, texts: list, batch_size=None):
487
+ payload = {
488
+ "input": texts,
489
+ "input_type": "query",
490
+ "model": self.model_name,
491
+ "encoding_format": "float",
492
+ "truncate": "END",
493
+ }
494
+ res = requests.post(self.base_url, headers=self.headers, json=payload).json()
495
+ return (
496
+ np.array([d["embedding"] for d in res["data"]]),
497
+ res["usage"]["total_tokens"],
498
+ )
499
+
500
+ def encode_queries(self, text):
501
+ embds, cnt = self.encode([text])
502
+ return np.array(embds[0]), cnt
rag/llm/rerank_model.py CHANGED
@@ -164,3 +164,41 @@ class LocalAIRerank(Base):
164
 
165
  def similarity(self, query: str, texts: list):
166
  raise NotImplementedError("The LocalAIRerank has not been implement")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  def similarity(self, query: str, texts: list):
166
  raise NotImplementedError("The LocalAIRerank has not been implement")
167
+
168
+
169
+ class NvidiaRerank(Base):
170
+ def __init__(
171
+ self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/"
172
+ ):
173
+ if not base_url:
174
+ base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/"
175
+ self.model_name = model_name
176
+
177
+ if self.model_name == "nvidia/nv-rerankqa-mistral-4b-v3":
178
+ self.base_url = os.path.join(
179
+ base_url, "nv-rerankqa-mistral-4b-v3", "reranking"
180
+ )
181
+
182
+ if self.model_name == "nvidia/rerank-qa-mistral-4b":
183
+ self.base_url = os.path.join(base_url, "reranking")
184
+ self.model_name = "nv-rerank-qa-mistral-4b:1"
185
+
186
+ self.headers = {
187
+ "accept": "application/json",
188
+ "Content-Type": "application/json",
189
+ "Authorization": f"Bearer {key}",
190
+ }
191
+
192
+ def similarity(self, query: str, texts: list):
193
+ token_count = num_tokens_from_string(query) + sum(
194
+ [num_tokens_from_string(t) for t in texts]
195
+ )
196
+ data = {
197
+ "model": self.model_name,
198
+ "query": {"text": query},
199
+ "passages": [{"text": text} for text in texts],
200
+ "truncate": "END",
201
+ "top_n": len(texts),
202
+ }
203
+ res = requests.post(self.base_url, headers=self.headers, json=data).json()
204
+ return (np.array([d["logit"] for d in res["rankings"]]), token_count)
web/src/assets/svg/llm/nvidia.svg ADDED
web/src/pages/user-setting/setting-model/constant.ts CHANGED
@@ -20,6 +20,7 @@ export const IconMap = {
20
  OpenRouter: 'open-router',
21
  LocalAI: 'local-ai',
22
  StepFun: 'stepfun',
 
23
  };
24
 
25
  export const BedrockRegionList = [
 
20
  OpenRouter: 'open-router',
21
  LocalAI: 'local-ai',
22
  StepFun: 'stepfun',
23
+ NVIDIA:'nvidia'
24
  };
25
 
26
  export const BedrockRegionList = [