File size: 4,584 Bytes
6644e70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import Dict, List, Any
from peft import AutoPeftModelForCausalLM
import transformers
import os
import tempfile
from PIL import Image, ImageDraw

COORDINATE_PROMPT = 'In this UI screenshot, what is the position of the element corresponding to the command \"{command}\" (with point)?'

PARTITION_PROMPT = 'In this UI screenshot, what is the partition of the element corresponding to the command \"{command}\" (with quadrant number)?'

class EndpointHandler():
    def __init__(self, path=""):
        self.model = AutoPeftModelForCausalLM.from_pretrained(
                    path,
                    device_map="cuda",
                    trust_remote_code=True,
                    fp16=True).eval()
        tokenizer = transformers.AutoTokenizer.from_pretrained(
                    path,
                    cache_dir=None,
                    model_max_length=2048,
                    padding_side="right",
                    use_fast=False,
                    trust_remote_code=True,
                )
        tokenizer.pad_token_id = tokenizer.eod_id
        return
 
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
       data args:
            image (:obj: `PIL.Image`)
            task (:obj: `str`)
            k (:obj: `str`)
            context (:obj: 'str')
            kwargs
      Return:
            A :obj:`list` | `dict`: will be serialized and returned
        """
        # open temp directory
        with tempfile.TemporaryDirectory() as temp_dir:
            image = os.path.join(temp_dir, "image.jpg")
            data["image"].save(image)
            img = Image.open(image)
            command = data["task"]
            K = int(data["k"])
            keep_context = bool(data["context"])

            print(image)
            print(command)
            print(k)
            print(keep_context)

            images = [image]
            partitions = []

            try: 
                for k in range(K):
                    query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) + 
                    [{'text': PARTITION_PROMPT.format(command=command)}])
                    response, _ = self.model.chat(self.tokenizer, query=query, history=None)

                    partition = int(response.split(" ")[-1])
                    partitions.append(partition)

                    # get cropped image of the partition
                    with Image.open(image) as img:
                        width, height = img.size
                        if partition == 1:
                            img = img.crop((width // 2, 0, width, height // 2))
                        elif partition == 2:
                            img = img.crop((0, 0, width // 2, height // 2))
                        elif partition == 3:
                            img = img.crop((0, height // 2, width // 2, height))
                        elif partition == 4:
                            img = img.crop((width // 2, height // 2, width, height))
                        
                        new_path = os.path.join(temp_dir, f"partition{k}.png")
                        img.save(new_path)
                        image = new_path
                        images.append(image)
            
                query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) + 
                [{'text': COORDINATE_PROMPT.format(command=command)}])
                response, _ = self.model.chat(self.tokenizer, query=query, history=None)
                print("Coordinate Response:", response)

                x = float(response.split(",")[0].split("(")[1])
                y = float(response.split(",")[1].split(")")[0])

                for partition in partitions[::-1]:
                    if partition == 1:
                        x = x/2 + 0.5
                        y = y/2
                    elif partition == 2:
                        x = x/2
                        y = y/2
                    elif partition == 3:
                        x = x/2
                        y = y/2 + 0.5
                    elif partition == 4:
                        x = x/2 + 0.5
                        y = y/2 + 0.5
                print("rescaled point:", x, y)

            except:
                print("Invalid response")
                print()

            response = {}
            response['x'] = x
            response['y'] = y
            return response