ctranslate2-4you commited on
Commit
6c3d846
·
verified ·
1 Parent(s): 45ad1fb

use customized code

Browse files
Files changed (4) hide show
  1. got_vision_b.py +0 -10
  2. modeling_GOT.py +74 -142
  3. render_tools.py +0 -25
  4. tokenization_qwen.py +4 -8
got_vision_b.py CHANGED
@@ -129,7 +129,6 @@ class ImageEncoderViT(nn.Module):
129
  LayerNorm2d(out_chans),
130
  )
131
 
132
-
133
  self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
134
  self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
135
 
@@ -145,7 +144,6 @@ class ImageEncoderViT(nn.Module):
145
  x = self.net_2(x)
146
  x = self.net_3(x)
147
 
148
-
149
  return x
150
 
151
 
@@ -272,7 +270,6 @@ class Attention(nn.Module):
272
 
273
  return x
274
 
275
-
276
  def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
277
  """
278
  Partition into non-overlapping windows with padding if needed.
@@ -296,7 +293,6 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
296
  windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
297
  return windows, (Hp, Wp)
298
 
299
-
300
  def window_unpartition(
301
  windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
302
  ) -> torch.Tensor:
@@ -321,7 +317,6 @@ def window_unpartition(
321
  x = x[:, :H, :W, :].contiguous()
322
  return x
323
 
324
-
325
  def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
326
  """
327
  Get relative positional embeddings according to the relative positions of
@@ -354,7 +349,6 @@ def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor
354
 
355
  return rel_pos_resized[relative_coords.long()]
356
 
