|
|
|
import json |
|
import time |
|
import math |
|
import base64 |
|
import requests |
|
|
|
import torch |
|
import decord |
|
import numpy as np |
|
from PIL import Image, ImageSequence |
|
from torchvision.io import read_image, encode_jpeg |
|
from torchvision.transforms.functional import resize |
|
from torchvision.transforms import InterpolationMode |
|
|
|
|
|
class ConversationModeI18N: |
|
G = "General" |
|
D = "Deep Thinking" |
|
|
|
|
|
class ConversationModeCN: |
|
G = "常规" |
|
D = "深度思考" |
|
|
|
|
|
def round_by_factor(number: int, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
|
|
def ceil_by_factor(number: int, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
|
|
def floor_by_factor(number: int, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
|
|
def get_resized_hw_for_Navit( |
|
height: int, |
|
width: int, |
|
min_pixels: int, |
|
max_pixels: int, |
|
max_ratio: int = 200, |
|
factor: int = 28, |
|
): |
|
if max(height, width) / min(height, width) > max_ratio: |
|
raise ValueError( |
|
f"absolute aspect ratio must be smaller than {max_ratio}, got {max(height, width) / min(height, width)}" |
|
) |
|
h_bar = max(factor, round_by_factor(height, factor)) |
|
w_bar = max(factor, round_by_factor(width, factor)) |
|
if h_bar * w_bar > max_pixels: |
|
beta = math.sqrt((height * width) / max_pixels) |
|
h_bar = floor_by_factor(height / beta, factor) |
|
w_bar = floor_by_factor(width / beta, factor) |
|
elif h_bar * w_bar < min_pixels: |
|
beta = math.sqrt(min_pixels / (height * width)) |
|
h_bar = ceil_by_factor(height * beta, factor) |
|
w_bar = ceil_by_factor(width * beta, factor) |
|
return int(h_bar), int(w_bar) |
|
|
|
|
|
class SeedVLInfer: |
|
def __init__( |
|
self, |
|
model_id: str, |
|
api_key: str, |
|
base_url: str = 'https://ark.cn-beijing.volces.com/api/v3/chat/completions', |
|
min_pixels: int = 4 * 28 * 28, |
|
max_pixels: int = 5120 * 28 * 28, |
|
video_sampling_strategy: dict = { |
|
'sampling_fps': |
|
1, |
|
'min_n_frames': |
|
16, |
|
'max_video_length': |
|
81920, |
|
'max_pixels_choices': [ |
|
640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28, |
|
160 * 28 * 28, 128 * 28 * 28 |
|
], |
|
'use_timestamp': |
|
True, |
|
}, |
|
): |
|
self.base_url = base_url |
|
self.api_key = api_key |
|
self.model_id = model_id |
|
self.min_pixels = min_pixels |
|
self.max_pixels = max_pixels |
|
self.sampling_fps = video_sampling_strategy.get('sampling_fps', 1) |
|
self.min_n_frames = video_sampling_strategy.get('min_n_frames', 16) |
|
self.max_video_length = video_sampling_strategy.get( |
|
'max_video_length', 81920) |
|
self.max_pixels_choices = video_sampling_strategy.get( |
|
'max_pixels_choices', [ |
|
640 * 28 * 28, 512 * 28 * 28, 384 * 28 * 28, 256 * 28 * 28, |
|
160 * 28 * 28, 128 * 28 * 28 |
|
]) |
|
self.use_timestamp = video_sampling_strategy.get('use_timestamp', True) |
|
|
|
def preprocess_video(self, video_path: str): |
|
try: |
|
video_reader = decord.VideoReader(video_path, num_threads=2) |
|
fps = video_reader.get_avg_fps() |
|
except decord._ffi.base.DECORDError: |
|
video_reader = [ |
|
frame.convert('RGB') |
|
for frame in ImageSequence.Iterator(Image.open(video_path)) |
|
] |
|
fps = 1 |
|
|
|
length = len(video_reader) |
|
n_frames = min( |
|
max(math.ceil(length / fps * self.sampling_fps), |
|
self.min_n_frames), length) |
|
frame_indices = np.linspace(0, length - 1, |
|
n_frames).round().astype(int).tolist() |
|
max_pixels = self.max_pixels |
|
for round_idx, max_pixels in enumerate(self.max_pixels_choices): |
|
is_last_round = round_idx == len(self.max_pixels_choices) - 1 |
|
if len(frame_indices |
|
) * max_pixels / 28 / 28 > self.max_video_length: |
|
if is_last_round: |
|
max_frame_num = int(self.max_video_length / max_pixels * |
|
28 * 28) |
|
select_ids = np.linspace( |
|
0, |
|
len(frame_indices) - 1, |
|
max_frame_num).round().astype(int).tolist() |
|
frame_indices = [ |
|
frame_indices[select_id] for select_id in select_ids |
|
] |
|
else: |
|
continue |
|
else: |
|
break |
|
|
|
if hasattr(video_reader, "get_batch"): |
|
video_clip = torch.from_numpy( |
|
video_reader.get_batch(frame_indices).asnumpy()).permute( |
|
0, 3, 1, 2) |
|
else: |
|
video_clip_array = torch.stack( |
|
[np.array(video_reader[i]) for i in frame_indices], dim=0) |
|
video_clip = torch.from_numpy(video_clip_array).permute(0, 3, 1, 2) |
|
|
|
height, width = video_clip.shape[-2:] |
|
resized_height, resized_width = get_resized_hw_for_Navit( |
|
height, |
|
width, |
|
min_pixels=self.min_pixels, |
|
max_pixels=max_pixels, |
|
) |
|
resized_video_clip = resize(video_clip, |
|
(resized_height, resized_width), |
|
interpolation=InterpolationMode.BICUBIC, |
|
antialias=True) |
|
if self.use_timestamp: |
|
resized_video_clip = [ |
|
(round(i / fps, 1), f) |
|
for i, f in zip(frame_indices, resized_video_clip) |
|
] |
|
return resized_video_clip |
|
|
|
def preprocess_streaming_frame(self, frame: torch.Tensor): |
|
height, width = frame.shape[-2:] |
|
resized_height, resized_width = get_resized_hw_for_Navit( |
|
height, |
|
width, |
|
min_pixels=self.min_pixels, |
|
max_pixels=self.max_pixels_choices[0], |
|
) |
|
resized_frame = resize(frame[None], (resized_height, resized_width), |
|
interpolation=InterpolationMode.BICUBIC, |
|
antialias=True)[0] |
|
return resized_frame |
|
|
|
def encode_image(self, image: torch.Tensor) -> str: |
|
if image.shape[0] == 4: |
|
image = image[:3] |
|
encoded = encode_jpeg(image) |
|
return base64.b64encode(encoded.numpy()).decode('utf-8') |
|
|
|
def construct_messages(self, |
|
inputs: dict, |
|
streaming_timestamp: int = None, |
|
online: bool = False) -> list[dict]: |
|
content = [] |
|
for i, path in enumerate(inputs.get('files', [])): |
|
if path.endswith('.mp4'): |
|
video = self.preprocess_video(video_path=path) |
|
for frame in video: |
|
if self.use_timestamp: |
|
timestamp, frame = frame |
|
content.append({ |
|
"type": "text", |
|
"text": f'[{timestamp} second]', |
|
}) |
|
content.append({ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": |
|
f"data:image/jpeg;base64,{self.encode_image(frame)}", |
|
"detail": "high" |
|
}, |
|
}) |
|
else: |
|
image = read_image(path) |
|
if online and path.endswith('.webp'): |
|
streaming_timestamp = i |
|
if streaming_timestamp is not None: |
|
image = self.preprocess_streaming_frame(frame=image) |
|
content.append({ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": |
|
f"data:image/jpeg;base64,{self.encode_image(image)}", |
|
"detail": "high" |
|
}, |
|
}) |
|
if streaming_timestamp is not None: |
|
content.insert(-1, { |
|
"type": "text", |
|
"text": f'[{streaming_timestamp} second]', |
|
}) |
|
query = inputs.get('text', '') |
|
if query: |
|
content.append({ |
|
"type": "text", |
|
"text": query, |
|
}) |
|
messages = [{ |
|
"role": "user", |
|
"content": content, |
|
}] |
|
return messages |
|
|
|
def request(self, |
|
messages, |
|
thinking: bool = True, |
|
temperature: float = 1.0): |
|
headers = { |
|
"Authorization": f"Bearer {self.api_key}", |
|
"Content-Type": "application/json" |
|
} |
|
payload = { |
|
"model": self.model_id, |
|
"messages": messages, |
|
"stream": True, |
|
"thinking": { |
|
"type": "enabled" if thinking else "disabled", |
|
}, |
|
"temperature": temperature, |
|
} |
|
for _ in range(10): |
|
try: |
|
requested = requests.post(self.base_url, |
|
headers=headers, |
|
json=payload, |
|
stream=True, |
|
timeout=600) |
|
break |
|
except Exception as e: |
|
time.sleep(0.1) |
|
print(e) |
|
content, reasoning_content = '', '' |
|
for line in requested.iter_lines(): |
|
if not line: |
|
continue |
|
if line.startswith(b'data:'): |
|
data = line[len("data: "):] |
|
if data == b"[DONE]": |
|
yield content, reasoning_content, True |
|
break |
|
delta = json.loads(data)['choices'][0]['delta'] |
|
content += delta['content'] |
|
reasoning_content += delta.get('reasoning_content', '') |
|
yield content, reasoning_content, False |
|
|
|
def __call__(self, |
|
inputs: dict, |
|
history: list[dict] = [], |
|
mode: str = ConversationModeI18N.D, |
|
temperature: float = 1.0, |
|
online: bool = False): |
|
messages = self.construct_messages(inputs=inputs, online=online) |
|
updated_history = history + messages |
|
for response, reasoning, finished in self.request( |
|
messages=updated_history, |
|
thinking=mode == ConversationModeI18N.D, |
|
temperature=temperature): |
|
if mode == ConversationModeI18N.D: |
|
response = '<think>' + reasoning + '</think>' + response |
|
yield response, updated_history + [{'role': 'assistant', 'content': response}], finished |