V-MAGE-EVAL-DEMO / provider /OpenAIProvider.py
Fengx1n's picture
first commit
504b2e4
import configparser
import time
from typing import Any, Dict, List
from datetime import datetime
from utils.dict_utils import get_with_warning
class OpenAIProvider:
def __init__(self, params) -> None:
self.params = params
self.model_path = get_with_warning(params, 'model')
self.openai_api_key = get_with_warning(params, 'api_key')
self.openai_api_base = get_with_warning(params, 'base_url')
self.image_prompt_format = "openai"
# Generation config
self.top_p = float(get_with_warning(params, 'top_p', default=0.9))
self.temperature = float(get_with_warning(params, 'temperature', default=0.8))
self.max_new_tokens = int(get_with_warning(params, 'max_new_tokens', default=2048))
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
def reset(self):
pass
def get_tokens_usage(self):
return {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens
}
def create_completion(self, message_prompts):
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3], f" {self.model_path} create local completion...")
from openai import OpenAI
try:
client = OpenAI(
api_key=self.openai_api_key,
base_url=self.openai_api_base,
)
chat_response = client.chat.completions.create(
model=self.model_path,
messages=message_prompts,
temperature=self.temperature,
top_p=self.top_p,
max_tokens=self.max_new_tokens
)
response_text = chat_response.choices[0].message.content
# print(chat_response)
prompt_tokens = chat_response.usage.prompt_tokens
completion_tokens = chat_response.usage.completion_tokens
# 输出token使用统计
total_tokens = prompt_tokens + completion_tokens
print(f"Token Usage - Prompt Tokens: {prompt_tokens}, Completion Tokens: {completion_tokens}, Total Tokens: {total_tokens}")
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.total_tokens += total_tokens
return True, response_text
except Exception as e:
print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())), e)
print(self.model_path)
print(self.openai_api_key)
print(self.openai_api_base)
print(self.top_p)
print(self.temperature)
time.sleep(5)
self.reset()
return False, str(e)
# print(f'response: {response_text}')