|
import argparse |
|
import asyncio |
|
import concurrent.futures |
|
import time |
|
from typing import Annotated |
|
|
|
import structlog |
|
from fastapi import Depends, FastAPI, HTTPException, Response, status |
|
from fastapi.encoders import jsonable_encoder |
|
from fastapi.exceptions import RequestValidationError |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
from fastapi.security import ( |
|
HTTPAuthorizationCredentials, |
|
HTTPBasic, |
|
HTTPBasicCredentials, |
|
HTTPBearer, |
|
) |
|
from opentelemetry import metrics |
|
from prometheus_client import CONTENT_TYPE_LATEST, REGISTRY, generate_latest |
|
from slowapi import Limiter, _rate_limit_exceeded_handler |
|
from slowapi.errors import RateLimitExceeded |
|
from slowapi.middleware import SlowAPIMiddleware |
|
from slowapi.util import get_remote_address |
|
from starlette.exceptions import HTTPException as StarletteHTTPException |
|
|
|
from llm_guard import scan_output, scan_prompt |
|
from llm_guard.vault import Vault |
|
|
|
from .cache import InMemoryCache |
|
from .config import AuthConfig, get_config |
|
from .otel import configure_otel, instrument_app |
|
from .scanner import get_input_scanners, get_output_scanners |
|
from .schemas import ( |
|
AnalyzeOutputRequest, |
|
AnalyzeOutputResponse, |
|
AnalyzePromptRequest, |
|
AnalyzePromptResponse, |
|
) |
|
from .util import configure_logger |
|
from .version import __version__ |
|
|
|
vault = Vault() |
|
|
|
parser = argparse.ArgumentParser(description="LLM Guard API") |
|
parser.add_argument("config", type=str, help="Path to the configuration file") |
|
args = parser.parse_args() |
|
scanners_config_file = args.config |
|
|
|
config = get_config(scanners_config_file) |
|
|
|
LOGGER = structlog.getLogger(__name__) |
|
log_level = config.app.log_level |
|
is_debug = log_level == "DEBUG" |
|
configure_logger(log_level) |
|
|
|
configure_otel(config.app.name, config.tracing, config.metrics) |
|
|
|
input_scanners = get_input_scanners(config.input_scanners, vault) |
|
output_scanners = get_output_scanners(config.output_scanners, vault) |
|
|
|
|
|
meter = metrics.get_meter_provider().get_meter(__name__) |
|
scanners_valid_counter = meter.create_counter( |
|
name="scanners.valid", |
|
unit="1", |
|
description="measures the number of valid scanners", |
|
) |
|
|
|
|
|
def create_app() -> FastAPI: |
|
cache = InMemoryCache( |
|
max_size=config.cache.max_size, |
|
expiration_time=config.cache.ttl, |
|
) |
|
|
|
if config.app.scan_fail_fast: |
|
LOGGER.debug("Scan fail_fast mode is enabled") |
|
|
|
app = FastAPI( |
|
title=config.app.name, |
|
description="API to run LLM Guard scanners.", |
|
debug=is_debug, |
|
version=__version__, |
|
openapi_url="/openapi.json" if is_debug else None, |
|
) |
|
|
|
register_routes(app, cache, input_scanners, output_scanners) |
|
|
|
return app |
|
|
|
|
|
def _check_auth_function(auth_config: AuthConfig) -> callable: |
|
async def check_auth_noop() -> bool: |
|
return True |
|
|
|
if not auth_config: |
|
return check_auth_noop |
|
|
|
if auth_config.type == "http_bearer": |
|
credentials_type = Annotated[HTTPAuthorizationCredentials, Depends(HTTPBearer())] |
|
elif auth_config.type == "http_basic": |
|
credentials_type = Annotated[HTTPBasicCredentials, Depends(HTTPBasic())] |
|
else: |
|
raise ValueError(f"Invalid auth type: {auth_config.type}") |
|
|
|
async def check_auth(credentials: credentials_type) -> bool: |
|
if auth_config.type == "http_bearer": |
|
if credentials.credentials != auth_config.token: |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" |
|
) |
|
elif auth_config.type == "http_basic": |
|
if ( |
|
credentials.username != auth_config.username |
|
or credentials.password != auth_config.password |
|
): |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Username or Password" |
|
) |
|
|
|
return True |
|
|
|
return check_auth |
|
|
|
|
|
def register_routes( |
|
app: FastAPI, cache: InMemoryCache, input_scanners: list, output_scanners: list |
|
): |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["Authorization", "Content-Type"], |
|
) |
|
|
|
limiter = Limiter(key_func=get_remote_address, default_limits=[config.rate_limit.limit]) |
|
app.state.limiter = limiter |
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) |
|
if bool(config.rate_limit.enabled): |
|
app.add_middleware(SlowAPIMiddleware) |
|
|
|
check_auth = _check_auth_function(config.auth) |
|
|
|
@app.get("/", tags=["Main"]) |
|
@limiter.exempt |
|
async def read_root(): |
|
return {"name": "LLM Guard API"} |
|
|
|
@app.get("/healthz", tags=["Health"]) |
|
@limiter.exempt |
|
async def healthcheck(): |
|
return JSONResponse({"status": "alive"}) |
|
|
|
@app.get("/readyz", tags=["Health"]) |
|
@limiter.exempt |
|
async def liveliness(): |
|
return JSONResponse({"status": "ready"}) |
|
|
|
@app.post( |
|
"/analyze/output", |
|
tags=["Analyze"], |
|
response_model=AnalyzeOutputResponse, |
|
status_code=status.HTTP_200_OK, |
|
description="Analyze an output and return the sanitized output and the results of the scanners", |
|
) |
|
async def analyze_output( |
|
request: AnalyzeOutputRequest, _: Annotated[bool, Depends(check_auth)] |
|
) -> AnalyzeOutputResponse: |
|
LOGGER.debug("Received analyze output request", request=request) |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
loop = asyncio.get_event_loop() |
|
try: |
|
start_time = time.time() |
|
sanitized_output, results_valid, results_score = await asyncio.wait_for( |
|
loop.run_in_executor( |
|
executor, |
|
scan_output, |
|
output_scanners, |
|
request.prompt, |
|
request.output, |
|
config.app.scan_fail_fast, |
|
), |
|
timeout=config.app.scan_output_timeout, |
|
) |
|
|
|
for scanner, valid in results_valid.items(): |
|
scanners_valid_counter.add( |
|
1, {"source": "output", "valid": valid, "scanner": scanner} |
|
) |
|
|
|
response = AnalyzeOutputResponse( |
|
sanitized_output=sanitized_output, |
|
is_valid=all(results_valid.values()), |
|
scanners=results_score, |
|
) |
|
elapsed_time = time.time() - start_time |
|
LOGGER.debug( |
|
"Sanitized response", |
|
scores=results_score, |
|
elapsed_time_seconds=round(elapsed_time, 6), |
|
) |
|
except asyncio.TimeoutError: |
|
raise HTTPException( |
|
status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout." |
|
) |
|
|
|
return response |
|
|
|
@app.post( |
|
"/analyze/prompt", |
|
tags=["Analyze"], |
|
response_model=AnalyzePromptResponse, |
|
status_code=status.HTTP_200_OK, |
|
description="Analyze a prompt and return the sanitized prompt and the results of the scanners", |
|
) |
|
async def analyze_prompt( |
|
request: AnalyzePromptRequest, |
|
_: Annotated[bool, Depends(check_auth)], |
|
response: Response, |
|
) -> AnalyzePromptResponse: |
|
LOGGER.debug("Received analyze prompt request", request=request) |
|
|
|
cached_result = cache.get(request.prompt) |
|
if cached_result: |
|
LOGGER.debug("Response was found in cache") |
|
|
|
response.headers["X-Cache-Hit"] = "true" |
|
|
|
return AnalyzePromptResponse(**cached_result) |
|
|
|
response.headers["X-Cache-Hit"] = "false" |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
loop = asyncio.get_event_loop() |
|
try: |
|
start_time = time.time() |
|
sanitized_prompt, results_valid, results_score = await asyncio.wait_for( |
|
loop.run_in_executor( |
|
executor, |
|
scan_prompt, |
|
input_scanners, |
|
request.prompt, |
|
config.app.scan_fail_fast, |
|
), |
|
timeout=config.app.scan_prompt_timeout, |
|
) |
|
|
|
for scanner, valid in results_valid.items(): |
|
scanners_valid_counter.add( |
|
1, {"source": "input", "valid": valid, "scanner": scanner} |
|
) |
|
|
|
response = AnalyzePromptResponse( |
|
sanitized_prompt=sanitized_prompt, |
|
is_valid=all(results_valid.values()), |
|
scanners=results_score, |
|
) |
|
cache.set(request.prompt, response.dict()) |
|
|
|
elapsed_time = time.time() - start_time |
|
LOGGER.debug( |
|
"Sanitized prompt response returned", |
|
scores=results_score, |
|
elapsed_time_seconds=round(elapsed_time, 6), |
|
) |
|
except asyncio.TimeoutError: |
|
raise HTTPException( |
|
status_code=status.HTTP_408_REQUEST_TIMEOUT, detail="Request timeout." |
|
) |
|
|
|
return response |
|
|
|
if config.metrics and config.metrics.exporter == "prometheus": |
|
|
|
@app.get("/metrics", tags=["Metrics"]) |
|
@limiter.exempt |
|
async def metrics(): |
|
return Response( |
|
content=generate_latest(REGISTRY), headers={"Content-Type": CONTENT_TYPE_LATEST} |
|
) |
|
|
|
@app.on_event("shutdown") |
|
async def shutdown_event(): |
|
LOGGER.info("Shutting down app...") |
|
|
|
@app.exception_handler(StarletteHTTPException) |
|
async def http_exception_handler(request, exc): |
|
LOGGER.warning( |
|
"HTTP exception", exception_status_code=exc.status_code, exception_detail=exc.detail |
|
) |
|
|
|
return JSONResponse( |
|
{"message": str(exc.detail), "details": None}, status_code=exc.status_code |
|
) |
|
|
|
@app.exception_handler(RequestValidationError) |
|
async def validation_exception_handler(request, exc): |
|
LOGGER.warning("Invalid request", exception=str(exc)) |
|
|
|
response = {"message": "Validation failed", "details": exc.errors()} |
|
return JSONResponse( |
|
jsonable_encoder(response), status_code=status.HTTP_422_UNPROCESSABLE_ENTITY |
|
) |
|
|
|
|
|
app = create_app() |
|
instrument_app(app) |
|
|
|
|
|
def run_app(): |
|
import uvicorn |
|
|
|
uvicorn.run( |
|
app, |
|
host="0.0.0.0", |
|
port=config.app.port, |
|
server_header=False, |
|
log_level=log_level.lower(), |
|
proxy_headers=True, |
|
forwarded_allow_ips="*", |
|
timeout_keep_alive=2, |
|
) |
|
|