hgagent / lagent /llms /anthropic_llm.py
Chisleep
update01
f8d0193
raw
history blame
14.7 kB
import asyncio
import json
import os
from typing import Dict, List, Optional, Union
import anthropic
import httpcore
import httpx
from anthropic import NOT_GIVEN
from requests.exceptions import ProxyError
from .base_api import AsyncBaseAPILLM, BaseAPILLM
class ClaudeAPI(BaseAPILLM):
is_api: bool = True
def __init__(
self,
model_type: str = 'claude-3-5-sonnet-20241022',
retry: int = 5,
key: Union[str, List[str]] = 'ENV',
proxies: Optional[Dict] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='user'),
],
temperature: float = NOT_GIVEN,
max_new_tokens: int = 512,
top_p: float = NOT_GIVEN,
top_k: int = NOT_GIVEN,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None,
):
super().__init__(
meta_template=meta_template,
model_type=model_type,
retry=retry,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
)
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key
if isinstance(key, str):
self.keys = [key]
else:
self.keys = list(set(key))
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys}
# record invalid keys and skip them when requesting API
# - keys have insufficient_quota
self.invalid_keys = set()
self.key_ctr = 0
self.model_type = model_type
self.proxies = proxies
def chat(
self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params,
) -> Union[str, List[str]]:
"""Generate responses given the contexts.
Args:
inputs (Union[List[dict], List[List[dict]]]): a list of messages
or list of lists of messages
gen_params: additional generation configuration
Returns:
Union[str, List[str]]: generated string(s)
"""
assert isinstance(inputs, list)
gen_params = {**self.gen_params, **gen_params}
import nest_asyncio
nest_asyncio.apply()
async def run_async_tasks():
tasks = [
self._chat(self.template_parser(messages), **gen_params)
for messages in ([inputs] if isinstance(inputs[0], dict) else inputs)
]
return await asyncio.gather(*tasks)
try:
loop = asyncio.get_running_loop()
# If the event loop is already running, schedule the task
future = asyncio.ensure_future(run_async_tasks())
ret = loop.run_until_complete(future)
except RuntimeError:
# If no running event loop, start a new one
ret = asyncio.run(run_async_tasks())
return ret[0] if isinstance(inputs[0], dict) else ret
def generate_request_data(self, model_type, messages, gen_params):
"""
Generates the request data for different model types.
Args:
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
messages (list): The list of messages to be sent to the model.
gen_params (dict): The generation parameters.
json_mode (bool): Flag to determine if the response format should be JSON.
Returns:
tuple: A tuple containing the header and the request data.
"""
# Copy generation parameters to avoid modifying the original dictionary
gen_params = gen_params.copy()
# Hold out 100 tokens due to potential errors in token calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return '', ''
gen_params.pop('repetition_penalty')
if 'stop_words' in gen_params:
gen_params['stop_sequences'] = gen_params.pop('stop_words')
# Common parameters processing
gen_params['max_tokens'] = max_tokens
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
system = None
if messages[0]['role'] == 'system':
system = messages.pop(0)
system = system['content']
for message in messages:
message.pop('name', None)
data = {'model': model_type, 'messages': messages, **gen_params}
if system:
data['system'] = system
return data
async def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Returns:
str: The generated string.
"""
assert isinstance(messages, list)
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params)
max_num_retries = 0
while max_num_retries < self.retry:
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')
# find the next valid key
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0
if self.keys[self.key_ctr] not in self.invalid_keys:
break
key = self.keys[self.key_ctr]
client = self.clients[key]
try:
response = await client.messages.create(**data)
response = json.loads(response.json())
return response['content'][0]['text'].strip()
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e:
print(f'API请求错误: {e}')
await asyncio.sleep(5)
except (httpcore.ProxyError, ProxyError) as e:
print(f'代理服务器错误: {e}')
await asyncio.sleep(5)
except httpx.TimeoutException as e:
print(f'请求超时: {e}')
await asyncio.sleep(5)
except KeyboardInterrupt:
raise
except Exception as error:
if error.body['error']['message'] == 'invalid x-api-key':
self.invalid_keys.add(key)
self.logger.warn(f'invalid key: {key}')
elif error.body['error']['type'] == 'overloaded_error':
await asyncio.sleep(5)
elif error.body['error']['message'] == 'Internal server error':
await asyncio.sleep(5)
elif error.body['error']['message'] == (
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to '
'upgrade or purchase credits.'
):
self.invalid_keys.add(key)
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}')
max_num_retries += 1
raise RuntimeError(
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.'
)
class AsyncClaudeAPI(AsyncBaseAPILLM):
is_api: bool = True
def __init__(
self,
model_type: str = 'claude-3-5-sonnet-20241022',
retry: int = 5,
key: Union[str, List[str]] = 'ENV',
proxies: Optional[Dict] = None,
meta_template: Optional[Dict] = [
dict(role='system', api_role='system'),
dict(role='user', api_role='user'),
dict(role='assistant', api_role='assistant'),
dict(role='environment', api_role='user'),
],
temperature: float = NOT_GIVEN,
max_new_tokens: int = 512,
top_p: float = NOT_GIVEN,
top_k: int = NOT_GIVEN,
repetition_penalty: float = 0.0,
stop_words: Union[List[str], str] = None,
):
super().__init__(
model_type=model_type,
retry=retry,
meta_template=meta_template,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
stop_words=stop_words,
)
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key
if isinstance(key, str):
self.keys = [key]
else:
self.keys = list(set(key))
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys}
# record invalid keys and skip them when requesting API
# - keys have insufficient_quota
self.invalid_keys = set()
self.key_ctr = 0
self.model_type = model_type
self.proxies = proxies
async def chat(
self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params,
) -> Union[str, List[str]]:
"""Generate responses given the contexts.
Args:
inputs (Union[List[dict], List[List[dict]]]): a list of messages
or list of lists of messages
gen_params: additional generation configuration
Returns:
Union[str, List[str]]: generated string(s)
"""
assert isinstance(inputs, list)
gen_params = {**self.gen_params, **gen_params}
tasks = [
self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs)
]
ret = await asyncio.gather(*tasks)
return ret[0] if isinstance(inputs[0], dict) else ret
def generate_request_data(self, model_type, messages, gen_params):
"""
Generates the request data for different model types.
Args:
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
messages (list): The list of messages to be sent to the model.
gen_params (dict): The generation parameters.
json_mode (bool): Flag to determine if the response format should be JSON.
Returns:
tuple: A tuple containing the header and the request data.
"""
# Copy generation parameters to avoid modifying the original dictionary
gen_params = gen_params.copy()
# Hold out 100 tokens due to potential errors in token calculation
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
if max_tokens <= 0:
return '', ''
gen_params.pop('repetition_penalty')
if 'stop_words' in gen_params:
gen_params['stop_sequences'] = gen_params.pop('stop_words')
# Common parameters processing
gen_params['max_tokens'] = max_tokens
gen_params.pop('skip_special_tokens', None)
gen_params.pop('session_id', None)
system = None
if messages[0]['role'] == 'system':
system = messages.pop(0)
system = system['content']
for message in messages:
message.pop('name', None)
data = {'model': model_type, 'messages': messages, **gen_params}
if system:
data['system'] = system
return data
async def _chat(self, messages: List[dict], **gen_params) -> str:
"""Generate completion from a list of templates.
Args:
messages (List[dict]): a list of prompt dictionaries
gen_params: additional generation configuration
Returns:
str: The generated string.
"""
assert isinstance(messages, list)
messages = self.template_parser(messages)
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params)
max_num_retries = 0
while max_num_retries < self.retry:
if len(self.invalid_keys) == len(self.keys):
raise RuntimeError('All keys have insufficient quota.')
# find the next valid key
while True:
self.key_ctr += 1
if self.key_ctr == len(self.keys):
self.key_ctr = 0
if self.keys[self.key_ctr] not in self.invalid_keys:
break
key = self.keys[self.key_ctr]
client = self.clients[key]
try:
response = await client.messages.create(**data)
response = json.loads(response.json())
return response['content'][0]['text'].strip()
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e:
print(f'API请求错误: {e}')
await asyncio.sleep(5)
except (httpcore.ProxyError, ProxyError) as e:
print(f'代理服务器错误: {e}')
await asyncio.sleep(5)
except httpx.TimeoutException as e:
print(f'请求超时: {e}')
await asyncio.sleep(5)
except KeyboardInterrupt:
raise
except Exception as error:
if error.body['error']['message'] == 'invalid x-api-key':
self.invalid_keys.add(key)
self.logger.warn(f'invalid key: {key}')
elif error.body['error']['type'] == 'overloaded_error':
await asyncio.sleep(5)
elif error.body['error']['message'] == 'Internal server error':
await asyncio.sleep(5)
elif error.body['error']['message'] == (
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to'
' upgrade or purchase credits.'
):
self.invalid_keys.add(key)
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}')
else:
raise error
max_num_retries += 1
raise RuntimeError(
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.'
)