|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64
|
|
from datetime import datetime
|
|
import io
|
|
import json
|
|
import os
|
|
import pickle
|
|
import socket
|
|
import time
|
|
import uuid
|
|
import requests
|
|
from enum import Enum, IntEnum
|
|
import importlib
|
|
from Cryptodome.PublicKey import RSA
|
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5
|
|
|
|
from filelock import FileLock
|
|
|
|
from . import file_utils
|
|
|
|
SERVICE_CONF = "service_conf.yaml"
|
|
|
|
def conf_realpath(conf_name):
|
|
conf_path = f"conf/{conf_name}"
|
|
return os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
|
|
def get_base_config(key, default=None, conf_name=SERVICE_CONF) -> dict:
|
|
local_config = {}
|
|
local_path = conf_realpath(f'local.{conf_name}')
|
|
if default is None:
|
|
default = os.environ.get(key.upper())
|
|
|
|
if os.path.exists(local_path):
|
|
local_config = file_utils.load_yaml_conf(local_path)
|
|
if not isinstance(local_config, dict):
|
|
raise ValueError(f'Invalid config file: "{local_path}".')
|
|
|
|
if key is not None and key in local_config:
|
|
return local_config[key]
|
|
|
|
config_path = conf_realpath(conf_name)
|
|
config = file_utils.load_yaml_conf(config_path)
|
|
|
|
if not isinstance(config, dict):
|
|
raise ValueError(f'Invalid config file: "{config_path}".')
|
|
|
|
config.update(local_config)
|
|
return config.get(key, default) if key is not None else config
|
|
|
|
|
|
use_deserialize_safe_module = get_base_config('use_deserialize_safe_module', False)
|
|
|
|
|
|
class CoordinationCommunicationProtocol(object):
|
|
HTTP = "http"
|
|
GRPC = "grpc"
|
|
|
|
|
|
class BaseType:
|
|
def to_dict(self):
|
|
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()])
|
|
|
|
def to_dict_with_type(self):
|
|
def _dict(obj):
|
|
module = None
|
|
if issubclass(obj.__class__, BaseType):
|
|
data = {}
|
|
for attr, v in obj.__dict__.items():
|
|
k = attr.lstrip("_")
|
|
data[k] = _dict(v)
|
|
module = obj.__module__
|
|
elif isinstance(obj, (list, tuple)):
|
|
data = []
|
|
for i, vv in enumerate(obj):
|
|
data.append(_dict(vv))
|
|
elif isinstance(obj, dict):
|
|
data = {}
|
|
for _k, vv in obj.items():
|
|
data[_k] = _dict(vv)
|
|
else:
|
|
data = obj
|
|
return {"type": obj.__class__.__name__, "data": data, "module": module}
|
|
return _dict(self)
|
|
|
|
|
|
class CustomJSONEncoder(json.JSONEncoder):
|
|
def __init__(self, **kwargs):
|
|
self._with_type = kwargs.pop("with_type", False)
|
|
super().__init__(**kwargs)
|
|
|
|
def default(self, obj):
|
|
if isinstance(obj, datetime.datetime):
|
|
return obj.strftime('%Y-%m-%d %H:%M:%S')
|
|
elif isinstance(obj, datetime.date):
|
|
return obj.strftime('%Y-%m-%d')
|
|
elif isinstance(obj, datetime.timedelta):
|
|
return str(obj)
|
|
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum):
|
|
return obj.value
|
|
elif isinstance(obj, set):
|
|
return list(obj)
|
|
elif issubclass(type(obj), BaseType):
|
|
if not self._with_type:
|
|
return obj.to_dict()
|
|
else:
|
|
return obj.to_dict_with_type()
|
|
elif isinstance(obj, type):
|
|
return obj.__name__
|
|
else:
|
|
return json.JSONEncoder.default(self, obj)
|
|
|
|
|
|
def rag_uuid():
|
|
return uuid.uuid1().hex
|
|
|
|
|
|
def string_to_bytes(string):
|
|
return string if isinstance(string, bytes) else string.encode(encoding="utf-8")
|
|
|
|
|
|
def bytes_to_string(byte):
|
|
return byte.decode(encoding="utf-8")
|
|
|
|
|
|
def json_dumps(src, byte=False, indent=None, with_type=False):
|
|
dest = json.dumps(src, indent=indent, cls=CustomJSONEncoder, with_type=with_type)
|
|
if byte:
|
|
dest = string_to_bytes(dest)
|
|
return dest
|
|
|
|
|
|
def json_loads(src, object_hook=None, object_pairs_hook=None):
|
|
if isinstance(src, bytes):
|
|
src = bytes_to_string(src)
|
|
return json.loads(src, object_hook=object_hook, object_pairs_hook=object_pairs_hook)
|
|
|
|
|
|
def current_timestamp():
|
|
return int(time.time() * 1000)
|
|
|
|
|
|
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"):
|
|
if not timestamp:
|
|
timestamp = time.time()
|
|
timestamp = int(timestamp) / 1000
|
|
time_array = time.localtime(timestamp)
|
|
str_date = time.strftime(format_string, time_array)
|
|
return str_date
|
|
|
|
|
|
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"):
|
|
time_array = time.strptime(time_str, format_string)
|
|
time_stamp = int(time.mktime(time_array) * 1000)
|
|
return time_stamp
|
|
|
|
|
|
def serialize_b64(src, to_str=False):
|
|
dest = base64.b64encode(pickle.dumps(src))
|
|
if not to_str:
|
|
return dest
|
|
else:
|
|
return bytes_to_string(dest)
|
|
|
|
|
|
def deserialize_b64(src):
|
|
src = base64.b64decode(string_to_bytes(src) if isinstance(src, str) else src)
|
|
if use_deserialize_safe_module:
|
|
return restricted_loads(src)
|
|
return pickle.loads(src)
|
|
|
|
|
|
safe_module = {
|
|
'numpy',
|
|
'fate_flow'
|
|
}
|
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
import importlib
|
|
if module.split('.')[0] in safe_module:
|
|
_module = importlib.import_module(module)
|
|
return getattr(_module, name)
|
|
|
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
|
|
(module, name))
|
|
|
|
|
|
def restricted_loads(src):
|
|
"""Helper function analogous to pickle.loads()."""
|
|
return RestrictedUnpickler(io.BytesIO(src)).load()
|
|
|
|
|
|
def get_lan_ip():
|
|
if os.name != "nt":
|
|
import fcntl
|
|
import struct
|
|
|
|
def get_interface_ip(ifname):
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
return socket.inet_ntoa(
|
|
fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24])
|
|
|
|
ip = socket.gethostbyname(socket.getfqdn())
|
|
if ip.startswith("127.") and os.name != "nt":
|
|
interfaces = [
|
|
"bond1",
|
|
"eth0",
|
|
"eth1",
|
|
"eth2",
|
|
"wlan0",
|
|
"wlan1",
|
|
"wifi0",
|
|
"ath0",
|
|
"ath1",
|
|
"ppp0",
|
|
]
|
|
for ifname in interfaces:
|
|
try:
|
|
ip = get_interface_ip(ifname)
|
|
break
|
|
except IOError as e:
|
|
pass
|
|
return ip or ''
|
|
|
|
def from_dict_hook(in_dict: dict):
|
|
if "type" in in_dict and "data" in in_dict:
|
|
if in_dict["module"] is None:
|
|
return in_dict["data"]
|
|
else:
|
|
return getattr(importlib.import_module(in_dict["module"]), in_dict["type"])(**in_dict["data"])
|
|
else:
|
|
return in_dict
|
|
|
|
|
|
def decrypt_database_password(password):
|
|
encrypt_password = get_base_config("encrypt_password", False)
|
|
encrypt_module = get_base_config("encrypt_module", False)
|
|
private_key = get_base_config("private_key", None)
|
|
|
|
if not password or not encrypt_password:
|
|
return password
|
|
|
|
if not private_key:
|
|
raise ValueError("No private key")
|
|
|
|
module_fun = encrypt_module.split("#")
|
|
pwdecrypt_fun = getattr(importlib.import_module(module_fun[0]), module_fun[1])
|
|
|
|
return pwdecrypt_fun(private_key, password)
|
|
|
|
|
|
def decrypt_database_config(database=None, passwd_key="passwd", name="database"):
|
|
if not database:
|
|
database = get_base_config(name, {})
|
|
|
|
database[passwd_key] = decrypt_database_password(database[passwd_key])
|
|
return database
|
|
|
|
|
|
def update_config(key, value, conf_name=SERVICE_CONF):
|
|
conf_path = conf_realpath(conf_name=conf_name)
|
|
if not os.path.isabs(conf_path):
|
|
conf_path = os.path.join(file_utils.get_project_base_directory(), conf_path)
|
|
|
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")):
|
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {}
|
|
config[key] = value
|
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config)
|
|
|
|
|
|
def get_uuid():
|
|
return uuid.uuid1().hex
|
|
|
|
|
|
def datetime_format(date_time: datetime) -> datetime:
|
|
return datetime(date_time.year, date_time.month, date_time.day, date_time.hour, date_time.minute, date_time.second)
|
|
|
|
|
|
def get_format_time() -> datetime:
|
|
return datetime_format(datetime.now())
|
|
|
|
|
|
def str2date(date_time: str):
|
|
return datetime.strptime(date_time, '%Y-%m-%d')
|
|
|
|
|
|
def elapsed2time(elapsed):
|
|
seconds = elapsed / 1000
|
|
minuter, second = divmod(seconds, 60)
|
|
hour, minuter = divmod(minuter, 60)
|
|
return '%02d:%02d:%02d' % (hour, minuter, second)
|
|
|
|
|
|
def decrypt(line):
|
|
file_path = os.path.join(file_utils.get_project_base_directory(), "conf", "private.pem")
|
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome")
|
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key)
|
|
return cipher.decrypt(base64.b64decode(line), "Fail to decrypt password!").decode('utf-8')
|
|
|
|
|
|
def download_img(url):
|
|
if not url: return ""
|
|
response = requests.get(url)
|
|
return "data:" + \
|
|
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
|
"base64," + base64.b64encode(response.content).decode("utf-8")
|
|
|