RTE Build
Deployment
a099612
"""Template Demo for IBM Granite Hugging Face spaces."""
import os
import time
from pathlib import Path
import re
import gradio as gr
import spaces
import torch
from gradio_pdf import PDF
from sandbox.light_rag.light_rag import LightRAG
from themes.research_monochrome import theme
dir_ = Path(__file__).parent.parent
TITLE = "Multimodal RAG with Granite Vision 3.2"
DESCRIPTION = """
<p>This experimental demo highlights granite-vision-3.2-2b capabilities within a multimodal retrieval-augmented generation (RAG) pipeline, demonstrating Granite's document understanding in real-world applications. Explore the sample document excerpts and try the sample prompts or enter your own. Keep in mind that AI can occasionally make mistakes.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/models/vision/">View Documentation <i class="fa fa-external-link"></i></a>
</span>
</p>
"""
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
BASE_PATH = dir_ / "data" / "final_v2_mar04"
PDFS_PATH = BASE_PATH / "pdfs"
MILVUS_PATH = BASE_PATH / "milvus"
IMAGES_PATH = BASE_PATH / "images"
PREVIEWS_PATH = BASE_PATH / "preview"
sample_data = [
{
"preview_image": str(PREVIEWS_PATH / "IBM-financial-2010.png"),
"prompts": """Where geographically was the greatest growth in revenue in 2007?
Which year had the highest income in billion?
Did the net income decrease in 2007 compared to 2006?
Net cash from operations on 2005?
What does it mean to be Globally Integrated Enterprise?
What are the segments for pretax income?""".split("\n"),
"pdf": str(PDFS_PATH / "IBM_Annual_Report_2007_3-20.pdf"),
"index": "ibm_report_2007_short_text_milvus_lite_2048_128_slate_278m_cosine",
"db": str(MILVUS_PATH / "milvus.db"),
"name": "IBM annual report 2007",
"origin": "https://www.ibm.com/investor/att/pdf/IBM_Annual_Report_2007.pdf",
"image_paths": {"prefix": str(IMAGES_PATH / "ibm_report_2007") + "/", "use_last": 2},
},
{
"preview_image": str(PREVIEWS_PATH / "Wilhlborg-financial.png"),
"prompts": """Where does Wihlborgs mainly operate?
Which year had the second lowest Equity/assets ratio?
Which year had the highest Project investments value?
What is the trend of equity/assets ratio?
What was the Growth percentage in income from property management in 2020?
Has the company’s interest coverage ratio increased or decreased in recent years?""".split("\n")
,
"pdf": str(PDFS_PATH / "wihlborgs-2-13_16-18.pdf"),
"index": "wihlborgs_short_text_milvus_lite_2048_128_slate_278m_cosine",
"db": str(MILVUS_PATH / "milvus.db"),
"name": "Wihlborgs Report 2020",
"origin": "https://www.wihlborgs.se/globalassets/investor-relations/rapporter/2021/20210401-wihlborgs-annual-report-and-sustainability-report-2020-c24a6b51-c124-44fc-a4af-4237a33a29fb.pdf",
"image_paths": {"prefix": str(IMAGES_PATH / "wihlborgs") + "/", "use_last": 2},
},
]
config = {
"embedding_model_id": "ibm-granite/granite-embedding-278m-multilingual",
"generation_model_id": "ibm-granite/granite-3.1-8b-instruct",
"milvus_collection_name": "granite_vision_tech_report_text_milvus_lite_512_128_slate_125m_cosine",
"milvus_db_path": str(dir_ / "data" / MILVUS_PATH / "milvus_text_sample.db"),
}
if gr.NO_RELOAD:
light_rag: LightRAG = LightRAG(config)
if not os.environ.get("LAZY_LOADING") == "true":
for sample in sample_data:
light_rag.precache_milvus(sample["index"], sample["db"])
def lower_md_headers(md: str) -> str:
return re.sub(r'(?:^|\n)##?\s(.+)', lambda m: '\n### ' + m.group(1), md)
# Parser for retrival results
def format_retrieval_result(i, d, cb, selected_sample):
image_paths = sample_data[selected_sample]["image_paths"]
if d.metadata["type"] == "text":
context_string = f"---\n## Context {i + 1}\n#### (text extracted from document)\n{lower_md_headers(d.page_content)}\n"
cb.append(gr.ChatMessage(role="assistant", content=context_string))
return True
elif d.metadata["type"] == "image_description":
context_string = f"---\n## Context {i + 1}\n#### (image description generated by Granite Vision)"
cb.append(gr.ChatMessage(role="assistant", content=context_string))
# /dccstor/mm-rag/idanfr/granite_vision_demo/wdu_output/IBM_Annual_Report_2007/images/IBM_Annual_Report_2007_im_image_7_1.png
image_path_parts = d.metadata["image_fullpath"].split("/")
image_path = image_paths["prefix"] + ("/".join(image_path_parts[-image_paths["use_last"]:]))
# print(f"image_path: {image_path}")
cb.append(gr.ChatMessage(role="assistant", content=gr.Image(image_path)))
cb.append(gr.ChatMessage(role="assistant", content=f"\n{lower_md_headers(d.metadata['image_description'])}\n"))
chatbot = gr.Chatbot(
examples=[{"text": x} for x in sample_data[0]["prompts"]],
type="messages",
label=f"Q&A about {sample_data[0]['name']}",
height=685,
group_consecutive_messages=True,
autoscroll=False,
elem_classes=["chatbot_view"],
)
@spaces.GPU()
def generate_with_llm(query, context):
if os.environ.get("NO_LLM"):
time.sleep(2)
return "Now answer, just a string", query
return light_rag.generate(query=query, context=context)
# TODO: maybe add GPU back ?
def retrieval(collection, db, q):
return light_rag.search(q, top_n=3, collection=collection, db=db)
# ################
# User Interface
# ################
css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")
with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
is_in_edit_mode = gr.State(True) # in block to be reactive
selected_doc = gr.State(0)
current_question = gr.State("")
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESCRIPTION)
with gr.Row():
# LEFT COLUMN: Sample selection, download, and PDF viewer
with gr.Column():
# Show preview images
images_only = [sd["preview_image"] for sd in sample_data]
document_gallery = gr.Gallery(
images_only,
label="Select a document",
rows=1,
columns=3,
height="125px",
# width="125px",
allow_preview=False,
selected_index=0,
elem_classes=["preview_im_element"],
)
with gr.Group():
pdf_display = PDF(
sample_data[0]["pdf"],
label=f"Preview for {sample_data[0]['name']}",
height=460,
interactive=False,
elem_classes=["pdf_viewer"],
)
dl_btn = gr.DownloadButton(
label=f"Download PDF ({sample_data[0]['name']})", value=sample_data[0]["pdf"], visible=True
)
def sample_image_selected(d: gr.SelectData):
dx = sample_data[d.index]
# print(f"DX:{dx}")
return (
gr.update(examples=[{"text": x} for x in dx["prompts"]], label=f"Q&A about {dx['name']}"),
gr.update(value=dx["pdf"], label=f"Preview for {dx['name']}"),
gr.DownloadButton(value=dx["pdf"], label=f"Download PDF ({dx['name']})"),
d.index
)
document_gallery.select(lambda: [], outputs=[chatbot])
document_gallery.select(sample_image_selected, inputs=[],
outputs=[chatbot, pdf_display, dl_btn, selected_doc])
# Right Column: Chat interface
with gr.Column():
# Render ChatBot
chatbot.render()
# Define behavior for example selection
def update_user_chat_x(x: gr.SelectData):
return [gr.ChatMessage(role="user", content=x.value["text"])]
def question_from_selection(x: gr.SelectData):
return x.value["text"]
def _decorate_yield_result(cb, fb_status=False, gallery_status=False):
return cb, gr.Button(interactive=fb_status), gr.Gallery(
elem_classes=["preview_im_element"] if gallery_status else ["preview_im_element", "inactive_div"])
def send_generate(msg, cb, selected_sample):
collection = sample_data[selected_sample]["index"]
db = sample_data[selected_sample]["db"]
# print(f"collection: {collection}, {db}")
original_msg = gr.ChatMessage(role="user", content=msg)
cb.append(original_msg)
waiting_for_retrieval_msg = gr.ChatMessage(role="assistant",
content='## Answer\n*Querying Index*<span class="jumping-dots"><span class="dot-1">.</span> <span class="dot-2">.</span> <span class="dot-3">.</span></span>')
cb.append(waiting_for_retrieval_msg)
yield _decorate_yield_result(cb)
q = msg.strip()
results = retrieval(collection, db, q)
# for d in results:
# print(f"****\n{d}")
context_string = "## Context Documents for Answer\n\n"
for i, d in enumerate(results):
if format_retrieval_result(i, d, cb, selected_sample):
yield _decorate_yield_result(cb)
waiting_for_llm_msg = gr.ChatMessage(role="assistant",
content='## Answer\n *Waiting for LLM* <span class="jumping-dots"><span class="dot-1">.</span> <span class="dot-2">.</span> <span class="dot-3">.</span></span> ')
cb[1] = waiting_for_llm_msg
yield _decorate_yield_result(cb)
answer, prompt = generate_with_llm(q, results)
cb[1] = gr.ChatMessage(role="assistant", content=f"## Answer\n<b>{answer.strip()}</b>")
# cb.pop()
# cb.append(gr.ChatMessage(role="assistant", content=f"## Answer\n<b>{answer.strip()}</b>"))
yield _decorate_yield_result(cb, fb_status=True, gallery_status=True)
# Create User Chat Textbox and Reset Button
tbb = gr.Textbox(submit_btn=True, show_label=False, placeholder="Type a message...")
fb = gr.Button("Ask new question", visible=False)
fb.click(lambda: [], outputs=[chatbot])
chatbot.example_select(lambda: False, outputs=is_in_edit_mode)
# chatbot.example_select(update_user_chat_x, outputs=[chatbot])
chatbot.example_select(question_from_selection, inputs=[], outputs=[current_question]
).then(send_generate, inputs=[current_question, chatbot, selected_doc],
outputs=[chatbot, fb, document_gallery])
def textbox_switch(e_mode): # Handle toggling between edit and non-edit mode
if not e_mode:
return [gr.update(visible=False), gr.update(visible=True)]
else:
return [gr.update(visible=True), gr.update(visible=False)]
tbb.submit(lambda: False, outputs=[is_in_edit_mode])
fb.click(lambda: True, outputs=[is_in_edit_mode])
is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb])
# submit user question
# tbb.submit(lambda x: [gr.ChatMessage(role="user", content=x)], inputs=tbb, outputs=chatbot)
tbb.submit(lambda x: x, inputs=[tbb], outputs=[current_question]
).then(send_generate,
inputs=[current_question, chatbot, selected_doc],
outputs=[chatbot, fb, document_gallery])
if __name__ == "__main__":
# demo.queue(max_size=20).launch()
demo.launch()