AnsenH's picture
feat: add our model
24615d9
raw
history blame
5.44 kB
import torch
import torch.nn as nn
import os
from infer import sample_clips, batch_transform_val, videofile_to_frames
from i2v import load_weight, i2v_transform
import numpy as np
import coremltools as ct
def image_to_emb_attn(image, mlmodel):
output_dict = mlmodel.predict({'input': image.resize((224,224))})
vector = output_dict['vector']
attention = output_dict['attention']
attn_cls = attention[0,0,0]
attn_img = attention[0,0,1:].reshape(7,7)
return vector, attn_cls, attn_img
class ScoringWrapper(nn.Module):
def __init__(self):
super(ScoringWrapper, self).__init__()
ckpt_path = os.path.join('..', 'weight', 'ckpt_epoch_59_loss_0.3066582295561343.ckpt')
self.scoring_model = torch.load(ckpt_path, map_location='cpu')['model'].eval()
def forward(self, img_vectors):
img_vectors = img_vectors.view(-1, 512 * self.scoring_model.frames_per_clip)
scores = torch.sigmoid(self.scoring_model(img_vectors)).view(-1)
return scores
class HighlightModel(nn.Module):
def __init__(self, ckpt_path = os.path.join('..', 'weight', 'ckpt_epoch_59_loss_0.3066582295561343.ckpt'),
i2v_path = None, i2v_transform = i2v_transform, batch_transform_val = batch_transform_val):
super().__init__()
if not i2v_path:
i2v_path = os.path.join('..', 'weight', 'heads24_attn_epoch30_loss0.22810565.pt')
self.i2v = load_weight(i2v_path)
self.i2v_transform = i2v_transform
checkpoint = torch.load(ckpt_path, map_location='cpu')
self.scoring_model = checkpoint['model'].eval()
self.batch_transform_val = batch_transform_val
def forward(self, frames):
n_clips = len(frames) - self.scoring_model.frames_per_clip + 1
x = torch.stack(self.batch_transform_val(frames)) # x.size = (num_frames, 3, 224, 224)
x = self.i2v_transform(x)
img_vectors, attn = self.i2v(x) # img_vectors.size = (num_frames, 512)
# attn.size = (num_frames, 50, 50)
attn = torch.stack([attn[i, 0, 1:].view(7, 7) for i in range(len(attn))])
x = sample_clips(img_vectors, self.scoring_model.frames_per_clip)
x = x.view(n_clips, -1)
scores = torch.sigmoid(self.scoring_model(x)).view(-1)
return img_vectors, attn, scores
if __name__ == '__main__':
# Load testing video
frames = videofile_to_frames('../get_highlight_example/sports_day_smile.MOV')
# Initialize Pytorch model (i2v + scoring_model)
highlight_model = HighlightModel()
_ = highlight_model.eval()
'''
Note:
BatchResize is different from PIL.Image resize.
If we purely use BatchResize as preprocessing, the output of pytorch model will be different from coreml
=> resize image by PIL.Image.resize() first
'''
pytorch_img_vector, pytorch_attn, pytorch_scores = highlight_model([frame.resize((224,224)) for frame in frames])
print("Pytorch scores:", pytorch_scores.detach())
print("==="*30)
# Load baby clip half precision mlmodel
i2v_mlmodel_filename = os.path.join('..','weight','half_heads24_attn_epoch30_loss0.22810565.pt.mlmodel')
i2v_mlmodel = ct.models.MLModel(i2v_mlmodel_filename)
vectors = []
attns = []
for frame in frames:
vector, attn_cls, attn_img = image_to_emb_attn(frame, i2v_mlmodel)
vectors.append(vector.squeeze())
attns.append(attn_img)
print('MAE of image vectors:', (pytorch_img_vector.detach() - torch.Tensor(vectors)).abs().mean())
print('MAE of attn:', (pytorch_attn.detach() - torch.Tensor(attns)).abs().mean())
print("==="*30)
# Load Pytorch scoring model
scoring_model = ScoringWrapper().eval()
scores = [ scoring_model(torch.Tensor(vectors[i:i+3])).detach().item()
for i in range(len(vectors)-scoring_model.scoring_model.frames_per_clip+1)]
print('I2v_mlmodel + pytorch_scoring:', scores)
print('MAE of scores (using i2v_mlmodel):', (pytorch_scores.detach() - torch.Tensor(scores)).abs().mean())
print("==="*30)
# mlmodel conversion
x = torch.rand(scoring_model.scoring_model.frames_per_clip, 512)
traced = torch.jit.trace(scoring_model, x)
model_input = ct.TensorType(name='input', shape=x.shape)
scoring_mlmodel = ct.convert(source='pytorch', model=traced, inputs=[model_input])
spec = scoring_mlmodel.get_spec()
scoring_mlmodel_filename = os.path.join('..','weight','score_epoch_59_loss_0.3066.mlmodel')
ct.models.utils.rename_feature(spec, 'var_14', 'score', rename_outputs=True)
ct.models.utils.save_spec(spec, scoring_mlmodel_filename)
scoring_mlmodel = ct.models.MLModel(scoring_mlmodel_filename)
# scoring mlmodel prediction
mlmodel_scores = [ scoring_mlmodel.predict({'input':np.array(vectors[i:i+3])})['score'][0]
for i in range(len(vectors)-scoring_model.scoring_model.frames_per_clip+1)]
print("I2v_mlmodel + scoring_mlmodel:", mlmodel_scores)
print('MAE of scores (using i2v_mlmodel and scoring_mlmodel):', np.abs(pytorch_scores.detach().numpy() - np.array(mlmodel_scores)).mean())
# assert torch.allclose(pytorch_img_vector.detach(), torch.Tensor(vectors))
# assert torch.allclose(pytorch_attn.detach(), torch.Tensor(attns))
# assert np.allclose(pytorch_scores.detach().numpy(), mlmodel_scores)