k_1_context_model / handler.py
AmaadMartin's picture
Upload 2 files
6644e70 verified
from typing import Dict, List, Any
from peft import AutoPeftModelForCausalLM
import transformers
import os
import tempfile
from PIL import Image, ImageDraw
COORDINATE_PROMPT = 'In this UI screenshot, what is the position of the element corresponding to the command \"{command}\" (with point)?'
PARTITION_PROMPT = 'In this UI screenshot, what is the partition of the element corresponding to the command \"{command}\" (with quadrant number)?'
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoPeftModelForCausalLM.from_pretrained(
path,
device_map="cuda",
trust_remote_code=True,
fp16=True).eval()
tokenizer = transformers.AutoTokenizer.from_pretrained(
path,
cache_dir=None,
model_max_length=2048,
padding_side="right",
use_fast=False,
trust_remote_code=True,
)
tokenizer.pad_token_id = tokenizer.eod_id
return
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
image (:obj: `PIL.Image`)
task (:obj: `str`)
k (:obj: `str`)
context (:obj: 'str')
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
# open temp directory
with tempfile.TemporaryDirectory() as temp_dir:
image = os.path.join(temp_dir, "image.jpg")
data["image"].save(image)
img = Image.open(image)
command = data["task"]
K = int(data["k"])
keep_context = bool(data["context"])
print(image)
print(command)
print(k)
print(keep_context)
images = [image]
partitions = []
try:
for k in range(K):
query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) +
[{'text': PARTITION_PROMPT.format(command=command)}])
response, _ = self.model.chat(self.tokenizer, query=query, history=None)
partition = int(response.split(" ")[-1])
partitions.append(partition)
# get cropped image of the partition
with Image.open(image) as img:
width, height = img.size
if partition == 1:
img = img.crop((width // 2, 0, width, height // 2))
elif partition == 2:
img = img.crop((0, 0, width // 2, height // 2))
elif partition == 3:
img = img.crop((0, height // 2, width // 2, height))
elif partition == 4:
img = img.crop((width // 2, height // 2, width, height))
new_path = os.path.join(temp_dir, f"partition{k}.png")
img.save(new_path)
image = new_path
images.append(image)
query = self.tokenizer.from_list_format(([{ 'image': context_image } for context_image in images] if keep_context else [{'image': image}]) +
[{'text': COORDINATE_PROMPT.format(command=command)}])
response, _ = self.model.chat(self.tokenizer, query=query, history=None)
print("Coordinate Response:", response)
x = float(response.split(",")[0].split("(")[1])
y = float(response.split(",")[1].split(")")[0])
for partition in partitions[::-1]:
if partition == 1:
x = x/2 + 0.5
y = y/2
elif partition == 2:
x = x/2
y = y/2
elif partition == 3:
x = x/2
y = y/2 + 0.5
elif partition == 4:
x = x/2 + 0.5
y = y/2 + 0.5
print("rescaled point:", x, y)
except:
print("Invalid response")
print()
response = {}
response['x'] = x
response['y'] = y
return response