Spaces:
Running
Running
| """Taken and modified from vllm: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/benchmarks/benchmark_serving.py | |
| Filter dataset to: | |
| 1. Remove entries that have too long prompts or completions | |
| 2. Only keep first human prompt for each conversation | |
| """ | |
| import json | |
| import random | |
| from typing import AsyncGenerator, List, Tuple | |
| from transformers import ( | |
| AutoTokenizer, | |
| PreTrainedTokenizer, | |
| PreTrainedTokenizerBase, | |
| PreTrainedTokenizerFast, | |
| ) | |
| def filter_dataset_to_size( | |
| dataset_path: str, | |
| size: int, | |
| ) -> List[Tuple[str, int, int]]: | |
| # Load the dataset. | |
| with open(dataset_path) as f: | |
| dataset = json.load(f) | |
| # randomly sample dataset | |
| return random.sample(dataset, size) | |
| def filter_dataset( | |
| dataset_path: str, | |
| tokenizer: PreTrainedTokenizerBase, | |
| ) -> List[Tuple[str, int, int]]: | |
| # Load the dataset. | |
| with open(dataset_path) as f: | |
| dataset = json.load(f) | |
| # Filter out the conversations with less than 2 turns. | |
| dataset = [data for data in dataset if len(data["conversations"]) >= 2] | |
| # Only keep the first two turns of each conversation. | |
| dataset = [ | |
| ( | |
| data["id"], | |
| data["conversations"][0]["value"], | |
| data["conversations"][1]["value"], | |
| ) | |
| for data in dataset | |
| ] | |
| # Tokenize the prompts and completions. | |
| conversation_ids = [conv_id for conv_id, _, _ in dataset] | |
| prompts = [prompt for _, prompt, _ in dataset] | |
| prompt_token_ids = tokenizer(prompts).input_ids | |
| completions = [completion for _, _, completion in dataset] | |
| completion_token_ids = tokenizer(completions).input_ids | |
| tokenized_dataset = [] | |
| for i in range(len(dataset)): | |
| output_len = len(completion_token_ids[i]) | |
| tokenized_dataset.append( | |
| (conversation_ids[i], prompts[i], prompt_token_ids[i], output_len) | |
| ) | |
| # Filter out too long sequences. | |
| filtered_dataset_json = [] | |
| for conv_id, prompt, prompt_token_ids, output_len in tokenized_dataset: | |
| prompt_len = len(prompt_token_ids) | |
| if prompt_len < 4 or output_len < 4: | |
| # Prune too short sequences. | |
| # This is because TGI causes errors when the input or output length | |
| # is too short. | |
| continue | |
| # making even shorter than 1024 to account for additional tokens introduced by chat completion wrapper | |
| if prompt_len > 800 or output_len > 800: | |
| # if prompt_len > 1024 or output_len > 1024: | |
| # Prune too long sequences. | |
| continue | |
| filtered_dataset_json.append( | |
| { | |
| "id": conv_id, | |
| "conversations": [ | |
| { | |
| "from": "human", | |
| "value": prompt, | |
| } | |
| ], | |
| } | |
| ) | |
| return filtered_dataset_json | |
| def main(): | |
| tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | |
| # download: https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json | |
| filtered_dataset = filter_dataset( | |
| "ShareGPT_V3_unfiltered_cleaned_split.json", tokenizer | |
| ) | |
| with open("ShareGPT_V3_filtered.json", "w") as f: | |
| json.dump(filtered_dataset, f) | |
| sampled_dataset = filter_dataset_to_size("ShareGPT_V3_filtered.json", 500) | |
| with open("ShareGPT_V3_filtered_500.json", "w") as f: | |
| json.dump(sampled_dataset, f) | |
| if __name__ == "__main__": | |
| main() | |