Spaces:
Build error
Build error
| from octo.model.octo_model import OctoModel | |
| from PIL import Image | |
| import numpy as np | |
| import jax | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import os | |
| import io | |
| import base64 | |
| from typing import List | |
| from fastapi.openapi.docs import get_swagger_ui_html | |
| # Set JAX to use CPU platform (adjust if GPU is needed) | |
| os.environ['JAX_PLATFORMS'] = 'cpu' | |
| # Load the model once globally | |
| model = OctoModel.load_pretrained("hf://rail-berkeley/octo-base-1.5") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Octo Model Inference API", | |
| docs_url="/" # Swagger UI at root | |
| ) | |
| # Define request body model | |
| class InferenceRequest(BaseModel): | |
| image_base64: List[str] # List of base64-encoded images in time sequence | |
| task: str = "pick up the fork" # Default task | |
| window_size: int = 2 # Default window size, configurable | |
| # Health check endpoint | |
| async def health_check(): | |
| return {"status": "healthy"} | |
| # Inference endpoint | |
| async def predict(request: InferenceRequest, dataset_name: str = "bridge_dataset"): | |
| try: | |
| # Validate input | |
| if len(request.image_base64) < request.window_size: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"At least {request.window_size} images required for the specified window size" | |
| ) | |
| # Process images | |
| images = [] | |
| for img_base64 in request.image_base64: | |
| if img_base64.startswith("data:image"): | |
| img_base64 = img_base64.split(",")[1] | |
| img_data = base64.b64decode(img_base64) | |
| img = Image.open(io.BytesIO(img_data)).resize((256, 256)) | |
| img = np.array(img) | |
| images.append(img) | |
| # Stack all images and add batch dimension | |
| img_array = np.stack(images)[np.newaxis, ...] # Shape: (1, T, 256, 256, 3) | |
| observation = { | |
| "image_primary": img_array, | |
| "timestep_pad_mask": np.full((1, len(images)), True, dtype=bool) # Shape: (1, T) | |
| } | |
| # Create task and predict actions | |
| task_obj = model.create_tasks(texts=[request.task]) | |
| actions = model.sample_actions( | |
| observation, | |
| task_obj, | |
| unnormalization_statistics=model.dataset_statistics[dataset_name]["action"], | |
| rng=jax.random.PRNGKey(0) | |
| ) | |
| actions = actions[0] # Remove batch dimension, Shape: (horizon, action_dim) | |
| # Convert to list for JSON response | |
| actions_list = actions.tolist() | |
| return {"actions": actions_list} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") | |
| # Custom Swagger UI route (optional) | |
| async def custom_swagger_ui_html(): | |
| return get_swagger_ui_html( | |
| openapi_url=app.openapi_url, | |
| title=app.title + " - Swagger UI", | |
| oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url, | |
| ) |