import importlib from typing import Any from colorama import Fore, Style, init import os _SUPPORTED_PROVIDERS = { "openai", "anthropic", "azure_openai", "cohere", "google_vertexai", "google_genai", "fireworks", "ollama", "together", "mistralai", "huggingface", "groq", "bedrock", "dashscope", "xai", "deepseek", "litellm", } class GenericLLMProvider: def __init__(self, llm): self.llm = llm @classmethod def from_provider(cls, provider: str, **kwargs: Any): if provider == "openai": _check_pkg("langchain_openai") from langchain_openai import ChatOpenAI llm = ChatOpenAI(**kwargs) elif provider == "anthropic": _check_pkg("langchain_anthropic") from langchain_anthropic import ChatAnthropic llm = ChatAnthropic(**kwargs) elif provider == "azure_openai": _check_pkg("langchain_openai") from langchain_openai import AzureChatOpenAI if "model" in kwargs: model_name = kwargs.get("model", None) kwargs = {"azure_deployment": model_name, **kwargs} llm = AzureChatOpenAI(**kwargs) elif provider == "cohere": _check_pkg("langchain_cohere") from langchain_cohere import ChatCohere llm = ChatCohere(**kwargs) elif provider == "google_vertexai": _check_pkg("langchain_google_vertexai") from langchain_google_vertexai import ChatVertexAI llm = ChatVertexAI(**kwargs) elif provider == "google_genai": _check_pkg("langchain_google_genai") from langchain_google_genai import ChatGoogleGenerativeAI llm = ChatGoogleGenerativeAI(**kwargs) elif provider == "fireworks": _check_pkg("langchain_fireworks") from langchain_fireworks import ChatFireworks llm = ChatFireworks(**kwargs) elif provider == "ollama": _check_pkg("langchain_community") from langchain_ollama import ChatOllama llm = ChatOllama(base_url=os.environ["OLLAMA_BASE_URL"], **kwargs) elif provider == "together": _check_pkg("langchain_together") from langchain_together import ChatTogether llm = ChatTogether(**kwargs) elif provider == "mistralai": _check_pkg("langchain_mistralai") from langchain_mistralai import ChatMistralAI llm = ChatMistralAI(**kwargs) elif provider == "huggingface": _check_pkg("langchain_huggingface") from langchain_huggingface import ChatHuggingFace if "model" in kwargs or "model_name" in kwargs: model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) kwargs = {"model_id": model_id, **kwargs} llm = ChatHuggingFace(**kwargs) elif provider == "groq": _check_pkg("langchain_groq") from langchain_groq import ChatGroq llm = ChatGroq(**kwargs) elif provider == "bedrock": _check_pkg("langchain_aws") from langchain_aws import ChatBedrock if "model" in kwargs or "model_name" in kwargs: model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None) kwargs = {"model_id": model_id, "model_kwargs": kwargs} llm = ChatBedrock(**kwargs) elif provider == "dashscope": _check_pkg("langchain_dashscope") from langchain_dashscope import ChatDashScope llm = ChatDashScope(**kwargs) elif provider == "xai": _check_pkg("langchain_xai") from langchain_xai import ChatXAI llm = ChatXAI(**kwargs) elif provider == "deepseek": _check_pkg("langchain_openai") from langchain_openai import ChatOpenAI llm = ChatOpenAI(openai_api_base='https://api.deepseek.com', openai_api_key=os.environ["DEEPSEEK_API_KEY"], **kwargs ) elif provider == "litellm": _check_pkg("langchain_community") from langchain_community.chat_models.litellm import ChatLiteLLM llm = ChatLiteLLM(**kwargs) else: supported = ", ".join(_SUPPORTED_PROVIDERS) raise ValueError( f"Unsupported {provider}.\n\nSupported model providers are: {supported}" ) return cls(llm) async def get_chat_response(self, messages, stream, websocket=None): if not stream: # Getting output from the model chain using ainvoke for asynchronous invoking output = await self.llm.ainvoke(messages) return output.content else: return await self.stream_response(messages, websocket) async def stream_response(self, messages, websocket=None): paragraph = "" response = "" # Streaming the response using the chain astream method from langchain async for chunk in self.llm.astream(messages): content = chunk.content if content is not None: response += content paragraph += content if "\n" in paragraph: await self._send_output(paragraph, websocket) paragraph = "" if paragraph: await self._send_output(paragraph, websocket) return response async def _send_output(self, content, websocket=None): if websocket is not None: await websocket.send_json({"type": "report", "output": content}) else: print(f"{Fore.GREEN}{content}{Style.RESET_ALL}") def _check_pkg(pkg: str) -> None: if not importlib.util.find_spec(pkg): pkg_kebab = pkg.replace("_", "-") # Import colorama and initialize it init(autoreset=True) # Use Fore.RED to color the error message raise ImportError( Fore.RED + f"Unable to import {pkg_kebab}. Please install with " f"`pip install -U {pkg_kebab}`" )