SSK-14 commited on
Commit
acb544e
·
1 Parent(s): 4cab2aa

Add LLM guard api

Browse files
Files changed (18) hide show
  1. Dockerfile +45 -0
  2. Dockerfile-cuda +53 -0
  3. Makefile +33 -0
  4. README.md +0 -2
  5. app/__init__.py +0 -0
  6. app/__main__.py +4 -0
  7. app/app.py +325 -0
  8. app/cache.py +145 -0
  9. app/config.py +93 -0
  10. app/otel.py +85 -0
  11. app/scanner.py +107 -0
  12. app/schemas.py +24 -0
  13. app/util.py +57 -0
  14. app/version.py +1 -0
  15. config/scanners.yml +162 -0
  16. docker-compose.yml +11 -0
  17. openapi.json +319 -0
  18. pyproject.toml +57 -0
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use the Python 3.11 slim image
2
+ FROM python:3.11-slim
3
+
4
+ LABEL org.opencontainers.image.source=https://github.com/protectai/llm-guard
5
+ LABEL org.opencontainers.image.description="LLM Guard API"
6
+ LABEL org.opencontainers.image.licenses=MIT
7
+
8
+ # Install system packages needed for building
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
12
+
13
+ RUN useradd -m -u 1000 user
14
+ USER user
15
+ ENV HOME=/home/user \
16
+ PATH=/home/user/.local/bin:$PATH
17
+
18
+ # ensures that the python output is sent straight to terminal (e.g. your container log)
19
+ # without being first buffered and that you can see the output of your application (e.g. django logs)
20
+ # in real time. Equivalent to python -u: https://docs.python.org/3/using/cmdline.html#cmdoption-u
21
+ ENV PYTHONUNBUFFERED 1
22
+
23
+ # https://docs.python.org/3/using/cmdline.html#envvar-PYTHONDONTWRITEBYTECODE
24
+ # Prevents Python from writing .pyc files to disk
25
+ ENV PYTHONDONTWRITEBYTECODE 1
26
+
27
+ # Set up a working directory
28
+ WORKDIR $HOME/app
29
+
30
+ # Copy pyproject.toml and other necessary files for installation
31
+ COPY --chown=user:user pyproject.toml ./
32
+ COPY --chown=user:user app ./app
33
+
34
+ # Install the project's dependencies
35
+ RUN pip install --no-cache-dir --upgrade pip && \
36
+ pip install torch==2.0.1 --index-url https://download.pytorch.org/whl/cpu && \
37
+ pip install --no-cache-dir ".[cpu]"
38
+
39
+ RUN python -m spacy download en_core_web_sm
40
+
41
+ COPY --chown=user:user ./config/scanners.yml ./config/scanners.yml
42
+
43
+ EXPOSE 7860
44
+
45
+ CMD ["llm_guard_api", "/home/user/app/config/scanners.yml"]
Dockerfile-cuda ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Start from an NVIDIA CUDA base image with Python 3
2
+ FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
3
+
4
+ LABEL org.opencontainers.image.source=https://github.com/protectai/llm-guard
5
+ LABEL org.opencontainers.image.description="LLM Guard API"
6
+ LABEL org.opencontainers.image.licenses=MIT
7
+
8
+ # Install Python and other necessary packages
9
+ RUN apt-get update && apt-get install -y \
10
+ python3-pip \
11
+ python3-dev \
12
+ build-essential \
13
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Alias python3 to python
16
+ RUN ln -s /usr/bin/python3 /usr/bin/python
17
+
18
+ # Create a non-root user and set user environment variables
19
+ RUN useradd -m -u 1000 user
20
+ USER user
21
+ ENV HOME=/home/user \
22
+ PATH=/home/user/.local/bin:$PATH
23
+
24
+ # ensures that the python output is sent straight to terminal (e.g. your container log)
25
+ # without being first buffered and that you can see the output of your application (e.g. django logs)
26
+ # in real time. Equivalent to python -u: https://docs.python.org/3/using/cmdline.html#cmdoption-u
27
+ ENV PYTHONUNBUFFERED 1
28
+
29
+ # https://docs.python.org/3/using/cmdline.html#envvar-PYTHONDONTWRITEBYTECODE
30
+ # Prevents Python from writing .pyc files to disk
31
+ ENV PYTHONDONTWRITEBYTECODE 1
32
+
33
+ # Set up a working directory
34
+ WORKDIR $HOME/app
35
+
36
+ # Copy pyproject.toml and other necessary files for installation
37
+ COPY --chown=user:user pyproject.toml ./
38
+ COPY --chown=user:user app ./app
39
+
40
+ # Install the project's dependencies
41
+ RUN pip3 install --no-cache-dir --upgrade pip && \
42
+ pip3 install --no-cache-dir torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu118 && \
43
+ pip3 install --no-cache-dir ".[gpu]"
44
+
45
+ RUN python -m spacy download en_core_web_sm
46
+
47
+ COPY --chown=user:user ./config/scanners.yml ./config/scanners.yml
48
+
49
+ # Expose the port the app runs on
50
+ EXPOSE 7860
51
+
52
+ # Specify the default command
53
+ CMD ["llm_guard_api", "/home/user/app/config/scanners.yml"]
Makefile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### --------------------------------------------------------------------------------------------------------------------
2
+ ### Variables
3
+ ### --------------------------------------------------------------------------------------------------------------------
4
+
5
+ # Docker config
6
+ DOCKER_IMAGE_NAME=laiyer/llm-guard-api
7
+ VERSION=0.3.10
8
+
9
+ # Other config
10
+ NO_COLOR=\033[0m
11
+ OK_COLOR=\033[32;01m
12
+ ERROR_COLOR=\033[31;01m
13
+ WARN_COLOR=\033[33;01m
14
+
15
+ install:
16
+ @python -m pip install ".[cpu]"
17
+
18
+ build-docker-multi:
19
+ @docker buildx build --platform linux/amd64,linux/arm64 -t $(DOCKER_IMAGE_NAME):$(VERSION) -t $(DOCKER_IMAGE_NAME):latest . --push
20
+
21
+ build-docker-cuda-multi:
22
+ @docker buildx build --platform linux/amd64 -t $(DOCKER_IMAGE_NAME):$(VERSION)-cuda -t $(DOCKER_IMAGE_NAME):latest-cuda -f Dockerfile-cuda . --push
23
+
24
+ run: install
25
+ llm_guard_api ./config/scanners.yml
26
+
27
+ run-docker:
28
+ @docker run -p 7860:7860 -e DEBUG='true' -v ./config:/home/user/app/config $(DOCKER_IMAGE_NAME):$(VERSION)
29
+
30
+ run-docker-cuda:
31
+ @docker run --gpus all -p 7860:7860 -e DEBUG='true' -v ./config:/home/user/app/config $(DOCKER_IMAGE_NAME):$(VERSION)-cuda
32
+
33
+ .PHONY: install run build-docker-multi build-docker-cuda-multi run-docker run-docker-cuda
README.md CHANGED
@@ -7,5 +7,3 @@ sdk: docker
7
  pinned: false
8
  license: mit
9
  ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
7
  pinned: false
8
  license: mit
9
  ---
 
 
app/__init__.py ADDED
File without changes
app/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from app import run_app
2
+
3
+ if __name__ == "__main__":
4
+ run_app()
app/app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import asyncio
3
+ import concurrent.futures
4
+ import time
5
+ from typing import Annotated
6
+
7
+ import structlog
8
+ from fastapi import Depends, FastAPI, HTTPException, Response, status
9
+ from fastapi.encoders import jsonable_encoder
10
+ from fastapi.exceptions import RequestValidationError
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from fastapi.responses import JSONResponse
13
+ from fastapi.security import (
14
+ HTTPAuthorizationCredentials,
15
+ HTTPBasic,
16
+ HTTPBasicCredentials,
17
+ HTTPBearer,
18
+ )
19
+ from opentelemetry import metrics
20
+ from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, generate_latest
21
+ from slowapi import Limiter, _rate_limit_exceeded_handler
22
+ from slowapi.errors import RateLimitExceeded
23
+ from slowapi.middleware import SlowAPIMiddleware
24
+ from slowapi.util import get_remote_address
25
+ from starlette.exceptions import HTTPException as StarletteHTTPException
26
+
27
+ from llm_guard import scan_output, scan_prompt
28
+ from llm_guard.vault import Vault
29
+
30
+ from .cache import InMemoryCache
31
+ from .config import AuthConfig, get_config
32
+ from .otel import configure_otel, instrument_app
33
+ from .scanner import get_input_scanners, get_output_scanners
34
+ from .schemas import (
35
+ AnalyzeOutputRequest,
36
+ AnalyzeOutputResponse,
37
+ AnalyzePromptRequest,
38
+ AnalyzePromptResponse,
39
+ )
40
+ from .util import configure_logger
41
+ from .version import __version__
42
+
43
+ vault = Vault()
44
+
45
+ parser = argparse.ArgumentParser(description="LLM Guard API")
46
+ parser.add_argument("config", type=str, help="Path to the configuration file")
47
+ args = parser.parse_args()
48
+ scanners_config_file = args.config
49
+
50
+ config = get_config(scanners_config_file)
51
+
52
+ LOGGER = structlog.getLogger(__name__)
53
+ log_level = config.app.log_level
54
+ is_debug = log_level == "DEBUG"
55
+ configure_logger(log_level)
56
+
57
+ configure_otel(config.app.name, config.tracing, config.metrics)
58
+
59
+ input_scanners = get_input_scanners(config.input_scanners, vault)
60
+ output_scanners = get_output_scanners(config.output_scanners, vault)
61
+
62
+
63
+ meter = metrics.get_meter_provider().get_meter(__name__)
64
+ scanners_valid_counter = meter.create_counter(
65
+ name="scanners.valid",
66
+ unit="1",
67
+ description="measures the number of valid scanners",
68
+ )
69
+
70
+
71
+ def create_app() -> FastAPI:
72
+ cache = InMemoryCache(
73
+ max_size=config.cache.max_size,
74
+ expiration_time=config.cache.ttl,
75
+ )
76
+
77
+ if config.app.scan_fail_fast:
78
+ LOGGER.debug("Scan fail_fast mode is enabled")
79
+
80
+ app = FastAPI(
81
+ title=config.app.name,
82
+ description="API to run LLM Guard scanners.",
83
+ debug=is_debug,
84
+ version=__version__,
85
+ openapi_url="/openapi.json" if is_debug else None, # hide docs in production
86
+ )
87
+
88
+ register_routes(app, cache, input_scanners, output_scanners)
89
+
90
+ return app
91
+
92
+
93
+ def _check_auth_function(auth_config: AuthConfig) -> callable:
94
+ async def check_auth_noop() -> bool:
95
+ return True
96
+
97
+ if not auth_config:
98
+ return check_auth_noop
99
+
100
+ if auth_config.type == "http_bearer":
101
+ credentials_type = Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer())]
102
+ elif auth_config.type == "http_basic":
103
+ credentials_type = Annotated[HTTPBasicCredentials, Depends(HTTPBasic())]
104
+ else:
105
+ raise ValueError(f"Invalid auth type: {auth_config.type}")
106
+
107
+ async def check_auth(credentials: credentials_type) -> bool:
108
+ if auth_config.type == "http_bearer":
109
+ if credentials.credentials != auth_config.token:
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
112
+ )
113
+ elif auth_config.type == "http_basic":
114
+ if (
115
+ credentials.username != auth_config.username
116
+ or credentials.password != auth_config.password
117
+ ):
118
+ raise HTTPException(
119
+ status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Username or Password"
120
+ )
121
+
122
+ return True
123
+
124
+ return check_auth
125
+
126
+
127
+ def register_routes(
128
+ app: FastAPI, cache: InMemoryCache, input_scanners: list, output_scanners: list
129
+ ):
130
+ app.add_middleware(
131
+ CORSMiddleware,
132
+ allow_origins=["*"],
133
+ allow_credentials=True,
134
+ allow_methods=["*"],
135
+ allow_headers=["Authorization", "Content-Type"],
136
+ )
137
+
138
+ limiter = Limiter(key_func=get_remote_address, default_limits=[config.rate_limit.limit])
139
+ app.state.limiter = limiter
140
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
141
+ if bool(config.rate_limit.enabled):
142
+ app.add_middleware(SlowAPIMiddleware)
143
+
144
+ check_auth = _check_auth_function(config.auth)
145
+
146
+ @app.get("/", tags=["Main"])
147
+ @limiter.exempt
148
+ async def read_root():
149
+ return {"name": "LLM Guard API"}
150
+
151
+ @app.get("/healthz", tags=["Health"])
152
+ @limiter.exempt
153
+ async def healthcheck():
154
+ return JSONResponse({"status": "alive"})
155
+
156
+ @app.get("/readyz", tags=["Health"])
157
+ @limiter.exempt
158
+ async def liveliness():
159
+ return JSONResponse({"status": "ready"})
160
+
161
+ @app.post(
162
+ "/analyze/output",
163
+ tags=["Analyze"],
164
+ response_model=AnalyzeOutputResponse,
165
+ status_code=status.HTTP_200_OK,
166
+ description="Analyze an output and return the sanitized output and the results of the scanners",
167
+ )
168
+ async def analyze_output(
169
+ request: AnalyzeOutputRequest, _: Annotated[bool, Depends(check_auth)]
170
+ ) -> AnalyzeOutputResponse:
171
+ LOGGER.debug("Received analyze output request", request=request)
172
+
173
+ with concurrent.futures.ThreadPoolExecutor() as executor:
174
+ loop = asyncio.get_event_loop()
175
+ try:
176
+ start_time = time.time()
177
+ sanitized_output, results_valid, results_score = await asyncio.wait_for(
178
+ loop.run_in_executor(
179
+ executor,
180
+ scan_output,
181
+ output_scanners,
182
+ request.prompt,
183
+ request.output,
184
+ config.app.scan_fail_fast,
185
+ ),
186
+ timeout=config.app.scan_output_timeout,
187
+ )
188
+
189
+ for scanner, valid in results_valid.items():
190
+ scanners_valid_counter.add(
191
+ 1, {"source": "output", "valid": valid, "scanner": scanner}
192
+ )
193
+
194
+ response = AnalyzeOutputResponse(
195
+ sanitized_output=sanitized_output,
196
+ is_valid=all(results_valid.values()),
197
+ scanners=results_score,
198
+ )
199
+ elapsed_time = time.time() - start_time
200
+ LOGGER.debug(
201
+ "Sanitized response",
202
+ scores=results_score,
203
+ elapsed_time_seconds=round(elapsed_time, 6),
204
+ )
205
+ except asyncio.TimeoutError:
206
+ raise HTTPException(
207
+ status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout."
208
+ )
209
+
210
+ return response
211
+
212
+ @app.post(
213
+ "/analyze/prompt",
214
+ tags=["Analyze"],
215
+ response_model=AnalyzePromptResponse,
216
+ status_code=status.HTTP_200_OK,
217
+ description="Analyze a prompt and return the sanitized prompt and the results of the scanners",
218
+ )
219
+ async def analyze_prompt(
220
+ request: AnalyzePromptRequest,
221
+ _: Annotated[bool, Depends(check_auth)],
222
+ response: Response,
223
+ ) -> AnalyzePromptResponse:
224
+ LOGGER.debug("Received analyze prompt request", request=request)
225
+
226
+ cached_result = cache.get(request.prompt)
227
+ if cached_result:
228
+ LOGGER.debug("Response was found in cache")
229
+
230
+ response.headers["X-Cache-Hit"] = "true"
231
+
232
+ return AnalyzePromptResponse(**cached_result)
233
+
234
+ response.headers["X-Cache-Hit"] = "false"
235
+
236
+ with concurrent.futures.ThreadPoolExecutor() as executor:
237
+ loop = asyncio.get_event_loop()
238
+ try:
239
+ start_time = time.time()
240
+ sanitized_prompt, results_valid, results_score = await asyncio.wait_for(
241
+ loop.run_in_executor(
242
+ executor,
243
+ scan_prompt,
244
+ input_scanners,
245
+ request.prompt,
246
+ config.app.scan_fail_fast,
247
+ ),
248
+ timeout=config.app.scan_prompt_timeout,
249
+ )
250
+
251
+ for scanner, valid in results_valid.items():
252
+ scanners_valid_counter.add(
253
+ 1, {"source": "input", "valid": valid, "scanner": scanner}
254
+ )
255
+
256
+ response = AnalyzePromptResponse(
257
+ sanitized_prompt=sanitized_prompt,
258
+ is_valid=all(results_valid.values()),
259
+ scanners=results_score,
260
+ )
261
+ cache.set(request.prompt, response.dict())
262
+
263
+ elapsed_time = time.time() - start_time
264
+ LOGGER.debug(
265
+ "Sanitized prompt response returned",
266
+ scores=results_score,
267
+ elapsed_time_seconds=round(elapsed_time, 6),
268
+ )
269
+ except asyncio.TimeoutError:
270
+ raise HTTPException(
271
+ status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout."
272
+ )
273
+
274
+ return response
275
+
276
+ if config.metrics and config.metrics.exporter == "prometheus":
277
+
278
+ @app.get("/metrics", tags=["Metrics"])
279
+ @limiter.exempt
280
+ async def metrics():
281
+ return Response(
282
+ content=generate_latest(REGISTRY), headers={"Content-Type": CONTENT_TYPE_LATEST}
283
+ )
284
+
285
+ @app.on_event("shutdown")
286
+ async def shutdown_event():
287
+ LOGGER.info("Shutting down app...")
288
+
289
+ @app.exception_handler(StarletteHTTPException)
290
+ async def http_exception_handler(request, exc):
291
+ LOGGER.warning(
292
+ "HTTP exception", exception_status_code=exc.status_code, exception_detail=exc.detail
293
+ )
294
+
295
+ return JSONResponse(
296
+ {"message": str(exc.detail), "details": None}, status_code=exc.status_code
297
+ )
298
+
299
+ @app.exception_handler(RequestValidationError)
300
+ async def validation_exception_handler(request, exc):
301
+ LOGGER.warning("Invalid request", exception=str(exc))
302
+
303
+ response = {"message": "Validation failed", "details": exc.errors()}
304
+ return JSONResponse(
305
+ jsonable_encoder(response), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
306
+ )
307
+
308
+
309
+ app = create_app()
310
+ instrument_app(app)
311
+
312
+
313
+ def run_app():
314
+ import uvicorn
315
+
316
+ uvicorn.run(
317
+ app,
318
+ host="0.0.0.0",
319
+ port=config.app.port,
320
+ server_header=False,
321
+ log_level=log_level.lower(),
322
+ proxy_headers=True,
323
+ forwarded_allow_ips="*",
324
+ timeout_keep_alive=2,
325
+ )
app/cache.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+ from collections import OrderedDict
4
+ from typing import Optional
5
+
6
+
7
+ class InMemoryCache:
8
+ """
9
+ A simple in-memory cache using an OrderedDict.
10
+
11
+ This cache supports setting a maximum size and expiration time for cached items.
12
+ When the cache is full, it uses a Least Recently Used (LRU) eviction policy.
13
+ Thread-safe using a threading Lock.
14
+
15
+ Attributes:
16
+ max_size (int, optional): Maximum number of items to store in the cache.
17
+ expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
18
+
19
+ Example:
20
+
21
+ cache = InMemoryCache(max_size=3, expiration_time=5)
22
+
23
+ # setting cache values
24
+ cache.set("a", 1)
25
+ cache.set("b", 2)
26
+ cache["c"] = 3
27
+
28
+ # getting cache values
29
+ a = cache.get("a")
30
+ b = cache["b"]
31
+ """
32
+
33
+ def __init__(self, max_size: Optional[int] = None, expiration_time: Optional[int] = 60 * 60):
34
+ """
35
+ Initialize a new InMemoryCache instance.
36
+
37
+ Args:
38
+ max_size (int, optional): Maximum number of items to store in the cache.
39
+ expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour.
40
+ """
41
+ self._cache = OrderedDict()
42
+ self._lock = threading.Lock()
43
+ self.max_size = max_size
44
+ self.expiration_time = expiration_time
45
+
46
+ def get(self, key):
47
+ """
48
+ Retrieve an item from the cache.
49
+
50
+ Args:
51
+ key: The key of the item to retrieve.
52
+
53
+ Returns:
54
+ The value associated with the key, or None if the key is not found or the item has expired.
55
+ """
56
+ with self._lock:
57
+ if key in self._cache:
58
+ item = self._cache.pop(key)
59
+ if (
60
+ self.expiration_time is None
61
+ or time.time() - item["time"] < self.expiration_time
62
+ ):
63
+ # Move the key to the end to make it recently used
64
+ self._cache[key] = item
65
+ return item["value"]
66
+ else:
67
+ self.delete(key)
68
+ return None
69
+
70
+ def set(self, key, value):
71
+ """
72
+ Add an item to the cache.
73
+
74
+ If the cache is full, the least recently used item is evicted.
75
+
76
+ Args:
77
+ key: The key of the item.
78
+ value: The value to cache.
79
+ """
80
+ with self._lock:
81
+ if key in self._cache:
82
+ # Remove existing key before re-inserting to update order
83
+ self.delete(key)
84
+ elif self.max_size and len(self._cache) >= self.max_size:
85
+ # Remove least recently used item
86
+ self._cache.popitem(last=False)
87
+ self._cache[key] = {"value": value, "time": time.time()}
88
+
89
+ def get_or_set(self, key, value):
90
+ """
91
+ Retrieve an item from the cache. If the item does not exist, set it with the provided value.
92
+
93
+ Args:
94
+ key: The key of the item.
95
+ value: The value to cache if the item doesn't exist.
96
+
97
+ Returns:
98
+ The cached value associated with the key.
99
+ """
100
+ with self._lock:
101
+ if key in self._cache:
102
+ return self.get(key)
103
+ self.set(key, value)
104
+ return value
105
+
106
+ def delete(self, key):
107
+ """
108
+ Remove an item from the cache.
109
+
110
+ Args:
111
+ key: The key of the item to remove.
112
+ """
113
+ # with self._lock:
114
+ self._cache.pop(key, None)
115
+
116
+ def clear(self):
117
+ """
118
+ Clear all items from the cache.
119
+ """
120
+ with self._lock:
121
+ self._cache.clear()
122
+
123
+ def __contains__(self, key):
124
+ """Check if the key is in the cache."""
125
+ return key in self._cache
126
+
127
+ def __getitem__(self, key):
128
+ """Retrieve an item from the cache using the square bracket notation."""
129
+ return self.get(key)
130
+
131
+ def __setitem__(self, key, value):
132
+ """Add an item to the cache using the square bracket notation."""
133
+ self.set(key, value)
134
+
135
+ def __delitem__(self, key):
136
+ """Remove an item from the cache using the square bracket notation."""
137
+ self.delete(key)
138
+
139
+ def __len__(self):
140
+ """Return the number of items in the cache."""
141
+ return len(self._cache)
142
+
143
+ def __repr__(self):
144
+ """Return a string representation of the InMemoryCache instance."""
145
+ return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})"
app/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import Any, Dict, List, Literal, Optional
4
+
5
+ import structlog
6
+ import yaml
7
+ from pydantic import BaseModel, Field
8
+
9
+ LOGGER = structlog.getLogger(__name__)
10
+
11
+ _var_matcher = re.compile(r"\${([^}^{]+)}")
12
+ _tag_matcher = re.compile(r"[^$]*\${([^}^{]+)}.*")
13
+
14
+
15
+ class RateLimitConfig(BaseModel):
16
+ enabled: bool = Field(default=False)
17
+ limit: str = Field(default="100/minute")
18
+
19
+
20
+ class CacheConfig(BaseModel):
21
+ ttl: int = Field(default=60)
22
+ max_size: Optional[int] = Field(default=None)
23
+
24
+
25
+ class AuthConfig(BaseModel):
26
+ type: Literal["http_bearer", "http_basic"] = Field()
27
+ token: Optional[str] = Field(default=None)
28
+ username: Optional[str] = Field(default=None)
29
+ password: Optional[str] = Field(default=None)
30
+
31
+
32
+ class TracingConfig(BaseModel):
33
+ exporter: Literal["otel_http", "console"] = Field(default="console")
34
+ endpoint: Optional[str] = Field(default=None)
35
+
36
+
37
+ class MetricsConfig(BaseModel):
38
+ exporter: Literal["otel_http", "prometheus", "console"] = Field(default="console")
39
+ endpoint: Optional[str] = Field(default=None)
40
+
41
+
42
+ class AppConfig(BaseModel):
43
+ name: Optional[str] = Field(default="LLM Guard API")
44
+ port: Optional[int] = Field(default=7860)
45
+ log_level: Optional[str] = Field(default="INFO")
46
+ scan_fail_fast: Optional[bool] = Field(default=False)
47
+ scan_prompt_timeout: Optional[int] = Field(default=10)
48
+ scan_output_timeout: Optional[int] = Field(default=30)
49
+
50
+
51
+ class ScannerConfig(BaseModel):
52
+ type: str
53
+ params: Optional[Dict] = Field(default_factory=dict)
54
+
55
+
56
+ class Config(BaseModel):
57
+ input_scanners: List[ScannerConfig] = Field()
58
+ output_scanners: List[ScannerConfig] = Field()
59
+ rate_limit: RateLimitConfig = Field(default_factory=RateLimitConfig)
60
+ cache: CacheConfig = Field(default_factory=CacheConfig)
61
+ auth: Optional[AuthConfig] = Field(default=None)
62
+ app: AppConfig = Field(default_factory=AppConfig)
63
+ tracing: Optional[TracingConfig] = Field(default=None)
64
+ metrics: Optional[MetricsConfig] = Field(default=None)
65
+
66
+
67
+ def _path_constructor(_loader: Any, node: Any):
68
+ def replace_fn(match):
69
+ envparts = f"{match.group(1)}:".split(":")
70
+ return os.environ.get(envparts[0], envparts[1])
71
+
72
+ return _var_matcher.sub(replace_fn, node.value)
73
+
74
+
75
+ def load_yaml(filename: str) -> dict:
76
+ yaml.add_implicit_resolver("!envvar", _tag_matcher, None, yaml.SafeLoader)
77
+ yaml.add_constructor("!envvar", _path_constructor, yaml.SafeLoader)
78
+ try:
79
+ with open(filename, "r") as f:
80
+ return yaml.safe_load(f.read())
81
+ except (FileNotFoundError, PermissionError, yaml.YAMLError) as exc:
82
+ LOGGER.error("Error loading YAML file", exception=exc)
83
+ return dict()
84
+
85
+
86
+ def get_config(file_name: str) -> Optional[Config]:
87
+ LOGGER.debug("Loading config file", file_name=file_name)
88
+
89
+ conf = load_yaml(file_name)
90
+ if conf == {}:
91
+ return None
92
+
93
+ return Config(**conf)
app/otel.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from opentelemetry import metrics, propagate, trace
3
+ from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
4
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
5
+ from opentelemetry.exporter.prometheus import PrometheusMetricReader
6
+ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
7
+ from opentelemetry.propagators.aws import AwsXRayPropagator
8
+ from opentelemetry.sdk.extension.aws.resource.ec2 import AwsEc2ResourceDetector
9
+ from opentelemetry.sdk.extension.aws.trace import AwsXRayIdGenerator
10
+ from opentelemetry.sdk.metrics import MeterProvider
11
+ from opentelemetry.sdk.metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader
12
+ from opentelemetry.sdk.resources import (
13
+ SERVICE_NAME,
14
+ SERVICE_VERSION,
15
+ Resource,
16
+ get_aggregated_resources,
17
+ )
18
+ from opentelemetry.sdk.trace import TracerProvider
19
+ from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
20
+
21
+ from .config import MetricsConfig, TracingConfig
22
+ from .version import __version__
23
+
24
+
25
+ def _configure_tracing(tracing_config: TracingConfig, resource: Resource) -> None:
26
+ if tracing_config is None:
27
+ return
28
+
29
+ if tracing_config.exporter == "xray":
30
+ propagate.set_global_textmap(AwsXRayPropagator())
31
+ resource = resource.merge(
32
+ get_aggregated_resources(
33
+ [AwsEc2ResourceDetector()],
34
+ )
35
+ )
36
+
37
+ tracer_provider = TracerProvider(resource=resource)
38
+ if tracing_config.exporter == "xray":
39
+ tracer_provider.id_generator = AwsXRayIdGenerator()
40
+ exporter = OTLPSpanExporter(endpoint=tracing_config.endpoint)
41
+ elif tracing_config.exporter == "otel_http":
42
+ exporter = OTLPSpanExporter(endpoint=tracing_config.endpoint)
43
+ elif tracing_config.exporter == "console":
44
+ exporter = ConsoleSpanExporter()
45
+
46
+ tracer_provider.add_span_processor(BatchSpanProcessor(exporter))
47
+ trace.set_tracer_provider(tracer_provider)
48
+
49
+
50
+ def _configure_metrics(metrics_config: MetricsConfig, resource: Resource) -> None:
51
+ if metrics_config is None:
52
+ return
53
+
54
+ if metrics_config.exporter == "console":
55
+ reader = PeriodicExportingMetricReader(ConsoleMetricExporter())
56
+ elif metrics_config.exporter == "otel_http":
57
+ reader = PeriodicExportingMetricReader(OTLPMetricExporter(endpoint=metrics_config.endpoint))
58
+ elif metrics_config.exporter == "prometheus":
59
+ reader = PrometheusMetricReader()
60
+
61
+ meter_provider = MeterProvider(resource=resource, metric_readers=[reader])
62
+ metrics.set_meter_provider(meter_provider)
63
+
64
+
65
+ def configure_otel(
66
+ app_name: str, tracing_config: TracingConfig, metrics_config: MetricsConfig
67
+ ) -> None:
68
+ resource = Resource(
69
+ attributes={
70
+ SERVICE_NAME: app_name,
71
+ SERVICE_VERSION: __version__,
72
+ }
73
+ )
74
+
75
+ _configure_tracing(tracing_config, resource)
76
+ _configure_metrics(metrics_config, resource)
77
+
78
+
79
+ def instrument_app(app: FastAPI) -> None:
80
+ FastAPIInstrumentor.instrument_app(
81
+ app,
82
+ excluded_urls="healthz,readyz,metrics",
83
+ meter_provider=metrics.get_meter_provider(),
84
+ tracer_provider=trace.get_tracer_provider(),
85
+ )
app/scanner.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+
3
+ import structlog
4
+
5
+ from llm_guard import input_scanners, output_scanners
6
+ from llm_guard.input_scanners.base import Scanner as InputScanner
7
+ from llm_guard.output_scanners.base import Scanner as OutputScanner
8
+ from llm_guard.vault import Vault
9
+
10
+ from .config import ScannerConfig
11
+ from .util import get_resource_utilization
12
+
13
+ LOGGER = structlog.getLogger(__name__)
14
+
15
+
16
+ def get_input_scanners(scanners: List[ScannerConfig], vault: Vault) -> List[InputScanner]:
17
+ """
18
+ Load input scanners from the configuration file.
19
+ """
20
+
21
+ input_scanners_loaded = []
22
+ for scanner in scanners:
23
+ LOGGER.debug("Loading input scanner", scanner=scanner.type, **get_resource_utilization())
24
+ input_scanners_loaded.append(
25
+ _get_input_scanner(
26
+ scanner.type,
27
+ scanner.params,
28
+ vault=vault,
29
+ )
30
+ )
31
+
32
+ return input_scanners_loaded
33
+
34
+
35
+ def get_output_scanners(scanners: List[ScannerConfig], vault: Vault) -> List[OutputScanner]:
36
+ """
37
+ Load output scanners from the configuration file.
38
+ """
39
+ output_scanners_loaded = []
40
+ for scanner in scanners:
41
+ LOGGER.debug("Loading output scanner", scanner=scanner.type, **get_resource_utilization())
42
+ output_scanners_loaded.append(
43
+ _get_output_scanner(
44
+ scanner.type,
45
+ scanner.params,
46
+ vault=vault,
47
+ )
48
+ )
49
+
50
+ return output_scanners_loaded
51
+
52
+
53
+ def _get_input_scanner(
54
+ scanner_name: str,
55
+ scanner_config: Optional[Dict],
56
+ *,
57
+ vault: Vault,
58
+ ):
59
+ if scanner_config is None:
60
+ scanner_config = {}
61
+
62
+ if scanner_name == "Anonymize":
63
+ scanner_config["vault"] = vault
64
+
65
+ if scanner_name in [
66
+ "Anonymize",
67
+ "BanTopics",
68
+ "Code",
69
+ "Gibberish",
70
+ "Language",
71
+ "PromptInjection",
72
+ "Toxicity",
73
+ ]:
74
+ scanner_config["use_onnx"] = True
75
+
76
+ return input_scanners.get_scanner_by_name(scanner_name, scanner_config)
77
+
78
+
79
+ def _get_output_scanner(
80
+ scanner_name: str,
81
+ scanner_config: Optional[Dict],
82
+ *,
83
+ vault: Vault,
84
+ ):
85
+ if scanner_config is None:
86
+ scanner_config = {}
87
+
88
+ if scanner_name == "Deanonymize":
89
+ scanner_config["vault"] = vault
90
+
91
+ if scanner_name in [
92
+ "BanTopics",
93
+ "Bias",
94
+ "Code",
95
+ "FactualConsistency",
96
+ "Gibberish",
97
+ "Language",
98
+ "LanguageSame",
99
+ "MaliciousURLs",
100
+ "NoRefusal",
101
+ "Relevance",
102
+ "Sensitive",
103
+ "Toxicity",
104
+ ]:
105
+ scanner_config["use_onnx"] = True
106
+
107
+ return output_scanners.get_scanner_by_name(scanner_name, scanner_config)
app/schemas.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class AnalyzePromptRequest(BaseModel):
7
+ prompt: str
8
+
9
+
10
+ class AnalyzePromptResponse(BaseModel):
11
+ sanitized_prompt: str
12
+ is_valid: bool
13
+ scanners: Dict[str, float]
14
+
15
+
16
+ class AnalyzeOutputRequest(BaseModel):
17
+ prompt: str
18
+ output: str
19
+
20
+
21
+ class AnalyzeOutputResponse(BaseModel):
22
+ sanitized_output: str
23
+ is_valid: bool
24
+ scanners: Dict[str, float]
app/util.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from os import getpid
4
+ from typing import Dict, Literal
5
+
6
+ import psutil
7
+ import structlog
8
+
9
+ from llm_guard.util import configure_logger as configure_llm_guard_logger
10
+
11
+ LOG_LEVELS = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
12
+ EXTERNAL_LOGGERS = {
13
+ "transformers",
14
+ }
15
+
16
+
17
+ def configure_logger(log_level: LOG_LEVELS = "INFO"):
18
+ """
19
+ Configures the logger for the package.
20
+
21
+ Args:
22
+ log_level: The log level to use for the logger. It should be one of the following strings:
23
+ "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL".
24
+ """
25
+ logging.basicConfig(
26
+ format="[%(asctime)s - %(name)s - %(levelname)s] %(message)s",
27
+ level=log_level,
28
+ stream=sys.stdout,
29
+ )
30
+ structlog.configure(logger_factory=structlog.stdlib.LoggerFactory())
31
+ for log_name in EXTERNAL_LOGGERS:
32
+ logging.getLogger(log_name).setLevel(logging.WARNING)
33
+
34
+ configure_llm_guard_logger(log_level)
35
+
36
+
37
+ def get_resource_utilization() -> Dict:
38
+ """
39
+ Returns the current resource utilization of the system.
40
+
41
+ Returns:
42
+ A dictionary containing the current resource utilization of the system.
43
+ """
44
+
45
+ process = psutil.Process(getpid())
46
+ # A float representing the current system-wide CPU utilization as a percentage
47
+ cpu_percent = process.cpu_percent()
48
+ # A float representing process memory utilization as a percentage
49
+ memory_percent = process.memory_percent()
50
+ # Total physical memory
51
+ total_memory_bytes = psutil.virtual_memory().total
52
+
53
+ return {
54
+ "cpu_utilization_percent": cpu_percent,
55
+ "memory_utilization_percent": memory_percent,
56
+ "total_memory_available_bytes": total_memory_bytes,
57
+ }
app/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.0.6"
config/scanners.yml ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ app:
2
+ name: ${APP_NAME:LLM Guard API}
3
+ log_level: ${LOG_LEVEL:INFO}
4
+ scan_fail_fast: ${SCAN_FAIL_FAST:false}
5
+ scan_prompt_timeout: ${SCAN_PROMPT_TIMEOUT:10}
6
+ scan_output_timeout: ${SCAN_OUTPUT_TIMEOUT:30}
7
+ port: ${APP_PORT:7860}
8
+
9
+ cache:
10
+ ttl: ${CACHE_TTL:3600}
11
+ #max_size: ${CACHE_MAX_SIZE:1000}
12
+
13
+ rate_limit:
14
+ enabled: ${RATE_LIMIT_ENABLED:true}
15
+ limit: ${RATE_LIMIT_LIMIT:100/minute}
16
+
17
+ #auth:
18
+ # type: http_bearer
19
+ # token: ${AUTH_TOKEN:}
20
+
21
+ tracing:
22
+ exporter: ${TRACING_EXPORTER:console}
23
+ endpoint: ${TRACING_OTEL_ENDPOINT:} # Example: "<traces-endpoint>/v1/traces"
24
+
25
+ metrics:
26
+ exporter: ${METRICS_TYPE:prometheus}
27
+ endpoint: ${METRICS_ENDPOINT:} # Example: "<metrics-endpoint>/v1/metrics"
28
+
29
+ # Scanners are applied in the order they are listed here.
30
+ input_scanners:
31
+ - type: Anonymize
32
+ params:
33
+ # allowed_names: []
34
+ # hidden_names: []
35
+ # entity_types: []
36
+ # preamble: ""
37
+ use_faker: false
38
+ - type: BanCompetitors
39
+ params:
40
+ competitors: ["facebook"]
41
+ threshold: 0.5
42
+ - type: BanSubstrings
43
+ params:
44
+ substrings: ["test"]
45
+ match_type: "word"
46
+ case_sensitive: false
47
+ redact: false
48
+ contains_all: false
49
+ - type: BanTopics
50
+ params:
51
+ topics: ["violence"]
52
+ threshold: 0.6
53
+ - type: Code
54
+ params:
55
+ languages: ["Python"]
56
+ is_blocked: true
57
+ - type: Gibberish
58
+ params:
59
+ threshold: 0.9
60
+ - type: InvisibleText
61
+ params: {}
62
+ - type: Language
63
+ params:
64
+ valid_languages: ["en"]
65
+ - type: PromptInjection
66
+ params:
67
+ threshold: 0.9
68
+ - type: Regex
69
+ params:
70
+ patterns: ["Bearer [A-Za-z0-9-._~+/]+"]
71
+ is_blocked: true
72
+ match_type: search
73
+ redact: true
74
+ - type: Secrets
75
+ params:
76
+ redact_mode: "all"
77
+ - type: Sentiment
78
+ params:
79
+ # lexicon: "vader_lexicon"
80
+ threshold: -0.1
81
+ - type: TokenLimit
82
+ params:
83
+ limit: 4096
84
+ encoding_name: "cl100k_base"
85
+ - type: Toxicity
86
+ params:
87
+ threshold: 0.5
88
+
89
+ output_scanners:
90
+ - type: BanCompetitors
91
+ params:
92
+ competitors: ["facebook"]
93
+ threshold: 0.5
94
+ - type: BanSubstrings
95
+ params:
96
+ substrings: ["test"]
97
+ match_type: "word"
98
+ case_sensitive: false
99
+ redact: false
100
+ contains_all: false
101
+ - type: BanTopics
102
+ params:
103
+ topics: ["violence"]
104
+ threshold: 0.6
105
+ - type: Bias
106
+ params:
107
+ threshold: 0.75
108
+ - type: Code
109
+ params:
110
+ languages: ["Python"]
111
+ is_blocked: true
112
+ - type: Deanonymize
113
+ params:
114
+ matching_strategy: "exact"
115
+ - type: FactualConsistency
116
+ params:
117
+ minimum_score: 0.5
118
+ - type: Gibberish
119
+ params:
120
+ threshold: 0.9
121
+ - type: JSON
122
+ params:
123
+ required_elements: 0
124
+ repair: true
125
+ - type: Language
126
+ params:
127
+ valid_languages: ["en"]
128
+ - type: LanguageSame
129
+ params: {}
130
+ - type: MaliciousURLs
131
+ params:
132
+ threshold: 0.75
133
+ - type: NoRefusal
134
+ params:
135
+ threshold: 0.5
136
+ - type: ReadingTime
137
+ params:
138
+ max_time: 5
139
+ truncate: false
140
+ - type: Regex
141
+ params:
142
+ patterns: ["Bearer [A-Za-z0-9-._~+/]+"]
143
+ is_blocked: true
144
+ match_type: search
145
+ redact: true
146
+ - type: Relevance
147
+ params:
148
+ threshold: 0.5
149
+ - type: Sensitive
150
+ params:
151
+ # entity_types:
152
+ redact: false
153
+ threshold: 0.0
154
+ - type: Sentiment
155
+ params:
156
+ threshold: -0.1
157
+ # lexicon: "vader_lexicon"
158
+ - type: Toxicity
159
+ params:
160
+ threshold: 0.5
161
+ - type: URLReachability
162
+ params: {}
docker-compose.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ llm_guard_api:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ - "7860:7860"
10
+ volumes:
11
+ - ./config/scanners.yml:/home/user/app/config/scanners.yml
openapi.json ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "openapi": "3.1.0",
3
+ "info": {
4
+ "title": "LLM Guard API",
5
+ "description": "API to run LLM Guard scanners.",
6
+ "version": "0.0.6"
7
+ },
8
+ "paths": {
9
+ "/": {
10
+ "get": {
11
+ "tags": [
12
+ "Main"
13
+ ],
14
+ "summary": "Read Root",
15
+ "operationId": "read_root__get",
16
+ "responses": {
17
+ "200": {
18
+ "description": "Successful Response",
19
+ "content": {
20
+ "application/json": {
21
+ "schema": {}
22
+ }
23
+ }
24
+ }
25
+ }
26
+ }
27
+ },
28
+ "/healthz": {
29
+ "get": {
30
+ "tags": [
31
+ "Health"
32
+ ],
33
+ "summary": "Healthcheck",
34
+ "operationId": "healthcheck_healthz_get",
35
+ "responses": {
36
+ "200": {
37
+ "description": "Successful Response",
38
+ "content": {
39
+ "application/json": {
40
+ "schema": {}
41
+ }
42
+ }
43
+ }
44
+ }
45
+ }
46
+ },
47
+ "/readyz": {
48
+ "get": {
49
+ "tags": [
50
+ "Health"
51
+ ],
52
+ "summary": "Liveliness",
53
+ "operationId": "liveliness_readyz_get",
54
+ "responses": {
55
+ "200": {
56
+ "description": "Successful Response",
57
+ "content": {
58
+ "application/json": {
59
+ "schema": {}
60
+ }
61
+ }
62
+ }
63
+ }
64
+ }
65
+ },
66
+ "/analyze/output": {
67
+ "post": {
68
+ "tags": [
69
+ "Analyze"
70
+ ],
71
+ "summary": "Analyze Output",
72
+ "description": "Analyze an output and return the sanitized output and the results of the scanners",
73
+ "operationId": "analyze_output_analyze_output_post",
74
+ "requestBody": {
75
+ "content": {
76
+ "application/json": {
77
+ "schema": {
78
+ "$ref": "#/components/schemas/AnalyzeOutputRequest"
79
+ }
80
+ }
81
+ },
82
+ "required": true
83
+ },
84
+ "responses": {
85
+ "200": {
86
+ "description": "Successful Response",
87
+ "content": {
88
+ "application/json": {
89
+ "schema": {
90
+ "$ref": "#/components/schemas/AnalyzeOutputResponse"
91
+ }
92
+ }
93
+ }
94
+ },
95
+ "422": {
96
+ "description": "Validation Error",
97
+ "content": {
98
+ "application/json": {
99
+ "schema": {
100
+ "$ref": "#/components/schemas/HTTPValidationError"
101
+ }
102
+ }
103
+ }
104
+ }
105
+ },
106
+ "security": [
107
+ {
108
+ "HTTPBearer": []
109
+ }
110
+ ]
111
+ }
112
+ },
113
+ "/analyze/prompt": {
114
+ "post": {
115
+ "tags": [
116
+ "Analyze"
117
+ ],
118
+ "summary": "Analyze Prompt",
119
+ "description": "Analyze a prompt and return the sanitized prompt and the results of the scanners",
120
+ "operationId": "analyze_prompt_analyze_prompt_post",
121
+ "requestBody": {
122
+ "content": {
123
+ "application/json": {
124
+ "schema": {
125
+ "$ref": "#/components/schemas/AnalyzePromptRequest"
126
+ }
127
+ }
128
+ },
129
+ "required": true
130
+ },
131
+ "responses": {
132
+ "200": {
133
+ "description": "Successful Response",
134
+ "content": {
135
+ "application/json": {
136
+ "schema": {
137
+ "$ref": "#/components/schemas/AnalyzePromptResponse"
138
+ }
139
+ }
140
+ }
141
+ },
142
+ "422": {
143
+ "description": "Validation Error",
144
+ "content": {
145
+ "application/json": {
146
+ "schema": {
147
+ "$ref": "#/components/schemas/HTTPValidationError"
148
+ }
149
+ }
150
+ }
151
+ }
152
+ },
153
+ "security": [
154
+ {
155
+ "HTTPBearer": []
156
+ }
157
+ ]
158
+ }
159
+ },
160
+ "/metrics": {
161
+ "get": {
162
+ "tags": [
163
+ "Metrics"
164
+ ],
165
+ "summary": "Metrics",
166
+ "operationId": "metrics_metrics_get",
167
+ "responses": {
168
+ "200": {
169
+ "description": "Successful Response",
170
+ "content": {
171
+ "application/json": {
172
+ "schema": {}
173
+ }
174
+ }
175
+ }
176
+ }
177
+ }
178
+ }
179
+ },
180
+ "components": {
181
+ "schemas": {
182
+ "AnalyzeOutputRequest": {
183
+ "properties": {
184
+ "prompt": {
185
+ "type": "string",
186
+ "title": "Prompt"
187
+ },
188
+ "output": {
189
+ "type": "string",
190
+ "title": "Output"
191
+ }
192
+ },
193
+ "type": "object",
194
+ "required": [
195
+ "prompt",
196
+ "output"
197
+ ],
198
+ "title": "AnalyzeOutputRequest"
199
+ },
200
+ "AnalyzeOutputResponse": {
201
+ "properties": {
202
+ "sanitized_output": {
203
+ "type": "string",
204
+ "title": "Sanitized Output"
205
+ },
206
+ "is_valid": {
207
+ "type": "boolean",
208
+ "title": "Is Valid"
209
+ },
210
+ "scanners": {
211
+ "additionalProperties": {
212
+ "type": "number"
213
+ },
214
+ "type": "object",
215
+ "title": "Scanners"
216
+ }
217
+ },
218
+ "type": "object",
219
+ "required": [
220
+ "sanitized_output",
221
+ "is_valid",
222
+ "scanners"
223
+ ],
224
+ "title": "AnalyzeOutputResponse"
225
+ },
226
+ "AnalyzePromptRequest": {
227
+ "properties": {
228
+ "prompt": {
229
+ "type": "string",
230
+ "title": "Prompt"
231
+ }
232
+ },
233
+ "type": "object",
234
+ "required": [
235
+ "prompt"
236
+ ],
237
+ "title": "AnalyzePromptRequest"
238
+ },
239
+ "AnalyzePromptResponse": {
240
+ "properties": {
241
+ "sanitized_prompt": {
242
+ "type": "string",
243
+ "title": "Sanitized Prompt"
244
+ },
245
+ "is_valid": {
246
+ "type": "boolean",
247
+ "title": "Is Valid"
248
+ },
249
+ "scanners": {
250
+ "additionalProperties": {
251
+ "type": "number"
252
+ },
253
+ "type": "object",
254
+ "title": "Scanners"
255
+ }
256
+ },
257
+ "type": "object",
258
+ "required": [
259
+ "sanitized_prompt",
260
+ "is_valid",
261
+ "scanners"
262
+ ],
263
+ "title": "AnalyzePromptResponse"
264
+ },
265
+ "HTTPValidationError": {
266
+ "properties": {
267
+ "detail": {
268
+ "items": {
269
+ "$ref": "#/components/schemas/ValidationError"
270
+ },
271
+ "type": "array",
272
+ "title": "Detail"
273
+ }
274
+ },
275
+ "type": "object",
276
+ "title": "HTTPValidationError"
277
+ },
278
+ "ValidationError": {
279
+ "properties": {
280
+ "loc": {
281
+ "items": {
282
+ "anyOf": [
283
+ {
284
+ "type": "string"
285
+ },
286
+ {
287
+ "type": "integer"
288
+ }
289
+ ]
290
+ },
291
+ "type": "array",
292
+ "title": "Location"
293
+ },
294
+ "msg": {
295
+ "type": "string",
296
+ "title": "Message"
297
+ },
298
+ "type": {
299
+ "type": "string",
300
+ "title": "Error Type"
301
+ }
302
+ },
303
+ "type": "object",
304
+ "required": [
305
+ "loc",
306
+ "msg",
307
+ "type"
308
+ ],
309
+ "title": "ValidationError"
310
+ }
311
+ },
312
+ "securitySchemes": {
313
+ "HTTPBearer": {
314
+ "type": "http",
315
+ "scheme": "bearer"
316
+ }
317
+ }
318
+ }
319
+ }
pyproject.toml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "llm-guard-api"
3
+ description = "LLM Guard API is a deployment of LLM Guard as an API."
4
+ authors = [
5
+ { name = "Protect AI", email = "[email protected]"}
6
+ ]
7
+ readme = "README.md"
8
+ dynamic = ["version"]
9
+ classifiers = [
10
+ "Development Status :: 4 - Beta",
11
+ "Intended Audience :: Developers",
12
+ "License :: OSI Approved :: MIT License",
13
+ "Programming Language :: Python :: 3",
14
+ "Programming Language :: Python :: 3.9",
15
+ "Programming Language :: Python :: 3.10",
16
+ "Programming Language :: Python :: 3.11",
17
+ ]
18
+ requires-python = ">=3.9"
19
+
20
+ dependencies = [
21
+ "asyncio==3.4.3",
22
+ "fastapi==0.110.0",
23
+ "llm-guard==0.3.10",
24
+ "pydantic==1.10.14",
25
+ "pyyaml==6.0.1",
26
+ "uvicorn[standard]==0.29.0",
27
+ "structlog>=24",
28
+ "slowapi==0.1.9",
29
+ "opentelemetry-instrumentation-fastapi==0.44b0",
30
+ "opentelemetry-api==1.23.0",
31
+ "opentelemetry-sdk==1.23.0",
32
+ "opentelemetry-exporter-otlp-proto-http==1.23.0",
33
+ "opentelemetry-exporter-prometheus==0.44b0",
34
+ "opentelemetry-sdk-extension-aws==2.0.1",
35
+ "opentelemetry-propagator-aws-xray==1.0.1"
36
+ ]
37
+
38
+ [project.optional-dependencies]
39
+ cpu = [
40
+ "llm-guard[onnxruntime]==0.3.10",
41
+ ]
42
+ gpu = [
43
+ "llm-guard[onnxruntime-gpu]==0.3.10",
44
+ ]
45
+
46
+ [tool.setuptools]
47
+ packages = ["app"]
48
+
49
+ [tool.setuptools.dynamic]
50
+ version = {attr = "app.version.__version__"}
51
+
52
+ [build-system]
53
+ requires = ["setuptools", "wheel"]
54
+ build-backend = "setuptools.build_meta"
55
+
56
+ [project.scripts]
57
+ llm_guard_api = "app.app:run_app"