|
import asyncio |
|
import json |
|
import os |
|
from typing import Dict, List, Optional, Union |
|
|
|
import anthropic |
|
import httpcore |
|
import httpx |
|
from anthropic import NOT_GIVEN |
|
from requests.exceptions import ProxyError |
|
|
|
from .base_api import AsyncBaseAPILLM, BaseAPILLM |
|
|
|
|
|
class ClaudeAPI(BaseAPILLM): |
|
|
|
is_api: bool = True |
|
|
|
def __init__( |
|
self, |
|
model_type: str = 'claude-3-5-sonnet-20241022', |
|
retry: int = 5, |
|
key: Union[str, List[str]] = 'ENV', |
|
proxies: Optional[Dict] = None, |
|
meta_template: Optional[Dict] = [ |
|
dict(role='system', api_role='system'), |
|
dict(role='user', api_role='user'), |
|
dict(role='assistant', api_role='assistant'), |
|
dict(role='environment', api_role='user'), |
|
], |
|
temperature: float = NOT_GIVEN, |
|
max_new_tokens: int = 512, |
|
top_p: float = NOT_GIVEN, |
|
top_k: int = NOT_GIVEN, |
|
repetition_penalty: float = 0.0, |
|
stop_words: Union[List[str], str] = None, |
|
): |
|
|
|
super().__init__( |
|
meta_template=meta_template, |
|
model_type=model_type, |
|
retry=retry, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
stop_words=stop_words, |
|
) |
|
|
|
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key |
|
|
|
if isinstance(key, str): |
|
self.keys = [key] |
|
else: |
|
self.keys = list(set(key)) |
|
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} |
|
|
|
|
|
|
|
self.invalid_keys = set() |
|
|
|
self.key_ctr = 0 |
|
self.model_type = model_type |
|
self.proxies = proxies |
|
|
|
def chat( |
|
self, |
|
inputs: Union[List[dict], List[List[dict]]], |
|
session_ids: Union[int, List[int]] = None, |
|
**gen_params, |
|
) -> Union[str, List[str]]: |
|
"""Generate responses given the contexts. |
|
|
|
Args: |
|
inputs (Union[List[dict], List[List[dict]]]): a list of messages |
|
or list of lists of messages |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
Union[str, List[str]]: generated string(s) |
|
""" |
|
assert isinstance(inputs, list) |
|
gen_params = {**self.gen_params, **gen_params} |
|
import nest_asyncio |
|
|
|
nest_asyncio.apply() |
|
|
|
async def run_async_tasks(): |
|
tasks = [ |
|
self._chat(self.template_parser(messages), **gen_params) |
|
for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) |
|
] |
|
return await asyncio.gather(*tasks) |
|
|
|
try: |
|
loop = asyncio.get_running_loop() |
|
|
|
future = asyncio.ensure_future(run_async_tasks()) |
|
ret = loop.run_until_complete(future) |
|
except RuntimeError: |
|
|
|
ret = asyncio.run(run_async_tasks()) |
|
return ret[0] if isinstance(inputs[0], dict) else ret |
|
|
|
def generate_request_data(self, model_type, messages, gen_params): |
|
""" |
|
Generates the request data for different model types. |
|
|
|
Args: |
|
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). |
|
messages (list): The list of messages to be sent to the model. |
|
gen_params (dict): The generation parameters. |
|
json_mode (bool): Flag to determine if the response format should be JSON. |
|
|
|
Returns: |
|
tuple: A tuple containing the header and the request data. |
|
""" |
|
|
|
gen_params = gen_params.copy() |
|
|
|
|
|
max_tokens = min(gen_params.pop('max_new_tokens'), 4096) |
|
if max_tokens <= 0: |
|
return '', '' |
|
gen_params.pop('repetition_penalty') |
|
if 'stop_words' in gen_params: |
|
gen_params['stop_sequences'] = gen_params.pop('stop_words') |
|
|
|
gen_params['max_tokens'] = max_tokens |
|
gen_params.pop('skip_special_tokens', None) |
|
gen_params.pop('session_id', None) |
|
|
|
system = None |
|
if messages[0]['role'] == 'system': |
|
system = messages.pop(0) |
|
system = system['content'] |
|
for message in messages: |
|
message.pop('name', None) |
|
data = {'model': model_type, 'messages': messages, **gen_params} |
|
if system: |
|
data['system'] = system |
|
return data |
|
|
|
async def _chat(self, messages: List[dict], **gen_params) -> str: |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
messages (List[dict]): a list of prompt dictionaries |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: The generated string. |
|
""" |
|
assert isinstance(messages, list) |
|
|
|
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) |
|
max_num_retries = 0 |
|
|
|
while max_num_retries < self.retry: |
|
if len(self.invalid_keys) == len(self.keys): |
|
raise RuntimeError('All keys have insufficient quota.') |
|
|
|
while True: |
|
self.key_ctr += 1 |
|
if self.key_ctr == len(self.keys): |
|
self.key_ctr = 0 |
|
|
|
if self.keys[self.key_ctr] not in self.invalid_keys: |
|
break |
|
|
|
key = self.keys[self.key_ctr] |
|
client = self.clients[key] |
|
|
|
try: |
|
response = await client.messages.create(**data) |
|
response = json.loads(response.json()) |
|
return response['content'][0]['text'].strip() |
|
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: |
|
print(f'API请求错误: {e}') |
|
await asyncio.sleep(5) |
|
|
|
except (httpcore.ProxyError, ProxyError) as e: |
|
|
|
print(f'代理服务器错误: {e}') |
|
await asyncio.sleep(5) |
|
except httpx.TimeoutException as e: |
|
print(f'请求超时: {e}') |
|
await asyncio.sleep(5) |
|
|
|
except KeyboardInterrupt: |
|
raise |
|
|
|
except Exception as error: |
|
if error.body['error']['message'] == 'invalid x-api-key': |
|
self.invalid_keys.add(key) |
|
self.logger.warn(f'invalid key: {key}') |
|
elif error.body['error']['type'] == 'overloaded_error': |
|
await asyncio.sleep(5) |
|
elif error.body['error']['message'] == 'Internal server error': |
|
await asyncio.sleep(5) |
|
elif error.body['error']['message'] == ( |
|
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to ' |
|
'upgrade or purchase credits.' |
|
): |
|
self.invalid_keys.add(key) |
|
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') |
|
max_num_retries += 1 |
|
|
|
raise RuntimeError( |
|
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' |
|
) |
|
|
|
|
|
class AsyncClaudeAPI(AsyncBaseAPILLM): |
|
|
|
is_api: bool = True |
|
|
|
def __init__( |
|
self, |
|
model_type: str = 'claude-3-5-sonnet-20241022', |
|
retry: int = 5, |
|
key: Union[str, List[str]] = 'ENV', |
|
proxies: Optional[Dict] = None, |
|
meta_template: Optional[Dict] = [ |
|
dict(role='system', api_role='system'), |
|
dict(role='user', api_role='user'), |
|
dict(role='assistant', api_role='assistant'), |
|
dict(role='environment', api_role='user'), |
|
], |
|
temperature: float = NOT_GIVEN, |
|
max_new_tokens: int = 512, |
|
top_p: float = NOT_GIVEN, |
|
top_k: int = NOT_GIVEN, |
|
repetition_penalty: float = 0.0, |
|
stop_words: Union[List[str], str] = None, |
|
): |
|
|
|
super().__init__( |
|
model_type=model_type, |
|
retry=retry, |
|
meta_template=meta_template, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
stop_words=stop_words, |
|
) |
|
|
|
key = os.getenv('Claude_API_KEY') if key == 'ENV' else key |
|
|
|
if isinstance(key, str): |
|
self.keys = [key] |
|
else: |
|
self.keys = list(set(key)) |
|
self.clients = {key: anthropic.AsyncAnthropic(proxies=proxies, api_key=key) for key in self.keys} |
|
|
|
|
|
|
|
self.invalid_keys = set() |
|
|
|
self.key_ctr = 0 |
|
self.model_type = model_type |
|
self.proxies = proxies |
|
|
|
async def chat( |
|
self, |
|
inputs: Union[List[dict], List[List[dict]]], |
|
session_ids: Union[int, List[int]] = None, |
|
**gen_params, |
|
) -> Union[str, List[str]]: |
|
"""Generate responses given the contexts. |
|
|
|
Args: |
|
inputs (Union[List[dict], List[List[dict]]]): a list of messages |
|
or list of lists of messages |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
Union[str, List[str]]: generated string(s) |
|
""" |
|
assert isinstance(inputs, list) |
|
gen_params = {**self.gen_params, **gen_params} |
|
tasks = [ |
|
self._chat(messages, **gen_params) for messages in ([inputs] if isinstance(inputs[0], dict) else inputs) |
|
] |
|
ret = await asyncio.gather(*tasks) |
|
return ret[0] if isinstance(inputs[0], dict) else ret |
|
|
|
def generate_request_data(self, model_type, messages, gen_params): |
|
""" |
|
Generates the request data for different model types. |
|
|
|
Args: |
|
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen'). |
|
messages (list): The list of messages to be sent to the model. |
|
gen_params (dict): The generation parameters. |
|
json_mode (bool): Flag to determine if the response format should be JSON. |
|
|
|
Returns: |
|
tuple: A tuple containing the header and the request data. |
|
""" |
|
|
|
gen_params = gen_params.copy() |
|
|
|
|
|
max_tokens = min(gen_params.pop('max_new_tokens'), 4096) |
|
if max_tokens <= 0: |
|
return '', '' |
|
gen_params.pop('repetition_penalty') |
|
if 'stop_words' in gen_params: |
|
gen_params['stop_sequences'] = gen_params.pop('stop_words') |
|
|
|
gen_params['max_tokens'] = max_tokens |
|
gen_params.pop('skip_special_tokens', None) |
|
gen_params.pop('session_id', None) |
|
|
|
system = None |
|
if messages[0]['role'] == 'system': |
|
system = messages.pop(0) |
|
system = system['content'] |
|
for message in messages: |
|
message.pop('name', None) |
|
data = {'model': model_type, 'messages': messages, **gen_params} |
|
if system: |
|
data['system'] = system |
|
return data |
|
|
|
async def _chat(self, messages: List[dict], **gen_params) -> str: |
|
"""Generate completion from a list of templates. |
|
|
|
Args: |
|
messages (List[dict]): a list of prompt dictionaries |
|
gen_params: additional generation configuration |
|
|
|
Returns: |
|
str: The generated string. |
|
""" |
|
assert isinstance(messages, list) |
|
messages = self.template_parser(messages) |
|
data = self.generate_request_data(model_type=self.model_type, messages=messages, gen_params=gen_params) |
|
max_num_retries = 0 |
|
|
|
while max_num_retries < self.retry: |
|
if len(self.invalid_keys) == len(self.keys): |
|
raise RuntimeError('All keys have insufficient quota.') |
|
|
|
while True: |
|
self.key_ctr += 1 |
|
if self.key_ctr == len(self.keys): |
|
self.key_ctr = 0 |
|
|
|
if self.keys[self.key_ctr] not in self.invalid_keys: |
|
break |
|
|
|
key = self.keys[self.key_ctr] |
|
client = self.clients[key] |
|
|
|
try: |
|
response = await client.messages.create(**data) |
|
response = json.loads(response.json()) |
|
return response['content'][0]['text'].strip() |
|
except (anthropic.RateLimitError, anthropic.APIConnectionError) as e: |
|
print(f'API请求错误: {e}') |
|
await asyncio.sleep(5) |
|
|
|
except (httpcore.ProxyError, ProxyError) as e: |
|
|
|
print(f'代理服务器错误: {e}') |
|
await asyncio.sleep(5) |
|
except httpx.TimeoutException as e: |
|
print(f'请求超时: {e}') |
|
await asyncio.sleep(5) |
|
|
|
except KeyboardInterrupt: |
|
raise |
|
|
|
except Exception as error: |
|
if error.body['error']['message'] == 'invalid x-api-key': |
|
self.invalid_keys.add(key) |
|
self.logger.warn(f'invalid key: {key}') |
|
elif error.body['error']['type'] == 'overloaded_error': |
|
await asyncio.sleep(5) |
|
elif error.body['error']['message'] == 'Internal server error': |
|
await asyncio.sleep(5) |
|
elif error.body['error']['message'] == ( |
|
'Your credit balance is too low to access the Anthropic API. Please go to Plans & Billing to' |
|
' upgrade or purchase credits.' |
|
): |
|
self.invalid_keys.add(key) |
|
print(f'API has no quota: {key}, Valid keys: {len(self.keys) - len(self.invalid_keys)}') |
|
else: |
|
raise error |
|
max_num_retries += 1 |
|
|
|
raise RuntimeError( |
|
'Calling Claude failed after retrying for ' f'{max_num_retries} times. Check the logs for ' 'details.' |
|
) |
|
|