Spaces:
Running
Running
import gradio as gr | |
import fitz # PyMuPDF | |
import torch | |
import cv2 | |
import os | |
import tempfile | |
import shutil | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
# Load Qwen-VL-Chat | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen-VL-Chat", | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
).eval() | |
# Embedding model | |
embed_model = SentenceTransformer('all-MiniLM-L6-v2') | |
# Global state for FAISS | |
chunks = [] | |
index = None | |
# PDF processing | |
def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200): | |
doc = fitz.open(pdf_path) | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)] | |
def build_faiss_index(chunks): | |
embeddings = embed_model.encode(chunks, convert_to_numpy=True) | |
dim = embeddings.shape[1] | |
idx = faiss.IndexFlatL2(dim) | |
idx.add(embeddings) | |
return idx | |
def rag_query(query, chunks, index, top_k=3): | |
q_emb = embed_model.encode([query], convert_to_numpy=True) | |
D, I = index.search(q_emb, top_k) | |
return "\n\n".join([chunks[i] for i in I[0]]) | |
# Vision/Text chat | |
def chat_with_qwen(text=None, image=None): | |
elements = [] | |
if image: | |
elements.append({"image": image}) | |
if text: | |
elements.append({"text": text}) | |
if not elements: | |
return "Please upload or type something." | |
query = tokenizer.from_list_format(elements) | |
response, _ = model.chat(tokenizer, query, history=None) | |
return response | |
# Video frame extraction | |
def extract_video_frames(video_path, max_frames=3): | |
cap = cv2.VideoCapture(video_path) | |
frames, count = [], 0 | |
while len(frames) < max_frames: | |
success, frame = cap.read() | |
if not success: | |
break | |
frames.append(frame) | |
count += 1 | |
cap.set(cv2.CAP_PROP_POS_FRAMES, count * 30) | |
cap.release() | |
return frames | |
# Main chatbot logic | |
def multimodal_chat(message, history, image=None, video=None, pdf=None): | |
global chunks, index | |
# PDF-based RAG | |
if pdf: | |
chunks = extract_chunks_from_pdf(pdf.name) | |
index = build_faiss_index(chunks) | |
context = rag_query(message, chunks, index) | |
final_prompt = f"Context:\n{context}\n\nQuestion: {message}" | |
response = chat_with_qwen(final_prompt) | |
return response | |
# Image | |
if image: | |
response = chat_with_qwen(message, image) | |
return response | |
# Video (extract frames and send all in one call) | |
if video: | |
temp_dir = tempfile.mkdtemp() | |
video_path = os.path.join(temp_dir, "vid.mp4") | |
shutil.copy(video, video_path) | |
frames = extract_video_frames(video_path) | |
# Save and collect image paths | |
images = [] | |
for i, frame in enumerate(frames): | |
temp_img_path = os.path.join(temp_dir, f"frame_{i}.jpg") | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
cv2.imwrite(temp_img_path, frame_rgb) | |
images.append(temp_img_path) | |
# Combine all frames and text into one query | |
elements = [{"image": img} for img in images] | |
if message: | |
elements.append({"text": message}) | |
query = tokenizer.from_list_format(elements) | |
response, _ = model.chat(tokenizer, query, history=None) | |
return response | |
# Text only | |
if message: | |
return chat_with_qwen(message) | |
return "Please input a message, image, video, or PDF." | |
# ---- Gradio UI ---- # | |
with gr.Blocks(css=""" | |
body { | |
background-color: #f3f6fc; | |
} | |
.gradio-container { | |
font-family: 'Segoe UI', sans-serif; | |
} | |
h1 { | |
background: linear-gradient(to right, #667eea, #764ba2); | |
color: white !important; | |
padding: 1rem; | |
border-radius: 12px; | |
margin-bottom: 0.5rem; | |
} | |
p { | |
font-size: 1rem; | |
color: white; | |
} | |
.gr-box { | |
background-color: white; | |
border-radius: 12px; | |
box-shadow: 0 0 10px rgba(0,0,0,0.05); | |
padding: 16px; | |
} | |
footer {display: none !important;} | |
""") as demo: | |
gr.Markdown( | |
"<h1 style='text-align: center;'>Multimodal Chatbot powered by LLAVACMVRL and QWEN-VL</h1>" | |
"<p style='text-align: center;'>Ask questions with text, images, videos, or PDFs in a smart and multimodal way.</p>" | |
) | |
chatbot = gr.Chatbot(show_label=False, height=450) | |
state = gr.State([]) | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="Type a message...", scale=5) | |
send_btn = gr.Button("🚀 Send", scale=1) | |
with gr.Row(): | |
image_input = gr.Image(type="filepath", label="Upload Image") | |
video_input = gr.Video(label="Upload Video") | |
pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF") | |
def user_send(message, history, image, video, pdf): | |
response = multimodal_chat(message, history, image, video, pdf) | |
history.append((message, response)) | |
return "", history | |
send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot]) | |
txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot]) | |
# Launch the app | |
demo.launch() |