ctranslate2-4you commited on
Commit
6c59ef9
·
verified ·
1 Parent(s): 979938b

Fix code to comport with newer Transformers library

Browse files

This fixes the multiple warnings and notices that the current source code gives due to not being updated for newer Transformers versions. I realize that you guys have released a new model that is "HF compatible," but unless/until it's added to Transformers 4.49, this should resolve all issues. Plus, I'd like to test the new model before switching to it. I noticed some errors in OCR when comparing to this original model so...here's to hoping you guys work out the bugs.

Files changed (1) hide show
  1. modeling_GOT.py +39 -54
modeling_GOT.py CHANGED
@@ -393,59 +393,46 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
393
  def prepare_inputs_for_generation(
394
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
  ):
396
- # Omit tokens covered by past_key_values
 
 
397
  if past_key_values is not None:
398
  if isinstance(past_key_values, Cache):
399
  cache_length = past_key_values.get_seq_length()
400
- past_length = past_key_values.seen_tokens
401
- max_cache_length = past_key_values.get_max_length()
 
402
  else:
403
- cache_length = past_length = past_key_values[0][0].shape[2]
 
404
  max_cache_length = None
405
 
406
- # Keep only the unprocessed tokens:
407
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
408
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
409
- # input)
410
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
411
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
412
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
413
- # input_ids based on the past_length.
414
- elif past_length < input_ids.shape[1]:
415
- input_ids = input_ids[:, past_length:]
416
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
417
-
418
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
419
- if (
420
- max_cache_length is not None
421
- and attention_mask is not None
422
- and cache_length + input_ids.shape[1] > max_cache_length
423
- ):
424
- attention_mask = attention_mask[:, -max_cache_length:]
425
 
426
  position_ids = kwargs.get("position_ids", None)
427
  if attention_mask is not None and position_ids is None:
428
- # create position_ids on the fly for batch generation
429
  position_ids = attention_mask.long().cumsum(-1) - 1
430
  position_ids.masked_fill_(attention_mask == 0, 1)
431
  if past_key_values:
432
- position_ids = position_ids[:, -input_ids.shape[1] :]
 
 
 
 
 
 
 
 
 
 
433
 
434
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
435
- if inputs_embeds is not None and past_key_values is None:
436
- model_inputs = {"inputs_embeds": inputs_embeds}
437
- else:
438
- model_inputs = {"input_ids": input_ids}
439
-
440
- model_inputs.update(
441
- {
442
- "position_ids": position_ids,
443
- "past_key_values": past_key_values,
444
- "use_cache": kwargs.get("use_cache"),
445
- "attention_mask": attention_mask,
446
- "images": kwargs.get("images", None),
447
- }
448
- )
449
  return model_inputs
450
 
