|
import json |
|
import datasets |
|
import os |
|
|
|
logger = datasets.logging.get_logger(__name__) |
|
|
|
|
|
class Dataset(datasets.GeneratorBasedBuilder): |
|
def _info(self): |
|
return datasets.DatasetInfo( |
|
features=datasets.Features({ |
|
"images": datasets.Sequence(datasets.Image()), |
|
"length": datasets.Value(dtype="int32"), |
|
"conversations": datasets.Sequence(datasets.Features({ |
|
"from": datasets.Value("string"), |
|
"value": datasets.Value("string") |
|
})), |
|
"task_name": datasets.Value("string"), |
|
"step_name": datasets.Value("string"), |
|
"has_retry": datasets.Value("bool"), |
|
"retry_index": datasets.Value("int32"), |
|
"total_retries": datasets.Value("int32"), |
|
"task_num_steps": datasets.Value("int32"), |
|
"task_has_solve_captcha": datasets.Value("bool"), |
|
}) |
|
) |
|
|
|
def _split_generators(self, dl_manager: datasets.DownloadManager): |
|
dl_manager.download_config.token = True |
|
dl_manager.download_config.num_proc = 10 |
|
|
|
base_url = "https://huggingface.co/datasets/empower-dev-staging/skyvern-v0/resolve/main/data" |
|
image_files = dl_manager.download_and_extract( |
|
[f"{base_url}/images/{i + 1}.tar.gz" for i in range(10)]) |
|
|
|
image_file_to_full_path_mapping = dict([ |
|
('images/' + '/'.join(image_file.split('/')[-2:]), image_file) for image_file in dl_manager.iter_files(image_files) |
|
]) |
|
|
|
return [ |
|
datasets.SplitGenerator( |
|
name=datasets.Split.TRAIN, |
|
gen_kwargs={ |
|
"filepath": dl_manager.download_and_extract( |
|
f"{base_url}/train.jsonl"), |
|
"image_file_to_full_path_mapping": image_file_to_full_path_mapping |
|
}, |
|
), |
|
datasets.SplitGenerator( |
|
name=datasets.Split.TEST, |
|
gen_kwargs={ |
|
"filepath": dl_manager.download_and_extract( |
|
f"{base_url}/test.jsonl"), |
|
"image_file_to_full_path_mapping": image_file_to_full_path_mapping |
|
}, |
|
), |
|
] |
|
|
|
def _get_step_info(self, item): |
|
first_image_path = item['images'][0] |
|
folder = '/'.join(first_image_path.split('/')[-2:-1]) |
|
|
|
task = folder.split('-')[0] |
|
step = folder.split('-')[1].split('_') |
|
|
|
step_number = step[0] |
|
retry_index = int(step[1]) |
|
|
|
return { |
|
"task_name": task, |
|
"step_name": f"{task}-{step_number}", |
|
"retry_index": retry_index |
|
} |
|
|
|
def _generate_examples(self, filepath, image_file_to_full_path_mapping): |
|
with open(filepath, "r") as f: |
|
lines = f.readlines() |
|
|
|
items = [] |
|
step_name_to_retry_indices = {} |
|
task_name_to_num_steps = {} |
|
task_name_to_having_solve_captcha = {} |
|
for id, line in enumerate(lines): |
|
item = json.loads(line) |
|
actions = json.loads(item["conversations"][1]["value"])[ |
|
"actions"] |
|
if len(actions) == 0: |
|
continue |
|
|
|
items.append(item) |
|
|
|
step_info = self._get_step_info(item) |
|
step_name = step_info["step_name"] |
|
task_name = step_info["task_name"] |
|
|
|
if task_name not in task_name_to_having_solve_captcha: |
|
task_name_to_having_solve_captcha[task_name] = False |
|
if any(action["action_type"].lower() == "solve_captcha" for action in actions): |
|
task_name_to_having_solve_captcha[task_name] = True |
|
|
|
if step_name not in step_name_to_retry_indices: |
|
step_name_to_retry_indices[step_name] = [] |
|
task_name_to_num_steps[task_name] = task_name_to_num_steps.get( |
|
task_name, 0) + 1 |
|
step_name_to_retry_indices[step_name].append( |
|
step_info["retry_index"]) |
|
|
|
step_name_to_retry_indices = dict([ |
|
(step_name, sorted(retry_indices)) for (step_name, retry_indices) in step_name_to_retry_indices.items() |
|
]) |
|
|
|
for id, item in enumerate(items): |
|
step_info = self._get_step_info(item) |
|
retry_indices = step_name_to_retry_indices[step_info['step_name']] |
|
yield id, { |
|
"images": [ |
|
image_file_to_full_path_mapping[image] for image in item["images"] |
|
], |
|
"conversations": item["conversations"], |
|
"length": item["length"], |
|
"task_name": step_info["task_name"], |
|
"step_name": step_info["step_name"], |
|
"has_retry": len(retry_indices) > 1, |
|
"retry_index": retry_indices.index(step_info["retry_index"]), |
|
"total_retries": len(retry_indices), |
|
"task_num_steps": task_name_to_num_steps[step_info["task_name"]], |
|
"task_has_solve_captcha": task_name_to_having_solve_captcha[step_info["task_name"]], |
|
} |
|
|