from typing import List, Dict, Any from logits import LogitsPredictor class EndpointHandler: def __init__(self, path=""): self.predictor = LogitsPredictor() self.predictor.setup(path) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: # Extract parameters from the data dictionary trg_text = data.pop("inputs", "") parameters = data.get("parameters", {}) prefix_text = parameters.get("prefix_text", "") context_length = parameters.get("context_length", 1024) stride = parameters.get("stride", 512) topk = parameters.get("topk", -1) perf_metadata = parameters.get("perf_metadata", False) return self.predictor.predict( trg_text, prefix_text, context_length, stride, topk, perf_metadata )