| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| | from transformers.cache_utils import Cache, HybridCache |
| | from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPast, |
| | SequenceClassifierOutputWithPast, |
| | ) |
| | from transformers.models.llama.configuration_llama import LlamaConfig |
| | from transformers.models.llama.modeling_llama import ( |
| | LlamaModel, |
| | LlamaPreTrainedModel, |
| | ) |
| | from transformers.utils import logging |
| |
|
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class LlamaBidirectionalConfig(LlamaConfig): |
| | model_type = "llama_bidirec" |
| |
|
| | def __init__( |
| | self, pooling="avg", temperature=1.0, **kwargs, |
| | ): |
| | self.pooling = pooling |
| | self.temperature = temperature |
| | super().__init__(**kwargs,) |
| |
|
| |
|
| | class LlamaBidirectionalModel(LlamaModel): |
| | config_class = LlamaBidirectionalConfig |
| |
|
| | def __init__(self, config: LlamaConfig): |
| | super().__init__(config) |
| | for layer in self.layers: |
| | layer.self_attn.is_causal = False |
| | self.config._attn_implementation = "eager" |
| |
|
| | def _update_causal_mask( |
| | self, |
| | attention_mask: torch.Tensor, |
| | input_tensor: torch.Tensor, |
| | cache_position: torch.Tensor, |
| | past_key_values: Cache, |
| | output_attentions: bool, |
| | ): |
| | |
| | causal_mask = _prepare_4d_attention_mask(attention_mask, input_tensor.dtype) |
| | return causal_mask |
| |
|
| |
|
| |
|