Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| from clip_transform import CLIPTransform | |
| from PIL import Image | |
| from torch.nn import functional as F | |
| class Prototypes: | |
| def __init__(self): | |
| self._clip_transform = CLIPTransform() | |
| self._load_prototypes() | |
| def _prepare_prototypes(self): | |
| image_embeddings = self.load_images_from_folder('prototypes') | |
| assert image_embeddings is not None, "no image embeddings found" | |
| assert len(image_embeddings) > 0, "no image embeddings found" | |
| person_keys = [key for key in image_embeddings.keys() if key.startswith('person-')] | |
| no_person_keys = [key for key in image_embeddings.keys() if key.startswith('no_person-')] | |
| person_keys.sort() | |
| no_person_keys.sort() | |
| # create pytorch vector of person embeddings | |
| person_embeddings = torch.cat([image_embeddings[key] for key in person_keys]) | |
| # create pytorch vector of no_person embeddings | |
| no_person_embeddings = torch.cat([image_embeddings[key] for key in no_person_keys]) | |
| person_embedding = person_embeddings.mean(dim=0) | |
| person_embedding /= person_embedding.norm(dim=-1, keepdim=True) | |
| no_person_embedding = no_person_embeddings.mean(dim=0) | |
| no_person_embedding /= no_person_embedding.norm(dim=-1, keepdim=True) | |
| self.prototype_keys = ["person", "no_person"] | |
| self.prototypes = torch.stack([person_embedding, no_person_embedding]) | |
| # save prototypes to file | |
| torch.save(self.prototypes, 'prototypes.pt') | |
| def _load_prototypes(self): | |
| # check if file exists | |
| if not os.path.exists('prototypes.pt'): | |
| self._prepare_prototypes() | |
| self.prototypes = torch.load('prototypes.pt') | |
| self.prototype_keys = ["person", "no_person"] | |
| def load_images_from_folder(self, folder): | |
| image_embeddings = {} | |
| supported_filetypes = ['.jpg','.png','.jpeg'] | |
| for filename in os.listdir(folder): | |
| if not any([filename.endswith(ft) for ft in supported_filetypes]): | |
| continue | |
| image = Image.open(os.path.join(folder,filename)) | |
| embeddings = self._clip_transform.pil_image_to_embeddings(image) | |
| image_embeddings[filename] = embeddings | |
| return image_embeddings | |
| def get_distances(self, embeddings): | |
| # case not normalized | |
| # distances = F.cosine_similarity(embeddings, self.prototypes) | |
| # case normalized | |
| distances = embeddings @ self.prototypes.T | |
| closest_item_idex = distances.argmax().item() | |
| closest_item_key = self.prototype_keys[closest_item_idex] | |
| debug_str = "" | |
| for key, value in zip(self.prototype_keys, distances): | |
| debug_str += f"{key}: {value.item():.2f}, " | |
| return distances, closest_item_key, debug_str | |
| if __name__ == "__main__": | |
| prototypes = Prototypes() | |
| print ("prototypes:") | |
| for key, value in zip(prototypes.prototype_keys, prototypes.prototypes): | |
| print (f"{key}: {len(value)}") | |
| embeddings = prototypes.prototypes[0] | |
| distances, closest_item_key, debug_str = prototypes.get_distances(embeddings) | |
| print (f"closest_item_key: {closest_item_key}") | |
| print (f"distances: {debug_str}") | |
| print ("done") |