import asyncio, os |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Request |
from fastapi.responses import JSONResponse |
from fastapi import FastAPI, HTTPException, Request |
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse |
from fastapi.staticfiles import StaticFiles |
from fastapi.middleware.cors import CORSMiddleware |
from fastapi.templating import Jinja2Templates |
from fastapi.exceptions import RequestValidationError |
from starlette.middleware.base import BaseHTTPMiddleware |
from starlette.responses import FileResponse |
from fastapi.responses import RedirectResponse |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
from fastapi import Depends, Security |
from pydantic import BaseModel, HttpUrl, Field |
from typing import Optional, List, Dict, Any, Union |
import psutil |
import time |
import uuid |
from collections import defaultdict |
from urllib.parse import urlparse |
import math |
import logging |
from enum import Enum |
from dataclasses import dataclass |
import json |
from crawl4ai import AsyncWebCrawler, CrawlResult, CacheMode |
from crawl4ai.config import MIN_WORD_THRESHOLD |
from crawl4ai.extraction_strategy import ( |
LLMExtractionStrategy, |
CosineStrategy, |
JsonCssExtractionStrategy, |
) |
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) |
logging.basicConfig(level=logging.INFO) |
logger = logging.getLogger(__name__) |
class TaskStatus(str, Enum): |
PENDING = "pending" |
PROCESSING = "processing" |
COMPLETED = "completed" |
FAILED = "failed" |
class CrawlerType(str, Enum): |
BASIC = "basic" |
LLM = "llm" |
COSINE = "cosine" |
JSON_CSS = "json_css" |
class ExtractionConfig(BaseModel): |
type: CrawlerType |
params: Dict[str, Any] = {} |
class ChunkingStrategy(BaseModel): |
type: str |
params: Dict[str, Any] = {} |
class ContentFilter(BaseModel): |
type: str = "bm25" |
params: Dict[str, Any] = {} |
class CrawlRequest(BaseModel): |
urls: Union[HttpUrl, List[HttpUrl]] |
word_count_threshold: int = MIN_WORD_THRESHOLD |
extraction_config: Optional[ExtractionConfig] = None |
chunking_strategy: Optional[ChunkingStrategy] = None |
content_filter: Optional[ContentFilter] = None |
js_code: Optional[List[str]] = None |
wait_for: Optional[str] = None |
css_selector: Optional[str] = None |
screenshot: bool = False |
magic: bool = False |
extra: Optional[Dict[str, Any]] = {} |
session_id: Optional[str] = None |
cache_mode: Optional[CacheMode] = CacheMode.ENABLED |
priority: int = Field(default=5, ge=1, le=10) |
ttl: Optional[int] = 3600 |
crawler_params: Dict[str, Any] = {} |
@dataclass |
class TaskInfo: |
id: str |
status: TaskStatus |
result: Optional[Union[CrawlResult, List[CrawlResult]]] = None |
error: Optional[str] = None |
created_at: float = time.time() |
ttl: int = 3600 |
class ResourceMonitor: |
def __init__(self, max_concurrent_tasks: int = 10): |
self.max_concurrent_tasks = max_concurrent_tasks |
self.memory_threshold = 0.85 |
self.cpu_threshold = 0.90 |
self._last_check = 0 |
self._check_interval = 1 |
self._last_available_slots = max_concurrent_tasks |
async def get_available_slots(self) -> int: |
current_time = time.time() |
if current_time - self._last_check < self._check_interval: |
return self._last_available_slots |
mem_usage = psutil.virtual_memory().percent / 100 |
cpu_usage = psutil.cpu_percent() / 100 |
memory_factor = max( |
0, (self.memory_threshold - mem_usage) / self.memory_threshold |
) |
cpu_factor = max(0, (self.cpu_threshold - cpu_usage) / self.cpu_threshold) |
self._last_available_slots = math.floor( |
self.max_concurrent_tasks * min(memory_factor, cpu_factor) |
) |
self._last_check = current_time |
return self._last_available_slots |
class TaskManager: |
def __init__(self, cleanup_interval: int = 300): |
self.tasks: Dict[str, TaskInfo] = {} |
self.high_priority = asyncio.PriorityQueue() |
self.low_priority = asyncio.PriorityQueue() |
self.cleanup_interval = cleanup_interval |
self.cleanup_task = None |
async def start(self): |
self.cleanup_task = asyncio.create_task(self._cleanup_loop()) |
async def stop(self): |
if self.cleanup_task: |
self.cleanup_task.cancel() |
try: |
await self.cleanup_task |
except asyncio.CancelledError: |
pass |
async def add_task(self, task_id: str, priority: int, ttl: int) -> None: |
task_info = TaskInfo(id=task_id, status=TaskStatus.PENDING, ttl=ttl) |
self.tasks[task_id] = task_info |
queue = self.high_priority if priority > 5 else self.low_priority |
await queue.put((-priority, task_id)) |
async def get_next_task(self) -> Optional[str]: |
try: |
_, task_id = await asyncio.wait_for(self.high_priority.get(), timeout=0.1) |
return task_id |
except asyncio.TimeoutError: |
try: |
_, task_id = await asyncio.wait_for( |
self.low_priority.get(), timeout=0.1 |
) |
return task_id |
except asyncio.TimeoutError: |
return None |
def update_task( |
self, task_id: str, status: TaskStatus, result: Any = None, error: str = None |
): |
if task_id in self.tasks: |
task_info = self.tasks[task_id] |
task_info.status = status |
task_info.result = result |
task_info.error = error |
def get_task(self, task_id: str) -> Optional[TaskInfo]: |
return self.tasks.get(task_id) |
async def _cleanup_loop(self): |
while True: |
try: |
await asyncio.sleep(self.cleanup_interval) |
current_time = time.time() |
expired_tasks = [ |
task_id |
for task_id, task in self.tasks.items() |
if current_time - task.created_at > task.ttl |
and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] |
] |
for task_id in expired_tasks: |
del self.tasks[task_id] |
except Exception as e: |
logger.error(f"Error in cleanup loop: {e}") |
class CrawlerPool: |
def __init__(self, max_size: int = 10): |
self.max_size = max_size |
self.active_crawlers: Dict[AsyncWebCrawler, float] = {} |
self._lock = asyncio.Lock() |
async def acquire(self, **kwargs) -> AsyncWebCrawler: |
async with self._lock: |
current_time = time.time() |
inactive = [ |
crawler |
for crawler, last_used in self.active_crawlers.items() |
if current_time - last_used > 600 |
] |
for crawler in inactive: |
await crawler.__aexit__(None, None, None) |
del self.active_crawlers[crawler] |
if len(self.active_crawlers) < self.max_size: |
crawler = AsyncWebCrawler(**kwargs) |
await crawler.__aenter__() |
self.active_crawlers[crawler] = current_time |
return crawler |
crawler = min(self.active_crawlers.items(), key=lambda x: x[1])[0] |
self.active_crawlers[crawler] = current_time |
return crawler |
async def release(self, crawler: AsyncWebCrawler): |
async with self._lock: |
if crawler in self.active_crawlers: |
self.active_crawlers[crawler] = time.time() |
async def cleanup(self): |
async with self._lock: |
for crawler in list(self.active_crawlers.keys()): |
await crawler.__aexit__(None, None, None) |
self.active_crawlers.clear() |
class CrawlerService: |
def __init__(self, max_concurrent_tasks: int = 10): |
self.resource_monitor = ResourceMonitor(max_concurrent_tasks) |
self.task_manager = TaskManager() |
self.crawler_pool = CrawlerPool(max_concurrent_tasks) |
self._processing_task = None |
async def start(self): |
await self.task_manager.start() |
self._processing_task = asyncio.create_task(self._process_queue()) |
async def stop(self): |
if self._processing_task: |
self._processing_task.cancel() |
try: |
await self._processing_task |
except asyncio.CancelledError: |
pass |
await self.task_manager.stop() |
await self.crawler_pool.cleanup() |
def _create_extraction_strategy(self, config: ExtractionConfig): |
if not config: |
return None |
if config.type == CrawlerType.LLM: |
return LLMExtractionStrategy(**config.params) |
elif config.type == CrawlerType.COSINE: |
return CosineStrategy(**config.params) |
elif config.type == CrawlerType.JSON_CSS: |
return JsonCssExtractionStrategy(**config.params) |
return None |
async def submit_task(self, request: CrawlRequest) -> str: |
task_id = str(uuid.uuid4()) |
await self.task_manager.add_task(task_id, request.priority, request.ttl or 3600) |
self.task_manager.tasks[task_id].request = request |
return task_id |
async def _process_queue(self): |
while True: |
try: |
available_slots = await self.resource_monitor.get_available_slots() |
if False and available_slots <= 0: |
await asyncio.sleep(1) |
continue |
task_id = await self.task_manager.get_next_task() |
if not task_id: |
await asyncio.sleep(1) |
continue |
task_info = self.task_manager.get_task(task_id) |
if not task_info: |
continue |
request = task_info.request |
self.task_manager.update_task(task_id, TaskStatus.PROCESSING) |
try: |
crawler = await self.crawler_pool.acquire(**request.crawler_params) |
extraction_strategy = self._create_extraction_strategy( |
request.extraction_config |
) |
if isinstance(request.urls, list): |
results = await crawler.arun_many( |
urls=[str(url) for url in request.urls], |
word_count_threshold=MIN_WORD_THRESHOLD, |
extraction_strategy=extraction_strategy, |
js_code=request.js_code, |
wait_for=request.wait_for, |
css_selector=request.css_selector, |
screenshot=request.screenshot, |
magic=request.magic, |
session_id=request.session_id, |
cache_mode=request.cache_mode, |
**request.extra, |
) |
else: |
results = await crawler.arun( |
url=str(request.urls), |
extraction_strategy=extraction_strategy, |
js_code=request.js_code, |
wait_for=request.wait_for, |
css_selector=request.css_selector, |
screenshot=request.screenshot, |
magic=request.magic, |
session_id=request.session_id, |
cache_mode=request.cache_mode, |
**request.extra, |
) |
await self.crawler_pool.release(crawler) |
self.task_manager.update_task( |
task_id, TaskStatus.COMPLETED, results |
) |
except Exception as e: |
logger.error(f"Error processing task {task_id}: {str(e)}") |
self.task_manager.update_task( |
task_id, TaskStatus.FAILED, error=str(e) |
) |
except Exception as e: |
logger.error(f"Error in queue processing: {str(e)}") |
await asyncio.sleep(1) |
app = FastAPI(title="Crawl4AI API") |
origins = ["*"] |
app.add_middleware( |
CORSMiddleware, |
allow_origins=origins, |
allow_credentials=True, |
allow_methods=["*"], |
allow_headers=["*"], |
) |
security = HTTPBearer() |
async def verify_token(credentials: HTTPAuthorizationCredentials = Security(security)): |
return credentials |
if credentials.credentials != CRAWL4AI_API_TOKEN: |
raise HTTPException(status_code=401, detail="Invalid token") |
return credentials |
def secure_endpoint(): |
"""Returns security dependency only if CRAWL4AI_API_TOKEN is set""" |
return Depends(verify_token) if CRAWL4AI_API_TOKEN else None |
if os.path.exists(__location__ + "/site"): |
app.mount("/mkdocs", StaticFiles(directory="site", html=True), name="mkdocs") |
site_templates = Jinja2Templates(directory=__location__ + "/site") |
crawler_service = CrawlerService() |
@app.on_event("startup") |
async def startup_event(): |
await crawler_service.start() |
@app.on_event("shutdown") |
async def shutdown_event(): |
await crawler_service.stop() |
@app.get("/") |
async def root(): |
return RedirectResponse(url="/docs") |
@app.post("/crawl", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) |
async def crawl(request: CrawlRequest) -> Dict[str, str]: |
task_id = await crawler_service.submit_task(request) |
return {"task_id": task_id} |
@app.get( |
"/task/{task_id}", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] |
) |
async def get_task_status(task_id: str): |
task_info = crawler_service.task_manager.get_task(task_id) |
if not task_info: |
raise HTTPException(status_code=404, detail="Task not found") |
response = { |
"status": task_info.status, |
"created_at": task_info.created_at, |
} |
if task_info.status == TaskStatus.COMPLETED: |
if isinstance(task_info.result, list): |
response["results"] = [result.dict() for result in task_info.result] |
else: |
response["result"] = task_info.result.dict() |
elif task_info.status == TaskStatus.FAILED: |
response["error"] = task_info.error |
return response |
@app.post("/crawl_sync", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else []) |
async def crawl_sync(request: CrawlRequest) -> Dict[str, Any]: |
task_id = await crawler_service.submit_task(request) |
for _ in range(60): |
task_info = crawler_service.task_manager.get_task(task_id) |
if not task_info: |
raise HTTPException(status_code=404, detail="Task not found") |
if task_info.status == TaskStatus.COMPLETED: |
if isinstance(task_info.result, list): |
return { |
"status": task_info.status, |
"results": [result.dict() for result in task_info.result], |
} |
return {"status": task_info.status, "result": task_info.result.dict()} |
if task_info.status == TaskStatus.FAILED: |
raise HTTPException(status_code=500, detail=task_info.error) |
await asyncio.sleep(1) |
raise HTTPException(status_code=408, detail="Task timed out") |
@app.post( |
"/crawl_direct", dependencies=[secure_endpoint()] if CRAWL4AI_API_TOKEN else [] |
) |
async def crawl_direct(request: CrawlRequest) -> Dict[str, Any]: |
logger.info("Received request to crawl directly.") |
try: |
logger.debug("Acquiring crawler from the crawler pool.") |
crawler = await crawler_service.crawler_pool.acquire(**request.crawler_params) |
logger.debug("Crawler acquired successfully.") |
logger.debug("Creating extraction strategy based on the request configuration.") |
extraction_strategy = crawler_service._create_extraction_strategy( |
request.extraction_config |
) |
logger.debug("Extraction strategy created successfully.") |
try: |
if isinstance(request.urls, list): |
logger.info("Processing multiple URLs.") |
results = await crawler.arun_many( |
urls=[str(url) for url in request.urls], |
extraction_strategy=extraction_strategy, |
js_code=request.js_code, |
wait_for=request.wait_for, |
css_selector=request.css_selector, |
screenshot=request.screenshot, |
magic=request.magic, |
cache_mode=request.cache_mode, |
session_id=request.session_id, |
**request.extra, |
) |
logger.info("Crawling completed for multiple URLs.") |
return {"results": [result.dict() for result in results]} |
else: |
logger.info("Processing a single URL.") |
result = await crawler.arun( |
url=str(request.urls), |
extraction_strategy=extraction_strategy, |
js_code=request.js_code, |
wait_for=request.wait_for, |
css_selector=request.css_selector, |
screenshot=request.screenshot, |
magic=request.magic, |
cache_mode=request.cache_mode, |
session_id=request.session_id, |
**request.extra, |
) |
logger.info("Crawling completed for a single URL.") |
return {"result": result.dict()} |
finally: |
logger.debug("Releasing crawler back to the pool.") |
await crawler_service.crawler_pool.release(crawler) |
logger.debug("Crawler released successfully.") |
except Exception as e: |
logger.error(f"Error in direct crawl: {str(e)}") |
raise HTTPException(status_code=500, detail=str(e)) |
@app.get("/health") |
async def health_check(): |
available_slots = await crawler_service.resource_monitor.get_available_slots() |
memory = psutil.virtual_memory() |
return { |
"status": "healthy", |
"available_slots": available_slots, |
"memory_usage": memory.percent, |
"cpu_usage": psutil.cpu_percent(), |
} |
if __name__ == "__main__": |
import uvicorn |
uvicorn.run(app, host="", port=11235) |