SkillForge45 commited on
Commit
74de93b
·
verified ·
1 Parent(s): a7bbbd5

Create tokenizer.py

Browse files
Files changed (1) hide show
  1. de_en/tokenizer.py +60 -0
de_en/tokenizer.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import numpy as np
5
+ import cv2
6
+
7
+ class VideoTokenizer:
8
+ def __init__(self, resolution=128):
9
+ self.resolution = resolution
10
+ self.transform = transforms.Compose([
11
+ transforms.Resize((resolution, resolution)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
14
+ ])
15
+
16
+ def encode_image(self, image):
17
+ if isinstance(image, str):
18
+ image = Image.open(image).convert('RGB')
19
+ return self.transform(image)
20
+
21
+ def encode_video(self, video_path, max_frames=24):
22
+ # For simplicity, we'll assume video_path is actually a tensor in our dataset
23
+ # In a real implementation, this would read frames from video file
24
+ if isinstance(video_path, torch.Tensor):
25
+ return video_path
26
+
27
+ cap = cv2.VideoCapture(video_path)
28
+ frames = []
29
+ frame_count = 0
30
+
31
+ while cap.isOpened() and frame_count < max_frames:
32
+ ret, frame = cap.read()
33
+ if not ret:
34
+ break
35
+
36
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
37
+ frame = self.transform(frame)
38
+ frames.append(frame)
39
+ frame_count += 1
40
+
41
+ cap.release()
42
+
43
+ # Pad if video is shorter than max_frames
44
+ while len(frames) < max_frames:
45
+ frames.append(torch.zeros_like(frames[0]))
46
+
47
+ return torch.stack(frames)
48
+
49
+ def save_video(self, frames, output_path, fps=24):
50
+ frames = (frames.clamp(-1, 1) + 1) / 2 # [-1,1] to [0,1]
51
+ frames = (frames.permute(0, 2, 3, 1).numpy() * 255).astype('uint8')
52
+
53
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
54
+ out = cv2.VideoWriter(output_path, fourcc, fps, (self.resolution, self.resolution))
55
+
56
+ for frame in frames:
57
+ bgr_frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
58
+ out.write(bgr_frame)
59
+
60
+ out.release()