chaima01 commited on
Commit
c9576c8
·
verified ·
1 Parent(s): 17f7dee

fixed the no attribute 'pipeline' bug

Browse files
Files changed (1) hide show
  1. handler.py +26 -16
handler.py CHANGED
@@ -1,32 +1,42 @@
1
- from typing import Dict, Any, List
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextGenerationPipeline
3
  import torch
 
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
- # Load model and tokenizer
8
  self.tokenizer = AutoTokenizer.from_pretrained(path)
9
- self.model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=torch.float16)
 
 
 
 
 
 
 
 
10
 
11
- def __call__(self, data: Dict[str, Any]) -> List[Dict[str, str]]:
12
  prompt_input = data.get("inputs", "")
13
- vibe = data.get("vibe", "Open to All Paths") # Default fallback
14
-
15
- # Construct Camino-aware prompt
16
- full_prompt = (
17
  f"#### Human (Vibe: {vibe}): {prompt_input.strip()}\n"
18
  f"#### Assistant (Vela - your Camino companion):"
19
  )
20
-
21
  # Default generation params
22
  generation_args = data.get("parameters", {})
23
  generation_args.setdefault("max_new_tokens", 1024)
24
  generation_args.setdefault("temperature", 0.2)
25
  generation_args.setdefault("top_p", 0.95)
26
  generation_args.setdefault("do_sample", True)
27
-
28
- # Generate response
29
- outputs = self.pipeline(full_prompt, **generation_args)
30
-
31
- # Return in correct format
32
- return [{"generated_text": outputs[0]["generated_text"]}]
 
 
 
1
+ from typing import Dict, Any
 
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
 
5
  class EndpointHandler:
6
  def __init__(self, path=""):
7
+ # Load tokenizer and model
8
  self.tokenizer = AutoTokenizer.from_pretrained(path)
9
+ self.model = AutoModelForCausalLM.from_pretrained(path)
10
+
11
+ # Create a pipeline that the inference API expects
12
+ self.pipeline = pipeline(
13
+ "text-generation",
14
+ model=self.model,
15
+ tokenizer=self.tokenizer,
16
+ device=0 if torch.cuda.is_available() else -1
17
+ )
18
 
19
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
20
  prompt_input = data.get("inputs", "")
21
+ vibe = data.get("vibe", "Open to All Paths")
22
+
23
+ # Prepare prompt with Vela's persona
24
+ prompt = (
25
  f"#### Human (Vibe: {vibe}): {prompt_input.strip()}\n"
26
  f"#### Assistant (Vela - your Camino companion):"
27
  )
28
+
29
  # Default generation params
30
  generation_args = data.get("parameters", {})
31
  generation_args.setdefault("max_new_tokens", 1024)
32
  generation_args.setdefault("temperature", 0.2)
33
  generation_args.setdefault("top_p", 0.95)
34
  generation_args.setdefault("do_sample", True)
35
+
36
+ # Use pipeline for generation
37
+ outputs = self.pipeline(
38
+ prompt,
39
+ **generation_args
40
+ )
41
+
42
+ return outputs