Spaces:
Sleeping
Sleeping
import requests | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_hub as hub | |
_UINT8_MAX_F = float(np.iinfo(np.uint8).max) | |
def load_image(img_url: str): | |
"""Returns an image with shape [height, width, num_channels], with pixels in [0..1] range, and type np.float32.""" | |
if (img_url.startswith("https")): | |
response = requests.get(img_url) | |
image_data = response.content | |
else: | |
image_data = tf.io.read_file(img_url) | |
image = tf.io.decode_image(image_data, channels=3) | |
image_numpy = tf.cast(image, dtype=tf.float32).numpy() | |
return image_numpy / _UINT8_MAX_F | |
model = hub.load('https://www.kaggle.com/models/google/film/TensorFlow2/film/1') | |
def interpolate_batch(img1, img2, batch_size=30): | |
if model is None: | |
print("Model not loaded. Exiting interpolation.") | |
return None | |
# Generate time steps from 0 to 1, rounded to two decimal places | |
time_values = np.linspace(0.0, 1.0, batch_size, dtype=np.float32) | |
time_values = np.round(time_values, 2) # Round time values to two decimal places | |
time_values = np.expand_dims(time_values, axis=0) # Add batch dimension | |
image1 = load_image(img1) | |
image2 = load_image(img2) | |
if image1 is None or image2 is None: | |
print("One or both images failed to load. Exiting interpolation.") | |
return None | |
# Create batch input for model | |
input_data = { | |
'time': np.repeat(time_values, image1.shape[0], axis=0), # Expand time across batch | |
'x0': np.repeat(np.expand_dims(image1, axis=0), batch_size, axis=0), # Repeat image1 for each batch | |
'x1': np.repeat(np.expand_dims(image2, axis=0), batch_size, axis=0) # Repeat image2 for each batch | |
} | |
try: | |
mid_frames = model(input_data) | |
frames = mid_frames['image'].numpy() # Get interpolated frames | |
return frames | |
except Exception as e: | |
print(f"Error during interpolation: {e}") | |
return None | |
def interpolate_single(img1, img2): | |
"""Interpolate a single frame at the midpoint between two images (time=0.5).""" | |
if model is None: | |
print("Model not loaded. Exiting interpolation.") | |
return None | |
# Midpoint time value with batch size of 1 | |
time_value = np.array([[0.5]], dtype=np.float32) # shape [1, 1] | |
# Load and normalize images | |
image1 = load_image(img1) | |
image2 = load_image(img2) | |
# Ensure the images have the same height and width if needed | |
target_height, target_width = image1.shape[0], image1.shape[1] | |
image1_resized = tf.image.resize(image1, [target_height, target_width]).numpy() | |
image2_resized = tf.image.resize(image2, [target_height, target_width]).numpy() | |
# Expand dimensions to add batch dimension | |
input_data = { | |
'time': time_value, # shape [1, 1] | |
'x0': np.expand_dims(image1_resized, axis=0), # shape [1, H, W, 3] | |
'x1': np.expand_dims(image2_resized, axis=0), # shape [1, H, W, 3] | |
} | |
try: | |
mid_frame = model(input_data) | |
frame = mid_frame['image'][0].numpy() # Extract the interpolated frame | |
return frame | |
except Exception as e: | |
print(f"Error during interpolation: {e}") | |
return None |