韩宇
init
1b7e88c
import json
import re
from pathlib import Path
from typing import List
import json_repair
from omagent_core.models.llms.base import BaseLLMBackend
from omagent_core.models.llms.prompt import PromptTemplate
from omagent_core.tool_system.base import ArgSchema, BaseTool
from omagent_core.utils.logger import logging
from omagent_core.utils.registry import registry
from pydantic import Field
from scenedetect import FrameTimecode
from ...misc.scene import VideoScenes
CURRENT_PATH = Path(__file__).parents[0]
ARGSCHEMA = {
"start_time": {
"type": "number",
"description": "Start time (in seconds) of the video to extract frames from.",
"required": True,
},
"end_time": {
"type": "number",
"description": "End time (in seconds) of the video to extract frames from.",
"required": True,
},
"number": {
"type": "number",
"description": "Number of frames of extraction. More frames means more details but more cost. Do not exceed 10.",
"required": True,
},
}
@registry.register_tool()
class Rewinder(BaseTool, BaseLLMBackend):
args_schema: ArgSchema = ArgSchema(**ARGSCHEMA)
description: str = (
"Rollback and extract frames from video which is already loaded to get more specific details for further analysis."
)
prompts: List[PromptTemplate] = Field(
default=[
PromptTemplate.from_file(
CURRENT_PATH.joinpath("rewinder_sys_prompt.prompt"),
role="system",
),
PromptTemplate.from_file(
CURRENT_PATH.joinpath("rewinder_user_prompt.prompt"),
role="user",
),
]
)
def _run(
self, start_time: float = 0.0, end_time: float = None, number: int = 1
) -> str:
if self.stm(self.workflow_instance_id).get("video", None) is None:
raise ValueError("No video is loaded.")
else:
video: VideoScenes = VideoScenes.from_serializable(
self.stm(self.workflow_instance_id)["video"]
)
if number > 10:
logging.warning("Number of frames exceeds 10. Will extract 10 frames.")
number = 10
start = FrameTimecode(timecode=start_time, fps=video.stream.frame_rate)
if end_time is None:
end = video.stream.duration
else:
end = FrameTimecode(timecode=end_time, fps=video.stream.frame_rate)
if start_time == end_time:
frames, time_stamps = video.get_video_frames(
(start, end + 1), video.stream.frame_rate
)
else:
interval = int((end.get_frames() - start.get_frames()) / number)
frames, time_stamps = video.get_video_frames((start, end), interval)
# self.stm.image_cache.clear()
payload = []
for i, (frame, time_stamp) in enumerate(zip(frames, time_stamps)):
payload.append(f"timestamp_{time_stamp}")
payload.append(frame)
res = self.infer(input_list=[{"timestamp_with_images": payload}])[0]["choices"][
0
]["message"]["content"]
image_contents = json_repair.loads(res)
self.stm(self.workflow_instance_id)["image_cache"] = {}
return f"extracted_frames described as: {image_contents}."
async def _arun(
self, start_time: float = 0.0, end_time: float = None, number: int = 1
) -> str:
return self._run(start_time, end_time, number=number)