fastapi_dummy / utility.py
senthil3226w's picture
Create utility.py
663af26 verified
import requests
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
_UINT8_MAX_F = float(np.iinfo(np.uint8).max)
def load_image(img_url: str):
"""Returns an image with shape [height, width, num_channels], with pixels in [0..1] range, and type np.float32."""
if (img_url.startswith("https")):
response = requests.get(img_url)
image_data = response.content
else:
image_data = tf.io.read_file(img_url)
image = tf.io.decode_image(image_data, channels=3)
image_numpy = tf.cast(image, dtype=tf.float32).numpy()
return image_numpy / _UINT8_MAX_F
model = hub.load('https://www.kaggle.com/models/google/film/TensorFlow2/film/1')
def interpolate_batch(img1, img2, batch_size=30):
if model is None:
print("Model not loaded. Exiting interpolation.")
return None
# Generate time steps from 0 to 1, rounded to two decimal places
time_values = np.linspace(0.0, 1.0, batch_size, dtype=np.float32)
time_values = np.round(time_values, 2) # Round time values to two decimal places
time_values = np.expand_dims(time_values, axis=0) # Add batch dimension
image1 = load_image(img1)
image2 = load_image(img2)
if image1 is None or image2 is None:
print("One or both images failed to load. Exiting interpolation.")
return None
# Create batch input for model
input_data = {
'time': np.repeat(time_values, image1.shape[0], axis=0), # Expand time across batch
'x0': np.repeat(np.expand_dims(image1, axis=0), batch_size, axis=0), # Repeat image1 for each batch
'x1': np.repeat(np.expand_dims(image2, axis=0), batch_size, axis=0) # Repeat image2 for each batch
}
try:
mid_frames = model(input_data)
frames = mid_frames['image'].numpy() # Get interpolated frames
return frames
except Exception as e:
print(f"Error during interpolation: {e}")
return None
def interpolate_single(img1, img2):
"""Interpolate a single frame at the midpoint between two images (time=0.5)."""
if model is None:
print("Model not loaded. Exiting interpolation.")
return None
# Midpoint time value with batch size of 1
time_value = np.array([[0.5]], dtype=np.float32) # shape [1, 1]
# Load and normalize images
image1 = load_image(img1)
image2 = load_image(img2)
# Ensure the images have the same height and width if needed
target_height, target_width = image1.shape[0], image1.shape[1]
image1_resized = tf.image.resize(image1, [target_height, target_width]).numpy()
image2_resized = tf.image.resize(image2, [target_height, target_width]).numpy()
# Expand dimensions to add batch dimension
input_data = {
'time': time_value, # shape [1, 1]
'x0': np.expand_dims(image1_resized, axis=0), # shape [1, H, W, 3]
'x1': np.expand_dims(image2_resized, axis=0), # shape [1, H, W, 3]
}
try:
mid_frame = model(input_data)
frame = mid_frame['image'][0].numpy() # Extract the interpolated frame
return frame
except Exception as e:
print(f"Error during interpolation: {e}")
return None