Vipitis commited on
Commit
9361469
·
1 Parent(s): 6b811d9

add generation config, fixing warning output

Browse files
Files changed (2) hide show
  1. ShaderEval.py +24 -6
  2. app.py +6 -5
ShaderEval.py CHANGED
@@ -8,7 +8,7 @@ _CITATION = """\
8
  @InProceedings{huggingface:module,
9
  title = {A great new module},
10
  authors={huggingface, Inc.},
11
- year={2020}
12
  }
13
  """
14
 
@@ -28,7 +28,7 @@ from evaluate.evaluation_suite import SubTask
28
  from datasets import Dataset
29
  from typing import Any, Callable, Dict, List, Optional, Union # used in .prepare_pipeline()
30
  import transformers
31
- from transformers import Pipeline, pipeline
32
  from datasets import load_dataset #used by Suite.run()
33
 
34
  # write a custom evaluator, inherent from: https://github.com/huggingface/evaluate/blob/v0.4.0/src/evaluate/evaluator/text_generation.py#L31
@@ -36,7 +36,12 @@ class ReturnGenerationEvaluator(evaluate.TextGenerationEvaluator):
36
  def __init__(self, task="text-generation", default_metric_name="exact_match", predictions_prefix: str = "generated"):
37
  super().__init__(task=task, default_metric_name=default_metric_name)
38
  self.predictions_prefix = predictions_prefix
39
- PIPELINE_KWARGS = {"return_full_text":False, "do_sample":False} #these kwargs are for the pipeline call, not the pipeline init.
 
 
 
 
 
40
 
41
  # for the pipeline init we need to copy the whole function and add two lines. this still prints errors due to the pad_toke_id = eos_token_id change.
42
  # from: https://github.com/huggingface/evaluate/blob/v0.4.0/src/evaluate/evaluator/base.py#L375
@@ -98,12 +103,25 @@ class ReturnGenerationEvaluator(evaluate.TextGenerationEvaluator):
98
  # fixinging default for max_lenght
99
  pipe.model.config.max_length = self._resolve_context_lenght(pipe=pipe)
100
 
101
- # specify eos tokens to be all of those that include a ; so we can stop early.
102
- self.PIPELINE_KWARGS.update({"eos_token_id": [v for k,v in pipe.tokenizer.vocab.items() if ";" in k]}) #didn't see that this was passed all the way already.
103
- # solution found here: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.eos_token_id but does it actually work?
104
 
105
  return pipe
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def _resolve_context_lenght(self, model_or_pipeline=None, pipe=None): #TODO should really copy the typing hints here.
108
  if isinstance(model_or_pipeline, transformers.GPT2Model): # you are comparing a string here -.-
109
  return model_or_pipeline.config.n_ctx # how GPT2 models might handle is, seen with
 
8
  @InProceedings{huggingface:module,
9
  title = {A great new module},
10
  authors={huggingface, Inc.},
11
+ year={2023}
12
  }
13
  """
14
 
 
28
  from datasets import Dataset
29
  from typing import Any, Callable, Dict, List, Optional, Union # used in .prepare_pipeline()
30
  import transformers
31
+ from transformers import Pipeline, pipeline, GenerationConfig #GenerationConfig to specify greedy and avoid error
32
  from datasets import load_dataset #used by Suite.run()
33
 
34
  # write a custom evaluator, inherent from: https://github.com/huggingface/evaluate/blob/v0.4.0/src/evaluate/evaluator/text_generation.py#L31
 
36
  def __init__(self, task="text-generation", default_metric_name="exact_match", predictions_prefix: str = "generated"):
37
  super().__init__(task=task, default_metric_name=default_metric_name)
38
  self.predictions_prefix = predictions_prefix
39
+
40
+ greedy_cfg = GenerationConfig(
41
+ do_sample = False, # default to ensure greedy
42
+ num_beams = 1, # same as above
43
+ )
44
+ PIPELINE_KWARGS = {"return_full_text": False, "generation_config":greedy_cfg} #these kwargs are for the pipeline call, not the pipeline init - but that seems to still work.
45
 
46
  # for the pipeline init we need to copy the whole function and add two lines. this still prints errors due to the pad_toke_id = eos_token_id change.
47
  # from: https://github.com/huggingface/evaluate/blob/v0.4.0/src/evaluate/evaluator/base.py#L375
 
103
  # fixinging default for max_lenght
104
  pipe.model.config.max_length = self._resolve_context_lenght(pipe=pipe)
105
 
106
+ # update the generation config with information from the pipe
107
+ self._update_generation_config(pipe)
 
108
 
109
  return pipe
110
 
111
+ def _update_generation_config(self, pipe):
112
+ """
113
+ Update the generation config with information from the pipe. Sets eos_token_id and pad_token_id.
114
+ Args:
115
+ pipe (:class:`~transformers.Pipeline`): we need to access the tokenizer.vocab
116
+ returns:
117
+ None
118
+ """
119
+ semicolon_token_ids = [v for k,v in pipe.tokenizer.vocab.items() if ";" in k] # this requires the tokenizer, which we only have once a pipe is made.
120
+ # GenerationConfig.update also exists, but it does only replace, not add kwargs.
121
+ self.greedy_cfg.eos_token_id = semicolon_token_ids # eos_token_id can be a list, so we give them all possible tokens.
122
+ self.greedy_cfg.pad_token_id = semicolon_token_ids[0] # pad_token_id has to be an int, so we just take the first one.
123
+ return None # doesn't do anything?
124
+
125
  def _resolve_context_lenght(self, model_or_pipeline=None, pipe=None): #TODO should really copy the typing hints here.
126
  if isinstance(model_or_pipeline, transformers.GPT2Model): # you are comparing a string here -.-
127
  return model_or_pipeline.config.n_ctx # how GPT2 models might handle is, seen with
app.py CHANGED
@@ -38,17 +38,18 @@ text = """# Welcome to the ShaderEval Suite.
38
  - The results will be displayed in the **Output** box
39
 
40
  ## Todo (feel free to contribute in a Pull Request)
41
- - leaderboard
42
- - supporting batches to speed up inference
43
- - CER metric (via a custom metric perhaps?)
44
- - removing the pad_token warning
45
- - adding OpenVINO pipelines for inference, pending on OpenVINO release
46
  """
47
 
48
 
49
  def run_suite(model_cp, snippet):
50
  # print(model_cp, snippet)
51
  results = suite.run(model_cp, snippet)
 
52
  return results[0]
53
 
54
  with gr.Blocks() as site:
 
38
  - The results will be displayed in the **Output** box
39
 
40
  ## Todo (feel free to contribute in a Pull Request)
41
+ - [ ] leaderboard
42
+ - [ ] supporting batches to speed up inference
43
+ - [ ] CER metric (via a custom metric perhaps?)
44
+ - [x] removing the pad_token warning
45
+ - [ ] adding OpenVINO pipelines for inference, pending on OpenVINO release
46
  """
47
 
48
 
49
  def run_suite(model_cp, snippet):
50
  # print(model_cp, snippet)
51
  results = suite.run(model_cp, snippet)
52
+ print(results) # so they show up in the logs for me.
53
  return results[0]
54
 
55
  with gr.Blocks() as site: