Spaces:
Running
Running
import os | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import AutoModel, AutoTokenizer | |
MODEL_ID = os.environ.get("MINICPM_MODEL_ID", "openbmb/MiniCPM-V-4_5") | |
# Best practice: set a deterministic seed for reproducibility | |
torch.manual_seed(100) | |
def load_model(precision_mode="int4"): | |
""" | |
Load MiniCPM-V-4_5 model on CPU with chosen precision. | |
- precision_mode: "int4" (default) quantized or "fp16" half precision emulation. | |
Note: True FP16 tensors are not supported on CPU; we use bfloat16 or float32 fallback. | |
""" | |
kwargs = dict(trust_remote_code=True, attn_implementation="sdpa") | |
if precision_mode == "int4": | |
# BitsAndBytes is not available for CPU only in Transformers' AutoModel consistently across archs, | |
# but MiniCPM provides CPU-friendly quantization via trust_remote_code. We'll pass load_in_4bit if supported. | |
try: | |
model = AutoModel.from_pretrained( | |
MODEL_ID, | |
load_in_4bit=True, | |
device_map="cpu", | |
**kwargs, | |
) | |
dtype_used = "int4" | |
except Exception: | |
# Fallback: load in 8-bit or bf16 if 4-bit isn't supported in environment | |
model = AutoModel.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map="cpu", | |
**kwargs, | |
) | |
dtype_used = "fallback_bf16_or_fp32" | |
else: | |
# "fp16" requested: CPU cannot run native fp16; we emulate with bfloat16 if available, otherwise float32 | |
# Many Intel/AMD CPUs support bfloat16 acceleration; if not, it will still run in fp32 math. | |
model = AutoModel.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
device_map="cpu", | |
**kwargs, | |
) | |
dtype_used = "bf16_or_fp32_on_cpu_for_fp16_request" | |
model = model.eval() | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
return model, tokenizer, dtype_used | |
# Global cache to avoid reloading each time | |
_state = {"model": None, "tokenizer": None, "mode": None, "dtype_used": None} | |
def ensure_model(mode): | |
if _state["model"] is None or _state["mode"] != mode: | |
_state["model"], _state["tokenizer"], _state["dtype_used"] = load_model(mode) | |
_state["mode"] = mode | |
def chat_infer(image: Image.Image, message: str, history, mode: str, enable_thinking: bool): | |
if image is None and (not history or all((h[0] or "") == "" and (h[1] or "") == "" for h in history)): | |
return history or [], "Please upload an image or enter a message." | |
ensure_model(mode) | |
model, tokenizer = _state["model"], _state["tokenizer"] | |
# Build msgs from history and current inputs | |
msgs = [] | |
# Convert history into msgs | |
# Each item in history is (user, assistant) | |
for user_msg, assistant_msg in history or []: | |
if user_msg: | |
# history may not contain images; only text | |
msgs.append({"role": "user", "content": [user_msg]}) | |
if assistant_msg: | |
msgs.append({"role": "assistant", "content": [assistant_msg]}) | |
# Add current user turn | |
user_content = [] | |
if image is not None: | |
# Ensure RGB | |
if image.mode != "RGB": | |
image = image.convert("RGB") | |
user_content.append(image) | |
if message and message.strip(): | |
user_content.append(message.strip()) | |
if not user_content: | |
return history or [], "Please provide text or image." | |
msgs.append({"role": "user", "content": user_content}) | |
try: | |
answer = model.chat( | |
msgs=msgs, | |
tokenizer=tokenizer, | |
enable_thinking=enable_thinking, | |
) | |
except Exception as e: | |
return history or [], f"Inference error: {e}" | |
# Update history for Gradio chat UI: append the latest pair | |
history = (history or []) + [(message or "[Image]", answer)] | |
sys_info = f"Mode: {mode} | Loaded dtype: {_state['dtype_used']} | Device: CPU" | |
return history, sys_info | |
def clear_history(): | |
return [], "" | |
with gr.Blocks(title="MiniCPM-V-4_5 CPU (int4 default, fp16 optional)", fill_height=True) as demo: | |
gr.Markdown("# MiniCPM-V-4_5 CPU Deployment\n- Modes: int4 (default) and fp16\n- Running on CPU") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbox = gr.Chatbot(height=420, label="Chat") | |
with gr.Row(): | |
img = gr.Image(type="pil", label="Image (optional)") | |
msg = gr.Textbox(placeholder="Ask a question about the image or general query...", lines=3) | |
with gr.Row(): | |
send_btn = gr.Button("Send", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(scale=1): | |
mode = gr.Radio( | |
choices=["int4", "fp16"], | |
value="int4", | |
label="Precision Mode (CPU)", | |
info="int4 as default. fp16 uses bf16/fp32 on CPU." | |
) | |
thinking = gr.Checkbox(label="Enable Thinking Mode", value=False) | |
sys_out = gr.Markdown("") | |
def on_send(message, image, history, mode, thinking): | |
return chat_infer(image, message, history, mode, thinking) | |
send_btn.click( | |
fn=on_send, | |
inputs=[msg, img, chatbox, mode, thinking], | |
outputs=[chatbox, sys_out], | |
show_progress=True, | |
) | |
# Submit on Enter | |
msg.submit( | |
fn=on_send, | |
inputs=[msg, img, chatbox, mode, thinking], | |
outputs=[chatbox, sys_out], | |
show_progress=True, | |
) | |
clear_btn.click(fn=clear_history, outputs=[chatbox, sys_out]) | |
if __name__ == "__main__": | |
# For CPU environments with many threads, you may limit to reduce contention: | |
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "4"))) | |
demo.launch() |