Qwen
/

Qwen-7B-Chat / test_passkey_retrieval.py
tpoisonooo
feat(testcase): add testcase
7159c70
raw
history blame
4.53 kB
import argparse
import random
from numpy import random
from transformers import AutoTokenizer, AutoModelForCausalLM
import pdb
def parse_config():
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--max_tokens', type=int, default=20000, help='maximum token length for evaluation')
parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation')
parser.add_argument('--num_tests', type=int, default=30, help='number of repeat testing for each length')
args = parser.parse_args()
return args
# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py
def generate_prompt_landmark(n_garbage=60000, seed=666):
"""Generates a text file and inserts an passkey at a random position."""
rnd_state = random.get_state()
random.seed(seed)
n_garbage_prefix = random.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix
task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
garbage_inf = " ".join([garbage] * 5000)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = random.randint(1, 50000)
information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
final_question = "What is the pass key? The pass key is"
print('idx : {}'.format(len(task_description) + len(garbage_prefix)))
lines = [
task_description,
garbage_prefix,
information_line,
garbage_suffix,
final_question,
]
random.set_state(rnd_state)
return "\n".join(lines), str(pass_key)
# NTK+log on Qwen-7B tokens {'5801': 0.95, '7986': 0.9, '8805': 0.85, '9897': 0.8, '11809': 0.95, '12900': 0.78, '13993':0.06, '14812': 0.0}
# ReRoPE on Qwen-7B
def main(args):
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('/models/Qwen-7B-Chat-ReRoPE', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained('/models/Qwen-7B-Chat-ReRoPE', trust_remote_code=True).eval().cuda('cuda:3')
# tokenizer = AutoTokenizer.from_pretrained('/models/Qwen-14B-Chat', trust_remote_code=True)
# model = AutoModelForCausalLM.from_pretrained('/models/Qwen-14B-Chat', trust_remote_code=True).eval().cuda('cuda:3')
all_accuries = {}
# This is a rough ratio to control the number of texts and tokens
# for val in [8000, 9000, 10000, 11000, 13000, 14000, 15000, 16000, 17000]:
for val in range(2000, 12000, args.interval):
n_garbage = int(3.75 * val // 1024 * 1024)
passed_tests = 0
total_tokens = 0
for j in range(args.num_tests):
prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j)
response, _ = model.chat(tokenizer, prompt, history=[], top_k=1)
print((response, pass_key))
if pass_key in response:
passed_tests += 1
total_tokens += len(tokenizer(prompt).input_ids)
avg_tokens = total_tokens//args.num_tests
accuracy = passed_tests/args.num_tests
print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
all_accuries[str(avg_tokens)] = accuracy
all_accuries = {}
# This is a rough ratio to control the number of texts and tokens
# for val in [8000, 9000, 10000, 11000, 13000, 14000, 15000, 16000, 17000]:
for val in range(2000, 12000, args.interval):
n_garbage = int(3.75 * val // 1024 * 1024)
passed_tests = 0
total_tokens = 0
for j in range(args.num_tests):
prompt, pass_key = generate_prompt_landmark(n_garbage=n_garbage, seed=j+val)
response, _ = model.chat(tokenizer, prompt, history=[])
print((response, pass_key))
if pass_key in response:
passed_tests += 1
total_tokens += len(tokenizer(prompt).input_ids)
avg_tokens = total_tokens//args.num_tests
accuracy = passed_tests/args.num_tests
print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
all_accuries[str(avg_tokens)] = accuracy
print("accuries over tokens", all_accuries)
if __name__ == "__main__":
args = parse_config()
main(args)