357
-
358
  def add_decomposed_rel_pos(
359
  attn: torch.Tensor,
360
  q: torch.Tensor,
@@ -425,8 +419,6 @@ class PatchEmbed(nn.Module):
425
  x = x.permute(0, 2, 3, 1)
426
  return x
427
 
428
-
429
-
430
  def build_GOT_vit_b(checkpoint=None):
431
  return _build_GOT_vision(
432
  encoder_embed_dim=768,
@@ -436,7 +428,6 @@ def build_GOT_vit_b(checkpoint=None):
436
  checkpoint=checkpoint,
437
  )
438
 
439
-
440
  def _build_GOT_vision(
441
  encoder_embed_dim,
442
  encoder_depth,
@@ -462,7 +453,6 @@ def _build_GOT_vision(
462
  window_size=14,
463
  out_chans=prompt_embed_dim,
464
  )
465
-
466
 
467
  return image_encoder
468
 
 
129
  LayerNorm2d(out_chans),
130
  )
131
 
 
132
  self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
133
  self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
134
 
 
144
  x = self.net_2(x)
145
  x = self.net_3(x)
146
 
 
147
  return x
148
 
149
 
 
270
 
271
  return x
272
 
 
273
  def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
274
  """
275
  Partition into non-overlapping windows with padding if needed.
 
293
  windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
294
  return windows, (Hp, Wp)
295
 
 
296
  def window_unpartition(
297
  windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
298
  ) -> torch.Tensor:
 
317
  x = x[:, :H, :W, :].contiguous()
318
  return x
319
 
 
320
  def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
321
  """
322
  Get relative positional embeddings according to the relative positions of
 
349
 
350
  return rel_pos_resized[relative_coords.long()]
351
 
 
352
  def add_decomposed_rel_pos(
353
  attn: torch.Tensor,
354
  q: torch.Tensor,
 
419
  x = x.permute(0, 2, 3, 1)
420
  return x
421
 
 
 
422
  def build_GOT_vit_b(checkpoint=None):
423
  return _build_GOT_vision(
424
  encoder_embed_dim=768,
 
428
  checkpoint=checkpoint,
429
  )
430
 
 
431
  def _build_GOT_vision(
432
  encoder_embed_dim,
433
  encoder_depth,
 
453
  window_size=14,
454
  out_chans=prompt_embed_dim,
455
  )
 
456
 
457
  return image_encoder
458
 
modeling_GOT.py CHANGED
@@ -12,7 +12,6 @@ from .got_vision_b import build_GOT_vit_b
12
  from torchvision import transforms
13
  from torchvision.transforms.functional import InterpolationMode
14
  import dataclasses
15
- ###
16
 
17
  DEFAULT_IMAGE_TOKEN = "<image>"
18
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
@@ -20,6 +19,15 @@ DEFAULT_IM_START_TOKEN = '<img>'
20
  DEFAULT_IM_END_TOKEN = '</img>'
21
 
22
  from enum import auto, Enum
 
 
 
 
 
 
 
 
 
23
  class SeparatorStyle(Enum):
24
  """Different separator style."""
25
  SINGLE = auto()
@@ -79,7 +87,6 @@ class Conversation:
79
  else:
80
  raise ValueError(f"Invalid style: {self.sep_style}")
81
 
82
-
83
  def append_message(self, role, message):
84
  self.messages.append([role, message])
85
 
@@ -94,7 +101,6 @@ class Conversation:
94
  sep2=self.sep2)
95
 
96
 
97
-
98
  class KeywordsStoppingCriteria(StoppingCriteria):
99
  def __init__(self, keywords, tokenizer, input_ids):
100
  self.keywords = keywords
@@ -116,7 +122,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
116
  if keyword in outputs:
117
  return True
118
  return False
119
-
120
 
121
  class GOTImageEvalProcessor:
122
  def __init__(self, image_size=384, mean=None, std=None):
@@ -140,7 +146,6 @@ class GOTImageEvalProcessor:
140
  return self.transform(item)
141
 
142
 
143
-
144
  class GOTConfig(Qwen2Config):
145
  model_type = "GOT"
146
 
@@ -155,7 +160,6 @@ class GOTQwenModel(Qwen2Model):
155
 
156
  self.mm_projector_vary = nn.Linear(1024, 1024)
157
 
158
-
159
  def initialize_vision_modules(
160
  self,
161
  vision_tower,
@@ -167,14 +171,12 @@ class GOTQwenModel(Qwen2Model):
167
  device="cuda"
168
  ):
169
 
170
-
171
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
-
173
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
 
175
  self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
176
 
177
-
178
  image_token_len = 256
179
 
180
  self.config.vision_tower = vision_tower
@@ -184,13 +186,12 @@ class GOTQwenModel(Qwen2Model):
184
 
185
  self.config.vision_select_layer = vision_select_layer
186
  self.config.freeze_vision_tower = freeze_vision_tower
187
-
188
  return dict(
189
  image_processor_high=image_processor_high,
190
  image_token_len=image_token_len,
191
  )
192
-
193
-
194
  def forward(
195
  self,
196
  input_ids: torch.LongTensor = None,
@@ -205,7 +206,6 @@ class GOTQwenModel(Qwen2Model):
205
  return_dict: Optional[bool] = None,
206
  ) -> Union[Tuple, BaseModelOutputWithPast]:
207
 
208
- # HACK: replace back original embeddings for LLaVA pretraining
209
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
210
  if orig_embeds_params is not None:
211
  with torch.no_grad():
@@ -214,10 +214,8 @@ class GOTQwenModel(Qwen2Model):
214
  if inputs_embeds is None:
215
  inputs_embeds = self.embed_tokens(input_ids)
216
 
217
-
218
  vision_tower_high = getattr(self, 'vision_tower_high', None)
219
 
220
-
221
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
222
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
223
 
@@ -232,9 +230,9 @@ class GOTQwenModel(Qwen2Model):
232
  im_start_token = 151857
233
 
234
  im_end_token = 151858
235
-
236
  image_features = []
237
-
238
  for image in images:
239
  P, C, H, W = image.shape
240
  if P == 1:
@@ -249,7 +247,7 @@ class GOTQwenModel(Qwen2Model):
249
  image_patches_features = []
250
  for image_patch in image_patches:
251
  image_p = torch.stack([image_patch])
252
-
253
  with torch.set_grad_enabled(False):
254
  cnn_feature_p = vision_tower_high(image_p)
255
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
@@ -258,7 +256,6 @@ class GOTQwenModel(Qwen2Model):
258
  image_feature = torch.cat(image_patches_features, dim=1)
259
  image_features.append(image_feature)
260
 
261
-
262
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
263
  dummy_image_features = dummy_image_features_2
264
  use_im_start_end = True
@@ -272,7 +269,7 @@ class GOTQwenModel(Qwen2Model):
272
  if use_im_start_end:
273
  if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
274
  raise ValueError("The number of image start tokens and image end tokens should be the same.")
275
-
276
  image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
277
  for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
278
  per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
@@ -280,7 +277,7 @@ class GOTQwenModel(Qwen2Model):
280
 
281
  if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
282
  raise ValueError("The image end token should follow the image start token.")
283
-
284
  cur_input_embeds = torch.cat(
285
  (
286
  cur_input_embeds[:image_start_token_pos+1],
@@ -290,7 +287,6 @@ class GOTQwenModel(Qwen2Model):
290
  dim=0
291
  )
292
 
293
-
294
  new_input_embeds.append(cur_input_embeds)
295
  else:
296
  raise NotImplementedError
@@ -305,10 +301,8 @@ class GOTQwenModel(Qwen2Model):
305
  )
306
 
307
 
308
-
309
  class GOTQwenForCausalLM(Qwen2ForCausalLM):
310
  config_class = GOTConfig
311
- # supports_gradient_checkpointing = True
312
 
313
  def __init__(self, config):
314
  super(Qwen2ForCausalLM, self).__init__(config)
@@ -317,7 +311,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
317
  self.vocab_size = config.vocab_size
318
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
319
 
320
- # Initialize weights and apply final processing
321
  self.post_init()
322
 
323
  def get_model(self):
@@ -336,7 +329,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
336
  output_hidden_states: Optional[bool] = None,
337
  images: Optional[torch.FloatTensor] = None,
338
  return_dict: Optional[bool] = None,
339
-
340
  ) -> Union[Tuple, CausalLMOutputWithPast]:
