Spaces:
Running
Running
| import argparse | |
| import textwrap | |
| from promptsource.templates import TemplateCollection, INCLUDED_USERS | |
| from promptsource.utils import get_dataset | |
| parser = argparse.ArgumentParser(description="Process some integers.") | |
| parser.add_argument("dataset_path", type=str, help="path to dataset name") | |
| args = parser.parse_args() | |
| if "templates.yaml" not in args.dataset_path: | |
| exit() | |
| path = args.dataset_path.split("/") | |
| if path[2] in INCLUDED_USERS: | |
| print("Skipping showing templates for community dataset.") | |
| else: | |
| dataset_name = path[2] | |
| subset_name = path[3] if len(path) == 5 else "" | |
| template_collection = TemplateCollection() | |
| dataset = get_dataset(dataset_name, subset_name) | |
| splits = list(dataset.keys()) | |
| dataset_templates = template_collection.get_dataset(dataset_name, subset_name) | |
| template_list = dataset_templates.all_template_names | |
| width = 80 | |
| print("DATASET ", args.dataset_path) | |
| # First show all the templates. | |
| for template_name in template_list: | |
| template = dataset_templates[template_name] | |
| print("TEMPLATE") | |
| print("NAME:", template_name) | |
| print("Is Original Task: ", template.metadata.original_task) | |
| print(template.jinja) | |
| print() | |
| # Show examples of the templates. | |
| for template_name in template_list: | |
| template = dataset_templates[template_name] | |
| print() | |
| print("TEMPLATE") | |
| print("NAME:", template_name) | |
| print("REFERENCE:", template.reference) | |
| print("--------") | |
| print() | |
| print(template.jinja) | |
| print() | |
| for split_name in splits: | |
| dataset_split = dataset[split_name] | |
| print_counter = 0 | |
| for example in dataset_split: | |
| print("\t--------") | |
| print("\tSplit ", split_name) | |
| print("\tExample ", example) | |
| print("\t--------") | |
| output = template.apply(example) | |
| if output[0].strip() == "" or (len(output) > 1 and output[1].strip() == ""): | |
| print("\t Blank result") | |
| continue | |
| xp, yp = output | |
| print() | |
| print("\tPrompt | X") | |
| for line in textwrap.wrap(xp, width=width, replace_whitespace=False): | |
| print("\t", line.replace("\n", "\n\t")) | |
| print() | |
| print("\tY") | |
| for line in textwrap.wrap(yp, width=width, replace_whitespace=False): | |
| print("\t", line.replace("\n", "\n\t")) | |
| print_counter += 1 | |
| if print_counter >= 10: | |
| break | |