Spaces:
Sleeping
Sleeping
import configparser | |
import time | |
from typing import Any, Dict, List | |
from datetime import datetime | |
from utils.dict_utils import get_with_warning | |
class OpenAIProvider: | |
def __init__(self, params) -> None: | |
self.params = params | |
self.model_path = get_with_warning(params, 'model') | |
self.openai_api_key = get_with_warning(params, 'api_key') | |
self.openai_api_base = get_with_warning(params, 'base_url') | |
self.image_prompt_format = "openai" | |
# Generation config | |
self.top_p = float(get_with_warning(params, 'top_p', default=0.9)) | |
self.temperature = float(get_with_warning(params, 'temperature', default=0.8)) | |
self.max_new_tokens = int(get_with_warning(params, 'max_new_tokens', default=2048)) | |
self.prompt_tokens = 0 | |
self.completion_tokens = 0 | |
self.total_tokens = 0 | |
def reset(self): | |
pass | |
def get_tokens_usage(self): | |
return { | |
"prompt_tokens": self.prompt_tokens, | |
"completion_tokens": self.completion_tokens, | |
"total_tokens": self.total_tokens | |
} | |
def create_completion(self, message_prompts): | |
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')[:-3], f" {self.model_path} create local completion...") | |
from openai import OpenAI | |
try: | |
client = OpenAI( | |
api_key=self.openai_api_key, | |
base_url=self.openai_api_base, | |
) | |
chat_response = client.chat.completions.create( | |
model=self.model_path, | |
messages=message_prompts, | |
temperature=self.temperature, | |
top_p=self.top_p, | |
max_tokens=self.max_new_tokens | |
) | |
response_text = chat_response.choices[0].message.content | |
# print(chat_response) | |
prompt_tokens = chat_response.usage.prompt_tokens | |
completion_tokens = chat_response.usage.completion_tokens | |
# 输出token使用统计 | |
total_tokens = prompt_tokens + completion_tokens | |
print(f"Token Usage - Prompt Tokens: {prompt_tokens}, Completion Tokens: {completion_tokens}, Total Tokens: {total_tokens}") | |
self.prompt_tokens += prompt_tokens | |
self.completion_tokens += completion_tokens | |
self.total_tokens += total_tokens | |
return True, response_text | |
except Exception as e: | |
print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())), e) | |
print(self.model_path) | |
print(self.openai_api_key) | |
print(self.openai_api_base) | |
print(self.top_p) | |
print(self.temperature) | |
time.sleep(5) | |
self.reset() | |
return False, str(e) | |
# print(f'response: {response_text}') | |