Utiric commited on
Commit
038cbb8
·
verified ·
1 Parent(s): fd627d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -164
app.py CHANGED
@@ -1,124 +1,52 @@
1
- from flask import Flask, request, jsonify, render_template_string
2
  import os
3
  import uuid
4
- import torch
 
 
 
 
5
  from detoxify import Detoxify
6
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
7
 
8
- app = Flask(__name__)
9
 
10
- # Modelleri yükle
11
  detoxify_model = Detoxify('multilingual')
12
- koala_model = AutoModelForSequenceClassification.from_pretrained("KoalaAI/Text-Moderation")
13
- koala_tokenizer = AutoTokenizer.from_pretrained("KoalaAI/Text-Moderation")
14
 
15
- # API key environment variable'dan
16
- API_KEY = os.getenv('API_KEY')
17
 
18
- # Modern, TailwindCSS destekli HTML arayüzü (dark/light)
19
- HTML_TEMPLATE = '''
20
- <!DOCTYPE html>
21
- <html lang="en">
22
- <head>
23
- <meta charset="UTF-8">
24
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
25
- <title>Modern Moderation API Test</title>
26
- <script src="https://cdn.tailwindcss.com"></script>
27
- </head>
28
- <body class="bg-gray-100 dark:bg-gray-900 text-gray-900 dark:text-gray-100">
29
- <div class="container mx-auto px-4 py-8">
30
- <h1 class="text-4xl font-bold mb-6 text-center">Modern Moderation API Test</h1>
31
- <form id="testForm" class="bg-white dark:bg-gray-800 shadow-md rounded px-8 pt-6 pb-8 mb-4">
32
- <div class="mb-4">
33
- <label class="block text-gray-700 dark:text-gray-300 text-sm font-bold mb-2" for="api_key">API Key:</label>
34
- <input type="text" id="api_key" name="api_key" required class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 dark:text-gray-900 leading-tight focus:outline-none focus:shadow-outline">
35
- </div>
36
- <div class="mb-4">
37
- <label class="block text-gray-700 dark:text-gray-300 text-sm font-bold mb-2" for="model">Select Model:</label>
38
- <select id="model" name="model" class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 dark:text-gray-900 leading-tight focus:outline-none focus:shadow-outline">
39
- <option value="unitaryai/detoxify-multilingual" selected>unitaryai/detoxify-multilingual</option>
40
- <option value="koalaai/text-moderation">koalaai/text-moderation</option>
41
- </select>
42
- </div>
43
- <div class="mb-4">
44
- <label class="block text-gray-700 dark:text-gray-300 text-sm font-bold mb-2" for="input">Text to Analyze:</label>
45
- <textarea id="input" name="input" rows="4" required class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 dark:text-gray-900 leading-tight focus:outline-none focus:shadow-outline"></textarea>
46
- </div>
47
- <div class="flex items-center justify-between">
48
- <button type="submit" class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded focus:outline-none focus:shadow-outline">
49
- Analyze
50
- </button>
51
- </div>
52
- </form>
53
- <div id="results" class="mt-6"></div>
54
- </div>
55
- <script>
56
- document.getElementById('testForm').addEventListener('submit', async function(event) {
57
- event.preventDefault();
58
- const apiKey = document.getElementById('api_key').value;
59
- const model = document.getElementById('model').value;
60
- const input = document.getElementById('input').value;
61
- try {
62
- const response = await fetch('/v1/moderations', {
63
- method: 'POST',
64
- headers: {
65
- 'Content-Type': 'application/json',
66
- 'Authorization': 'Bearer ' + apiKey
67
- },
68
- body: JSON.stringify({ model: model, input: input })
69
- });
70
- const data = await response.json();
71
- const resultsDiv = document.getElementById('results');
72
- if (data.error) {
73
- resultsDiv.innerHTML = `<p class="text-red-500 font-bold">Error: ${data.error}</p>`;
74
- } else {
75
- let html = '<h2 class="text-2xl font-bold mb-4">Results:</h2>';
76
- data.results.forEach(item => {
77
- html += `<div class="mb-4 p-4 bg-gray-200 dark:bg-gray-700 rounded">
78
- <p class="font-semibold">Flagged: ${item.flagged}</p>
79
- <p class="font-semibold">Categories:</p>
80
- <ul>`;
81
- for (const [key, value] of Object.entries(item.categories)) {
82
- html += `<li>${key}: ${value} (score: ${item.category_scores[key].toFixed(5)})</li>`;
83
- }
84
- html += ` </ul>
85
- </div>`;
86
- });
87
- resultsDiv.innerHTML = html;
88
- }
89
- } catch (error) {
90
- console.error('Error:', error);
91
- }
92
- });
93
- </script>
94
- </body>
95
- </html>
96
- '''
97
 
