Spaces:
Running
Running
import json | |
import re | |
from pathlib import Path | |
from typing import List | |
import json_repair | |
from omagent_core.models.llms.base import BaseLLMBackend | |
from omagent_core.models.llms.prompt import PromptTemplate | |
from omagent_core.tool_system.base import ArgSchema, BaseTool | |
from omagent_core.utils.logger import logging | |
from omagent_core.utils.registry import registry | |
from pydantic import Field | |
from scenedetect import FrameTimecode | |
from ...misc.scene import VideoScenes | |
CURRENT_PATH = Path(__file__).parents[0] | |
ARGSCHEMA = { | |
"start_time": { | |
"type": "number", | |
"description": "Start time (in seconds) of the video to extract frames from.", | |
"required": True, | |
}, | |
"end_time": { | |
"type": "number", | |
"description": "End time (in seconds) of the video to extract frames from.", | |
"required": True, | |
}, | |
"number": { | |
"type": "number", | |
"description": "Number of frames of extraction. More frames means more details but more cost. Do not exceed 10.", | |
"required": True, | |
}, | |
} | |
class Rewinder(BaseTool, BaseLLMBackend): | |
args_schema: ArgSchema = ArgSchema(**ARGSCHEMA) | |
description: str = ( | |
"Rollback and extract frames from video which is already loaded to get more specific details for further analysis." | |
) | |
prompts: List[PromptTemplate] = Field( | |
default=[ | |
PromptTemplate.from_file( | |
CURRENT_PATH.joinpath("rewinder_sys_prompt.prompt"), | |
role="system", | |
), | |
PromptTemplate.from_file( | |
CURRENT_PATH.joinpath("rewinder_user_prompt.prompt"), | |
role="user", | |
), | |
] | |
) | |
def _run( | |
self, start_time: float = 0.0, end_time: float = None, number: int = 1 | |
) -> str: | |
if self.stm(self.workflow_instance_id).get("video", None) is None: | |
raise ValueError("No video is loaded.") | |
else: | |
video: VideoScenes = VideoScenes.from_serializable( | |
self.stm(self.workflow_instance_id)["video"] | |
) | |
if number > 10: | |
logging.warning("Number of frames exceeds 10. Will extract 10 frames.") | |
number = 10 | |
start = FrameTimecode(timecode=start_time, fps=video.stream.frame_rate) | |
if end_time is None: | |
end = video.stream.duration | |
else: | |
end = FrameTimecode(timecode=end_time, fps=video.stream.frame_rate) | |
if start_time == end_time: | |
frames, time_stamps = video.get_video_frames( | |
(start, end + 1), video.stream.frame_rate | |
) | |
else: | |
interval = int((end.get_frames() - start.get_frames()) / number) | |
frames, time_stamps = video.get_video_frames((start, end), interval) | |
# self.stm.image_cache.clear() | |
payload = [] | |
for i, (frame, time_stamp) in enumerate(zip(frames, time_stamps)): | |
payload.append(f"timestamp_{time_stamp}") | |
payload.append(frame) | |
res = self.infer(input_list=[{"timestamp_with_images": payload}])[0]["choices"][ | |
0 | |
]["message"]["content"] | |
image_contents = json_repair.loads(res) | |
self.stm(self.workflow_instance_id)["image_cache"] = {} | |
return f"extracted_frames described as: {image_contents}." | |
async def _arun( | |
self, start_time: float = 0.0, end_time: float = None, number: int = 1 | |
) -> str: | |
return self._run(start_time, end_time, number=number) | |