#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Ovis-U1-3B 多模态 DEMO(CPU / GPU 自适应版本) 依赖:Python 3.10+、torch 2.*、transformers 4.41.*、gradio 4.* """ # ─────────────────────────────────────────────── # ① 在任何 transformers / flash_attn 导入之前处理环境 # ─────────────────────────────────────────────── import os, sys, types, subprocess, random, numpy as np, torch import importlib.util # ★ 新增:用于生成 ModuleSpec DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # -------- CPU 环境:屏蔽 flash-attn -------- if DEVICE == "cpu": # 卸载潜在的 flash-attn subprocess.run("pip uninstall -y flash-attn", shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # 构造空壳模块 fake_flash_attn = types.ModuleType("flash_attn") fake_layers = types.ModuleType("flash_attn.layers") fake_rotary = types.ModuleType("flash_attn.layers.rotary") def _cpu_apply_rotary_emb(x, cos, sin): """纯 CPU 的旋转位置编码(简易实现)""" x1, x2 = x[..., ::2], x[..., 1::2] rot_x1 = x1 * cos - x2 * sin rot_x2 = x1 * sin + x2 * cos out = torch.empty_like(x) out[..., ::2] = rot_x1 out[..., 1::2] = rot_x2 return out fake_rotary.apply_rotary_emb = _cpu_apply_rotary_emb fake_layers.rotary = fake_rotary fake_flash_attn.layers = fake_layers # ★ 新增:为空壳模块补充合法的 __spec__ fake_flash_attn.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None) sys.modules.update({ "flash_attn": fake_flash_attn, "flash_attn.layers": fake_layers, "flash_attn.layers.rotary": fake_rotary, }) else: # GPU 环境:尝试安装 flash-attn try: subprocess.run( "pip install flash-attn==2.6.3 --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, check=True) except subprocess.CalledProcessError: print("[WARN] flash-attn 安装失败,GPU 加速功能受限。") # ─────────────────────────────────────────────── # ② 常规依赖 # ─────────────────────────────────────────────── from PIL import Image import gradio as gr import spaces from transformers import AutoModelForCausalLM from test_img_edit import pipe_img_edit from test_img_to_txt import pipe_txt_gen from test_txt_to_img import pipe_t2i # ─────────────────────────────────────────────── # ③ 工具函数 & 常量 # ─────────────────────────────────────────────── MAX_SEED = 10_000 def set_global_seed(seed: int = 42): random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def randomize_seed_fn(seed: int, randomize: bool) -> int: return random.randint(0, MAX_SEED) if randomize else seed # ─────────────────────────────────────────────── # ④ 加载模型 # ─────────────────────────────────────────────── HF_TOKEN = os.getenv("HF_TOKEN") MODEL_ID = "AIDC-AI/Ovis-U1-3B" print(f"[INFO] Loading {MODEL_ID} on {DEVICE} …") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True, device_map="auto", token=HF_TOKEN, trust_remote_code=True ).eval() print("[INFO] Model ready!") # ─────────────────────────────────────────────── # ⑤ 推理封装 # ─────────────────────────────────────────────── def process_txt_to_img(prompt, height, width, steps, seed, cfg, progress=gr.Progress(track_tqdm=True)): set_global_seed(seed) return pipe_t2i(model, prompt, height, width, steps, cfg=cfg, seed=seed) def process_img_to_txt(prompt, img, progress=gr.Progress(track_tqdm=True)): return pipe_txt_gen(model, img, prompt) def process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)): set_global_seed(seed) return pipe_img_edit(model, img, prompt, steps, txt_cfg, img_cfg, seed=seed) # ─────────────────────────────────────────────── # ⑥ Gradio UI(与前版一致,此处省略修改标记) # ─────────────────────────────────────────────── with gr.Blocks(title="Ovis-U1-3B (CPU/GPU adaptive)") as demo: gr.Markdown("# Ovis-U1-3B\n多模态文本-图像 DEMO(CPU/GPU 自适应版)") with gr.Row(): with gr.Column(): with gr.Tabs(): # Tab 1: Image + Text → Image with gr.TabItem("Image + Text → Image"): edit_image_input = gr.Image(label="Input Image", type="pil") with gr.Row(): edit_prompt_input = gr.Textbox(show_label=False, placeholder="Describe the editing instruction…") run_edit_image_btn = gr.Button("Run", scale=0) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): edit_img_guidance = gr.Slider(label="Image Guidance", minimum=1, maximum=10, value=1.5, step=0.1) edit_txt_guidance = gr.Slider(label="Text Guidance", minimum=1, maximum=30, value=6.0, step=0.5) edit_steps = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1) edit_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1) edit_random = gr.Checkbox(label="Randomize seed", value=False) # Tab 2: Text → Image with gr.TabItem("Text → Image"): prompt_gen = gr.Textbox(show_label=False, placeholder="Describe the image you want…") run_gen_btn = gr.Button("Run", scale=0) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): height_slider = gr.Slider(label="height", minimum=256, maximum=1536, value=1024, step=32) width_slider = gr.Slider(label="width", minimum=256, maximum=1536, value=1024, step=32) guidance_slider = gr.Slider(label="Guidance Scale", minimum=1, maximum=30, value=5, step=0.5) steps_slider = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1) seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1) random_check = gr.Checkbox(label="Randomize seed", value=False) # Tab 3: Image → Text with gr.TabItem("Image → Text"): understand_img = gr.Image(label="Input Image", type="pil") understand_prompt = gr.Textbox(show_label=False, placeholder="Describe the question about image…") run_understand = gr.Button("Run", scale=0) clear_btn = gr.Button("Clear All") with gr.Column(): gallery = gr.Gallery(label="Generated Images", columns=2, visible=True) txt_out = gr.Textbox(label="Generated Text", visible=False, lines=5, interactive=False) # 事件绑定(与上一版相同,省略重复注释) def run_tab1(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)): if img is None: return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True) imgs = process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress) return gr.update(value=imgs, visible=True), gr.update(value="", visible=False) def run_tab2(prompt, h, w, steps, seed, guidance, progress=gr.Progress(track_tqdm=True)): imgs = process_txt_to_img(prompt, h, w, steps, seed, guidance, progress) return gr.update(value=imgs, visible=True), gr.update(value="", visible=False) def run_tab3(img, prompt, progress=gr.Progress(track_tqdm=True)): if img is None: return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True) text = process_img_to_txt(prompt, img, progress) return gr.update(value=[], visible=False), gr.update(value=text, visible=True) # Tab1 绑定 run_edit_image_btn.click(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then( run_tab1, [edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance], [gallery, txt_out]) edit_prompt_input.submit(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then( run_tab1, [edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance], [gallery, txt_out]) # Tab2 绑定 run_gen_btn.click(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then( run_tab2, [prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider], [gallery, txt_out]) prompt_gen.submit(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then( run_tab2, [prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider], [gallery, txt_out]) # Tab3 绑定 run_understand.click(run_tab3, [understand_img, understand_prompt], [gallery, txt_out]) understand_prompt.submit(run_tab3, [understand_img, understand_prompt], [gallery, txt_out]) # 清空 def clear_all(): return ( gr.update(value=None), gr.update(value=""), gr.update(value=1.5), gr.update(value=6.0), gr.update(value=50), gr.update(value=42), gr.update(value=False), gr.update(value=""), gr.update(value=1024), gr.update(value=1024), gr.update(value=5), gr.update(value=50), gr.update(value=42), gr.update(value=False), gr.update(value=None), gr.update(value=""), gr.update(value=[], visible=True), gr.update(value="", visible=False) ) clear_btn.click(clear_all, [], [ edit_image_input, edit_prompt_input, edit_img_guidance, edit_txt_guidance, edit_steps, edit_seed, edit_random, prompt_gen, height_slider, width_slider, guidance_slider, steps_slider, seed_slider, random_check, understand_img, understand_prompt, gallery, txt_out ]) # ─────────────────────────────────────────────── # ⑦ 启动 # ─────────────────────────────────────────────── if __name__ == "__main__": demo.launch()