detoxify / app.py
Utiric's picture
Update app.py
038cbb8 verified
raw
history blame
7.19 kB
from flask import Flask, request, jsonify, render_template, send_from_directory
import os
import uuid
import time
import threading
import tiktoken
from datetime import datetime, timedelta
from collections import defaultdict, deque
from detoxify import Detoxify
app = Flask(__name__, static_folder='static', template_folder='templates')
# Load the detoxify model
detoxify_model = Detoxify('multilingual')
# API key from environment variable
API_KEY = os.getenv('API_KEY', 'your-api-key-here')
# Performance metrics tracking
request_times = deque(maxlen=1000) # Store last 1000 request times
daily_requests = defaultdict(int)
daily_tokens = defaultdict(int)
concurrent_requests = 0
concurrent_requests_lock = threading.Lock()
# Token encoding for counting tokens
encoding = tiktoken.get_encoding("cl100k_base")
def count_tokens(text):
"""Count tokens in the given text using tiktoken."""
return len(encoding.encode(text))
def transform_predictions(prediction_dict):
"""
Transform predictions to OpenAI-style format.
Returns: flagged, categories, category_scores, category_applied_input_types
"""
# Define the categories we want to track
category_keys = [
"toxicity", "severe_toxicity", "obscene", "threat",
"insult", "identity_attack", "sexual_explicit"
]
# Map detoxify output to our categories
scores = {}
for key in category_keys:
scores[key] = prediction_dict.get(key, 0.0)
# Set threshold for flagging content
threshold = 0.5
bool_categories = {key: (scores[key] > threshold) for key in category_keys}
cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys}
flagged = any(bool_categories.values())
return flagged, bool_categories, scores, cat_applied_input_types
def track_request_metrics(start_time, tokens_count):
"""Track performance metrics for requests."""
end_time = time.time()
request_time = end_time - start_time
request_times.append(request_time)
today = datetime.now().strftime("%Y-%m-%d")
daily_requests[today] += 1
daily_tokens[today] += tokens_count
def get_performance_metrics():
"""Get current performance metrics."""
global concurrent_requests
with concurrent_requests_lock:
current_concurrent = concurrent_requests
# Calculate average request time
avg_request_time = sum(request_times) / len(request_times) if request_times else 0
# Get today's date
today = datetime.now().strftime("%Y-%m-%d")
# Calculate requests per second (based on last 100 requests)
recent_requests = list(request_times)[-100:] if len(request_times) >= 100 else list(request_times)
requests_per_second = len(recent_requests) / sum(recent_requests) if recent_requests and sum(recent_requests) > 0 else 0
# Get daily stats
today_requests = daily_requests.get(today, 0)
today_tokens = daily_tokens.get(today, 0)
# Get last 7 days stats
last_7_days = []
for i in range(7):
date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
last_7_days.append({
"date": date,
"requests": daily_requests.get(date, 0),
"tokens": daily_tokens.get(date, 0)
})
return {
"avg_request_time": avg_request_time,
"requests_per_second": requests_per_second,
"concurrent_requests": current_concurrent,
"today_requests": today_requests,
"today_tokens": today_tokens,
"last_7_days": last_7_days
}
@app.route('/')
def home():
return render_template('index.html')
@app.route('/v1/moderations', methods=['POST'])
def moderations():
global concurrent_requests
# Track concurrent requests
with concurrent_requests_lock:
concurrent_requests += 1
start_time = time.time()
total_tokens = 0
try:
# Check authorization
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith("Bearer "):
return jsonify({"error": "Unauthorized"}), 401
provided_api_key = auth_header.split(" ")[1]
if provided_api_key != API_KEY:
return jsonify({"error": "Unauthorized"}), 401
# Get input data
data = request.get_json()
raw_input = data.get('input') or data.get('texts')
if raw_input is None:
return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400
# Handle both string and list inputs
if isinstance(raw_input, str):
texts = [raw_input]
elif isinstance(raw_input, list):
texts = raw_input
else:
return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400
# Validate input size
if len(texts) > 10:
return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400
for text in texts:
if not isinstance(text, str) or len(text) > 100000:
return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400
total_tokens += count_tokens(text)
# Process each text
results = []
for text in texts:
pred = detoxify_model.predict([text])
prediction = {k: v[0] for k, v in pred.items()}
flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(prediction)
results.append({
"flagged": flagged,
"categories": bool_categories,
"category_scores": scores,
"category_applied_input_types": cat_applied_input_types
})
# Track metrics
track_request_metrics(start_time, total_tokens)
# Prepare response
response_data = {
"id": "modr-" + uuid.uuid4().hex[:24],
"model": "unitaryai/detoxify-multilingual",
"results": results,
"object": "moderation",
"usage": {
"total_tokens": total_tokens
}
}
return jsonify(response_data)
finally:
# Decrement concurrent requests counter
with concurrent_requests_lock:
concurrent_requests -= 1
@app.route('/v1/metrics', methods=['GET'])
def metrics():
"""Endpoint to get performance metrics."""
auth_header = request.headers.get('Authorization')
if not auth_header or not auth_header.startswith("Bearer "):
return jsonify({"error": "Unauthorized"}), 401
provided_api_key = auth_header.split(" ")[1]
if provided_api_key != API_KEY:
return jsonify({"error": "Unauthorized"}), 401
return jsonify(get_performance_metrics())
if __name__ == '__main__':
# Create directories if they don't exist
os.makedirs('templates', exist_ok=True)
os.makedirs('static', exist_ok=True)
port = int(os.getenv('PORT', 7860))
app.run(host='0.0.0.0', port=port, debug=True)