| | from typing import Optional, Sequence, Tuple, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast |
| | from xlstm.components.init import small_init_init_ |
| | from xlstm.utils import WeightDecayOptimGroupMixin |
| | from xlstm.xlstm_block_stack import xLSTMBlockStack as _xLSTMBlockStack |
| |
|
| | from .configuration_xlstm import xLSTMConfig |
| |
|
| |
|
| | class xLSTMPreTrainedModel(PreTrainedModel): |
| | """Base class for all models.""" |
| |
|
| | config_class = xLSTMConfig |
| |
|
| |
|
| | class xLSTMBlockStack(_xLSTMBlockStack): |
| | """Small wrapper to expose hidden states""" |
| |
|
| | def forward( |
| | self, x: torch.Tensor, **kwargs |
| | ) -> Tuple[torch.Tensor, Sequence[torch.Tensor]]: |
| | hidden_states = () |
| | for block in self.blocks: |
| | x = block(x, **kwargs) |
| | hidden_states += (x,) |
| |
|
| | x = self.post_blocks_norm(x) |
| |
|
| | return x, hidden_states |
| |
|
| |
|
| | class xLSTMModel(xLSTMPreTrainedModel): |
| | def __init__(self, config: xLSTMConfig): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | self.token_embedding = nn.Embedding( |
| | num_embeddings=config.vocab_size, embedding_dim=config.embedding_dim |
| | ) |
| | _config = config.to_xlstm_config() |
| |
|
| | self.emb_dropout = ( |
| | nn.Dropout(_config.dropout) |
| | if _config.add_embedding_dropout |
| | else nn.Identity() |
| | ) |
| |
|
| | self.xlstm_block_stack = xLSTMBlockStack(config=_config) |
| |
|
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict=Optional[bool], |
| | ) -> Union[Tuple, BaseModelOutput]: |
| | token_embedding = self.token_embedding(input_ids) |
| | x = self.emb_dropout(token_embedding) |
| | x, hidden_states = self.xlstm_block_stack(x) |
| |
|
| | if output_hidden_states: |
| | hidden_states = (token_embedding,) + hidden_states |
| |
|
| | if not return_dict: |
| | return x, hidden_states |
| |
|
| | return BaseModelOutput( |
| | last_hidden_state=x, |
| | hidden_states=hidden_states if output_hidden_states else None, |
| | ) |
| |
|
| |
|
| | class xLSTMForCausalLM(xLSTMPreTrainedModel, WeightDecayOptimGroupMixin): |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config: xLSTMConfig, **kwargs): |
| | super().__init__(config) |
| | self.config = config |
| | self.vocab_size = config.vocab_size |
| |
|
| | self.model = xLSTMModel(config) |
| |
|
| | self.lm_head = nn.Linear( |
| | in_features=config.embedding_dim, |
| | out_features=config.vocab_size, |
| | bias=False, |
| | ) |
| |
|
| | self.post_init() |
| | |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.token_embedding |
| |
|
| | def set_input_embeddings(self, value: nn.Module): |
| | self.model.token_embedding = value |
| |
|
| | def get_output_embeddings(self): |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, value): |
| | self.lm_head = value |
| |
|
| | def reset_parameters(self): |
| | self.model.xlstm_block_stack.reset_parameters() |
| |
|
| | small_init_init_( |
| | self.get_input_embeddings().weight, dim=self.config.embedding_dim |
| | ) |
| |
|
| | if not self.config.tie_word_embeddings: |
| | small_init_init_( |
| | self.get_output_embeddings().weight, dim=self.config.embedding_dim |
| | ) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | labels: Optional[torch.LongTensor] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | ): |
| | output = self.model( |
| | input_ids, |
| | output_hidden_states=output_hidden_states, |
| | ) |
| |
|
| | hidden_state = output[0] |
| |
|
| | logits = self.lm_head(hidden_state) |
| | logits = logits.float() |
| |
|
| | loss = None |
| |
|
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| |
|
| | loss_fct = nn.CrossEntropyLoss() |
| | shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| | shift_labels = shift_labels.view(-1) |
| |
|
| | shift_labels = shift_labels.to(shift_logits.device) |
| | loss = loss_fct(shift_logits, shift_labels) |
| |
|
| | if not return_dict: |
| | output = (logits,) + output[1:] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=output.hidden_states, |
| | ) |
| |
|
| | def step( |
| | self, |
| | idx: torch.Tensor, |
| | state: dict[str, dict[str, tuple[torch.Tensor, ...]]] = None, |
| | **kwargs, |
| | ) -> tuple[torch.Tensor, dict[str, dict[str, tuple[torch.Tensor, ...]]]]: |
| | x = self.token_embedding(idx) |
| | x = self.emb_dropout(x) |
| | x, state = self.xlstm_block_stack.step(x, state=state, **kwargs) |
| | logits = self.lm_head(x) |
| | return logits, state |
| |
|
| | def _create_weight_decay_optim_groups( |
| | self, **kwargs |
| | ) -> tuple[Sequence[nn.Parameter], Sequence[nn.Parameter]]: |
| | weight_decay, no_weight_decay = super()._create_weight_decay_optim_groups( |
| | **kwargs |
| | ) |
| | |
| | weight_decay = list(weight_decay) |
| | removed = 0 |
| | for idx in range(len(weight_decay)): |
| | if weight_decay[idx - removed] is self.get_input_embeddings().weight: |
| | weight_decay.pop(idx - removed) |
| | removed += 1 |
| | weight_decay = tuple(weight_decay) |
| |
|
| | |
| | |
| | if True: |
| | weight_decay += (self.get_input_embeddings().weight,) |
| | else: |
| | no_weight_decay += (self.get_input_embeddings().weight,) |
| |
|
| | return weight_decay, no_weight_decay |
| |
|
| | def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding: |
| | new_embeddings = nn.Embedding( |
| | new_num_tokens, self.token_embedding.embedding_dim |
| | ) |
| | self.token_embedding = new_embeddings.to(self.device) |
| | return new_embeddings |
| |
|
| | def tie_weights(self): |
| | self.get_output_embeddings().weight = self.get_input_embeddings().weight |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | **kwargs, |
| | ): |
| | model_inputs = { |
| | "input_ids": input_ids.to(self.device), |
| | } |
| | return model_inputs |
| |
|