File size: 1,623 Bytes
c21f3cc
 
 
 
 
62c8baf
 
 
 
692bd71
 
62c8baf
c21f3cc
 
62c8baf
c21f3cc
 
62c8baf
c21f3cc
62c8baf
c21f3cc
 
 
 
 
c92f6c4
c21f3cc
62c8baf
 
 
c21f3cc
c92f6c4
62c8baf
 
 
c21f3cc
62c8baf
692bd71
 
 
 
 
62c8baf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import google.generativeai as genai
import os
from dotenv import load_dotenv
from typing import Optional

generation_config=genai.types.GenerationConfig(
        # Only one candidate for now.
        #candidate_count=1,
        #stop_sequences=['x'],
        max_output_tokens=4096,
        temperature=0.1
)
class GeminiModel:
    """
    This class is used to interact with the Google LLM models for text generation.

    Args:
        model: The name of the model to be used. Defaults to 'gemini-pro'.
        max_output_tokens: The maximum number of tokens to generate. Defaults to 1024.
        top_p: The probability of generating the next token. Defaults to 1.0.
        temperature: The temperature of the model. Defaults to 0.0.
        top_k: The number of top tokens to consider. Defaults to 5.
    """

    def __init__(self,
                 model_name: Optional[str] = 'gemini-pro',
                 ):
        
        
        load_dotenv()
        genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
        self.model = genai.GenerativeModel(model_name) # type: ignore
        

    def execute(self, prompt: str) -> str:
        
        try:
            total_tokens = self.model.count_tokens(prompt).total_tokens
            print(f"Input tokens: {total_tokens}")
            response = self.model.generate_content(prompt, generation_config=generation_config)
            output_tokens = self.model.count_tokens(response.text).total_tokens
            print(f"Output tokens: {output_tokens}")
            return response.text
        except Exception as e:
            return f"An error occurred: {e}"