CodyBontecou commited on
Commit
4658eb5
·
1 Parent(s): f992a7f

handler from tut

Browse files
Files changed (1) hide show
  1. handler.py +18 -161
handler.py CHANGED
@@ -1,168 +1,25 @@
1
- from typing import Dict, List, Any, Optional, Union
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
3
- import torch
4
- import base64
5
- from io import BytesIO
6
- from PIL import Image
7
- import requests
8
 
9
 
10
  class EndpointHandler:
11
- def __init__(self, path=""):
12
- # If path is empty, use the GSAI-ML/LLaDA-8B-Instruct model
13
- if not path:
14
- path = "GSAI-ML/LLaDA-8B-Instruct"
15
-
16
- print(f"Loading model from {path}...")
17
-
18
- # Load model with half precision to save memory
19
- self.model = AutoModelForCausalLM.from_pretrained(
20
- path, torch_dtype=torch.float16, device_map="auto"
 
 
21
  )
22
 
23
- # Load tokenizer
24
- self.tokenizer = AutoTokenizer.from_pretrained(path)
25
-
26
- # Load processor for handling images
27
- self.processor = AutoProcessor.from_pretrained(path)
28
-
29
- # Ensure pad token is properly set
30
- if self.tokenizer.pad_token_id is None:
31
- if (
32
- hasattr(self.tokenizer, "eos_token_id")
33
- and self.tokenizer.eos_token_id is not None
34
- ):
35
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
36
- self.tokenizer.pad_token = self.tokenizer.eos_token
37
- else:
38
- # Fallback to a common pad token
39
- self.tokenizer.pad_token_id = 0
40
- self.tokenizer.pad_token = self.tokenizer.convert_ids_to_tokens(0)
41
-
42
- print(f"Model loaded successfully. Pad token ID: {self.tokenizer.pad_token_id}")
43
-
44
- def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
45
- """Load image from URL or base64 encoded string"""
46
- if isinstance(image_data, str):
47
- if image_data.startswith("http"):
48
- # Load from URL
49
- response = requests.get(image_data, stream=True)
50
- response.raise_for_status()
51
- return Image.open(BytesIO(response.content))
52
- elif image_data.startswith("data:image"):
53
- # Handle base64 encoded image
54
- base64_data = image_data.split(",")[1]
55
- image_bytes = base64.b64decode(base64_data)
56
- return Image.open(BytesIO(image_bytes))
57
- else:
58
- # Assume it's a base64 string without the prefix
59
- try:
60
- image_bytes = base64.b64decode(image_data)
61
- return Image.open(BytesIO(image_bytes))
62
- except Exception as e:
63
- raise ValueError(f"Invalid image data format: {e}")
64
- elif isinstance(image_data, bytes):
65
- return Image.open(BytesIO(image_data))
66
- else:
67
- raise ValueError(f"Unsupported image data type: {type(image_data)}")
68
-
69
- def _format_prompt(self, text: str, system_prompt: Optional[str] = None) -> str:
70
- """Format the prompt according to LLaDA's expected format"""
71
- # Default system prompt for LLaDA if none provided
72
- if system_prompt is None:
73
- system_prompt = (
74
- "You are a helpful AI assistant that can understand images and text."
75
- )
76
-
77
- # Format the prompt following LLaDA's expected structure
78
- formatted_prompt = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n"
79
- return formatted_prompt
80
-
81
- def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
82
- """Handle inference requests"""
83
- # Extract inputs and parameters from request data
84
- inputs = data.pop("inputs", data)
85
- parameters = data.pop("parameters", {})
86
-
87
- # Extract image data if present
88
- image_data = parameters.get("image", None)
89
- system_prompt = parameters.get("system_prompt", None)
90
-
91
- # Extract generation parameters with sensible defaults
92
- max_new_tokens = parameters.get("max_new_tokens", 256)
93
- temperature = parameters.get("temperature", 0.7)
94
- top_p = parameters.get("top_p", 0.95)
95
- do_sample = parameters.get("do_sample", True)
96
-
97
- # Convert single string input to list for consistent handling
98
- if isinstance(inputs, str):
99
- inputs = [inputs]
100
-
101
- # Process each input
102
- generated_texts = []
103
- for input_text in inputs:
104
- # Format the prompt according to LLaDA's expected format
105
- formatted_prompt = self._format_prompt(input_text, system_prompt)
106
-
107
- if image_data:
108
- try:
109
- # Process image if present
110
- image = self._load_image(image_data)
111
- inputs_processor = self.processor(
112
- text=formatted_prompt, images=image, return_tensors="pt"
113
- )
114
-
115
- # Move inputs to the same device as the model
116
- for k, v in inputs_processor.items():
117
- if isinstance(v, torch.Tensor):
118
- inputs_processor[k] = v.to(self.model.device)
119
-
120
- # Generate text with image context
121
- with torch.no_grad():
122
- outputs = self.model.generate(
123
- **inputs_processor,
124
- max_new_tokens=max_new_tokens,
125
- temperature=temperature,
126
- top_p=top_p,
127
- do_sample=do_sample,
128
- pad_token_id=self.tokenizer.pad_token_id,
129
- )
130
-
131
- # Decode generated text
132
- generated_text = self.tokenizer.decode(
133
- outputs[0], skip_special_tokens=True
134
- )
135
- generated_texts.append(generated_text)
136
-
137
- except Exception as e:
138
- # If image processing fails, fall back to text-only
139
- print(
140
- f"Error processing image: {e}. Falling back to text-only processing."
141
- )
142
- image_data = None
143
-
144
- if not image_data:
145
- # Text-only processing
146
- input_tokens = self.tokenizer(formatted_prompt, return_tensors="pt").to(
147
- self.model.device
148
- )
149
-
150
- # Generate text
151
- with torch.no_grad():
152
- outputs = self.model.generate(
153
- **input_tokens,
154
- max_new_tokens=max_new_tokens,
155
- temperature=temperature,
156
- top_p=top_p,
157
- do_sample=do_sample,
158
- pad_token_id=self.tokenizer.pad_token_id,
159
- )
160
 
161
- # Decode generated text
162
- generated_text = self.tokenizer.decode(
163
- outputs[0], skip_special_tokens=True
164
- )
165
- generated_texts.append(generated_text)
166
 
167
- # Return results in expected format
168
- return {"generated_text": generated_texts}
 
 
1
+ from typing import Any, Dict
 
 
 
 
 
 
2
 
3
 
4
  class EndpointHandler:
5
+ def __init__(self, model_dir: str, **kwargs: Any) -> None:
6
+ self.model = AutoModel.from_pretrained(
7
+ model_dir,
8
+ torch_dtype=torch.bfloat16,
9
+ low_cpu_mem_usage=True,
10
+ use_flash_attn=False,
11
+ trust_remote_code=True,
12
+ device_map=split_model(),
13
+ ).eval()
14
+
15
+ self.tokenizer = AutoTokenizer.from_pretrained(
16
+ model_dir, trust_remote_code=True, use_fast=False
17
  )
18
 
19
+ def __call__(self, data: Dict[str, Any]) -> Any:
20
+ logger.info(f"Received incoming request with {data=}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
22
 
23
+ if __name__ == "__main__":
24
+ handler = EndpointHandler(model_dir="GSAI-ML/LLaDA-8B-Instruct")
25
+ print(handler)