pszemraj commited on
Commit
f578dba
·
1 Parent(s): b07a789

✨ upgrade aggregation model

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (2) hide show
  1. aggregate.py +30 -79
  2. app.py +7 -5
aggregate.py CHANGED
@@ -1,12 +1,10 @@
1
  """
2
- aggregate.py - module for aggregating text from multiple sources/multiple parts of a single source.
3
- Primary usage is through the BatchAggregator class.
4
 
5
- How it works:
6
- 1. We tell the language model to do it.
7
- 2. The language model does it.
8
- 3. Yaay!
9
  """
 
10
  import logging
11
  import pprint as pp
12
  import time
@@ -14,8 +12,6 @@ import time
14
  import torch
15
  from transformers import GenerationConfig, pipeline
16
 
17
- from utils import compare_model_size
18
-
19
  # Setting up logging
20
  logging.basicConfig(
21
  level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -27,42 +23,30 @@ class BatchAggregator:
27
  BatchAggregator is a class for aggregating text from multiple sources.
28
 
29
  Usage:
30
- >>> from aggregate import BatchAggregator
31
- >>> aggregator = BatchAggregator()
32
- >>> agg = aggregator.infer_aggregate(["This is a test", "This is another test"])
33
- >>> print(agg)
34
  """
35
 
36
  GENERIC_CONFIG = GenerationConfig(
37
- num_beams=8,
 
38
  early_stopping=True,
39
  do_sample=False,
40
- min_new_tokens=32,
41
- max_new_tokens=256,
42
- repetition_penalty=1.1,
43
- length_penalty=1.4,
44
- no_repeat_ngram_size=4,
45
- encoder_no_repeat_ngram_size=5,
46
  )
47
- CONFIGURED_MODELS = [
48
- "pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
49
- "pszemraj/bart-base-instruct-dolly_hhrlhf",
50
- "pszemraj/flan-t5-large-instruct-dolly_hhrlhf",
51
- "pszemraj/flan-t5-base-instruct-dolly_hhrlhf",
52
- ] # these have generation configs defined for this task in their model repos
53
-
54
- DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
55
 
56
  def __init__(
57
  self,
58
- model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1",
59
  force_cpu: bool = False,
60
  **kwargs,
61
  ):
62
  """
63
  __init__ initializes the BatchAggregator class.
64
 
65
- :param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
66
  :param bool force_cpu: force the model to run on CPU, default: False
67
  """
68
  self.device = None
@@ -87,40 +71,29 @@ class BatchAggregator:
87
  self.model_name = model_name
88
  self.aggregator = self._create_pipeline(model_name)
89
  self._configure_model()
90
- # update the generation config with the specific tokenizer
91
- tokenizer_params = {
92
- "decoder_start_token_id": 0
93
- if "t5" in model_name.lower()
94
- else self.aggregator.tokenizer.eos_token_id,
95
- "eos_token_id": 1
96
- if "t5" in model_name.lower()
97
- else self.aggregator.tokenizer.eos_token_id,
98
- "pad_token_id": 0
99
- if "t5" in model_name.lower()
100
- else self.aggregator.tokenizer.pad_token_id,
101
- }
102
- self.update_generation_config(**tokenizer_params)
103
 
104
  def _create_pipeline(
105
- self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
106
  ) -> pipeline:
107
  """
108
  _create_pipeline creates a pipeline for the model.
109
 
110
- :param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
111
  :return pipeline: the pipeline for the model
112
 
113
  :raises Exception: if the pipeline cannot be created
114
  """
115
- self.device = 0 if torch.cuda.is_available() and not self.force_cpu else -1
 
 
116
  try:
117
  self.logger.info(
118
- f"Creating pipeline with model {model_name} on device {self.device}"
119
  )
120
  return pipeline(
121
  "text2text-generation",
122
  model=model_name,
123
- device=self.device,
124
  torch_dtype=torch.float32,
125
  )
126
  except Exception as e:
@@ -137,36 +110,16 @@ class BatchAggregator:
137
  except Exception as e:
138
  self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
139
 
140
- if self.model_name not in self.CONFIGURED_MODELS:
141
- self.logger.info("Setting generation config to general defaults")
142
- self._set_default_generation_config()
143
- else:
144
- try:
145
- self.logger.info("Loading generation config from hub")
146
- self.aggregator.model.generation_config = (
147
- GenerationConfig.from_pretrained(self.model_name)
148
- )
149
- except Exception as e:
150
- self.logger.warning(
151
- f"Could not load generation config, using defaults: {e}"
152
- )
153
- self._set_default_generation_config()
154
-
155
  self.logger.info(self.aggregator.model.generation_config.to_json_string())
156
 
157
  def _set_default_generation_config(self):
158
  """
159
  Set the default generation configuration for the model.
160
  """
161
- self.aggregator.model.generation_config = self.GENERIC_CONFIG
162
-
163
- if (
164
- "large"
165
- or "xl" in self.model_name.lower()
166
- or compare_model_size(self.model_name, 500)
167
- ):
168
- upd = {"num_beams": 4}
169
- self.update_generation_config(**upd)
170
 
171
  def update_generation_config(self, **kwargs):
172
  """
@@ -176,7 +129,6 @@ class BatchAggregator:
176
  **kwargs: The parameters to update in the generation configuration.
177
  """
178
  self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
179
-
180
  self.aggregator.model.generation_config.update(**kwargs)
181
 
182
  def get_generation_config(self) -> dict:
@@ -200,33 +152,32 @@ class BatchAggregator:
200
  def infer_aggregate(
201
  self,
202
  text_list: list,
203
- instruction: str = DEFAULT_INSTRUCTION,
204
  **kwargs,
205
  ) -> str:
206
- f"""
207
  infer_aggregate - infers a consolidated summary from a list of texts.
208
 
209
  Args:
210
  text_list (list): The texts to summarize.
211
- instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}.
212
  **kwargs: Additional parameters to update in the generation configuration.
