import torch import clip from PIL import Image import glob import os from random import choice device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-L/14@336px", device=device) COCO = glob.glob(os.path.join(os.getcwd(), "images", "*")) available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'] def load_random_image(): image_path = choice(COCO) image = Image.open(image_path) return image def next_image(): global image_org, image image_org = load_random_image() image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device) # def calculate_logits(image, text): # return model(image, text)[0] def calculate_logits(image_features, text_features): image_features = image_features / image_features.norm(dim=1, keepdim=True) text_features = text_features / text_features.norm(dim=1, keepdim=True) logit_scale = model.logit_scale.exp() return logit_scale * image_features @ text_features.t() last = -1 best = -1 goal = 23 image_org = load_random_image() image = preprocess(image_org).unsqueeze(0).to(device) with torch.no_grad(): image_features = model.encode_image(image) def answer(message): global last, best text = clip.tokenize([message]).to(device) with torch.no_grad(): text_features = model.encode_text(text) # logits_per_image, _ = model(image, text) logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0] # logits = calculate_logits(image, text) if last == -1: is_better = -1 elif last > logits: is_better = 0 elif last < logits: is_better = 1 elif logits > goal: is_better = 2 else: is_better = -1 last = logits if logits > best: best = logits is_better = 3 return logits, is_better def reset_everything(): global last, best, goal, image, image_org last = -1 best = -1 goal = 23 image_org = load_random_image() image = preprocess(image_org).unsqueeze(0).to(device)