ginipick commited on
Commit
45b6f79
·
verified ·
1 Parent(s): 1fc115f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -52
app.py CHANGED
@@ -1,34 +1,35 @@
1
  import subprocess # 🥲
2
-
3
  subprocess.run(
4
  "pip install flash-attn --no-build-isolation",
5
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
6
  shell=True,
7
  )
 
8
  import spaces
9
  import gradio as gr
10
  import re
11
-
12
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
13
- from qwen_vl_utils import process_vision_info
14
  import torch
15
  import os
16
  import json
 
17
  from pydantic import BaseModel
18
  from typing import Tuple
 
 
 
19
 
20
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
21
 
 
22
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
  "Qwen/Qwen2.5-VL-7B-Instruct",
24
  torch_dtype=torch.bfloat16,
25
  attn_implementation="flash_attention_2",
26
  device_map="auto",
27
  )
28
- processor = AutoProcessor.from_pretrained(
29
- "Qwen/Qwen2.5-VL-7B-Instruct",
30
- )
31
 
 
32
  class GeneralRetrievalQuery(BaseModel):
33
  broad_topical_query: str
34
  broad_topical_explanation: str
@@ -38,21 +39,15 @@ class GeneralRetrievalQuery(BaseModel):
38
  visual_element_explanation: str
39
 
40
  def extract_json_with_regex(text):
41
- # Pattern to match content between code backticks
42
  pattern = r'```(?:json)?\s*(.+?)\s*```'
43
-
44
- # Find all matches (should typically be one)
45
  matches = re.findall(pattern, text, re.DOTALL)
46
-
47
  if matches:
48
- # Return the first match
49
  return matches[0]
50
  return None
51
 
52
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
53
  if prompt_name != "general":
54
  raise ValueError("Only 'general' prompt is available in this version")
55
-
56
  prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
57
 
58
  Please generate 3 different types of retrieval queries:
@@ -85,32 +80,25 @@ Here is the document image to analyze:
85
  <image>
86
 
