import spaces, ffmpeg, os, sys, torch, time import gradio as gr from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoModelForImageTextToText, Gemma3nForConditionalGeneration, AutoProcessor, BitsAndBytesConfig, ) from qwen_vl_utils import process_vision_info from loguru import logger logger.remove() logger.add( sys.stderr, format="{time:YYYY-MM-DD ddd HH:mm:ss} | {level} | {message}", ) # --- Installing Flash Attention for ZeroGPU is special --- # import subprocess subprocess.run( "pip install flash-attn --no-build-isolation", env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, shell=True, ) # --- now we got Flash Attention ---# # Set target DEVICE and DTYPE # For maximum memory efficiency, use bfloat16 if your GPU supports it, otherwise float16. DTYPE = ( torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 ) # DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Use "auto" to let accelerate handle device placement (GPU, CPU, disk) DEVICE = "auto" logger.info(f"Device: {DEVICE}, dtype: {DTYPE}") def get_fps_ffmpeg(video_path: str): probe = ffmpeg.probe(video_path) # Find the first video stream video_stream = next( (stream for stream in probe["streams"] if stream["codec_type"] == "video"), None ) if video_stream is None: raise ValueError("No video stream found") # Frame rate is given as a string fraction, e.g., '30000/1001' r_frame_rate = video_stream["r_frame_rate"] num, denom = map(int, r_frame_rate.split("/")) return num / denom def load_model( model_name: str = "chancharikm/qwen2.5-vl-7b-cam-motion-preview", use_flash_attention: bool = True, apply_quantization: bool = True, ): # We recommend enabling flash_attention_2 for better acceleration and memory saving, # especially in multi-image and video scenarios. bnb_config = BitsAndBytesConfig( load_in_4bit=True, # Load model weights in 4-bit bnb_4bit_quant_type="nf4", # Use NF4 quantization (or "fp4") bnb_4bit_compute_dtype=DTYPE, # Perform computations in bfloat16/float16 bnb_4bit_use_double_quant=True, # Optional: further quantization for slightly more memory saving ) # Determine model family from model name model_family = model_name.split("/")[-1].split("-")[ 0 ] # Extract model family from name # Common model loading arguments common_args = { "torch_dtype": DTYPE, "device_map": DEVICE, "low_cpu_mem_usage": True, "quantization_config": bnb_config if apply_quantization else None, } # Add flash attention if supported and requested if use_flash_attention: common_args["attn_implementation"] = "flash_attention_2" # Load model based on family match model_family: case "qwen2.5" | "Qwen2.5": model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_name, **common_args ) case "InternVL3": model = AutoModelForImageTextToText.from_pretrained( model_name, **common_args ) case "gemma": model = Gemma3nForConditionalGeneration.from_pretrained( model_name, **common_args ) case _: raise ValueError(f"Unsupported model family: {model_family}") # Set model to evaluation mode for inference (disables dropout, etc.) return model.eval() def load_processor(model_name="Qwen/Qwen2.5-VL-7B-Instruct"): return AutoProcessor.from_pretrained( model_name, device_map=DEVICE, use_fast=True, torch_dtype=DTYPE, ) logger.debug("Loading Models and Processors...") MODEL_ZOO = { "qwen2.5-vl-7b-cam-motion-preview": load_model( model_name="chancharikm/qwen2.5-vl-7b-cam-motion-preview", use_flash_attention=False, apply_quantization=False, ), "qwen2.5-vl-7b-instruct": load_model( model_name="Qwen/Qwen2.5-VL-7B-Instruct", use_flash_attention=False, apply_quantization=False, ), "qwen2.5-vl-3b-instruct": load_model( model_name="Qwen/Qwen2.5-VL-3B-Instruct", use_flash_attention=False, apply_quantization=False, ), "InternVL3-1B-hf": load_model( model_name="OpenGVLab/InternVL3-1B-hf", use_flash_attention=False, apply_quantization=False, ), "InternVL3-2B-hf": load_model( model_name="OpenGVLab/InternVL3-2B-hf", use_flash_attention=False, apply_quantization=False, ), "InternVL3-8B-hf": load_model( model_name="OpenGVLab/InternVL3-8B-hf", use_flash_attention=False, apply_quantization=True, ), "gemma-3n-e4b-it": load_model( model_name="google/gemma-3n-e4b-it", use_flash_attention=False, apply_quantization=True, ), } PROCESSORS = { "qwen2.5-vl-7b-cam-motion-preview": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"), "qwen2.5-vl-7b-instruct": load_processor("Qwen/Qwen2.5-VL-7B-Instruct"), "qwen2.5-vl-3b-instruct": load_processor("Qwen/Qwen2.5-VL-3B-Instruct"), "InternVL3-1B-hf": load_processor("OpenGVLab/InternVL3-1B-hf"), "InternVL3-2B-hf": load_processor("OpenGVLab/InternVL3-2B-hf"), "InternVL3-8B-hf": load_processor("OpenGVLab/InternVL3-8B-hf"), "gemma-3n-e4b-it": load_processor("google/gemma-3n-e4b-it"), } logger.debug("Models and Processors Loaded!") @spaces.GPU(duration=120) def inference( video_path: str, prompt: str = "Describe the camera motion in this video.", model_name: str = "qwen2.5-vl-7b-instruct", custom_fps: int = 8, max_tokens: int = 256, temperature: float = 0.0, ): s_time = time.time() # default processor # processor, model = PROCESSOR, MODEL # processor = load_processor() # model = load_model( # use_flash_attention=use_flash_attention, apply_quantization=apply_quantization # ) model = MODEL_ZOO[model_name] processor = PROCESSORS[model_name] # The model is trained on 8.0 FPS which we recommend for optimal inference fps = custom_fps if custom_fps else get_fps_ffmpeg(video_path) logger.info(f"{os.path.basename(video_path)} FPS: {fps}") messages = [ { "role": "user", "content": [ { "type": "video", "video": video_path, "fps": fps, }, {"type": "text", "text": prompt}, ], } ] # text = processor.apply_chat_template( # messages, tokenize=False, add_generation_prompt=True # ) # image_inputs, video_inputs, video_kwargs = process_vision_info( # messages, return_video_kwargs=True # ) # This prevents PyTorch from building the computation graph for gradients, # saving a significant amount of memory for intermediate activations. with torch.no_grad(): model_family = model_name.split("-")[0] match model_family: case "qwen2.5": text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs, video_kwargs = process_vision_info( messages, return_video_kwargs=True ) inputs = processor( text=[text], images=image_inputs, videos=video_inputs, # fps=fps, padding=True, return_tensors="pt", **video_kwargs, ) inputs = inputs.to("cuda") # Inference generated_ids = model.generate( **inputs, max_new_tokens=max_tokens, temperature=float(temperature), do_sample=temperature > 0.0, ) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] output_text = processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] case "InternVL3" | "gemma": inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", fps=fps, # num_frames = 8 ).to("cuda", dtype=DTYPE) output = model.generate( **inputs, max_new_tokens=max_tokens, temperature=float(temperature), do_sample=temperature > 0.0, ) output_text = processor.decode( output[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True ) case _: raise ValueError(f"{model_name} is not currently supported") return { "output_text": output_text, "fps": fps, "inference_time": time.time() - s_time, } demo = gr.Interface( fn=inference, inputs=[ gr.Video(label="Input Video"), gr.Textbox( label="Prompt", lines=3, info="Some models like [cam motion](https://huggingface.co/chancharikm/qwen2.5-vl-7b-cam-motion-preview) are trained specific prompts", value="Describe the camera motion in this video.", ), gr.Dropdown(label="Model", choices=list(MODEL_ZOO.keys())), gr.Number( label="FPS", info="inference sampling rate (Qwen2.5VL is trained on videos with 8 fps); a value of 0 means the FPS of the input video will be used", value=8, minimum=0, step=1, ), gr.Slider( label="Max Tokens", info="maximum number of tokens to generate", value=128, minimum=32, maximum=512, step=32, ), gr.Slider( label="Temperature", value=0.0, minimum=0.0, maximum=1.0, step=0.1, ), # gr.Checkbox(label="Use Flash Attention", value=False), # gr.Checkbox(label="Apply Quantization", value=True), ], outputs=gr.JSON(label="Output JSON"), title="Video Captioning with VLM", description='comparing various "small" VLMs on the task of video captioning', api_name="video_inference", ) demo.launch( mcp_server=True, app_kwargs={"docs_url": "/docs"} # add FastAPI Swagger API Docs )