|
from typing import Any, Dict |
|
|
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir: str, **kwargs: Any) -> None: |
|
self.model = AutoModel.from_pretrained( |
|
model_dir, |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
use_flash_attn=False, |
|
trust_remote_code=True, |
|
device_map=split_model(), |
|
).eval() |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
model_dir, trust_remote_code=True, use_fast=False |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Any: |
|
logger.info(f"Received incoming request with {data=}") |
|
|
|
|
|
if __name__ == "__main__": |
|
handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct") |
|
print(handler) |
|
|