c4ai / tests /test_main.py
amaye15
test
03c0888
import asyncio
import aiohttp
import json
import time
import os
from typing import Optional, Dict, Any
from pydantic import BaseModel, HttpUrl
class NBCNewsAPITest:
def __init__(self, base_url: str = "http://localhost:8000"):
self.base_url = base_url
self.session = None
async def __aenter__(self):
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.session:
await self.session.close()
async def submit_crawl(self, request_data: Dict[str, Any]) -> str:
async with self.session.post(f"{self.base_url}/crawl", json=request_data) as response:
result = await response.json()
return result["task_id"]
async def get_task_status(self, task_id: str) -> Dict[str, Any]:
async with self.session.get(f"{self.base_url}/task/{task_id}") as response:
return await response.json()
async def wait_for_task(self, task_id: str, timeout: int = 300, poll_interval: int = 2) -> Dict[str, Any]:
start_time = time.time()
while True:
if time.time() - start_time > timeout:
raise TimeoutError(f"Task {task_id} did not complete within {timeout} seconds")
status = await self.get_task_status(task_id)
if status["status"] in ["completed", "failed"]:
return status
await asyncio.sleep(poll_interval)
async def check_health(self) -> Dict[str, Any]:
async with self.session.get(f"{self.base_url}/health") as response:
return await response.json()
async def test_basic_crawl():
print("\n=== Testing Basic Crawl ===")
async with NBCNewsAPITest() as api:
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 10
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
print(f"Basic crawl result length: {len(result['result']['markdown'])}")
assert result["status"] == "completed"
assert "result" in result
assert result["result"]["success"]
async def test_js_execution():
print("\n=== Testing JS Execution ===")
async with NBCNewsAPITest() as api:
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 8,
"js_code": [
"const loadMoreButton = Array.from(document.querySelectorAll('button')).find(button => button.textContent.includes('Load More')); loadMoreButton && loadMoreButton.click();"
],
"wait_for": "article.tease-card:nth-child(10)",
"crawler_params": {
"headless": True
}
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
print(f"JS execution result length: {len(result['result']['markdown'])}")
assert result["status"] == "completed"
assert result["result"]["success"]
async def test_css_selector():
print("\n=== Testing CSS Selector ===")
async with NBCNewsAPITest() as api:
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 7,
"css_selector": ".wide-tease-item__description"
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
print(f"CSS selector result length: {len(result['result']['markdown'])}")
assert result["status"] == "completed"
assert result["result"]["success"]
async def test_structured_extraction():
print("\n=== Testing Structured Extraction ===")
async with NBCNewsAPITest() as api:
schema = {
"name": "NBC News Articles",
"baseSelector": "article.tease-card",
"fields": [
{
"name": "title",
"selector": "h2",
"type": "text"
},
{
"name": "description",
"selector": ".tease-card__description",
"type": "text"
},
{
"name": "link",
"selector": "a",
"type": "attribute",
"attribute": "href"
}
]
}
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 9,
"extraction_config": {
"type": "json_css",
"params": {
"schema": schema
}
}
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted {len(extracted)} articles")
assert result["status"] == "completed"
assert result["result"]["success"]
assert len(extracted) > 0
async def test_batch_crawl():
print("\n=== Testing Batch Crawl ===")
async with NBCNewsAPITest() as api:
request = {
"urls": [
"https://www.nbcnews.com/business",
"https://www.nbcnews.com/business/consumer",
"https://www.nbcnews.com/business/economy"
],
"priority": 6,
"crawler_params": {
"headless": True
}
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
print(f"Batch crawl completed, got {len(result['results'])} results")
assert result["status"] == "completed"
assert "results" in result
assert len(result["results"]) == 3
async def test_llm_extraction():
print("\n=== Testing LLM Extraction with Ollama ===")
async with NBCNewsAPITest() as api:
schema = {
"type": "object",
"properties": {
"article_title": {
"type": "string",
"description": "The main title of the news article"
},
"summary": {
"type": "string",
"description": "A brief summary of the article content"
},
"main_topics": {
"type": "array",
"items": {"type": "string"},
"description": "Main topics or themes discussed in the article"
}
},
"required": ["article_title", "summary", "main_topics"]
}
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 8,
"extraction_config": {
"type": "llm",
"params": {
"provider": "openai/gpt-4o-mini",
"api_key": os.getenv("OLLAMA_API_KEY"),
"schema": schema,
"extraction_type": "schema",
"instruction": """Extract the main article information including title, a brief summary, and main topics discussed.
Focus on the primary business news article on the page."""
}
},
"crawler_params": {
"headless": True,
"word_count_threshold": 1
}
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
if result["status"] == "completed":
extracted = json.loads(result["result"]["extracted_content"])
print(f"Extracted article analysis:")
print(json.dumps(extracted, indent=2))
assert result["status"] == "completed"
assert result["result"]["success"]
async def test_screenshot():
print("\n=== Testing Screenshot ===")
async with NBCNewsAPITest() as api:
request = {
"urls": "https://www.nbcnews.com/business",
"priority": 5,
"screenshot": True,
"crawler_params": {
"headless": True
}
}
task_id = await api.submit_crawl(request)
result = await api.wait_for_task(task_id)
print("Screenshot captured:", bool(result["result"]["screenshot"]))
assert result["status"] == "completed"
assert result["result"]["success"]
assert result["result"]["screenshot"] is not None
async def test_priority_handling():
print("\n=== Testing Priority Handling ===")
async with NBCNewsAPITest() as api:
# Submit low priority task first
low_priority = {
"urls": "https://www.nbcnews.com/business",
"priority": 1,
"crawler_params": {"headless": True}
}
low_task_id = await api.submit_crawl(low_priority)
# Submit high priority task
high_priority = {
"urls": "https://www.nbcnews.com/business/consumer",
"priority": 10,
"crawler_params": {"headless": True}
}
high_task_id = await api.submit_crawl(high_priority)
# Get both results
high_result = await api.wait_for_task(high_task_id)
low_result = await api.wait_for_task(low_task_id)
print("Both tasks completed")
assert high_result["status"] == "completed"
assert low_result["status"] == "completed"
async def main():
try:
# Start with health check
async with NBCNewsAPITest() as api:
health = await api.check_health()
print("Server health:", health)
# Run all tests
# await test_basic_crawl()
# await test_js_execution()
# await test_css_selector()
# await test_structured_extraction()
await test_llm_extraction()
# await test_batch_crawl()
# await test_screenshot()
# await test_priority_handling()
except Exception as e:
print(f"Test failed: {str(e)}")
raise
if __name__ == "__main__":
asyncio.run(main())