|
from typing import Optional, Dict, Any |
|
from dataclasses import dataclass |
|
import os |
|
from enum import Enum |
|
import logging |
|
from openai import OpenAI |
|
from anthropic import Anthropic |
|
import groq |
|
import google.generativeai as palm |
|
from smolagents import HfApiModel, CodeAgent, DuckDuckGoSearchTool, load_tool, tool |
|
import datetime |
|
import requests |
|
import pytz |
|
import yaml |
|
from tools.final_answer import FinalAnswerTool |
|
from tools.visit_webpage import VisitWebpageTool |
|
from tools.web_search import DuckDuckGoSearchTool |
|
from tools.linkedin_job_search import LinkedInJobSearchTool |
|
from tools.odoo_documentation_search import OdooDocumentationSearchTool |
|
from tools.odoo_code_agent_16 import OdooCodeAgent16 |
|
from tools.odoo_code_agent_17 import OdooCodeAgent17 |
|
from tools.odoo_code_agent_18 import OdooCodeAgent18 |
|
|
|
from Gradio_UI import GradioUI |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
os.environ["TRANSFORMERS_OFFLINE"] = "1" |
|
os.environ["TORCH_MPS_FORCE_CPU"] = "1" |
|
|
|
|
|
|
|
@tool |
|
def my_custom_tool(arg1:str, arg2:int)-> str: |
|
|
|
"""A tool that does nothing yet |
|
Args: |
|
arg1: the first argument |
|
arg2: the second argument |
|
""" |
|
return "What magic will you build ?" |
|
|
|
@tool |
|
def get_current_time_in_timezone(timezone: str) -> str: |
|
"""A tool that fetches the current local time in a specified timezone. |
|
Args: |
|
timezone: A string representing a valid timezone (e.g., 'America/New_York'). |
|
""" |
|
try: |
|
|
|
tz = pytz.timezone(timezone) |
|
|
|
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") |
|
return f"The current local time in {timezone} is: {local_time}" |
|
except Exception as e: |
|
return f"Error fetching time for timezone '{timezone}': {str(e)}" |
|
|
|
with open("prompts.yaml", 'r') as stream: |
|
prompt_templates = yaml.safe_load(stream) |
|
|
|
final_answer = FinalAnswerTool() |
|
visit_webpage = VisitWebpageTool() |
|
web_search = DuckDuckGoSearchTool() |
|
job_search_tool = LinkedInJobSearchTool() |
|
odoo_documentation_search_tool = OdooDocumentationSearchTool() |
|
odoo_code_agent_16_tool = OdooCodeAgent16(prompt_templates["system_prompt"]) |
|
odoo_code_agent_17_tool = OdooCodeAgent17(prompt_templates["system_prompt"]) |
|
odoo_code_agent_18_tool = OdooCodeAgent18(prompt_templates["system_prompt"]) |
|
|
|
|
|
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True) |
|
|
|
class ModelProvider(Enum): |
|
QWEN = "Qwen" |
|
HUGGINGFACE = "HuggingFace" |
|
OPENAI = "OpenAI" |
|
ANTHROPIC = "Anthropic" |
|
GROQ = "Groq" |
|
GOOGLE = "Google" |
|
CUSTOM = "Custom" |
|
|
|
@dataclass |
|
class ProviderConfig: |
|
model_id: str |
|
api_key_env_var: Optional[str] = None |
|
model_name_env_var: Optional[str] = None |
|
base_url_env_var: Optional[str] = None |
|
default_max_tokens: int = 1000 |
|
default_temperature: float = 0.5 |
|
|
|
class LLMProviderManager: |
|
def __init__(self): |
|
self.providers_config = { |
|
ModelProvider.QWEN: ProviderConfig( |
|
model_id="Qwen/Qwen2.5-Coder-32B-Instruct" |
|
), |
|
ModelProvider.HUGGINGFACE: ProviderConfig( |
|
model_id="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud" |
|
), |
|
ModelProvider.OPENAI: ProviderConfig( |
|
model_id="gpt-4", |
|
api_key_env_var="OPENAI_API_KEY", |
|
model_name_env_var="OPENAI_MODEL_NAME", |
|
base_url_env_var="OPENAI_BASE_URL" |
|
), |
|
ModelProvider.ANTHROPIC: ProviderConfig( |
|
model_id="claude-v1", |
|
api_key_env_var="ANTHROPIC_API_KEY", |
|
model_name_env_var="ANTHROPIC_MODEL_NAME", |
|
base_url_env_var="ANTHROPIC_BASE_URL" |
|
), |
|
ModelProvider.GROQ: ProviderConfig( |
|
model_id="mixtral-8x7b-32768", |
|
api_key_env_var="GROQ_API_KEY", |
|
model_name_env_var="GROQ_MODEL_NAME", |
|
base_url_env_var="GROQ_BASE_URL" |
|
), |
|
ModelProvider.GOOGLE: ProviderConfig( |
|
model_id="gemini-pro", |
|
api_key_env_var="GOOGLE_API_KEY", |
|
model_name_env_var="GOOGLE_MODEL_NAME", |
|
base_url_env_var="GOOGLE_BASE_URL" |
|
), |
|
ModelProvider.CUSTOM: ProviderConfig( |
|
model_id=None, |
|
base_url_env_var="CUSTOM_BASE_URL" |
|
) |
|
} |
|
|
|
def _get_api_key(self, provider: ModelProvider, custom_api_key: Optional[str] = None) -> Optional[str]: |
|
config = self.providers_config[provider] |
|
if custom_api_key: |
|
return custom_api_key |
|
return os.environ.get(config.api_key_env_var) if config.api_key_env_var else None |
|
|
|
def _get_base_url(self, provider: ModelProvider) -> Optional[str]: |
|
config = self.providers_config[provider] |
|
return os.environ.get(config.base_url_env_var) if config.base_url_env_var else None |
|
|
|
def _get_model_name(self, provider: ModelProvider) -> str: |
|
config = self.providers_config[provider] |
|
if config.model_name_env_var: |
|
return os.environ.get(config.model_name_env_var, config.model_id) |
|
return config.model_id |
|
|
|
def initialize_provider( |
|
self, |
|
provider: ModelProvider, |
|
custom_api_key: Optional[str] = None, |
|
max_tokens: Optional[int] = None, |
|
temperature: Optional[float] = None |
|
) -> Any: |
|
"""Initialize a specific LLM provider with given configuration.""" |
|
try: |
|
config = self.providers_config[provider] |
|
api_key = self._get_api_key(provider, custom_api_key) |
|
base_url = self._get_base_url(provider) |
|
|
|
if provider in [ModelProvider.QWEN, ModelProvider.HUGGINGFACE, ModelProvider.CUSTOM]: |
|
return self._initialize_hf_model(config, api_key, base_url, max_tokens, temperature) |
|
|
|
provider_initializers = { |
|
ModelProvider.OPENAI: self._initialize_openai, |
|
ModelProvider.ANTHROPIC: self._initialize_anthropic, |
|
ModelProvider.GROQ: self._initialize_groq, |
|
ModelProvider.GOOGLE: self._initialize_google |
|
} |
|
|
|
initializer = provider_initializers.get(provider) |
|
if not initializer: |
|
raise ValueError(f"Unsupported provider: {provider}") |
|
|
|
if provider == ModelProvider.GOOGLE: |
|
client = initializer(api_key, base_url) |
|
return client |
|
else: |
|
return initializer(api_key, base_url) |
|
|
|
except Exception as e: |
|
logger.error(f"Error initializing provider {provider}: {str(e)}") |
|
raise |
|
|
|
def _initialize_hf_model( |
|
self, |
|
config: ProviderConfig, |
|
api_key: Optional[str], |
|
base_url: Optional[str], |
|
max_tokens: Optional[int], |
|
temperature: Optional[float] |
|
) -> HfApiModel: |
|
model_kwargs = { |
|
"max_tokens": max_tokens or config.default_max_tokens, |
|
"temperature": temperature or config.default_temperature, |
|
"model_id": config.model_id, |
|
"custom_role_conversions": None |
|
} |
|
|
|
if api_key: |
|
model_kwargs["api_key"] = api_key |
|
if base_url: |
|
model_kwargs["url"] = base_url |
|
|
|
return HfApiModel(**model_kwargs) |
|
|
|
def _initialize_openai(self, api_key: str, base_url: Optional[str]) -> OpenAI: |
|
kwargs = {"api_key": api_key} |
|
if base_url: |
|
kwargs["base_url"] = base_url |
|
return OpenAI(**kwargs) |
|
|
|
def _initialize_anthropic(self, api_key: str, base_url: Optional[str]) -> Anthropic: |
|
kwargs = {"api_key": api_key} |
|
if base_url: |
|
kwargs["base_url"] = base_url |
|
return Anthropic(**kwargs) |
|
|
|
def _initialize_groq(self, api_key: str, _: Optional[str]) -> groq.Groq: |
|
return groq.Groq(api_key=api_key) |
|
|
|
def _initialize_google(self, api_key: str, _: Optional[str]) -> Any: |
|
palm.configure(api_key=api_key) |
|
return palm |
|
|
|
model_providers = { |
|
"Qwen": { |
|
"model_id": "Qwen/Qwen2.5-Coder-32B-Instruct", |
|
"api_key_env_var": None |
|
}, |
|
"HuggingFace": { |
|
"model_id": "https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud", |
|
"api_key_env_var": None |
|
}, |
|
"OpenAI": { |
|
"model_id": "gpt-4", |
|
"api_key_env_var": "OPENAI_API_KEY", |
|
"model_name_env_var": "OPENAI_MODEL_NAME", |
|
"base_url_env_var": "OPENAI_BASE_URL" |
|
}, |
|
"Anthropic": { |
|
"model_id": "claude-v1", |
|
"api_key_env_var": "ANTHROPIC_API_KEY", |
|
"model_name_env_var": "ANTHROPIC_MODEL_NAME", |
|
"base_url_env_var": "ANTHROPIC_BASE_URL" |
|
}, |
|
"Groq": { |
|
"model_id": "mixtral-8x7b-32768", |
|
"api_key_env_var": "GROQ_API_KEY", |
|
"model_name_env_var": "GROQ_MODEL_NAME", |
|
"base_url_env_var": "GROQ_BASE_URL" |
|
}, |
|
"Google": { |
|
"model_id": "gemini-pro", |
|
"api_key_env_var": "GOOGLE_API_KEY", |
|
"model_name_env_var": "GOOGLE_MODEL_NAME", |
|
"base_url_env_var": "GOOGLE_BASE_URL" |
|
}, |
|
"Custom": { |
|
"model_id": None, |
|
"api_key_env_var": None, |
|
"base_url_env_var": "CUSTOM_BASE_URL" |
|
} |
|
} |
|
|
|
def launch_gradio_ui(additional_args: Optional[Dict[str, Any]] = None): |
|
"""Launch the Gradio UI with the specified LLM provider configuration.""" |
|
if additional_args is None: |
|
additional_args = {} |
|
|
|
def generate_google_content(prompt: str, model: palm.GenerativeModel): |
|
"""Helper function to generate content using the Google provider.""" |
|
try: |
|
response = model.generate_content(prompt) |
|
return response.text |
|
except Exception as e: |
|
logger.error(f"Google Palm API error: {str(e)}") |
|
return f"Error generating text with Google Palm: {str(e)}" |
|
|
|
provider_name = additional_args.get("selected_provider", "HuggingFace") |
|
max_steps = int(additional_args.get("max_steps", 6)) |
|
max_tokens = int(additional_args.get("max_tokens", 1000)) |
|
temperature = float(additional_args.get("temperature", 0.5)) |
|
|
|
try: |
|
provider = ModelProvider(provider_name) |
|
provider_manager = LLMProviderManager() |
|
|
|
custom_api_key = additional_args.get(f"{provider_name}_api_key") |
|
model = provider_manager.initialize_provider( |
|
provider=provider, |
|
custom_api_key=custom_api_key, |
|
max_tokens=max_tokens, |
|
temperature=temperature |
|
) |
|
|
|
agent = CodeAgent( |
|
model=generate_google_content if provider == ModelProvider.GOOGLE else model, |
|
tools=[ |
|
final_answer, visit_webpage, web_search, image_generation_tool, get_current_time_in_timezone, |
|
job_search_tool, |
|
odoo_documentation_search_tool, odoo_code_agent_16_tool, |
|
odoo_code_agent_17_tool, odoo_code_agent_18_tool |
|
], |
|
max_steps=max_steps, |
|
verbosity_level=1, |
|
grammar=None, |
|
planning_interval=None, |
|
name=None, |
|
description=None, |
|
prompt_templates=prompt_templates |
|
) |
|
|
|
GradioUI(agent).launch() |
|
|
|
except Exception as e: |
|
logger.error(f"Error launching Gradio UI: {str(e)}") |
|
raise |
|
|
|
launch_gradio_ui() |
|
|