98
- def transform_predictions(model_choice, prediction_dict):
99
  """
100
- OpenAI tarzı response formatı için tahminleri dönüştür.
101
- Dönüştürülen alanlar: flagged, categories, category_scores, category_applied_input_types
102
  """
 
103
  category_keys = [
104
- "sexual", "sexual/minors", "harassment", "harassment/threatening",
105
- "hate", "hate/threatening", "illicit", "illicit/violent",
106
- "self-harm", "self-harm/intent", "self-harm/instructions",
107
- "violence", "violence/graphic"
108
  ]
 
 
109
  scores = {}
110
- if model_choice == "unitaryai/detoxify-multilingual":
111
- scores["sexual"] = prediction_dict.get("sexual_explicit", 0.0)
112
- scores["harassment"] = max(prediction_dict.get("identity_attack", 0.0), prediction_dict.get("insult", 0.0))
113
- scores["hate"] = prediction_dict.get("toxicity", 0.0)
114
- scores["violence"] = max(prediction_dict.get("severe_toxicity", 0.0), prediction_dict.get("threat", 0.0))
115
- for key in category_keys:
116
- if key not in scores:
117
- scores[key] = 0.0
118
- else:
119
- for key in category_keys:
120
- scores[key] = prediction_dict.get(key, 0.0)
121
 
 
122
  threshold = 0.5
123
  bool_categories = {key: (scores[key] > threshold) for key in category_keys}
124
  cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys}
@@ -126,80 +54,156 @@ def transform_predictions(model_choice, prediction_dict):
126
 
127
  return flagged, bool_categories, scores, cat_applied_input_types
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  @app.route('/')
130
  def home():
131
- return render_template_string(HTML_TEMPLATE)
132
 
133
  @app.route('/v1/moderations', methods=['POST'])
134
  def moderations():
135
- auth_header = request.headers.get('Authorization')
136
- if not auth_header or not auth_header.startswith("Bearer "):
137
- return jsonify({"error": "Unauthorized"}), 401
138
- provided_api_key = auth_header.split(" ")[1]
139
- if provided_api_key != API_KEY:
140
- return jsonify({"error": "Unauthorized"}), 401
141
-
142
- data = request.get_json()
143
- raw_input = data.get('input') or data.get('texts')
144
- if raw_input is None:
145
- return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400
146
-
147
- if isinstance(raw_input, str):
148
- texts = [raw_input]
149
- elif isinstance(raw_input, list):
150
- texts = raw_input
151
- else:
152
- return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400
153
-
154
- if len(texts) > 10:
155
- return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400
156
-
157
- for text in texts:
158
- if not isinstance(text, str) or len(text) > 100000:
159
- return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400
160
-
161
- results = []
162
- model_choice = data.get('model', 'unitaryai/detoxify-multilingual')
163
-
164
- if model_choice == "koalaai/text-moderation":
 
 
 
 
 
 
 
 
165
  for text in texts:
166
- inputs = koala_tokenizer(text, return_tensors="pt")
167
- outputs = koala_model(**inputs)
168
- logits = outputs.logits
169
- probabilities = torch.softmax(logits, dim=-1).squeeze().tolist()
170
- if isinstance(probabilities, float):
171
- probabilities = [probabilities]
172
- labels = [koala_model.config.id2label[idx] for idx in range(len(probabilities))]
173
- prediction = {label: prob for label, prob in zip(labels, probabilities)}
174
- flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(model_choice, prediction)
175
- results.append({
176
- "flagged": flagged,
177
- "categories": bool_categories,
178
- "category_scores": scores,
179
- "category_applied_input_types": cat_applied_input_types
180
- })
181
- response_model = "koalaai/text-moderation"
182
- else:
183
  for text in texts:
