Spaces:
Sleeping
Sleeping
from abc import ABC, abstractmethod | |
from vertexai.generative_models import GenerativeModel | |
from openai import OpenAI | |
import json | |
from typing import Dict, List, Optional, Union | |
class LLMProvider(ABC): | |
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str: | |
pass | |
class GeminiProvider(LLMProvider): | |
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str: | |
model_name = kwargs.get('model', 'gemini-pro') | |
model = GenerativeModel(model_name=model_name) | |
content = prompt | |
if messages and not prompt: | |
content = "\n".join([f"{m['role']}: {m['content']}" for m in messages]) | |
if "response_format" in kwargs and kwargs["response_format"].get("type") == "json_object": | |
generation_config = { | |
"response_mime_type": "application/json", | |
} | |
response = model.generate_content( | |
content, | |
generation_config=generation_config | |
) | |
else: | |
response = model.generate_content(content) | |
return response.text | |
class OpenAIProvider(LLMProvider): | |
def __init__(self, client): | |
self.client = client | |
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str: | |
if prompt and not messages: | |
messages = [{"role": "user", "content": prompt}] | |
completion_kwargs = { | |
"model": kwargs.get("model", "gpt-4o"), | |
"messages": messages, | |
"max_tokens": kwargs.get("max_tokens", 1000), | |
"temperature": kwargs.get("temperature", 1.0), | |
} | |
if "response_format" in kwargs: | |
completion_kwargs["response_format"] = kwargs["response_format"] | |
response = self.client.chat.completions.create(**completion_kwargs) | |
return response.choices[0].message.content | |
class LLMService: | |
def __init__(self, provider: LLMProvider): | |
self.provider = provider | |
def chat( | |
self, | |
prompt: Optional[str] = None, | |
messages: Optional[List[Dict[str, str]]] = None, | |
**kwargs | |
) -> str: | |
try: | |
return self.provider.generate(prompt=prompt, messages=messages, **kwargs) | |
except Exception as e: | |
print(f"LLM API error: {str(e)}") | |
raise |