341
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
342
  output_hidden_states = (
@@ -362,18 +355,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
362
  logits = self.lm_head(hidden_states)
363
  logits = logits.float()
364
 
365
- # logits
366
-
367
  loss = None
368
  if labels is not None:
369
- # Shift so that tokens < n predict n
370
  shift_logits = logits[..., :-1, :].contiguous()
371
  shift_labels = labels[..., 1:].contiguous()
372
- # Flatten the tokens
373
  loss_fct = CrossEntropyLoss()
374
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
375
  shift_labels = shift_labels.view(-1)
376
- # Enable model parallelism
377
  shift_labels = shift_labels.to(shift_logits.device)
378
  loss = loss_fct(shift_logits, shift_labels)
379
 
@@ -389,63 +377,49 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
389
  attentions=outputs.attentions,
390
  )
391
 
392
-
393
  def prepare_inputs_for_generation(
394
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
  ):
396
- # Omit tokens covered by past_key_values
 
 
397
  if past_key_values is not None:
398
  if isinstance(past_key_values, Cache):
399
  cache_length = past_key_values.get_seq_length()
400
- past_length = past_key_values.seen_tokens
401
- max_cache_length = past_key_values.get_max_length()
 
402
  else:
403
- cache_length = past_length = past_key_values[0][0].shape[2]
 
404
  max_cache_length = None
405
 
406
- # Keep only the unprocessed tokens:
407
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
408
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
409
- # input)
410
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
411
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
412
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
413
- # input_ids based on the past_length.
414
- elif past_length < input_ids.shape[1]:
415
- input_ids = input_ids[:, past_length:]
416
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
417
-
418
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
419
- if (
420
- max_cache_length is not None
421
- and attention_mask is not None
422
- and cache_length + input_ids.shape[1] > max_cache_length
423
- ):
424
- attention_mask = attention_mask[:, -max_cache_length:]
425
 
426
  position_ids = kwargs.get("position_ids", None)
427
  if attention_mask is not None and position_ids is None:
428
- # create position_ids on the fly for batch generation
429
  position_ids = attention_mask.long().cumsum(-1) - 1
430
  position_ids.masked_fill_(attention_mask == 0, 1)
431
  if past_key_values:
432
- position_ids = position_ids[:, -input_ids.shape[1] :]
 
 
 
 
 
 
 
 
 
 
433
 
434
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
435
- if inputs_embeds is not None and past_key_values is None:
436
- model_inputs = {"inputs_embeds": inputs_embeds}
437
- else:
438
- model_inputs = {"input_ids": input_ids}
439
-
440
- model_inputs.update(
441
- {
442
- "position_ids": position_ids,
443
- "past_key_values": past_key_values,
444
- "use_cache": kwargs.get("use_cache"),
445
- "attention_mask": attention_mask,
446
- "images": kwargs.get("images", None),
447
- }
448
- )
449
  return model_inputs
450
 
451
  def initialize_vision_tokenizer(
@@ -457,7 +431,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
457
  ):
458
  config = self.get_model().config
459
 
460
-
461
  self.resize_token_embeddings(len(tokenizer))
462
 
463
  config.im_patch_token = 151859
@@ -488,7 +461,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
488
 
489
  self.disable_torch_init()
490
 
491
-
492
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
 
494
  use_im_start_end = True
@@ -501,7 +473,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
501
  image = self.load_image(image_file)
502
 
503
  w, h = image.size
504
-
505
  if ocr_type == 'format':
506
  qs = 'OCR with format: '
507
  else:
@@ -533,10 +505,9 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
533
  else:
534
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
535
 