184
  pred = detoxify_model.predict([text])
185
  prediction = {k: v[0] for k, v in pred.items()}
186
- flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(model_choice, prediction)
 
187
  results.append({
188
  "flagged": flagged,
189
  "categories": bool_categories,
190
  "category_scores": scores,
191
  "category_applied_input_types": cat_applied_input_types
192
  })
193
- response_model = "unitaryai/detoxify-multilingual"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- response_data = {
196
- "id": "modr-" + uuid.uuid4().hex[:24],
197
- "model": response_model,
198
- "results": results,
199
- "object": "moderation"
200
- }
201
- return jsonify(response_data)
 
 
 
 
 
202
 
203
  if __name__ == '__main__':
 
 
 
 
204
  port = int(os.getenv('PORT', 7860))
205
- app.run(host='0.0.0.0', port=port, debug=True)
 
1
+ from flask import Flask, request, jsonify, render_template, send_from_directory
2
  import os
3
  import uuid
4
+ import time
5
+ import threading
6
+ import tiktoken
7
+ from datetime import datetime, timedelta
8
+ from collections import defaultdict, deque
9
  from detoxify import Detoxify
 
10
 
11
+ app = Flask(__name__, static_folder='static', template_folder='templates')
12
 
13
+ # Load the detoxify model
14
  detoxify_model = Detoxify('multilingual')
 
 
15
 
16
+ # API key from environment variable
17
+ API_KEY = os.getenv('API_KEY', 'your-api-key-here')
18
 
19
+ # Performance metrics tracking
20
+ request_times = deque(maxlen=1000) # Store last 1000 request times
21
+ daily_requests = defaultdict(int)
22
+ daily_tokens = defaultdict(int)
23
+ concurrent_requests = 0
24
+ concurrent_requests_lock = threading.Lock()
25
+
26
+ # Token encoding for counting tokens
27
+ encoding = tiktoken.get_encoding("cl100k_base")
28
+
29
+ def count_tokens(text):
30
+ """Count tokens in the given text using tiktoken."""
31
+ return len(encoding.encode(text))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ def transform_predictions(prediction_dict):
34
  """
35
+ Transform predictions to OpenAI-style format.
36
+ Returns: flagged, categories, category_scores, category_applied_input_types
37
  """
38
+ # Define the categories we want to track
39
  category_keys = [
40
+ "toxicity", "severe_toxicity", "obscene", "threat",
41
+ "insult", "identity_attack", "sexual_explicit"
 
 
42
  ]
43
+
44
+ # Map detoxify output to our categories
45
  scores = {}
46
+ for key in category_keys:
47
+ scores[key] = prediction_dict.get(key, 0.0)
 
 
 
 
 
 
 
 
 
48
 
49
+ # Set threshold for flagging content
50
  threshold = 0.5
51
  bool_categories = {key: (scores[key] > threshold) for key in category_keys}
52
  cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys}
 
54
 
55
  return flagged, bool_categories, scores, cat_applied_input_types
56
 
