import spaces import json import math import os import traceback from io import BytesIO from typing import Any, Dict, List, Optional, Tuple, Union import re import time from threading import Thread from io import BytesIO import uuid import tempfile import gradio as gr import requests import torch from PIL import Image import fitz import numpy as np import cv2 from transformers import ( Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer, AutoTokenizer, ) from reportlab.lib.pagesizes import A4 from reportlab.lib.styles import getSampleStyleSheet from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer from reportlab.lib.units import inch # --- Constants and Model Setup --- MAX_INPUT_TOKEN_LENGTH = 4096 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES")) print("torch.__version__ =", torch.__version__) print("torch.version.cuda =", torch.version.cuda) print("cuda available:", torch.cuda.is_available()) print("cuda device count:", torch.cuda.device_count()) if torch.cuda.is_available(): print("current device:", torch.cuda.current_device()) print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) print("Using device:", device) # --- Model Loading --- MODEL_ID_M = "Qwen/Qwen2.5-VL-7B-Instruct" processor_m = AutoProcessor.from_pretrained(MODEL_ID_M, trust_remote_code=True) model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_M, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_X = "Qwen/Qwen2.5-VL-3B-Instruct" processor_x = AutoProcessor.from_pretrained(MODEL_ID_X, trust_remote_code=True) model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_X, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_Q = "prithivMLmods/Qwen2.5-VL-7B-Abliterated-Caption-it" processor_q = AutoProcessor.from_pretrained(MODEL_ID_Q, trust_remote_code=True) model_q = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_Q, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() MODEL_ID_D = "prithivMLmods/DeepCaption-VLA-7B" processor_d = AutoProcessor.from_pretrained(MODEL_ID_D, trust_remote_code=True) model_d = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID_D, trust_remote_code=True, torch_dtype=torch.float16 ).to(device).eval() # --- Video and PDF Utility Functions --- def downsample_video(video_path): """ Downsamples the video to 10 evenly spaced frames. Each frame is returned as a PIL image. """ try: vidcap = cv2.VideoCapture(video_path) total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) frames = [] # Ensure we don't try to sample more frames than exist num_frames_to_sample = min(10, total_frames) if num_frames_to_sample == 0: vidcap.release() return [] frame_indices = np.linspace(0, total_frames - 1, num_frames_to_sample, dtype=int) for i in frame_indices: vidcap.set(cv2.CAP_PROP_POS_FRAMES, i) success, image = vidcap.read() if success: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image) frames.append(pil_image) vidcap.release() return frames except Exception as e: print(f"Error processing video: {e}") return [] def generate_and_preview_pdf(media_input: Union[str, Image.Image], text_content: str, font_size: int, line_spacing: float, alignment: str, image_size: str, state_media_type: str, state_frames: list): """ Generates a PDF from an image or video frames, saves it, and creates image previews. Returns the path to the PDF and a list of paths to the preview images. """ if (media_input is None and not state_frames) or not text_content or not text_content.strip(): raise gr.Error("Cannot generate PDF. Media input or text content is missing.") images_to_process = [] if state_media_type == "video": images_to_process = [Image.fromarray(frame) for frame in state_frames] # Assuming state_frames are numpy arrays elif isinstance(media_input, Image.Image): images_to_process = [media_input] if not images_to_process: raise gr.Error("No images found to generate PDF.") # --- 1. Generate the PDF --- temp_dir = tempfile.gettempdir() pdf_filename = os.path.join(temp_dir, f"output_{uuid.uuid4()}.pdf") doc = SimpleDocTemplate( pdf_filename, pagesize=A4, rightMargin=inch, leftMargin=inch, topMargin=inch, bottomMargin=inch ) styles = getSampleStyleSheet() style_normal = styles["Normal"] style_normal.fontSize = int(font_size) style_normal.leading = int(font_size) * line_spacing style_normal.alignment = {"Left": 0, "Center": 1, "Right": 2, "Justified": 4}[alignment] story = [] page_width, _ = A4 available_width = page_width - 2 * inch image_widths = { "Small": available_width * 0.3, "Medium": available_width * 0.6, "Large": available_width * 0.9, } img_width = image_widths[image_size] for image in images_to_process: img_buffer = BytesIO() image.save(img_buffer, format='PNG') img_buffer.seek(0) img = RLImage(img_buffer, width=img_width, height=image.height * (img_width / image.width)) story.append(img) story.append(Spacer(1, 6)) # Add a smaller spacer between frames story.append(Spacer(1, 12)) cleaned_text = re.sub(r'#+\s*', '', text_content).replace("*", "") text_paragraphs = cleaned_text.split('\n') for para in text_paragraphs: if para.strip(): story.append(Paragraph(para, style_normal)) doc.build(story) # --- 2. Render PDF pages as images for preview --- preview_images = [] try: pdf_doc = fitz.open(pdf_filename) for page_num in range(len(pdf_doc)): page = pdf_doc.load_page(page_num) pix = page.get_pixmap(dpi=150) preview_img_path = os.path.join(temp_dir, f"preview_{uuid.uuid4()}_p{page_num}.png") pix.save(preview_img_path) preview_images.append(preview_img_path) pdf_doc.close() except Exception as e: print(f"Error generating PDF preview: {e}") return pdf_filename, preview_images # --- Core Application Logic --- @spaces.GPU def process_document_stream( model_name: str, media_input: Union[str, Image.Image], prompt_input: str, max_new_tokens: int, temperature: float, top_p: float, top_k: int, repetition_penalty: float ): """ Main generator function that handles model inference for images or videos. Also returns the type of media and extracted frames for state management. """ if media_input is None: yield "Please upload an image or video.", "", "none", [] return if not prompt_input or not prompt_input.strip(): yield "Please enter a prompt.", "", "none", [] return # --- Model Selection --- if model_name == "Qwen2.5-VL-7B-Instruct": processor, model = processor_m, model_m elif model_name == "Qwen2.5-VL-3B-Instruct": processor, model = processor_x, model_x elif model_name == "Qwen2.5-VL-7B-Abliterated-Caption-it": processor, model = processor_q, model_q elif model_name == "DeepCaption-VLA-7B": processor, model = processor_d, model_d else: yield "Invalid model selected.", "", "none", [] return media_type = "none" saved_frames = [] # --- Input Processing (Image vs. Video) --- if isinstance(media_input, str): # It's a video file path media_type = "video" frames = downsample_video(media_input) if not frames: yield "Could not process video file.", "", "none", [] return # Convert PIL images to numpy arrays for state to avoid serialization issues saved_frames = [np.array(f) for f in frames] messages = [{"role": "user", "content": [{"type": "text", "text": prompt_input}]}] for frame in frames: messages[0]["content"].append({"type": "image", "image": frame}) prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt_full], images=frames, return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device) elif isinstance(media_input, Image.Image): # It's an image media_type = "image" messages = [{"role": "user", "content": [{"type": "image", "image": media_input}, {"type": "text", "text": prompt_input}]}] prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = processor(text=[prompt_full], images=[media_input], return_tensors="pt", padding=True, truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).to(device) else: yield "Invalid input type.", "", "none", [] return streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "do_sample": True if temperature > 0 else False } thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() buffer = "" for new_text in streamer: buffer += new_text buffer = buffer.replace("<|im_end|>", "") time.sleep(0.01) yield buffer, buffer, media_type, saved_frames yield buffer, buffer, media_type, saved_frames # --- Gradio UI Definition --- def create_gradio_interface(): """Builds and returns the Gradio web interface.""" css = """ .main-container { max-width: 1400px; margin: 0 auto; } .process-button { border: none !important; color: white !important; font-weight: bold !important; background-color: blue !important;} .process-button:hover { background-color: darkblue !important; transform: translateY(-2px) !important; box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important; } #gallery { min-height: 400px; } """ with gr.Blocks(theme="bethecloud/storj_theme", css=css) as demo: # Hidden state variables to store media type and frames state_media_type = gr.State("none") state_frames = gr.State([]) gr.HTML("""
Advanced Vision-Language Models for Image and Video Understanding