Update modeling_rotary_indictrans.py
Browse files- modeling_rotary_indictrans.py +26 -25
modeling_rotary_indictrans.py
CHANGED
|
@@ -31,16 +31,22 @@ from transformers.generation import GenerationMixin
|
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from .configuration_rotary_indictrans import RotaryIndicTransConfig
|
| 33 |
|
| 34 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 35 |
-
from flash_attn.bert_padding import (
|
| 36 |
-
index_first_axis,
|
| 37 |
-
pad_input,
|
| 38 |
-
unpad_input,
|
| 39 |
-
)
|
| 40 |
-
|
| 41 |
logger = logging.get_logger(__name__)
|
| 42 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 46 |
def _get_unpad_data(attention_mask):
|
|
@@ -1401,8 +1407,6 @@ class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel):
|
|
| 1401 |
|
| 1402 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->RotaryIndicTrans
|
| 1403 |
class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
|
| 1404 |
-
_tied_weights_keys = None
|
| 1405 |
-
|
| 1406 |
def __init__(self, config: RotaryIndicTransConfig):
|
| 1407 |
super().__init__(config)
|
| 1408 |
|
|
@@ -1497,10 +1501,11 @@ class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
|
|
| 1497 |
|
| 1498 |
|
| 1499 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
|
| 1500 |
-
class RotaryIndicTransForConditionalGeneration(
|
|
|
|
|
|
|
| 1501 |
base_model_prefix = "model"
|
| 1502 |
-
_tied_weights_keys =
|
| 1503 |
-
_label_smoothing = 0.0
|
| 1504 |
|
| 1505 |
def __init__(self, config: RotaryIndicTransConfig):
|
| 1506 |
super().__init__(config)
|
|
@@ -1509,19 +1514,16 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
| 1509 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
| 1510 |
)
|
| 1511 |
|
| 1512 |
-
if config.share_decoder_input_output_embed:
|
| 1513 |
-
self.lm_head.weight = self.model.decoder.embed_tokens.weight
|
| 1514 |
-
|
| 1515 |
self.post_init()
|
| 1516 |
|
| 1517 |
-
def tie_weights(self):
|
| 1518 |
-
pass
|
| 1519 |
-
|
| 1520 |
def get_encoder(self):
|
| 1521 |
-
return self.model.
|
| 1522 |
|
| 1523 |
def get_decoder(self):
|
| 1524 |
-
return self.model.
|
|
|
|
|
|
|
|
|
|
| 1525 |
|
| 1526 |
def get_output_embeddings(self):
|
| 1527 |
return self.lm_head
|
|
@@ -1529,8 +1531,9 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
| 1529 |
def set_output_embeddings(self, new_embeddings):
|
| 1530 |
self.lm_head = new_embeddings
|
| 1531 |
|
| 1532 |
-
def
|
| 1533 |
-
self.
|
|
|
|
| 1534 |
|
| 1535 |
def forward(
|
| 1536 |
self,
|
|
@@ -1594,8 +1597,6 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
| 1594 |
masked_lm_loss = F.cross_entropy(
|
| 1595 |
input=lm_logits.view(-1, self.config.decoder_vocab_size),
|
| 1596 |
target=labels.view(-1),
|
| 1597 |
-
ignore_index=-100,
|
| 1598 |
-
label_smoothing=self._label_smoothing,
|
| 1599 |
)
|
| 1600 |
|
| 1601 |
if not return_dict:
|
|
@@ -1652,4 +1653,4 @@ class RotaryIndicTransForConditionalGeneration(RotaryIndicTransPreTrainedModel,
|
|
| 1652 |
past_state.index_select(0, beam_idx) for past_state in layer_past
|
| 1653 |
),
|
| 1654 |
)
|
| 1655 |
-
return reordered_past
|
|
|
|
| 31 |
from transformers.modeling_utils import PreTrainedModel
|
| 32 |
from .configuration_rotary_indictrans import RotaryIndicTransConfig
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
logger = logging.get_logger(__name__)
|
| 35 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 36 |
|
| 37 |
+
try:
|
| 38 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 39 |
+
from flash_attn.bert_padding import (
|
| 40 |
+
index_first_axis,
|
| 41 |
+
pad_input,
|
| 42 |
+
unpad_input,
|
| 43 |
+
)
|
| 44 |
+
except ImportError:
|
| 45 |
+
logger.warning(
|
| 46 |
+
"It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans."
|
| 47 |
+
"Falling back to the default `eager` implementation."
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
|
| 51 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
| 52 |
def _get_unpad_data(attention_mask):
|
|
|
|
| 1407 |
|
| 1408 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->RotaryIndicTrans
|
| 1409 |
class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel):
|
|
|
|
|
|
|
| 1410 |
def __init__(self, config: RotaryIndicTransConfig):
|
| 1411 |
super().__init__(config)
|
| 1412 |
|
|
|
|
| 1501 |
|
| 1502 |
|
| 1503 |
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans
|
| 1504 |
+
class RotaryIndicTransForConditionalGeneration(
|
| 1505 |
+
RotaryIndicTransPreTrainedModel, GenerationMixin
|
| 1506 |
+
):
|
| 1507 |
base_model_prefix = "model"
|
| 1508 |
+
_tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"]
|
|
|
|
| 1509 |
|
| 1510 |
def __init__(self, config: RotaryIndicTransConfig):
|
| 1511 |
super().__init__(config)
|
|
|
|
| 1514 |
config.decoder_embed_dim, config.decoder_vocab_size, bias=False
|
| 1515 |
)
|
| 1516 |
|
|
|
|
|
|
|
|
|
|
| 1517 |
self.post_init()
|
| 1518 |
|
|
|
|
|
|
|
|
|
|
| 1519 |
def get_encoder(self):
|
| 1520 |
+
return self.model.encoder
|
| 1521 |
|
| 1522 |
def get_decoder(self):
|
| 1523 |
+
return self.model.decoder
|
| 1524 |
+
|
| 1525 |
+
def get_input_embeddings(self):
|
| 1526 |
+
return self.model.encoder.embed_tokens
|
| 1527 |
|
| 1528 |
def get_output_embeddings(self):
|
| 1529 |
return self.lm_head
|
|
|
|
| 1531 |
def set_output_embeddings(self, new_embeddings):
|
| 1532 |
self.lm_head = new_embeddings
|
| 1533 |
|
| 1534 |
+
def tie_weights(self):
|
| 1535 |
+
if self.config.share_decoder_input_output_embed:
|
| 1536 |
+
self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head)
|
| 1537 |
|
| 1538 |
def forward(
|
| 1539 |
self,
|
|
|
|
| 1597 |
masked_lm_loss = F.cross_entropy(
|
| 1598 |
input=lm_logits.view(-1, self.config.decoder_vocab_size),
|
| 1599 |
target=labels.view(-1),
|
|
|
|
|
|
|
| 1600 |
)
|
| 1601 |
|
| 1602 |
if not return_dict:
|
|
|
|
| 1653 |
past_state.index_select(0, beam_idx) for past_state in layer_past
|
| 1654 |
),
|
| 1655 |
)
|
| 1656 |
+
return reordered_past
|