import configparser import time from typing import Any, Dict, List from datetime import datetime from utils.dict_utils import get_with_warning class OpenAIProvider: def __init__(self, params) -> None: self.params = params self.model_path = get_with_warning(params, 'model') self.openai_api_key = get_with_warning(params, 'api_key') self.openai_api_base = get_with_warning(params, 'base_url') self.image_prompt_format = "openai" # Generation config self.top_p = float(get_with_warning(params, 'top_p', default=0.9)) self.temperature = float(get_with_warning(params, 'temperature', default=0.8)) self.max_new_tokens = int(get_with_warning(params, 'max_new_tokens', default=2048)) self.prompt_tokens = 0 self.completion_tokens = 0 self.total_tokens = 0 def reset(self): pass def get_tokens_usage(self): return { "prompt_tokens": self.prompt_tokens, "completion_tokens": self.completion_tokens, "total_tokens": self.total_tokens } def create_completion(self, message_prompts): print(datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3], f" {self.model_path} create local completion...") from openai import OpenAI try: client = OpenAI( api_key=self.openai_api_key, base_url=self.openai_api_base, ) chat_response = client.chat.completions.create( model=self.model_path, messages=message_prompts, temperature=self.temperature, top_p=self.top_p, max_tokens=self.max_new_tokens ) response_text = chat_response.choices[0].message.content # print(chat_response) prompt_tokens = chat_response.usage.prompt_tokens completion_tokens = chat_response.usage.completion_tokens # 输出token使用统计 total_tokens = prompt_tokens + completion_tokens print(f"Token Usage - Prompt Tokens: {prompt_tokens}, Completion Tokens: {completion_tokens}, Total Tokens: {total_tokens}") self.prompt_tokens += prompt_tokens self.completion_tokens += completion_tokens self.total_tokens += total_tokens return True, response_text except Exception as e: print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())), e) print(self.model_path) print(self.openai_api_key) print(self.openai_api_base) print(self.top_p) print(self.temperature) time.sleep(5) self.reset() return False, str(e) # print(f'response: {response_text}')