|
from __future__ import annotations |
|
|
|
import requests |
|
from aiohttp import ClientSession |
|
|
|
from .OpenaiAPI import OpenaiAPI |
|
from ...typing import AsyncResult, Messages, Cookies |
|
from ...requests.raise_for_status import raise_for_status |
|
from ...cookies import get_cookies |
|
|
|
class Cerebras(OpenaiAPI): |
|
label = "Cerebras Inference" |
|
url = "https://inference.cerebras.ai/" |
|
working = True |
|
default_model = "llama3.1-70b" |
|
fallback_models = [ |
|
"llama3.1-70b", |
|
"llama3.1-8b", |
|
] |
|
model_aliases = {"llama-3.1-70b": "llama3.1-70b", "llama-3.1-8b": "llama3.1-8b"} |
|
|
|
@classmethod |
|
def get_models(cls, api_key: str = None): |
|
if not cls.models: |
|
try: |
|
headers = {} |
|
if api_key: |
|
headers["authorization"] = f"Bearer ${api_key}" |
|
response = requests.get(f"https://api.cerebras.ai/v1/models", headers=headers) |
|
raise_for_status(response) |
|
data = response.json() |
|
cls.models = [model.get("model") for model in data.get("models")] |
|
except Exception: |
|
cls.models = cls.fallback_models |
|
return cls.models |
|
|
|
@classmethod |
|
async def create_async_generator( |
|
cls, |
|
model: str, |
|
messages: Messages, |
|
api_base: str = "https://api.cerebras.ai/v1", |
|
api_key: str = None, |
|
cookies: Cookies = None, |
|
**kwargs |
|
) -> AsyncResult: |
|
if api_key is None and cookies is None: |
|
cookies = get_cookies(".cerebras.ai") |
|
async with ClientSession(cookies=cookies) as session: |
|
async with session.get("https://inference.cerebras.ai/api/auth/session") as response: |
|
raise_for_status(response) |
|
data = await response.json() |
|
if data: |
|
api_key = data.get("user", {}).get("demoApiKey") |
|
async for chunk in super().create_async_generator( |
|
model, messages, |
|
api_base=api_base, |
|
impersonate="chrome", |
|
api_key=api_key, |
|
headers={ |
|
"User-Agent": "ex/JS 1.5.0", |
|
}, |
|
**kwargs |
|
): |
|
yield chunk |
|
|