Update compAnIonv1.py
Browse files- compAnIonv1.py +10 -10
compAnIonv1.py
CHANGED
@@ -78,16 +78,16 @@ def create_bert_classification_model(bert_model,
|
|
78 |
attention_mask = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='attention_mask')
|
79 |
|
80 |
class CustomLayer(tf.keras.layers.Layer):
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
|
93 |
# Create instances of the custom layer for each input
|
|
|
78 |
attention_mask = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='attention_mask')
|
79 |
|
80 |
class CustomLayer(tf.keras.layers.Layer):
|
81 |
+
def call(self, inputs):
|
82 |
+
if isinstance(inputs, tf.Tensor):
|
83 |
+
# If the input is a tensor, convert it to int64
|
84 |
+
return tf.constant(inputs, dtype=tf.int64)
|
85 |
+
elif isinstance(inputs, dict):
|
86 |
+
# If the input is a dictionary, process each value
|
87 |
+
return {key: tf.constant(value, dtype=tf.int64) for key, value in inputs.items()}
|
88 |
+
else:
|
89 |
+
# If the input type is unknown, raise an error
|
90 |
+
raise ValueError(f"Unsupported input type: {type(inputs)}")
|
91 |
|
92 |
|
93 |
# Create instances of the custom layer for each input
|