import tensorflow as tf | |
class LSTM_Classifier(tf.keras.Model): | |
def __init__(self, lstm_units=128, num_classes=4, **kwargs): | |
super(LSTM_Classifier, self).__init__(**kwargs) | |
self.lstm = tf.keras.layers.LSTM(lstm_units, return_sequences=False) | |
self.dropout = tf.keras.layers.Dropout(0.3) | |
self.classifier = tf.keras.layers.Dense(num_classes, activation='softmax') | |
def call(self, inputs): | |
x = self.lstm(inputs) | |
x = self.dropout(x) | |
return self.classifier(x) | |