57
+ def track_request_metrics(start_time, tokens_count):
58
+ """Track performance metrics for requests."""
59
+ end_time = time.time()
60
+ request_time = end_time - start_time
61
+ request_times.append(request_time)
62
+
63
+ today = datetime.now().strftime("%Y-%m-%d")
64
+ daily_requests[today] += 1
65
+ daily_tokens[today] += tokens_count
66
+
67
+ def get_performance_metrics():
68
+ """Get current performance metrics."""
69
+ global concurrent_requests
70
+ with concurrent_requests_lock:
71
+ current_concurrent = concurrent_requests
72
+
73
+ # Calculate average request time
74
+ avg_request_time = sum(request_times) / len(request_times) if request_times else 0
75
+
76
+ # Get today's date
77
+ today = datetime.now().strftime("%Y-%m-%d")
78
+
79
+ # Calculate requests per second (based on last 100 requests)
80
+ recent_requests = list(request_times)[-100:] if len(request_times) >= 100 else list(request_times)
81
+ requests_per_second = len(recent_requests) / sum(recent_requests) if recent_requests and sum(recent_requests) > 0 else 0
82
+
83
+ # Get daily stats
84
+ today_requests = daily_requests.get(today, 0)
85
+ today_tokens = daily_tokens.get(today, 0)
86
+
87
+ # Get last 7 days stats
88
+ last_7_days = []
89
+ for i in range(7):
90
+ date = (datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
91
+ last_7_days.append({
92
+ "date": date,
93
+ "requests": daily_requests.get(date, 0),
94
+ "tokens": daily_tokens.get(date, 0)
95
+ })
96
+
97
+ return {
98
+ "avg_request_time": avg_request_time,
99
+ "requests_per_second": requests_per_second,
100
+ "concurrent_requests": current_concurrent,
101
+ "today_requests": today_requests,
102
+ "today_tokens": today_tokens,
103
+ "last_7_days": last_7_days
104
+ }
105
+
106
  @app.route('/')
107
  def home():
108
+ return render_template('index.html')
109
 
110
  @app.route('/v1/moderations', methods=['POST'])
111
  def moderations():
112
+ global concurrent_requests
113
+
114
+ # Track concurrent requests
115
+ with concurrent_requests_lock:
116
+ concurrent_requests += 1
117
+
118
+ start_time = time.time()
119
+ total_tokens = 0
120
+
121
+ try:
122
+ # Check authorization
123
+ auth_header = request.headers.get('Authorization')
124
+ if not auth_header or not auth_header.startswith("Bearer "):
125
+ return jsonify({"error": "Unauthorized"}), 401
126
+
127
+ provided_api_key = auth_header.split(" ")[1]
128
+ if provided_api_key != API_KEY:
129
+ return jsonify({"error": "Unauthorized"}), 401
130
+
131
+ # Get input data
132
+ data = request.get_json()
133
+ raw_input = data.get('input') or data.get('texts')
134
+
135
+ if raw_input is None:
136
+ return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400
137
+
138
+ # Handle both string and list inputs
139
+ if isinstance(raw_input, str):
140
+ texts = [raw_input]
141
+ elif isinstance(raw_input, list):
142
+ texts = raw_input
143
+ else:
144
+ return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400
145
+
146
+ # Validate input size
147
+ if len(texts) > 10:
148
+ return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400
149
+
150
  for text in texts:
151
+ if not isinstance(text, str) or len(text) > 100000:
152
+ return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400
153
+ total_tokens += count_tokens(text)
154
+
155
+ # Process each text
156
+ results = []
 
 
 
 
 
 
 
 
 
 
 
157
  for text in texts:
158
  pred = detoxify_model.predict([text])
159
  prediction = {k: v[0] for k, v in pred.items()}
160
+ flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(prediction)
161
+
162
  results.append({
163
  "flagged": flagged,
164
  "categories": bool_categories,
165
  "category_scores": scores,
166
  "category_applied_input_types": cat_applied_input_types
167
  })
168
+
169
+ # Track metrics
170
+ track_request_metrics(start_time, total_tokens)
171
+
172
+ # Prepare response
173
+ response_data = {
174
+ "id": "modr-" + uuid.uuid4().hex[:24],
175
+ "model": "unitaryai/detoxify-multilingual",
176
+ "results": results,
177
+ "object": "moderation",
178
+ "usage": {
179
+ "total_tokens": total_tokens
180
+ }
181
+ }
182
+
183
+ return jsonify(response_data)
184
+
185
+ finally:
186
+ # Decrement concurrent requests counter
187
+ with concurrent_requests_lock:
188
+ concurrent_requests -= 1
189
 
190
+ @app.route('/v1/metrics', methods=['GET'])
191
+ def metrics():
192
+ """Endpoint to get performance metrics."""
193
+ auth_header = request.headers.get('Authorization')
194
+ if not auth_header or not auth_header.startswith("Bearer "):
195
+ return jsonify({"error": "Unauthorized"}), 401
196
+
197
+ provided_api_key = auth_header.split(" ")[1]
198
+ if provided_api_key != API_KEY:
199
+ return jsonify({"error": "Unauthorized"}), 401
200
+
201
+ return jsonify(get_performance_metrics())
202
 
203
  if __name__ == '__main__':
204
+ # Create directories if they don't exist
205
+ os.makedirs('templates', exist_ok=True)
206
+ os.makedirs('static', exist_ok=True)
207
+
208
  port = int(os.getenv('PORT', 7860))
209
+ app.run(host='0.0.0.0', port=port, debug=True)