Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
""" | |
Test client for the ChatGPT Oasis Model Inference API | |
""" | |
import requests | |
import base64 | |
import json | |
from PIL import Image | |
import io | |
import os | |
# API base URL | |
BASE_URL = "http://localhost:8000" | |
def test_health_check(): | |
"""Test the health check endpoint""" | |
print("Testing health check...") | |
try: | |
response = requests.get(f"{BASE_URL}/health") | |
print(f"Status: {response.status_code}") | |
print(f"Response: {json.dumps(response.json(), indent=2)}") | |
return response.status_code == 200 | |
except Exception as e: | |
print(f"Error: {e}") | |
return False | |
def test_list_models(): | |
"""Test the models list endpoint""" | |
print("\nTesting models list...") | |
try: | |
response = requests.get(f"{BASE_URL}/models") | |
print(f"Status: {response.status_code}") | |
print(f"Response: {json.dumps(response.json(), indent=2)}") | |
return response.status_code == 200 | |
except Exception as e: | |
print(f"Error: {e}") | |
return False | |
def create_test_image(): | |
"""Create a simple test image""" | |
# Create a simple colored rectangle | |
img = Image.new('RGB', (224, 224), color='red') | |
# Save to bytes | |
buffer = io.BytesIO() | |
img.save(buffer, format='JPEG') | |
buffer.seek(0) | |
return buffer.getvalue() | |
def test_base64_inference(): | |
"""Test inference with base64 encoded image""" | |
print("\nTesting base64 inference...") | |
# Create test image | |
image_data = create_test_image() | |
image_base64 = base64.b64encode(image_data).decode() | |
# Test both models | |
for model_name in ["oasis500m", "vit-l-20"]: | |
print(f"\nTesting {model_name}...") | |
try: | |
response = requests.post( | |
f"{BASE_URL}/inference", | |
json={ | |
"image": image_base64, | |
"model_name": model_name | |
}, | |
headers={"Content-Type": "application/json"} | |
) | |
print(f"Status: {response.status_code}") | |
if response.status_code == 200: | |
result = response.json() | |
print(f"Model used: {result['model_used']}") | |
print(f"Top prediction: {result['predictions'][0]}") | |
else: | |
print(f"Error: {response.text}") | |
except Exception as e: | |
print(f"Error: {e}") | |
def test_file_upload_inference(): | |
"""Test inference with file upload""" | |
print("\nTesting file upload inference...") | |
# Create test image | |
image_data = create_test_image() | |
# Test both models | |
for model_name in ["oasis500m", "vit-l-20"]: | |
print(f"\nTesting {model_name} with file upload...") | |
try: | |
files = {'file': ('test_image.jpg', image_data, 'image/jpeg')} | |
data = {'model_name': model_name} | |
response = requests.post( | |
f"{BASE_URL}/upload_inference", | |
files=files, | |
data=data | |
) | |
print(f"Status: {response.status_code}") | |
if response.status_code == 200: | |
result = response.json() | |
print(f"Model used: {result['model_used']}") | |
print(f"Top prediction: {result['predictions'][0]}") | |
else: | |
print(f"Error: {response.text}") | |
except Exception as e: | |
print(f"Error: {e}") | |
def test_with_real_image(image_path): | |
"""Test with a real image file""" | |
if not os.path.exists(image_path): | |
print(f"Image file not found: {image_path}") | |
return | |
print(f"\nTesting with real image: {image_path}") | |
# Test file upload | |
try: | |
with open(image_path, 'rb') as f: | |
files = {'file': (os.path.basename(image_path), f, 'image/jpeg')} | |
data = {'model_name': 'oasis500m'} | |
response = requests.post( | |
f"{BASE_URL}/upload_inference", | |
files=files, | |
data=data | |
) | |
print(f"Status: {response.status_code}") | |
if response.status_code == 200: | |
result = response.json() | |
print(f"Model used: {result['model_used']}") | |
print("Top 3 predictions:") | |
for i, pred in enumerate(result['predictions'][:3]): | |
print(f" {i+1}. {pred['label']} ({pred['confidence']:.3f})") | |
else: | |
print(f"Error: {response.text}") | |
except Exception as e: | |
print(f"Error: {e}") | |
def main(): | |
"""Run all tests""" | |
print("ChatGPT Oasis Model Inference API - Test Client") | |
print("=" * 50) | |
# Test basic endpoints | |
health_ok = test_health_check() | |
models_ok = test_list_models() | |
if not health_ok: | |
print("Health check failed. Make sure the server is running!") | |
return | |
# Test inference endpoints | |
test_base64_inference() | |
test_file_upload_inference() | |
# Test with real image if available | |
test_images = ["test.jpg", "sample.jpg", "image.jpg"] | |
for img in test_images: | |
if os.path.exists(img): | |
test_with_real_image(img) | |
break | |
print("\n" + "=" * 50) | |
print("Test completed!") | |
if __name__ == "__main__": | |
main() | |