chatgpt-oasis / test_client.py
parthraninga's picture
Upload 10 files
95efa57 verified
#!/usr/bin/env python3
"""
Test client for the ChatGPT Oasis Model Inference API
"""
import requests
import base64
import json
from PIL import Image
import io
import os
# API base URL
BASE_URL = "http://localhost:8000"
def test_health_check():
"""Test the health check endpoint"""
print("Testing health check...")
try:
response = requests.get(f"{BASE_URL}/health")
print(f"Status: {response.status_code}")
print(f"Response: {json.dumps(response.json(), indent=2)}")
return response.status_code == 200
except Exception as e:
print(f"Error: {e}")
return False
def test_list_models():
"""Test the models list endpoint"""
print("\nTesting models list...")
try:
response = requests.get(f"{BASE_URL}/models")
print(f"Status: {response.status_code}")
print(f"Response: {json.dumps(response.json(), indent=2)}")
return response.status_code == 200
except Exception as e:
print(f"Error: {e}")
return False
def create_test_image():
"""Create a simple test image"""
# Create a simple colored rectangle
img = Image.new('RGB', (224, 224), color='red')
# Save to bytes
buffer = io.BytesIO()
img.save(buffer, format='JPEG')
buffer.seek(0)
return buffer.getvalue()
def test_base64_inference():
"""Test inference with base64 encoded image"""
print("\nTesting base64 inference...")
# Create test image
image_data = create_test_image()
image_base64 = base64.b64encode(image_data).decode()
# Test both models
for model_name in ["oasis500m", "vit-l-20"]:
print(f"\nTesting {model_name}...")
try:
response = requests.post(
f"{BASE_URL}/inference",
json={
"image": image_base64,
"model_name": model_name
},
headers={"Content-Type": "application/json"}
)
print(f"Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"Model used: {result['model_used']}")
print(f"Top prediction: {result['predictions'][0]}")
else:
print(f"Error: {response.text}")
except Exception as e:
print(f"Error: {e}")
def test_file_upload_inference():
"""Test inference with file upload"""
print("\nTesting file upload inference...")
# Create test image
image_data = create_test_image()
# Test both models
for model_name in ["oasis500m", "vit-l-20"]:
print(f"\nTesting {model_name} with file upload...")
try:
files = {'file': ('test_image.jpg', image_data, 'image/jpeg')}
data = {'model_name': model_name}
response = requests.post(
f"{BASE_URL}/upload_inference",
files=files,
data=data
)
print(f"Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"Model used: {result['model_used']}")
print(f"Top prediction: {result['predictions'][0]}")
else:
print(f"Error: {response.text}")
except Exception as e:
print(f"Error: {e}")
def test_with_real_image(image_path):
"""Test with a real image file"""
if not os.path.exists(image_path):
print(f"Image file not found: {image_path}")
return
print(f"\nTesting with real image: {image_path}")
# Test file upload
try:
with open(image_path, 'rb') as f:
files = {'file': (os.path.basename(image_path), f, 'image/jpeg')}
data = {'model_name': 'oasis500m'}
response = requests.post(
f"{BASE_URL}/upload_inference",
files=files,
data=data
)
print(f"Status: {response.status_code}")
if response.status_code == 200:
result = response.json()
print(f"Model used: {result['model_used']}")
print("Top 3 predictions:")
for i, pred in enumerate(result['predictions'][:3]):
print(f" {i+1}. {pred['label']} ({pred['confidence']:.3f})")
else:
print(f"Error: {response.text}")
except Exception as e:
print(f"Error: {e}")
def main():
"""Run all tests"""
print("ChatGPT Oasis Model Inference API - Test Client")
print("=" * 50)
# Test basic endpoints
health_ok = test_health_check()
models_ok = test_list_models()
if not health_ok:
print("Health check failed. Make sure the server is running!")
return
# Test inference endpoints
test_base64_inference()
test_file_upload_inference()
# Test with real image if available
test_images = ["test.jpg", "sample.jpg", "image.jpg"]
for img in test_images:
if os.path.exists(img):
test_with_real_image(img)
break
print("\n" + "=" * 50)
print("Test completed!")
if __name__ == "__main__":
main()