thisnick commited on
Commit
383b031
·
verified ·
1 Parent(s): ed84920

Upload full model folder with custom handler

Browse files
Files changed (2) hide show
  1. README.md +0 -1
  2. handler.py +34 -25
README.md CHANGED
@@ -2,7 +2,6 @@
2
  base_model:
3
  - meta-llama/Llama-3.1-8B-Instruct
4
  - google/siglip-so400m-patch14-384
5
- - fancyfeast/llama-joycaption-alpha-two-hf-llava
6
  tags:
7
  - captioning
8
  ---
 
2
  base_model:
3
  - meta-llama/Llama-3.1-8B-Instruct
4
  - google/siglip-so400m-patch14-384
 
5
  tags:
6
  - captioning
7
  ---
handler.py CHANGED
@@ -16,14 +16,21 @@ class EndpointHandler():
16
  self.model.eval()
17
 
18
  def __call__(self, data):
19
- # Expecting data with a "prompt" (text) and an "image" (base64 string)
20
- prompt = data.get("prompt", "Generate a caption for this image.")
21
- image_b64 = data.get("image")
22
- if image_b64 is None:
23
- return {"error": "No image provided in the payload."}
 
 
 
 
 
24
  try:
25
- image_bytes = base64.b64decode(image_b64)
26
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
27
  except Exception as e:
28
  return {"error": f"Failed to decode image: {str(e)}"}
29
 
@@ -41,32 +48,34 @@ class EndpointHandler():
41
  if not isinstance(convo_string, str):
42
  return {"error": "Failed to create conversation string."}
43
 
44
- # Prepare the inputs for the model
45
- inputs = self.processor(
46
  text=[convo_string],
47
- images=[image],
48
  return_tensors="pt"
49
  )
50
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
51
- if "pixel_values" in inputs:
52
- inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
53
 
54
- # Generate caption tokens
55
  generate_ids = self.model.generate(
56
- **inputs,
57
  max_new_tokens=300,
58
  do_sample=True,
59
  temperature=0.6,
60
  top_p=0.9
61
- )[0]
62
-
63
- # Optionally, trim off the prompt tokens
64
- generate_ids = generate_ids[inputs["input_ids"].shape[1]:]
65
 
66
- caption = self.processor.tokenizer.decode(
67
- generate_ids,
68
- skip_special_tokens=True,
69
- clean_up_tokenization_spaces=False
70
- ).strip()
 
 
 
 
 
71
 
72
- return {"caption": caption}
 
16
  self.model.eval()
17
 
18
  def __call__(self, data):
19
+ inputs = data.get("inputs", {})
20
+ prompt = inputs.get("prompt", "Generate a caption for this image.")
21
+ images_b64 = inputs.get("images")
22
+
23
+ # Handle both single image and list of images
24
+ if isinstance(images_b64, str):
25
+ images_b64 = [images_b64]
26
+ if not images_b64:
27
+ return {"error": "No images provided in the payload."}
28
+
29
  try:
30
+ images = [
31
+ Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
32
+ for img_b64 in images_b64
33
+ ]
34
  except Exception as e:
35
  return {"error": f"Failed to decode image: {str(e)}"}
36
 
 
48
  if not isinstance(convo_string, str):
49
  return {"error": "Failed to create conversation string."}
50
 
51
+ # Prepare the inputs for the model - process all images at once
52
+ model_inputs = self.processor(
53
  text=[convo_string],
54
+ images=images,
55
  return_tensors="pt"
56
  )
57
+ model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
58
+ if "pixel_values" in model_inputs:
59
+ model_inputs["pixel_values"] = model_inputs["pixel_values"].to(torch.bfloat16)
60
 
61
+ # Generate caption tokens for all images at once
62
  generate_ids = self.model.generate(
63
+ **model_inputs,
64
  max_new_tokens=300,
65
  do_sample=True,
66
  temperature=0.6,
67
  top_p=0.9
68
+ )
 
 
 
69
 
70
+ # Trim off the prompt tokens and decode all captions
71
+ generate_ids = generate_ids[:, model_inputs["input_ids"].shape[1]:]
72
+ captions = [
73
+ self.processor.tokenizer.decode(
74
+ ids,
75
+ skip_special_tokens=True,
76
+ clean_up_tokenization_spaces=False
77
+ ).strip()
78
+ for ids in generate_ids
79
+ ]
80
 
81
+ return {"captions": captions if len(captions) > 1 else captions[0]}