jrippert commited on
Commit
7916653
·
verified ·
1 Parent(s): f93b2d7

Update compAnIonv1.py

Browse files
Files changed (1) hide show
  1. compAnIonv1.py +10 -10
compAnIonv1.py CHANGED
@@ -78,16 +78,16 @@ def create_bert_classification_model(bert_model,
78
  attention_mask = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='attention_mask')
79
 
80
  class CustomLayer(tf.keras.layers.Layer):
81
- def call(self, inputs):
82
- if isinstance(inputs, tf.Tensor):
83
- # If the input is a tensor, convert it to int64
84
- return tf.constant(inputs, dtype=tf.int64)
85
- elif isinstance(inputs, dict):
86
- # If the input is a dictionary, process each value
87
- return {key: tf.constant(value, dtype=tf.int64) for key, value in inputs.items()}
88
- else:
89
- # If the input type is unknown, raise an error
90
- raise ValueError(f"Unsupported input type: {type(inputs)}")
91
 
92
 
93
  # Create instances of the custom layer for each input
 
78
  attention_mask = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='attention_mask')
79
 
80
  class CustomLayer(tf.keras.layers.Layer):
81
+ def call(self, inputs):
82
+ if isinstance(inputs, tf.Tensor):
83
+ # If the input is a tensor, convert it to int64
84
+ return tf.constant(inputs, dtype=tf.int64)
85
+ elif isinstance(inputs, dict):
86
+ # If the input is a dictionary, process each value
87
+ return {key: tf.constant(value, dtype=tf.int64) for key, value in inputs.items()}
88
+ else:
89
+ # If the input type is unknown, raise an error
90
+ raise ValueError(f"Unsupported input type: {type(inputs)}")
91
 
92
 
93
  # Create instances of the custom layer for each input