87
  Generate the queries based on this image and provide the response in the specified JSON format."""
88
-
89
  return prompt, GeneralRetrievalQuery
90
 
91
- # defined like this so we can later add more prompting options
92
  prompt, pydantic_model = get_retrieval_prompt("general")
93
 
 
94
  def _prep_data_for_input(image):
95
  messages = [
96
  {
97
  "role": "user",
98
  "content": [
99
- {
100
- "type": "image",
101
- "image": image,
102
- },
103
  {"type": "text", "text": prompt},
104
  ],
105
  }
106
  ]
107
-
108
  text = processor.apply_chat_template(
109
  messages, tokenize=False, add_generation_prompt=True
110
  )
111
-
112
  image_inputs, video_inputs = process_vision_info(messages)
113
-
114
  return processor(
115
  text=[text],
116
  images=image_inputs,
@@ -119,17 +107,40 @@ def _prep_data_for_input(image):
119
  return_tensors="pt",
120
  )
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  @spaces.GPU
123
- def generate_response(image):
124
  inputs = _prep_data_for_input(image)
125
  inputs = inputs.to("cuda")
126
-
127
  generated_ids = model.generate(**inputs, max_new_tokens=200)
128
  generated_ids_trimmed = [
129
- out_ids[len(in_ids) :]
130
- for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
131
  ]
132
-
133
  output_text = processor.batch_decode(
134
  generated_ids_trimmed,
135
  skip_special_tokens=True,
@@ -137,32 +148,27 @@ def generate_response(image):
137
  )[0]
138
 
139
  try:
140
- # Try to extract JSON from code block first
141
  json_str = extract_json_with_regex(output_text)
142
  if json_str:
143
  parsed = json.loads(json_str)
144
- return json.dumps(parsed, indent=2)
145
- # If no code block found, try direct JSON parsing
146
  parsed = json.loads(output_text)
147
- return json.dumps(parsed, indent=2)
148
  except Exception:
149
  gr.Warning("Failed to parse JSON from output")
150
  return output_text
151
 
152
- title = "ColPali Query Generator using Qwen2.5-VL"
153
- description = """[ColPali](https://huggingface.co/papers/2407.01449) is a very exciting new approach to multimodal document retrieval which aims to replace existing document retrievers which often rely on an OCR step with an end-to-end multimodal approach.
154
-
155
- To train or fine-tune a ColPali model, we need a dataset of image-text pairs which represent the document images and the relevant text queries which those documents should match.
156
- To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
157
-
158
- One way in which we might go about generating such a dataset is to use a VLM to generate synthetic queries for us.
159
- This space uses the [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) VLM model to generate queries for a document, based on an input document image.
160
 
161
- **Note** there is a lot of scope for improving to prompts and the quality of the generated queries! If you have any suggestions for improvements please [open a Discussion](https://huggingface.co/spaces/davanstrien/ColPali-Query-Generator/discussions/new)!
 
 
162
 
163
- This [blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) gives an overview of how you can use this kind of approach to generate a full dataset for fine-tuning ColPali models.
164
-
165
- If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space.
166
  """
167
 
168
  examples = [
@@ -170,12 +176,85 @@ examples = [
170
  "examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
171
  ]
172
 
173
- demo = gr.Interface(
174
- fn=generate_response,
175
- inputs=gr.Image(type="pil"),
176
- outputs=gr.Text(),
177
- title=title,
178
- description=description,
179
- examples=examples,
180
- )
181
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import subprocess # 🥲
 
2
  subprocess.run(
3
  "pip install flash-attn --no-build-isolation",
4
  env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
5
  shell=True,
6
  )
7
+
8
  import spaces
9
  import gradio as gr
10
  import re
 
 
 
11
  import torch
12
  import os
13
  import json
14
+ import time
15
  from pydantic import BaseModel
16
  from typing import Tuple
17
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
18
+ from qwen_vl_utils import process_vision_info
19
+ from PIL import Image
20
 
21
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
22
 
23
+ # ----------------------- 모델 및 프로세서 로드 ----------------------- #
24
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
25
  "Qwen/Qwen2.5-VL-7B-Instruct",
26
  torch_dtype=torch.bfloat16,
27
  attn_implementation="flash_attention_2",
28
  device_map="auto",
29
  )
30
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
 
 
31
 
32
+ # ----------------------- Pydantic 모델 정의 ----------------------- #
33
  class GeneralRetrievalQuery(BaseModel):
34
  broad_topical_query: str
35
  broad_topical_explanation: str
 
39
  visual_element_explanation: str
40
 
41
  def extract_json_with_regex(text):
 
42
  pattern = r'```(?:json)?\s*(.+?)\s*```'
 
 
43
  matches = re.findall(pattern, text, re.DOTALL)
 
44
  if matches:
 
45
  return matches[0]
46
  return None
47
 
48
  def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
49
  if prompt_name != "general":
50
  raise ValueError("Only 'general' prompt is available in this version")
 
51
  prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
52
 
53
  Please generate 3 different types of retrieval queries:
 
80
  <image>
81
 
82
  Generate the queries based on this image and provide the response in the specified JSON format."""
 
83
  return prompt, GeneralRetrievalQuery
84
 
 
85
  prompt, pydantic_model = get_retrieval_prompt("general")
86
 
87
+ # ----------------------- 입력 데이터 전처리 ----------------------- #
88
  def _prep_data_for_input(image):
89
  messages = [
90
  {
91
  "role": "user",
92
  "content": [
93
+ {"type": "image", "image": image},
 
 
 
94
  {"type": "text", "text": prompt},
95
  ],
96
  }
97
  ]
 
98
  text = processor.apply_chat_template(
99
  messages, tokenize=False, add_generation_prompt=True
100
  )
 
101
  image_inputs, video_inputs = process_vision_info(messages)
 
102
  return processor(
103
  text=[text],
104
  images=image_inputs,
 
107
  return_tensors="pt",
108
  )
109
 
110
+ # ----------------------- 출력 형식 변환 함수 ----------------------- #
111
+ def format_output(data: dict, output_format: str) -> str:
112
+ """
113
+ data: 파싱된 JSON 딕셔너리
114
+ output_format: "JSON", "Markdown", "Table" 중 하나
115
+ """
116
+ if output_format == "JSON":
117
+ return json.dumps(data, indent=2, ensure_ascii=False)
118
+ elif output_format == "Markdown":
119
+ # 각 항목을 Markdown 문단 형식으로 출력
120
+ md_lines = []
121
+ for key, value in data.items():
122
+ md_lines.append(f"**{key.replace('_', ' ').title()}:** {value}")
123
+ return "\n\n".join(md_lines)
124
+ elif output_format == "Table":
125
+ # 간단한 Markdown 표 형식으로 변환
126
+ headers = ["Field", "Content"]
127
+ separator = "|".join(["---"] * len(headers))
128
+ rows = [f"| {' | '.join(headers)} |", f"|{separator}|"]
129
+ for key, value in data.items():
130
+ rows.append(f"| {key.replace('_', ' ').title()} | {value} |")
131
+ return "\n".join(rows)
132
+ else:
133
+ return json.dumps(data, indent=2, ensure_ascii=False)
134
+
135
+ # ----------------------- 응답 생성 함수 ----------------------- #
136
  @spaces.GPU
137
+ def generate_response(image, output_format: str = "JSON"):
138
  inputs = _prep_data_for_input(image)
139
  inputs = inputs.to("cuda")
 
140
  generated_ids = model.generate(**inputs, max_new_tokens=200)
141
  generated_ids_trimmed = [
142
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
143
  ]
 
144
  output_text = processor.batch_decode(
145
  generated_ids_trimmed,
146
  skip_special_tokens=True,
 
148
  )[0]
149
 
150
  try:
 
151
  json_str = extract_json_with_regex(output_text)
152
  if json_str:
153
  parsed = json.loads(json_str)
154
+ return format_output(parsed, output_format)
 
155
  parsed = json.loads(output_text)
156
+ return format_output(parsed, output_format)
157
  except Exception:
158
  gr.Warning("Failed to parse JSON from output")
159
  return output_text
160
 
161
+ # ----------------------- 인터페이스 제목 설명 ----------------------- #
162
+ title = "Elegant ColPali Query Generator using Qwen2.5-VL"
163
+ description = """**ColPali**는 문서 검색에 최적화된 멀티모달 접근법입니다.
164
+ 인터페이스는 [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) 모델을 사용하여, 문서 이미지로부터 관련 검색 쿼리를 생성합니다.
 
 
 
 
165
 
166
+ - **Broad Topical Query:** 문서의 주요 주제를 포괄하는 쿼리
167
+ - **Specific Detail Query:** 문서 내 특정 사실이나 수치를 포함한 쿼리
168
+ - **Visual Element Query:** 문서의 시각적 요소(예: 차트, 그래프 등)를 기반으로 한 쿼리
169
 
170
+ 아래 예제를 참고하여, 문서 이미지에 적합한 쿼리를 생성해 보세요.
171
+ 더 자세한 정보는 [블로그 포스트](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html)를 참조하세요.
 
172
  """
173
 
174
  examples = [
 
176
  "examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
177
  ]
178
 
179
+ # ----------------------- 커스텀 CSS ----------------------- #
180
+ custom_css = """
181
+ body {
182
+ background: #f7f9fb;
183
+ font-family: 'Segoe UI', sans-serif;
184
+ color: #333;
185
+ }
186
+ header {
187
+ text-align: center;
188
+ padding: 20px;
189
+ margin-bottom: 20px;
190
+ }
191
+ header h1 {
192
+ font-size: 3em;
193
+ color: #2c3e50;
194
+ }
195
+ .gradio-container {
196
+ padding: 20px;
197
+ }
198
+ .gr-button {
199
+ background-color: #3498db !important;
200
+ color: #fff !important;
201
+ border: none !important;
202
+ padding: 10px 20px !important;
203
+ border-radius: 5px !important;
204
+ font-size: 1em !important;
205
+ }
206
+ .gr-button:hover {
207
+ background-color: #2980b9 !important;
208
+ }
209
+ .gr-gallery-item {
210
+ border-radius: 10px;
211
+ overflow: hidden;
212
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
213
+ }
214
+ footer {
215
+ text-align: center;
216
+ padding: 20px 0;
217
+ font-size: 0.9em;
218
+ color: #555;
219
+ }
220
+ """
221
+
222
+ # ----------------------- Gradio 인터페이스 구성 ----------------------- #
223
+ with gr.Blocks(css=custom_css, title=title) as demo:
224
+ with gr.Column(variant="panel"):
225
+ gr.Markdown(f"<header><h1>{title}</h1></header>")
226
+ gr.Markdown(description)
227
+
228
+ with gr.Tabs():
229
+ with gr.TabItem("Query Generation"):
230
+ gr.Markdown("### Generate Retrieval Queries from a Document Image")
231
+ with gr.Row():
232
+ image_input = gr.Image(label="Upload Document Image", type="pil")
233
+ with gr.Row():
234
+ # 출력 형식 선택 옵션 추가
235
+ output_format = gr.Radio(
236
+ choices=["JSON", "Markdown", "Table"],
237
+ value="JSON",
238
+ label="Output Format",
239
+ info="Select the desired output format."
240
+ )
241
+ generate_button = gr.Button("Generate Query")
242
+ output_text = gr.Textbox(label="Generated Query", lines=10)
243
+ with gr.Accordion("Examples", open=False):
244
+ gr.Examples(
245
+ label="Query Examples",
246
+ examples=[
247
+ "examples/Approche_no_13_1977.pdf_page_22.jpg",
248
+ "examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
249
+ ],
250
+ inputs=image_input,
251
+ )
252
+ generate_button.click(
253
+ fn=generate_response,
254
+ inputs=[image_input, output_format],
255
+ outputs=output_text
256
+ )
257
+
258
+ gr.Markdown("<footer>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
259
+
260
+ demo.launch()