Boni98 commited on
Commit
0ab67e5
·
1 Parent(s): 3af8d8e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +94 -0
  2. clip_chat.py +82 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import clip_chat
3
+
4
+
5
+ def logit2sentence(logit, slider_value):
6
+ sentence = ""
7
+ if logit < slider_value / 2.5:
8
+ sentence = "Nope. Not at all."
9
+ elif slider_value / 2.5 < logit < slider_value / 1.56:
10
+ sentence = "Not really..."
11
+ elif slider_value / 1.56 < logit < slider_value / 1.36:
12
+ sentence = "Close but not there."
13
+ elif slider_value / 1.36 < logit < slider_value / 1.14:
14
+ sentence = "That's quite close."
15
+ elif slider_value / 1.14 < logit < slider_value:
16
+ sentence = "Almost guessed."
17
+ elif logit >= slider_value:
18
+ sentence = "YES!!"
19
+ return sentence
20
+
21
+
22
+ def give_up():
23
+ image = clip_chat.image_org
24
+ return image, None, "You lost... (Press \"Reset\" to play again)"
25
+
26
+
27
+ def update_difficulty(x):
28
+ if not has_started:
29
+ clip_chat.goal = x
30
+ return clip_chat.goal
31
+ return clip_chat.goal
32
+
33
+
34
+ has_started = False
35
+ best_guess = None
36
+
37
+
38
+ def respond(message, chat_history, label_value, image_value):
39
+ global has_started, best_guess
40
+
41
+ if not has_started:
42
+ has_started = True
43
+
44
+ logits, is_better = clip_chat.answer(message)
45
+ bot_message = logit2sentence(logits, clip_chat.goal)
46
+
47
+ if is_better == 3:
48
+ best_guess = {f"Best Guess: \"{message}\"": float(logits) / clip_chat.goal}
49
+ if float(logits) >= clip_chat.goal:
50
+ bot_message = "YES!"
51
+ best_guess = "YOU WIN! (Press \"Reset\" to play again)"
52
+ image_value = clip_chat.image_org
53
+ else:
54
+ if is_better == -1:
55
+ bot_message += ""
56
+ elif is_better == 0:
57
+ bot_message += "You did worse than the last one."
58
+ elif is_better == 1 or is_better == 3:
59
+ bot_message += "You did better than the last one."
60
+
61
+ label_value = best_guess
62
+
63
+ chat_history.append((message, bot_message))
64
+ return "", chat_history, label_value, image_value
65
+
66
+
67
+ def reset_everything():
68
+ global has_started, best_guess
69
+ clip_chat.reset_everything()
70
+ has_started = False
71
+ best_guess = None
72
+ return clip_chat.goal, None, "This is a \"Guess the Image\" game. I'm thinking of a picture and you have to guess using the chat above.", None
73
+
74
+
75
+ with gr.Blocks() as demo:
76
+ with gr.Row():
77
+ with gr.Column():
78
+ chatbot = gr.Chatbot()
79
+ msg = gr.Textbox()
80
+ slider = gr.inputs.Slider(minimum=18, maximum=27, default=21, label="Difficulty (18 - Easy, 27 - Expert)")
81
+ with gr.Column():
82
+ label = gr.Label("This is a \"Guess the Image\" game. I'm thinking of a picture and you have to guess using the chat above.")
83
+ image_output = gr.outputs.Image(type="pil")
84
+ show_image_button = gr.Button("Give Up...")
85
+ reset_button = gr.Button("Reset")
86
+
87
+ msg.submit(respond, [msg, chatbot], [msg, chatbot, label, image_output])
88
+ slider.release(update_difficulty, inputs=[slider], outputs=[slider])
89
+ show_image_button.click(give_up, outputs=[image_output, chatbot, label], queue=False)
90
+ reset_button.click(reset_everything, outputs=[slider, image_output, label, chatbot], queue=False)
91
+
92
+ if __name__ == "__main__":
93
+ demo.title = "CLIP Guess the Image"
94
+ demo.launch(share=False)
clip_chat.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import clip
3
+ from PIL import Image
4
+ import glob
5
+ import os
6
+ from random import choice
7
+
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model, preprocess = clip.load("ViT-L/14@336px", device=device)
11
+
12
+ COCO = glob.glob(os.path.join(os.getcwd(), "images", "*"))
13
+ available_models = ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
14
+
15
+
16
+ def load_random_image():
17
+ image_path = choice(COCO)
18
+ image = Image.open(image_path)
19
+ return image
20
+
21
+
22
+ def next_image():
23
+ global image_org, image
24
+ image_org = load_random_image()
25
+ image = preprocess(Image.fromarray(image_org)).unsqueeze(0).to(device)
26
+
27
+
28
+ def calculate_logits(image_features, text_features):
29
+ image_features = image_features / image_features.norm(dim=1, keepdim=True)
30
+ text_features = text_features / text_features.norm(dim=1, keepdim=True)
31
+
32
+ logit_scale = model.logit_scale.exp()
33
+ return logit_scale * image_features @ text_features.t()
34
+
35
+
36
+ last = -1
37
+ best = -1
38
+
39
+ goal = 21
40
+
41
+ image_org = load_random_image()
42
+ image = preprocess(image_org).unsqueeze(0).to(device)
43
+ with torch.no_grad():
44
+ image_features = model.encode_image(image)
45
+
46
+
47
+ def answer(message):
48
+ global last, best
49
+
50
+ text = clip.tokenize([message]).to(device)
51
+
52
+ with torch.no_grad():
53
+ text_features = model.encode_text(text)
54
+ logits_per_image, _ = model(image, text)
55
+ logits = calculate_logits(image_features, text_features).cpu().numpy().flatten()[0]
56
+
57
+ if last == -1:
58
+ is_better = -1
59
+ elif last > logits:
60
+ is_better = 0
61
+ elif last < logits:
62
+ is_better = 1
63
+ elif logits > goal:
64
+ is_better = 2
65
+ else:
66
+ is_better = -1
67
+
68
+ last = logits
69
+ if logits > best:
70
+ best = logits
71
+ is_better = 3
72
+
73
+ return logits, is_better
74
+
75
+
76
+ def reset_everything():
77
+ global last, best, goal, image, image_org
78
+ last = -1
79
+ best = -1
80
+ goal = 21
81
+ image_org = load_random_image()
82
+ image = preprocess(image_org).unsqueeze(0).to(device)