File size: 519 Bytes
a2e629f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15

import tensorflow as tf

class LSTM_Classifier(tf.keras.Model):
    def __init__(self, lstm_units=128, num_classes=4, **kwargs):
        super(LSTM_Classifier, self).__init__(**kwargs)
        self.lstm = tf.keras.layers.LSTM(lstm_units, return_sequences=False)
        self.dropout = tf.keras.layers.Dropout(0.3)
        self.classifier = tf.keras.layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        x = self.lstm(inputs)
        x = self.dropout(x)
        return self.classifier(x)