Boni98 commited on
Commit
c9e70e7
·
1 Parent(s): ef3db12

Upload 104 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CLIP Guess The Image
3
+ emoji: 🦀
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.28.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unknown
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import clip_chat
3
+
4
+
5
+ def logit2sentence(logit, slider_value):
6
+ sentence = ""
7
+ if logit < slider_value / 2.5:
8
+ sentence = "Nope. Not at all."
9
+ elif slider_value / 2.5 < logit < slider_value / 1.56:
10
+ sentence = "Not really..."
11
+ elif slider_value / 1.56 < logit < slider_value / 1.36:
12
+ sentence = "Close but not there."
13
+ elif slider_value / 1.36 < logit < slider_value / 1.14:
14
+ sentence = "That's quite close."
15
+ elif slider_value / 1.14 < logit < slider_value:
16
+ sentence = "Almost guessed."
17
+ elif logit >= slider_value:
18
+ sentence = "YES!!"
19
+ return sentence
20
+
21
+
22
+ def give_up():
23
+ image = clip_chat.image_org
24
+ return image, None, "You lost... (Press \"Reset\" to play again)"
25
+
26
+
27
+ def update_difficulty(x):
28
+ if not has_started:
29
+ clip_chat.goal = x
30
+ return clip_chat.goal
31
+ return clip_chat.goal
32
+
33
+
34
+ has_started = False
35
+ best_guess = None
36
+
37
+
38
+ def respond(message, chat_history, label_value, image_value):
39
+ global has_started, best_guess
40
+
41
+ if not has_started:
42
+ has_started = True
43
+
44
+ logits, is_better = clip_chat.answer(message)
45
+ bot_message = logit2sentence(logits, clip_chat.goal)
46
+
47
+ if is_better == 3:
48
+ best_guess = {f"Best Guess: \"{message}\"": float(logits) / clip_chat.goal}
49
+ if float(logits) >= clip_chat.goal:
50
+ bot_message = "YES!"
51
+ best_guess = "YOU WIN! (Press \"Reset\" to play again)"
52
+ image_value = clip_chat.image_org
53
+ else:
54
+ if is_better == -1:
55
+ bot_message += ""
56
+ elif is_better == 0:
57
+ bot_message += "You did worse than the last one."
58
+ elif is_better == 1 or is_better == 3:
59
+ bot_message += "You did better than the last one."
60
+
61
+ label_value = best_guess
62
+
63
+ chat_history.append((message, bot_message))
64
+ return "", chat_history, label_value, image_value
65
+
66
+
67
+ def reset_everything():
68
+ global has_started, best_guess
69
+ clip_chat.reset_everything()
70
+ has_started = False
71
+ best_guess = None
72
+ return clip_chat.goal, None, "This is a \"Guess the Image\" game. I'm thinking of a picture and you have to guess using the chat above.", None
73
+
74
+
75
+ with gr.Blocks() as demo:
76
+ chatbot = gr.Chatbot()
77
+ msg = gr.Textbox()
78
+ label = gr.Label("This is a \"Guess the Image\" game. I'm thinking of a picture and you have to guess using the chat above.")
79
+ image_output = gr.outputs.Image(type="pil")
80
+ show_image_button = gr.Button("Give Up...")
81
+ slider = gr.inputs.Slider(minimum=18, maximum=25, default=21, label="Difficulty (18 - Easy, 25 - Expert)")
82
+ reset_button = gr.Button("Reset")
83
+
84
+ msg.submit(respond, [msg, chatbot], [msg, chatbot, label, image_output])
85
+ slider.release(update_difficulty, inputs=[slider], outputs=[slider])
86
+ show_image_button.click(give_up, outputs=[image_output, chatbot, label], queue=False)
87
+ reset_button.click(reset_everything, outputs=[slider, image_output, label, chatbot], queue=False)
88
+
89
+ if __name__ == "__main__":
90
+ demo.title = "CLIP Guess the Image"
91
+ demo.launch()
clip_chat.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ from PIL import Image
4
+ import glob
5
+ import os
6
+ from random import choice
7
+
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ # model, preprocess = clip.load("ViT-L/14@336px", device=device)
11
+ model, preprocess = clip.load(clip.available_models()[-3], device=device)
12
+ COCO = glob.glob(os.path.join(os.getcwd(), "images", "*"))
13
+
14
+
15
+ def load_random_image():
16
+ image_path = choice(COCO)
17
+ image = Image.open(image_path)
18
+ return image
19
+
20
+
21
+ def next_image():
22
+ global image_org, image
23
+ image_org = load_random_image()
24
+ image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device)
25
+
26
+
27
+ def calculate_logits(image_features, text_features):
28
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
29
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
30
+
31
+ logit_scale = model.logit_scale.exp()
32
+ return logit_scale * image_features @ text_features.t()
33
+
34
+
35
+ last = -1
36
+ best = -1
37
+
38
+ goal = 21
39
+
40
+ image_org = load_random_image()
41
+ image = preprocess(image_org).unsqueeze(0).to(device)
42
+ with torch.no_grad():
43
+ image_features = model.encode_image(image)
44
+
45
+
46
+ def answer(message):
47
+ global last, best
48
+
49
+ text = clip.tokenize([message]).to(device)
50
+
51
+ with torch.no_grad():
52
+ text_features = model.encode_text(text)
53
+ logits_per_image, _ = model(image, text)
54
+ logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0]
55
+
56
+ if last == -1:
57
+ is_better = -1
58
+ elif last > logits:
59
+ is_better = 0
60
+ elif last < logits:
61
+ is_better = 1
62
+ elif logits > goal:
63
+ is_better = 2
64
+ else:
65
+ is_better = -1
66
+
67
+ last = logits
68
+ if logits > best:
69
+ best = logits
70
+ is_better = 3
71
+
72
+ return logits, is_better
73
+
74
+
75
+ def reset_everything():
76
+ global last, best, goal, image, image_org
77
+ last = -1
78
+ best = -1
79
+ goal = 21
80
+ image_org = load_random_image()
81
+ image = preprocess(image_org).unsqueeze(0).to(device)
images/000000000034.jpg ADDED
images/000000000089.jpg ADDED
images/000000000247.jpg ADDED
images/000000000283.jpg ADDED
images/000000000315.jpg ADDED
images/000000000321.jpg ADDED
images/000000000338.jpg ADDED
images/000000000400.jpg ADDED
images/000000000450.jpg ADDED
images/000000000532.jpg ADDED
images/000000000575.jpg ADDED
images/000000000599.jpg ADDED
images/000000000641.jpg ADDED
images/000000000659.jpg ADDED
images/000000000671.jpg ADDED
images/000000000761.jpg ADDED
images/000000000853.jpg ADDED
images/000000001063.jpg ADDED
images/000000001064.jpg ADDED
images/000000001072.jpg ADDED
images/000000001122.jpg ADDED
images/000000001155.jpg ADDED
images/000000001180.jpg ADDED
images/000000001319.jpg ADDED
images/000000001360.jpg ADDED
images/000000001393.jpg ADDED
images/000000001407.jpg ADDED
images/000000001558.jpg ADDED
images/000000001596.jpg ADDED
images/000000001647.jpg ADDED
images/000000001668.jpg ADDED
images/000000001737.jpg ADDED
images/000000001764.jpg ADDED
images/000000001781.jpg ADDED
images/000000001811.jpg ADDED
images/000000001841.jpg ADDED
images/000000001915.jpg ADDED
images/000000001924.jpg ADDED
images/000000001999.jpg ADDED
images/000000002007.jpg ADDED
images/000000002056.jpg ADDED
images/000000002066.jpg ADDED
images/000000002270.jpg ADDED
images/000000002281.jpg ADDED
images/000000002283.jpg ADDED
images/000000002295.jpg ADDED
images/000000002296.jpg ADDED