Spaces:
Runtime error
Runtime error
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Copyright 2023 Imperial College London (Pingchuan Ma) | |
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
import torch | |
import whisper | |
import torchaudio | |
import torchvision | |
from .transforms import AudioTransform, VideoTransform | |
class AVSRDataLoader: | |
def __init__(self, modality, speed_rate=1, transform=True, detector="retinaface", convert_gray=True): | |
self.modality = modality | |
self.transform = transform | |
if self.modality in ["audio", "audiovisual"]: | |
self.audio_transform = AudioTransform() | |
if self.modality in ["video", "audiovisual"]: | |
if detector == "mediapipe": | |
from pipelines.detectors.mediapipe.video_process import VideoProcess | |
self.video_process = VideoProcess(convert_gray=convert_gray) | |
if detector == "retinaface": | |
from pipelines.detectors.retinaface.video_process import VideoProcess | |
self.video_process = VideoProcess(convert_gray=convert_gray) | |
self.video_transform = VideoTransform(speed_rate=speed_rate) | |
def load_data(self, data_filename, landmarks=None, transform=True): | |
if self.modality == "audio": | |
audio = self.load_audio(data_filename) | |
return self.audio_transform(audio) if self.transform else audio | |
if self.modality == "video": | |
video = self.load_video(data_filename) | |
video = self.video_process(video, landmarks) | |
video = torch.tensor(video) | |
return self.video_transform(video) if self.transform else video | |
if self.modality == "audiovisual": | |
rate_ratio = 640 | |
audio = self.load_audio(data_filename) | |
video = self.load_video(data_filename) | |
video = self.video_process(video, landmarks) | |
video = torch.tensor(video) | |
min_t = min(len(video), audio.size(1) // rate_ratio) | |
audio = audio[:, :min_t*rate_ratio] | |
video = video[:min_t] | |
if self.transform: | |
audio = self.audio_transform(audio) | |
video = self.video_transform(video) | |
return video, audio | |
def load_audio(self, data_filename): | |
# rtype: [1, T] | |
waveform = torch.tensor(whisper.load_audio(data_filename)).unsqueeze(0) | |
return waveform | |
def load_video(self, data_filename): | |
return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy() | |