File size: 6,094 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
from dataclasses import dataclass, field
import os
import json
import logging
from argparse import Namespace
from typing import List, Literal, Optional, Union
from pydantic import AnyHttpUrl, BaseSettings, HttpUrl, validator, BaseModel
CURRENT_DIR_PATH = os.path.dirname(os.path.abspath(__file__))
# request body
# 使用pydantic对请求中的body数据进行验证
class RequestDataStructure(BaseModel):
input_text: List[str] = [""]
uuid: Optional[int]
# parameters for text2image model
input_image: Optional[str]
skip_steps: Optional[int]
clip_guidance_scale: Optional[int]
init_scale: Optional[int]
# API config
@dataclass
class APIConfig:
# server config
SERVER_HOST: AnyHttpUrl = "127.0.0.1"
SERVER_PORT: int = 8990
SERVER_NAME: str = ""
PROJECT_NAME: str = ""
API_PREFIX_STR: str = "/api"
# api config
API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = "POST"
API_path: str = "/TextClassification"
API_tags: List[str] = field(default_factory = lambda: [""])
# CORS config
BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = field(default_factory = lambda: ["*"])
allow_credentials: bool = True
allow_methods: List[str] = field(default_factory = lambda: ["*"])
allow_headers: List[str] = field(default_factory = lambda: ["*"])
# log config
log_file_path: str = ""
log_level: str = "INFO"
# pipeline config
pipeline_type: str = ""
model_name: str = ""
# model config
# device: int = -1
# texta_name: Optional[str] = "sentence"
# textb_name: Optional[str] = "sentence2"
# label_name: Optional[str] = "label"
# max_length: int = 512
# return_tensors: str = "pt"
# padding: str = "longest"
# truncation: bool = True
# skip_special_tokens: bool = True
# clean_up_tkenization_spaces: bool = True
# # parameters for text2image model
# skip_steps: Optional[int] = 0
# clip_guidance_scale: Optional[int] = 0
# init_scale: Optional[int] = 0
def setup_config(self, args:Namespace) -> None:
# load config file
with open(CURRENT_DIR_PATH + "/" + args.config_path, "r") as jsonfile:
config = json.load(jsonfile)
server_config = config["SERVER"]
logging_config = config["LOGGING"]
pipeline_config = config["PIPELINE"]
# server config
self.SERVER_HOST: AnyHttpUrl = server_config["SERVER_HOST"]
self.SERVER_PORT: int = server_config["SERVER_PORT"]
self.SERVER_NAME: str = server_config["SERVER_NAME"]
self.PROJECT_NAME: str = server_config["PROJECT_NAME"]
self.API_PREFIX_STR: str = server_config["API_PREFIX_STR"]
# api config
self.API_method: Literal["POST","GET","PUT","OPTIONS","WEBSOCKET","PATCH","DELETE","TRACE","CONNECT"] = server_config["API_method"]
self.API_path: str = server_config["API_path"]
self.API_tags: List[str] = server_config["API_tags"]
# CORS config
self.BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = server_config["BACKEND_CORS_ORIGINS"]
self.allow_credentials: bool = server_config["allow_credentials"]
self.allow_methods: List[str] = server_config["allow_methods"]
self.allow_headers: List[str] = server_config["allow_headers"]
# log config
self.log_file_path: str = logging_config["log_file_path"]
self.log_level: str = logging_config["log_level"]
# pipeline config
self.pipeline_type: str = pipeline_config["pipeline_type"]
self.model_name: str = pipeline_config["model_name"]
# general model config
self.model_settings: dict = pipeline_config["model_settings"]
# 由于pipeline本身会解析参数,后续参数可以不要
# 直接将model_settings字典转为Namespace后作为pipeline的args参数即可
# self.device: int = self.model_settings["device"]
# self.texta_name: Optional[str] = self.model_settings["texta_name"]
# self.textb_name: Optional[str] = self.model_settings["textb_name"]
# self.label_name: Optional[str] = self.model_settings["label_name"]
# self.max_length: int = self.model_settings["max_length"]
# self.return_tensors: str = self.model_settings["return_tensors"]
# self.padding: str = self.model_settings["padding"]
# self.truncation: bool = self.model_settings["truncation"]
# self.skip_special_tokens: bool = self.model_settings["skip_special_tokens"]
# self.clean_up_tkenization_spaces: bool = self.model_settings["clean_up_tkenization_spaces"]
# # specific parameters for text2image model
# self.skip_steps: Optional[int] = self.model_settings["skip_steps"]
# self.clip_guidance_scale: Optional[int] = self.model_settings["clip_guidance_scale"]
# self.init_scale: Optional[int] = self.model_settings["init_scale"]
def setup_logger(logger, user_config: APIConfig):
# default level: INFO
logger.setLevel(getattr(logging, user_config.log_level, "INFO"))
ch = logging.StreamHandler()
if(user_config.log_file_path == ""):
fh = logging.FileHandler(filename = CURRENT_DIR_PATH + "/" + user_config.SERVER_NAME + ".log")
elif(".log" not in user_config.log_file_path[-5:-1]):
fh = logging.FileHandler(filename = user_config.log_file_path + "/" + user_config.SERVER_NAME + ".log")
else:
fh = logging.FileHandler(filename = user_config.log_file_path)
formatter = logging.Formatter(
"%(asctime)s - %(module)s - %(funcName)s - line:%(lineno)d - %(levelname)s - %(message)s"
)
ch.setFormatter(formatter)
fh.setFormatter(formatter)
logger.addHandler(ch) # Exporting logs to the screen
logger.addHandler(fh) # Exporting logs to a file
return logger
user_config = APIConfig()
api_logger = logging.getLogger()
|