File size: 5,998 Bytes
e7b1930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import gradio as gr
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer

MODEL_ID = os.environ.get("MINICPM_MODEL_ID", "openbmb/MiniCPM-V-4_5")

# Best practice: set a deterministic seed for reproducibility
torch.manual_seed(100)

def load_model(precision_mode="int4"):
    """
    Load MiniCPM-V-4_5 model on CPU with chosen precision.
    - precision_mode: "int4" (default) quantized or "fp16" half precision emulation.
    Note: True FP16 tensors are not supported on CPU; we use bfloat16 or float32 fallback.
    """
    kwargs = dict(trust_remote_code=True, attn_implementation="sdpa")

    if precision_mode == "int4":
        # BitsAndBytes is not available for CPU only in Transformers' AutoModel consistently across archs,
        # but MiniCPM provides CPU-friendly quantization via trust_remote_code. We'll pass load_in_4bit if supported.
        try:
            model = AutoModel.from_pretrained(
                MODEL_ID,
                load_in_4bit=True,
                device_map="cpu",
                **kwargs,
            )
            dtype_used = "int4"
        except Exception:
            # Fallback: load in 8-bit or bf16 if 4-bit isn't supported in environment
            model = AutoModel.from_pretrained(
                MODEL_ID,
                torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
                device_map="cpu",
                **kwargs,
            )
            dtype_used = "fallback_bf16_or_fp32"
    else:
        # "fp16" requested: CPU cannot run native fp16; we emulate with bfloat16 if available, otherwise float32
        # Many Intel/AMD CPUs support bfloat16 acceleration; if not, it will still run in fp32 math.
        model = AutoModel.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map="cpu",
            **kwargs,
        )
        dtype_used = "bf16_or_fp32_on_cpu_for_fp16_request"

    model = model.eval()
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
    return model, tokenizer, dtype_used

# Global cache to avoid reloading each time
_state = {"model": None, "tokenizer": None, "mode": None, "dtype_used": None}

def ensure_model(mode):
    if _state["model"] is None or _state["mode"] != mode:
        _state["model"], _state["tokenizer"], _state["dtype_used"] = load_model(mode)
        _state["mode"] = mode

def chat_infer(image: Image.Image, message: str, history, mode: str, enable_thinking: bool):
    if image is None and (not history or all((h[0] or "") == "" and (h[1] or "") == "" for h in history)):
        return history or [], "Please upload an image or enter a message."
    ensure_model(mode)
    model, tokenizer = _state["model"], _state["tokenizer"]

    # Build msgs from history and current inputs
    msgs = []
    # Convert history into msgs
    # Each item in history is (user, assistant)
    for user_msg, assistant_msg in history or []:
        if user_msg:
            # history may not contain images; only text
            msgs.append({"role": "user", "content": [user_msg]})
        if assistant_msg:
            msgs.append({"role": "assistant", "content": [assistant_msg]})

    # Add current user turn
    user_content = []
    if image is not None:
        # Ensure RGB
        if image.mode != "RGB":
            image = image.convert("RGB")
        user_content.append(image)
    if message and message.strip():
        user_content.append(message.strip())
    if not user_content:
        return history or [], "Please provide text or image."

    msgs.append({"role": "user", "content": user_content})

    try:
        answer = model.chat(
            msgs=msgs,
            tokenizer=tokenizer,
            enable_thinking=enable_thinking,
        )
    except Exception as e:
        return history or [], f"Inference error: {e}"

    # Update history for Gradio chat UI: append the latest pair
    history = (history or []) + [(message or "[Image]", answer)]
    sys_info = f"Mode: {mode} | Loaded dtype: {_state['dtype_used']} | Device: CPU"
    return history, sys_info

def clear_history():
    return [], ""

with gr.Blocks(title="MiniCPM-V-4_5 CPU (int4 default, fp16 optional)", fill_height=True) as demo:
    gr.Markdown("# MiniCPM-V-4_5 CPU Deployment\n- Modes: int4 (default) and fp16\n- Running on CPU")

    with gr.Row():
        with gr.Column(scale=2):
            chatbox = gr.Chatbot(height=420, label="Chat")
            with gr.Row():
                img = gr.Image(type="pil", label="Image (optional)")
            msg = gr.Textbox(placeholder="Ask a question about the image or general query...", lines=3)
            with gr.Row():
                send_btn = gr.Button("Send", variant="primary")
                clear_btn = gr.Button("Clear")
        with gr.Column(scale=1):
            mode = gr.Radio(
                choices=["int4", "fp16"],
                value="int4",
                label="Precision Mode (CPU)",
                info="int4 as default. fp16 uses bf16/fp32 on CPU."
            )
            thinking = gr.Checkbox(label="Enable Thinking Mode", value=False)
            sys_out = gr.Markdown("")

    def on_send(message, image, history, mode, thinking):
        return chat_infer(image, message, history, mode, thinking)

    send_btn.click(
        fn=on_send,
        inputs=[msg, img, chatbox, mode, thinking],
        outputs=[chatbox, sys_out],
        show_progress=True,
    )

    # Submit on Enter
    msg.submit(
        fn=on_send,
        inputs=[msg, img, chatbox, mode, thinking],
        outputs=[chatbox, sys_out],
        show_progress=True,
    )

    clear_btn.click(fn=clear_history, outputs=[chatbox, sys_out])

if __name__ == "__main__":
    # For CPU environments with many threads, you may limit to reduce contention:
    torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "4")))
    demo.launch()