Amite5h commited on
Commit
00a9dd7
·
1 Parent(s): c439b4b

Upload 3 files

Browse files
Files changed (3) hide show
  1. model.h5 +3 -0
  2. model.py +340 -0
  3. vocab_coco.file +0 -0
model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:697c158f2ac305a5c5b83585f217b720e9bf330b379e0c651b6db1a7d0ab01d8
3
+ size 221999608
model.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import os
3
+ import json
4
+ import pandas as pd
5
+ import re
6
+ import numpy as np
7
+ import time
8
+ import matplotlib.pyplot as plt
9
+ import collections
10
+ import random
11
+ import pickle
12
+
13
+ import requests
14
+ import json
15
+ from math import sqrt
16
+ from PIL import Image
17
+ from tqdm.auto import tqdm
18
+
19
+ MAX_LENGTH = 40
20
+ VOCABULARY_SIZE = 17000
21
+ BATCH_SIZE = 64
22
+ BUFFER_SIZE = 1000
23
+ EMBEDDING_DIM = 512
24
+ UNITS = 512
25
+ EPOCHS = 8
26
+
27
+ vocab = pickle.load(open('vocab_coco.file', 'rb'))
28
+
29
+ tokenizer = tf.keras.layers.TextVectorization(
30
+ # max_tokens=VOCABULARY_SIZE,
31
+ standardize=None,
32
+ output_sequence_length=MAX_LENGTH,
33
+ vocabulary=vocab
34
+ )
35
+
36
+ idx2word = tf.keras.layers.StringLookup(
37
+ mask_token="",
38
+ vocabulary=tokenizer.get_vocabulary(),
39
+ invert=True
40
+ )
41
+
42
+ def CNN_Encoder():
43
+ inception_v3 = tf.keras.applications.InceptionV3(
44
+ include_top=False, #we are not doing classification on image net so we have to drop last dense layers
45
+ weights='imagenet'
46
+ )
47
+
48
+ output = inception_v3.output
49
+ print(output.shape)
50
+ output = tf.keras.layers.Reshape(
51
+ (-1, output.shape[-1]))(output)
52
+ print(output.shape)
53
+ cnn_model = tf.keras.models.Model(inception_v3.input, output)
54
+ return cnn_model
55
+
56
+ class TransformerEncoderLayer(tf.keras.layers.Layer):
57
+
58
+ def __init__(self, embed_dim, num_heads):
59
+ super().__init__()
60
+ self.layer_norm_1 = tf.keras.layers.LayerNormalization()
61
+ self.layer_norm_2 = tf.keras.layers.LayerNormalization()
62
+ self.attention = tf.keras.layers.MultiHeadAttention(
63
+ num_heads=num_heads, key_dim=embed_dim)
64
+ self.dense = tf.keras.layers.Dense(embed_dim, activation="relu")
65
+
66
+
67
+ def call(self, x, training):
68
+ x = self.layer_norm_1(x)
69
+ x = self.dense(x)
70
+
71
+ attn_output = self.attention(
72
+ query=x,
73
+ value=x,
74
+ key=x,
75
+ attention_mask=None,
76
+ training=training
77
+ )
78
+
79
+ x = self.layer_norm_2(x + attn_output) #skip connection
80
+ return x
81
+
82
+ # combines token embeddings and position embeddings
83
+
84
+ class Embeddings(tf.keras.layers.Layer):
85
+
86
+ def __init__(self, vocab_size, embed_dim, max_len):
87
+ super().__init__()
88
+ self.token_embeddings = tf.keras.layers.Embedding(
89
+ vocab_size, embed_dim)
90
+ self.position_embeddings = tf.keras.layers.Embedding(
91
+ max_len, embed_dim, input_shape=(None, max_len))
92
+
93
+
94
+ def call(self, input_ids):
95
+ length = tf.shape(input_ids)[-1]
96
+ #calculate the total length of input sequence so that it would be used to calculate positional id
97
+ position_ids = tf.range(start=0, limit=length, delta=1)
98
+ #give id positional id to each word in caption
99
+ position_ids = tf.expand_dims(position_ids, axis=0)
100
+ #This is done to match the shape of the input tensor when performing element-wise addition in the next step.
101
+ token_embeddings = self.token_embeddings(input_ids)
102
+ #so we are creating token embedding for input ids
103
+ position_embeddings = self.position_embeddings(position_ids)
104
+ #but we are creating postion embedding with position_ids only
105
+ return token_embeddings + position_embeddings
106
+
107
+ class TransformerDecoderLayer(tf.keras.layers.Layer):
108
+
109
+ def __init__(self, embed_dim, units, num_heads):
110
+ super().__init__()
111
+ self.embedding = Embeddings(
112
+ tokenizer.vocabulary_size(), embed_dim, MAX_LENGTH)
113
+ #embedding from
114
+ self.attention_1 = tf.keras.layers.MultiHeadAttention(
115
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
116
+ )
117
+ self.attention_2 = tf.keras.layers.MultiHeadAttention(
118
+ num_heads=num_heads, key_dim=embed_dim, dropout=0.1
119
+ )
120
+
121
+ self.layernorm_1 = tf.keras.layers.LayerNormalization()
122
+ self.layernorm_2 = tf.keras.layers.LayerNormalization()
123
+ self.layernorm_3 = tf.keras.layers.LayerNormalization()
124
+
125
+ self.ffn_layer_1 = tf.keras.layers.Dense(units, activation="relu")
126
+ self.ffn_layer_2 = tf.keras.layers.Dense(embed_dim)
127
+
128
+ self.out = tf.keras.layers.Dense(tokenizer.vocabulary_size(), activation="softmax")
129
+
130
+ self.dropout_1 = tf.keras.layers.Dropout(0.3)
131
+ self.dropout_2 = tf.keras.layers.Dropout(0.5)
132
+
133
+
134
+ def call(self, input_ids, encoder_output, training, mask=None):
135
+ embeddings = self.embedding(input_ids)
136
+
137
+ combined_mask = None
138
+ padding_mask = None
139
+
140
+ if mask is not None:
141
+ causal_mask = self.get_causal_attention_mask(embeddings)
142
+ padding_mask = tf.cast(mask[:, :, tf.newaxis], dtype=tf.int32)
143
+ combined_mask = tf.cast(mask[:, tf.newaxis, :], dtype=tf.int32)
144
+ combined_mask = tf.minimum(combined_mask, causal_mask)
145
+
146
+
147
+ #this layer contain masked self attention layer
148
+ attn_output_1 = self.attention_1(
149
+ query=embeddings,
150
+ value=embeddings,
151
+ key=embeddings,
152
+ attention_mask=combined_mask,
153
+ training=training
154
+ )
155
+ #this layer contain cross attention
156
+ #which is taking query vector from the previous masked attention and key and value vector from encoder so that some information of input is there
157
+ #Expalin:
158
+ out_1 = self.layernorm_1(embeddings + attn_output_1)
159
+
160
+ attn_output_2 = self.attention_2(
161
+ query=out_1, #query vector from deocder
162
+ value=encoder_output, #key and value vector from encoder
163
+ key=encoder_output,
164
+ attention_mask=padding_mask, #no masking is there
165
+ training=training
166
+ )
167
+
168
+ out_2 = self.layernorm_2(out_1 + attn_output_2) #skip connection
169
+
170
+ ffn_out = self.ffn_layer_1(out_2)
171
+ ffn_out = self.dropout_1(ffn_out, training=training)
172
+ ffn_out = self.ffn_layer_2(ffn_out)
173
+
174
+ ffn_out = self.layernorm_3(ffn_out + out_2)
175
+ ffn_out = self.dropout_2(ffn_out, training=training)
176
+ preds = self.out(ffn_out)
177
+ return preds
178
+
179
+
180
+ def get_causal_attention_mask(self, inputs):
181
+ input_shape = tf.shape(inputs)
182
+ batch_size, sequence_length = input_shape[0], input_shape[1]
183
+ i = tf.range(sequence_length)[:, tf.newaxis]
184
+ j = tf.range(sequence_length)
185
+ mask = tf.cast(i >= j, dtype="int32")
186
+ mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
187
+ mult = tf.concat(
188
+ [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
189
+ axis=0
190
+ )
191
+ return tf.tile(mask, mult)
192
+
193
+
194
+ class ImageCaptioningModel(tf.keras.Model):
195
+
196
+ def __init__(self, cnn_model, encoder, decoder, image_aug=None):
197
+ super().__init__()
198
+ self.cnn_model = cnn_model
199
+ self.encoder = encoder
200
+ self.decoder = decoder
201
+ self.image_aug = image_aug
202
+ self.loss_tracker = tf.keras.metrics.Mean(name="loss")
203
+ self.acc_tracker = tf.keras.metrics.Mean(name="accuracy")
204
+
205
+
206
+ def calculate_loss(self, y_true, y_pred, mask):
207
+ loss = self.loss(y_true, y_pred)
208
+ mask = tf.cast(mask, dtype=loss.dtype)
209
+ loss *= mask
210
+ return tf.reduce_sum(loss) / tf.reduce_sum(mask)
211
+
212
+
213
+ def calculate_accuracy(self, y_true, y_pred, mask):
214
+ accuracy = tf.equal(y_true, tf.argmax(y_pred, axis=2))
215
+ accuracy = tf.math.logical_and(mask, accuracy)
216
+ accuracy = tf.cast(accuracy, dtype=tf.float32)
217
+ mask = tf.cast(mask, dtype=tf.float32)
218
+ return tf.reduce_sum(accuracy) / tf.reduce_sum(mask)
219
+
220
+
221
+ def compute_loss_and_acc(self, img_embed, captions, training=True):
222
+ encoder_output = self.encoder(img_embed, training=True)
223
+ y_input = captions[:, :-1]
224
+ y_true = captions[:, 1:]
225
+ mask = (y_true != 0)
226
+ y_pred = self.decoder(
227
+ y_input, encoder_output, training=True, mask=mask
228
+ )
229
+ loss = self.calculate_loss(y_true, y_pred, mask)
230
+ acc = self.calculate_accuracy(y_true, y_pred, mask)
231
+ return loss, acc
232
+
233
+
234
+ def train_step(self, batch):
235
+ imgs, captions = batch
236
+
237
+ if self.image_aug:
238
+ imgs = self.image_aug(imgs)
239
+
240
+ img_embed = self.cnn_model(imgs)
241
+
242
+ with tf.GradientTape() as tape:
243
+ loss, acc = self.compute_loss_and_acc(
244
+ img_embed, captions
245
+ )
246
+
247
+ train_vars = (
248
+ self.encoder.trainable_variables + self.decoder.trainable_variables
249
+ )
250
+ grads = tape.gradient(loss, train_vars)
251
+ self.optimizer.apply_gradients(zip(grads, train_vars))
252
+ self.loss_tracker.update_state(loss)
253
+ self.acc_tracker.update_state(acc)
254
+
255
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
256
+
257
+
258
+ def test_step(self, batch):
259
+ imgs, captions = batch
260
+
261
+ img_embed = self.cnn_model(imgs)
262
+
263
+ loss, acc = self.compute_loss_and_acc(
264
+ img_embed, captions, training=False
265
+ )
266
+
267
+ self.loss_tracker.update_state(loss)
268
+ self.acc_tracker.update_state(acc)
269
+
270
+ return {"loss": self.loss_tracker.result(), "acc": self.acc_tracker.result()}
271
+
272
+ @property
273
+ def metrics(self):
274
+ return [self.loss_tracker, self.acc_tracker]
275
+
276
+
277
+ def load_image_from_path(img_path):
278
+ img = tf.io.read_file(img_path)
279
+ img = tf.io.decode_jpeg(img, channels=3)
280
+ img = tf.keras.layers.Resizing(299, 299)(img)
281
+ img = tf.keras.applications.inception_v3.preprocess_input(img)
282
+ return img
283
+
284
+
285
+ def generate_caption(img_path, add_noise=False):
286
+ img = load_image_from_path(img_path)
287
+
288
+ if add_noise:
289
+ noise = tf.random.normal(img.shape)*0.1
290
+ img = img + noise
291
+ img = (img - tf.reduce_min(img))/(tf.reduce_max(img) - tf.reduce_min(img))
292
+
293
+ img = tf.expand_dims(img, axis=0)
294
+ img_embed = model.cnn_model(img)
295
+ img_encoded = model.encoder(img_embed, training=False)
296
+
297
+ y_inp = '[start]'
298
+ for i in range(MAX_LENGTH-1):
299
+ tokenized = tokenizer([y_inp])[:, :-1]
300
+ mask = tf.cast(tokenized != 0, tf.int32)
301
+ pred = model.decoder(
302
+ tokenized, img_encoded, training=False, mask=mask)
303
+
304
+ pred_idx = np.argmax(pred[0, i, :])
305
+ pred_idx = tf.convert_to_tensor(pred_idx)
306
+ pred_word = idx2word(pred_idx).numpy().decode('utf-8')
307
+ if pred_word == '[end]':
308
+ break
309
+
310
+ y_inp += ' ' + pred_word
311
+
312
+ y_inp = y_inp.replace('[start] ', '')
313
+ return y_inp
314
+
315
+ def get_caption_model():
316
+ encoder = TransformerEncoderLayer(EMBEDDING_DIM, 1)
317
+ decoder = TransformerDecoderLayer(EMBEDDING_DIM, UNITS, 8)
318
+
319
+ cnn_model = CNN_Encoder()
320
+
321
+ caption_mode = ImageCaptioningModel(
322
+ cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=None,
323
+ )
324
+
325
+ def call_fn(batch, training):
326
+ return batch
327
+
328
+ caption_mode.call = call_fn
329
+ sample_x, sample_y = tf.random.normal((1, 299, 299, 3)), tf.zeros((1, 40))
330
+
331
+ caption_mode((sample_x, sample_y))
332
+
333
+ sample_img_embed = caption_mode.cnn_model(sample_x)
334
+ sample_enc_out = caption_mode.encoder(sample_img_embed, training=False)
335
+ caption_mode.decoder(sample_y, sample_enc_out, training=False)
336
+
337
+ caption_mode.load_weights('model.h5')
338
+
339
+ return caption_mode
340
+
vocab_coco.file ADDED
Binary file (918 kB). View file