File size: 7,191 Bytes
038cbb8 2ec7d5b ebef9a1 038cbb8 c7133b5 2ec7d5b 038cbb8 2ec7d5b 038cbb8 176df74 2ec7d5b 038cbb8 2ec7d5b 038cbb8 2ec7d5b 038cbb8 ebef9a1 038cbb8 ebef9a1 038cbb8 ebef9a1 038cbb8 ebef9a1 038cbb8 ebef9a1 038cbb8 ebef9a1 038cbb8 fd627d4 ebef9a1 038cbb8 c7133b5 038cbb8 c7133b5 176df74 038cbb8 176df74 038cbb8 176df74 038cbb8 ebef9a1 038cbb8 176df74 038cbb8 176df74 2ec7d5b 038cbb8 17d94c7 038cbb8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
from flask import Flask, request, jsonify, render_template, send_from_directory
import os
import uuid
import time
import threading
import tiktoken
from datetime import datetime, timedelta
from collections import defaultdict, deque
from detoxify import Detoxify
app = Flask(__name__, static_folder='static', template_folder='templates')
# Load the detoxify model
detoxify_model = Detoxify('multilingual')
# API key from environment variable
API_KEY = os.getenv('API_KEY', 'your-api-key-here')
# Performance metrics tracking
request_times = deque(maxlen=1000) # Store last 1000 request times
daily_requests = defaultdict(int)
daily_tokens = defaultdict(int)
concurrent_requests = 0
concurrent_requests_lock = threading.Lock()
# Token encoding for counting tokens
encoding = tiktoken.get_encoding("cl100k_base")
def count_tokens(text):
"""Count tokens in the given text using tiktoken."""
return len(encoding.encode(text))
def transform_predictions(prediction_dict):
"""
Transform predictions to OpenAI-style format.
Returns: flagged, categories, category_scores, category_applied_input_types
"""
# Define the categories we want to track
category_keys = [
"toxicity", "severe_toxicity", "obscene", "threat",
"insult", "identity_attack", "sexual_explicit"
]
# Map detoxify output to our categories
scores = {}
for key in category_keys:
scores[key] = prediction_dict.get(key, 0.0)
# Set threshold for flagging content
threshold = 0.5
bool_categories = {key: (scores[key] > threshold) for key in category_keys}
cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys}
flagged = any(bool_categories.values())
return flagged, bool_categories, scores, cat_applied_input_types
def track_request_metrics(start_time, tokens_count):
"""Track performance metrics for requests."""
end_time = time.time()
request_time = end_time - start_time
request_times.append(request_time)
today = datetime.now().strftime("%Y-%m-%d")
daily_requests[today] += 1
daily_tokens[today] += tokens_count
def get_performance_metrics():
"""Get current performance metrics."""
global concurrent_requests
with concurrent_requests_lock:
current_concurrent = concurrent_requests
# Calculate average request time
avg_request_time = sum(request_times) / len(request_times) if request_times else 0
# Get today's date
today = datetime.now().strftime("%Y-%m-%d")
# Calculate requests per second (based on last 100 requests)
recent_requests = list(request_times)[-100:] if len(request_times) >= 100 else list(request_times)
requests_per_second = len(recent_requests) / sum(recent_requests) if recent_requests and sum(recent_requests) > 0 else 0
# Get daily stats
today_requests = daily_requests.get(today, 0)
today_tokens = daily_tokens.get(today, 0)
# Get last 7 days stats
last_7_days = []
for i in range(7):
date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
last_7_days.append({
"date": date,
"requests": daily_requests.get(date, 0),
"tokens": daily_tokens.get(date, 0)
})
return {
"avg_request_time": avg_request_time,
"requests_per_second": requests_per_second,
"concurrent_requests": current_concurrent,
"today_requests": today_requests,
"today_tokens": today_tokens,
"last_7_days": last_7_days
}
@app.route('/')
def home():
return render_template('index.html')
@app.route('/v1/moderations', methods=['POST'])
def moderations():
global concurrent_requests
# Track concurrent requests
with concurrent_requests_lock:
concurrent_requests += 1
start_time = time.time()
total_tokens = 0
try:
# Check authorization
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith("Bearer "):
return jsonify({"error": "Unauthorized"}), 401
provided_api_key = auth_header.split(" ")[1]
if provided_api_key != API_KEY:
return jsonify({"error": "Unauthorized"}), 401
# Get input data
data = request.get_json()
raw_input = data.get('input') or data.get('texts')
if raw_input is None:
return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400
# Handle both string and list inputs
if isinstance(raw_input, str):
texts = [raw_input]
elif isinstance(raw_input, list):
texts = raw_input
else:
return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400
# Validate input size
if len(texts) > 10:
return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400
for text in texts:
if not isinstance(text, str) or len(text) > 100000:
return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400
total_tokens += count_tokens(text)
# Process each text
results = []
for text in texts:
pred = detoxify_model.predict([text])
prediction = {k: v[0] for k, v in pred.items()}
flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(prediction)
results.append({
"flagged": flagged,
"categories": bool_categories,
"category_scores": scores,
"category_applied_input_types": cat_applied_input_types
})
# Track metrics
track_request_metrics(start_time, total_tokens)
# Prepare response
response_data = {
"id": "modr-" + uuid.uuid4().hex[:24],
"model": "unitaryai/detoxify-multilingual",
"results": results,
"object": "moderation",
"usage": {
"total_tokens": total_tokens
}
}
return jsonify(response_data)
finally:
# Decrement concurrent requests counter
with concurrent_requests_lock:
concurrent_requests -= 1
@app.route('/v1/metrics', methods=['GET'])
def metrics():
"""Endpoint to get performance metrics."""
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith("Bearer "):
return jsonify({"error": "Unauthorized"}), 401
provided_api_key = auth_header.split(" ")[1]
if provided_api_key != API_KEY:
return jsonify({"error": "Unauthorized"}), 401
return jsonify(get_performance_metrics())
if __name__ == '__main__':
# Create directories if they don't exist
os.makedirs('templates', exist_ok=True)
os.makedirs('static', exist_ok=True)
port = int(os.getenv('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=True) |