File size: 1,462 Bytes
c9576c8 9bdf998 c9576c8 9bdf998 c9576c8 9bdf998 c9576c8 9bdf998 c9576c8 9bdf998 c9576c8 9bdf998 c9576c8 214729c 9bdf998 c9576c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from typing import Dict, Any
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
class EndpointHandler:
def __init__(self, path=""):
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(path)
# Create a pipeline that the inference API expects
self.pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0 if torch.cuda.is_available() else -1
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
prompt_input = data.get("inputs", "")
vibe = data.get("vibe", "Open to All Paths")
# Prepare prompt with Vela's persona
prompt = (
f"#### Human (Vibe: {vibe}): {prompt_input.strip()}\n"
f"#### Assistant (Vela - your Camino companion):"
)
# Default generation params
generation_args = data.get("parameters", {})
generation_args.setdefault("max_new_tokens", 1024)
generation_args.setdefault("temperature", 0.2)
generation_args.setdefault("top_p", 0.95)
generation_args.setdefault("do_sample", True)
# Use pipeline for generation
outputs = self.pipeline(
prompt,
**generation_args
)
return outputs |