inference_endpoint (#2)
Browse files- added endpoint scripts (1259b94c50e3fa4f31bce132de4f3ec7596aaa5b)
Co-authored-by: Ivan Moshkov <[email protected]>
- Dockerfile +26 -0
- entrypoint.sh +32 -0
- handler.py +139 -0
- server.py +77 -0
    	
        Dockerfile
    ADDED
    
    | @@ -0,0 +1,26 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            FROM igitman/nemo-skills-vllm:0.6.0 as base
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            # Install NeMo-Skills and dependencies
         | 
| 4 | 
            +
            RUN git clone https://github.com/NVIDIA/NeMo-Skills \
         | 
| 5 | 
            +
                && cd NeMo-Skills \
         | 
| 6 | 
            +
                && pip install --ignore-installed blinker \
         | 
| 7 | 
            +
                && pip install -e . \
         | 
| 8 | 
            +
                && pip install -r requirements/code_execution.txt
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Ensure python is available
         | 
| 11 | 
            +
            RUN ln -s /usr/bin/python3 /usr/bin/python
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Copy our custom files
         | 
| 14 | 
            +
            COPY handler.py server.py /usr/local/endpoint/
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # Expose port 80
         | 
| 17 | 
            +
            EXPOSE 80
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Copy and set up entrypoint script
         | 
| 20 | 
            +
            COPY entrypoint.sh /usr/local/endpoint/
         | 
| 21 | 
            +
            RUN chmod +x /usr/local/endpoint/entrypoint.sh
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # Set working directory
         | 
| 24 | 
            +
            WORKDIR /usr/local/endpoint
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ENTRYPOINT ["/usr/local/endpoint/entrypoint.sh"]
         | 
    	
        entrypoint.sh
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/bin/bash
         | 
| 2 | 
            +
            set -e
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            # Default environment variables
         | 
| 5 | 
            +
            export MODEL_PATH=${MODEL_PATH:-"/repository"}
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            echo "Starting NeMo Skills inference endpoint..."
         | 
| 8 | 
            +
            echo "Model path: $MODEL_PATH"
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Function to handle cleanup on exit
         | 
| 11 | 
            +
            cleanup() {
         | 
| 12 | 
            +
                echo "Cleaning up processes..."
         | 
| 13 | 
            +
                kill $(jobs -p) 2>/dev/null || true
         | 
| 14 | 
            +
                wait
         | 
| 15 | 
            +
            }
         | 
| 16 | 
            +
            trap cleanup EXIT
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Start the model server in the background
         | 
| 19 | 
            +
            echo "Starting model server..."
         | 
| 20 | 
            +
            ns start_server \
         | 
| 21 | 
            +
               --model="$MODEL_PATH" \
         | 
| 22 | 
            +
               --server_gpus=2 \
         | 
| 23 | 
            +
               --server_type=vllm \
         | 
| 24 | 
            +
               --with_sandbox &
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # Start the HTTP endpoint
         | 
| 27 | 
            +
            echo "Starting HTTP endpoint on port 80..."
         | 
| 28 | 
            +
            python /usr/local/endpoint/server.py &
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            # Wait for both processes
         | 
| 31 | 
            +
            echo "Both servers started. Waiting..."
         | 
| 32 | 
            +
            wait
         | 
    	
        handler.py
    ADDED
    
    | @@ -0,0 +1,139 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            import traceback
         | 
| 4 | 
            +
            from typing import Dict, List, Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from nemo_skills.inference.server.code_execution_model import get_code_execution_model
         | 
| 7 | 
            +
            from nemo_skills.code_execution.sandbox import get_sandbox
         | 
| 8 | 
            +
            from nemo_skills.prompt.utils import get_prompt
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # Configure logging
         | 
| 11 | 
            +
            logging.basicConfig(level=logging.INFO)
         | 
| 12 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class EndpointHandler:
         | 
| 16 | 
            +
                """Custom endpoint handler for NeMo Skills code execution inference."""
         | 
| 17 | 
            +
                
         | 
| 18 | 
            +
                def __init__(self):
         | 
| 19 | 
            +
                    """
         | 
| 20 | 
            +
                    Initialize the handler with the model and prompt configurations.
         | 
| 21 | 
            +
                    """
         | 
| 22 | 
            +
                    self.model = None
         | 
| 23 | 
            +
                    self.prompt = None
         | 
| 24 | 
            +
                    self.initialized = False
         | 
| 25 | 
            +
                    
         | 
| 26 | 
            +
                    # Configuration
         | 
| 27 | 
            +
                    self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
         | 
| 28 | 
            +
                    self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
         | 
| 29 | 
            +
                    
         | 
| 30 | 
            +
                def _initialize_components(self):
         | 
| 31 | 
            +
                    """Initialize the model, sandbox, and prompt components lazily."""
         | 
| 32 | 
            +
                    if self.initialized:
         | 
| 33 | 
            +
                        return
         | 
| 34 | 
            +
                        
         | 
| 35 | 
            +
                    try:
         | 
| 36 | 
            +
                        logger.info("Initializing sandbox...")
         | 
| 37 | 
            +
                        sandbox = get_sandbox(sandbox_type="local")
         | 
| 38 | 
            +
                        
         | 
| 39 | 
            +
                        logger.info("Initializing code execution model...")
         | 
| 40 | 
            +
                        self.model = get_code_execution_model(
         | 
| 41 | 
            +
                            server_type="vllm",
         | 
| 42 | 
            +
                            sandbox=sandbox,
         | 
| 43 | 
            +
                            host="127.0.0.1",
         | 
| 44 | 
            +
                            port=5000
         | 
| 45 | 
            +
                        )
         | 
| 46 | 
            +
                        
         | 
| 47 | 
            +
                        logger.info("Initializing prompt...")
         | 
| 48 | 
            +
                        if self.prompt_config_path:
         | 
| 49 | 
            +
                            self.prompt = get_prompt(
         | 
| 50 | 
            +
                                prompt_config=self.prompt_config_path,
         | 
| 51 | 
            +
                                prompt_template=self.prompt_template_path
         | 
| 52 | 
            +
                            )
         | 
| 53 | 
            +
                        
         | 
| 54 | 
            +
                        self.initialized = True
         | 
| 55 | 
            +
                        logger.info("All components initialized successfully")
         | 
| 56 | 
            +
                        
         | 
| 57 | 
            +
                    except Exception as e:
         | 
| 58 | 
            +
                        logger.warning(f"Failed to initialize the model")
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
         | 
| 61 | 
            +
                    """
         | 
| 62 | 
            +
                    Process inference requests.
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    Args:
         | 
| 65 | 
            +
                        data: Dictionary containing the request data
         | 
| 66 | 
            +
                            Expected keys:
         | 
| 67 | 
            +
                            - inputs: str or list of str - the input prompts/problems
         | 
| 68 | 
            +
                            - parameters: dict (optional) - generation parameters
         | 
| 69 | 
            +
                            
         | 
| 70 | 
            +
                    Returns:
         | 
| 71 | 
            +
                        List of dictionaries containing the generated responses
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                    try:
         | 
| 74 | 
            +
                        # Initialize components if not already done
         | 
| 75 | 
            +
                        self._initialize_components()
         | 
| 76 | 
            +
                        
         | 
| 77 | 
            +
                        # Extract inputs and parameters
         | 
| 78 | 
            +
                        inputs = data.get("inputs", "")
         | 
| 79 | 
            +
                        parameters = data.get("parameters", {})
         | 
| 80 | 
            +
                        
         | 
| 81 | 
            +
                        # Handle both single string and list of strings
         | 
| 82 | 
            +
                        if isinstance(inputs, str):
         | 
| 83 | 
            +
                            prompts = [inputs]
         | 
| 84 | 
            +
                        elif isinstance(inputs, list):
         | 
| 85 | 
            +
                            prompts = inputs
         | 
| 86 | 
            +
                        else:
         | 
| 87 | 
            +
                            raise ValueError("inputs must be a string or list of strings")
         | 
| 88 | 
            +
                        
         | 
| 89 | 
            +
                        # If we have a prompt template configured, format the inputs
         | 
| 90 | 
            +
                        if self.prompt is not None:
         | 
| 91 | 
            +
                            formatted_prompts = []
         | 
| 92 | 
            +
                            for prompt_text in prompts:
         | 
| 93 | 
            +
                                formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
         | 
| 94 | 
            +
                                formatted_prompts.append(formatted_prompt)
         | 
| 95 | 
            +
                            prompts = formatted_prompts
         | 
| 96 | 
            +
                        
         | 
| 97 | 
            +
                        # Get code execution arguments from prompt if available
         | 
| 98 | 
            +
                        extra_generate_params = {}
         | 
| 99 | 
            +
                        if self.prompt is not None:
         | 
| 100 | 
            +
                            extra_generate_params = self.prompt.get_code_execution_args()
         | 
| 101 | 
            +
                        
         | 
| 102 | 
            +
                        # Set default generation parameters
         | 
| 103 | 
            +
                        generation_params = {
         | 
| 104 | 
            +
                            "tokens_to_generate": 12000,
         | 
| 105 | 
            +
                            "temperature": 0.0,
         | 
| 106 | 
            +
                            "top_p": 0.95,
         | 
| 107 | 
            +
                            "top_k": 0,
         | 
| 108 | 
            +
                            "repetition_penalty": 1.0,
         | 
| 109 | 
            +
                            "random_seed": 0,
         | 
| 110 | 
            +
                        }
         | 
| 111 | 
            +
                        
         | 
| 112 | 
            +
                        # Update with provided parameters
         | 
| 113 | 
            +
                        generation_params.update(parameters)
         | 
| 114 | 
            +
                        generation_params.update(extra_generate_params)
         | 
| 115 | 
            +
                        
         | 
| 116 | 
            +
                        logger.info(f"Processing {len(prompts)} prompt(s)")
         | 
| 117 | 
            +
                        
         | 
| 118 | 
            +
                        # Generate responses
         | 
| 119 | 
            +
                        outputs = self.model.generate(
         | 
| 120 | 
            +
                            prompts=prompts,
         | 
| 121 | 
            +
                            **generation_params
         | 
| 122 | 
            +
                        )
         | 
| 123 | 
            +
                        
         | 
| 124 | 
            +
                        # Format outputs
         | 
| 125 | 
            +
                        results = []
         | 
| 126 | 
            +
                        for output in outputs:
         | 
| 127 | 
            +
                            result = {
         | 
| 128 | 
            +
                                "generated_text": output.get("generation", ""),
         | 
| 129 | 
            +
                                "code_rounds_executed": output.get("code_rounds_executed", 0),
         | 
| 130 | 
            +
                            }
         | 
| 131 | 
            +
                            results.append(result)
         | 
| 132 | 
            +
                        
         | 
| 133 | 
            +
                        logger.info(f"Successfully processed {len(results)} request(s)")
         | 
| 134 | 
            +
                        return results
         | 
| 135 | 
            +
                        
         | 
| 136 | 
            +
                    except Exception as e:
         | 
| 137 | 
            +
                        logger.error(f"Error processing request: {str(e)}")
         | 
| 138 | 
            +
                        logger.error(traceback.format_exc())
         | 
| 139 | 
            +
                        return [{"error": str(e), "generated_text": ""}]
         | 
    	
        server.py
    ADDED
    
    | @@ -0,0 +1,77 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import json
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            from http.server import HTTPServer, BaseHTTPRequestHandler
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from handler import EndpointHandler
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            # Configure logging
         | 
| 8 | 
            +
            logging.basicConfig(level=logging.INFO)
         | 
| 9 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # Initialize the handler
         | 
| 12 | 
            +
            handler = EndpointHandler()
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class RequestHandler(BaseHTTPRequestHandler):
         | 
| 16 | 
            +
                def do_POST(self):
         | 
| 17 | 
            +
                    try:
         | 
| 18 | 
            +
                        content_length = int(self.headers['Content-Length'])
         | 
| 19 | 
            +
                        post_data = self.rfile.read(content_length)
         | 
| 20 | 
            +
                        data = json.loads(post_data.decode('utf-8'))
         | 
| 21 | 
            +
                        
         | 
| 22 | 
            +
                        logger.info(f'Received request with {len(data.get("inputs", []))} inputs')
         | 
| 23 | 
            +
                        
         | 
| 24 | 
            +
                        # Process the request
         | 
| 25 | 
            +
                        result = handler(data)
         | 
| 26 | 
            +
                        
         | 
| 27 | 
            +
                        # Send response
         | 
| 28 | 
            +
                        self.send_response(200)
         | 
| 29 | 
            +
                        self.send_header('Content-Type', 'application/json')
         | 
| 30 | 
            +
                        self.end_headers()
         | 
| 31 | 
            +
                        self.wfile.write(json.dumps(result).encode('utf-8'))
         | 
| 32 | 
            +
                        
         | 
| 33 | 
            +
                    except Exception as e:
         | 
| 34 | 
            +
                        logger.error(f'Error processing request: {str(e)}')
         | 
| 35 | 
            +
                        self.send_response(500)
         | 
| 36 | 
            +
                        self.send_header('Content-Type', 'application/json')
         | 
| 37 | 
            +
                        self.end_headers()
         | 
| 38 | 
            +
                        error_response = [{'error': str(e), 'generated_text': ''}]
         | 
| 39 | 
            +
                        self.wfile.write(json.dumps(error_response).encode('utf-8'))
         | 
| 40 | 
            +
                
         | 
| 41 | 
            +
                def do_GET(self):
         | 
| 42 | 
            +
                    if self.path == '/health':
         | 
| 43 | 
            +
                        # Trigger initialisation if needed but don't block.
         | 
| 44 | 
            +
                        if not handler.initialized:
         | 
| 45 | 
            +
                            try:
         | 
| 46 | 
            +
                                handler._initialize_components()
         | 
| 47 | 
            +
                            except Exception as e:
         | 
| 48 | 
            +
                                logger.error(f'Initialization failed during health check: {str(e)}')
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                        is_ready = handler.initialized
         | 
| 51 | 
            +
                        health_response = {
         | 
| 52 | 
            +
                            'status': 'healthy' if is_ready else 'unhealthy',
         | 
| 53 | 
            +
                            'model_ready': is_ready
         | 
| 54 | 
            +
                        }
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                        try:
         | 
| 57 | 
            +
                            self.send_response(200 if is_ready else 503)
         | 
| 58 | 
            +
                            self.send_header('Content-Type', 'application/json')
         | 
| 59 | 
            +
                            self.end_headers()
         | 
| 60 | 
            +
                            self.wfile.write(json.dumps(health_response).encode('utf-8'))
         | 
| 61 | 
            +
                        except BrokenPipeError:
         | 
| 62 | 
            +
                            # Client disconnected before we replied – safe to ignore.
         | 
| 63 | 
            +
                            pass
         | 
| 64 | 
            +
                        return
         | 
| 65 | 
            +
                    else:
         | 
| 66 | 
            +
                        self.send_response(404)
         | 
| 67 | 
            +
                        self.end_headers()
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def log_message(self, format, *args):
         | 
| 70 | 
            +
                    # Suppress default HTTP server logs to keep output clean
         | 
| 71 | 
            +
                    pass
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            if __name__ == "__main__":
         | 
| 75 | 
            +
                server = HTTPServer(('0.0.0.0', 80), RequestHandler)
         | 
| 76 | 
            +
                logger.info('HTTP server started on port 80')
         | 
| 77 | 
            +
                server.serve_forever() 
         | 

 
		