Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	add example script to evaluate a model and generate code
Browse files- example_script.py +132 -0
 
    	
        example_script.py
    ADDED
    
    | 
         @@ -0,0 +1,132 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """This is an example script to evaluate a code generation model on APPS, you can also use the APPS solutions as code generations
         
     | 
| 2 | 
         
            +
            >>> python example_script.py --model_ckpt MODEL_NAME --num_tasks 10 --difficulty introductory --n_samples 1
         
     | 
| 3 | 
         
            +
            >>> python example_script.py --use_solutions True --num_tasks 10 --difficulty introductory --n_samples 1"""
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import json
         
     | 
| 6 | 
         
            +
            import pprint
         
     | 
| 7 | 
         
            +
            from tqdm import tqdm
         
     | 
| 8 | 
         
            +
            from datasets import load_dataset
         
     | 
| 9 | 
         
            +
            from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, set_seed
         
     | 
| 10 | 
         
            +
            from tools.utils import compute_metrics
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            def generate_prompt(sample):
         
     | 
| 13 | 
         
            +
                starter_code = None if len(sample["starter_code"]) == 0 else sample["starter_code"] 
         
     | 
| 14 | 
         
            +
                try:
         
     | 
| 15 | 
         
            +
                    input_outpout = json.loads(sample["input_output"])
         
     | 
| 16 | 
         
            +
                    fn_name = None if not input_outpout.get("fn_name") else input_outpout["fn_name"] 
         
     | 
| 17 | 
         
            +
                except ValueError:
         
     | 
| 18 | 
         
            +
                    fn_name = None 
         
     | 
| 19 | 
         
            +
                _input = "\nQUESTION:\n"
         
     | 
| 20 | 
         
            +
                _input += sample["question"]
         
     | 
| 21 | 
         
            +
                if starter_code:
         
     | 
| 22 | 
         
            +
                    _input += starter_code
         
     | 
| 23 | 
         
            +
                if fn_name:
         
     | 
| 24 | 
         
            +
                    _input += "\nUse Standard Input format"
         
     | 
| 25 | 
         
            +
                else:
         
     | 
| 26 | 
         
            +
                    _input += "\nUse Call-Based format"
         
     | 
| 27 | 
         
            +
                
         
     | 
| 28 | 
         
            +
                _input += "\nANSWER:\n"
         
     | 
| 29 | 
         
            +
                return _input
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def complete_code(pipe, prompt, num_completions=1, max_length=256, **gen_kwargs):
         
     | 
| 33 | 
         
            +
                """Complete prompt with text generation pipeline and return num_completions."""
         
     | 
| 34 | 
         
            +
                prompt = pipe.tokenizer.eos_token + prompt
         
     | 
| 35 | 
         
            +
                try:
         
     | 
| 36 | 
         
            +
                    code_gens = pipe(prompt, num_return_sequences=num_completions, max_length=max_length, **gen_kwargs)
         
     | 
| 37 | 
         
            +
                    return [code_gen["generated_text"][len(prompt):] for code_gen in code_gens]
         
     | 
| 38 | 
         
            +
                except IndexError:
         
     | 
| 39 | 
         
            +
                    print("prompt is longer than the context size of the model, generation skipped")
         
     | 
| 40 | 
         
            +
                    code_gens = ""
         
     | 
| 41 | 
         
            +
                    return [""]
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def make_generations(dataset, args, model, tokenizer):
         
     | 
| 45 | 
         
            +
                set_seed(args.seed)
         
     | 
| 46 | 
         
            +
                pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=args.device_int)
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                # Generation settings
         
     | 
| 49 | 
         
            +
                gen_kwargs = {
         
     | 
| 50 | 
         
            +
                    "do_sample": args.do_sample,
         
     | 
| 51 | 
         
            +
                    "temperature": args.temperature,
         
     | 
| 52 | 
         
            +
                    "top_p": args.top_p,
         
     | 
| 53 | 
         
            +
                    "top_k": args.top_k
         
     | 
| 54 | 
         
            +
                }
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                # Generate completions for evaluation set
         
     | 
| 57 | 
         
            +
                n_tasks = args.num_tasks if args.num_tasks is not None else len(dataset)
         
     | 
| 58 | 
         
            +
                print(f"ntasks is {n_tasks}")
         
     | 
| 59 | 
         
            +
                generations = []
         
     | 
| 60 | 
         
            +
                for task in tqdm(range(n_tasks)):
         
     | 
| 61 | 
         
            +
                    task_generations = []
         
     | 
| 62 | 
         
            +
                    prompt = generate_prompt(dataset[task]).strip()
         
     | 
| 63 | 
         
            +
                    task_generations.extend(complete_code(pipe, prompt, num_completions=args.n_samples, max_length=args.max_length, **gen_kwargs))
         
     | 
| 64 | 
         
            +
                    generations.append([gen.replace(args.eos, "") for gen in task_generations])
         
     | 
| 65 | 
         
            +
                return generations
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def main(args):
         
     | 
| 69 | 
         
            +
                DATA_PATH = "codeparrot/apps"
         
     | 
| 70 | 
         
            +
                argsdict = vars(args)
         
     | 
| 71 | 
         
            +
                print(pprint.pformat(argsdict))
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                # setup
         
     | 
| 74 | 
         
            +
                print("Loading evaluation dataset...")
         
     | 
