senthil3226w commited on
Commit
663af26
·
verified ·
1 Parent(s): 7ab5dbd

Create utility.py

Browse files
Files changed (1) hide show
  1. utility.py +90 -0
utility.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import tensorflow_hub as hub
5
+
6
+ _UINT8_MAX_F = float(np.iinfo(np.uint8).max)
7
+
8
+ def load_image(img_url: str):
9
+ """Returns an image with shape [height, width, num_channels], with pixels in [0..1] range, and type np.float32."""
10
+
11
+ if (img_url.startswith("https")):
12
+ response = requests.get(img_url)
13
+ image_data = response.content
14
+ else:
15
+ image_data = tf.io.read_file(img_url)
16
+
17
+ image = tf.io.decode_image(image_data, channels=3)
18
+ image_numpy = tf.cast(image, dtype=tf.float32).numpy()
19
+ return image_numpy / _UINT8_MAX_F
20
+
21
+
22
+
23
+ model = hub.load('https://www.kaggle.com/models/google/film/TensorFlow2/film/1')
24
+
25
+ def interpolate_batch(img1, img2, batch_size=30):
26
+ if model is None:
27
+ print("Model not loaded. Exiting interpolation.")
28
+ return None
29
+
30
+ # Generate time steps from 0 to 1, rounded to two decimal places
31
+ time_values = np.linspace(0.0, 1.0, batch_size, dtype=np.float32)
32
+ time_values = np.round(time_values, 2) # Round time values to two decimal places
33
+ time_values = np.expand_dims(time_values, axis=0) # Add batch dimension
34
+
35
+ image1 = load_image(img1)
36
+ image2 = load_image(img2)
37
+
38
+ if image1 is None or image2 is None:
39
+ print("One or both images failed to load. Exiting interpolation.")
40
+ return None
41
+
42
+ # Create batch input for model
43
+ input_data = {
44
+ 'time': np.repeat(time_values, image1.shape[0], axis=0), # Expand time across batch
45
+ 'x0': np.repeat(np.expand_dims(image1, axis=0), batch_size, axis=0), # Repeat image1 for each batch
46
+ 'x1': np.repeat(np.expand_dims(image2, axis=0), batch_size, axis=0) # Repeat image2 for each batch
47
+ }
48
+
49
+ try:
50
+ mid_frames = model(input_data)
51
+ frames = mid_frames['image'].numpy() # Get interpolated frames
52
+ return frames
53
+ except Exception as e:
54
+ print(f"Error during interpolation: {e}")
55
+ return None
56
+
57
+
58
+ def interpolate_single(img1, img2):
59
+ """Interpolate a single frame at the midpoint between two images (time=0.5)."""
60
+
61
+ if model is None:
62
+ print("Model not loaded. Exiting interpolation.")
63
+ return None
64
+
65
+ # Midpoint time value with batch size of 1
66
+ time_value = np.array([[0.5]], dtype=np.float32) # shape [1, 1]
67
+
68
+ # Load and normalize images
69
+ image1 = load_image(img1)
70
+ image2 = load_image(img2)
71
+
72
+ # Ensure the images have the same height and width if needed
73
+ target_height, target_width = image1.shape[0], image1.shape[1]
74
+ image1_resized = tf.image.resize(image1, [target_height, target_width]).numpy()
75
+ image2_resized = tf.image.resize(image2, [target_height, target_width]).numpy()
76
+
77
+ # Expand dimensions to add batch dimension
78
+ input_data = {
79
+ 'time': time_value, # shape [1, 1]
80
+ 'x0': np.expand_dims(image1_resized, axis=0), # shape [1, H, W, 3]
81
+ 'x1': np.expand_dims(image2_resized, axis=0), # shape [1, H, W, 3]
82
+ }
83
+
84
+ try:
85
+ mid_frame = model(input_data)
86
+ frame = mid_frame['image'][0].numpy() # Extract the interpolated frame
87
+ return frame
88
+ except Exception as e:
89
+ print(f"Error during interpolation: {e}")
90
+ return None