|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
import datetime |
|
import io |
|
import json |
|
import os |
|
import pickle |
|
import socket |
|
import time |
|
import uuid |
|
import requests |
|
import logging |
|
from enum import Enum, IntEnum |
|
import importlib |
|
from Cryptodome.PublicKey import RSA |
|
from Cryptodome.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 |
|
from filelock import FileLock |
|
from api.constants import SERVICE_CONF |
|
|
|
from . import file_utils |
|
|
|
|
|
def conf_realpath(conf_name): |
|
conf_path = f"conf/{conf_name}" |
|
return os.path.join(file_utils.get_project_base_directory(), conf_path) |
|
|
|
|
|
def read_config(conf_name=SERVICE_CONF): |
|
local_config = {} |
|
local_path = conf_realpath(f'local.{conf_name}') |
|
|
|
|
|
if os.path.exists(local_path): |
|
local_config = file_utils.load_yaml_conf(local_path) |
|
if not isinstance(local_config, dict): |
|
raise ValueError(f'Invalid config file: "{local_path}".') |
|
|
|
global_config_path = conf_realpath(conf_name) |
|
global_config = file_utils.load_yaml_conf(global_config_path) |
|
|
|
if not isinstance(global_config, dict): |
|
raise ValueError(f'Invalid config file: "{global_config_path}".') |
|
|
|
global_config.update(local_config) |
|
return global_config |
|
|
|
|
|
CONFIGS = read_config() |
|
|
|
|
|
def show_configs(): |
|
msg = f"Current configs, from {conf_realpath(SERVICE_CONF)}:" |
|
for k, v in CONFIGS.items(): |
|
msg += f"\n\t{k}: {v}" |
|
logging.info(msg) |
|
|
|
|
|
def get_base_config(key, default=None): |
|
if key is None: |
|
return None |
|
if default is None: |
|
default = os.environ.get(key.upper()) |
|
return CONFIGS.get(key, default) |
|
|
|
|
|
use_deserialize_safe_module = get_base_config( |
|
'use_deserialize_safe_module', False) |
|
|
|
|
|
class BaseType: |
|
def to_dict(self): |
|
return dict([(k.lstrip("_"), v) for k, v in self.__dict__.items()]) |
|
|
|
def to_dict_with_type(self): |
|
def _dict(obj): |
|
module = None |
|
if issubclass(obj.__class__, BaseType): |
|
data = {} |
|
for attr, v in obj.__dict__.items(): |
|
k = attr.lstrip("_") |
|
data[k] = _dict(v) |
|
module = obj.__module__ |
|
elif isinstance(obj, (list, tuple)): |
|
data = [] |
|
for i, vv in enumerate(obj): |
|
data.append(_dict(vv)) |
|
elif isinstance(obj, dict): |
|
data = {} |
|
for _k, vv in obj.items(): |
|
data[_k] = _dict(vv) |
|
else: |
|
data = obj |
|
return {"type": obj.__class__.__name__, |
|
"data": data, "module": module} |
|
|
|
return _dict(self) |
|
|
|
|
|
class CustomJSONEncoder(json.JSONEncoder): |
|
def __init__(self, **kwargs): |
|
self._with_type = kwargs.pop("with_type", False) |
|
super().__init__(**kwargs) |
|
|
|
def default(self, obj): |
|
if isinstance(obj, datetime.datetime): |
|
return obj.strftime('%Y-%m-%d %H:%M:%S') |
|
elif isinstance(obj, datetime.date): |
|
return obj.strftime('%Y-%m-%d') |
|
elif isinstance(obj, datetime.timedelta): |
|
return str(obj) |
|
elif issubclass(type(obj), Enum) or issubclass(type(obj), IntEnum): |
|
return obj.value |
|
elif isinstance(obj, set): |
|
return list(obj) |
|
elif issubclass(type(obj), BaseType): |
|
if not self._with_type: |
|
return obj.to_dict() |
|
else: |
|
return obj.to_dict_with_type() |
|
elif isinstance(obj, type): |
|
return obj.__name__ |
|
else: |
|
return json.JSONEncoder.default(self, obj) |
|
|
|
|
|
def rag_uuid(): |
|
return uuid.uuid1().hex |
|
|
|
|
|
def string_to_bytes(string): |
|
return string if isinstance( |
|
string, bytes) else string.encode(encoding="utf-8") |
|
|
|
|
|
def bytes_to_string(byte): |
|
return byte.decode(encoding="utf-8") |
|
|
|
|
|
def json_dumps(src, byte=False, indent=None, with_type=False): |
|
dest = json.dumps( |
|
src, |
|
indent=indent, |
|
cls=CustomJSONEncoder, |
|
with_type=with_type) |
|
if byte: |
|
dest = string_to_bytes(dest) |
|
return dest |
|
|
|
|
|
def json_loads(src, object_hook=None, object_pairs_hook=None): |
|
if isinstance(src, bytes): |
|
src = bytes_to_string(src) |
|
return json.loads(src, object_hook=object_hook, |
|
object_pairs_hook=object_pairs_hook) |
|
|
|
|
|
def current_timestamp(): |
|
return int(time.time() * 1000) |
|
|
|
|
|
def timestamp_to_date(timestamp, format_string="%Y-%m-%d %H:%M:%S"): |
|
if not timestamp: |
|
timestamp = time.time() |
|
timestamp = int(timestamp) / 1000 |
|
time_array = time.localtime(timestamp) |
|
str_date = time.strftime(format_string, time_array) |
|
return str_date |
|
|
|
|
|
def date_string_to_timestamp(time_str, format_string="%Y-%m-%d %H:%M:%S"): |
|
time_array = time.strptime(time_str, format_string) |
|
time_stamp = int(time.mktime(time_array) * 1000) |
|
return time_stamp |
|
|
|
|
|
def serialize_b64(src, to_str=False): |
|
dest = base64.b64encode(pickle.dumps(src)) |
|
if not to_str: |
|
return dest |
|
else: |
|
return bytes_to_string(dest) |
|
|
|
|
|
def deserialize_b64(src): |
|
src = base64.b64decode( |
|
string_to_bytes(src) if isinstance( |
|
src, str) else src) |
|
if use_deserialize_safe_module: |
|
return restricted_loads(src) |
|
return pickle.loads(src) |
|
|
|
|
|
safe_module = { |
|
'numpy', |
|
'rag_flow' |
|
} |
|
|
|
|
|
class RestrictedUnpickler(pickle.Unpickler): |
|
def find_class(self, module, name): |
|
import importlib |
|
if module.split('.')[0] in safe_module: |
|
_module = importlib.import_module(module) |
|
return getattr(_module, name) |
|
|
|
raise pickle.UnpicklingError("global '%s.%s' is forbidden" % |
|
(module, name)) |
|
|
|
|
|
def restricted_loads(src): |
|
"""Helper function analogous to pickle.loads().""" |
|
return RestrictedUnpickler(io.BytesIO(src)).load() |
|
|
|
|
|
def get_lan_ip(): |
|
if os.name != "nt": |
|
import fcntl |
|
import struct |
|
|
|
def get_interface_ip(ifname): |
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
|
return socket.inet_ntoa( |
|
fcntl.ioctl(s.fileno(), 0x8915, struct.pack('256s', string_to_bytes(ifname[:15])))[20:24]) |
|
|
|
ip = socket.gethostbyname(socket.getfqdn()) |
|
if ip.startswith("127.") and os.name != "nt": |
|
interfaces = [ |
|
"bond1", |
|
"eth0", |
|
"eth1", |
|
"eth2", |
|
"wlan0", |
|
"wlan1", |
|
"wifi0", |
|
"ath0", |
|
"ath1", |
|
"ppp0", |
|
] |
|
for ifname in interfaces: |
|
try: |
|
ip = get_interface_ip(ifname) |
|
break |
|
except IOError: |
|
pass |
|
return ip or '' |
|
|
|
|
|
def from_dict_hook(in_dict: dict): |
|
if "type" in in_dict and "data" in in_dict: |
|
if in_dict["module"] is None: |
|
return in_dict["data"] |
|
else: |
|
return getattr(importlib.import_module( |
|
in_dict["module"]), in_dict["type"])(**in_dict["data"]) |
|
else: |
|
return in_dict |
|
|
|
|
|
def decrypt_database_password(password): |
|
encrypt_password = get_base_config("encrypt_password", False) |
|
encrypt_module = get_base_config("encrypt_module", False) |
|
private_key = get_base_config("private_key", None) |
|
|
|
if not password or not encrypt_password: |
|
return password |
|
|
|
if not private_key: |
|
raise ValueError("No private key") |
|
|
|
module_fun = encrypt_module.split("#") |
|
pwdecrypt_fun = getattr( |
|
importlib.import_module( |
|
module_fun[0]), |
|
module_fun[1]) |
|
|
|
return pwdecrypt_fun(private_key, password) |
|
|
|
|
|
def decrypt_database_config( |
|
database=None, passwd_key="password", name="database"): |
|
if not database: |
|
database = get_base_config(name, {}) |
|
|
|
database[passwd_key] = decrypt_database_password(database[passwd_key]) |
|
return database |
|
|
|
|
|
def update_config(key, value, conf_name=SERVICE_CONF): |
|
conf_path = conf_realpath(conf_name=conf_name) |
|
if not os.path.isabs(conf_path): |
|
conf_path = os.path.join( |
|
file_utils.get_project_base_directory(), conf_path) |
|
|
|
with FileLock(os.path.join(os.path.dirname(conf_path), ".lock")): |
|
config = file_utils.load_yaml_conf(conf_path=conf_path) or {} |
|
config[key] = value |
|
file_utils.rewrite_yaml_conf(conf_path=conf_path, config=config) |
|
|
|
|
|
def get_uuid(): |
|
return uuid.uuid1().hex |
|
|
|
|
|
def datetime_format(date_time: datetime.datetime) -> datetime.datetime: |
|
return datetime.datetime(date_time.year, date_time.month, date_time.day, |
|
date_time.hour, date_time.minute, date_time.second) |
|
|
|
|
|
def get_format_time() -> datetime.datetime: |
|
return datetime_format(datetime.datetime.now()) |
|
|
|
|
|
def str2date(date_time: str): |
|
return datetime.datetime.strptime(date_time, '%Y-%m-%d') |
|
|
|
|
|
def elapsed2time(elapsed): |
|
seconds = elapsed / 1000 |
|
minuter, second = divmod(seconds, 60) |
|
hour, minuter = divmod(minuter, 60) |
|
return '%02d:%02d:%02d' % (hour, minuter, second) |
|
|
|
|
|
def decrypt(line): |
|
file_path = os.path.join( |
|
file_utils.get_project_base_directory(), |
|
"conf", |
|
"private.pem") |
|
rsa_key = RSA.importKey(open(file_path).read(), "Welcome") |
|
cipher = Cipher_pkcs1_v1_5.new(rsa_key) |
|
return cipher.decrypt(base64.b64decode( |
|
line), "Fail to decrypt password!").decode('utf-8') |
|
|
|
|
|
def download_img(url): |
|
if not url: |
|
return "" |
|
response = requests.get(url) |
|
return "data:" + \ |
|
response.headers.get('Content-Type', 'image/jpg') + ";" + \ |
|
"base64," + base64.b64encode(response.content).decode("utf-8") |
|
|
|
|
|
def delta_seconds(date_string: str): |
|
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S") |
|
return (datetime.datetime.now() - dt).total_seconds() |
|
|