Shreyas094's picture
Upload 528 files
372531f verified
import importlib
from typing import Any
from colorama import Fore, Style, init
import os
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
"dashscope",
"xai",
"deepseek",
"litellm",
}
class GenericLLMProvider:
def __init__(self, llm):
self.llm = llm
@classmethod
def from_provider(cls, provider: str, **kwargs: Any):
if provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(**kwargs)
elif provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(**kwargs)
elif provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
if "model" in kwargs:
model_name = kwargs.get("model", None)
kwargs = {"azure_deployment": model_name, **kwargs}
llm = AzureChatOpenAI(**kwargs)
elif provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
llm = ChatCohere(**kwargs)
elif provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(**kwargs)
elif provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(**kwargs)
elif provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
llm = ChatFireworks(**kwargs)
elif provider == "ollama":
_check_pkg("langchain_community")
from langchain_ollama import ChatOllama
llm = ChatOllama(base_url=os.environ["OLLAMA_BASE_URL"], **kwargs)
elif provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
llm = ChatTogether(**kwargs)
elif provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
llm = ChatMistralAI(**kwargs)
elif provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
if "model" in kwargs or "model_name" in kwargs:
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None)
kwargs = {"model_id": model_id, **kwargs}
llm = ChatHuggingFace(**kwargs)
elif provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
llm = ChatGroq(**kwargs)
elif provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
if "model" in kwargs or "model_name" in kwargs:
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None)
kwargs = {"model_id": model_id, "model_kwargs": kwargs}
llm = ChatBedrock(**kwargs)
elif provider == "dashscope":
_check_pkg("langchain_dashscope")
from langchain_dashscope import ChatDashScope
llm = ChatDashScope(**kwargs)
elif provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
llm = ChatXAI(**kwargs)
elif provider == "deepseek":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(openai_api_base='https://api.deepseek.com',
openai_api_key=os.environ["DEEPSEEK_API_KEY"],
**kwargs
)
elif provider == "litellm":
_check_pkg("langchain_community")
from langchain_community.chat_models.litellm import ChatLiteLLM
llm = ChatLiteLLM(**kwargs)
else:
supported = ", ".join(_SUPPORTED_PROVIDERS)
raise ValueError(
f"Unsupported {provider}.\n\nSupported model providers are: {supported}"
)
return cls(llm)
async def get_chat_response(self, messages, stream, websocket=None):
if not stream:
# Getting output from the model chain using ainvoke for asynchronous invoking
output = await self.llm.ainvoke(messages)
return output.content
else:
return await self.stream_response(messages, websocket)
async def stream_response(self, messages, websocket=None):
paragraph = ""
response = ""
# Streaming the response using the chain astream method from langchain
async for chunk in self.llm.astream(messages):
content = chunk.content
if content is not None:
response += content
paragraph += content
if "\n" in paragraph:
await self._send_output(paragraph, websocket)
paragraph = ""
if paragraph:
await self._send_output(paragraph, websocket)
return response
async def _send_output(self, content, websocket=None):
if websocket is not None:
await websocket.send_json({"type": "report", "output": content})
else:
print(f"{Fore.GREEN}{content}{Style.RESET_ALL}")
def _check_pkg(pkg: str) -> None:
if not importlib.util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-")
# Import colorama and initialize it
init(autoreset=True)
# Use Fore.RED to color the error message
raise ImportError(
Fore.RED + f"Unable to import {pkg_kebab}. Please install with "
f"`pip install -U {pkg_kebab}`"
)