AmaadMartin commited on
Commit
6644e70
·
verified ·
1 Parent(s): a5986eb

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +115 -0
  2. requirements.txt +3 -0
handler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from peft import AutoPeftModelForCausalLM
3
+ import transformers
4
+ import os
5
+ import tempfile
6
+ from PIL import Image, ImageDraw
7
+
8
+ COORDINATE_PROMPT = 'In this UI screenshot, what is the position of the element corresponding to the command \"{command}\" (with point)?'
9
+
10
+ PARTITION_PROMPT = 'In this UI screenshot, what is the partition of the element corresponding to the command \"{command}\" (with quadrant number)?'
11
+
12
+ class EndpointHandler():
13
+ def __init__(self, path=""):
14
+ self.model = AutoPeftModelForCausalLM.from_pretrained(
15
+ path,
16
+ device_map="cuda",
17
+ trust_remote_code=True,
18
+ fp16=True).eval()
19
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
20
+ path,
21
+ cache_dir=None,
22
+ model_max_length=2048,
23
+ padding_side="right",
24
+ use_fast=False,
25
+ trust_remote_code=True,
26
+ )
27
+ tokenizer.pad_token_id = tokenizer.eod_id
28
+ return
29
+
30
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
31
+ """
32
+ data args:
33
+ image (:obj: `PIL.Image`)
34
+ task (:obj: `str`)
35
+ k (:obj: `str`)
36
+ context (:obj: 'str')
37
+ kwargs
38
+ Return:
39
+ A :obj:`list` | `dict`: will be serialized and returned
40
+ """
41
+ # open temp directory
42
+ with tempfile.TemporaryDirectory() as temp_dir:
43
+ image = os.path.join(temp_dir, "image.jpg")
44
+ data["image"].save(image)
45
+ img = Image.open(image)
46
+ command = data["task"]
47
+ K = int(data["k"])
48
+ keep_context = bool(data["context"])
49
+
50
+ print(image)
51
+ print(command)
52
+ print(k)
53
+ print(keep_context)
54
+
55
+ images = [image]
56
+ partitions = []
57
+
58
+ try:
59
+ for k in range(K):
60
+ query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) +
61
+ [{'text': PARTITION_PROMPT.format(command=command)}])
62
+ response, _ = self.model.chat(self.tokenizer, query=query, history=None)
63
+
64
+ partition = int(response.split(" ")[-1])
65
+ partitions.append(partition)
66
+
67
+ # get cropped image of the partition
68
+ with Image.open(image) as img:
69
+ width, height = img.size
70
+ if partition == 1:
71
+ img = img.crop((width // 2, 0, width, height // 2))
72
+ elif partition == 2:
73
+ img = img.crop((0, 0, width // 2, height // 2))
74
+ elif partition == 3:
75
+ img = img.crop((0, height // 2, width // 2, height))
76
+ elif partition == 4:
77
+ img = img.crop((width // 2, height // 2, width, height))
78
+
79
+ new_path = os.path.join(temp_dir, f"partition{k}.png")
80
+ img.save(new_path)
81
+ image = new_path
82
+ images.append(image)
83
+
84
+ query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) +
85
+ [{'text': COORDINATE_PROMPT.format(command=command)}])
86
+ response, _ = self.model.chat(self.tokenizer, query=query, history=None)
87
+ print("Coordinate Response:", response)
88
+
89
+ x = float(response.split(",")[0].split("(")[1])
90
+ y = float(response.split(",")[1].split(")")[0])
91
+
92
+ for partition in partitions[::-1]:
93
+ if partition == 1:
94
+ x = x/2 + 0.5
95
+ y = y/2
96
+ elif partition == 2:
97
+ x = x/2
98
+ y = y/2
99
+ elif partition == 3:
100
+ x = x/2
101
+ y = y/2 + 0.5
102
+ elif partition == 4:
103
+ x = x/2 + 0.5
104
+ y = y/2 + 0.5
105
+ print("rescaled point:", x, y)
106
+
107
+ except:
108
+ print("Invalid response")
109
+ print()
110
+
111
+ response = {}
112
+ response['x'] = x
113
+ response['y'] = y
114
+ return response
115
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ peft
2
+ transformers
3
+ PIL