File size: 5,259 Bytes
48fbb50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
import fitz  # PyMuPDF
import torch
import cv2
import os
import tempfile
import shutil
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import faiss

# Load Qwen-VL-Chat
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen-VL-Chat",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
).eval()

# Embedding model
embed_model = SentenceTransformer('all-MiniLM-L6-v2')

# Global state for FAISS
chunks = []
index = None

# PDF processing
def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200):
    doc = fitz.open(pdf_path)
    text = ""
    for page in doc:
        text += page.get_text()
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]

def build_faiss_index(chunks):
    embeddings = embed_model.encode(chunks, convert_to_numpy=True)
    dim = embeddings.shape[1]
    idx = faiss.IndexFlatL2(dim)
    idx.add(embeddings)
    return idx

def rag_query(query, chunks, index, top_k=3):
    q_emb = embed_model.encode([query], convert_to_numpy=True)
    D, I = index.search(q_emb, top_k)
    return "\n\n".join([chunks[i] for i in I[0]])

# Vision/Text chat
def chat_with_qwen(text=None, image=None):
    elements = []
    if image:
        elements.append({"image": image})
    if text:
        elements.append({"text": text})
    if not elements:
        return "Please upload or type something."
    query = tokenizer.from_list_format(elements)
    response, _ = model.chat(tokenizer, query, history=None)
    return response

# Video frame extraction
def extract_video_frames(video_path, max_frames=3):
    cap = cv2.VideoCapture(video_path)
    frames, count = [], 0
    while len(frames) < max_frames:
        success, frame = cap.read()
        if not success:
            break
        frames.append(frame)
        count += 1
        cap.set(cv2.CAP_PROP_POS_FRAMES, count * 30)
    cap.release()
    return frames

# Main chatbot logic
def multimodal_chat(message, history, image=None, video=None, pdf=None):
    global chunks, index

    # PDF-based RAG
    if pdf:
        chunks = extract_chunks_from_pdf(pdf.name)
        index = build_faiss_index(chunks)
        context = rag_query(message, chunks, index)
        final_prompt = f"Context:\n{context}\n\nQuestion: {message}"
        response = chat_with_qwen(final_prompt)
        return response

    # Image
    if image:
        response = chat_with_qwen(message, image)
        return response

    # Video (extract frames and send all in one call)
    if video:
        temp_dir = tempfile.mkdtemp()
        video_path = os.path.join(temp_dir, "vid.mp4")
        shutil.copy(video, video_path)
        frames = extract_video_frames(video_path)

        # Save and collect image paths
        images = []
        for i, frame in enumerate(frames):
            temp_img_path = os.path.join(temp_dir, f"frame_{i}.jpg")
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            cv2.imwrite(temp_img_path, frame_rgb)
            images.append(temp_img_path)

        # Combine all frames and text into one query
        elements = [{"image": img} for img in images]
        if message:
            elements.append({"text": message})

        query = tokenizer.from_list_format(elements)
        response, _ = model.chat(tokenizer, query, history=None)
        return response

    # Text only
    if message:
        return chat_with_qwen(message)

    return "Please input a message, image, video, or PDF."

# ---- Gradio UI ---- #
with gr.Blocks(css="""
body {
background-color: #f3f6fc;
}
.gradio-container {
font-family: 'Segoe UI', sans-serif;
}
h1 {
background: linear-gradient(to right, #667eea, #764ba2);
color: white !important;
padding: 1rem;
border-radius: 12px;
margin-bottom: 0.5rem;
}
p {
font-size: 1rem;
color: white;
}
.gr-box {
background-color: white;
border-radius: 12px;
box-shadow: 0 0 10px rgba(0,0,0,0.05);
padding: 16px;
}
footer {display: none !important;}
""") as demo:
    gr.Markdown(
        "<h1 style='text-align: center;'>Multimodal Chatbot powered by LLAVACMVRL and QWEN-VL</h1>"
        "<p style='text-align: center;'>Ask questions with text, images, videos, or PDFs in a smart and multimodal way.</p>"
    )

    chatbot = gr.Chatbot(show_label=False, height=450)
    state = gr.State([])

    with gr.Row():
        txt = gr.Textbox(show_label=False, placeholder="Type a message...", scale=5)
        send_btn = gr.Button("🚀 Send", scale=1)

    with gr.Row():
        image_input = gr.Image(type="filepath", label="Upload Image")
        video_input = gr.Video(label="Upload Video")
        pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF")

    def user_send(message, history, image, video, pdf):
        response = multimodal_chat(message, history, image, video, pdf)
        history.append((message, response))
        return "", history

    send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
    txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])

# Launch the app
demo.launch()