VPCSinfo commited on
Commit
e5f5bde
·
1 Parent(s): 743c53a

[ADD] added multi provider class and its structure part to handle it.

Browse files
Files changed (3) hide show
  1. Gradio_UI.py +5 -1
  2. app.py +231 -53
  3. requirements.txt +4 -1
Gradio_UI.py CHANGED
@@ -278,7 +278,11 @@ class GradioUI:
278
  app_content = f.read()
279
  model_providers_match = re.search(r"model_providers = ({[^{}]*(?:{[^{}]*}[^{}]*)*})", app_content, re.DOTALL)
280
  if model_providers_match:
281
- model_providers = eval(model_providers_match.group(1))
 
 
 
 
282
  else:
283
  model_providers = {}
284
 
 
278
  app_content = f.read()
279
  model_providers_match = re.search(r"model_providers = ({[^{}]*(?:{[^{}]*}[^{}]*)*})", app_content, re.DOTALL)
280
  if model_providers_match:
281
+ try:
282
+ model_providers = eval(model_providers_match.group(1))
283
+ except Exception as e:
284
+ print(f"Error evaluating model_providers: {e}")
285
+ model_providers = {}
286
  else:
287
  model_providers = {}
288
 
app.py CHANGED
@@ -1,4 +1,13 @@
1
- from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
 
 
 
 
 
 
 
 
 
2
  import datetime
3
  import requests
4
  import pytz
@@ -12,11 +21,15 @@ from tools.odoo_code_agent_16 import OdooCodeAgent16
12
  from tools.odoo_code_agent_17 import OdooCodeAgent17
13
  from tools.odoo_code_agent_18 import OdooCodeAgent18
14
 
15
- from dotenv import load_dotenv
16
  from Gradio_UI import GradioUI
17
- load_dotenv()
18
 
19
- import os
 
 
 
 
 
 
20
 
21
  # Below is an example of a tool that does nothing. Amaze us with your creativity !
22
  @tool
@@ -56,8 +69,160 @@ odoo_code_agent_16_tool = OdooCodeAgent16(prompt_templates["system_prompt"])
56
  odoo_code_agent_17_tool = OdooCodeAgent17(prompt_templates["system_prompt"])
