from flask import Flask, request, jsonify, render_template, send_from_directory import os import uuid import time import threading import tiktoken from datetime import datetime, timedelta from collections import defaultdict, deque from detoxify import Detoxify app = Flask(__name__, static_folder='static', template_folder='templates') # Load the detoxify model detoxify_model = Detoxify('multilingual') # API key from environment variable API_KEY = os.getenv('API_KEY', 'your-api-key-here') # Performance metrics tracking request_times = deque(maxlen=1000) # Store last 1000 request times daily_requests = defaultdict(int) daily_tokens = defaultdict(int) concurrent_requests = 0 concurrent_requests_lock = threading.Lock() # Token encoding for counting tokens encoding = tiktoken.get_encoding("cl100k_base") def count_tokens(text): """Count tokens in the given text using tiktoken.""" return len(encoding.encode(text)) def transform_predictions(prediction_dict): """ Transform predictions to OpenAI-style format. Returns: flagged, categories, category_scores, category_applied_input_types """ # Define the categories we want to track category_keys = [ "toxicity", "severe_toxicity", "obscene", "threat", "insult", "identity_attack", "sexual_explicit" ] # Map detoxify output to our categories scores = {} for key in category_keys: scores[key] = prediction_dict.get(key, 0.0) # Set threshold for flagging content threshold = 0.5 bool_categories = {key: (scores[key] > threshold) for key in category_keys} cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys} flagged = any(bool_categories.values()) return flagged, bool_categories, scores, cat_applied_input_types def track_request_metrics(start_time, tokens_count): """Track performance metrics for requests.""" end_time = time.time() request_time = end_time - start_time request_times.append(request_time) today = datetime.now().strftime("%Y-%m-%d") daily_requests[today] += 1 daily_tokens[today] += tokens_count def get_performance_metrics(): """Get current performance metrics.""" global concurrent_requests with concurrent_requests_lock: current_concurrent = concurrent_requests # Calculate average request time avg_request_time = sum(request_times) / len(request_times) if request_times else 0 # Get today's date today = datetime.now().strftime("%Y-%m-%d") # Calculate requests per second (based on last 100 requests) recent_requests = list(request_times)[-100:] if len(request_times) >= 100 else list(request_times) requests_per_second = len(recent_requests) / sum(recent_requests) if recent_requests and sum(recent_requests) > 0 else 0 # Get daily stats today_requests = daily_requests.get(today, 0) today_tokens = daily_tokens.get(today, 0) # Get last 7 days stats last_7_days = [] for i in range(7): date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d") last_7_days.append({ "date": date, "requests": daily_requests.get(date, 0), "tokens": daily_tokens.get(date, 0) }) return { "avg_request_time": avg_request_time, "requests_per_second": requests_per_second, "concurrent_requests": current_concurrent, "today_requests": today_requests, "today_tokens": today_tokens, "last_7_days": last_7_days } @app.route('/') def home(): return render_template('index.html') @app.route('/v1/moderations', methods=['POST']) def moderations(): global concurrent_requests # Track concurrent requests with concurrent_requests_lock: concurrent_requests += 1 start_time = time.time() total_tokens = 0 try: # Check authorization auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith("Bearer "): return jsonify({"error": "Unauthorized"}), 401 provided_api_key = auth_header.split(" ")[1] if provided_api_key != API_KEY: return jsonify({"error": "Unauthorized"}), 401 # Get input data data = request.get_json() raw_input = data.get('input') or data.get('texts') if raw_input is None: return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400 # Handle both string and list inputs if isinstance(raw_input, str): texts = [raw_input] elif isinstance(raw_input, list): texts = raw_input else: return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400 # Validate input size if len(texts) > 10: return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400 for text in texts: if not isinstance(text, str) or len(text) > 100000: return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400 total_tokens += count_tokens(text) # Process each text results = [] for text in texts: pred = detoxify_model.predict([text]) prediction = {k: v[0] for k, v in pred.items()} flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(prediction) results.append({ "flagged": flagged, "categories": bool_categories, "category_scores": scores, "category_applied_input_types": cat_applied_input_types }) # Track metrics track_request_metrics(start_time, total_tokens) # Prepare response response_data = { "id": "modr-" + uuid.uuid4().hex[:24], "model": "unitaryai/detoxify-multilingual", "results": results, "object": "moderation", "usage": { "total_tokens": total_tokens } } return jsonify(response_data) finally: # Decrement concurrent requests counter with concurrent_requests_lock: concurrent_requests -= 1 @app.route('/v1/metrics', methods=['GET']) def metrics(): """Endpoint to get performance metrics.""" auth_header = request.headers.get('Authorization') if not auth_header or not auth_header.startswith("Bearer "): return jsonify({"error": "Unauthorized"}), 401 provided_api_key = auth_header.split(" ")[1] if provided_api_key != API_KEY: return jsonify({"error": "Unauthorized"}), 401 return jsonify(get_performance_metrics()) if __name__ == '__main__': # Create directories if they don't exist os.makedirs('templates', exist_ok=True) os.makedirs('static', exist_ok=True) port = int(os.getenv('PORT', 7860)) app.run(host='0.0.0.0', port=port, debug=True)