File size: 2,504 Bytes
ac10d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3c4560
 
 
 
 
 
 
 
 
 
 
ac10d4c
 
 
 
 
 
 
 
 
 
 
0ca2b17
ac10d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from abc import ABC, abstractmethod
from vertexai.generative_models import GenerativeModel
from openai import OpenAI
import json
from typing import Dict, List, Optional, Union

class LLMProvider(ABC):
    @abstractmethod
    def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
        pass

class GeminiProvider(LLMProvider):
    def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
        model_name = kwargs.get('model', 'gemini-pro')
        
        model = GenerativeModel(model_name=model_name)
        
        content = prompt
        if messages and not prompt:
            content = "\n".join([f"{m['role']}: {m['content']}" for m in messages])
            
        if "response_format" in kwargs and kwargs["response_format"].get("type") == "json_object":
            generation_config = {
                "response_mime_type": "application/json",
            }
            response = model.generate_content(
                content,
                generation_config=generation_config
            )
        else:
            response = model.generate_content(content)
            
        return response.text

class OpenAIProvider(LLMProvider):
    def __init__(self, client):
        self.client = client

    def generate(self, prompt: Optional[str] = None, messages: Optional[List[Dict]] = None, **kwargs) -> str:
        if prompt and not messages:
            messages = [{"role": "user", "content": prompt}]
            
        completion_kwargs = {
            "model": kwargs.get("model", "gpt-4o"),
            "messages": messages,
            "max_tokens": kwargs.get("max_tokens", 1000),
            "temperature": kwargs.get("temperature", 1.0),
        }
        
        if "response_format" in kwargs:
            completion_kwargs["response_format"] = kwargs["response_format"]
            
        response = self.client.chat.completions.create(**completion_kwargs)
        return response.choices[0].message.content

class LLMService:
    def __init__(self, provider: LLMProvider):
        self.provider = provider

    def chat(
        self,
        prompt: Optional[str] = None,
        messages: Optional[List[Dict[str, str]]] = None,
        **kwargs
    ) -> str:
        try:
            return self.provider.generate(prompt=prompt, messages=messages, **kwargs)
        except Exception as e:
            print(f"LLM API error: {str(e)}")
            raise