57
  odoo_code_agent_18_tool = OdooCodeAgent18(prompt_templates["system_prompt"])
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
61
  model_providers = {
62
  "Qwen": {
63
  "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct",
@@ -69,74 +234,87 @@ model_providers = {
69
  },
70
  "OpenAI": {
71
  "model_id": "gpt-4",
72
- "api_key_env_var": "OPENAI_API_KEY"
 
 
73
  },
74
  "Anthropic": {
75
  "model_id": "claude-v1",
76
- "api_key_env_var": "ANTHROPIC_API_KEY"
 
 
77
  },
78
  "Groq": {
79
  "model_id": "mixtral-8x7b-32768",
80
- "api_key_env_var": "GROQ_API_KEY"
 
 
81
  },
82
  "Google": {
83
  "model_id": "gemini-pro",
84
- "api_key_env_var": "GOOGLE_API_KEY"
 
 
85
  },
86
  "Custom": {
87
  "model_id": None,
88
- "api_key_env_var": None
 
89
  }
90
  }
91
 
 
 
 
 
92
 
93
- # Import tool from Hub
94
- image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
 
 
 
 
 
 
95
 
 
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- def launch_gradio_ui(additional_args=None):
99
- global selected_provider
100
- global model
101
- #global agent # Declare agent as global
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
- if additional_args:
104
- selected_provider = additional_args.get("selected_provider", "HuggingFace")
105
- max_steps = int(additional_args.get("max_steps", 6))
106
- max_tokens = int(additional_args.get("max_tokens", 1000))
107
- else:
108
- selected_provider = "HuggingFace"
109
- max_steps = 6
110
- max_tokens = 1000
111
 
112
- model_id = model_providers[selected_provider]["model_id"]
113
- api_key_env_var = model_providers[selected_provider]["api_key_env_var"]
114
- api_key = additional_args.get(f"{selected_provider}_api_key") if additional_args else None
115
 
116
- model_kwargs = {
117
- "max_tokens": max_tokens,
118
- "temperature": 0.5,
119
- "model_id": model_id,
120
- "custom_role_conversions": None,
121
- }
122
- if model_providers[selected_provider]["api_key_env_var"]:
123
- model_kwargs["api_key"] = api_key if api_key else os.environ.get(api_key_env_var)
124
-
125
- model = HfApiModel(**model_kwargs)
126
-
127
- agent = CodeAgent(
128
- model=model,
129
- tools=[final_answer, visit_webpage, web_search, image_generation_tool, get_current_time_in_timezone, job_search_tool, odoo_documentation_search_tool, odoo_code_agent_16_tool, odoo_code_agent_17_tool, odoo_code_agent_18_tool],
130
- max_steps=max_steps,
131
- verbosity_level=1,
132
- grammar=None,
133
- planning_interval=None,
134
- name=None,
135
- description=None,
136
- prompt_templates=prompt_templates
137
- )
138
-
139
- GradioUI(agent).launch()
140
-
141
- # Remove the direct call to launch_gradio_ui()
142
  launch_gradio_ui()
 
1
+ from typing import Optional, Dict, Any
2
+ from dataclasses import dataclass
3
+ import os
4
+ from enum import Enum
5
+ import logging
6
+ from openai import OpenAI
7
+ from anthropic import Anthropic
8
+ import groq
9
+ import google.generativeai as palm
10
+ from smolagents import HfApiModel, CodeAgent, DuckDuckGoSearchTool, load_tool, tool
11
  import datetime
12
  import requests
13
  import pytz
 
21
  from tools.odoo_code_agent_17 import OdooCodeAgent17
22
  from tools.odoo_code_agent_18 import OdooCodeAgent18
23
 
 
24
  from Gradio_UI import GradioUI
 
25
 
26
+ # Configure logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ os.environ["TRANSFORMERS_OFFLINE"] = "1"
31
+ os.environ["TORCH_MPS_FORCE_CPU"] = "1"
32
+
33
 
34
  # Below is an example of a tool that does nothing. Amaze us with your creativity !
35
  @tool
 
69
  odoo_code_agent_17_tool = OdooCodeAgent17(prompt_templates["system_prompt"])
70
  odoo_code_agent_18_tool = OdooCodeAgent18(prompt_templates["system_prompt"])
71
 
72
+ # Import tool from Hub
73
+ image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
74
+
75
+ class ModelProvider(Enum):
76
+ QWEN = "Qwen"
77
+ HUGGINGFACE = "HuggingFace"
78
+ OPENAI = "OpenAI"
79
+ ANTHROPIC = "Anthropic"
80
+ GROQ = "Groq"
81
+ GOOGLE = "Google"
82
+ CUSTOM = "Custom"
83
+
84
+ @dataclass
85
+ class ProviderConfig:
86
+ model_id: str
87
+ api_key_env_var: Optional[str] = None
88
+ model_name_env_var: Optional[str] = None
89
+ base_url_env_var: Optional[str] = None
90
+ default_max_tokens: int = 1000
91
+ default_temperature: float = 0.5
92
+
93
+ class LLMProviderManager:
94
+ def __init__(self):
95
+ self.providers_config = {
96
+ ModelProvider.QWEN: ProviderConfig(
97
+ model_id="Qwen/Qwen2.5-Coder-32B-Instruct"
98
+ ),
99
+ ModelProvider.HUGGINGFACE: ProviderConfig(
100
+ model_id="https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud"
101
+ ),
102
+ ModelProvider.OPENAI: ProviderConfig(
103
+ model_id="gpt-4",
104
+ api_key_env_var="OPENAI_API_KEY",
105
+ model_name_env_var="OPENAI_MODEL_NAME",
106
+ base_url_env_var="OPENAI_BASE_URL"
107
+ ),
108
+ ModelProvider.ANTHROPIC: ProviderConfig(
109
+ model_id="claude-v1",
110
+ api_key_env_var="ANTHROPIC_API_KEY",
111
+ model_name_env_var="ANTHROPIC_MODEL_NAME",
112
+ base_url_env_var="ANTHROPIC_BASE_URL"
113
+ ),
114
+ ModelProvider.GROQ: ProviderConfig(
115
+ model_id="mixtral-8x7b-32768",
116
+ api_key_env_var="GROQ_API_KEY",
117
+ model_name_env_var="GROQ_MODEL_NAME",
118
+ base_url_env_var="GROQ_BASE_URL"
119
+ ),
120
+ ModelProvider.GOOGLE: ProviderConfig(
121
+ model_id="gemini-pro",
122
+ api_key_env_var="GOOGLE_API_KEY",
123
+ model_name_env_var="GOOGLE_MODEL_NAME",
124
+ base_url_env_var="GOOGLE_BASE_URL"
125
+ ),
126
+ ModelProvider.CUSTOM: ProviderConfig(
127
+ model_id=None,
128
+ base_url_env_var="CUSTOM_BASE_URL"
129
+ )
130
+ }
131
+
132
+ def _get_api_key(self, provider: ModelProvider, custom_api_key: Optional[str] = None) -> Optional[str]:
133
+ config = self.providers_config[provider]
134
+ if custom_api_key:
135
+ return custom_api_key
136
+ return os.environ.get(config.api_key_env_var) if config.api_key_env_var else None
137
+
138
+ def _get_base_url(self, provider: ModelProvider) -> Optional[str]:
139
+ config = self.providers_config[provider]
140
+ return os.environ.get(config.base_url_env_var) if config.base_url_env_var else None
141
+
142
+ def _get_model_name(self, provider: ModelProvider) -> str:
143
+ config = self.providers_config[provider]
144
+ if config.model_name_env_var:
145
+ return os.environ.get(config.model_name_env_var, config.model_id)
146
+ return config.model_id
147
+
148
+ def initialize_provider(
149
+ self,
150
+ provider: ModelProvider,
151
+ custom_api_key: Optional[str] = None,
152
+ max_tokens: Optional[int] = None,
153
+ temperature: Optional[float] = None
154
+ ) -> Any:
155
+ """Initialize a specific LLM provider with given configuration."""
156
+ try:
157
+ config = self.providers_config[provider]
158
+ api_key = self._get_api_key(provider, custom_api_key)
159
+ base_url = self._get_base_url(provider)
160
+
161
+ if provider in [ModelProvider.QWEN, ModelProvider.HUGGINGFACE, ModelProvider.CUSTOM]:
162
+ return self._initialize_hf_model(config, api_key, base_url, max_tokens, temperature)
163
+
164
+ provider_initializers = {
165
+ ModelProvider.OPENAI: self._initialize_openai,
166
+ ModelProvider.ANTHROPIC: self._initialize_anthropic,
167
+ ModelProvider.GROQ: self._initialize_groq,
168
+ ModelProvider.GOOGLE: self._initialize_google
169
+ }
170
+
171
+ initializer = provider_initializers.get(provider)
172
+ if not initializer:
173
+ raise ValueError(f"Unsupported provider: {provider}")
174
+
175
+ if provider == ModelProvider.GOOGLE:
176
+ client = initializer(api_key, base_url)
177
+ return client
178
+ else:
179
+ return initializer(api_key, base_url)
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error initializing provider {provider}: {str(e)}")
183
+ raise
184
+
185
+ def _initialize_hf_model(
186
+ self,
187
+ config: ProviderConfig,
188
+ api_key: Optional[str],
189
+ base_url: Optional[str],
190
+ max_tokens: Optional[int],
191
+ temperature: Optional[float]
192
+ ) -> HfApiModel:
193
+ model_kwargs = {
194
+ "max_tokens": max_tokens or config.default_max_tokens,
195
+ "temperature": temperature or config.default_temperature,
196
+ "model_id": config.model_id,
197
+ "custom_role_conversions": None
198
+ }
199
+
200
+ if api_key:
201
+ model_kwargs["api_key"] = api_key
202
+ if base_url:
203
+ model_kwargs["url"] = base_url
204
+
205
+ return HfApiModel(**model_kwargs)
206
+
207
+ def _initialize_openai(self, api_key: str, base_url: Optional[str]) -> OpenAI:
208
+ kwargs = {"api_key": api_key}
209
+ if base_url:
210
+ kwargs["base_url"] = base_url
211
+ return OpenAI(**kwargs)
212
+
213
+ def _initialize_anthropic(self, api_key: str, base_url: Optional[str]) -> Anthropic:
214
+ kwargs = {"api_key": api_key}
215
+ if base_url:
216
+ kwargs["base_url"] = base_url
217
+ return Anthropic(**kwargs)
218
+
219
+ def _initialize_groq(self, api_key: str, _: Optional[str]) -> groq.Groq:
220
+ return groq.Groq(api_key=api_key)
221
+
222
+ def _initialize_google(self, api_key: str, _: Optional[str]) -> Any:
223
+ palm.configure(api_key=api_key)
224
+ return palm
225
 
 
226
  model_providers = {
227
  "Qwen": {
228
  "model_id": "Qwen/Qwen2.5-Coder-32B-Instruct",
 
234
  },
235
  "OpenAI": {
236
  "model_id": "gpt-4",
237
+ "api_key_env_var": "OPENAI_API_KEY",
238
+ "model_name_env_var": "OPENAI_MODEL_NAME",
239
+ "base_url_env_var": "OPENAI_BASE_URL"
240
  },
241
  "Anthropic": {
242
  "model_id": "claude-v1",
243
+ "api_key_env_var": "ANTHROPIC_API_KEY",
244
+ "model_name_env_var": "ANTHROPIC_MODEL_NAME",
245
+ "base_url_env_var": "ANTHROPIC_BASE_URL"
246
  },
247
  "Groq": {
248
  "model_id": "mixtral-8x7b-32768",
249
+ "api_key_env_var": "GROQ_API_KEY",
250
+ "model_name_env_var": "GROQ_MODEL_NAME",
251
+ "base_url_env_var": "GROQ_BASE_URL"
252
  },
253
  "Google": {
254
  "model_id": "gemini-pro",
255
+ "api_key_env_var": "GOOGLE_API_KEY",
256
+ "model_name_env_var": "GOOGLE_MODEL_NAME",
257
+ "base_url_env_var": "GOOGLE_BASE_URL"
258
  },
259
  "Custom": {
260
  "model_id": None,
261
+ "api_key_env_var": None,
262
+ "base_url_env_var": "CUSTOM_BASE_URL"
263
  }
264
  }
265
 
266
+ def launch_gradio_ui(additional_args: Optional[Dict[str, Any]] = None):
267
+ """Launch the Gradio UI with the specified LLM provider configuration."""
268
+ if additional_args is None:
269
+ additional_args = {}
270
 
271
+ def generate_google_content(prompt: str, model: palm.GenerativeModel):
272
+ """Helper function to generate content using the Google provider."""
273
+ try:
274
+ response = model.generate_content(prompt)
275
+ return response.text
276
+ except Exception as e:
277
+ logger.error(f"Google Palm API error: {str(e)}")
278
+ return f"Error generating text with Google Palm: {str(e)}"
279
 
280
+ provider_name = additional_args.get("selected_provider", "HuggingFace")
281
+ max_steps = int(additional_args.get("max_steps", 6))
282
+ max_tokens = int(additional_args.get("max_tokens", 1000))
283
+ temperature = float(additional_args.get("temperature", 0.5))
284
 
285
+ try:
286
+ provider = ModelProvider(provider_name)
287
+ provider_manager = LLMProviderManager()
288
+
289
+ custom_api_key = additional_args.get(f"{provider_name}_api_key")
290
+ model = provider_manager.initialize_provider(
291
+ provider=provider,
292
+ custom_api_key=custom_api_key,
293
+ max_tokens=max_tokens,
294
+ temperature=temperature
295
+ )
296
 
297
+ agent = CodeAgent(
298
+ model=generate_google_content if provider == ModelProvider.GOOGLE else model,
299
+ tools=[
300
+ final_answer, visit_webpage, web_search, image_generation_tool, get_current_time_in_timezone,
301
+ job_search_tool,
302
+ odoo_documentation_search_tool, odoo_code_agent_16_tool,
303
+ odoo_code_agent_17_tool, odoo_code_agent_18_tool
304
+ ],
305
+ max_steps=max_steps,
306
+ verbosity_level=1,
307
+ grammar=None,
308
+ planning_interval=None,
309
+ name=None,
310
+ description=None,
311
+ prompt_templates=prompt_templates
312
+ )
313
 
314
+ GradioUI(agent).launch()
 
 
 
 
 
 
 
315
 
316
+ except Exception as e:
317
+ logger.error(f"Error launching Gradio UI: {str(e)}")
318
+ raise
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  launch_gradio_ui()
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- markdownify
2
  smolagents
3
  requests
4
  duckduckgo_search
@@ -10,3 +9,7 @@ transformers
10
  torch
11
  sentence-transformers
12
  numpy==1.23.5
 
 
 
 
 
 
1
  smolagents
2
  requests
3
  duckduckgo_search
 
9
  torch
10
  sentence-transformers
11
  numpy==1.23.5
12
+ openai
13
+ anthropic
14
+ groq
15
+ google-generativeai