CGQN commited on
Commit
1d84803
·
verified ·
1 Parent(s): 028fd64

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +233 -0
app.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoModel, AutoTokenizer
6
+
7
+ # Notes:
8
+ # - This demo runs on CPU for broader compatibility. It may be slow compared to GPU.
9
+ # - If you have a GPU, you can set device="cuda" and possibly use torch_dtype=torch.bfloat16.
10
+ # - MiniCPM-V-4_5 uses trust_remote_code; ensure you trust the source.
11
+ # - The model expects multi-modal messages in a chat-like format: [{'role': 'user', 'content': [image, text]}]
12
+ # - For multi-turn chat, we persist history in Gradio state and pass it back to model.chat.
13
+
14
+ MODEL_ID = os.environ.get("MINICPM_MODEL_ID", "openbmb/MiniCPM-V-4_5")
15
+ DEVICE = "cpu" # Force CPU per user request
16
+ DTYPE = torch.float32 # CPU-friendly dtype
17
+
18
+ # Lazy global variables (loaded on first launch)
19
+ _tokenizer = None
20
+ _model = None
21
+
22
+ def load_model():
23
+ global _tokenizer, _model
24
+ if _model is None or _tokenizer is None:
25
+ # Some platforms require setting no_mmap or local_files_only as needed; adjust if necessary.
26
+ _model = AutoModel.from_pretrained(
27
+ MODEL_ID,
28
+ trust_remote_code=True,
29
+ attn_implementation="sdpa", # sdpa is fine on CPU; avoid eager per model note
30
+ torch_dtype=DTYPE
31
+ )
32
+ _model = _model.eval().to(DEVICE)
33
+ _tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
34
+ return _model, _tokenizer
35
+
36
+ def format_history(history):
37
+ """
38
+ Convert Gradio-style chat history into model's expected message format.
39
+ history: list of tuples (user_text, assistant_text) where user_text may have an <image> placeholder handled separately.
40
+ We will store messages in a structured way in state to retain images explicitly instead of parsing text.
41
+ This function is not used directly; we keep the raw message structure in state for fidelity.
42
+ """
43
+ return history
44
+
45
+ def predict(image, user_message, history_state, enable_thinking=False, stream=False):
46
+ """
47
+ image: PIL.Image or None
48
+ user_message: str
49
+ history_state: list of dicts in MiniCPM format [{'role': 'user'|'assistant', 'content':[...]}]
50
+ """
51
+ model, tokenizer = load_model()
52
+
53
+ # Initialize history if empty
54
+ msgs = history_state if isinstance(history_state, list) else []
55
+
56
+ # Build the current user content payload
57
+ # The model expects a list mixing image(s) and text; include only provided items.
58
+ content = []
59
+ if image is not None:
60
+ if image.mode != "RGB":
61
+ image = image.convert("RGB")
62
+ content.append(image)
63
+ if user_message and user_message.strip():
64
+ content.append(user_message.strip())
65
+
66
+ if len(content) == 0:
67
+ return gr.update(), msgs, "Please provide an image and/or a message."
68
+
69
+ msgs = msgs + [{'role': 'user', 'content': content}]
70
+
71
+ # Run generation
72
+ try:
73
+ # model.chat returns either an iterator (when stream=True) or a string
74
+ answer = model.chat(
75
+ msgs=msgs,
76
+ tokenizer=tokenizer,
77
+ enable_thinking=bool(enable_thinking),
78
+ stream=bool(stream)
79
+ )
80
+
81
+ if stream:
82
+ # Concatenate streamed text
83
+ generated = []
84
+ for chunk in answer:
85
+ generated.append(chunk)
86
+ yield "\n".join(["".join(generated)]), msgs, None
87
+ final_text = "".join(generated)
88
+ else:
89
+ final_text = answer
90
+
91
+ # Append assistant message back into msgs
92
+ msgs = msgs + [{"role": "assistant", "content": [final_text]}]
93
+
94
+ # Return final
95
+ yield final_text, msgs, None
96
+
97
+ except Exception as e:
98
+ yield gr.update(), msgs, f"Error: {e}"
99
+
100
+ def clear_state():
101
+ return None, [], None
102
+
103
+ with gr.Blocks(title="MiniCPM-V-4_5 CPU Gradio Demo") as demo:
104
+ gr.Markdown("# MiniCPM-V-4_5 (CPU) Demo")
105
+ gr.Markdown("Upload an image (optional) and ask a question. Multi-turn chat is supported. Running on CPU may be slow.")
106
+
107
+ with gr.Row():
108
+ with gr.Column(scale=1):
109
+ image_in = gr.Image(type="pil", label="Image (optional)")
110
+ user_in = gr.Textbox(label="Your Message", placeholder="Ask a question about the image or general query...", lines=3)
111
+ with gr.Row():
112
+ think_chk = gr.Checkbox(label="Enable Thinking Mode", value=False)
113
+ stream_chk = gr.Checkbox(label="Stream Output", value=False)
114
+ with gr.Row():
115
+ submit_btn = gr.Button("Send", variant="primary")
116
+ clear_btn = gr.Button("Clear")
117
+
118
+ with gr.Column(scale=2):
119
+ chat_out = gr.Chatbot(label="Chat", type="messages", height=450, avatar_images=(None, None))
120
+ status_box = gr.Markdown("", visible=True)
121
+
122
+ # Hidden state: we store the raw MiniCPM messages, not just text pairs
123
+ state_msgs = gr.State([])
124
+
125
+ def on_submit(image, message, enable_thinking, stream, msgs):
126
+ # Kick off streaming generator
127
+ # We'll display only last exchange in Chatbot. Convert msgs to Chatbot-friendly format when yielding.
128
+ # For Chatbot display, we reconstruct from msgs
129
+ def format_for_chatbot(msgs_local):
130
+ chat_pairs = []
131
+ # Collect pairs by scanning msgs in order
132
+ user_tmp = None
133
+ for m in msgs_local:
134
+ if m["role"] == "user":
135
+ # Convert content to displayable string for Chatbot
136
+ parts = []
137
+ for c in m["content"]:
138
+ if isinstance(c, Image.Image):
139
+ parts.append("[Image]")
140
+ else:
141
+ parts.append(str(c))
142
+ user_tmp = " ".join(parts).strip() or "[Image]"
143
+ elif m["role"] == "assistant":
144
+ assistant_text = " ".join([str(x) for x in m["content"]]) if m["content"] else ""
145
+ if user_tmp is None:
146
+ chat_pairs.append((None, assistant_text))
147
+ else:
148
+ chat_pairs.append((user_tmp, assistant_text))
149
+ user_tmp = None
150
+ return chat_pairs
151
+
152
+ gen = predict(image, message, msgs, enable_thinking, stream)
153
+ if stream:
154
+ for partial_text, updated_msgs, err in gen:
155
+ # Build display history from updated_msgs + current partial response
156
+ display_msgs = updated_msgs.copy()
157
+ # Don't duplicate assistant msg until finalized; just show in Chatbot via the last pair
158
+ chat_history = format_for_chatbot(display_msgs)
159
+ if chat_history and isinstance(partial_text, str) and partial_text:
160
+ if chat_history and (not chat_history[-1][1] or chat_history[-1][1] == ""):
161
+ # replace last tuple assistant part
162
+ u, _ = chat_history[-1]
163
+ chat_history[-1] = (u, partial_text)
164
+ else:
165
+ # append live pair
166
+ last_user = None
167
+ for m in reversed(display_msgs):
168
+ if m["role"] == "user":
169
+ parts = []
170
+ for c in m["content"]:
171
+ if isinstance(c, Image.Image):
172
+ parts.append("[Image]")
173
+ else:
174
+ parts.append(str(c))
175
+ last_user = " ".join(parts).strip() or "[Image]"
176
+ break
177
+ chat_history.append((last_user, partial_text))
178
+ status = "" if not err else f"{err}"
179
+ yield chat_history, updated_msgs, status, gr.update(value=None), gr.update(value=None)
180
+ else:
181
+ for final_text, updated_msgs, err in gen:
182
+ chat_history = []
183
+ # Build chat history from updated_msgs
184
+ def format_for_chatbot_final(msgs_local):
185
+ pairs = []
186
+ u_txt = None
187
+ for m in msgs_local:
188
+ if m["role"] == "user":
189
+ parts = []
190
+ for c in m["content"]:
191
+ if isinstance(c, Image.Image):
192
+ parts.append("[Image]")
193
+ else:
194
+ parts.append(str(c))
195
+ u_txt = " ".join(parts).strip() or "[Image]"
196
+ elif m["role"] == "assistant":
197
+ a_txt = " ".join([str(x) for x in m["content"]]) if m["content"] else ""
198
+ if u_txt is None:
199
+ pairs.append((None, a_txt))
200
+ else:
201
+ pairs.append((u_txt, a_txt))
202
+ u_txt = None
203
+ return pairs
204
+
205
+ chat_history = format_for_chatbot_final(updated_msgs)
206
+ status = "" if not err else f"{err}"
207
+ yield chat_history, updated_msgs, status, gr.update(value=None), gr.update(value=None)
208
+
209
+ submit_btn.click(
210
+ on_submit,
211
+ inputs=[image_in, user_in, think_chk, stream_chk, state_msgs],
212
+ outputs=[chat_out, state_msgs, status_box, user_in, image_in]
213
+ )
214
+
215
+ clear_btn.click(
216
+ fn=clear_state,
217
+ inputs=[],
218
+ outputs=[user_in, state_msgs, status_box]
219
+ ).then(
220
+ lambda: [],
221
+ inputs=None,
222
+ outputs=chat_out
223
+ )
224
+
225
+ # Preload model on app start (optional; keeps UI responsive on first query)
226
+ demo.load(lambda: "Model loading on CPU... Please wait a moment.", outputs=status_box).then(
227
+ lambda: (load_model() or True) and "Model loaded. Ready!",
228
+ outputs=status_box
229
+ )
230
+
231
+ if __name__ == "__main__":
232
+ # Set server_name="0.0.0.0" to expose externally if needed.
233
+ demo.queue(max_size=8, concurrency_count=1).launch()