auto_avsr / pipelines /data /data_module.py
mpc001's picture
Update pipelines/data/data_module.py
b3d8e4f
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2023 Imperial College London (Pingchuan Ma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import whisper
import torchaudio
import torchvision
from .transforms import AudioTransform, VideoTransform
class AVSRDataLoader:
def __init__(self, modality, speed_rate=1, transform=True, detector="retinaface", convert_gray=True):
self.modality = modality
self.transform = transform
if self.modality in ["audio", "audiovisual"]:
self.audio_transform = AudioTransform()
if self.modality in ["video", "audiovisual"]:
if detector == "mediapipe":
from pipelines.detectors.mediapipe.video_process import VideoProcess
self.video_process = VideoProcess(convert_gray=convert_gray)
if detector == "retinaface":
from pipelines.detectors.retinaface.video_process import VideoProcess
self.video_process = VideoProcess(convert_gray=convert_gray)
self.video_transform = VideoTransform(speed_rate=speed_rate)
def load_data(self, data_filename, landmarks=None, transform=True):
if self.modality == "audio":
audio = self.load_audio(data_filename)
return self.audio_transform(audio) if self.transform else audio
if self.modality == "video":
video = self.load_video(data_filename)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
return self.video_transform(video) if self.transform else video
if self.modality == "audiovisual":
rate_ratio = 640
audio = self.load_audio(data_filename)
video = self.load_video(data_filename)
video = self.video_process(video, landmarks)
video = torch.tensor(video)
min_t = min(len(video), audio.size(1) // rate_ratio)
audio = audio[:, :min_t*rate_ratio]
video = video[:min_t]
if self.transform:
audio = self.audio_transform(audio)
video = self.video_transform(video)
return video, audio
def load_audio(self, data_filename):
# rtype: [1, T]
waveform = torch.tensor(whisper.load_audio(data_filename)).unsqueeze(0)
return waveform
def load_video(self, data_filename):
return torchvision.io.read_video(data_filename, pts_unit='sec')[0].numpy()