Fix image embedding logic to be mps-compatible
#45
by
DefOs9
- opened
Addresses the assertion error raised on mps machines. Cf. https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/12
- MPS changes:
.bool()
instead of.type(torch.BoolTensor)
- Avoid
index_put
issues by having an mps-specific logical block.
- The
temp_len
variable in the assertion was never used anyway, so I removed the variable and the offending assertion. - Various clean up of comments and code.
I do not recommend to remove the variable of 'temp_len', this is used to verify the length consistency between the pre-spared image tokens and the real number of image tokens. Also do not change the comments with hard code dimension number, in case the users might change the input resolution or the image encoder by themselves.