Spaces:
Runtime error
Runtime error
from flask import Flask, request, Response | |
import logging | |
import threading | |
from huggingface_hub import snapshot_download#, Repository | |
import huggingface_hub | |
import gc | |
import os.path | |
import xml.etree.ElementTree as ET | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from datetime import datetime, timedelta | |
from llm_backend import LlmBackend | |
import json | |
import sys | |
llm = LlmBackend() | |
_lock = threading.Lock() | |
SYSTEM_PROMPT = os.environ.get('SYSTEM_PROMPT', default="Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык.") | |
CONTEXT_SIZE = int(os.environ.get('CONTEXT_SIZE', default='500')) | |
HF_CACHE_DIR = os.environ.get('HF_CACHE_DIR', default='/home/user/app/.cache') | |
USE_SYSTEM_PROMPT = os.environ.get('USE_SYSTEM_PROMPT', default='False').lower() == 'true' | |
ENABLE_GPU = os.environ.get('ENABLE_GPU', default='False').lower() == 'true' | |
GPU_LAYERS = int(os.environ.get('GPU_LAYERS', default='0')) | |
CHAT_FORMAT = os.environ.get('CHAT_FORMAT', default='llama-2') | |
REPO_NAME = os.environ.get('REPO_NAME', default='IlyaGusev/saiga2_7b_gguf') | |
MODEL_NAME = os.environ.get('MODEL_NAME', default='model-q4_K.gguf') | |
DATASET_REPO_URL = os.environ.get('DATASET_REPO_URL', default="https://huggingface.co/datasets/muryshev/saiga-chat") | |
DATA_FILENAME = os.environ.get('DATA_FILENAME', default="data-saiga-cuda-release.xml") | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
APP_HOST = os.environ.get('APP_HOST', default='0.0.0.0') | |
APP_PORT = int(os.environ.get('APP_PORT', default='7860')) | |
FLASK_THREADED = os.environ.get('FLASK_THREADED', default='False').lower() == "true" | |
# Create a lock object | |
lock = threading.Lock() | |
app = Flask('llm_api') | |
app.logger.handlers.clear() | |
handler = logging.StreamHandler(sys.stdout) | |
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) | |
app.logger.addHandler(handler) | |
app.logger.setLevel(logging.DEBUG) | |
# Variable to store the last request time | |
last_request_time = datetime.now() | |
# Initialize the model when the application starts | |
#model_path = "../models/model-q4_K.gguf" # Replace with the actual model path | |
#MODEL_NAME = "model/ggml-model-q4_K.gguf" | |
#REPO_NAME = "IlyaGusev/saiga2_13b_gguf" | |
#MODEL_NAME = "model-q4_K.gguf" | |
#epo_name = "IlyaGusev/saiga2_70b_gguf" | |
#MODEL_NAME = "ggml-model-q4_1.gguf" | |
local_dir = '.' | |
if os.path.isdir('/data'): | |
app.logger.info('Persistent storage enabled') | |
model = None | |
MODEL_PATH = snapshot_download(repo_id=REPO_NAME, allow_patterns=MODEL_NAME, cache_dir=HF_CACHE_DIR) + '/' + MODEL_NAME | |
app.logger.info('Model path: ' + MODEL_PATH) | |
DATA_FILE = os.path.join("dataset", DATA_FILENAME) | |
app.logger.info("hfh: "+huggingface_hub.__version__) | |
# repo = Repository( | |
# local_dir="dataset", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN | |
# ) | |
# def log(req: str = '', resp: str = ''): | |
# if req or resp: | |
# element = ET.Element("row", {"time": str(datetime.now()) }) | |
# req_element = ET.SubElement(element, "request") | |
# req_element.text = req | |
# resp_element = ET.SubElement(element, "response") | |
# resp_element.text = resp | |
# with open(DATA_FILE, "ab+") as xml_file: | |
# xml_file.write(ET.tostring(element, encoding="utf-8")) | |
# commit_url = repo.push_to_hub() | |
# app.logger.info(commit_url) | |
def handler_change_context_size(): | |
global stop_generation, model | |
stop_generation = True | |
new_size = int(request.args.get('size', CONTEXT_SIZE)) | |
init_model(new_size, ENABLE_GPU, GPU_LAYERS) | |
return Response('Size changed', content_type='text/plain') | |
def handler_stop_generation(): | |
global stop_generation | |
stop_generation = True | |
return Response('Stopped', content_type='text/plain') | |
def generate_unknown_response(): | |
app.logger.info('unknown method: '+request.method) | |
try: | |
request_payload = request.get_json() | |
app.logger.info('payload: '+request.get_json()) | |
except Exception as e: | |
app.logger.info('payload empty') | |
return Response('What do you want?', content_type='text/plain') | |
response_tokens = bytearray() | |
def generate_and_log_tokens(user_request, generator): | |
global response_tokens, last_request_time | |
for token in llm.generate_tokens(generator): | |
if token == b'': # or (max_new_tokens is not None and i >= max_new_tokens): | |
last_request_time = datetime.now() | |
# log(json.dumps(user_request), response_tokens.decode("utf-8", errors="ignore")) | |
response_tokens = bytearray() | |
break | |
response_tokens.extend(token) | |
yield token | |
def generate_response(): | |
app.logger.info('generate_response called') | |
data = request.get_json() | |
app.logger.info(data) | |
messages = data.get("messages", []) | |
preprompt = data.get("preprompt", "") | |
parameters = data.get("parameters", {}) | |
# Extract parameters from the request | |
p = { | |
'temperature': parameters.get("temperature", 0.01), | |
'truncate': parameters.get("truncate", 1000), | |
'max_new_tokens': parameters.get("max_new_tokens", 1024), | |
'top_p': parameters.get("top_p", 0.85), | |
'repetition_penalty': parameters.get("repetition_penalty", 1.2), | |
'top_k': parameters.get("top_k", 30), | |
'return_full_text': parameters.get("return_full_text", False) | |
} | |
generator = llm.create_chat_generator_for_saiga(messages=messages, parameters=p, use_system_prompt=USE_SYSTEM_PROMPT) | |
app.logger.info('Generator created') | |
# Use Response to stream tokens | |
return Response(generate_and_log_tokens(user_request='1', generator=generator), content_type='text/plain', status=200, direct_passthrough=True) | |
def init_model(): | |
llm.load_model(model_path=MODEL_PATH, context_size=CONTEXT_SIZE, enable_gpu=ENABLE_GPU, gpu_layer_number=GPU_LAYERS) | |
# Function to check if no requests were made in the last 5 minutes | |
def check_last_request_time(): | |
global last_request_time | |
current_time = datetime.now() | |
if (current_time - last_request_time).total_seconds() > 300: # 5 minutes in seconds | |
llm.unload_model() | |
app.logger.info(f"Model unloaded at {current_time}") | |
else: | |
app.logger.info(f"No action needed at {current_time}") | |
if __name__ == "__main__": | |
init_model() | |
# scheduler = BackgroundScheduler() | |
# scheduler.add_job(check_last_request_time, trigger='interval', minutes=1) | |
# scheduler.start() | |
app.run(host=APP_HOST, port=APP_PORT, debug=False, threaded=FLASK_THREADED) | |