Spaces:
Runtime error
Runtime error
| """ | |
| This is a modified version which only extract text embedding in HF Space. | |
| See https://github.com/baaivision/Uni3D for source code. | |
| Or refer to https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py for extracting all embeddings. | |
| """ | |
| import os | |
| import sys | |
| import open_clip | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| sys.path.append('') | |
| from feature_extractors import FeatureExtractor | |
| from utils.tokenizer import SimpleTokenizer | |
| class Uni3dEmbeddingEncoder(FeatureExtractor): | |
| def __init__(self, cache_dir, **kwargs) -> None: | |
| bpe_path = "utils/bpe_simple_vocab_16e6.txt.gz" | |
| clip_path = os.path.join(cache_dir, "Uni3D", "open_clip_pytorch_model.bin") | |
| if not os.path.exists(clip_path): | |
| hf_hub_download("timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k", "open_clip_pytorch_model.bin", | |
| cache_dir=cache_dir, local_dir=cache_dir + os.sep + "Uni3D") | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.tokenizer = SimpleTokenizer(bpe_path) | |
| self.clip_model, _, self.preprocess = open_clip.create_model_and_transforms(model_name="EVA02-E-14-plus", pretrained=clip_path) | |
| self.clip_model.to(self.device) | |
| def encode_3D(self, data): | |
| raise NotImplementedError("For extracting 3D feature, see https://github.com/yuanze1024/LD-T3D/blob/master/feature_extractors/uni3d_embedding_encoder.py") | |
| def encode_text(self, input_text): | |
| texts = self.tokenizer(input_text).to(device=self.device, non_blocking=True) | |
| if len(texts.shape) < 2: | |
| texts = texts[None, ...] | |
| class_embeddings = self.clip_model.encode_text(texts) | |
| class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) | |
| return class_embeddings.float() | |
| def encode_image(self, img_tensor_list): | |
| image = img_tensor_list.to(device=self.device, non_blocking=True) | |
| image_features = self.clip_model.encode_image(image) | |
| image_features = image_features / image_features.norm(dim=-1, keepdim=True) | |
| return image_features.float() | |
| def encode_query(self, query_list): | |
| return self.encode_text(query_list) | |
| def get_img_transform(self): | |
| return self.preprocess |