File size: 3,188 Bytes
663af26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import requests
import numpy as np
import tensorflow as tf 
import tensorflow_hub as hub

_UINT8_MAX_F = float(np.iinfo(np.uint8).max)

def load_image(img_url: str):
  """Returns an image with shape [height, width, num_channels], with pixels in [0..1] range, and type np.float32."""

  if (img_url.startswith("https")):
    response = requests.get(img_url)
    image_data = response.content
  else:
    image_data = tf.io.read_file(img_url)

  image = tf.io.decode_image(image_data, channels=3)
  image_numpy = tf.cast(image, dtype=tf.float32).numpy()
  return image_numpy / _UINT8_MAX_F



model = hub.load('https://www.kaggle.com/models/google/film/TensorFlow2/film/1')

def interpolate_batch(img1, img2, batch_size=30):
    if model is None:
        print("Model not loaded. Exiting interpolation.")
        return None

    # Generate time steps from 0 to 1, rounded to two decimal places
    time_values = np.linspace(0.0, 1.0, batch_size, dtype=np.float32)
    time_values = np.round(time_values, 2)  # Round time values to two decimal places
    time_values = np.expand_dims(time_values, axis=0)  # Add batch dimension

    image1 = load_image(img1)
    image2 = load_image(img2)
    
    if image1 is None or image2 is None:
        print("One or both images failed to load. Exiting interpolation.")
        return None

    # Create batch input for model
    input_data = {
        'time': np.repeat(time_values, image1.shape[0], axis=0),  # Expand time across batch
        'x0': np.repeat(np.expand_dims(image1, axis=0), batch_size, axis=0),  # Repeat image1 for each batch
        'x1': np.repeat(np.expand_dims(image2, axis=0), batch_size, axis=0)   # Repeat image2 for each batch
    }

    try:
        mid_frames = model(input_data)
        frames = mid_frames['image'].numpy()  # Get interpolated frames
        return frames
    except Exception as e:
        print(f"Error during interpolation: {e}")
        return None


def interpolate_single(img1, img2):
    """Interpolate a single frame at the midpoint between two images (time=0.5)."""

    if model is None:
        print("Model not loaded. Exiting interpolation.")
        return None

    # Midpoint time value with batch size of 1
    time_value = np.array([[0.5]], dtype=np.float32)  # shape [1, 1]

    # Load and normalize images
    image1 = load_image(img1)
    image2 = load_image(img2)

    # Ensure the images have the same height and width if needed
    target_height, target_width = image1.shape[0], image1.shape[1]
    image1_resized = tf.image.resize(image1, [target_height, target_width]).numpy()
    image2_resized = tf.image.resize(image2, [target_height, target_width]).numpy()

    # Expand dimensions to add batch dimension
    input_data = {
        'time': time_value,  # shape [1, 1]
        'x0': np.expand_dims(image1_resized, axis=0),  # shape [1, H, W, 3]
        'x1': np.expand_dims(image2_resized, axis=0),  # shape [1, H, W, 3]
    }

    try:
        mid_frame = model(input_data)
        frame = mid_frame['image'][0].numpy()  # Extract the interpolated frame
        return frame
    except Exception as e:
        print(f"Error during interpolation: {e}")
        return None