|
from flask import Flask, request, jsonify, render_template, send_from_directory |
|
import os |
|
import uuid |
|
import time |
|
import threading |
|
import tiktoken |
|
from datetime import datetime, timedelta |
|
from collections import defaultdict, deque |
|
from detoxify import Detoxify |
|
|
|
app = Flask(__name__, static_folder='static', template_folder='templates') |
|
|
|
|
|
detoxify_model = Detoxify('multilingual') |
|
|
|
|
|
API_KEY = os.getenv('API_KEY', 'your-api-key-here') |
|
|
|
|
|
request_times = deque(maxlen=1000) |
|
daily_requests = defaultdict(int) |
|
daily_tokens = defaultdict(int) |
|
concurrent_requests = 0 |
|
concurrent_requests_lock = threading.Lock() |
|
|
|
|
|
encoding = tiktoken.get_encoding("cl100k_base") |
|
|
|
def count_tokens(text): |
|
"""Count tokens in the given text using tiktoken.""" |
|
return len(encoding.encode(text)) |
|
|
|
def transform_predictions(prediction_dict): |
|
""" |
|
Transform predictions to OpenAI-style format. |
|
Returns: flagged, categories, category_scores, category_applied_input_types |
|
""" |
|
|
|
category_keys = [ |
|
"toxicity", "severe_toxicity", "obscene", "threat", |
|
"insult", "identity_attack", "sexual_explicit" |
|
] |
|
|
|
|
|
scores = {} |
|
for key in category_keys: |
|
scores[key] = prediction_dict.get(key, 0.0) |
|
|
|
|
|
threshold = 0.5 |
|
bool_categories = {key: (scores[key] > threshold) for key in category_keys} |
|
cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys} |
|
flagged = any(bool_categories.values()) |
|
|
|
return flagged, bool_categories, scores, cat_applied_input_types |
|
|
|
def track_request_metrics(start_time, tokens_count): |
|
"""Track performance metrics for requests.""" |
|
end_time = time.time() |
|
request_time = end_time - start_time |
|
request_times.append(request_time) |
|
|
|
today = datetime.now().strftime("%Y-%m-%d") |
|
daily_requests[today] += 1 |
|
daily_tokens[today] += tokens_count |
|
|
|
def get_performance_metrics(): |
|
"""Get current performance metrics.""" |
|
global concurrent_requests |
|
with concurrent_requests_lock: |
|
current_concurrent = concurrent_requests |
|
|
|
|
|
avg_request_time = sum(request_times) / len(request_times) if request_times else 0 |
|
|
|
|
|
today = datetime.now().strftime("%Y-%m-%d") |
|
|
|
|
|
recent_requests = list(request_times)[-100:] if len(request_times) >= 100 else list(request_times) |
|
requests_per_second = len(recent_requests) / sum(recent_requests) if recent_requests and sum(recent_requests) > 0 else 0 |
|
|
|
|
|
today_requests = daily_requests.get(today, 0) |
|
today_tokens = daily_tokens.get(today, 0) |
|
|
|
|
|
last_7_days = [] |
|
for i in range(7): |
|
date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d") |
|
last_7_days.append({ |
|
"date": date, |
|
"requests": daily_requests.get(date, 0), |
|
"tokens": daily_tokens.get(date, 0) |
|
}) |
|
|
|
return { |
|
"avg_request_time": avg_request_time, |
|
"requests_per_second": requests_per_second, |
|
"concurrent_requests": current_concurrent, |
|
"today_requests": today_requests, |
|
"today_tokens": today_tokens, |
|
"last_7_days": last_7_days |
|
} |
|
|
|
@app.route('/') |
|
def home(): |
|
return render_template('index.html') |
|
|
|
@app.route('/v1/moderations', methods=['POST']) |
|
def moderations(): |
|
global concurrent_requests |
|
|
|
|
|
with concurrent_requests_lock: |
|
concurrent_requests += 1 |
|
|
|
start_time = time.time() |
|
total_tokens = 0 |
|
|
|
try: |
|
|
|
auth_header = request.headers.get('Authorization') |
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
provided_api_key = auth_header.split(" ")[1] |
|
if provided_api_key != API_KEY: |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
|
|
data = request.get_json() |
|
raw_input = data.get('input') or data.get('texts') |
|
|
|
if raw_input is None: |
|
return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400 |
|
|
|
|
|
if isinstance(raw_input, str): |
|
texts = [raw_input] |
|
elif isinstance(raw_input, list): |
|
texts = raw_input |
|
else: |
|
return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400 |
|
|
|
|
|
if len(texts) > 10: |
|
return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400 |
|
|
|
for text in texts: |
|
if not isinstance(text, str) or len(text) > 100000: |
|
return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400 |
|
total_tokens += count_tokens(text) |
|
|
|
|
|
results = [] |
|
for text in texts: |
|
pred = detoxify_model.predict([text]) |
|
prediction = {k: v[0] for k, v in pred.items()} |
|
flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(prediction) |
|
|
|
results.append({ |
|
"flagged": flagged, |
|
"categories": bool_categories, |
|
"category_scores": scores, |
|
"category_applied_input_types": cat_applied_input_types |
|
}) |
|
|
|
|
|
track_request_metrics(start_time, total_tokens) |
|
|
|
|
|
response_data = { |
|
"id": "modr-" + uuid.uuid4().hex[:24], |
|
"model": "unitaryai/detoxify-multilingual", |
|
"results": results, |
|
"object": "moderation", |
|
"usage": { |
|
"total_tokens": total_tokens |
|
} |
|
} |
|
|
|
return jsonify(response_data) |
|
|
|
finally: |
|
|
|
with concurrent_requests_lock: |
|
concurrent_requests -= 1 |
|
|
|
@app.route('/v1/metrics', methods=['GET']) |
|
def metrics(): |
|
"""Endpoint to get performance metrics.""" |
|
auth_header = request.headers.get('Authorization') |
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
provided_api_key = auth_header.split(" ")[1] |
|
if provided_api_key != API_KEY: |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
return jsonify(get_performance_metrics()) |
|
|
|
if __name__ == '__main__': |
|
|
|
os.makedirs('templates', exist_ok=True) |
|
os.makedirs('static', exist_ok=True) |
|
|
|
port = int(os.getenv('PORT', 7860)) |
|
app.run(host='0.0.0.0', port=port, debug=True) |