451
  def initialize_vision_tokenizer(
@@ -536,7 +523,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
536
 
537
  conv_mpt = Conversation(
538
  system="""<|im_start|>system
539
- You should follow the instructions carefully and explain your answers in detail.""",
540
  # system = None,
541
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
542
  version="mpt",
@@ -728,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
728
  return processed_images
729
 
730
 
731
- def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
732
  # Model
733
  self.disable_torch_init()
734
  multi_page=False
@@ -778,21 +765,18 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
778
  image_tensor_1 = image_processor_high(image)
779
  image_list.append(image_tensor_1)
780
 
781
-
782
  image_list = torch.stack(image_list)
783
 
784
- print('====new images batch size======: \n',image_list.shape)
785
-
786
 
787
  if use_im_start_end:
788
  qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
789
  else:
790
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
791
 
792
-
793
  conv_mpt = Conversation(
794
  system="""<|im_start|>system
795
- You should follow the instructions carefully and explain your answers in detail.""",
796
  # system = None,
797
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
798
  version="mpt",
@@ -811,8 +795,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
811
  print(prompt)
812
 
813
  inputs = tokenizer([prompt])
814
-
815
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
 
816
 
817
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
818
  keywords = [stop_str]
@@ -824,25 +808,26 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
824
  output_ids = self.generate(
825
  input_ids,
826
  images=[image_list.half().cuda()],
 
827
  do_sample=False,
828
- num_beams = 1,
829
- # no_repeat_ngram_size = 20,
830
  streamer=streamer,
 
831
  max_new_tokens=4096,
832
  stopping_criteria=[stopping_criteria]
833
- )
 
834
  else:
835
  with torch.autocast("cuda", dtype=torch.bfloat16):
836
  output_ids = self.generate(
837
  input_ids,
838
  images=[image_list.half().cuda()],
 
839
  do_sample=False,
840
- num_beams = 1,
841
- # no_repeat_ngram_size = 20,
842
  # streamer=streamer,
 
843
  max_new_tokens=4096,
844
  stopping_criteria=[stopping_criteria]
845
- )
846
 
847
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
848
 
 
393
  def prepare_inputs_for_generation(
394
  self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
395
  ):
396
+ if attention_mask is None:
397
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
398
+
399
  if past_key_values is not None:
400
  if isinstance(past_key_values, Cache):
401
  cache_length = past_key_values.get_seq_length()
402
+ current_length = cache_length
403
+ max_cache_shape = past_key_values.get_max_cache_shape()
404
+ max_cache_length = max_cache_shape[1] if max_cache_shape else None
405
  else:
406
+ cache_length = past_key_values[0][0].shape[2]
407
+ current_length = cache_length
408
  max_cache_length = None
409
 
410
+ if attention_mask.shape[1] > input_ids.shape[1]:
411
+ input_ids = input_ids[:, -(attention_mask.shape[1] - cache_length):]
412
+ elif cache_length < input_ids.shape[1]:
413
+ input_ids = input_ids[:, cache_length:]
414
+
415
+ if max_cache_length is not None and attention_mask is not None:
416
+ if cache_length + input_ids.shape[1] > max_cache_length:
417
+ attention_mask = attention_mask[:, -max_cache_length:]
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  position_ids = kwargs.get("position_ids", None)
420
  if attention_mask is not None and position_ids is None:
 
421
  position_ids = attention_mask.long().cumsum(-1) - 1
422
  position_ids.masked_fill_(attention_mask == 0, 1)
423
  if past_key_values:
424
+ position_ids = position_ids[:, -input_ids.shape[1]:]
425
+
426
+ model_inputs = {
427
+ "input_ids": input_ids if inputs_embeds is None or past_key_values is not None else None,
428
+ "inputs_embeds": inputs_embeds if past_key_values is None else None,
429
+ "past_key_values": past_key_values,
430
+ "position_ids": position_ids,
431
+ "attention_mask": attention_mask,
432
+ "images": kwargs.get("images", None),
433
+ "use_cache": kwargs.get("use_cache", True)
434
+ }
435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  return model_inputs
437
 
438
  def initialize_vision_tokenizer(
 
523
 
524
  conv_mpt = Conversation(
525
  system="""<|im_start|>system
526
+ You should follow the instructions carefully and explain your answers in detail.""",
527
  # system = None,
528
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
529
  version="mpt",
 
715
  return processed_images
716
 
717
 
718
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag=False):
719
  # Model
720
  self.disable_torch_init()
721
  multi_page=False
 
765
  image_tensor_1 = image_processor_high(image)
766
  image_list.append(image_tensor_1)
767
 
 
768
  image_list = torch.stack(image_list)
769
 
770
+ # print('====new images batch size======: \n',image_list.shape)
 
771
 
772
  if use_im_start_end:
773
  qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
774
  else:
775
  qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
776
 
 
777
  conv_mpt = Conversation(
778
  system="""<|im_start|>system
779
+ You should follow the instructions carefully and explain your answers in detail.""",
780
  # system = None,
781
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
782
  version="mpt",
 
795
  print(prompt)
796
 
797
  inputs = tokenizer([prompt])
 
798
  input_ids = torch.as_tensor(inputs.input_ids).cuda()
799
+ attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=input_ids.device)
800
 
801
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
802
  keywords = [stop_str]
 
808
  output_ids = self.generate(
809
  input_ids,
810
  images=[image_list.half().cuda()],
811
+ attention_mask=attention_mask,
812
  do_sample=False,
 
 
813
  streamer=streamer,
814
+ num_beams=1,
815
  max_new_tokens=4096,
816
  stopping_criteria=[stopping_criteria]
817
+ )
818
+
819
  else:
820
  with torch.autocast("cuda", dtype=torch.bfloat16):
821
  output_ids = self.generate(
822
  input_ids,
823
  images=[image_list.half().cuda()],
824
+ attention_mask=attention_mask,
825
  do_sample=False,
 
 
826
  # streamer=streamer,
827
+ num_beams=1,
828
  max_new_tokens=4096,
829
  stopping_criteria=[stopping_criteria]
830
+ )
831
 
832
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
833