make deepseekocr compatible with mps and latest transformers

#96
by hebangwen - opened
Files changed (2) hide show
  1. modeling_deepseekocr.py +22 -14
  2. modeling_deepseekv2.py +6 -7
modeling_deepseekocr.py CHANGED
@@ -24,6 +24,13 @@ import numpy as np
24
  import time
25
 
26
 
 
 
 
 
 
 
 
27
  def load_image(image_path):
28
 
29
  try:
@@ -502,7 +509,7 @@ class DeepseekOCRModel(DeepseekV2Model):
502
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
503
  # exit()
504
 
505
- inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch)
506
 
507
  idx += 1
508
 
@@ -622,8 +629,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
622
  if past_key_values is not None:
623
  if isinstance(past_key_values, Cache):
624
  cache_length = past_key_values.get_seq_length()
625
- past_length = past_key_values.seen_tokens
626
- max_cache_length = past_key_values.get_max_length()
627
  else:
628
  cache_length = past_length = past_key_values[0][0].shape[2]
629
  max_cache_length = None
@@ -645,6 +652,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
645
  max_cache_length is not None
646
  and attention_mask is not None
647
  and cache_length + input_ids.shape[1] > max_cache_length
 
648
  ):
649
  attention_mask = attention_mask[:, -max_cache_length:]
650
 
@@ -911,12 +919,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
911
 
912
  if not eval_mode:
913
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
914
- with torch.autocast("cuda", dtype=torch.bfloat16):
915
  with torch.no_grad():
916
  output_ids = self.generate(
917
- input_ids.unsqueeze(0).cuda(),
918
- images=[(images_crop.cuda(), images_ori.cuda())],
919
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
920
  images_spatial_crop = images_spatial_crop,
921
  # do_sample=False,
922
  # num_beams = 1,
@@ -929,12 +937,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
929
  )
930
 
931
  else:
932
- with torch.autocast("cuda", dtype=torch.bfloat16):
933
  with torch.no_grad():
934
  output_ids = self.generate(
935
- input_ids.unsqueeze(0).cuda(),
936
- images=[(images_crop.cuda(), images_ori.cuda())],
937
- images_seq_mask = images_seq_mask.unsqueeze(0).cuda(),
938
  images_spatial_crop = images_spatial_crop,
939
  # do_sample=False,
940
  # num_beams = 1,
@@ -947,7 +955,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
947
 
948
 
949
  if '<image>' in conversation[0]['content'] and eval_mode:
950
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
951
  stop_str = '<|end▁of▁sentence|>'
952
  if outputs.endswith(stop_str):
953
  outputs = outputs[:-len(stop_str)]
@@ -957,7 +965,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
957
  return outputs
958
 
959
  if '<image>' in conversation[0]['content'] and test_compress:
960
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
961
  pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
962
  print('='*50)
963
  print('image size: ', (w, h))
@@ -968,7 +976,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
968
 
969
 
970
  if '<image>' in conversation[0]['content'] and save_results:
971
- outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:])
972
  stop_str = '<|end▁of▁sentence|>'
973
 
974
  print('='*15 + 'save results:' + '='*15)
 
24
  import time
25
 
26
 
27
+ DEVICE = "cpu"
28
+ if torch.mps.is_available():
29
+ DEVICE = "mps"
30
+ elif torch.cuda.is_available():
31
+ DEVICE = "cuda"
32
+
33
+
34
  def load_image(image_path):
35
 
36
  try:
 
509
  images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
510
  # exit()
511
 
512
+ inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).to(DEVICE), images_in_this_batch)
513
 
514
  idx += 1
515
 
 
629
  if past_key_values is not None:
630
  if isinstance(past_key_values, Cache):
631
  cache_length = past_key_values.get_seq_length()
632
+ past_length = cache_length
633
+ max_cache_length = past_key_values.get_max_cache_shape()
634
  else:
635
  cache_length = past_length = past_key_values[0][0].shape[2]
636
  max_cache_length = None
 
652
  max_cache_length is not None
653
  and attention_mask is not None
654
  and cache_length + input_ids.shape[1] > max_cache_length
655
+ and max_cache_length > -1
656
  ):
657
  attention_mask = attention_mask[:, -max_cache_length:]
658
 
 
919
 
920
  if not eval_mode:
921
  streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
922
+ with torch.autocast(DEVICE, dtype=torch.bfloat16):
923
  with torch.no_grad():
924
  output_ids = self.generate(
925
+ input_ids.unsqueeze(0).to(DEVICE),
926
+ images=[(images_crop.to(DEVICE), images_ori.to(DEVICE))],
927
+ images_seq_mask = images_seq_mask.unsqueeze(0).to(DEVICE),
928
  images_spatial_crop = images_spatial_crop,
929
  # do_sample=False,
930
  # num_beams = 1,
 
937
  )
