make deepseekocr compatible with mps and latest transformers
#96
by
hebangwen
- opened
- modeling_deepseekocr.py +22 -14
- modeling_deepseekv2.py +6 -7
modeling_deepseekocr.py
CHANGED
|
@@ -24,6 +24,13 @@ import numpy as np
|
|
| 24 |
import time
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def load_image(image_path):
|
| 28 |
|
| 29 |
try:
|
|
@@ -502,7 +509,7 @@ class DeepseekOCRModel(DeepseekV2Model):
|
|
| 502 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 503 |
# exit()
|
| 504 |
|
| 505 |
-
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).
|
| 506 |
|
| 507 |
idx += 1
|
| 508 |
|
|
@@ -622,8 +629,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 622 |
if past_key_values is not None:
|
| 623 |
if isinstance(past_key_values, Cache):
|
| 624 |
cache_length = past_key_values.get_seq_length()
|
| 625 |
-
past_length =
|
| 626 |
-
max_cache_length = past_key_values.
|
| 627 |
else:
|
| 628 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 629 |
max_cache_length = None
|
|
@@ -645,6 +652,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 645 |
max_cache_length is not None
|
| 646 |
and attention_mask is not None
|
| 647 |
and cache_length + input_ids.shape[1] > max_cache_length
|
|
|
|
| 648 |
):
|
| 649 |
attention_mask = attention_mask[:, -max_cache_length:]
|
| 650 |
|
|
@@ -911,12 +919,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 911 |
|
| 912 |
if not eval_mode:
|
| 913 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 914 |
-
with torch.autocast(
|
| 915 |
with torch.no_grad():
|
| 916 |
output_ids = self.generate(
|
| 917 |
-
input_ids.unsqueeze(0).
|
| 918 |
-
images=[(images_crop.
|
| 919 |
-
images_seq_mask = images_seq_mask.unsqueeze(0).
|
| 920 |
images_spatial_crop = images_spatial_crop,
|
| 921 |
# do_sample=False,
|
| 922 |
# num_beams = 1,
|
|
@@ -929,12 +937,12 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 929 |
)
|
| 930 |
|
| 931 |
else:
|
| 932 |
-
with torch.autocast(
|
| 933 |
with torch.no_grad():
|
| 934 |
output_ids = self.generate(
|
| 935 |
-
input_ids.unsqueeze(0).
|
| 936 |
-
images=[(images_crop.
|
| 937 |
-
images_seq_mask = images_seq_mask.unsqueeze(0).
|
| 938 |
images_spatial_crop = images_spatial_crop,
|
| 939 |
# do_sample=False,
|
| 940 |
# num_beams = 1,
|
|
@@ -947,7 +955,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 947 |
|
| 948 |
|
| 949 |
if '<image>' in conversation[0]['content'] and eval_mode:
|
| 950 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 951 |
stop_str = '<|end▁of▁sentence|>'
|
| 952 |
if outputs.endswith(stop_str):
|
| 953 |
outputs = outputs[:-len(stop_str)]
|
|
@@ -957,7 +965,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 957 |
return outputs
|
| 958 |
|
| 959 |
if '<image>' in conversation[0]['content'] and test_compress:
|
| 960 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 961 |
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
|
| 962 |
print('='*50)
|
| 963 |
print('image size: ', (w, h))
|
|
@@ -968,7 +976,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 968 |
|
| 969 |
|
| 970 |
if '<image>' in conversation[0]['content'] and save_results:
|
| 971 |
-
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).
|
| 972 |
stop_str = '<|end▁of▁sentence|>'
|
| 973 |
|
| 974 |
print('='*15 + 'save results:' + '='*15)
|
|
|
|
| 24 |
import time
|
| 25 |
|
| 26 |
|
| 27 |
+
DEVICE = "cpu"
|
| 28 |
+
if torch.mps.is_available():
|
| 29 |
+
DEVICE = "mps"
|
| 30 |
+
elif torch.cuda.is_available():
|
| 31 |
+
DEVICE = "cuda"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
def load_image(image_path):
|
| 35 |
|
| 36 |
try:
|
|
|
|
| 509 |
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
| 510 |
# exit()
|
| 511 |
|
| 512 |
+
inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).to(DEVICE), images_in_this_batch)
|
| 513 |
|
| 514 |
idx += 1
|
| 515 |
|
|
|
|
| 629 |
if past_key_values is not None:
|
| 630 |
if isinstance(past_key_values, Cache):
|
| 631 |
cache_length = past_key_values.get_seq_length()
|
| 632 |
+
past_length = cache_length
|
| 633 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
| 634 |
else:
|
| 635 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 636 |
max_cache_length = None
|
|
|
|
| 652 |
max_cache_length is not None
|
| 653 |
and attention_mask is not None
|
| 654 |
and cache_length + input_ids.shape[1] > max_cache_length
|
| 655 |
+
and max_cache_length > -1
|
| 656 |
):
|
| 657 |
attention_mask = attention_mask[:, -max_cache_length:]
|
| 658 |
|
|
|
|
| 919 |
|
| 920 |
if not eval_mode:
|
| 921 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 922 |
+
with torch.autocast(DEVICE, dtype=torch.bfloat16):
|
| 923 |
with torch.no_grad():
|
| 924 |
output_ids = self.generate(
|
| 925 |
+
input_ids.unsqueeze(0).to(DEVICE),
|
| 926 |
+
images=[(images_crop.to(DEVICE), images_ori.to(DEVICE))],
|
| 927 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).to(DEVICE),
|
| 928 |
images_spatial_crop = images_spatial_crop,
|
| 929 |
# do_sample=False,
|
| 930 |
# num_beams = 1,
|
|
|
|
| 937 |
)
|
| 938 |
|
| 939 |
else:
|
| 940 |
+
with torch.autocast(DEVICE, dtype=torch.bfloat16):
|
| 941 |
with torch.no_grad():
|
| 942 |
output_ids = self.generate(
|
| 943 |
+
input_ids.unsqueeze(0).to(DEVICE),
|
| 944 |
+
images=[(images_crop.to(DEVICE), images_ori.to(DEVICE))],
|
| 945 |
+
images_seq_mask = images_seq_mask.unsqueeze(0).to(DEVICE),
|
| 946 |
images_spatial_crop = images_spatial_crop,
|
| 947 |
# do_sample=False,
|
| 948 |
# num_beams = 1,
|
|
|
|
| 955 |
|
| 956 |
|
| 957 |
if '<image>' in conversation[0]['content'] and eval_mode:
|
| 958 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(DEVICE).shape[1]:])
|
| 959 |
stop_str = '<|end▁of▁sentence|>'
|
| 960 |
if outputs.endswith(stop_str):
|
| 961 |
outputs = outputs[:-len(stop_str)]
|
|
|
|
| 965 |
return outputs
|
| 966 |
|
| 967 |
if '<image>' in conversation[0]['content'] and test_compress:
|
| 968 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(DEVICE).shape[1]:])
|
| 969 |
pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False))
|
| 970 |
print('='*50)
|
| 971 |
print('image size: ', (w, h))
|
|
|
|
| 976 |
|
| 977 |
|
| 978 |
if '<image>' in conversation[0]['content'] and save_results:
|
| 979 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).to(DEVICE).shape[1]:])
|
| 980 |
stop_str = '<|end▁of▁sentence|>'
|
| 981 |
|
| 982 |
print('='*15 + 'save results:' + '='*15)
|
modeling_deepseekv2.py
CHANGED
|
@@ -36,7 +36,6 @@ from transformers.cache_utils import Cache, DynamicCache
|
|
| 36 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 37 |
from transformers.models.llama.modeling_llama import (
|
| 38 |
LlamaAttention,
|
| 39 |
-
LlamaFlashAttention2
|
| 40 |
)
|
| 41 |
from transformers.modeling_outputs import (
|
| 42 |
BaseModelOutputWithPast,
|
|
@@ -889,7 +888,7 @@ class DeepseekV2Attention(nn.Module):
|
|
| 889 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 890 |
"with a layer index."
|
| 891 |
)
|
| 892 |
-
kv_seq_len += past_key_value.
|
| 893 |
|
| 894 |
cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
|
| 895 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
|
@@ -1018,7 +1017,7 @@ class DeepseekV2FlashAttention2(DeepseekV2Attention):
|
|
| 1018 |
|
| 1019 |
kv_seq_len = value_states.shape[-2]
|
| 1020 |
if past_key_value is not None:
|
| 1021 |
-
kv_seq_len += past_key_value.
|
| 1022 |
|
| 1023 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 1024 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
|
@@ -1235,7 +1234,7 @@ ATTENTION_CLASSES = {
|
|
| 1235 |
"mla_flash_attention_2": DeepseekV2FlashAttention2,
|
| 1236 |
|
| 1237 |
"mha_eager": LlamaAttention,
|
| 1238 |
-
"mha_flash_attention_2":
|
| 1239 |
}
|
| 1240 |
|
| 1241 |
|
|
@@ -1539,7 +1538,7 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel):
|
|
| 1539 |
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 1540 |
if use_legacy_cache:
|
| 1541 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1542 |
-
past_key_values_length = past_key_values.
|
| 1543 |
|
| 1544 |
if position_ids is None:
|
| 1545 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
@@ -1779,8 +1778,8 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
|
|
| 1779 |
if past_key_values is not None:
|
| 1780 |
if isinstance(past_key_values, Cache):
|
| 1781 |
cache_length = past_key_values.get_seq_length()
|
| 1782 |
-
past_length =
|
| 1783 |
-
max_cache_length = past_key_values.
|
| 1784 |
else:
|
| 1785 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1786 |
max_cache_length = None
|
|
|
|
| 36 |
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 37 |
from transformers.models.llama.modeling_llama import (
|
| 38 |
LlamaAttention,
|
|
|
|
| 39 |
)
|
| 40 |
from transformers.modeling_outputs import (
|
| 41 |
BaseModelOutputWithPast,
|
|
|
|
| 888 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 889 |
"with a layer index."
|
| 890 |
)
|
| 891 |
+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
| 892 |
|
| 893 |
cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len)
|
| 894 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
|
|
|
| 1017 |
|
| 1018 |
kv_seq_len = value_states.shape[-2]
|
| 1019 |
if past_key_value is not None:
|
| 1020 |
+
kv_seq_len += past_key_value.get_seq_length(self.layer_idx)
|
| 1021 |
|
| 1022 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 1023 |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids)
|
|
|
|
| 1234 |
"mla_flash_attention_2": DeepseekV2FlashAttention2,
|
| 1235 |
|
| 1236 |
"mha_eager": LlamaAttention,
|
| 1237 |
+
"mha_flash_attention_2": None
|
| 1238 |
}
|
| 1239 |
|
| 1240 |
|
|
|
|
| 1538 |
use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 1539 |
if use_legacy_cache:
|
| 1540 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1541 |
+
past_key_values_length = past_key_values.get_seq_length()
|
| 1542 |
|
| 1543 |
if position_ids is None:
|
| 1544 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
| 1778 |
if past_key_values is not None:
|
| 1779 |
if isinstance(past_key_values, Cache):
|
| 1780 |
cache_length = past_key_values.get_seq_length()
|
| 1781 |
+
past_length = cache_length
|
| 1782 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
| 1783 |
else:
|
| 1784 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1785 |
max_cache_length = None
|