Spaces:
Running
Running
import importlib | |
from typing import Any | |
from colorama import Fore, Style, init | |
import os | |
_SUPPORTED_PROVIDERS = { | |
"openai", | |
"anthropic", | |
"azure_openai", | |
"cohere", | |
"google_vertexai", | |
"google_genai", | |
"fireworks", | |
"ollama", | |
"together", | |
"mistralai", | |
"huggingface", | |
"groq", | |
"bedrock", | |
"dashscope", | |
"xai", | |
"deepseek", | |
"litellm", | |
} | |
class GenericLLMProvider: | |
def __init__(self, llm): | |
self.llm = llm | |
def from_provider(cls, provider: str, **kwargs: Any): | |
if provider == "openai": | |
_check_pkg("langchain_openai") | |
from langchain_openai import ChatOpenAI | |
llm = ChatOpenAI(**kwargs) | |
elif provider == "anthropic": | |
_check_pkg("langchain_anthropic") | |
from langchain_anthropic import ChatAnthropic | |
llm = ChatAnthropic(**kwargs) | |
elif provider == "azure_openai": | |
_check_pkg("langchain_openai") | |
from langchain_openai import AzureChatOpenAI | |
if "model" in kwargs: | |
model_name = kwargs.get("model", None) | |
kwargs = {"azure_deployment": model_name, **kwargs} | |
llm = AzureChatOpenAI(**kwargs) | |
elif provider == "cohere": | |
_check_pkg("langchain_cohere") | |
from langchain_cohere import ChatCohere | |
llm = ChatCohere(**kwargs) | |
elif provider == "google_vertexai": | |
_check_pkg("langchain_google_vertexai") | |
from langchain_google_vertexai import ChatVertexAI | |
llm = ChatVertexAI(**kwargs) | |
elif provider == "google_genai": | |
_check_pkg("langchain_google_genai") | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
llm = ChatGoogleGenerativeAI(**kwargs) | |
elif provider == "fireworks": | |
_check_pkg("langchain_fireworks") | |
from langchain_fireworks import ChatFireworks | |
llm = ChatFireworks(**kwargs) | |
elif provider == "ollama": | |
_check_pkg("langchain_community") | |
from langchain_ollama import ChatOllama | |
llm = ChatOllama(base_url=os.environ["OLLAMA_BASE_URL"], **kwargs) | |
elif provider == "together": | |
_check_pkg("langchain_together") | |
from langchain_together import ChatTogether | |
llm = ChatTogether(**kwargs) | |
elif provider == "mistralai": | |
_check_pkg("langchain_mistralai") | |
from langchain_mistralai import ChatMistralAI | |
llm = ChatMistralAI(**kwargs) | |
elif provider == "huggingface": | |
_check_pkg("langchain_huggingface") | |
from langchain_huggingface import ChatHuggingFace | |
if "model" in kwargs or "model_name" in kwargs: | |
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) | |
kwargs = {"model_id": model_id, **kwargs} | |
llm = ChatHuggingFace(**kwargs) | |
elif provider == "groq": | |
_check_pkg("langchain_groq") | |
from langchain_groq import ChatGroq | |
llm = ChatGroq(**kwargs) | |
elif provider == "bedrock": | |
_check_pkg("langchain_aws") | |
from langchain_aws import ChatBedrock | |
if "model" in kwargs or "model_name" in kwargs: | |
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) | |
kwargs = {"model_id": model_id, "model_kwargs": kwargs} | |
llm = ChatBedrock(**kwargs) | |
elif provider == "dashscope": | |
_check_pkg("langchain_dashscope") | |
from langchain_dashscope import ChatDashScope | |
llm = ChatDashScope(**kwargs) | |
elif provider == "xai": | |
_check_pkg("langchain_xai") | |
from langchain_xai import ChatXAI | |
llm = ChatXAI(**kwargs) | |
elif provider == "deepseek": | |
_check_pkg("langchain_openai") | |
from langchain_openai import ChatOpenAI | |
llm = ChatOpenAI(openai_api_base='https://api.deepseek.com', | |
openai_api_key=os.environ["DEEPSEEK_API_KEY"], | |
**kwargs | |
) | |
elif provider == "litellm": | |
_check_pkg("langchain_community") | |
from langchain_community.chat_models.litellm import ChatLiteLLM | |
llm = ChatLiteLLM(**kwargs) | |
else: | |
supported = ", ".join(_SUPPORTED_PROVIDERS) | |
raise ValueError( | |
f"Unsupported {provider}.\n\nSupported model providers are: {supported}" | |
) | |
return cls(llm) | |
async def get_chat_response(self, messages, stream, websocket=None): | |
if not stream: | |
# Getting output from the model chain using ainvoke for asynchronous invoking | |
output = await self.llm.ainvoke(messages) | |
return output.content | |
else: | |
return await self.stream_response(messages, websocket) | |
async def stream_response(self, messages, websocket=None): | |
paragraph = "" | |
response = "" | |
# Streaming the response using the chain astream method from langchain | |
async for chunk in self.llm.astream(messages): | |
content = chunk.content | |
if content is not None: | |
response += content | |
paragraph += content | |
if "\n" in paragraph: | |
await self._send_output(paragraph, websocket) | |
paragraph = "" | |
if paragraph: | |
await self._send_output(paragraph, websocket) | |
return response | |
async def _send_output(self, content, websocket=None): | |
if websocket is not None: | |
await websocket.send_json({"type": "report", "output": content}) | |
else: | |
print(f"{Fore.GREEN}{content}{Style.RESET_ALL}") | |
def _check_pkg(pkg: str) -> None: | |
if not importlib.util.find_spec(pkg): | |
pkg_kebab = pkg.replace("_", "-") | |
# Import colorama and initialize it | |
init(autoreset=True) | |
# Use Fore.RED to color the error message | |
raise ImportError( | |
Fore.RED + f"Unable to import {pkg_kebab}. Please install with " | |
f"`pip install -U {pkg_kebab}`" | |
) | |