import tensorflow as tf from tensorflow.keras import layers, Model def create_unet(input_shape=(256, 256, 1)): """Creates a U-Net model for pneumothorax segmentation.""" inputs = layers.Input(input_shape) # Encoder conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs) conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1) pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1) conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2) pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2) conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2) conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3) pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3) conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3) conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4) drop4 = layers.Dropout(0.5)(conv4) pool4 = layers.MaxPooling2D(pool_size=(2, 2))(drop4) conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(pool4) conv5 = layers.Conv2D(1024, 3, activation='relu', padding='same')(conv5) drop5 = layers.Dropout(0.5)(conv5) # Decoder up6 = layers.Conv2DTranspose(512, 2, strides=(2, 2), padding='same')(drop5) merge6 = layers.concatenate([drop4, up6], axis=3) conv6 = layers.Conv2D(512, 3, activation='relu', padding='same')(merge6) conv6 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv6) up7 = layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(conv6) merge7 = layers.concatenate([conv3, up7], axis=3) conv7 = layers.Conv2D(256, 3, activation='relu', padding='same')(merge7) conv7 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv7) up8 = layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(conv7) merge8 = layers.concatenate([conv2, up8], axis=3) conv8 = layers.Conv2D(128, 3, activation='relu', padding='same')(merge8) conv8 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv8) up9 = layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(conv8) merge9 = layers.concatenate([conv1, up9], axis=3) conv9 = layers.Conv2D(64, 3, activation='relu', padding='same')(merge9) conv9 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv9) conv9 = layers.Conv2D(2, 3, activation='relu', padding='same')(conv9) conv10 = layers.Conv2D(1, 1, activation='sigmoid')(conv9) return Model(inputs, conv10)