Spaces:
Running
Running
| # Copyright (2025) [Seed-VL-Cookbook] Bytedance Seed | |
| import cv2 | |
| import json | |
| import time | |
| import math | |
| import base64 | |
| import requests | |
| import torch | |
| import decord | |
| import numpy as np | |
| from PIL import Image, ImageSequence | |
| from torchvision.io import read_image, encode_jpeg | |
| from torchvision.transforms.functional import resize, pil_to_tensor | |
| from torchvision.transforms import InterpolationMode | |
| class ConversationModeI18N: | |
| G = "General" | |
| D = "Deep Thinking" | |
| class ConversationModeCN: | |
| G = "常规" | |
| D = "深度思考" | |
| def round_by_factor(number: int, factor: int) -> int: | |
| """Returns the closest integer to 'number' that is divisible by 'factor'.""" | |
| return round(number / factor) * factor | |
| def ceil_by_factor(number: int, factor: int) -> int: | |
| """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" | |
| return math.ceil(number / factor) * factor | |
| def floor_by_factor(number: int, factor: int) -> int: | |
| """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" | |
| return math.floor(number / factor) * factor | |
| def get_resized_hw_for_Navit( | |
| height: int, | |
| width: int, | |
| min_pixels: int, | |
| max_pixels: int, | |
| max_ratio: int = 200, | |
| factor: int = 28, | |
| ): | |
| if max(height, width) / min(height, width) > max_ratio: | |
| raise ValueError( | |
| f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)}" | |
| ) | |
| h_bar = max(factor, round_by_factor(height, factor)) | |
| w_bar = max(factor, round_by_factor(width, factor)) | |
| if h_bar * w_bar > max_pixels: | |
| beta = math.sqrt((height * width) / max_pixels) | |
| h_bar = floor_by_factor(height / beta, factor) | |
| w_bar = floor_by_factor(width / beta, factor) | |
| elif h_bar * w_bar < min_pixels: | |
| beta = math.sqrt(min_pixels / (height * width)) | |
| h_bar = ceil_by_factor(height * beta, factor) | |
| w_bar = ceil_by_factor(width * beta, factor) | |
| return int(h_bar), int(w_bar) | |
| class SeedVLInfer: | |
| def __init__( | |
| self, | |
| model_id: str, | |
| api_key: str, | |
| base_url: str = 'https://ark.cn-beijing.volces.com/api/v3/chat/completions', | |
| min_pixels: int = 4 * 28 * 28, | |
| max_pixels: int = 5120 * 28 * 28, | |
| video_sampling_strategy: dict = { | |
| 'sampling_fps': | |
| 1, | |
| 'min_n_frames': | |
| 16, | |
| 'max_video_length': | |
| 81920, | |
| 'max_pixels_choices': [ | |
| 640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28, | |
| 160 * 28 * 28, 128 * 28 * 28 | |
| ], | |
| 'use_timestamp': | |
| True, | |
| }, | |
| ): | |
| self.base_url = base_url | |
| self.api_key = api_key | |
| self.model_id = model_id | |
| self.min_pixels = min_pixels | |
| self.max_pixels = max_pixels | |
| self.sampling_fps = video_sampling_strategy.get('sampling_fps', 1) | |
| self.min_n_frames = video_sampling_strategy.get('min_n_frames', 16) | |
| self.max_video_length = video_sampling_strategy.get( | |
| 'max_video_length', 81920) | |
| self.max_pixels_choices = video_sampling_strategy.get( | |
| 'max_pixels_choices', [ | |
| 640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28, | |
| 160 * 28 * 28, 128 * 28 * 28 | |
| ]) | |
| self.use_timestamp = video_sampling_strategy.get('use_timestamp', True) | |
| def preprocess_video(self, video_path: str): | |
| try: | |
| video_reader = decord.VideoReader(video_path, num_threads=2) | |
| fps = video_reader.get_avg_fps() | |
| except decord._ffi.base.DECORDError: | |
| video_reader = [ | |
| frame.convert('RGB') | |
| for frame in ImageSequence.Iterator(Image.open(video_path)) | |
| ] | |
| fps = 1 | |
| length = len(video_reader) | |
| n_frames = min( | |
| max(math.ceil(length / fps * self.sampling_fps), | |
| self.min_n_frames), length) | |
| frame_indices = np.linspace(0, length - 1, | |
| n_frames).round().astype(int).tolist() | |
| max_pixels = self.max_pixels | |
| for round_idx, max_pixels in enumerate(self.max_pixels_choices): | |
| is_last_round = round_idx == len(self.max_pixels_choices) - 1 | |
| if len(frame_indices | |
| ) * max_pixels / 28 / 28 > self.max_video_length: | |
| if is_last_round: | |
| max_frame_num = int(self.max_video_length / max_pixels * | |
| 28 * 28) | |
| select_ids = np.linspace( | |
| 0, | |
| len(frame_indices) - 1, | |
| max_frame_num).round().astype(int).tolist() | |
| frame_indices = [ | |
| frame_indices[select_id] for select_id in select_ids | |
| ] | |
| else: | |
| continue | |
| else: | |
| break | |
| if hasattr(video_reader, "get_batch"): | |
| video_clip = torch.from_numpy( | |
| video_reader.get_batch(frame_indices).asnumpy()).permute( | |
| 0, 3, 1, 2) | |
| else: | |
| video_clip_array = torch.stack( | |
| [np.array(video_reader[i]) for i in frame_indices], dim=0) | |
| video_clip = torch.from_numpy(video_clip_array).permute(0, 3, 1, 2) | |
| height, width = video_clip.shape[-2:] | |
| resized_height, resized_width = get_resized_hw_for_Navit( | |
| height, | |
| width, | |
| min_pixels=self.min_pixels, | |
| max_pixels=max_pixels, | |
| ) | |
| resized_video_clip = resize(video_clip, | |
| (resized_height, resized_width), | |
| interpolation=InterpolationMode.BICUBIC, | |
| antialias=True) | |
| if self.use_timestamp: | |
| resized_video_clip = [ | |
| (round(i / fps, 1), f) | |
| for i, f in zip(frame_indices, resized_video_clip) | |
| ] | |
| return resized_video_clip | |
| def preprocess_streaming_frame(self, frame: torch.Tensor): | |
| height, width = frame.shape[-2:] | |
| resized_height, resized_width = get_resized_hw_for_Navit( | |
| height, | |
| width, | |
| min_pixels=self.min_pixels, | |
| max_pixels=self.max_pixels_choices[0], | |
| ) | |
| resized_frame = resize(frame[None], (resized_height, resized_width), | |
| interpolation=InterpolationMode.BICUBIC, | |
| antialias=True)[0] | |
| return resized_frame | |
| def encode_image(self, image: torch.Tensor) -> str: | |
| if image.shape[0] == 4: | |
| image = image[:3] | |
| encoded = encode_jpeg(image) | |
| return base64.b64encode(encoded.numpy()).decode('utf-8') | |
| def construct_messages(self, | |
| inputs: dict, | |
| streaming_timestamp: int = None, | |
| online: bool = False) -> list[dict]: | |
| content = [] | |
| for i, path in enumerate(inputs.get('files', [])): | |
| if path.endswith('.mp4'): | |
| video = self.preprocess_video(video_path=path) | |
| for frame in video: | |
| if self.use_timestamp: | |
| timestamp, frame = frame | |
| content.append({ | |
| "type": "text", | |
| "text": f'[{timestamp} second]', | |
| }) | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": | |
| f"data:image/jpeg;base64,{self.encode_image(frame)}", | |
| "detail": "high" | |
| }, | |
| }) | |
| else: | |
| try: | |
| image = read_image(path, "RGB") | |
| except: | |
| try: | |
| image = pil_to_tensor(Image.open(path).convert('RGB')) | |
| except: | |
| image = torch.from_numpy( | |
| cv2.cvtColor( | |
| cv2.imread(path), | |
| cv2.COLOR_BGR2RGB | |
| ) | |
| ).permute(2, 0, 1) | |
| if online and path.endswith('.webp'): | |
| streaming_timestamp = i | |
| if streaming_timestamp is not None: | |
| image = self.preprocess_streaming_frame(frame=image) | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": | |
| f"data:image/jpeg;base64,{self.encode_image(image)}", | |
| "detail": "high" | |
| }, | |
| }) | |
| if streaming_timestamp is not None: | |
| content.insert(-1, { | |
| "type": "text", | |
| "text": f'[{streaming_timestamp} second]', | |
| }) | |
| query = inputs.get('text', '') | |
| if query: | |
| content.append({ | |
| "type": "text", | |
| "text": query, | |
| }) | |
| messages = [{ | |
| "role": "user", | |
| "content": content, | |
| }] | |
| return messages | |
| def request(self, | |
| messages, | |
| thinking: bool = True, | |
| temperature: float = 1.0): | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": self.model_id, | |
| "messages": messages, | |
| "stream": True, | |
| "thinking": { | |
| "type": "enabled" if thinking else "disabled", | |
| }, | |
| "temperature": temperature, | |
| } | |
| for _ in range(10): | |
| try: | |
| requested = requests.post(self.base_url, | |
| headers=headers, | |
| json=payload, | |
| stream=True, | |
| timeout=600) | |
| break | |
| except Exception as e: | |
| time.sleep(0.1) | |
| print(e) | |
| content, reasoning_content = '', '' | |
| for line in requested.iter_lines(): | |
| if not line: | |
| continue | |
| if line.startswith(b'data:'): | |
| data = line[len("data: "):] | |
| if data == b"[DONE]": | |
| yield content, reasoning_content, True | |
| break | |
| delta = json.loads(data)['choices'][0]['delta'] | |
| content += delta['content'] | |
| reasoning_content += delta.get('reasoning_content', '') | |
| yield content, reasoning_content, False | |
| def __call__(self, | |
| inputs: dict, | |
| history: list[dict] = [], | |
| mode: str = ConversationModeI18N.D, | |
| temperature: float = 1.0, | |
| online: bool = False): | |
| messages = self.construct_messages(inputs=inputs, online=online) | |
| updated_history = history + messages | |
| for response, reasoning, finished in self.request( | |
| messages=updated_history, | |
| thinking=mode == ConversationModeI18N.D, | |
| temperature=temperature): | |
| if mode == ConversationModeI18N.D: | |
| response = '<think>' + reasoning + '</think>' + response | |
| yield response, updated_history + [{'role': 'assistant', 'content': response}], finished |