938
 
939
  else:
940
+ with torch.autocast(DEVICE, dtype=torch.bfloat16):
941
  with torch.no_grad():
942
  output_ids = self.generate(
943
+ input_ids.unsqueeze(0).to(DEVICE),
944
+ images=[(images_crop.to(DEVICE), images_ori.to(DEVICE))],
945
+ images_seq_mask = images_seq_mask.unsqueeze(0).to(DEVICE),
946
  images_spatial_crop = images_spatial_crop,
947
  # do_sample=False,
948
  # num_beams = 1,
 
955
 
956
 
957
  if '<image>' in conversation[0]['content'] and eval_mode:
958
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(DEVICE).shape[1]:])
959
  stop_str = '<|end▁of▁sentence|>'
960
  if outputs.endswith(stop_str):
961
  outputs = outputs[:-len(stop_str)]
 
965
  return outputs
966
 
967
  if '<image>' in conversation[0]['content'] and test_compress:
968
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(DEVICE).shape[1]:])
969
  pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
970
  print('='*50)
971
  print('image size: ', (w, h))
 
976
 
977
 
978
  if '<image>' in conversation[0]['content'] and save_results:
979
+ outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(DEVICE).shape[1]:])
980
  stop_str = '<|end▁of▁sentence|>'
981
 
982
  print('='*15 + 'save results:' + '='*15)
modeling_deepseekv2.py CHANGED
@@ -36,7 +36,6 @@ from transformers.cache_utils import Cache, DynamicCache
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
  from transformers.models.llama.modeling_llama import (
38
  LlamaAttention,
39
- LlamaFlashAttention2
40
  )
41
  from transformers.modeling_outputs import (
42
  BaseModelOutputWithPast,
@@ -889,7 +888,7 @@ class DeepseekV2Attention(nn.Module):
889
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
890
  "with a layer index."
891
  )
892
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
893
 
894
  cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
895
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
@@ -1018,7 +1017,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
1018
 
1019
  kv_seq_len = value_states.shape[-2]
1020
  if past_key_value is not None:
1021
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1022
 
1023
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1024
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
@@ -1235,7 +1234,7 @@ ATTENTION_CLASSES = {
1235
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1236
 
1237
  "mha_eager": LlamaAttention,
1238
- "mha_flash_attention_2": LlamaFlashAttention2
1239
  }
1240
 
1241
 
@@ -1539,7 +1538,7 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
1539
  use_legacy_cache = not isinstance(past_key_values, Cache)
1540
  if use_legacy_cache:
1541
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1542
- past_key_values_length = past_key_values.get_usable_length(seq_length)
1543
 
1544
  if position_ids is None:
1545
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1779,8 +1778,8 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
1779
  if past_key_values is not None:
1780
  if isinstance(past_key_values, Cache):
1781
  cache_length = past_key_values.get_seq_length()
1782
- past_length = past_key_values.seen_tokens
1783
- max_cache_length = past_key_values.get_max_length()
1784
  else:
1785
  cache_length = past_length = past_key_values[0][0].shape[2]
1786
  max_cache_length = None
 
36
  from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
37
  from transformers.models.llama.modeling_llama import (
38
  LlamaAttention,
 
39
  )
40
  from transformers.modeling_outputs import (
41
  BaseModelOutputWithPast,
 
888
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
889
  "with a layer index."
890
  )
891
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
892
 
893
  cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
894
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 
1017
 
1018
  kv_seq_len = value_states.shape[-2]
1019
  if past_key_value is not None:
1020
+ kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
1021
 
1022
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1023
  q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
 
1234
  "mla_flash_attention_2": DeepseekV2FlashAttention2,
1235
 
1236
  "mha_eager": LlamaAttention,
1237
+ "mha_flash_attention_2": None
1238
  }
1239
 
1240
 
 
1538
  use_legacy_cache = not isinstance(past_key_values, Cache)
1539
  if use_legacy_cache:
1540
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1541
+ past_key_values_length = past_key_values.get_seq_length()
1542
 
1543
  if position_ids is None:
1544
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
1778
  if past_key_values is not None:
1779
  if isinstance(past_key_values, Cache):
1780
  cache_length = past_key_values.get_seq_length()
1781
+ past_length = cache_length
1782
+ max_cache_length = past_key_values.get_max_cache_shape()
1783
  else:
1784
  cache_length = past_length = past_key_values[0][0].shape[2]
1785
  max_cache_length = None