Spaces:
Running
Running
File size: 6,465 Bytes
372531f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import importlib
from typing import Any
from colorama import Fore, Style, init
import os
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
"dashscope",
"xai",
"deepseek",
"litellm",
}
class GenericLLMProvider:
def __init__(self, llm):
self.llm = llm
@classmethod
def from_provider(cls, provider: str, **kwargs: Any):
if provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(**kwargs)
elif provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
llm = ChatAnthropic(**kwargs)
elif provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
if "model" in kwargs:
model_name = kwargs.get("model", None)
kwargs = {"azure_deployment": model_name, **kwargs}
llm = AzureChatOpenAI(**kwargs)
elif provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
llm = ChatCohere(**kwargs)
elif provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
llm = ChatVertexAI(**kwargs)
elif provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(**kwargs)
elif provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
llm = ChatFireworks(**kwargs)
elif provider == "ollama":
_check_pkg("langchain_community")
from langchain_ollama import ChatOllama
llm = ChatOllama(base_url=os.environ["OLLAMA_BASE_URL"], **kwargs)
elif provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
llm = ChatTogether(**kwargs)
elif provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
llm = ChatMistralAI(**kwargs)
elif provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
if "model" in kwargs or "model_name" in kwargs:
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None)
kwargs = {"model_id": model_id, **kwargs}
llm = ChatHuggingFace(**kwargs)
elif provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
llm = ChatGroq(**kwargs)
elif provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
if "model" in kwargs or "model_name" in kwargs:
model_id = kwargs.pop("model", None) or kwargs.pop("model_name", None)
kwargs = {"model_id": model_id, "model_kwargs": kwargs}
llm = ChatBedrock(**kwargs)
elif provider == "dashscope":
_check_pkg("langchain_dashscope")
from langchain_dashscope import ChatDashScope
llm = ChatDashScope(**kwargs)
elif provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
llm = ChatXAI(**kwargs)
elif provider == "deepseek":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(openai_api_base='https://api.deepseek.com',
openai_api_key=os.environ["DEEPSEEK_API_KEY"],
**kwargs
)
elif provider == "litellm":
_check_pkg("langchain_community")
from langchain_community.chat_models.litellm import ChatLiteLLM
llm = ChatLiteLLM(**kwargs)
else:
supported = ", ".join(_SUPPORTED_PROVIDERS)
raise ValueError(
f"Unsupported {provider}.\n\nSupported model providers are: {supported}"
)
return cls(llm)
async def get_chat_response(self, messages, stream, websocket=None):
if not stream:
# Getting output from the model chain using ainvoke for asynchronous invoking
output = await self.llm.ainvoke(messages)
return output.content
else:
return await self.stream_response(messages, websocket)
async def stream_response(self, messages, websocket=None):
paragraph = ""
response = ""
# Streaming the response using the chain astream method from langchain
async for chunk in self.llm.astream(messages):
content = chunk.content
if content is not None:
response += content
paragraph += content
if "\n" in paragraph:
await self._send_output(paragraph, websocket)
paragraph = ""
if paragraph:
await self._send_output(paragraph, websocket)
return response
async def _send_output(self, content, websocket=None):
if websocket is not None:
await websocket.send_json({"type": "report", "output": content})
else:
print(f"{Fore.GREEN}{content}{Style.RESET_ALL}")
def _check_pkg(pkg: str) -> None:
if not importlib.util.find_spec(pkg):
pkg_kebab = pkg.replace("_", "-")
# Import colorama and initialize it
init(autoreset=True)
# Use Fore.RED to color the error message
raise ImportError(
Fore.RED + f"Unable to import {pkg_kebab}. Please install with "
f"`pip install -U {pkg_kebab}`"
)
|