paultltc commited on
Commit
a924f05
·
1 Parent(s): 596b5f2

init commit

Browse files
Files changed (4) hide show
  1. Dockerfile +60 -0
  2. app.py +104 -0
  3. requirements.txt +13 -0
  4. tool.py +344 -0
Dockerfile ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+ RUN apt-get update && \
4
+ apt-get upgrade -y && \
5
+ apt-get install -y --no-install-recommends \
6
+ git \
7
+ git-lfs \
8
+ wget \
9
+ curl \
10
+ # python build dependencies \
11
+ build-essential \
12
+ libssl-dev \
13
+ zlib1g-dev \
14
+ libbz2-dev \
15
+ libreadline-dev \
16
+ libsqlite3-dev \
17
+ libncursesw5-dev \
18
+ xz-utils \
19
+ tk-dev \
20
+ libxml2-dev \
21
+ libxmlsec1-dev \
22
+ libffi-dev \
23
+ liblzma-dev \
24
+ # gradio dependencies \
25
+ ffmpeg \
26
+ poppler-utils \
27
+ && apt-get clean \
28
+ && rm -rf /var/lib/apt/lists/*
29
+
30
+
31
+ RUN useradd -m -u 1000 user
32
+ USER user
33
+ ENV HOME=/home/user \
34
+ PATH=/home/user/.local/bin:${PATH}
35
+ WORKDIR ${HOME}/app
36
+
37
+ RUN curl https://pyenv.run | bash
38
+ ENV PATH=${HOME}/.pyenv/shims:${HOME}/.pyenv/bin:${PATH}
39
+ ARG PYTHON_VERSION=3.10.12
40
+ RUN pyenv install ${PYTHON_VERSION} && \
41
+ pyenv global ${PYTHON_VERSION} && \
42
+ pyenv rehash && \
43
+ pip install --no-cache-dir -U pip setuptools wheel && \
44
+ pip install packaging ninja
45
+
46
+ COPY --chown=1000 ./requirements.txt /tmp/requirements.txt
47
+ RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt && \
48
+ pip install flash-attn --no-build-isolation
49
+
50
+
51
+
52
+ COPY --chown=1000 . ${HOME}/app
53
+ ENV PYTHONPATH=${HOME}/app \
54
+ PYTHONUNBUFFERED=1 \
55
+ GRADIO_ALLOW_FLAGGING=never \
56
+ GRADIO_NUM_PORTS=1 \
57
+ GRADIO_SERVER_NAME=0.0.0.0 \
58
+ GRADIO_THEME=huggingface \
59
+ SYSTEM=spaces
60
+ CMD ["python", "app.py"]
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from tool import VisualRAGTool
4
+
5
+ tool = VisualRAGTool()
6
+
7
+ def search(query, k, api_key):
8
+ """Searches for the most relevant pages based on the query."""
9
+ print("=============== SEARCHING ===============")
10
+
11
+ context, answer = tool.search(query, k, api_key)
12
+
13
+ context_gallery = [(page.image, page.caption) for page in context]
14
+
15
+ print("========================================")
16
+
17
+ return gr.Gallery(value=context_gallery, label="Retrieved Documents", height=400, show_label=True, visible=True), answer
18
+
19
+ def index(files, contextualize_embeds, api_key):
20
+ """Indexes the uploaded files."""
21
+ print("=============== INDEXING ===============")
22
+
23
+ indexed_files_num = tool.index(
24
+ files=files,
25
+ contextualize=contextualize_embeds,
26
+ api_key=api_key,
27
+ overwrite_db=True
28
+ )
29
+
30
+ print("========================================")
31
+ return gr.Textbox(f"Uploaded and processed {indexed_files_num} pages!"),\
32
+ gr.Textbox(
33
+ lines=2,
34
+ label="Query",
35
+ show_label=False,
36
+ placeholder="Enter your prompt here and press Shift+Enter or press the button",
37
+ interactive=True,
38
+ )
39
+
40
+ def show_processing_status():
41
+ """Updates the upload status."""
42
+ return gr.Textbox(label="Processing Status", interactive=False, visible=True),\
43
+ gr.Checkbox(label="Contextualize Embeddings", visible=False),\
44
+ gr.Textbox(
45
+ lines=2,
46
+ label="Query",
47
+ show_label=False,
48
+ placeholder="Enter your prompt here and press Shift+Enter or press the button",
49
+ interactive=False,
50
+ )
51
+
52
+ with gr.Blocks(
53
+ theme=gr.themes.Ocean(),
54
+ title="ColPali Tool Demo",
55
+ ) as demo:
56
+ gr.Markdown("""\
57
+ # ColPali: Efficient Document Retrieval with Vision Language Models (ColQwen2) 📚
58
+ Demo to test the ColPali RAG Tool powered by ColQwen2 (ColPali) on PDF documents.
59
+ ColPali is implemented from the [ColPali paper](https://arxiv.org/abs/2407.01449).
60
+
61
+ This tool allows you to upload PDF files and search for the most relevant pages based on your query.
62
+ Refresh the page if you change documents!
63
+
64
+ ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing english text. Performance is expected to drop for other page formats and languages.
65
+ Other models will be released with better robustness towards different languages and document formats!
66
+ """)
67
+
68
+ api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
69
+
70
+ stored_messages = gr.State(value=[])
71
+
72
+ gr.Markdown("## 1️⃣ Upload PDFs")
73
+ gr.Markdown("Upload PDF files to index and search.")
74
+ with gr.Group():
75
+ contextualize_embeds = gr.Checkbox(label="Contextualize Embeddings", info="Add images surrouding context as metadata. Generated using gpt-4o-mini. ⚠️ Indexing will be longer!", value=True)
76
+ upload_files = gr.File(file_types=[".pdf"], file_count="multiple", label="Upload files")
77
+ processing_status = gr.Textbox(label="Processing Status", interactive=False, visible=False)
78
+
79
+ gr.Markdown("## 2️⃣ Search")
80
+ gr.Markdown("Ask a question relevant to the documents you uploaded.")
81
+ with gr.Group():
82
+ chatbot = gr.Textbox(label="AI Assistant", placeholder="Generated response based on retrieved documents.", lines=6)
83
+ output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True, visible=False)
84
+ with gr.Row(equal_height=True):
85
+ with gr.Column(scale=4):
86
+ text_input = gr.Textbox(
87
+ lines=2,
88
+ label="Query",
89
+ show_label=False,
90
+ placeholder="Enter your prompt here and press Shift+Enter or press the button",
91
+ )
92
+ with gr.Column(scale=1):
93
+ k = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Pages to retrieve")
94
+ search_button = gr.Button("🔍 Search", variant="primary")
95
+
96
+ # Define the flow of the demo
97
+ # upload_files.change(index, inputs=[upload_files, api_key], outputs=[upload_status])
98
+ upload_files.change(show_processing_status, inputs=[], outputs=[processing_status, contextualize_embeds, text_input])\
99
+ .then(index, inputs=[upload_files, contextualize_embeds, api_key], outputs=[processing_status, text_input])
100
+ text_input.submit(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot])
101
+ search_button.click(search, inputs=[text_input, k, api_key], outputs=[output_gallery, chatbot])
102
+
103
+ if __name__ == "__main__":
104
+ demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ colpali-engine==0.3.8
2
+ pdf2image
3
+ GPUtil
4
+ accelerate==0.30.1
5
+ openai
6
+ gradio
7
+ gradio_client
8
+ tqdm
9
+ Pillow
10
+ pqdm
11
+ smolagents
12
+ pyyaml
13
+ python-dotenv
tool.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple
3
+
4
+ import torch
5
+ from torch.utils.data import DataLoader, Dataset
6
+
7
+ import base64
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ from pdf2image import convert_from_path
11
+
12
+ from tqdm import tqdm
13
+ from pqdm.processes import pqdm
14
+
15
+ from colpali_engine.models import ColQwen2, ColQwen2Processor
16
+
17
+ from smolagents import Tool, ChatMessage
18
+
19
+ from utils import query_openai
20
+
21
+ from dotenv import load_dotenv
22
+ load_dotenv()
23
+
24
+ def encode_image_to_base64(image):
25
+ """Encodes a PIL image to a base64 string."""
26
+ buffered = BytesIO()
27
+ image.save(buffered, format="JPEG")
28
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
29
+
30
+ DEFAULT_SYSTEM_PROMPT = \
31
+ """You are a smart assistant designed to answer questions about a PDF document.
32
+ You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
33
+ Use them to construct a short response to the question, and cite your sources in the following format: (document, page number).
34
+ If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
35
+ Give detailed and extensive answers, only containing info in the pages you are given.
36
+ You can answer using information contained in plots and figures if necessary.
37
+ Answer in the same language as the query."""
38
+
39
+ def _build_query(query, pages):
40
+ messages = []
41
+ messages.append({"type": "text", "text": "PDF pages:\n"})
42
+ for page in pages:
43
+ capt = page.caption
44
+ if capt is not None:
45
+ messages.append({
46
+ "type": "text",
47
+ "text": capt
48
+ })
49
+ messages.append({
50
+ "type": "image_url",
51
+ "image_url": {
52
+ "url": f"data:image/jpeg;base64,{encode_image_to_base64(page.image)}"
53
+ },
54
+ })
55
+ messages.append({"type": "text", "text": f"Query:\n{query}"})
56
+
57
+ return messages
58
+
59
+ def query_openai(query, pages, api_key=None, system_prompt=DEFAULT_SYSTEM_PROMPT, model="gpt-4o-mini") -> ChatMessage:
60
+ """Calls OpenAI's GPT-4o-mini with the query and image data."""
61
+ if api_key and api_key.startswith("sk"):
62
+ try:
63
+ from openai import OpenAI
64
+
65
+ client = OpenAI(api_key=api_key.strip())
66
+
67
+ response = client.chat.completions.create(
68
+ model=model,
69
+ messages=[
70
+ {
71
+ "role": "system",
72
+ "content": system_prompt
73
+ },
74
+ {
75
+ "role": "user",
76
+ "content": _build_query(query, pages)
77
+ }
78
+ ],
79
+ max_tokens=500,
80
+ )
81
+
82
+ message = ChatMessage.from_dict(
83
+ response.choices[0].message.model_dump(include={"role", "content", "tool_calls"})
84
+ )
85
+ message.raw = response
86
+
87
+ return message
88
+
89
+ except Exception as e:
90
+ return "OpenAI API connection failure. Verify the provided key is correct (sk-***)."
91
+
92
+ return "Enter your OpenAI API key to get a custom response"
93
+
94
+ DEFAULT_CONTEXT_PROMPT = \
95
+ """You are a smart assistant designed to extract context of PDF pages.
96
+ Give concise answers, only containing info in the pages you are given.
97
+ You can answer using information contained in plots and figures if necessary."""
98
+
99
+ RAG_SYSTEM_PROMPT = \
100
+ """ You are a smart assistant designed to answer questions about a PDF document.
101
+
102
+ You are given relevant information in the form of PDF pages preceded by their metadata: document title, page identifier, surrounding context.
103
+ Use them to construct a response to the question, and cite your sources.
104
+ Use the following citation format:
105
+ "Some information from a first document [1, p.Page Number]. Some information from the same first document but at a different page [1, p.Page Number]. Some more information from another document [2, p.Page Number].
106
+ ...
107
+ Sources:
108
+ [1] Document Title
109
+ [2] Another Document Title"
110
+
111
+ You can answer using information contained in plots and figures if necessary.
112
+ If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
113
+ Give detailed answers, only containing info in the pages you are given.
114
+ Answer in the same language as the query."""
115
+
116
+ @dataclass
117
+ class Metadata:
118
+ doc_title: str
119
+ page_id: int
120
+ context: Optional[str] = None
121
+
122
+ def __str__(self):
123
+ return f"Document: {self.doc_title}, Page ID: {self.page_id}, Context: {self.context}"
124
+
125
+ @dataclass
126
+ class Page:
127
+ image: Image.Image
128
+ metadata: Optional[Metadata] = None
129
+
130
+ @property
131
+ def caption(self):
132
+ if self.metadata is None:
133
+ return None
134
+ return f"Document: {self.metadata.doc_title}, Context: {self.metadata.context}"
135
+
136
+ class VisualRAGTool(Tool):
137
+ name = "visual_rag"
138
+ description = """Performs a RAG query on your internal PDF documents and returns the generated text response."""
139
+ inputs = {
140
+ "query": {
141
+ "type": "string",
142
+ "description": "The query to perform. This should be semantically close to your target documents.",
143
+ },
144
+ "k": {
145
+ "type": "number",
146
+ "description": "The number of documents to retrieve.",
147
+ "default": 1,
148
+ "nullable": True,
149
+ },
150
+ "api_key": {
151
+ "type": "string",
152
+ "description": "The OpenAI API key to use for the query. If not provided, the key will be taken from the OPENAI_KEY environment variable.",
153
+ "nullable": True,
154
+ }
155
+ }
156
+ output_type = "string"
157
+
158
+ def _init_models(self, model_name: str) -> None:
159
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
160
+ self.model = ColQwen2.from_pretrained(
161
+ model_name,
162
+ torch_dtype=torch.bfloat16,
163
+ device_map="auto",
164
+ attn_implementation="flash_attention_2"
165
+ ).eval()
166
+ self.processor = ColQwen2Processor.from_pretrained(model_name)
167
+
168
+ def __init__(self, model_name: str = "vidore/colqwen2-v1.0", api_key: str = None, files: List[str] = None, **kwargs):
169
+ super().__init__(**kwargs)
170
+ self.model_name = model_name
171
+ self.api_key = api_key
172
+
173
+ self.embds = []
174
+ self.pages = []
175
+
176
+ self.files = files
177
+
178
+ self._init_models(self.model_name)
179
+
180
+ self.is_initialized = False
181
+
182
+ def setup(self):
183
+ """
184
+ Overwrite this method here for any operation that is expensive and needs to be executed before you start using
185
+ your tool. Such as loading a big model.
186
+ """
187
+ if self.files:
188
+ _ = self.index(self.files, self.api_key)
189
+
190
+ self.is_initialized = True
191
+
192
+ def _extract_contexts(self, images, api_key, window=10) -> List[str]:
193
+ """Extracts context from images."""
194
+ try:
195
+ args = [
196
+ {
197
+ 'query': "Give the general context about these pages. Give the context in the same language as the documents.",
198
+ 'pages': [Page(image=im) for im in images[max(i-window+1, 0):i+1]],
199
+ 'api_key': api_key,
200
+ 'system_prompt': DEFAULT_CONTEXT_PROMPT
201
+ } for i in range(0, len(images), window)
202
+ ]
203
+ window_contexts = pqdm(args, query_openai, n_jobs=8, argument_type='kwargs')
204
+
205
+ # code sequentially ftm with tqdm
206
+ # query = "Give the general context about these pages. Give the context in the same language as the documents."
207
+ # window_contexts = [query_openai(query, [Page(image=im) for im in images[max(i-window+1, 0):i+1]], api_key, DEFAULT_CONTEXT_PROMPT)\
208
+ # for i in tqdm(range(0, len(images), window))]
209
+
210
+ contexts = []
211
+ for i in range(len(images)):
212
+ context = window_contexts[i//window].content
213
+ contexts.append(context)
214
+
215
+ except Exception as e:
216
+ print(f"Error extracting contexts: {e}")
217
+ contexts = [None for _ in range(len(images))]
218
+
219
+ # Ensure that the number of contexts is equal to the number of images
220
+ assert len(contexts) == len(images)
221
+
222
+ return contexts
223
+
224
+ def _process_file(self, file: str, contextualize: bool = True, api_key: str = None, window: int = 10) -> List[Page]:
225
+ """Converts a file to images and extracts metadata."""
226
+ title = file.split("/")[-1]
227
+ images = convert_from_path(file, thread_count=4)
228
+ if contextualize and api_key:
229
+ contexts = self._extract_contexts(images, api_key, window=window)
230
+ else:
231
+ contexts = [None for _ in range(len(images))]
232
+ metadatas = [Metadata(doc_title=title, page_id=i, context=contexts[i]) for i in range(len(images))]
233
+
234
+ return [Page(image=img, metadata=metadata) for img, metadata in zip(images, metadatas)]
235
+
236
+ def preprocess(self, files: List[str], contextualize: bool = True, api_key: str = None, window: int = 10) -> List[Page]:
237
+ """Preprocesses the files and extracts metadata."""
238
+ pages = [page for file in files for page in self._process_file(file, contextualize=contextualize, api_key=api_key, window=window)]
239
+
240
+ print(f"Example metadata:\n{pages[0].metadata.context}")
241
+
242
+ return pages
243
+
244
+ def _embed_images(self, pages: List[Page]) -> List[torch.Tensor]:
245
+ """Embeds the images using the model."""
246
+ """Example script to run inference with ColPali (ColQwen2)"""
247
+ # run inference - docs
248
+ dataloader = DataLoader(
249
+ pages,
250
+ batch_size=4,
251
+ shuffle=False,
252
+ collate_fn=lambda x: self.processor.process_images([p.image for p in x]).to(self.device),
253
+ )
254
+
255
+ embds = []
256
+
257
+ for batch_doc in tqdm(dataloader):
258
+ with torch.no_grad():
259
+ batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
260
+ embeddings_doc = self.model(**batch_doc)
261
+ embds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
262
+
263
+ return embds
264
+
265
+ def index(self, files: List[str], contextualize: bool = True, api_key: str = None, overwrite_db: bool = False) -> int:
266
+ print("Converting files...")
267
+ # Convert files to images and extract metadata
268
+ pgs = self.preprocess(files, contextualize=contextualize, api_key=api_key or self.api_key)
269
+
270
+ # Embed the images
271
+ embds = self._embed_images(pgs)
272
+
273
+ # Overwrite the database if necessary
274
+ if overwrite_db:
275
+ self.pages = []
276
+ self.embds = []
277
+
278
+ # Extend the pages
279
+ self.pages.extend(pgs)
280
+
281
+ # Extend the datasets
282
+ self.embds.extend(embds)
283
+
284
+ print(f"Extracted and indexed {len(pgs)} images from {len(files)} files.")
285
+
286
+ return len(embds)
287
+
288
+ def retrieve(self, query: str, k: int) -> List[Page]:
289
+ """Retrieve the top k documents based on the query."""
290
+ k = min(k, len(self.embds))
291
+
292
+ qs = []
293
+ with torch.no_grad():
294
+ batch_query = self.processor.process_queries([query]).to(self.model.device)
295
+ embeddings_query = self.model(**batch_query)
296
+ qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))
297
+
298
+ # Run scoring
299
+ scores = self.processor.score(qs, self.embds, device=self.device)[0]
300
+ top_k_idx = scores.topk(k).indices.tolist()
301
+
302
+ print("Top Scores:")
303
+ [print(f'Page {self.pages[idx].metadata.page_id}: {scores[idx]}') for idx in top_k_idx]
304
+
305
+ # Get the top k results
306
+ results = [self.pages[idx] for idx in top_k_idx]
307
+
308
+ return results
309
+
310
+ def generate_answer(self, query: str, docs: List[Page], api_key: str = None) -> ChatMessage:
311
+ result = query_openai(query, docs, api_key or self.api_key, system_prompt=RAG_SYSTEM_PROMPT)
312
+ return result
313
+
314
+ def search(self, query: str, k: int = 1, api_key: str = None) -> Tuple[list, str]:
315
+ print(f"Searching for query: {query}")
316
+
317
+ # Retrieve the top k documents
318
+ context = self.retrieve(query, k)
319
+
320
+ # Generate response from GPT-4o-mini
321
+ rag_answer = self.generate_answer(
322
+ query=query,
323
+ docs=context,
324
+ api_key=api_key
325
+ )
326
+
327
+ return context, rag_answer.content
328
+
329
+ def forward(self, query: str, k: int = 1, api_key: str = None) -> str:
330
+ assert isinstance(query, str), "Your search query must be a string"
331
+
332
+ # Online indexing
333
+ # if files:
334
+ # _ = self.index(files, api_key)
335
+
336
+ # Retrieve the top k documents and generate response
337
+ _, rag_answer = self.search(
338
+ query=query,
339
+ files=None,
340
+ k=k,
341
+ api_key=api_key
342
+ )
343
+
344
+ return rag_answer