geekyrakshit's picture
updated zero-dce model
295bcab
import tensorflow as tf
from tensorflow.keras import layers, Input, Model
def build_dce_net() -> Model:
input_image = Input(shape=[None, None, 3])
conv1 = layers.Conv2D(
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
)(input_image)
conv2 = layers.Conv2D(
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
)(conv1)
conv3 = layers.Conv2D(
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
)(conv2)
conv4 = layers.Conv2D(
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
)(conv3)
int_con1 = layers.Concatenate(axis=-1)([conv4, conv3])
conv5 = layers.Conv2D(
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
)(int_con1)
int_con2 = layers.Concatenate(axis=-1)([conv5, conv2])
conv6 = layers.Conv2D(
32, (3, 3), strides=(1, 1), activation="relu", padding="same"
)(int_con2)
int_con3 = layers.Concatenate(axis=-1)([conv6, conv1])
x_r = layers.Conv2D(24, (3, 3), strides=(1, 1), activation="tanh", padding="same")(
int_con3
)
return Model(inputs=input_image, outputs=x_r)