* add AWS comprehend
Browse files- README.md +2 -1
- app.py +38 -3
- requirements.txt +7 -6
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 📝
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 4.
|
| 8 |
pinned: true
|
| 9 |
license: apache-2.0
|
| 10 |
---
|
|
@@ -35,3 +35,4 @@ gradio app.py
|
|
| 35 |
- [Rebuff](https://rebuff.ai/)
|
| 36 |
- [Azure Content Safety AI](https://learn.microsoft.com/en-us/azure/ai-services/content-safety/studio-quickstart)
|
| 37 |
- [AWS Bedrock Guardrails](https://aws.amazon.com/bedrock/guardrails/) (coming soon)
|
|
|
|
|
|
| 4 |
colorFrom: yellow
|
| 5 |
colorTo: gray
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.19.1
|
| 8 |
pinned: true
|
| 9 |
license: apache-2.0
|
| 10 |
---
|
|
|
|
| 35 |
- [Rebuff](https://rebuff.ai/)
|
| 36 |
- [Azure Content Safety AI](https://learn.microsoft.com/en-us/azure/ai-services/content-safety/studio-quickstart)
|
| 37 |
- [AWS Bedrock Guardrails](https://aws.amazon.com/bedrock/guardrails/) (coming soon)
|
| 38 |
+
- [AWS Comprehend](https://docs.aws.amazon.com/comprehend/latest/dg/trust-safety.html)
|
app.py
CHANGED
|
@@ -11,6 +11,7 @@ from functools import lru_cache
|
|
| 11 |
from typing import List, Union
|
| 12 |
|
| 13 |
import aegis
|
|
|
|
| 14 |
import gradio as gr
|
| 15 |
import requests
|
| 16 |
from huggingface_hub import HfApi
|
|
@@ -29,6 +30,7 @@ automorphic_api_key = os.getenv("AUTOMORPHIC_API_KEY")
|
|
| 29 |
rebuff_api_key = os.getenv("REBUFF_API_KEY")
|
| 30 |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
|
| 31 |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
@lru_cache(maxsize=2)
|
|
@@ -61,7 +63,9 @@ def convert_elapsed_time(diff_time) -> float:
|
|
| 61 |
deepset_classifier = init_prompt_injection_model(
|
| 62 |
"ProtectAI/deberta-v3-base-injection-onnx"
|
| 63 |
) # ONNX version of deepset/deberta-v3-base-injection
|
| 64 |
-
protectai_classifier = init_prompt_injection_model(
|
|
|
|
|
|
|
| 65 |
fmops_classifier = init_prompt_injection_model(
|
| 66 |
"ProtectAI/fmops-distilbert-prompt-injection-onnx"
|
| 67 |
) # ONNX version of fmops/distilbert-prompt-injection
|
|
@@ -155,6 +159,36 @@ def detect_azure(prompt: str) -> (bool, bool):
|
|
| 155 |
return False, False
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
detection_providers = {
|
| 159 |
"ProtectAI (HF model)": detect_hf_protectai,
|
| 160 |
"Deepset (HF model)": detect_hf_deepset,
|
|
@@ -163,6 +197,7 @@ detection_providers = {
|
|
| 163 |
"Automorphic Aegis": detect_automorphic,
|
| 164 |
# "Rebuff": detect_rebuff,
|
| 165 |
"Azure Content Safety": detect_azure,
|
|
|
|
| 166 |
}
|
| 167 |
|
| 168 |
|
|
@@ -235,8 +270,8 @@ if __name__ == "__main__":
|
|
| 235 |
"The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only."
|
| 236 |
"<br /><br />"
|
| 237 |
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />"
|
| 238 |
-
|
| 239 |
-
|
| 240 |
examples=[
|
| 241 |
[
|
| 242 |
example,
|
|
|
|
| 11 |
from typing import List, Union
|
| 12 |
|
| 13 |
import aegis
|
| 14 |
+
import boto3
|
| 15 |
import gradio as gr
|
| 16 |
import requests
|
| 17 |
from huggingface_hub import HfApi
|
|
|
|
| 30 |
rebuff_api_key = os.getenv("REBUFF_API_KEY")
|
| 31 |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
|
| 32 |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
|
| 33 |
+
aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-east-1")
|
| 34 |
|
| 35 |
|
| 36 |
@lru_cache(maxsize=2)
|
|
|
|
| 63 |
deepset_classifier = init_prompt_injection_model(
|
| 64 |
"ProtectAI/deberta-v3-base-injection-onnx"
|
| 65 |
) # ONNX version of deepset/deberta-v3-base-injection
|
| 66 |
+
protectai_classifier = init_prompt_injection_model(
|
| 67 |
+
"ProtectAI/deberta-v3-base-prompt-injection", "onnx"
|
| 68 |
+
)
|
| 69 |
fmops_classifier = init_prompt_injection_model(
|
| 70 |
"ProtectAI/fmops-distilbert-prompt-injection-onnx"
|
| 71 |
) # ONNX version of fmops/distilbert-prompt-injection
|
|
|
|
| 159 |
return False, False
|
| 160 |
|
| 161 |
|
| 162 |
+
def detect_aws_comprehend(prompt: str) -> (bool, bool):
|
| 163 |
+
response = aws_comprehend_client.classify_document(
|
| 164 |
+
EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
|
| 165 |
+
Text=prompt,
|
| 166 |
+
)
|
| 167 |
+
response = {
|
| 168 |
+
"Classes": [
|
| 169 |
+
{"Name": "SAFE_PROMPT", "Score": 0.9010000228881836},
|
| 170 |
+
{"Name": "UNSAFE_PROMPT", "Score": 0.0989999994635582},
|
| 171 |
+
],
|
| 172 |
+
"ResponseMetadata": {
|
| 173 |
+
"RequestId": "e8900fe1-3346-45c0-bad3-007b2840865a",
|
| 174 |
+
"HTTPStatusCode": 200,
|
| 175 |
+
"HTTPHeaders": {
|
| 176 |
+
"x-amzn-requestid": "e8900fe1-3346-45c0-bad3-007b2840865a",
|
| 177 |
+
"content-type": "application/x-amz-json-1.1",
|
| 178 |
+
"content-length": "115",
|
| 179 |
+
"date": "Mon, 19 Feb 2024 08:34:43 GMT",
|
| 180 |
+
},
|
| 181 |
+
"RetryAttempts": 0,
|
| 182 |
+
},
|
| 183 |
+
}
|
| 184 |
+
logger.info(f"Prompt injection result from AWS Comprehend: {response}")
|
| 185 |
+
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
| 186 |
+
logger.error(f"Failed to call AWS Comprehend API: {response}")
|
| 187 |
+
return False, False
|
| 188 |
+
|
| 189 |
+
return True, response["Classes"][0] == "UNSAFE_PROMPT"
|
| 190 |
+
|
| 191 |
+
|
| 192 |
detection_providers = {
|
| 193 |
"ProtectAI (HF model)": detect_hf_protectai,
|
| 194 |
"Deepset (HF model)": detect_hf_deepset,
|
|
|
|
| 197 |
"Automorphic Aegis": detect_automorphic,
|
| 198 |
# "Rebuff": detect_rebuff,
|
| 199 |
"Azure Content Safety": detect_azure,
|
| 200 |
+
"AWS Comprehend": detect_aws_comprehend,
|
| 201 |
}
|
| 202 |
|
| 203 |
|
|
|
|
| 270 |
"The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only."
|
| 271 |
"<br /><br />"
|
| 272 |
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />"
|
| 273 |
+
'<a href="https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w">Join our Slack community to discuss LLM Security</a><br />'
|
| 274 |
+
'<a href="https://github.com/protectai/llm-guard">Secure your LLM interactions with LLM Guard</a>',
|
| 275 |
examples=[
|
| 276 |
[
|
| 277 |
example,
|
requirements.txt
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
|
|
| 1 |
git+https://github.com/automorphic-ai/aegis.git
|
| 2 |
-
gradio==4.
|
| 3 |
-
huggingface_hub==0.
|
| 4 |
-
onnxruntime==1.
|
| 5 |
-
optimum[onnxruntime]==1.
|
| 6 |
-
rebuff==0.
|
| 7 |
requests==2.31.0
|
| 8 |
-
transformers==4.
|
|
|
|
| 1 |
+
boto3==1.34.44
|
| 2 |
git+https://github.com/automorphic-ai/aegis.git
|
| 3 |
+
gradio==4.19.1
|
| 4 |
+
huggingface_hub==0.20.3
|
| 5 |
+
onnxruntime==1.17.0
|
| 6 |
+
optimum[onnxruntime]==1.17.1
|
| 7 |
+
rebuff==0.1.1
|
| 8 |
requests==2.31.0
|
| 9 |
+
transformers==4.37.2
|