from flask import Flask, request, jsonify, render_template_string
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal

# 初始化 Flask 应用
app = Flask(__name__)

# 配置日志,级别设为 INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")

# 预定义代码片段
CODE_SNIPPETS = [
    "def sort_list(x): return sorted(x)",
    """def count_above_threshold(elements, threshold=0):
    return sum(1 for e in elements if e > threshold)""",
    """def find_min_max(elements):
    return min(elements), max(elements)"""
    """def count_evens(nums):
    return len([n for n in nums if n % 2 == 0])""",
     """def reverse_string(s):
    return s[::-1]""",
    """def is_prime(n):
    if n < 2:
        return False
    for i in range(2, int(n**0.5)+1):
        if n % i == 0:
            return False
    return True""",
    """def factorial(n):
    result = 1
    for i in range(1, n+1):
        result *= i
    return result""",
    """def sum_of_squares(nums):
    return sum(map(lambda x: x**2, nums))"""
]

# 全局服务状态
service_ready = False

# 优雅关闭处理
def handle_shutdown(signum, frame):
    app.logger.info("收到终止信号,开始关闭...")
    sys.exit(0)

signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)

# 初始化模型和预计算编码
try:
    app.logger.info("开始加载模型...")
    model = SentenceTransformer(
        "flax-sentence-embeddings/st-codesearch-distilroberta-base",
        cache_folder="/model-cache"
    )
    # 预计算代码片段的编码(强制使用 CPU)
    code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu")
    service_ready = True
    app.logger.info("服务初始化完成")
except Exception as e:
    app.logger.error("初始化失败: %s", str(e))
    raise

# Hugging Face 健康检查端点,必须响应根路径
@app.route('/')
def hf_health_check():
    # 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接)
    if request.accept_mimetypes.accept_html:
        html = """
        <h2>CodeSearch API</h2>
        <p>服务状态:{{ status }}</p>
        <p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p>
        """
        status = "ready" if service_ready else "initializing"
        return render_template_string(html, status=status)
    # 否则返回 JSON 格式的健康检查
    if service_ready:
        return jsonify({"status": "ready"}), 200
    else:
        return jsonify({"status": "initializing"}), 503

# 搜索 API 端点,同时支持 GET 和 POST 请求
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
    if not service_ready:
        app.logger.info("服务未就绪")
        return jsonify({"error": "服务正在初始化"}), 503

    try:
        # 根据请求方法提取查询内容
        if request.method == 'GET':
            query = request.args.get('query', '').strip()
        else:
            data = request.get_json() or {}
            query = data.get('query', '').strip()

        if not query:
            app.logger.info("收到空的查询请求")
            return jsonify({"error": "查询不能为空"}), 400

        # 记录接收到的查询
        app.logger.info("收到查询请求: %s", query)

        # 对查询进行编码,并进行语义搜索
        query_emb = model.encode(query, convert_to_tensor=True, device="cpu")
        hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
        best = hits[0]

        result = {
            "code": CODE_SNIPPETS[best['corpus_id']],
            "score": round(float(best['score']), 4)
        }

        # 记录返回结果
        app.logger.info("返回结果: %s", result)
        return jsonify(result)

    except Exception as e:
        app.logger.error("请求处理失败: %s", str(e))
        return jsonify({"error": "服务器内部错误"}), 500

if __name__ == "__main__":
    # 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动
    app.run(host='0.0.0.0', port=7860)