|
from flask import Flask, request, jsonify, render_template_string |
|
import os |
|
import uuid |
|
import torch |
|
from detoxify import Detoxify |
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
detoxify_model = Detoxify('multilingual') |
|
koala_model = AutoModelForSequenceClassification.from_pretrained("KoalaAI/Text-Moderation") |
|
koala_tokenizer = AutoTokenizer.from_pretrained("KoalaAI/Text-Moderation") |
|
|
|
|
|
API_KEY = os.getenv('API_KEY') |
|
|
|
|
|
HTML_TEMPLATE = ''' |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>Modern Moderation API Test</title> |
|
<script src="https://cdn.tailwindcss.com"></script> |
|
</head> |
|
<body class="bg-gray-100 dark:bg-gray-900 text-gray-900 dark:text-gray-100"> |
|
<div class="container mx-auto px-4 py-8"> |
|
<h1 class="text-4xl font-bold mb-6 text-center">Modern Moderation API Test</h1> |
|
<form id="testForm" class="bg-white dark:bg-gray-800 shadow-md rounded px-8 pt-6 pb-8 mb-4"> |
|
<div class="mb-4"> |
|
<label class="block text-gray-700 dark:text-gray-300 text-sm font-bold mb-2" for="api_key">API Key:</label> |
|
<input type="text" id="api_key" name="api_key" required class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 dark:text-gray-900 leading-tight focus:outline-none focus:shadow-outline"> |
|
</div> |
|
<div class="mb-4"> |
|
<label class="block text-gray-700 dark:text-gray-300 text-sm font-bold mb-2" for="model">Select Model:</label> |
|
<select id="model" name="model" class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 dark:text-gray-900 leading-tight focus:outline-none focus:shadow-outline"> |
|
<option value="unitaryai/detoxify-multilingual" selected>unitaryai/detoxify-multilingual</option> |
|
<option value="koalaai/text-moderation">koalaai/text-moderation</option> |
|
</select> |
|
</div> |
|
<div class="mb-4"> |
|
<label class="block text-gray-700 dark:text-gray-300 text-sm font-bold mb-2" for="input">Text to Analyze:</label> |
|
<textarea id="input" name="input" rows="4" required class="shadow appearance-none border rounded w-full py-2 px-3 text-gray-700 dark:text-gray-900 leading-tight focus:outline-none focus:shadow-outline"></textarea> |
|
</div> |
|
<div class="flex items-center justify-between"> |
|
<button type="submit" class="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded focus:outline-none focus:shadow-outline"> |
|
Analyze |
|
</button> |
|
</div> |
|
</form> |
|
<div id="results" class="mt-6"></div> |
|
</div> |
|
<script> |
|
document.getElementById('testForm').addEventListener('submit', async function(event) { |
|
event.preventDefault(); |
|
const apiKey = document.getElementById('api_key').value; |
|
const model = document.getElementById('model').value; |
|
const input = document.getElementById('input').value; |
|
try { |
|
const response = await fetch('/v1/moderations', { |
|
method: 'POST', |
|
headers: { |
|
'Content-Type': 'application/json', |
|
'Authorization': 'Bearer ' + apiKey |
|
}, |
|
body: JSON.stringify({ model: model, input: input }) |
|
}); |
|
const data = await response.json(); |
|
const resultsDiv = document.getElementById('results'); |
|
if (data.error) { |
|
resultsDiv.innerHTML = `<p class="text-red-500 font-bold">Error: ${data.error}</p>`; |
|
} else { |
|
let html = '<h2 class="text-2xl font-bold mb-4">Results:</h2>'; |
|
data.results.forEach(item => { |
|
html += `<div class="mb-4 p-4 bg-gray-200 dark:bg-gray-700 rounded"> |
|
<p class="font-semibold">Flagged: ${item.flagged}</p> |
|
<p class="font-semibold">Categories:</p> |
|
<ul>`; |
|
for (const [key, value] of Object.entries(item.categories)) { |
|
html += `<li>${key}: ${value} (score: ${item.category_scores[key].toFixed(5)})</li>`; |
|
} |
|
html += ` </ul> |
|
</div>`; |
|
}); |
|
resultsDiv.innerHTML = html; |
|
} |
|
} catch (error) { |
|
console.error('Error:', error); |
|
} |
|
}); |
|
</script> |
|
</body> |
|
</html> |
|
''' |
|
|
|
def transform_predictions(model_choice, prediction_dict): |
|
""" |
|
OpenAI tarzı response formatı için tahminleri dönüştür. |
|
Dönüştürülen alanlar: flagged, categories, category_scores, category_applied_input_types |
|
""" |
|
category_keys = [ |
|
"sexual", "sexual/minors", "harassment", "harassment/threatening", |
|
"hate", "hate/threatening", "illicit", "illicit/violent", |
|
"self-harm", "self-harm/intent", "self-harm/instructions", |
|
"violence", "violence/graphic" |
|
] |
|
scores = {} |
|
if model_choice == "unitaryai/detoxify-multilingual": |
|
scores["sexual"] = prediction_dict.get("sexual_explicit", 0.0) |
|
scores["harassment"] = max(prediction_dict.get("identity_attack", 0.0), prediction_dict.get("insult", 0.0)) |
|
scores["hate"] = prediction_dict.get("toxicity", 0.0) |
|
scores["violence"] = max(prediction_dict.get("severe_toxicity", 0.0), prediction_dict.get("threat", 0.0)) |
|
for key in category_keys: |
|
if key not in scores: |
|
scores[key] = 0.0 |
|
else: |
|
for key in category_keys: |
|
scores[key] = prediction_dict.get(key, 0.0) |
|
|
|
threshold = 0.5 |
|
bool_categories = {key: (scores[key] > threshold) for key in category_keys} |
|
cat_applied_input_types = {key: (["text"] if scores[key] > 0 else []) for key in category_keys} |
|
flagged = any(bool_categories.values()) |
|
|
|
return flagged, bool_categories, scores, cat_applied_input_types |
|
|
|
@app.route('/') |
|
def home(): |
|
return render_template_string(HTML_TEMPLATE) |
|
|
|
@app.route('/v1/moderations', methods=['POST']) |
|
def moderations(): |
|
auth_header = request.headers.get('Authorization') |
|
if not auth_header or not auth_header.startswith("Bearer "): |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
provided_api_key = auth_header.split(" ")[1] |
|
if provided_api_key != API_KEY: |
|
return jsonify({"error": "Unauthorized"}), 401 |
|
|
|
data = request.get_json() |
|
raw_input = data.get('input') or data.get('texts') |
|
if raw_input is None: |
|
return jsonify({"error": "Invalid input, expected 'input' or 'texts' field"}), 400 |
|
|
|
if isinstance(raw_input, str): |
|
texts = [raw_input] |
|
elif isinstance(raw_input, list): |
|
texts = raw_input |
|
else: |
|
return jsonify({"error": "Invalid input format, expected string or list of strings"}), 400 |
|
|
|
if len(texts) > 10: |
|
return jsonify({"error": "Too many input items. Maximum 10 allowed."}), 400 |
|
|
|
for text in texts: |
|
if not isinstance(text, str) or len(text) > 100000: |
|
return jsonify({"error": "Each input item must be a string with a maximum of 100k characters."}), 400 |
|
|
|
results = [] |
|
model_choice = data.get('model', 'unitaryai/detoxify-multilingual') |
|
|
|
if model_choice == "koalaai/text-moderation": |
|
for text in texts: |
|
inputs = koala_tokenizer(text, return_tensors="pt") |
|
outputs = koala_model(**inputs) |
|
logits = outputs.logits |
|
probabilities = torch.softmax(logits, dim=-1).squeeze().tolist() |
|
if isinstance(probabilities, float): |
|
probabilities = [probabilities] |
|
labels = [koala_model.config.id2label[idx] for idx in range(len(probabilities))] |
|
prediction = {label: prob for label, prob in zip(labels, probabilities)} |
|
flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(model_choice, prediction) |
|
results.append({ |
|
"flagged": flagged, |
|
"categories": bool_categories, |
|
"category_scores": scores, |
|
"category_applied_input_types": cat_applied_input_types |
|
}) |
|
response_model = "koalaai/text-moderation" |
|
else: |
|
for text in texts: |
|
pred = detoxify_model.predict([text]) |
|
prediction = {k: v[0] for k, v in pred.items()} |
|
flagged, bool_categories, scores, cat_applied_input_types = transform_predictions(model_choice, prediction) |
|
results.append({ |
|
"flagged": flagged, |
|
"categories": bool_categories, |
|
"category_scores": scores, |
|
"category_applied_input_types": cat_applied_input_types |
|
}) |
|
response_model = "unitaryai/detoxify-multilingual" |
|
|
|
response_data = { |
|
"id": "modr-" + uuid.uuid4().hex[:24], |
|
"model": response_model, |
|
"results": results, |
|
"object": "moderation" |
|
} |
|
return jsonify(response_data) |
|
|
|
if __name__ == '__main__': |
|
port = int(os.getenv('PORT', 7860)) |
|
app.run(host='0.0.0.0', port=port, debug=True) |
|
|