File size: 7,191 Bytes
038cbb8
2ec7d5b
ebef9a1
038cbb8
 
 
 
 
c7133b5
2ec7d5b
038cbb8
2ec7d5b
038cbb8
176df74
2ec7d5b
038cbb8
 
2ec7d5b
038cbb8
 
 
 
 
 
 
 
 
 
 
 
 
2ec7d5b
038cbb8
ebef9a1
038cbb8
 
ebef9a1
038cbb8
ebef9a1
038cbb8
 
ebef9a1
038cbb8
 
ebef9a1
038cbb8
 
ebef9a1
038cbb8
fd627d4
ebef9a1
 
 
 
 
 
038cbb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7133b5
 
038cbb8
c7133b5
176df74
 
038cbb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176df74
038cbb8
 
 
 
 
 
176df74
 
 
038cbb8
 
ebef9a1
 
 
 
 
 
038cbb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176df74
038cbb8
 
 
 
 
 
 
 
 
 
 
 
176df74
2ec7d5b
038cbb8
 
 
 
17d94c7
038cbb8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from flask import Flask, request, jsonify, render_template, send_from_directory
import os
import uuid
import time
import threading
import tiktoken
from datetime import datetime, timedelta
from collections import defaultdict, deque
from detoxify import Detoxify

app = Flask(__name__, static_folder='static', template_folder='templates')

# Load the detoxify model
detoxify_model = Detoxify('multilingual')

# API key from environment variable
API_KEY = os.getenv('API_KEY', 'your-api-key-here')

# Performance metrics tracking
request_times = deque(maxlen=1000)  # Store last 1000 request times
daily_requests = defaultdict(int)
daily_tokens = defaultdict(int)
concurrent_requests = 0
concurrent_requests_lock = threading.Lock()

# Token encoding for counting tokens
encoding = tiktoken.get_encoding("cl100k_base")

def count_tokens(text):
    """Count tokens in the given text using tiktoken."""
    return len(encoding.encode(text))

def transform_predictions(prediction_dict):
    """
    Transform predictions to OpenAI-style format.
    Returns: flagged, categories, category_scores, category_applied_input_types
    """
    # Define the categories we want to track
    category_keys = [
        "toxicity", "severe_toxicity", "obscene", "threat", 
        "insult", "identity_attack", "sexual_explicit"
    ]
    
    # Map detoxify output to our categories
    scores = {}
    for key in category_keys:
        scores[key] = prediction_dict.get(key, 0.0)
    
    # Set threshold for flagging content
    threshold = 0.5
    bool_categories = {key: (scores[key] > threshold) for key in category_keys}
    cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys}
    flagged = any(bool_categories.values())
    
    return flagged, bool_categories, scores, cat_applied_input_types

def track_request_metrics(start_time, tokens_count):
    """Track performance metrics for requests."""
    end_time = time.time()
    request_time = end_time - start_time
    request_times.append(request_time)
    
    today = datetime.now().strftime("%Y-%m-%d")
    daily_requests[today] += 1
    daily_tokens[today] += tokens_count

def get_performance_metrics():
    """Get current performance metrics."""
    global concurrent_requests
    with concurrent_requests_lock:
        current_concurrent = concurrent_requests
    
    # Calculate average request time
    avg_request_time = sum(request_times) / len(request_times) if request_times else 0
    
    # Get today's date
    today = datetime.now().strftime("%Y-%m-%d")
    
    # Calculate requests per second (based on last 100 requests)
    recent_requests = list(request_times)[-100:] if len(request_times) >= 100 else list(request_times)
    requests_per_second = len(recent_requests) / sum(recent_requests) if recent_requests and sum(recent_requests) > 0 else 0
    
    # Get daily stats
    today_requests = daily_requests.get(today, 0)
    today_tokens = daily_tokens.get(today, 0)
    
    # Get last 7 days stats
    last_7_days = []
    for i in range(7):
        date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
        last_7_days.append({
            "date": date,
            "requests": daily_requests.get(date, 0),
            "tokens": daily_tokens.get(date, 0)
        })
    
    return {
        "avg_request_time": avg_request_time,
        "requests_per_second": requests_per_second,
        "concurrent_requests": current_concurrent,
        "today_requests": today_requests,
        "today_tokens": today_tokens,
        "last_7_days": last_7_days
    }

@app.route('/')
def home():
    return render_template('index.html')

@app.route('/v1/moderations', methods=['POST'])
def moderations():
    global concurrent_requests
    
    # Track concurrent requests
    with concurrent_requests_lock:
        concurrent_requests += 1
    
    start_time = time.time()
    total_tokens = 0
    
    try:
        # Check authorization
        auth_header = request.headers.get('Authorization')
        if not auth_header or not auth_header.startswith("Bearer "):
            return jsonify({"error": "Unauthorized"}), 401
        
        provided_api_key = auth_header.split(" ")[1]
        if provided_api_key != API_KEY:
            return jsonify({"error": "Unauthorized"}), 401
        
        # Get input data
        data = request.get_json()
        raw_input = data.get('input') or data.get('texts')
        
        if raw_input is None:
            return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400
        
        # Handle both string and list inputs
        if isinstance(raw_input, str):
            texts = [raw_input]
        elif isinstance(raw_input, list):
            texts = raw_input
        else:
            return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400
        
        # Validate input size
        if len(texts) > 10:
            return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400
        
        for text in texts:
            if not isinstance(text, str) or len(text) > 100000:
                return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400
            total_tokens += count_tokens(text)
        
        # Process each text
        results = []
        for text in texts:
            pred = detoxify_model.predict([text])
            prediction = {k: v[0] for k, v in pred.items()}
            flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(prediction)
            
            results.append({
                "flagged": flagged,
                "categories": bool_categories,
                "category_scores": scores,
                "category_applied_input_types": cat_applied_input_types
            })
        
        # Track metrics
        track_request_metrics(start_time, total_tokens)
        
        # Prepare response
        response_data = {
            "id": "modr-" + uuid.uuid4().hex[:24],
            "model": "unitaryai/detoxify-multilingual",
            "results": results,
            "object": "moderation",
            "usage": {
                "total_tokens": total_tokens
            }
        }
        
        return jsonify(response_data)
    
    finally:
        # Decrement concurrent requests counter
        with concurrent_requests_lock:
            concurrent_requests -= 1

@app.route('/v1/metrics', methods=['GET'])
def metrics():
    """Endpoint to get performance metrics."""
    auth_header = request.headers.get('Authorization')
    if not auth_header or not auth_header.startswith("Bearer "):
        return jsonify({"error": "Unauthorized"}), 401
    
    provided_api_key = auth_header.split(" ")[1]
    if provided_api_key != API_KEY:
        return jsonify({"error": "Unauthorized"}), 401
    
    return jsonify(get_performance_metrics())

if __name__ == '__main__':
    # Create directories if they don't exist
    os.makedirs('templates', exist_ok=True)
    os.makedirs('static', exist_ok=True)
    
    port = int(os.getenv('PORT', 7860))
    app.run(host='0.0.0.0', port=port, debug=True)