Spaces:
Sleeping
Sleeping
Update pages/15_TransferLearning_HF.py
Browse files
pages/15_TransferLearning_HF.py
CHANGED
|
@@ -29,14 +29,20 @@ model_name = "google/vit-base-patch16-224-in21k"
|
|
| 29 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
| 30 |
base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
|
| 31 |
|
| 32 |
-
# Freeze the
|
| 33 |
base_model.trainable = False
|
| 34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
# Add custom layers on top
|
| 36 |
inputs = tf.keras.Input(shape=(224, 224, 3))
|
| 37 |
-
features =
|
| 38 |
-
|
| 39 |
-
x = tf.keras.layers.Flatten()(base_output)
|
| 40 |
x = tf.keras.layers.Dense(256, activation='relu')(x)
|
| 41 |
x = tf.keras.layers.Dropout(0.5)(x)
|
| 42 |
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
|
|
@@ -83,4 +89,3 @@ if st.button("Train Model"):
|
|
| 83 |
if st.button("Evaluate Model"):
|
| 84 |
test_loss, test_acc = model.evaluate(ds_val, verbose=2)
|
| 85 |
st.write(f"Validation accuracy: {test_acc}")
|
| 86 |
-
|
|
|
|
| 29 |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
| 30 |
base_model = TFAutoModelForImageClassification.from_pretrained(model_name, num_labels=2) # Cats vs Dogs has 2 classes
|
| 31 |
|
| 32 |
+
# Freeze the base model
|
| 33 |
base_model.trainable = False
|
| 34 |
|
| 35 |
+
# Function to extract features using the feature extractor
|
| 36 |
+
def extract_features(images):
|
| 37 |
+
# Convert images to the expected format for the feature extractor
|
| 38 |
+
images = [tf.image.convert_image_dtype(image, tf.float32) for image in images]
|
| 39 |
+
inputs = feature_extractor(images, return_tensors="tf")
|
| 40 |
+
return inputs["pixel_values"]
|
| 41 |
+
|
| 42 |
# Add custom layers on top
|
| 43 |
inputs = tf.keras.Input(shape=(224, 224, 3))
|
| 44 |
+
features = extract_features([inputs])
|
| 45 |
+
x = base_model.vit(inputs).last_hidden_state[:, 0]
|
|
|
|
| 46 |
x = tf.keras.layers.Dense(256, activation='relu')(x)
|
| 47 |
x = tf.keras.layers.Dropout(0.5)(x)
|
| 48 |
outputs = tf.keras.layers.Dense(1, activation='sigmoid')(x)
|
|
|
|
| 89 |
if st.button("Evaluate Model"):
|
| 90 |
test_loss, test_acc = model.evaluate(ds_val, verbose=2)
|
| 91 |
st.write(f"Validation accuracy: {test_acc}")
|
|
|