Fix image embedding logic to be mps-compatible

#45
by DefOs9 - opened

Addresses the assertion error raised on mps machines. Cf. https://huggingface.co/microsoft/Phi-4-multimodal-instruct/discussions/12

  • MPS changes:
    • .bool() instead of .type(torch.BoolTensor)
    • Avoid index_put issues by having an mps-specific logical block.
  • The temp_len variable in the assertion was never used anyway, so I removed the variable and the offending assertion.
  • Various clean up of comments and code.

I do not recommend to remove the variable of 'temp_len', this is used to verify the length consistency between the pre-spared image tokens and the real number of image tokens. Also do not change the comments with hard code dimension number, in case the users might change the input resolution or the image encoder by themselves.

Ready to merge
This branch is ready to get merged automatically.
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment