Spaces:
Runtime error
Runtime error
File size: 1,012 Bytes
ef1c94f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import torch
from moment_detr.model import build_transformer, build_position_encoding, MomentDETR
def build_inference_model(ckpt_path, **kwargs):
ckpt = torch.load(ckpt_path, map_location="cpu")
args = ckpt["opt"]
if len(kwargs) > 0: # used to overwrite default args
args.update(kwargs)
transformer = build_transformer(args)
position_embedding, txt_position_embedding = build_position_encoding(args)
model = MomentDETR(
transformer,
position_embedding,
txt_position_embedding,
txt_dim=args.t_feat_dim,
vid_dim=args.v_feat_dim,
num_queries=args.num_queries,
input_dropout=args.input_dropout,
aux_loss=args.aux_loss,
contrastive_align_loss=args.contrastive_align_loss,
contrastive_hdim=args.contrastive_hdim,
span_loss_type=args.span_loss_type,
use_txt_pos=args.use_txt_pos,
n_input_proj=args.n_input_proj,
)
model.load_state_dict(ckpt["model"])
return model
|