amitsinghchandel commited on
Commit
b69b215
·
1 Parent(s): fe70f68

Update space with new code

Browse files
README.md CHANGED
@@ -1,12 +1,8 @@
1
- ---
2
- title: Multimodal RAG
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
 
12
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
1
+ MultiModal Rag with Colpali and Milvus
2
+ ===
 
 
 
 
 
 
 
 
3
 
4
+ Code for blog [https://saumitra.me/2024/2024-11-15-colpali-milvus-rag/](https://saumitra.me/2024/2024-11-15-colpali-milvus-rag/) on how to do multimodal RAG with [colpali](https://arxiv.org/abs/2407.01449), [milvus](https://milvus.io/) and a visual LLM (gemini/gpt-4o)
5
+
6
+ Demo running at [https://huggingface.co/spaces/saumitras/colpali-milvus](https://huggingface.co/spaces/saumitras/colpali-milvus)
7
+
8
+ Application will allow users to upload a PDF and then perform search or Q&A queries on both the text and visual elements of the document. We will not extract text from the PDF; instead, we will treat it as an image and use colpali to get embeddings for the PDF pages. These embeddings will be indexed to Milvus, and then we will use a visual LLM (gemini/gpt-4o) to facilitate the Q&A queries.
__init__.py ADDED
File without changes
__pycache__/__init__.cpython-310.pyc ADDED
Binary file (118 Bytes). View file
 
__pycache__/colpali_manager.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
__pycache__/middleware.cpython-310.pyc ADDED
Binary file (2.11 kB). View file
 
__pycache__/milvus_manager.cpython-310.pyc ADDED
Binary file (5.27 kB). View file
 
__pycache__/pdf_manager.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
__pycache__/rag.cpython-310.pyc ADDED
Binary file (1.28 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (371 Bytes). View file
 
app.py CHANGED
@@ -1,64 +1,135 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
 
64
  demo.launch()
 
1
  import gradio as gr
2
+ import tempfile
3
+ import os
4
+ import fitz # PyMuPDF
5
+ import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ from FAPS.middleware import Middleware
8
+ from FAPS.rag import Rag
9
+
10
+ rag = Rag()
11
+
12
+ def generate_uuid(state):
13
+ # Check if UUID already exists in session state
14
+ if state["user_uuid"] is None:
15
+ # Generate a new UUID if not already set
16
+ state["user_uuid"] = str(uuid.uuid4())
17
+
18
+ return state["user_uuid"]
19
+
20
+
21
+ class PDFSearchApp:
22
+ def __init__(self):
23
+ self.indexed_docs = {}
24
+ self.current_pdf = None
25
+
26
+
27
+ def upload_and_convert(self, state, file, max_pages):
28
+ id = generate_uuid(state)
29
+
30
+ if file is None:
31
+ return "No file uploaded"
32
+
33
+ print(f"Uploading file: {file.name}, id: {id}")
34
+
35
+ try:
36
+ self.current_pdf = file.name
37
+
38
+ middleware = Middleware(id, create_collection=True)
39
+
40
+ pages = middleware.index(pdf_path=file.name, id=id, max_pages=max_pages)
41
+
42
+ self.indexed_docs[id] = True
43
+
44
+ return f"Uploaded and extracted {len(pages)} pages"
45
+ except Exception as e:
46
+ return f"Error processing PDF: {str(e)}"
47
+
48
+
49
+ def search_documents(self, state, query, num_results=1):
50
+ print(f"Searching for query: {query}")
51
+ id = generate_uuid(state)
52
+
53
+ if not self.indexed_docs[id]:
54
+ print("Please index documents first")
55
+ return "Please index documents first", "--"
56
+ if not query:
57
+ print("Please enter a search query")
58
+ return "Please enter a search query", "--"
59
+
60
+ try:
61
+
62
+ middleware = Middleware(id, create_collection=False)
63
+
64
+ search_results = middleware.search([query])[0]
65
+
66
+ page_num = search_results[0][1] + 1
67
+
68
+ print(f"Retrieved page number: {page_num}")
69
+
70
+ img_path = f"pages/{id}/page_{page_num}.png"
71
+
72
+ print(f"Retrieved image path: {img_path}")
73
+
74
+ rag_response = rag.get_answer_from_gemini(query, [img_path])
75
+
76
+ return img_path, rag_response
77
+
78
+ except Exception as e:
79
+ return f"Error during search: {str(e)}", "--"
80
+
81
+ def create_ui():
82
+ app = PDFSearchApp()
83
+
84
+ with gr.Blocks() as demo:
85
+ state = gr.State(value={"user_uuid": None})
86
+
87
+ gr.Markdown("# Colpali Milvus Multimodal RAG Demo")
88
+ gr.Markdown("This demo showcases how to use [Colpali](https://github.com/illuin-tech/colpali) embeddings with [Milvus](https://milvus.io/) and utilizing Gemini/OpenAI multimodal RAG for pdf search and Q&A.")
89
+
90
+ with gr.Tab("Upload PDF"):
91
+ with gr.Column():
92
+ file_input = gr.File(label="Upload PDF")
93
+
94
+ max_pages_input = gr.Slider(
95
+ minimum=1,
96
+ maximum=50,
97
+ value=20,
98
+ step=10,
99
+ label="Max pages to extract and index"
100
+ )
101
+
102
+ status = gr.Textbox(label="Indexing Status", interactive=False)
103
+
104
+ with gr.Tab("Query"):
105
+ with gr.Column():
106
+ query_input = gr.Textbox(label="Enter query")
107
+ # num_results = gr.Slider(
108
+ # minimum=1,
109
+ # maximum=10,
110
+ # value=5,
111
+ # step=1,
112
+ # label="Number of results"
113
+ # )
114
+ search_btn = gr.Button("Query")
115
+ llm_answer = gr.Textbox(label="RAG Response", interactive=False)
116
+ images = gr.Image(label="Top page matching query")
117
+
118
+ # Event handlers
119
+ file_input.change(
120
+ fn=app.upload_and_convert,
121
+ inputs=[state, file_input, max_pages_input],
122
+ outputs=[status]
123
+ )
124
+
125
+ search_btn.click(
126
+ fn=app.search_documents,
127
+ inputs=[state, query_input],
128
+ outputs=[images, llm_answer]
129
+ )
130
+
131
+ return demo
132
 
133
  if __name__ == "__main__":
134
+ demo = create_ui()
135
  demo.launch()
colpali_manager.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colpali_engine.models import ColPali
2
+ from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
3
+ from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
4
+ from colpali_engine.utils.torch_utils import ListDataset
5
+ from torch.utils.data import DataLoader
6
+ import torch
7
+ from typing import List, cast
8
+
9
+ from tqdm import tqdm
10
+ from PIL import Image
11
+
12
+ # Ensure device is set to CPU for macOS
13
+ device = torch.device("cpu")
14
+ model_name = "vidore/colpali-v1.2"
15
+
16
+ # Load the ColPali model and processor for CPU
17
+ model = ColPali.from_pretrained(
18
+ model_name,
19
+ torch_dtype=torch.float32, # Use float32 for CPU
20
+ device_map=None, # No device map needed for CPU
21
+ ).to(device).eval()
22
+
23
+ processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
24
+
25
+
26
+ class ColpaliManager:
27
+ def __init__(self, device="cpu", model_name="vidore/colpali-v1.2"):
28
+ print(f"Initializing ColpaliManager with device {device} and model {model_name}")
29
+
30
+ self.device = torch.device(device)
31
+ self.model = ColPali.from_pretrained(
32
+ model_name,
33
+ torch_dtype=torch.float32,
34
+ device_map=None,
35
+ ).to(self.device).eval()
36
+
37
+ self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
38
+
39
+ def get_images(self, paths: List[str]) -> List[Image.Image]:
40
+ return [Image.open(path) for path in paths]
41
+
42
+ def process_images(self, image_paths: List[str], batch_size=5) -> List[float]:
43
+ print(f"Processing {len(image_paths)} image_paths")
44
+ images = self.get_images(image_paths)
45
+
46
+ dataloader = DataLoader(
47
+ dataset=ListDataset[Image.Image](images),
48
+ batch_size=batch_size,
49
+ shuffle=False,
50
+ collate_fn=lambda x: self.processor.process_images(x),
51
+ )
52
+
53
+ embeddings = []
54
+ for batch_doc in tqdm(dataloader):
55
+ with torch.no_grad():
56
+ batch_doc = {k: v.to(self.device) for k, v in batch_doc.items()}
57
+ embeddings_batch = self.model(**batch_doc)
58
+ embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device))))
59
+
60
+ return [embedding.float().cpu().numpy() for embedding in embeddings]
61
+
62
+ def process_text(self, texts: List[str], batch_size=1) -> List[float]:
63
+ print(f"Processing {len(texts)} texts")
64
+
65
+ dataloader = DataLoader(
66
+ dataset=ListDataset[str](texts),
67
+ batch_size=batch_size,
68
+ shuffle=False,
69
+ collate_fn=lambda x: self.processor.process_queries(x),
70
+ )
71
+
72
+ embeddings = []
73
+ for batch_query in dataloader:
74
+ with torch.no_grad():
75
+ batch_query = {k: v.to(self.device) for k, v in batch_query.items()}
76
+ embeddings_batch = self.model(**batch_query)
77
+ embeddings.extend(list(torch.unbind(embeddings_batch.to(self.device))))
78
+
79
+ return [embedding.float().cpu().numpy() for embedding in embeddings]
middleware.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from FAPS.colpali_manager import ColpaliManager
2
+ from FAPS.milvus_manager import MilvusManager
3
+ from FAPS.pdf_manager import PdfManager
4
+ import hashlib
5
+
6
+
7
+ pdf_manager = PdfManager()
8
+ colpali_manager = ColpaliManager(device="cpu", model_name="vidore/colpali-v1.2")
9
+
10
+
11
+
12
+ class Middleware:
13
+ def __init__(self, id:str, create_collection=True):
14
+ hashed_id = hashlib.md5(id.encode()).hexdigest()[:8]
15
+ milvus_db_name = f"milvus_{hashed_id}.db"
16
+ self.milvus_manager = MilvusManager(milvus_db_name, "colpali", create_collection)
17
+
18
+ def index(self, pdf_path: str, id:str, max_pages: int, pages: list[int] = None):
19
+
20
+ print(f"Indexing {pdf_path}, id: {id}, max_pages: {max_pages}")
21
+
22
+ image_paths = pdf_manager.save_images(id, pdf_path, max_pages)
23
+
24
+ print(f"Saved {len(image_paths)} images")
25
+
26
+ colbert_vecs = colpali_manager.process_images(image_paths)
27
+
28
+ images_data = [{
29
+ "colbert_vecs": colbert_vecs[i],
30
+ "filepath": image_paths[i]
31
+ } for i in range(len(image_paths))]
32
+
33
+ print(f"Inserting {len(images_data)} images data to Milvus")
34
+
35
+ self.milvus_manager.insert_images_data(images_data)
36
+
37
+ print("Indexing completed")
38
+
39
+ return image_paths
40
+
41
+
42
+
43
+ def search(self, search_queries: list[str]):
44
+ print(f"Searching for {len(search_queries)} queries")
45
+
46
+ final_res = []
47
+
48
+ for query in search_queries:
49
+ print(f"Searching for query: {query}")
50
+ query_vec = colpali_manager.process_text([query])[0]
51
+ search_res = self.milvus_manager.search(query_vec, topk=1)
52
+ print(f"Search result: {search_res} for query: {query}")
53
+ final_res.append(search_res)
54
+
55
+ return final_res
56
+
milvus_manager.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymilvus import MilvusClient, DataType
2
+ import numpy as np
3
+ import concurrent.futures
4
+
5
+
6
+ class MilvusManager:
7
+ def __init__(self, milvus_uri, collection_name, create_collection, dim=128):
8
+ self.client = MilvusClient(uri=milvus_uri)
9
+ self.collection_name = collection_name
10
+ if self.client.has_collection(collection_name=self.collection_name):
11
+ self.client.load_collection(collection_name)
12
+ self.dim = dim
13
+
14
+ if create_collection:
15
+ self.create_collection()
16
+ self.create_index()
17
+
18
+
19
+ def create_collection(self):
20
+ if self.client.has_collection(collection_name=self.collection_name):
21
+ self.client.drop_collection(collection_name=self.collection_name)
22
+ schema = self.client.create_schema(
23
+ auto_id=True,
24
+ enable_dynamic_fields=True,
25
+ )
26
+ schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True)
27
+ schema.add_field(
28
+ field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim
29
+ )
30
+ schema.add_field(field_name="seq_id", datatype=DataType.INT16)
31
+ schema.add_field(field_name="doc_id", datatype=DataType.INT64)
32
+ schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535)
33
+
34
+ self.client.create_collection(
35
+ collection_name=self.collection_name, schema=schema
36
+ )
37
+
38
+ def create_index(self):
39
+ self.client.release_collection(collection_name=self.collection_name)
40
+ self.client.drop_index(
41
+ collection_name=self.collection_name, index_name="vector"
42
+ )
43
+ index_params = self.client.prepare_index_params()
44
+ index_params.add_index(
45
+ field_name="vector",
46
+ index_name="vector_index",
47
+ index_type="HNSW",
48
+ metric_type="IP",
49
+ params={
50
+ "M": 16,
51
+ "efConstruction": 500,
52
+ },
53
+ )
54
+
55
+ self.client.create_index(
56
+ collection_name=self.collection_name, index_params=index_params, sync=True
57
+ )
58
+
59
+ def create_scalar_index(self):
60
+ self.client.release_collection(collection_name=self.collection_name)
61
+
62
+ index_params = self.client.prepare_index_params()
63
+ index_params.add_index(
64
+ field_name="doc_id",
65
+ index_name="int32_index",
66
+ index_type="INVERTED",
67
+ )
68
+
69
+ self.client.create_index(
70
+ collection_name=self.collection_name, index_params=index_params, sync=True
71
+ )
72
+
73
+ def search(self, data, topk):
74
+ search_params = {"metric_type": "IP", "params": {}}
75
+ results = self.client.search(
76
+ self.collection_name,
77
+ data,
78
+ limit=int(50),
79
+ output_fields=["vector", "seq_id", "doc_id"],
80
+ search_params=search_params,
81
+ )
82
+ doc_ids = set()
83
+ for r_id in range(len(results)):
84
+ for r in range(len(results[r_id])):
85
+ doc_ids.add(results[r_id][r]["entity"]["doc_id"])
86
+
87
+ scores = []
88
+
89
+ def rerank_single_doc(doc_id, data, client, collection_name):
90
+ doc_colbert_vecs = client.query(
91
+ collection_name=collection_name,
92
+ filter=f"doc_id in [{doc_id}, {doc_id + 1}]",
93
+ output_fields=["seq_id", "vector", "doc"],
94
+ limit=1000,
95
+ )
96
+ doc_vecs = np.vstack(
97
+ [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))]
98
+ )
99
+ score = np.dot(data, doc_vecs.T).max(1).sum()
100
+ return (score, doc_id)
101
+
102
+ with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor:
103
+ futures = {
104
+ executor.submit(
105
+ rerank_single_doc, doc_id, data, self.client, self.collection_name
106
+ ): doc_id
107
+ for doc_id in doc_ids
108
+ }
109
+ for future in concurrent.futures.as_completed(futures):
110
+ score, doc_id = future.result()
111
+ scores.append((score, doc_id))
112
+
113
+ scores.sort(key=lambda x: x[0], reverse=True)
114
+ if len(scores) >= topk:
115
+ return scores[:topk]
116
+ else:
117
+ return scores
118
+
119
+ def insert(self, data):
120
+ colbert_vecs = [vec for vec in data["colbert_vecs"]]
121
+ seq_length = len(colbert_vecs)
122
+ doc_ids = [data["doc_id"] for i in range(seq_length)]
123
+ seq_ids = list(range(seq_length))
124
+ docs = [""] * seq_length
125
+ docs[0] = data["filepath"]
126
+
127
+ self.client.insert(
128
+ self.collection_name,
129
+ [
130
+ {
131
+ "vector": colbert_vecs[i],
132
+ "seq_id": seq_ids[i],
133
+ "doc_id": doc_ids[i],
134
+ "doc": docs[i],
135
+ }
136
+ for i in range(seq_length)
137
+ ],
138
+ )
139
+
140
+
141
+ def get_images_as_doc(self, images_with_vectors:list):
142
+
143
+ images_data = []
144
+
145
+ for i in range(len(images_with_vectors)):
146
+ data = {
147
+ "colbert_vecs": images_with_vectors[i]["colbert_vecs"],
148
+ "doc_id": i,
149
+ "filepath": images_with_vectors[i]["filepath"],
150
+ }
151
+ images_data.append(data)
152
+
153
+ return images_data
154
+
155
+
156
+ def insert_images_data(self, image_data):
157
+ data = self.get_images_as_doc(image_data)
158
+
159
+ for i in range(len(data)):
160
+ self.insert(data[i])
161
+
162
+
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
pdf_manager.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pdf2image import convert_from_path
2
+ import os
3
+ import shutil
4
+
5
+ class PdfManager:
6
+ def __init__(self):
7
+ pass
8
+
9
+ def clear_and_recreate_dir(self, output_folder):
10
+ print(f"Clearing output folder {output_folder}")
11
+
12
+ if os.path.exists(output_folder):
13
+ shutil.rmtree(output_folder)
14
+
15
+ os.makedirs(output_folder)
16
+
17
+ def save_images(self, id, pdf_path, max_pages, pages: list[int] = None) -> list[str]:
18
+ output_folder = f"pages/{id}/"
19
+ images = convert_from_path(pdf_path)
20
+
21
+ print(f"Saving images from {pdf_path} to {output_folder}. Max pages: {max_pages}")
22
+
23
+ self.clear_and_recreate_dir(output_folder)
24
+
25
+ num_page_processed = 0
26
+
27
+ for i, image in enumerate(images):
28
+ if max_pages and num_page_processed >= max_pages:
29
+ break
30
+
31
+ if pages and i not in pages:
32
+ continue
33
+
34
+ full_save_path = f"{output_folder}/page_{i + 1}.png"
35
+
36
+ #print(f"Saving image to {full_save_path}")
37
+
38
+ image.save(full_save_path, "PNG")
39
+
40
+ num_page_processed += 1
41
+
42
+ return [f"{output_folder}/page_{i + 1}.png" for i in range(num_page_processed)]
rag.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import os
3
+ import google.generativeai as genai
4
+
5
+ from typing import List
6
+ from FAPS.utils import encode_image
7
+ from PIL import Image
8
+
9
+ class Rag:
10
+
11
+ def get_answer_from_gemini(self, query, imagePaths):
12
+
13
+ print(f"Querying Gemini for query={query}, imagePaths={imagePaths}")
14
+
15
+ try:
16
+ genai.configure(api_key=os.environ['GEMINI_API_KEY'])
17
+ model = genai.GenerativeModel('gemini-1.5-flash')
18
+
19
+ images = [Image.open(path) for path in imagePaths]
20
+
21
+ chat = model.start_chat()
22
+
23
+ response = chat.send_message([*images, query])
24
+
25
+ answer = response.text
26
+
27
+ print(answer)
28
+
29
+ return answer
30
+
31
+ except Exception as e:
32
+ print(f"An error occurred while querying Gemini: {e}")
33
+ return f"Error: {str(e)}"
34
+
35
+
36
+ # def get_answer_from_openai(self, query, imagesPaths):
37
+ # print(f"Querying OpenAI for query={query}, imagesPaths={imagesPaths}")
38
+
39
+ # try:
40
+ # payload = self.__get_openai_api_payload(query, imagesPaths)
41
+
42
+ # headers = {
43
+ # "Content-Type": "application/json",
44
+ # "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
45
+ # }
46
+
47
+ # response = requests.post(
48
+ # url="https://api.openai.com/v1/chat/completions",
49
+ # headers=headers,
50
+ # json=payload
51
+ # )
52
+ # response.raise_for_status() # Raise an HTTPError for bad responses
53
+
54
+ # answer = response.json()["choices"][0]["message"]["content"]
55
+
56
+ # print(answer)
57
+
58
+ # return answer
59
+
60
+ # except Exception as e:
61
+ # print(f"An error occurred while querying OpenAI: {e}")
62
+ # return None
63
+
64
+
65
+ # def __get_openai_api_payload(self, query:str, imagesPaths:List[str]):
66
+ # image_payload = []
67
+
68
+ # for imagePath in imagesPaths:
69
+ # base64_image = encode_image(imagePath)
70
+ # image_payload.append({
71
+ # "type": "image_url",
72
+ # "image_url": {
73
+ # "url": f"data:image/jpeg;base64,{base64_image}"
74
+ # }
75
+ # })
76
+
77
+ # payload = {
78
+ # "model": "gpt-4o",
79
+ # "messages": [
80
+ # {
81
+ # "role": "user",
82
+ # "content": [
83
+ # {
84
+ # "type": "text",
85
+ # "text": query
86
+ # },
87
+ # *image_payload
88
+ # ]
89
+ # }
90
+ # ],
91
+ # "max_tokens": 1024
92
+ # }
93
+
94
+ # return payload
95
+
96
+
97
+
98
+ # if __name__ == "__main__":
99
+ # rag = Rag()
100
+
101
+ # query = "Based on attached images, how many new cases were reported during second wave peak"
102
+ # imagesPaths = ["covid_slides_page_8.png", "covid_slides_page_8.png"]
103
+
104
+ # rag.get_answer_from_gemini(query, imagesPaths)
requirements.txt CHANGED
@@ -1 +1,9 @@
1
- huggingface_hub==0.25.2
 
 
 
 
 
 
 
 
 
1
+ gradio==4.25.0
2
+ PyMuPDF==1.24.9
3
+ pdf2image==1.17.0
4
+ pymilvus==2.4.9
5
+ colpali_engine==0.3.4
6
+ tqdm==4.66.5
7
+ pillow==10.4.0
8
+ spaces==0.30.4
9
+ google-generativeai==0.8.3
utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import base64
2
+
3
+ def encode_image(image_path):
4
+ with open(image_path, "rb") as image_file:
5
+ return base64.b64encode(image_file.read()).decode('utf-8')