| 75 | 
         
            +
                dataset = load_dataset(DATA_PATH, split="test", difficulties=[args.difficulty])
         
     | 
| 76 | 
         
            +
                if args.use_solutions:
         
     | 
| 77 | 
         
            +
                    print("Using data solutions as code generations")
         
     | 
| 78 | 
         
            +
                    model = None
         
     | 
| 79 | 
         
            +
                    tokenizer = None
         
     | 
| 80 | 
         
            +
                    generations = []
         
     | 
| 81 | 
         
            +
                    for index in range(args.num_tasks+1):
         
     | 
| 82 | 
         
            +
                        try:
         
     | 
| 83 | 
         
            +
                            sol = json.loads(dataset[index]["solutions"])
         
     | 
| 84 | 
         
            +
                            generations.append(sol[:args.n_solutions])
         
     | 
| 85 | 
         
            +
                        except ValueError:
         
     | 
| 86 | 
         
            +
                            print(f"No solutions for task {index} or not enough to have {args.n_solutions} solutions")
         
     | 
| 87 | 
         
            +
                            break
         
     | 
| 88 | 
         
            +
                    
         
     | 
| 89 | 
         
            +
                else:
         
     | 
| 90 | 
         
            +
                    print("Loading tokenizer and model...")
         
     | 
| 91 | 
         
            +
                    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
         
     | 
| 92 | 
         
            +
                    model = AutoModelForCausalLM.from_pretrained(args.model_ckpt)
         
     | 
| 93 | 
         
            +
                    generations = make_generations(dataset, args, model, tokenizer)
         
     | 
| 94 | 
         
            +
                
         
     | 
| 95 | 
         
            +
                metrics = compute_metrics(generations, level=args.difficulty, k_list=args.k_list, count_errors=args.count_errors, debug=args.debug)
         
     | 
| 96 | 
         
            +
                print(metrics)
         
     | 
| 97 | 
         
            +
                with open(args.output_file, "w") as fp:
         
     | 
| 98 | 
         
            +
                    json.dump(metrics, fp)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 102 | 
         
            +
                import argparse
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                parser = argparse.ArgumentParser(description="Testing a Language Model on APPS Python Code dataset")
         
     | 
| 105 | 
         
            +
                #model and tokenizer arguments
         
     | 
| 106 | 
         
            +
                parser.add_argument("--model_ckpt", default="loubnabnl/apps-1.5B-model", type=str, help="path to model checkpoint.")
         
     | 
| 107 | 
         
            +
                parser.add_argument("--tokenizer", default="gpt2", type=str, help="tokenizer to use.")
         
     | 
| 108 | 
         
            +
                parser.add_argument("--eos", default="<|endoftext|>", type=str, help="end of sentence token.")
         
     | 
| 109 | 
         
            +
                # generation arguments
         
     | 
| 110 | 
         
            +
                parser.add_argument("--do_sample", default=True, type=bool, help="do sampling in generation")
         
     | 
| 111 | 
         
            +
                parser.add_argument("--temperature", default=0.2, type=float, help="temperature for sampling")
         
     | 
| 112 | 
         
            +
                parser.add_argument("--top_p", default=0.95, type=float, help="top p for sampling")
         
     | 
| 113 | 
         
            +
                parser.add_argument("--top_k", default=0, type=float, help="top k for sampling")
         
     | 
| 114 | 
         
            +
                parser.add_argument("--max_length", default=1024, type=int, help="max length of generated code")
         
     | 
| 115 | 
         
            +
                # evaluation arguments
         
     | 
| 116 | 
         
            +
                parser.add_argument("--difficulty", default="all", type=str, help="difficulty level to select in the dataset from:\
         
     | 
| 117 | 
         
            +
                 'all', 'introductory', 'interview'  and 'competition' ")
         
     | 
| 118 | 
         
            +
                parser.add_argument("--num_tasks", default=6, type=int, help="number of tasks to evaluate")
         
     | 
| 119 | 
         
            +
                parser.add_argument("--use_solutions", default=False, type=bool, help="use solutions instead of generating new code")
         
     | 
| 120 | 
         
            +
                parser.add_argument("--n_samples", default=1, type=int, help="number of samples to generate")
         
     | 
| 121 | 
         
            +
                parser.add_argument("--n_solutions", default=1, type=int, help="number of solutions to use")
         
     | 
| 122 | 
         
            +
                parser.add_argument("--k_list", default=[1, 2, 3], type=list, help="list of k values to evaluate pass@k")
         
     | 
| 123 | 
         
            +
                parser.add_argument("--count_errors", default=False, type=bool, help="count compilation and runtime errors for single generations")
         
     | 
| 124 | 
         
            +
                # configuration
         
     | 
| 125 | 
         
            +
                parser.add_argument("--seed", default=0, type=int, help="generation seed")
         
     | 
| 126 | 
         
            +
                parser.add_argument("--device_int", default=-1, type=int, help="device on which code generation is run, if positive use GPU")
         
     | 
| 127 | 
         
            +
                parser.add_argument("--debug", default=False, type=bool, help="debug mode")
         
     | 
| 128 | 
         
            +
                # save
         
     | 
| 129 | 
         
            +
                parser.add_argument("--output_file", default="apps_metrics.json", type=str, help="output file to save the results")
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                args = parser.parse_args()
         
     | 
| 132 | 
         
            +
                main(args)
         
     |