Visual Document Retrieval
Transformers
Safetensors
ColPali
English
pretraining
adrish commited on
Commit
a5e6882
·
1 Parent(s): 1d493ce

added custom handler.py

Browse files
Files changed (1) hide show
  1. handler.py +109 -0
handler.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import AutoProcessor, AutoModelForImageTextToText
7
+ from typing import Dict, Any, List
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, model_path: str = None):
11
+ """
12
+ Initialize the endpoint handler by loading the ColPali model for image-to-text generation.
13
+ If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf' on Hugging Face.
14
+ """
15
+ if model_path is None:
16
+ model_path = os.path.dirname(os.path.realpath(__file__))
17
+ try:
18
+ # Select GPU if available, otherwise fall back to CPU.
19
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ # Load the model with the generic ImageTextToText interface.
21
+ self.model = AutoModelForImageTextToText.from_pretrained(
22
+ model_path,
23
+ device_map="cuda" if torch.cuda.is_available() else "cpu",
24
+ trust_remote_code=True,
25
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
26
+ _attn_implementation="flash_attention_2"
27
+ ).to(self.device)
28
+ # Load the processor which handles both image preprocessing and text tokenization.
29
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
30
+ except Exception as e:
31
+ raise RuntimeError(f"Error loading model or processor: {e}")
32
+
33
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
34
+ """
35
+ Process the input data for image-to-text generation.
36
+ Expects a dictionary with an "inputs" key containing a list of dictionaries.
37
+ Each dictionary should have:
38
+ - "image": a base64-encoded image string.
39
+ - "prompt": (optional) a text prompt (a default prompt is used if missing).
40
+ """
41
+ try:
42
+ inputs_list = data.get("inputs", [])
43
+ config = data.get("config", {})
44
+
45
+ if not inputs_list or not isinstance(inputs_list, list):
46
+ return {"error": "Inputs should be a list of dictionaries with 'image' and optionally 'prompt' keys."}
47
+
48
+ images: List[Image.Image] = []
49
+ texts: List[str] = []
50
+
51
+ for item in inputs_list:
52
+ image_b64 = item.get("image")
53
+ if not image_b64:
54
+ return {"error": "One of the input items is missing 'image' data."}
55
+ try:
56
+ # Decode base64 image and convert to RGB.
57
+ image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
58
+ images.append(image)
59
+ except Exception as e:
60
+ return {"error": f"Failed to decode one of the images: {e}"}
61
+ # Use the provided prompt or fall back to a default prompt.
62
+ prompt = item.get("prompt", "Describe the image content in detail.")
63
+ texts.append(prompt)
64
+
65
+ # Process both text and image inputs via the processor.
66
+ model_inputs = self.processor(
67
+ text=texts,
68
+ images=images,
69
+ padding=True,
70
+ return_tensors="pt",
71
+ ).to(self.device)
72
+
73
+ # Generation configuration (can be overridden by the request).
74
+ max_new_tokens = config.get("max_new_tokens", 1000)
75
+ temperature = config.get("temperature", 0.8)
76
+ num_return_sequences = config.get("num_return_sequences", 1)
77
+ do_sample = bool(config.get("do_sample", True))
78
+
79
+ # Generate outputs using the model.
80
+ outputs = self.model.generate(
81
+ **model_inputs,
82
+ temperature=temperature,
83
+ max_new_tokens=max_new_tokens,
84
+ num_return_sequences=num_return_sequences,
85
+ do_sample=do_sample,
86
+ )
87
+
88
+ # Decode the generated tokens into human-readable text.
89
+ text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
90
+
91
+ return {"responses": text_output}
92
+
93
+ except Exception as e:
94
+ return {"error": f"Unexpected error: {e}"}
95
+
96
+ # Instantiate the endpoint handler.
97
+ _service = EndpointHandler()
98
+
99
+ def handle(data, context):
100
+ """
101
+ Entry point for the Hugging Face dedicated inference endpoint.
102
+ It processes the input data and returns the model's generated responses.
103
+ """
104
+ try:
105
+ if data is None:
106
+ return {"error": "No input data received"}
107
+ return _service(data)
108
+ except Exception as e:
109
+ return {"error": f"Exception in handler: {e}"}