213
 
214
  Returns:
215
  The generated summary.
216
  """
217
- joined_text = "\n".join(text_list)
218
- prompt = f"{instruction}\n\n{joined_text}\n"
219
  if kwargs:
220
  self.update_generation_config(**kwargs)
221
  st = time.perf_counter()
222
  self.logger.info(f"inference on {len(text_list)} texts ...")
223
  result = self.aggregator(
224
- prompt,
225
  generation_config=self.aggregator.model.generation_config,
226
  )[0]["generated_text"]
227
  self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
228
  self.logger.info(
229
- f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}"
230
  )
231
  self.logger.debug(f"Generated text:\n{result}")
232
 
 
1
  """
2
+ aggregate.py - module for 'reducing' multiple 'summary chunks' into one
 
3
 
4
+ an overly complicated class for legacy compatibility reasons, for usage of the
5
+ 2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage
 
 
6
  """
7
+
8
  import logging
9
  import pprint as pp
10
  import time
 
12
  import torch
13
  from transformers import GenerationConfig, pipeline
14
 
 
 
15
  # Setting up logging
16
  logging.basicConfig(
17
  level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
 
23
  BatchAggregator is a class for aggregating text from multiple sources.
24
 
25
  Usage:
26
+ from aggregate import BatchAggregator
27
+ aggregator = BatchAggregator()
28
+ agg = aggregator.infer_aggregate(["This is a test", "This is another test"])
29
+ print(agg)
30
  """
31
 
32
  GENERIC_CONFIG = GenerationConfig(
33
+ max_new_tokens=512,
34
+ num_beams=4,
35
  early_stopping=True,
36
  do_sample=False,
37
+ truncation=True,
 
 
 
 
 
38
  )
 
 
 
 
 
 
 
 
39
 
40
  def __init__(
41
  self,
42
+ model_name: str = "pszemraj/bart-large-summary-map-reduce",
43
  force_cpu: bool = False,
44
  **kwargs,
45
  ):
46
  """
47
  __init__ initializes the BatchAggregator class.
48
 
49
+ :param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce"
50
  :param bool force_cpu: force the model to run on CPU, default: False
51
  """
52
  self.device = None
 
71
  self.model_name = model_name
72
  self.aggregator = self._create_pipeline(model_name)
73
  self._configure_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def _create_pipeline(
76
+ self, model_name: str = "pszemraj/bart-large-summary-map-reduce"
77
  ) -> pipeline:
78
  """
79
  _create_pipeline creates a pipeline for the model.
80
 
81
+ :param str model_name: model name to use
82
  :return pipeline: the pipeline for the model
83
 
84
  :raises Exception: if the pipeline cannot be created
85
  """
86
+ device_map = (
87
+ "auto" if torch.cuda.is_available() and not self.force_cpu else "cpu"
88
+ )
89
  try:
90
  self.logger.info(
91
+ f"Creating pipeline with model {model_name} on device {device_map}"
92
  )
93
  return pipeline(
94
  "text2text-generation",
95
  model=model_name,
96
+ device_map=device_map,
97
  torch_dtype=torch.float32,
98
  )
99
  except Exception as e:
 
110
  except Exception as e:
111
  self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
112
 
113
+ self._set_default_generation_config()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  self.logger.info(self.aggregator.model.generation_config.to_json_string())
115
 
116
  def _set_default_generation_config(self):
117
  """
118
  Set the default generation configuration for the model.
119
  """
120
+ self.aggregator.model.generation_config.update(
121
+ **self.GENERIC_CONFIG.to_diff_dict()
122
+ )
 
 
 
 
 
 
123
 
124
  def update_generation_config(self, **kwargs):
125
  """
 
129
  **kwargs: The parameters to update in the generation configuration.
130
  """
131
  self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
 
132
  self.aggregator.model.generation_config.update(**kwargs)
133
 
134
  def get_generation_config(self) -> dict:
 
152
  def infer_aggregate(
153
  self,
154
  text_list: list,
155
+ instruction: str = None, # Kept for backward compatibility but not used
156
  **kwargs,
157
  ) -> str:
158
+ """
159
  infer_aggregate - infers a consolidated summary from a list of texts.
160
 
161
  Args:
162
  text_list (list): The texts to summarize.
163
+ instruction (str): Not used by this model, kept for compatibility.
164
  **kwargs: Additional parameters to update in the generation configuration.
165
 
166
  Returns:
167
  The generated summary.
168
  """
169
+ joined_text = "\n\n".join(text_list)
 
170
  if kwargs:
171
  self.update_generation_config(**kwargs)
172
  st = time.perf_counter()
173
  self.logger.info(f"inference on {len(text_list)} texts ...")
174
  result = self.aggregator(
175
+ joined_text,
176
  generation_config=self.aggregator.model.generation_config,
177
  )[0]["generated_text"]
178
  self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
179
  self.logger.info(
180
+ f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}"
181
  )
182
  self.logger.debug(f"Generated text:\n{result}")
183
 
app.py CHANGED
@@ -14,6 +14,7 @@ Optional Environment Variables:
14
  APP_MAX_WORDS (int): the maximum number of words to use for summarization
15
  APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR
16
  """
 
17
  import argparse
18
  import contextlib
19
  import gc
@@ -77,7 +78,7 @@ TOKEN_BATCH_OPTIONS = [
77
  ] # token batch sizes users can choose from
78
 
79
  SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
80
- AGGREGATE_MODEL = "MBZUAI/LaMini-Flan-T5-783M" # model to use for aggregation
81
 
82
  # if duplicating space: uncomment this line to adjust the max words
83
  # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
@@ -488,7 +489,7 @@ if __name__ == "__main__":
488
  with demo:
489
  gr.Markdown(
490
  """# Document Summarization with Long-Document Transformers
491
-
492
  An example use case for fine-tuned long document transformers. Model(s) are trained on [book summaries](https://hf.co/datasets/kmfoda/booksum). Architectures [in this demo](https://hf.co/spaces/pszemraj/document-summarization) are [LongT5-base](https://hf.co/pszemraj/long-t5-tglobal-base-16384-book-summary) and [Pegasus-X-Large](https://hf.co/pszemraj/pegasus-x-large-book-summary).
493
 
494
  **Want more performance?** Run this demo from a free [Google Colab GPU](https://colab.research.google.com/gist/pszemraj/52f67cf7326e780155812a6a1f9bb724/document-summarization-on-gpu.ipynb)
@@ -497,7 +498,7 @@ if __name__ == "__main__":
497
  with gr.Column():
498
  gr.Markdown(
499
  """## Load Inputs & Select Parameters
500
-
501
  Enter/paste text below, or upload a file. Pick a model & adjust params (_optional_), and press **Summarize!**
502
 
503
  See [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for details.
@@ -596,8 +597,9 @@ if __name__ == "__main__":
596
  )
597
 
598
  with gr.Column():
599
- gr.Markdown("""### Advanced Settings
600
-
 
601
  Refer to [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for what these are, and how they impact _quality_ and _speed_.
602
  """
603
  )
 
14
  APP_MAX_WORDS (int): the maximum number of words to use for summarization
15
  APP_OCR_MAX_PAGES (int): the maximum number of pages to use for OCR
16
  """
17
+
18
  import argparse
19
  import contextlib
20
  import gc
 
78
  ] # token batch sizes users can choose from
79
 
80
  SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
81
+ AGGREGATE_MODEL = "pszemraj/bart-large-summary-map-reduce" # map-reduce model
82
 
83
  # if duplicating space: uncomment this line to adjust the max words
84
  # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
 
489
  with demo:
490
  gr.Markdown(
491
  """# Document Summarization with Long-Document Transformers
492
+
493
  An example use case for fine-tuned long document transformers. Model(s) are trained on [book summaries](https://hf.co/datasets/kmfoda/booksum). Architectures [in this demo](https://hf.co/spaces/pszemraj/document-summarization) are [LongT5-base](https://hf.co/pszemraj/long-t5-tglobal-base-16384-book-summary) and [Pegasus-X-Large](https://hf.co/pszemraj/pegasus-x-large-book-summary).
494
 
495
  **Want more performance?** Run this demo from a free [Google Colab GPU](https://colab.research.google.com/gist/pszemraj/52f67cf7326e780155812a6a1f9bb724/document-summarization-on-gpu.ipynb)
 
498
  with gr.Column():
499
  gr.Markdown(
500
  """## Load Inputs & Select Parameters
501
+
502
  Enter/paste text below, or upload a file. Pick a model & adjust params (_optional_), and press **Summarize!**
503
 
504
  See [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for details.
 
597
  )
598
 
599
  with gr.Column():
600
+ gr.Markdown(
601
+ """### Advanced Settings
602
+
603
  Refer to [the guide doc](https://gist.github.com/pszemraj/722a7ba443aa3a671b02d87038375519) for what these are, and how they impact _quality_ and _speed_.
604
  """
605
  )