536
-
537
  conv_mpt = Conversation(
538
  system="""<|im_start|>system
539
- You should follow the instructions carefully and explain your answers in detail.""",
540
  # system = None,
541
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
542
  version="mpt",
@@ -566,7 +537,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
  if stream_flag:
569
- with torch.autocast("cuda", dtype=torch.bfloat16):
570
  output_ids = self.generate(
571
  input_ids,
572
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
@@ -578,7 +549,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
578
  stopping_criteria=[stopping_criteria]
579
  )
580
  else:
581
- with torch.autocast("cuda", dtype=torch.bfloat16):
582
  output_ids = self.generate(
583
  input_ids,
584
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
@@ -589,9 +560,9 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
589
  max_new_tokens=4096,
590
  stopping_criteria=[stopping_criteria]
591
  )
592
-
593
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
594
-
595
  if outputs.endswith(stop_str):
596
  outputs = outputs[:-len(stop_str)]
597
  outputs = outputs.strip()
@@ -599,24 +570,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
599
 
600
  if render:
601
  print('==============rendering===============')
602
- from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
603
 
604
  if '**kern' in outputs:
605
- import verovio
606
- tk = verovio.toolkit()
607
- tk.loadData(outputs)
608
- tk.setOptions({"pageWidth": 2100, "footer": 'none',
609
- 'barLineWidth': 0.5, 'beamMaxSlope': 15,
610
- 'staffLineWidth': 0.2, 'spacingStaff': 6})
611
- tk.getPageCount()
612
- svg = tk.renderToSVG()
613
- svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
614
-
615
- svg_to_html(svg, save_render_file)
616
 
617
  if ocr_type == 'format' and '**kern' not in outputs:
618
 
619
-
620
  if '\\begin{tikzpicture}' not in outputs:
621
  html_path_2 = save_render_file
622
  right_num = outputs.count('\\right')
@@ -625,16 +585,14 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
625
  if right_num != left_num:
626
  outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
627
 
628
-
629
  outputs = outputs.replace('"', '``').replace('$', '')
630
 
631
  outputs_list = outputs.split('\n')
632
  gt= ''
633
  for out in outputs_list:
634
  gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
635
-
636
- gt = gt[:-2]
637
 
 
638
 
639
  lines = content_mmd_to_html
640
  lines = lines.split("const text =")
@@ -652,7 +610,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
652
  out = out[:-1]
653
  if out is None:
654
  break
655
-
656
  if out:
657
  if out[-1] != ';':
658
  gt += out[:-1] + ';\n'
@@ -661,7 +619,6 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
661
  else:
662
  gt += out + '\n'
663
 
664
-
665
  lines = tik_html
666
  lines = lines.split("const text =")
667
  new_web = lines[0] + gt + lines[1]
@@ -671,7 +628,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
671
  return response_str
672
 
673
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
674
-
675
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
676
  best_ratio_diff = float('inf')
677
  best_ratio = (1, 1)
@@ -685,30 +642,24 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
685
  elif ratio_diff == best_ratio_diff:
686
  if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
687
  best_ratio = ratio
688
- # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
689
  return best_ratio
690
-
691
  orig_width, orig_height = image.size
692
  aspect_ratio = orig_width / orig_height
693
 
694
- # calculate the existing image aspect ratio
695
  target_ratios = set(
696
  (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
697
  i * j <= max_num and i * j >= min_num)
698
- # print(target_ratios)
699
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
700
 
701
- # find the closest aspect ratio to the target
702
  target_aspect_ratio = find_closest_aspect_ratio(
703
  aspect_ratio, target_ratios, orig_width, orig_height, image_size)
704
 
705
- # print(target_aspect_ratio)
706
- # calculate the target width and height
707
  target_width = image_size * target_aspect_ratio[0]
708
  target_height = image_size * target_aspect_ratio[1]
709
  blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
710
 
711
- # resize the image
712
  resized_img = image.resize((target_width, target_height))
713
  processed_images = []
714
  for i in range(blocks):
@@ -718,7 +669,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
718
  ((i % (target_width // image_size)) + 1) * image_size,
719
  ((i // (target_width // image_size)) + 1) * image_size
720
  )
721
- # split the image
722
  split_img = resized_img.crop(box)
723
  processed_images.append(split_img)
724
  assert len(processed_images) == blocks
@@ -727,40 +678,26 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
727
  processed_images.append(thumbnail_img)
728
  return processed_images
729
 
730
-
731
- def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
732
- # Model
733
  self.disable_torch_init()
734
  multi_page=False
735
 
736
-
737
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
738
 
739
  use_im_start_end = True
740
 
741
-
742
  image_token_len = 256
743
 
744
  image_list = []
745
 
746
- # if len(image_file_list)>1:
747
- # multi_page = True
748
-
749
  if multi_page:
750
  qs = 'OCR with format across multi pages: '
751
- # only for png files
752
- # import glob
753
- # from natsort import natsorted
754
- # patches = glob.glob(image_file + '/*png')
755
  patches = image_file
756
- # patches = natsorted(patches)
757
  sub_images = []
758
  for sub_image in patches:
759
  sub_images.append(self.load_image(sub_image))
760
 
761
  ll = len(patches)
762
- # print(patches)
763
- # print("len ll: ", ll)
764
 
765
  else:
766
  if ocr_type == 'format':
@@ -778,21 +715,16 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
778
  image_tensor_1 = image_processor_high(image)
779
  image_list.append(image_tensor_1)
780
 
781
-
782
  image_list = torch.stack(image_list)
783
 
784
- print('====new images batch size======: \n',image_list.shape)
785
-
786
-
787
  if use_im_start_end:
788
  qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
789
  else:
790
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
791
 
792
-
793
  conv_mpt = Conversation(
794
  system="""<|im_start|>system
795
- You should follow the instructions carefully and explain your answers in detail.""",
796
  # system = None,
797
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
798
  version="mpt",
@@ -811,8 +743,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
811
  print(prompt)
812
 
813
  inputs = tokenizer([prompt])
814
-
815
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
 
816
 
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
@@ -820,32 +752,33 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
820
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
  if stream_flag:
823
- with torch.autocast("cuda", dtype=torch.bfloat16):
824
  output_ids = self.generate(
825
  input_ids,
826
  images=[image_list.half().cuda()],
 
827
  do_sample=False,
828
- num_beams = 1,
829
- # no_repeat_ngram_size = 20,
830
  streamer=streamer,
 
831
  max_new_tokens=4096,
832
  stopping_criteria=[stopping_criteria]
833
- )
 
834
  else:
835
- with torch.autocast("cuda", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
  images=[image_list.half().cuda()],
 
839
  do_sample=False,
840
- num_beams = 1,
841
- # no_repeat_ngram_size = 20,
842
  # streamer=streamer,
 
843
  max_new_tokens=4096,
844
  stopping_criteria=[stopping_criteria]
845
- )
846
 
847
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
848
-
849
  if outputs.endswith(stop_str):
850
  outputs = outputs[:-len(stop_str)]
851
  outputs = outputs.strip()
@@ -861,14 +794,13 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
861
  if right_num != left_num:
862
  outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
863
 
864
-
865
  outputs = outputs.replace('"', '``').replace('$', '')
866
 
867
  outputs_list = outputs.split('\n')
868
  gt= ''
869
  for out in outputs_list:
870
  gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
871
-
872
  gt = gt[:-2]
873
 
874
  lines = content_mmd_to_html
 
12
  from torchvision import transforms
13
  from torchvision.transforms.functional import InterpolationMode
14
  import dataclasses
 
15
 
16
  DEFAULT_IMAGE_TOKEN = "<image>"
17
  DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
 
19
  DEFAULT_IM_END_TOKEN = '</img>'
20
 
21
  from enum import auto, Enum
22
+
23
+ def has_bfloat16_support():
24
+ if not torch.cuda.is_available():
25
+ return False
26
+ capability = torch.cuda.get_device_capability()
27
+ return capability >= (8, 0)
28
+
29
+ SUPPORTED_DTYPE = torch.bfloat16 if has_bfloat16_support() else torch.float16
30
+
31
  class SeparatorStyle(Enum):
32
  """Different separator style."""
33
  SINGLE = auto()
 
87
  else:
88
  raise ValueError(f"Invalid style: {self.sep_style}")
89
 
 
90
  def append_message(self, role, message):
91
  self.messages.append([role, message])
92
 
 
101
  sep2=self.sep2)
102
 
103
 
 
104
  class KeywordsStoppingCriteria(StoppingCriteria):
105
  def __init__(self, keywords, tokenizer, input_ids):
106
  self.keywords = keywords
 
122
  if keyword in outputs:
123
  return True
124
  return False
125
+
126
 
127
  class GOTImageEvalProcessor:
128
  def __init__(self, image_size=384, mean=None, std=None):
 
146
  return self.transform(item)
147
 
148
 
 
149
  class GOTConfig(Qwen2Config):
150
  model_type = "GOT"
151
 
 
160
 
161
  self.mm_projector_vary = nn.Linear(1024, 1024)
162
 
 
163
  def initialize_vision_modules(
164
  self,
165
  vision_tower,
 
171
  device="cuda"
172
  ):
173
 
 
174
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
175
+
176
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
177
 
178
  self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
179
 
 
180
  image_token_len = 256
181
 
182
  self.config.vision_tower = vision_tower
 
186
 
187
  self.config.vision_select_layer = vision_select_layer
188
  self.config.freeze_vision_tower = freeze_vision_tower
189
+
190
  return dict(
191
  image_processor_high=image_processor_high,
192
  image_token_len=image_token_len,
193
  )
194
+
 
195
  def forward(
196
  self,
197
  input_ids: torch.LongTensor = None,
 
206
  return_dict: Optional[bool] = None,
207
  ) -> Union[Tuple, BaseModelOutputWithPast]:
208
 
 
209
  orig_embeds_params = getattr(self, 'orig_embeds_params', None)
210
  if orig_embeds_params is not None:
211
  with torch.no_grad():
 
214
  if inputs_embeds is None:
215
  inputs_embeds = self.embed_tokens(input_ids)
216
 
 
217
  vision_tower_high = getattr(self, 'vision_tower_high', None)
218
 
 
219
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
220
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
221
 
 
230
  im_start_token = 151857
231
 
232
  im_end_token = 151858
233
+
234
  image_features = []
235
+
236
  for image in images:
237
  P, C, H, W = image.shape
238
  if P == 1:
 
247
  image_patches_features = []
248
  for image_patch in image_patches:
249
  image_p = torch.stack([image_patch])
250
+
251
  with torch.set_grad_enabled(False):
252
  cnn_feature_p = vision_tower_high(image_p)
253
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
 
256
  image_feature = torch.cat(image_patches_features, dim=1)
257
  image_features.append(image_feature)
258
 
 
259
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
260
  dummy_image_features = dummy_image_features_2
261
  use_im_start_end = True
 
269
  if use_im_start_end:
270
  if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
271
  raise ValueError("The number of image start tokens and image end tokens should be the same.")
272
+
273
  image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
274
  for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
275
  per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
 
277
 
278
  if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
279
  raise ValueError("The image end token should follow the image start token.")
280
+
281
  cur_input_embeds = torch.cat(
282
  (
283
  cur_input_embeds[:image_start_token_pos+1],
 
287
  dim=0
288
  )
289
 
 
290
  new_input_embeds.append(cur_input_embeds)
291
  else:
292
  raise NotImplementedError
 
301
  )
302
 
303
 
 
304
  class GOTQwenForCausalLM(Qwen2ForCausalLM):
305
  config_class = GOTConfig
 
306
 
307
  def __init__(self, config):
308
  super(Qwen2ForCausalLM, self).__init__(config)
 
311
  self.vocab_size = config.vocab_size
312
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
313
 
 
314
  self.post_init()
315
 
316
  def get_model(self):
 
329
  output_hidden_states: Optional[bool] = None,
330
  images: Optional[torch.FloatTensor] = None,
331
  return_dict: Optional[bool] = None,
332
+
333
  ) -> Union[Tuple, CausalLMOutputWithPast]:
334
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
335
  output_hidden_states = (
 
355
  logits = self.lm_head(hidden_states)
356
  logits = logits.float()
357
 
 
 
358
  loss = None
359
  if labels is not None:
 
360
  shift_logits = logits[..., :-1, :].contiguous()
361
  shift_labels = labels[..., 1:].contiguous()
 
362
  loss_fct = CrossEntropyLoss()
363
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
364
  shift_labels = shift_labels.view(-1)
 
365
  shift_labels = shift_labels.to(shift_logits.device)
366
  loss = loss_fct(shift_logits, shift_labels)
367
 
 
377
  attentions=outputs.attentions,
378
  )
379
 
 
380
  def prepare_inputs_for_generation(
381
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
382
  ):
383
+ if attention_mask is None:
384
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
385
+
386
  if past_key_values is not None:
387
  if isinstance(past_key_values, Cache):
388
  cache_length = past_key_values.get_seq_length()
389
+ current_length = cache_length
390
+ max_cache_shape = past_key_values.get_max_cache_shape()
391
+ max_cache_length = max_cache_shape[1] if max_cache_shape else None
392
  else:
393
+ cache_length = past_key_values[0][0].shape[2]
394
+ current_length = cache_length
395
  max_cache_length = None
396
 
397
+ if attention_mask.shape[1] > input_ids.shape[1]:
398
+ input_ids = input_ids[:, -(attention_mask.shape[1] - cache_length):]
399
+ elif cache_length < input_ids.shape[1]:
400
+ input_ids = input_ids[:, cache_length:]
401
+
402
+ if max_cache_length is not None and attention_mask is not None:
403
+ if cache_length + input_ids.shape[1] > max_cache_length:
404
+ attention_mask = attention_mask[:, -max_cache_length:]
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  position_ids = kwargs.get("position_ids", None)
407
  if attention_mask is not None and position_ids is None:
 
408
  position_ids = attention_mask.long().cumsum(-1) - 1
409
  position_ids.masked_fill_(attention_mask == 0, 1)
410
  if past_key_values:
411
+ position_ids = position_ids[:, -input_ids.shape[1]:]
412
+
413
+ model_inputs = {
414
+ "input_ids": input_ids if inputs_embeds is None or past_key_values is not None else None,
415
+ "inputs_embeds": inputs_embeds if past_key_values is None else None,
416
+ "past_key_values": past_key_values,
417
+ "position_ids": position_ids,
418
+ "attention_mask": attention_mask,
419
+ "images": kwargs.get("images", None),
420
+ "use_cache": kwargs.get("use_cache", True)
421
+ }
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  return model_inputs
424
 
425
  def initialize_vision_tokenizer(
 
431
  ):
432
  config = self.get_model().config
433
 
 
434
  self.resize_token_embeddings(len(tokenizer))
435
 
436
  config.im_patch_token = 151859
 
461
 
462
  self.disable_torch_init()
463
 
 
464
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
465
 
466
  use_im_start_end = True
 
473
  image = self.load_image(image_file)
474
 
475
  w, h = image.size
476
+
477
  if ocr_type == 'format':
478
  qs = 'OCR with format: '
479
  else:
 
505
  else:
506
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
507
 
 
508
  conv_mpt = Conversation(
509
  system="""<|im_start|>system
510
+ You should follow the instructions carefully and explain your answers in detail.""",
511
  # system = None,
512
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
513
  version="mpt",
 
537
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
538
 
539
  if stream_flag:
540
+ with torch.autocast("cuda", dtype=SUPPORTED_DTYPE):
541
  output_ids = self.generate(
542
  input_ids,
543
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
 
549
  stopping_criteria=[stopping_criteria]
550
  )
551
  else:
552
+ with torch.autocast("cuda", dtype=SUPPORTED_DTYPE):
553
  output_ids = self.generate(
554
  input_ids,
555
  images=[image_tensor_1.unsqueeze(0).half().cuda()],
 
560
  max_new_tokens=4096,
561
  stopping_criteria=[stopping_criteria]
562
  )
563
+
564
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
565
+
566
  if outputs.endswith(stop_str):
567
  outputs = outputs[:-len(stop_str)]
568
  outputs = outputs.strip()
 
570
 
571
  if render:
572
  print('==============rendering===============')
573
+ from .render_tools import content_mmd_to_html, tik_html, translation_table
574
 
575
  if '**kern' in outputs:
576
+ print("Musical notation detected but Verovio rendering is disabled")
 
 
 
 
 
 
 
 
 
 
577
 
578
  if ocr_type == 'format' and '**kern' not in outputs:
579
 
 
580
  if '\\begin{tikzpicture}' not in outputs:
581
  html_path_2 = save_render_file
582
  right_num = outputs.count('\\right')
 
585
  if right_num != left_num:
586
  outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
587
 
 
588
  outputs = outputs.replace('"', '``').replace('$', '')
589
 
590
  outputs_list = outputs.split('\n')
591
  gt= ''
592
  for out in outputs_list:
593
  gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
 
 
594
 
595
+ gt = gt[:-2]
596
 
597
  lines = content_mmd_to_html
598
  lines = lines.split("const text =")
 
610
  out = out[:-1]
611
  if out is None:
612
  break
613
+
614
  if out:
615
  if out[-1] != ';':
616
  gt += out[:-1] + ';\n'
 
619
  else:
620
  gt += out + '\n'
621
 
 
622
  lines = tik_html
623
  lines = lines.split("const text =")
624
  new_web = lines[0] + gt + lines[1]
 
628
  return response_str
629
 
630
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
631
+
632
  def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
633
  best_ratio_diff = float('inf')
634
  best_ratio = (1, 1)
 
642
  elif ratio_diff == best_ratio_diff:
643
  if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
644
  best_ratio = ratio
 
645
  return best_ratio
646
+
647
  orig_width, orig_height = image.size
648
  aspect_ratio = orig_width / orig_height
649
 
 
650
  target_ratios = set(
651
  (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
652
  i * j <= max_num and i * j >= min_num)
653
+
654
  target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
655
 
 
656
  target_aspect_ratio = find_closest_aspect_ratio(
657
  aspect_ratio, target_ratios, orig_width, orig_height, image_size)
658
 
 
 
659
  target_width = image_size * target_aspect_ratio[0]
660
  target_height = image_size * target_aspect_ratio[1]
661
  blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
662
 
 
663
  resized_img = image.resize((target_width, target_height))
664
  processed_images = []
665
  for i in range(blocks):
 
669
  ((i % (target_width // image_size)) + 1) * image_size,
670
  ((i // (target_width // image_size)) + 1) * image_size
671
  )
672
+
673
  split_img = resized_img.crop(box)
674
  processed_images.append(split_img)
675
  assert len(processed_images) == blocks
 
678
  processed_images.append(thumbnail_img)
679
  return processed_images
680
 
681
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag=False):
 
 
682
  self.disable_torch_init()
683
  multi_page=False
684
 
 
685
  image_processor_high = GOTImageEvalProcessor(image_size=1024)
686
 
687
  use_im_start_end = True
688
 
 
689
  image_token_len = 256
690
 
691
  image_list = []
692
 
 
 
 
693
  if multi_page:
694
  qs = 'OCR with format across multi pages: '
 
 
 
 
695
  patches = image_file
 
696
  sub_images = []
697
  for sub_image in patches:
698
  sub_images.append(self.load_image(sub_image))
699
 
700
  ll = len(patches)
 
 
701
 
702
  else:
703
  if ocr_type == 'format':
 
715
  image_tensor_1 = image_processor_high(image)
716
  image_list.append(image_tensor_1)
717
 
 
718
  image_list = torch.stack(image_list)
719
 
 
 
 
720
  if use_im_start_end:
721
  qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
722
  else:
723
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
724
 
 
725
  conv_mpt = Conversation(
726
  system="""<|im_start|>system
727
+ You should follow the instructions carefully and explain your answers in detail.""",
728
  # system = None,
729
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
730
  version="mpt",
 
743
  print(prompt)
744
 
745
  inputs = tokenizer([prompt])
 
746
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
747
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
748
 
749
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
750
  keywords = [stop_str]
 
752
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
753
 
754
  if stream_flag:
755
+ with torch.autocast("cuda", dtype=SUPPORTED_DTYPE):
756
  output_ids = self.generate(
757
  input_ids,
758
  images=[image_list.half().cuda()],
759
+ attention_mask=attention_mask,
760
  do_sample=False,
 
 
761
  streamer=streamer,
762
+ num_beams=1,
763
  max_new_tokens=4096,
764
  stopping_criteria=[stopping_criteria]
765
+ )
766
+
767
  else:
768
+ with torch.autocast("cuda", dtype=SUPPORTED_DTYPE):
769
  output_ids = self.generate(
770
  input_ids,
771
  images=[image_list.half().cuda()],
772
+ attention_mask=attention_mask,
773
  do_sample=False,
 
 
774
  # streamer=streamer,
775
+ num_beams=1,
776
  max_new_tokens=4096,
777
  stopping_criteria=[stopping_criteria]
778
+ )
779
 
780
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
781
+
782
  if outputs.endswith(stop_str):
783
  outputs = outputs[:-len(stop_str)]
784
  outputs = outputs.strip()
 
794
  if right_num != left_num:
795
  outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
796
 
 
797
  outputs = outputs.replace('"', '``').replace('$', '')
798
 
799
  outputs_list = outputs.split('\n')
800
  gt= ''
801
  for out in outputs_list:
802
  gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
803
+
804
  gt = gt[:-2]
805
 
806
  lines = content_mmd_to_html
render_tools.py CHANGED
@@ -5,29 +5,6 @@ punctuation_dict = {
5
 
6
  }
7
  translation_table = str.maketrans(punctuation_dict)
8
-
9
- def svg_to_html(svg_content, output_filename):
10
-
11
- html_content = f"""
12
- <!DOCTYPE html>
13
- <html lang="en">
14
- <head>
15
- <meta charset="UTF-8">
16
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
17
- <title>SVG Embedded in HTML</title>
18
- </head>
19
- <body>
20
- <svg width="2100" height="15000" xmlns="http://www.w3.org/2000/svg">
21
- {svg_content}
22
- </svg>
23
- </body>
24
- </html>
25
- """
26
-
27
- with open(output_filename, 'w') as file:
28
- file.write(html_content)
29
-
30
-
31
 
32
  content_mmd_to_html = """<!DOCTYPE html>
33
  <html lang="en" data-lt-installed="true"><head>
@@ -71,7 +48,6 @@ content_mmd_to_html = """<!DOCTYPE html>
71
  """
72
 
73
 
74
-
75
  tik_html = """
76
  <!DOCTYPE html>
77
 
@@ -92,5 +68,4 @@ const text =
92
  </html>"""
93
 
94
 
95
-
96
  # print(tik_html)
 
5
 
6
  }
7
  translation_table = str.maketrans(punctuation_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  content_mmd_to_html = """<!DOCTYPE html>
10
  <html lang="en" data-lt-installed="true"><head>
 
48
  """
49
 
50
 
 
51
  tik_html = """
52
  <!DOCTYPE html>
53
 
 
68
  </html>"""
69
 
70
 
 
71
  # print(tik_html)
tokenization_qwen.py CHANGED
@@ -23,9 +23,6 @@ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s
23
  ENDOFTEXT = "<|endoftext|>"
24
  IMSTART = "<|im_start|>"
25
  IMEND = "<|im_end|>"
26
- # as the default behavior is changed to allow special tokens in
27
- # regular texts, the surface forms of special tokens need to be
28
- # as different as possible to minimize the impact
29
  EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30
  SPECIAL_TOKENS = (
31
  ENDOFTEXT,
@@ -81,9 +78,9 @@ class QWenTokenizer(PreTrainedTokenizer):
81
  image_pad_tag
82
  )
83
 
84
- self.errors = errors # how to handle errors in decoding
85
 
86
- self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
87
  self.special_tokens = {
88
  token: index
89
  for index, token in enumerate(
@@ -113,10 +110,10 @@ class QWenTokenizer(PreTrainedTokenizer):
113
 
114
  self.decoder = {
115
  v: k for k, v in self.mergeable_ranks.items()
116
- } # type: dict[int, bytes|str]
117
  self.decoder.update({v: k for k, v in self.special_tokens.items()})
118
 
119
- self.tokenizer = enc # type: tiktoken.Encoding
120
 
121
  self.eod_id = self.tokenizer.eot_token
122
  self.im_start_id = self.special_tokens[IMSTART]
@@ -196,7 +193,6 @@ class QWenTokenizer(PreTrainedTokenizer):
196
  tokens = []
197
  text = unicodedata.normalize("NFC", text)
198
 
199
- # this implementation takes a detour: text -> token id -> token surface forms
200
  for t in self.tokenizer.encode(
201
  text, allowed_special=allowed_special, disallowed_special=disallowed_special
202
  ):
 
23
  ENDOFTEXT = "<|endoftext|>"
24
  IMSTART = "<|im_start|>"
25
  IMEND = "<|im_end|>"
 
 
 
26
  EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
27
  SPECIAL_TOKENS = (
28
  ENDOFTEXT,
 
78
  image_pad_tag
79
  )
80
 
81
+ self.errors = errors
82
 
83
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
84
  self.special_tokens = {
85
  token: index
86
  for index, token in enumerate(
 
110
 
111
  self.decoder = {
112
  v: k for k, v in self.mergeable_ranks.items()
113
+ }
114
  self.decoder.update({v: k for k, v in self.special_tokens.items()})
115
 
116
+ self.tokenizer = enc
117
 
118
  self.eod_id = self.tokenizer.eot_token
119
  self.im_start_id = self.special_tokens[IMSTART]
 
193
  tokens = []
194
  text = unicodedata.normalize("NFC", text)
195
 
 
196
  for t in self.tokenizer.encode(
197
  text, allowed_special=allowed_special, disallowed_special=disallowed_special
198
  ):