|
from openai import OpenAI
|
|
import pdb
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.globals import get_llm_cache
|
|
from langchain_core.language_models.base import (
|
|
BaseLanguageModel,
|
|
LangSmithParams,
|
|
LanguageModelInput,
|
|
)
|
|
import os
|
|
from langchain_core.load import dumpd, dumps
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
SystemMessage,
|
|
AnyMessage,
|
|
BaseMessage,
|
|
BaseMessageChunk,
|
|
HumanMessage,
|
|
convert_to_messages,
|
|
message_chunk_to_message,
|
|
)
|
|
from langchain_core.outputs import (
|
|
ChatGeneration,
|
|
ChatGenerationChunk,
|
|
ChatResult,
|
|
LLMResult,
|
|
RunInfo,
|
|
)
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_core.output_parsers.base import OutputParserLike
|
|
from langchain_core.runnables import Runnable, RunnableConfig
|
|
from langchain_core.tools import BaseTool
|
|
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Literal,
|
|
Optional,
|
|
Union,
|
|
cast, List,
|
|
)
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_mistralai import ChatMistralAI
|
|
from langchain_google_genai import ChatGoogleGenerativeAI
|
|
from langchain_ollama import ChatOllama
|
|
from langchain_openai import AzureChatOpenAI, ChatOpenAI
|
|
from langchain_ibm import ChatWatsonx
|
|
from langchain_aws import ChatBedrock
|
|
from pydantic import SecretStr
|
|
|
|
from src.utils import config
|
|
|
|
|
|
class DeepSeekR1ChatOpenAI(ChatOpenAI):
|
|
|
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
self.client = OpenAI(
|
|
base_url=kwargs.get("base_url"),
|
|
api_key=kwargs.get("api_key")
|
|
)
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
message_history = []
|
|
for input_ in input:
|
|
if isinstance(input_, SystemMessage):
|
|
message_history.append({"role": "system", "content": input_.content})
|
|
elif isinstance(input_, AIMessage):
|
|
message_history.append({"role": "assistant", "content": input_.content})
|
|
else:
|
|
message_history.append({"role": "user", "content": input_.content})
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=message_history
|
|
)
|
|
|
|
reasoning_content = response.choices[0].message.reasoning_content
|
|
content = response.choices[0].message.content
|
|
return AIMessage(content=content, reasoning_content=reasoning_content)
|
|
|
|
def invoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
message_history = []
|
|
for input_ in input:
|
|
if isinstance(input_, SystemMessage):
|
|
message_history.append({"role": "system", "content": input_.content})
|
|
elif isinstance(input_, AIMessage):
|
|
message_history.append({"role": "assistant", "content": input_.content})
|
|
else:
|
|
message_history.append({"role": "user", "content": input_.content})
|
|
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=message_history
|
|
)
|
|
|
|
reasoning_content = response.choices[0].message.reasoning_content
|
|
content = response.choices[0].message.content
|
|
return AIMessage(content=content, reasoning_content=reasoning_content)
|
|
|
|
|
|
class DeepSeekR1ChatOllama(ChatOllama):
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
org_ai_message = await super().ainvoke(input=input)
|
|
org_content = org_ai_message.content
|
|
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
|
content = org_content.split("</think>")[1]
|
|
if "**JSON Response:**" in content:
|
|
content = content.split("**JSON Response:**")[-1]
|
|
return AIMessage(content=content, reasoning_content=reasoning_content)
|
|
|
|
def invoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[list[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AIMessage:
|
|
org_ai_message = super().invoke(input=input)
|
|
org_content = org_ai_message.content
|
|
reasoning_content = org_content.split("</think>")[0].replace("<think>", "")
|
|
content = org_content.split("</think>")[1]
|
|
if "**JSON Response:**" in content:
|
|
content = content.split("**JSON Response:**")[-1]
|
|
return AIMessage(content=content, reasoning_content=reasoning_content)
|
|
|
|
|
|
def get_llm_model(provider: str, **kwargs):
|
|
"""
|
|
Get LLM model
|
|
:param provider: LLM provider
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
if provider not in ["ollama", "bedrock"]:
|
|
env_var = f"{provider.upper()}_API_KEY"
|
|
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
|
|
if not api_key:
|
|
provider_display = config.PROVIDER_DISPLAY_NAMES.get(provider, provider.upper())
|
|
error_msg = f"💥 {provider_display} API key not found! 🔑 Please set the `{env_var}` environment variable or provide it in the UI."
|
|
raise ValueError(error_msg)
|
|
kwargs["api_key"] = api_key
|
|
|
|
if provider == "anthropic":
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = "https://api.anthropic.com"
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
|
|
return ChatAnthropic(
|
|
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
elif provider == 'mistral':
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
if not kwargs.get("api_key", ""):
|
|
api_key = os.getenv("MISTRAL_API_KEY", "")
|
|
else:
|
|
api_key = kwargs.get("api_key")
|
|
|
|
return ChatMistralAI(
|
|
model=kwargs.get("model_name", "mistral-large-latest"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "openai":
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
|
|
return ChatOpenAI(
|
|
model=kwargs.get("model_name", "gpt-4o"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "grok":
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("GROK_ENDPOINT", "https://api.x.ai/v1")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
|
|
return ChatOpenAI(
|
|
model=kwargs.get("model_name", "grok-3"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "deepseek":
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
|
|
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
|
|
return DeepSeekR1ChatOpenAI(
|
|
model=kwargs.get("model_name", "deepseek-reasoner"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
else:
|
|
return ChatOpenAI(
|
|
model=kwargs.get("model_name", "deepseek-chat"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "google":
|
|
return ChatGoogleGenerativeAI(
|
|
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "ollama":
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
|
|
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
|
|
return DeepSeekR1ChatOllama(
|
|
model=kwargs.get("model_name", "deepseek-r1:14b"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
num_ctx=kwargs.get("num_ctx", 32000),
|
|
base_url=base_url,
|
|
)
|
|
else:
|
|
return ChatOllama(
|
|
model=kwargs.get("model_name", "qwen2.5:7b"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
num_ctx=kwargs.get("num_ctx", 32000),
|
|
num_predict=kwargs.get("num_predict", 1024),
|
|
base_url=base_url,
|
|
)
|
|
elif provider == "azure_openai":
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
|
|
return AzureChatOpenAI(
|
|
model=kwargs.get("model_name", "gpt-4o"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
api_version=api_version,
|
|
azure_endpoint=base_url,
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "alibaba":
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
|
|
return ChatOpenAI(
|
|
model=kwargs.get("model_name", "qwen-plus"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=base_url,
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "ibm":
|
|
parameters = {
|
|
"temperature": kwargs.get("temperature", 0.0),
|
|
"max_tokens": kwargs.get("num_ctx", 32000)
|
|
}
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("IBM_ENDPOINT", "https://us-south.ml.cloud.ibm.com")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
|
|
return ChatWatsonx(
|
|
model_id=kwargs.get("model_name", "ibm/granite-vision-3.1-2b-preview"),
|
|
url=base_url,
|
|
project_id=os.getenv("IBM_PROJECT_ID"),
|
|
apikey=os.getenv("IBM_API_KEY"),
|
|
params=parameters
|
|
)
|
|
elif provider == "moonshot":
|
|
return ChatOpenAI(
|
|
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=os.getenv("MOONSHOT_ENDPOINT"),
|
|
api_key=os.getenv("MOONSHOT_API_KEY"),
|
|
)
|
|
elif provider == "unbound":
|
|
return ChatOpenAI(
|
|
model=kwargs.get("model_name", "gpt-4o-mini"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"),
|
|
api_key=api_key,
|
|
)
|
|
elif provider == "siliconflow":
|
|
if not kwargs.get("api_key", ""):
|
|
api_key = os.getenv("SiliconFLOW_API_KEY", "")
|
|
else:
|
|
api_key = kwargs.get("api_key")
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
return ChatOpenAI(
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
)
|
|
elif provider == "modelscope":
|
|
if not kwargs.get("api_key", ""):
|
|
api_key = os.getenv("MODELSCOPE_API_KEY", "")
|
|
else:
|
|
api_key = kwargs.get("api_key")
|
|
if not kwargs.get("base_url", ""):
|
|
base_url = os.getenv("MODELSCOPE_ENDPOINT", "")
|
|
else:
|
|
base_url = kwargs.get("base_url")
|
|
return ChatOpenAI(
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
|
|
temperature=kwargs.get("temperature", 0.0),
|
|
extra_body = {"enable_thinking": False}
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported provider: {provider}")
|
|
|