Spaces:
Sleeping
Sleeping
Commit
·
64c61f3
1
Parent(s):
41ac6cc
Add Anthropic Opus support
Browse files
utils.py
CHANGED
|
@@ -24,9 +24,9 @@ from huggingface_hub import hf_hub_download
|
|
| 24 |
|
| 25 |
URL = "http://localhost:5834/v1/chat/completions"
|
| 26 |
in_memory_llm = None
|
| 27 |
-
worker_options = ["runpod", "http", "in_memory", "mistral"]
|
| 28 |
|
| 29 |
-
LLM_WORKER = env.get("LLM_WORKER", "
|
| 30 |
if LLM_WORKER not in worker_options:
|
| 31 |
raise ValueError(f"Invalid worker: {LLM_WORKER}")
|
| 32 |
N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
|
|
@@ -250,11 +250,62 @@ def llm_stream_mistral_api(prompt: str, pydantic_model_class=None, attempts=0) -
|
|
| 250 |
return json.loads(output)
|
| 251 |
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
def query_ai_prompt(prompt, replacements, model_class):
|
| 255 |
prompt = replace_text(prompt, replacements)
|
| 256 |
-
if LLM_WORKER == "
|
| 257 |
-
result =
|
| 258 |
if LLM_WORKER == "mistral":
|
| 259 |
result = llm_stream_mistral_api(prompt, model_class)
|
| 260 |
if LLM_WORKER == "runpod":
|
|
|
|
| 24 |
|
| 25 |
URL = "http://localhost:5834/v1/chat/completions"
|
| 26 |
in_memory_llm = None
|
| 27 |
+
worker_options = ["runpod", "http", "in_memory", "mistral", "anthropic"]
|
| 28 |
|
| 29 |
+
LLM_WORKER = env.get("LLM_WORKER", "anthropic")
|
| 30 |
if LLM_WORKER not in worker_options:
|
| 31 |
raise ValueError(f"Invalid worker: {LLM_WORKER}")
|
| 32 |
N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
|
|
|
|
| 250 |
return json.loads(output)
|
| 251 |
|
| 252 |
|
| 253 |
+
def send_anthropic_request(prompt: str):
|
| 254 |
+
api_key = env.get("ANTHROPIC_API_KEY")
|
| 255 |
+
if not api_key:
|
| 256 |
+
print("API key not found. Please set the ANTHROPIC_API_KEY environment variable.")
|
| 257 |
+
return
|
| 258 |
+
|
| 259 |
+
headers = {
|
| 260 |
+
'x-api-key': api_key,
|
| 261 |
+
'anthropic-version': '2023-06-01',
|
| 262 |
+
'Content-Type': 'application/json',
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
data = {
|
| 266 |
+
"model": "claude-3-opus-20240229",
|
| 267 |
+
"max_tokens": 1024,
|
| 268 |
+
"messages": [{"role": "user", "content": prompt}]
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, data=json.dumps(data))
|
| 272 |
+
if response.status_code != 200:
|
| 273 |
+
print(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}")
|
| 274 |
+
raise ValueError(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}")
|
| 275 |
+
j = response.json()
|
| 276 |
+
|
| 277 |
+
text = j['content'][0]["text"]
|
| 278 |
+
print(text)
|
| 279 |
+
return text
|
| 280 |
+
|
| 281 |
+
def llm_anthropic_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]:
|
| 282 |
+
# With no streaming or rate limits, we use the Anthropic API, we have string input and output from send_anthropic_request,
|
| 283 |
+
# but we need to convert it to JSON for the pydantic model class like the other APIs.
|
| 284 |
+
output = send_anthropic_request(prompt)
|
| 285 |
+
if pydantic_model_class:
|
| 286 |
+
try:
|
| 287 |
+
parsed_result = pydantic_model_class.model_validate_json(output)
|
| 288 |
+
print(parsed_result)
|
| 289 |
+
# This will raise an exception if the model is invalid.
|
| 290 |
+
return json.loads(output)
|
| 291 |
+
except Exception as e:
|
| 292 |
+
print(f"Error validating pydantic model: {e}")
|
| 293 |
+
# Let's retry by calling ourselves again if attempts < 3
|
| 294 |
+
if attempts == 0:
|
| 295 |
+
# We modify the prompt to remind it to output JSON in the required format
|
| 296 |
+
prompt = f"{prompt} You must output the JSON in the required format only, with no remarks or prefacing remarks - JUST JSON!"
|
| 297 |
+
if attempts < 3:
|
| 298 |
+
attempts += 1
|
| 299 |
+
print(f"Retrying Anthropic API call, attempt {attempts}")
|
| 300 |
+
return llm_anthropic_api(prompt, pydantic_model_class, attempts)
|
| 301 |
+
else:
|
| 302 |
+
print("No pydantic model class provided, returning without class validation")
|
| 303 |
+
return json.loads(output)
|
| 304 |
|
| 305 |
def query_ai_prompt(prompt, replacements, model_class):
|
| 306 |
prompt = replace_text(prompt, replacements)
|
| 307 |
+
if LLM_WORKER == "anthropic":
|
| 308 |
+
result = llm_anthropic_api(prompt, model_class)
|
| 309 |
if LLM_WORKER == "mistral":
|
| 310 |
result = llm_stream_mistral_api(prompt, model_class)
|
| 311 |
if LLM_WORKER == "runpod":
|