jutor_write / llms.py
youngtsai's picture
4o llms
0ca2b17
from abc import ABC, abstractmethod
from vertexai.generative_models import GenerativeModel
from openai import OpenAI
import json
from typing import Dict, List, Optional, Union
class LLMProvider(ABC):
@abstractmethod
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
pass
class GeminiProvider(LLMProvider):
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
model_name = kwargs.get('model', 'gemini-pro')
model = GenerativeModel(model_name=model_name)
content = prompt
if messages and not prompt:
content = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
if "response_format" in kwargs and kwargs["response_format"].get("type") == "json_object":
generation_config = {
"response_mime_type": "application/json",
}
response = model.generate_content(
content,
generation_config=generation_config
)
else:
response = model.generate_content(content)
return response.text
class OpenAIProvider(LLMProvider):
def __init__(self, client):
self.client = client
def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
if prompt and not messages:
messages = [{"role": "user", "content": prompt}]
completion_kwargs = {
"model": kwargs.get("model", "gpt-4o"),
"messages": messages,
"max_tokens": kwargs.get("max_tokens", 1000),
"temperature": kwargs.get("temperature", 1.0),
}
if "response_format" in kwargs:
completion_kwargs["response_format"] = kwargs["response_format"]
response = self.client.chat.completions.create(**completion_kwargs)
return response.choices[0].message.content
class LLMService:
def __init__(self, provider: LLMProvider):
self.provider = provider
def chat(
self,
prompt: Optional[str] = None,
messages: Optional[List[Dict[str, str]]] = None,
**kwargs
) -> str:
try:
return self.provider.generate(prompt=prompt, messages=messages, **kwargs)
except Exception as e:
print(f"LLM API error: {str(e)}")
raise