Spaces:
Paused
Paused
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
|
3 |
+
from PIL import Image
|
4 |
+
import gradio as gr
|
5 |
+
from threading import Thread
|
6 |
+
|
7 |
+
# --- 1. Model and Processor Setup ---
|
8 |
+
# This part is loaded only once when the script starts.
|
9 |
+
|
10 |
+
try:
|
11 |
+
model_id = "bharatgenai/patram-7b-instruct"
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
print(f"Using device: {device}")
|
14 |
+
|
15 |
+
# Load processor and model
|
16 |
+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(
|
18 |
+
model_id,
|
19 |
+
torch_dtype=torch.float16, # Use float16 for less memory usage on GPU
|
20 |
+
device_map="auto", # Automatically uses available GPUs
|
21 |
+
trust_remote_code=True
|
22 |
+
)
|
23 |
+
print("Model and processor loaded successfully.")
|
24 |
+
|
25 |
+
# --- Define and apply the chat template ---
|
26 |
+
# This is crucial for multi-turn conversation
|
27 |
+
chat_template = """{% for message in messages -%}
|
28 |
+
{%- if (loop.index % 2 == 1 and message['role'] != 'user') or
|
29 |
+
(loop.index % 2 == 0 and message['role'].lower() != 'assistant') -%}
|
30 |
+
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
|
31 |
+
{%- endif -%}
|
32 |
+
{{ message['role'].capitalize() + ': ' + message['content'] }}
|
33 |
+
{%- if not loop.last -%}
|
34 |
+
{{ ' ' }}
|
35 |
+
{%- endif %}
|
36 |
+
{%- endfor -%}
|
37 |
+
{%- if add_generation_prompt -%}
|
38 |
+
{{ ' Assistant:' }}
|
39 |
+
{%- endif %}"""
|
40 |
+
processor.tokenizer.chat_template = chat_template
|
41 |
+
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Error during model loading: {e}")
|
44 |
+
# Exit if model can't be loaded, as the app is unusable.
|
45 |
+
exit()
|
46 |
+
|
47 |
+
# --- 2. Gradio Chatbot Logic with Streaming ---
|
48 |
+
|
49 |
+
def process_chat_streaming(user_message, chatbot_display, messages_list, image_pil):
|
50 |
+
"""
|
51 |
+
This generator function handles the chat logic with streaming.
|
52 |
+
It yields the updated chatbot display at each step.
|
53 |
+
"""
|
54 |
+
# Check if an image has been uploaded
|
55 |
+
if image_pil is None:
|
56 |
+
chatbot_display.append((user_message, "Please upload an image first to start the conversation."))
|
57 |
+
yield chatbot_display, messages_list
|
58 |
+
return # Stop the generator
|
59 |
+
|
60 |
+
# Append user's message to the conversation history and display
|
61 |
+
messages_list.append({"role": "user", "content": user_message})
|
62 |
+
chatbot_display.append((user_message, "")) # Add an empty spot for the streaming response
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Use the processor to apply the chat template
|
66 |
+
prompt = processor.tokenizer.apply_chat_template(
|
67 |
+
messages_list,
|
68 |
+
tokenize=False,
|
69 |
+
add_generation_prompt=True
|
70 |
+
)
|
71 |
+
|
72 |
+
# Preprocess image and the entire formatted prompt
|
73 |
+
inputs = processor.process(images=[image_pil], text=prompt)
|
74 |
+
inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
|
75 |
+
|
76 |
+
# Setup the streamer
|
77 |
+
streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
78 |
+
|
79 |
+
# Define generation configuration
|
80 |
+
generation_config = GenerationConfig(
|
81 |
+
max_new_tokens=512,
|
82 |
+
do_sample=True,
|
83 |
+
top_p=0.9,
|
84 |
+
temperature=0.6,
|
85 |
+
stop_strings=["<|endoftext|>", "User:"] # Add stop strings to prevent over-generation
|
86 |
+
)
|
87 |
+
|
88 |
+
# Create generation kwargs for the thread
|
89 |
+
generation_kwargs = dict(
|
90 |
+
inputs,
|
91 |
+
streamer=streamer,
|
92 |
+
generation_config=generation_config
|
93 |
+
)
|
94 |
+
|
95 |
+
# Run generation in a separate thread
|
96 |
+
thread = Thread(target=model.generate_from_batch, kwargs=generation_kwargs)
|
97 |
+
thread.start()
|
98 |
+
|
99 |
+
# Yield updates to the Gradio UI
|
100 |
+
full_response = ""
|
101 |
+
for new_text in streamer:
|
102 |
+
full_response += new_text
|
103 |
+
chatbot_display[-1] = (user_message, full_response)
|
104 |
+
yield chatbot_display, messages_list
|
105 |
+
|
106 |
+
# After the loop, the generation is complete.
|
107 |
+
# Add the final full response to the messages list for context.
|
108 |
+
messages_list.append({"role": "assistant", "content": full_response})
|
109 |
+
yield chatbot_display, messages_list # Yield the final state
|
110 |
+
|
111 |
+
except Exception as e:
|
112 |
+
print(f"Error during streaming inference: {e}")
|
113 |
+
error_message = f"Sorry, an error occurred: {e}"
|
114 |
+
chatbot_display[-1] = (user_message, error_message)
|
115 |
+
yield chatbot_display, messages_list
|
116 |
+
|
117 |
+
# --- 3. Gradio Interface Definition ---
|
118 |
+
|
119 |
+
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="neutral")) as demo:
|
120 |
+
gr.Markdown("# 🤖 Patram-7B-Instruct Streaming Chatbot")
|
121 |
+
gr.Markdown("Upload an image and ask questions about it. The response will stream in real-time.")
|
122 |
+
|
123 |
+
# State variables to hold conversation history
|
124 |
+
messages_list = gr.State([])
|
125 |
+
|
126 |
+
with gr.Row():
|
127 |
+
with gr.Column(scale=1):
|
128 |
+
image_input = gr.Image(type="pil", label="Upload Image")
|
129 |
+
clear_btn = gr.Button("🗑️ Clear Chat and Image")
|
130 |
+
|
131 |
+
with gr.Column(scale=2):
|
132 |
+
chatbot_display = gr.Chatbot(
|
133 |
+
label="Conversation",
|
134 |
+
bubble_full_width=False,
|
135 |
+
height=500,
|
136 |
+
avatar_images=(None, "https://cdn-avatars.huggingface.co/v1/production/uploads/67b462a1f4f414c2b3e2bc2f/EnVeNWEIeZ6yF6ueZ7E3Y.jpeg")
|
137 |
+
)
|
138 |
+
with gr.Row():
|
139 |
+
user_textbox = gr.Textbox(
|
140 |
+
placeholder="Type your question here...",
|
141 |
+
show_label=False,
|
142 |
+
scale=4,
|
143 |
+
container=False
|
144 |
+
)
|
145 |
+
# The submit button is now primarily for show; Enter key is the main way to submit
|
146 |
+
# but we will wire it up anyway.
|
147 |
+
|
148 |
+
# --- Event Listeners ---
|
149 |
+
|
150 |
+
# Define the action for submitting a message (via enter key)
|
151 |
+
submit_action = user_textbox.submit(
|
152 |
+
fn=process_chat_streaming,
|
153 |
+
inputs=[user_textbox, chatbot_display, messages_list, image_input],
|
154 |
+
outputs=[chatbot_display, messages_list],
|
155 |
+
# queue=False # Set queue to False for faster interaction with streaming
|
156 |
+
)
|
157 |
+
|
158 |
+
# Chain the action to also clear the textbox after submission
|
159 |
+
submit_action.then(
|
160 |
+
fn=lambda: gr.update(value=""),
|
161 |
+
inputs=None,
|
162 |
+
outputs=[user_textbox],
|
163 |
+
queue=False
|
164 |
+
)
|
165 |
+
|
166 |
+
# Define the action for the clear button
|
167 |
+
clear_btn.click(
|
168 |
+
fn=lambda: ([], [], None, ""), # Function to return empty/default values
|
169 |
+
inputs=[],
|
170 |
+
outputs=[chatbot_display, messages_list, image_input, user_textbox],
|
171 |
+
queue=False
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
if __name__ == "__main__":
|
176 |
+
demo.queue().launch(debug=True, share=True)
|