Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,34 +1,35 @@
|
|
1 |
import subprocess # 🥲
|
2 |
-
|
3 |
subprocess.run(
|
4 |
"pip install flash-attn --no-build-isolation",
|
5 |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
6 |
shell=True,
|
7 |
)
|
|
|
8 |
import spaces
|
9 |
import gradio as gr
|
10 |
import re
|
11 |
-
|
12 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
13 |
-
from qwen_vl_utils import process_vision_info
|
14 |
import torch
|
15 |
import os
|
16 |
import json
|
|
|
17 |
from pydantic import BaseModel
|
18 |
from typing import Tuple
|
|
|
|
|
|
|
19 |
|
20 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
21 |
|
|
|
22 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
23 |
"Qwen/Qwen2.5-VL-7B-Instruct",
|
24 |
torch_dtype=torch.bfloat16,
|
25 |
attn_implementation="flash_attention_2",
|
26 |
device_map="auto",
|
27 |
)
|
28 |
-
processor = AutoProcessor.from_pretrained(
|
29 |
-
"Qwen/Qwen2.5-VL-7B-Instruct",
|
30 |
-
)
|
31 |
|
|
|
32 |
class GeneralRetrievalQuery(BaseModel):
|
33 |
broad_topical_query: str
|
34 |
broad_topical_explanation: str
|
@@ -38,21 +39,15 @@ class GeneralRetrievalQuery(BaseModel):
|
|
38 |
visual_element_explanation: str
|
39 |
|
40 |
def extract_json_with_regex(text):
|
41 |
-
# Pattern to match content between code backticks
|
42 |
pattern = r'```(?:json)?\s*(.+?)\s*```'
|
43 |
-
|
44 |
-
# Find all matches (should typically be one)
|
45 |
matches = re.findall(pattern, text, re.DOTALL)
|
46 |
-
|
47 |
if matches:
|
48 |
-
# Return the first match
|
49 |
return matches[0]
|
50 |
return None
|
51 |
|
52 |
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
|
53 |
if prompt_name != "general":
|
54 |
raise ValueError("Only 'general' prompt is available in this version")
|
55 |
-
|
56 |
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
|
57 |
|
58 |
Please generate 3 different types of retrieval queries:
|
@@ -85,32 +80,25 @@ Here is the document image to analyze:
|
|
85 |
<image>
|
86 |
|
87 |
Generate the queries based on this image and provide the response in the specified JSON format."""
|
88 |
-
|
89 |
return prompt, GeneralRetrievalQuery
|
90 |
|
91 |
-
# defined like this so we can later add more prompting options
|
92 |
prompt, pydantic_model = get_retrieval_prompt("general")
|
93 |
|
|
|
94 |
def _prep_data_for_input(image):
|
95 |
messages = [
|
96 |
{
|
97 |
"role": "user",
|
98 |
"content": [
|
99 |
-
{
|
100 |
-
"type": "image",
|
101 |
-
"image": image,
|
102 |
-
},
|
103 |
{"type": "text", "text": prompt},
|
104 |
],
|
105 |
}
|
106 |
]
|
107 |
-
|
108 |
text = processor.apply_chat_template(
|
109 |
messages, tokenize=False, add_generation_prompt=True
|
110 |
)
|
111 |
-
|
112 |
image_inputs, video_inputs = process_vision_info(messages)
|
113 |
-
|
114 |
return processor(
|
115 |
text=[text],
|
116 |
images=image_inputs,
|
@@ -119,17 +107,40 @@ def _prep_data_for_input(image):
|
|
119 |
return_tensors="pt",
|
120 |
)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
@spaces.GPU
|
123 |
-
def generate_response(image):
|
124 |
inputs = _prep_data_for_input(image)
|
125 |
inputs = inputs.to("cuda")
|
126 |
-
|
127 |
generated_ids = model.generate(**inputs, max_new_tokens=200)
|
128 |
generated_ids_trimmed = [
|
129 |
-
out_ids[len(in_ids)
|
130 |
-
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
131 |
]
|
132 |
-
|
133 |
output_text = processor.batch_decode(
|
134 |
generated_ids_trimmed,
|
135 |
skip_special_tokens=True,
|
@@ -137,32 +148,27 @@ def generate_response(image):
|
|
137 |
)[0]
|
138 |
|
139 |
try:
|
140 |
-
# Try to extract JSON from code block first
|
141 |
json_str = extract_json_with_regex(output_text)
|
142 |
if json_str:
|
143 |
parsed = json.loads(json_str)
|
144 |
-
return
|
145 |
-
# If no code block found, try direct JSON parsing
|
146 |
parsed = json.loads(output_text)
|
147 |
-
return
|
148 |
except Exception:
|
149 |
gr.Warning("Failed to parse JSON from output")
|
150 |
return output_text
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
To make the ColPali models work even better we might want a dataset of query/image document pairs related to our domain or task.
|
157 |
-
|
158 |
-
One way in which we might go about generating such a dataset is to use a VLM to generate synthetic queries for us.
|
159 |
-
This space uses the [Qwen/Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) VLM model to generate queries for a document, based on an input document image.
|
160 |
|
161 |
-
**
|
|
|
|
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
If you want to convert a PDF(s) to a dataset of page images you can try out the [ PDFs to Page Images Converter](https://huggingface.co/spaces/Dataset-Creation-Tools/pdf-to-page-images-dataset) Space.
|
166 |
"""
|
167 |
|
168 |
examples = [
|
@@ -170,12 +176,85 @@ examples = [
|
|
170 |
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
|
171 |
]
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import subprocess # 🥲
|
|
|
2 |
subprocess.run(
|
3 |
"pip install flash-attn --no-build-isolation",
|
4 |
env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
|
5 |
shell=True,
|
6 |
)
|
7 |
+
|
8 |
import spaces
|
9 |
import gradio as gr
|
10 |
import re
|
|
|
|
|
|
|
11 |
import torch
|
12 |
import os
|
13 |
import json
|
14 |
+
import time
|
15 |
from pydantic import BaseModel
|
16 |
from typing import Tuple
|
17 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
18 |
+
from qwen_vl_utils import process_vision_info
|
19 |
+
from PIL import Image
|
20 |
|
21 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
22 |
|
23 |
+
# ----------------------- 모델 및 프로세서 로드 ----------------------- #
|
24 |
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
25 |
"Qwen/Qwen2.5-VL-7B-Instruct",
|
26 |
torch_dtype=torch.bfloat16,
|
27 |
attn_implementation="flash_attention_2",
|
28 |
device_map="auto",
|
29 |
)
|
30 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
|
|
|
|
31 |
|
32 |
+
# ----------------------- Pydantic 모델 정의 ----------------------- #
|
33 |
class GeneralRetrievalQuery(BaseModel):
|
34 |
broad_topical_query: str
|
35 |
broad_topical_explanation: str
|
|
|
39 |
visual_element_explanation: str
|
40 |
|
41 |
def extract_json_with_regex(text):
|
|
|
42 |
pattern = r'```(?:json)?\s*(.+?)\s*```'
|
|
|
|
|
43 |
matches = re.findall(pattern, text, re.DOTALL)
|
|
|
44 |
if matches:
|
|
|
45 |
return matches[0]
|
46 |
return None
|
47 |
|
48 |
def get_retrieval_prompt(prompt_name: str) -> Tuple[str, GeneralRetrievalQuery]:
|
49 |
if prompt_name != "general":
|
50 |
raise ValueError("Only 'general' prompt is available in this version")
|
|
|
51 |
prompt = """You are an AI assistant specialized in document retrieval tasks. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this document in a large corpus.
|
52 |
|
53 |
Please generate 3 different types of retrieval queries:
|
|
|
80 |
<image>
|
81 |
|
82 |
Generate the queries based on this image and provide the response in the specified JSON format."""
|
|
|
83 |
return prompt, GeneralRetrievalQuery
|
84 |
|
|
|
85 |
prompt, pydantic_model = get_retrieval_prompt("general")
|
86 |
|
87 |
+
# ----------------------- 입력 데이터 전처리 ----------------------- #
|
88 |
def _prep_data_for_input(image):
|
89 |
messages = [
|
90 |
{
|
91 |
"role": "user",
|
92 |
"content": [
|
93 |
+
{"type": "image", "image": image},
|
|
|
|
|
|
|
94 |
{"type": "text", "text": prompt},
|
95 |
],
|
96 |
}
|
97 |
]
|
|
|
98 |
text = processor.apply_chat_template(
|
99 |
messages, tokenize=False, add_generation_prompt=True
|
100 |
)
|
|
|
101 |
image_inputs, video_inputs = process_vision_info(messages)
|
|
|
102 |
return processor(
|
103 |
text=[text],
|
104 |
images=image_inputs,
|
|
|
107 |
return_tensors="pt",
|
108 |
)
|
109 |
|
110 |
+
# ----------------------- 출력 형식 변환 함수 ----------------------- #
|
111 |
+
def format_output(data: dict, output_format: str) -> str:
|
112 |
+
"""
|
113 |
+
data: 파싱된 JSON 딕셔너리
|
114 |
+
output_format: "JSON", "Markdown", "Table" 중 하나
|
115 |
+
"""
|
116 |
+
if output_format == "JSON":
|
117 |
+
return json.dumps(data, indent=2, ensure_ascii=False)
|
118 |
+
elif output_format == "Markdown":
|
119 |
+
# 각 항목을 Markdown 문단 형식으로 출력
|
120 |
+
md_lines = []
|
121 |
+
for key, value in data.items():
|
122 |
+
md_lines.append(f"**{key.replace('_', ' ').title()}:** {value}")
|
123 |
+
return "\n\n".join(md_lines)
|
124 |
+
elif output_format == "Table":
|
125 |
+
# 간단한 Markdown 표 형식으로 변환
|
126 |
+
headers = ["Field", "Content"]
|
127 |
+
separator = "|".join(["---"] * len(headers))
|
128 |
+
rows = [f"| {' | '.join(headers)} |", f"|{separator}|"]
|
129 |
+
for key, value in data.items():
|
130 |
+
rows.append(f"| {key.replace('_', ' ').title()} | {value} |")
|
131 |
+
return "\n".join(rows)
|
132 |
+
else:
|
133 |
+
return json.dumps(data, indent=2, ensure_ascii=False)
|
134 |
+
|
135 |
+
# ----------------------- 응답 생성 함수 ----------------------- #
|
136 |
@spaces.GPU
|
137 |
+
def generate_response(image, output_format: str = "JSON"):
|
138 |
inputs = _prep_data_for_input(image)
|
139 |
inputs = inputs.to("cuda")
|
|
|
140 |
generated_ids = model.generate(**inputs, max_new_tokens=200)
|
141 |
generated_ids_trimmed = [
|
142 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
|
143 |
]
|
|
|
144 |
output_text = processor.batch_decode(
|
145 |
generated_ids_trimmed,
|
146 |
skip_special_tokens=True,
|
|
|
148 |
)[0]
|
149 |
|
150 |
try:
|
|
|
151 |
json_str = extract_json_with_regex(output_text)
|
152 |
if json_str:
|
153 |
parsed = json.loads(json_str)
|
154 |
+
return format_output(parsed, output_format)
|
|
|
155 |
parsed = json.loads(output_text)
|
156 |
+
return format_output(parsed, output_format)
|
157 |
except Exception:
|
158 |
gr.Warning("Failed to parse JSON from output")
|
159 |
return output_text
|
160 |
|
161 |
+
# ----------------------- 인터페이스 제목 및 설명 ----------------------- #
|
162 |
+
title = "Elegant ColPali Query Generator using Qwen2.5-VL"
|
163 |
+
description = """**ColPali**는 문서 검색에 최적화된 멀티모달 접근법입니다.
|
164 |
+
이 인터페이스는 [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) 모델을 사용하여, 문서 이미지로부터 관련 검색 쿼리를 생성합니다.
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
- **Broad Topical Query:** 문서의 주요 주제를 포괄하는 쿼리
|
167 |
+
- **Specific Detail Query:** 문서 내 특정 사실이나 수치를 포함한 쿼리
|
168 |
+
- **Visual Element Query:** 문서의 시각적 요소(예: 차트, 그래프 등)를 기반으로 한 쿼리
|
169 |
|
170 |
+
아래 예제를 참고하여, 문서 이미지에 적합한 쿼리를 생성해 보세요.
|
171 |
+
더 자세한 정보는 [블로그 포스트](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html)를 참조하세요.
|
|
|
172 |
"""
|
173 |
|
174 |
examples = [
|
|
|
176 |
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
|
177 |
]
|
178 |
|
179 |
+
# ----------------------- 커스텀 CSS ----------------------- #
|
180 |
+
custom_css = """
|
181 |
+
body {
|
182 |
+
background: #f7f9fb;
|
183 |
+
font-family: 'Segoe UI', sans-serif;
|
184 |
+
color: #333;
|
185 |
+
}
|
186 |
+
header {
|
187 |
+
text-align: center;
|
188 |
+
padding: 20px;
|
189 |
+
margin-bottom: 20px;
|
190 |
+
}
|
191 |
+
header h1 {
|
192 |
+
font-size: 3em;
|
193 |
+
color: #2c3e50;
|
194 |
+
}
|
195 |
+
.gradio-container {
|
196 |
+
padding: 20px;
|
197 |
+
}
|
198 |
+
.gr-button {
|
199 |
+
background-color: #3498db !important;
|
200 |
+
color: #fff !important;
|
201 |
+
border: none !important;
|
202 |
+
padding: 10px 20px !important;
|
203 |
+
border-radius: 5px !important;
|
204 |
+
font-size: 1em !important;
|
205 |
+
}
|
206 |
+
.gr-button:hover {
|
207 |
+
background-color: #2980b9 !important;
|
208 |
+
}
|
209 |
+
.gr-gallery-item {
|
210 |
+
border-radius: 10px;
|
211 |
+
overflow: hidden;
|
212 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
213 |
+
}
|
214 |
+
footer {
|
215 |
+
text-align: center;
|
216 |
+
padding: 20px 0;
|
217 |
+
font-size: 0.9em;
|
218 |
+
color: #555;
|
219 |
+
}
|
220 |
+
"""
|
221 |
+
|
222 |
+
# ----------------------- Gradio 인터페이스 구성 ----------------------- #
|
223 |
+
with gr.Blocks(css=custom_css, title=title) as demo:
|
224 |
+
with gr.Column(variant="panel"):
|
225 |
+
gr.Markdown(f"<header><h1>{title}</h1></header>")
|
226 |
+
gr.Markdown(description)
|
227 |
+
|
228 |
+
with gr.Tabs():
|
229 |
+
with gr.TabItem("Query Generation"):
|
230 |
+
gr.Markdown("### Generate Retrieval Queries from a Document Image")
|
231 |
+
with gr.Row():
|
232 |
+
image_input = gr.Image(label="Upload Document Image", type="pil")
|
233 |
+
with gr.Row():
|
234 |
+
# 출력 형식 선택 옵션 추가
|
235 |
+
output_format = gr.Radio(
|
236 |
+
choices=["JSON", "Markdown", "Table"],
|
237 |
+
value="JSON",
|
238 |
+
label="Output Format",
|
239 |
+
info="Select the desired output format."
|
240 |
+
)
|
241 |
+
generate_button = gr.Button("Generate Query")
|
242 |
+
output_text = gr.Textbox(label="Generated Query", lines=10)
|
243 |
+
with gr.Accordion("Examples", open=False):
|
244 |
+
gr.Examples(
|
245 |
+
label="Query Examples",
|
246 |
+
examples=[
|
247 |
+
"examples/Approche_no_13_1977.pdf_page_22.jpg",
|
248 |
+
"examples/SRCCL_Technical-Summary.pdf_page_7.jpg",
|
249 |
+
],
|
250 |
+
inputs=image_input,
|
251 |
+
)
|
252 |
+
generate_button.click(
|
253 |
+
fn=generate_response,
|
254 |
+
inputs=[image_input, output_format],
|
255 |
+
outputs=output_text
|
256 |
+
)
|
257 |
+
|
258 |
+
gr.Markdown("<footer>Join our community on <a href='https://discord.gg/openfreeai' target='_blank'>Discord</a></footer>")
|
259 |
+
|
260 |
+
demo.launch()
|