CGQN commited on
Commit
e7b1930
·
verified ·
1 Parent(s): e4be608

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoModel, AutoTokenizer
6
+
7
+ MODEL_ID = os.environ.get("MINICPM_MODEL_ID", "openbmb/MiniCPM-V-4_5")
8
+
9
+ # Best practice: set a deterministic seed for reproducibility
10
+ torch.manual_seed(100)
11
+
12
+ def load_model(precision_mode="int4"):
13
+ """
14
+ Load MiniCPM-V-4_5 model on CPU with chosen precision.
15
+ - precision_mode: "int4" (default) quantized or "fp16" half precision emulation.
16
+ Note: True FP16 tensors are not supported on CPU; we use bfloat16 or float32 fallback.
17
+ """
18
+ kwargs = dict(trust_remote_code=True, attn_implementation="sdpa")
19
+
20
+ if precision_mode == "int4":
21
+ # BitsAndBytes is not available for CPU only in Transformers' AutoModel consistently across archs,
22
+ # but MiniCPM provides CPU-friendly quantization via trust_remote_code. We'll pass load_in_4bit if supported.
23
+ try:
24
+ model = AutoModel.from_pretrained(
25
+ MODEL_ID,
26
+ load_in_4bit=True,
27
+ device_map="cpu",
28
+ **kwargs,
29
+ )
30
+ dtype_used = "int4"
31
+ except Exception:
32
+ # Fallback: load in 8-bit or bf16 if 4-bit isn't supported in environment
33
+ model = AutoModel.from_pretrained(
34
+ MODEL_ID,
35
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
36
+ device_map="cpu",
37
+ **kwargs,
38
+ )
39
+ dtype_used = "fallback_bf16_or_fp32"
40
+ else:
41
+ # "fp16" requested: CPU cannot run native fp16; we emulate with bfloat16 if available, otherwise float32
42
+ # Many Intel/AMD CPUs support bfloat16 acceleration; if not, it will still run in fp32 math.
43
+ model = AutoModel.from_pretrained(
44
+ MODEL_ID,
45
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
46
+ device_map="cpu",
47
+ **kwargs,
48
+ )
49
+ dtype_used = "bf16_or_fp32_on_cpu_for_fp16_request"
50
+
51
+ model = model.eval()
52
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
53
+ return model, tokenizer, dtype_used
54
+
55
+ # Global cache to avoid reloading each time
56
+ _state = {"model": None, "tokenizer": None, "mode": None, "dtype_used": None}
57
+
58
+ def ensure_model(mode):
59
+ if _state["model"] is None or _state["mode"] != mode:
60
+ _state["model"], _state["tokenizer"], _state["dtype_used"] = load_model(mode)
61
+ _state["mode"] = mode
62
+
63
+ def chat_infer(image: Image.Image, message: str, history, mode: str, enable_thinking: bool):
64
+ if image is None and (not history or all((h[0] or "") == "" and (h[1] or "") == "" for h in history)):
65
+ return history or [], "Please upload an image or enter a message."
66
+ ensure_model(mode)
67
+ model, tokenizer = _state["model"], _state["tokenizer"]
68
+
69
+ # Build msgs from history and current inputs
70
+ msgs = []
71
+ # Convert history into msgs
72
+ # Each item in history is (user, assistant)
73
+ for user_msg, assistant_msg in history or []:
74
+ if user_msg:
75
+ # history may not contain images; only text
76
+ msgs.append({"role": "user", "content": [user_msg]})
77
+ if assistant_msg:
78
+ msgs.append({"role": "assistant", "content": [assistant_msg]})
79
+
80
+ # Add current user turn
81
+ user_content = []
82
+ if image is not None:
83
+ # Ensure RGB
84
+ if image.mode != "RGB":
85
+ image = image.convert("RGB")
86
+ user_content.append(image)
87
+ if message and message.strip():
88
+ user_content.append(message.strip())
89
+ if not user_content:
90
+ return history or [], "Please provide text or image."
91
+
92
+ msgs.append({"role": "user", "content": user_content})
93
+
94
+ try:
95
+ answer = model.chat(
96
+ msgs=msgs,
97
+ tokenizer=tokenizer,
98
+ enable_thinking=enable_thinking,
99
+ )
100
+ except Exception as e:
101
+ return history or [], f"Inference error: {e}"
102
+
103
+ # Update history for Gradio chat UI: append the latest pair
104
+ history = (history or []) + [(message or "[Image]", answer)]
105
+ sys_info = f"Mode: {mode} | Loaded dtype: {_state['dtype_used']} | Device: CPU"
106
+ return history, sys_info
107
+
108
+ def clear_history():
109
+ return [], ""
110
+
111
+ with gr.Blocks(title="MiniCPM-V-4_5 CPU (int4 default, fp16 optional)", fill_height=True) as demo:
112
+ gr.Markdown("# MiniCPM-V-4_5 CPU Deployment\n- Modes: int4 (default) and fp16\n- Running on CPU")
113
+
114
+ with gr.Row():
115
+ with gr.Column(scale=2):
116
+ chatbox = gr.Chatbot(height=420, label="Chat")
117
+ with gr.Row():
118
+ img = gr.Image(type="pil", label="Image (optional)")
119
+ msg = gr.Textbox(placeholder="Ask a question about the image or general query...", lines=3)
120
+ with gr.Row():
121
+ send_btn = gr.Button("Send", variant="primary")
122
+ clear_btn = gr.Button("Clear")
123
+ with gr.Column(scale=1):
124
+ mode = gr.Radio(
125
+ choices=["int4", "fp16"],
126
+ value="int4",
127
+ label="Precision Mode (CPU)",
128
+ info="int4 as default. fp16 uses bf16/fp32 on CPU."
129
+ )
130
+ thinking = gr.Checkbox(label="Enable Thinking Mode", value=False)
131
+ sys_out = gr.Markdown("")
132
+
133
+ def on_send(message, image, history, mode, thinking):
134
+ return chat_infer(image, message, history, mode, thinking)
135
+
136
+ send_btn.click(
137
+ fn=on_send,
138
+ inputs=[msg, img, chatbox, mode, thinking],
139
+ outputs=[chatbox, sys_out],
140
+ show_progress=True,
141
+ )
142
+
143
+ # Submit on Enter
144
+ msg.submit(
145
+ fn=on_send,
146
+ inputs=[msg, img, chatbox, mode, thinking],
147
+ outputs=[chatbox, sys_out],
148
+ show_progress=True,
149
+ )
150
+
151
+ clear_btn.click(fn=clear_history, outputs=[chatbox, sys_out])
152
+
153
+ if __name__ == "__main__":
154
+ # For CPU environments with many threads, you may limit to reduce contention:
155
+ torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "4")))
156
+ demo.launch()