Spaces:
Running
Running
| # encoding: utf-8 | |
| import numpy as np | |
| import cv2 | |
| import os | |
| from torch.utils.data import Dataset | |
| from cvtransforms import * | |
| import torch | |
| import editdistance | |
| class MyDataset(Dataset): | |
| letters = [ | |
| " ", | |
| "A", | |
| "B", | |
| "C", | |
| "D", | |
| "E", | |
| "F", | |
| "G", | |
| "H", | |
| "I", | |
| "J", | |
| "K", | |
| "L", | |
| "M", | |
| "N", | |
| "O", | |
| "P", | |
| "Q", | |
| "R", | |
| "S", | |
| "T", | |
| "U", | |
| "V", | |
| "W", | |
| "X", | |
| "Y", | |
| "Z", | |
| ] | |
| def __init__(self, video_path, anno_path, file_list, vid_pad, txt_pad, phase): | |
| self.anno_path = anno_path | |
| self.vid_pad = vid_pad | |
| self.txt_pad = txt_pad | |
| self.phase = phase | |
| with open(file_list, "r") as f: | |
| self.videos = [ | |
| os.path.join(video_path, line.strip()) for line in f.readlines() | |
| ] | |
| self.data = [] | |
| for vid in self.videos: | |
| # items = vid.split(os.path.sep) | |
| items = vid.split("/") | |
| self.data.append((vid, items[-4], items[-1])) | |
| def __getitem__(self, idx): | |
| (vid, spk, name) = self.data[idx] | |
| vid = self._load_vid(vid) | |
| anno = self._load_anno( | |
| os.path.join(self.anno_path, spk, "align", name + ".align") | |
| ) | |
| if self.phase == "train": | |
| vid = HorizontalFlip(vid) | |
| vid = ColorNormalize(vid) | |
| vid_len = vid.shape[0] | |
| anno_len = anno.shape[0] | |
| vid = self._padding(vid, self.vid_pad) | |
| anno = self._padding(anno, self.txt_pad) | |
| return { | |
| "vid": torch.FloatTensor(vid.transpose(3, 0, 1, 2)), | |
| "txt": torch.LongTensor(anno), | |
| "txt_len": anno_len, | |
| "vid_len": vid_len, | |
| } | |
| def __len__(self): | |
| return len(self.data) | |
| def _load_vid(self, p): | |
| files = os.listdir(p) | |
| files = list(filter(lambda file: file.find(".jpg") != -1, files)) | |
| files = sorted(files, key=lambda file: int(os.path.splitext(file)[0])) | |
| array = [cv2.imread(os.path.join(p, file)) for file in files] | |
| array = list(filter(lambda im: not im is None, array)) | |
| array = [ | |
| cv2.resize(im, (128, 64), interpolation=cv2.INTER_LANCZOS4) for im in array | |
| ] | |
| array = np.stack(array, axis=0).astype(np.float32) | |
| return array | |
| def _load_anno(self, name): | |
| with open(name, "r") as f: | |
| lines = [line.strip().split(" ") for line in f.readlines()] | |
| txt = [line[2] for line in lines] | |
| txt = list(filter(lambda s: not s.upper() in ["SIL", "SP"], txt)) | |
| return MyDataset.txt2arr(" ".join(txt).upper(), 1) | |
| def _padding(self, array, length): | |
| array = [array[_] for _ in range(array.shape[0])] | |
| size = array[0].shape | |
| for i in range(length - len(array)): | |
| array.append(np.zeros(size)) | |
| return np.stack(array, axis=0) | |
| def txt2arr(txt, start): | |
| arr = [] | |
| for c in list(txt): | |
| arr.append(MyDataset.letters.index(c) + start) | |
| return np.array(arr) | |
| def arr2txt(arr, start): | |
| txt = [] | |
| for n in arr: | |
| if n >= start: | |
| txt.append(MyDataset.letters[n - start]) | |
| return "".join(txt).strip() | |
| def ctc_arr2txt(arr, start): | |
| pre = -1 | |
| txt = [] | |
| for n in arr: | |
| if pre != n and n >= start: | |
| if ( | |
| len(txt) > 0 | |
| and txt[-1] == " " | |
| and MyDataset.letters[n - start] == " " | |
| ): | |
| pass | |
| else: | |
| txt.append(MyDataset.letters[n - start]) | |
| pre = n | |
| return "".join(txt).strip() | |
| def wer(predict, truth): | |
| word_pairs = [(p[0].split(" "), p[1].split(" ")) for p in zip(predict, truth)] | |
| wer = [1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in word_pairs] | |
| return wer | |
| def cer(predict, truth): | |
| cer = [ | |
| 1.0 * editdistance.eval(p[0], p[1]) / len(p[1]) for p in zip(predict, truth) | |
| ] | |
| return cer | |