Spaces:
Running
on
Zero
Running
on
Zero
"""Template Demo for IBM Granite Hugging Face spaces.""" | |
import os | |
import time | |
from pathlib import Path | |
import re | |
import gradio as gr | |
import spaces | |
import torch | |
from gradio_pdf import PDF | |
from sandbox.light_rag.light_rag import LightRAG | |
from themes.research_monochrome import theme | |
dir_ = Path(__file__).parent.parent | |
TITLE = "Multimodal RAG with Granite Vision 3.2" | |
DESCRIPTION = """ | |
<p>This experimental demo highlights granite-vision-3.2-2b capabilities within a multimodal retrieval-augmented generation (RAG) pipeline, demonstrating Granite's document understanding in real-world applications. Explore the sample document excerpts and try the sample prompts or enter your own. Keep in mind that AI can occasionally make mistakes. | |
<span class="gr_docs_link"> | |
<a href="https://www.ibm.com/granite/docs/models/vision/">View Documentation <i class="fa fa-external-link"></i></a> | |
</span> | |
</p> | |
""" | |
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") | |
BASE_PATH = dir_ / "data" / "final_v2_mar04" | |
PDFS_PATH = BASE_PATH / "pdfs" | |
MILVUS_PATH = BASE_PATH / "milvus" | |
IMAGES_PATH = BASE_PATH / "images" | |
PREVIEWS_PATH = BASE_PATH / "preview" | |
sample_data = [ | |
{ | |
"preview_image": str(PREVIEWS_PATH / "IBM-financial-2010.png"), | |
"prompts": """Where geographically was the greatest growth in revenue in 2007? | |
Which year had the highest income in billion? | |
Did the net income decrease in 2007 compared to 2006? | |
Net cash from operations on 2005? | |
What does it mean to be Globally Integrated Enterprise? | |
What are the segments for pretax income?""".split("\n"), | |
"pdf": str(PDFS_PATH / "IBM_Annual_Report_2007_3-20.pdf"), | |
"index": "ibm_report_2007_short_text_milvus_lite_2048_128_slate_278m_cosine", | |
"db": str(MILVUS_PATH / "milvus.db"), | |
"name": "IBM annual report 2007", | |
"origin": "https://www.ibm.com/investor/att/pdf/IBM_Annual_Report_2007.pdf", | |
"image_paths": {"prefix": str(IMAGES_PATH / "ibm_report_2007") + "/", "use_last": 2}, | |
}, | |
{ | |
"preview_image": str(PREVIEWS_PATH / "Wilhlborg-financial.png"), | |
"prompts": """Where does Wihlborgs mainly operate? | |
Which year had the second lowest Equity/assets ratio? | |
Which year had the highest Project investments value? | |
What is the trend of equity/assets ratio? | |
What was the Growth percentage in income from property management in 2020? | |
Has the company’s interest coverage ratio increased or decreased in recent years?""".split("\n") | |
, | |
"pdf": str(PDFS_PATH / "wihlborgs-2-13_16-18.pdf"), | |
"index": "wihlborgs_short_text_milvus_lite_2048_128_slate_278m_cosine", | |
"db": str(MILVUS_PATH / "milvus.db"), | |
"name": "Wihlborgs Report 2020", | |
"origin": "https://www.wihlborgs.se/globalassets/investor-relations/rapporter/2021/20210401-wihlborgs-annual-report-and-sustainability-report-2020-c24a6b51-c124-44fc-a4af-4237a33a29fb.pdf", | |
"image_paths": {"prefix": str(IMAGES_PATH / "wihlborgs") + "/", "use_last": 2}, | |
}, | |
] | |
config = { | |
"embedding_model_id": "ibm-granite/granite-embedding-278m-multilingual", | |
"generation_model_id": "ibm-granite/granite-3.1-8b-instruct", | |
"milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine", | |
"milvus_db_path": str(dir_ / "data" / MILVUS_PATH / "milvus_text_sample.db"), | |
} | |
if gr.NO_RELOAD: | |
light_rag: LightRAG = LightRAG(config) | |
if not os.environ.get("LAZY_LOADING") == "true": | |
for sample in sample_data: | |
light_rag.precache_milvus(sample["index"], sample["db"]) | |
def lower_md_headers(md: str) -> str: | |
return re.sub(r'(?:^|\n)##?\s(.+)', lambda m: '\n### ' + m.group(1), md) | |
# Parser for retrival results | |
def format_retrieval_result(i, d, cb, selected_sample): | |
image_paths = sample_data[selected_sample]["image_paths"] | |
if d.metadata["type"] == "text": | |
context_string = f"---\n## Context {i + 1}\n#### (text extracted from document)\n{lower_md_headers(d.page_content)}\n" | |
cb.append(gr.ChatMessage(role="assistant", content=context_string)) | |
return True | |
elif d.metadata["type"] == "image_description": | |
context_string = f"---\n## Context {i + 1}\n#### (image description generated by Granite Vision)" | |
cb.append(gr.ChatMessage(role="assistant", content=context_string)) | |
# /dccstor/mm-rag/idanfr/granite_vision_demo/wdu_output/IBM_Annual_Report_2007/images/IBM_Annual_Report_2007_im_image_7_1.png | |
image_path_parts = d.metadata["image_fullpath"].split("/") | |
image_path = image_paths["prefix"] + ("/".join(image_path_parts[-image_paths["use_last"]:])) | |
# print(f"image_path: {image_path}") | |
cb.append(gr.ChatMessage(role="assistant", content=gr.Image(image_path))) | |
cb.append(gr.ChatMessage(role="assistant", content=f"\n{lower_md_headers(d.metadata['image_description'])}\n")) | |
chatbot = gr.Chatbot( | |
examples=[{"text": x} for x in sample_data[0]["prompts"]], | |
type="messages", | |
label=f"Q&A about {sample_data[0]['name']}", | |
height=685, | |
group_consecutive_messages=True, | |
autoscroll=False, | |
elem_classes=["chatbot_view"], | |
) | |
def generate_with_llm(query, context): | |
if os.environ.get("NO_LLM"): | |
time.sleep(2) | |
return "Now answer, just a string", query | |
return light_rag.generate(query=query, context=context) | |
# TODO: maybe add GPU back ? | |
def retrieval(collection, db, q): | |
return light_rag.search(q, top_n=3, collection=collection, db=db) | |
# ################ | |
# User Interface | |
# ################ | |
css_file_path = Path(Path(__file__).parent / "app.css") | |
head_file_path = Path(Path(__file__).parent / "app_head.html") | |
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: | |
is_in_edit_mode = gr.State(True) # in block to be reactive | |
selected_doc = gr.State(0) | |
current_question = gr.State("") | |
gr.Markdown(f"# {TITLE}") | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
# LEFT COLUMN: Sample selection, download, and PDF viewer | |
with gr.Column(): | |
# Show preview images | |
images_only = [sd["preview_image"] for sd in sample_data] | |
document_gallery = gr.Gallery( | |
images_only, | |
label="Select a document", | |
rows=1, | |
columns=3, | |
height="125px", | |
# width="125px", | |
allow_preview=False, | |
selected_index=0, | |
elem_classes=["preview_im_element"], | |
) | |
with gr.Group(): | |
pdf_display = PDF( | |
sample_data[0]["pdf"], | |
label=f"Preview for {sample_data[0]['name']}", | |
height=460, | |
interactive=False, | |
elem_classes=["pdf_viewer"], | |
) | |
dl_btn = gr.DownloadButton( | |
label=f"Download PDF ({sample_data[0]['name']})", value=sample_data[0]["pdf"], visible=True | |
) | |
def sample_image_selected(d: gr.SelectData): | |
dx = sample_data[d.index] | |
# print(f"DX:{dx}") | |
return ( | |
gr.update(examples=[{"text": x} for x in dx["prompts"]], label=f"Q&A about {dx['name']}"), | |
gr.update(value=dx["pdf"], label=f"Preview for {dx['name']}"), | |
gr.DownloadButton(value=dx["pdf"], label=f"Download PDF ({dx['name']})"), | |
d.index | |
) | |
document_gallery.select(lambda: [], outputs=[chatbot]) | |
document_gallery.select(sample_image_selected, inputs=[], | |
outputs=[chatbot, pdf_display, dl_btn, selected_doc]) | |
# Right Column: Chat interface | |
with gr.Column(): | |
# Render ChatBot | |
chatbot.render() | |
# Define behavior for example selection | |
def update_user_chat_x(x: gr.SelectData): | |
return [gr.ChatMessage(role="user", content=x.value["text"])] | |
def question_from_selection(x: gr.SelectData): | |
return x.value["text"] | |
def _decorate_yield_result(cb, fb_status=False, gallery_status=False): | |
return cb, gr.Button(interactive=fb_status), gr.Gallery( | |
elem_classes=["preview_im_element"] if gallery_status else ["preview_im_element", "inactive_div"]) | |
def send_generate(msg, cb, selected_sample): | |
collection = sample_data[selected_sample]["index"] | |
db = sample_data[selected_sample]["db"] | |
# print(f"collection: {collection}, {db}") | |
original_msg = gr.ChatMessage(role="user", content=msg) | |
cb.append(original_msg) | |
waiting_for_retrieval_msg = gr.ChatMessage(role="assistant", | |
content='## Answer\n*Querying Index*<span class="jumping-dots"><span class="dot-1">.</span> <span class="dot-2">.</span> <span class="dot-3">.</span></span>') | |
cb.append(waiting_for_retrieval_msg) | |
yield _decorate_yield_result(cb) | |
q = msg.strip() | |
results = retrieval(collection, db, q) | |
# for d in results: | |
# print(f"****\n{d}") | |
context_string = "## Context Documents for Answer\n\n" | |
for i, d in enumerate(results): | |
if format_retrieval_result(i, d, cb, selected_sample): | |
yield _decorate_yield_result(cb) | |
waiting_for_llm_msg = gr.ChatMessage(role="assistant", | |
content='## Answer\n *Waiting for LLM* <span class="jumping-dots"><span class="dot-1">.</span> <span class="dot-2">.</span> <span class="dot-3">.</span></span> ') | |
cb[1] = waiting_for_llm_msg | |
yield _decorate_yield_result(cb) | |
answer, prompt = generate_with_llm(q, results) | |
cb[1] = gr.ChatMessage(role="assistant", content=f"## Answer\n<b>{answer.strip()}</b>") | |
# cb.pop() | |
# cb.append(gr.ChatMessage(role="assistant", content=f"## Answer\n<b>{answer.strip()}</b>")) | |
yield _decorate_yield_result(cb, fb_status=True, gallery_status=True) | |
# Create User Chat Textbox and Reset Button | |
tbb = gr.Textbox(submit_btn=True, show_label=False, placeholder="Type a message...") | |
fb = gr.Button("Ask new question", visible=False) | |
fb.click(lambda: [], outputs=[chatbot]) | |
chatbot.example_select(lambda: False, outputs=is_in_edit_mode) | |
# chatbot.example_select(update_user_chat_x, outputs=[chatbot]) | |
chatbot.example_select(question_from_selection, inputs=[], outputs=[current_question] | |
).then(send_generate, inputs=[current_question, chatbot, selected_doc], | |
outputs=[chatbot, fb, document_gallery]) | |
def textbox_switch(e_mode): # Handle toggling between edit and non-edit mode | |
if not e_mode: | |
return [gr.update(visible=False), gr.update(visible=True)] | |
else: | |
return [gr.update(visible=True), gr.update(visible=False)] | |
tbb.submit(lambda: False, outputs=[is_in_edit_mode]) | |
fb.click(lambda: True, outputs=[is_in_edit_mode]) | |
is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb]) | |
# submit user question | |
# tbb.submit(lambda x: [gr.ChatMessage(role="user", content=x)], inputs=tbb, outputs=chatbot) | |
tbb.submit(lambda x: x, inputs=[tbb], outputs=[current_question] | |
).then(send_generate, | |
inputs=[current_question, chatbot, selected_doc], | |
outputs=[chatbot, fb, document_gallery]) | |
if __name__ == "__main__": | |
# demo.queue(max_size=20).launch() | |
demo.launch() | |