Upload folder using huggingface_hub
Browse files- .gitattributes +2 -0
- LICENSE +26 -0
- README.md +183 -0
- __init__.py +26 -0
- added_tokens.json +29 -0
- chat_template.jinja +69 -0
- config.json +69 -0
- configuration_iquestloopcoder.py +132 -0
- generation_config.json +6 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_iquestloopcoder.py +1113 -0
- papers/iquest-coder-v1-logo.png +3 -0
- papers/results.png +3 -0
- recipe.yaml +35 -0
- special_tokens_map.json +48 -0
- tokenization_iquestcoder.py +552 -0
- tokenizer.model +3 -0
- tokenizer_config.json +242 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
papers/iquest-coder-v1-logo.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
papers/results.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Modified MIT License
|
| 2 |
+
|
| 3 |
+
Software Copyright© 2025 IQuest Research
|
| 4 |
+
|
| 5 |
+
Our only modification is that, if the Software (or any derivative works
|
| 6 |
+
thereof) is used for any of your commercial products or services, you shall
|
| 7 |
+
prominently display "IQuest Coder" on the user interface of such product or
|
| 8 |
+
service.
|
| 9 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 10 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 11 |
+
in the Software without restriction, including without limitation the rights
|
| 12 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 13 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 14 |
+
furnished to do so, subject to the following conditions:
|
| 15 |
+
|
| 16 |
+
The above copyright notice and this permission notice shall be included in all
|
| 17 |
+
copies or substantial portions of the Software.
|
| 18 |
+
|
| 19 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 20 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 21 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 22 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 23 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 24 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 25 |
+
|
| 26 |
+
|
README.md
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: other
|
| 3 |
+
license_name: iquestcoder
|
| 4 |
+
license_link: >-
|
| 5 |
+
https://huggingface.co/IQuestLab/IQuest-Coder-V1-40B-Instruct/blob/main/LICENSE
|
| 6 |
+
language:
|
| 7 |
+
- en
|
| 8 |
+
library_name: transformers
|
| 9 |
+
base_model: IQuestLab/IQuest-Coder-V1-40B-Loop-Instruct
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+

|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
📘 <a href="https://iquestlab.github.io">Blog</a >
|
| 16 |
+
•
|
| 17 |
+
📄 <a href="https://github.com/IQuestLab/IQuest-Coder-V1/blob/main/papers/IQuest_Coder_Technical_Report.pdf">Technical Report</a >
|
| 18 |
+
</p >
|
| 19 |
+
|
| 20 |
+
# IQuest-Coder-V1 Model Family
|
| 21 |
+
|
| 22 |
+
| Model | Link |
|
| 23 |
+
|-------|------|
|
| 24 |
+
| IQuest-Coder-V1-40B-Base-Stage1 | [🤗 Hugging Face](https://huggingface.co/IQuestLab/IQuest-Coder-V1-40B-Base-Stage1) |
|
| 25 |
+
| IQuest-Coder-V1-40B-Base | [🤗 Hugging Face](https://huggingface.co/IQuestLab/IQuest-Coder-V1-40B-Base) |
|
| 26 |
+
| IQuest-Coder-V1-40B-Instruct | [🤗 Hugging Face](https://huggingface.co/IQuestLab/IQuest-Coder-V1-40B-Instruct) |
|
| 27 |
+
| IQuest-Coder-V1-40B-Loop-Instruct | [🤗 Hugging Face](https://huggingface.co/IQuestLab/IQuest-Coder-V1-40B-Loop-Instruct) |
|
| 28 |
+
|
| 29 |
+
[Clarification: Regarding the Performance of IQuest-Coder-V1](https://github.com/IQuestLab/IQuest-Coder-V1/issues/14#issuecomment-3705756919)
|
| 30 |
+
|
| 31 |
+
## Sampling Parameters
|
| 32 |
+
For the IQuest-Coder-V1-Instruct: We suggest using Temperature=0.6, TopP=0.85, TopK=20.
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
## IQuest-Coder-V1 Highlights
|
| 36 |
+
|
| 37 |
+
IQuest-Coder-V1 is a new family of code large language models (LLMs) designed to advance autonomous software engineering and code intelligence. Built on the innovative code-flow multi-stage training paradigm, IQuest-Coder-V1 captures the dynamic evolution of software logic, delivering state-of-the-art performance across critical dimensions:
|
| 38 |
+
|
| 39 |
+
- **State-of-the-Art Performance**: Achieves leading results on SWE-Bench Verified (76.2%), BigCodeBench (49.9%), LiveCodeBench v6 (81.1%), and other major coding benchmarks, surpassing competitive models across agentic software engineering, competitive programming, and complex tool use.
|
| 40 |
+
- **Code-Flow Training Paradigm**: Moving beyond static code representations, our models learn from repository evolution patterns, commit transitions, and dynamic code transformations to understand real-world software development processes.
|
| 41 |
+
- **Dual Specialization Paths**: Bifurcated post-training delivers two specialized variants—Thinking models (utilizing reasoning-driven RL for complex problem-solving) and Instruct models (optimized for general coding assistance and instruction-following).
|
| 42 |
+
- **Efficient Architecture**: The IQuest-Coder-V1-Loop variant introduces a recurrent mechanism that optimizes the trade-off between model capacity and deployment footprint.
|
| 43 |
+
- **Native Long Context**: All models natively support up to 128K tokens without requiring additional scaling techniques.
|
| 44 |
+
|
| 45 |
+
## Model Overview
|
| 46 |
+
|
| 47 |
+
The IQuest-Coder-V1 series includes models ranging from 7B to 40B parameters, with both standard and Loop variants:
|
| 48 |
+
|
| 49 |
+
| Model | Parameters | Layers | Hidden Size | Attention Heads (Q/KV) | Context Length |
|
| 50 |
+
|-------|------------|--------|-------------|------------------------|----------------|
|
| 51 |
+
| IQuest-Coder-V1-7B-Instruct | 7B | 14 | 5120 | 40/8 | 128K |
|
| 52 |
+
| IQuest-Coder-V1-7B-Thinking | 7B | 14 | 5120 | 40/8 | 128K |
|
| 53 |
+
| IQuest-Coder-V1-14B-Instruct | 14B | 28 | 5120 | 40/8 | 128K |
|
| 54 |
+
| IQuest-Coder-V1-14B-Thinking | 14B | 28 | 5120 | 40/8 | 128K |
|
| 55 |
+
| IQuest-Coder-V1-40B-Instruct | 40B | 80 | 5120 | 40/8 | 128K |
|
| 56 |
+
| IQuest-Coder-V1-40B-Thinking | 40B | 80 | 5120 | 40/8 | 128K |
|
| 57 |
+
| IQuest-Coder-V1-40B-Loop-Instruct | 40B | 80 (2 iterations) | 5120 | 40/8 | 128K |
|
| 58 |
+
| IQuest-Coder-V1-40B-Loop-Thinking | 40B | 80 (2 iterations) | 5120 | 40/8 | 128K |
|
| 59 |
+
|
| 60 |
+
**Architecture Features:**
|
| 61 |
+
|
| 62 |
+
- Grouped Query Attention (GQA) for efficient inference
|
| 63 |
+
- Native 128K context length support
|
| 64 |
+
- Vocabulary size: 76,800 tokens
|
| 65 |
+
- Loop variants use recurrent transformer design with shared parameters across two iterations
|
| 66 |
+
|
| 67 |
+
For more details, please refer to our Technical Report, GitHub.
|
| 68 |
+
|
| 69 |
+
## Quickstart
|
| 70 |
+
|
| 71 |
+
IQuest-Coder-V1 uses custom modeling code via Hugging Face's auto_map feature. We recommend using transformers==4.56.0.
|
| 72 |
+
|
| 73 |
+
### Basic Usage with Transformers
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 77 |
+
|
| 78 |
+
model_name = "IQuestLab/IQuest-Coder-V1-40B-Instruct"
|
| 79 |
+
|
| 80 |
+
# Load the tokenizer and model
|
| 81 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 82 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 83 |
+
model_name,
|
| 84 |
+
torch_dtype="auto",
|
| 85 |
+
device_map="cuda:0",
|
| 86 |
+
trust_remote_code=True,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Prepare the input
|
| 90 |
+
prompt = "Write a Python function to calculate the Fibonacci sequence using dynamic programming."
|
| 91 |
+
messages = [
|
| 92 |
+
{"role": "user", "content": prompt}
|
| 93 |
+
]
|
| 94 |
+
text = tokenizer.apply_chat_template(
|
| 95 |
+
messages,
|
| 96 |
+
tokenize=False,
|
| 97 |
+
add_generation_prompt=True
|
| 98 |
+
)
|
| 99 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 100 |
+
|
| 101 |
+
# Generate response
|
| 102 |
+
generated_ids = model.generate(
|
| 103 |
+
**model_inputs,
|
| 104 |
+
max_new_tokens=8192
|
| 105 |
+
)
|
| 106 |
+
generated_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
|
| 107 |
+
response = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 108 |
+
|
| 109 |
+
print(response)
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Using Thinking Models
|
| 113 |
+
|
| 114 |
+
For complex reasoning tasks, use the Thinking variant:
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
model_name = "IQuest/IQuest-Coder-V1-40B-Thinking"
|
| 118 |
+
|
| 119 |
+
# The Thinking model includes explicit reasoning traces
|
| 120 |
+
# Use similar code as above, but expect longer, more detailed responses
|
| 121 |
+
# with step-by-step problem decomposition
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### Deployment with vLLM
|
| 125 |
+
|
| 126 |
+
For production deployment, you can use vLLM to create an OpenAI-compatible API endpoint. Please refer to the [vLLM PR](https://github.com/vllm-project/vllm/pull/31575/files) for implementation details.
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
vllm serve IQuestLab/IQuest-Coder-V1-40B-Instruct --tensor-parallel-size 8
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
For Thinking models with reasoning support:
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
vllm serve IQuestLab/IQuest-Coder-V1-40B-Thinking --reasoning-parser qwen3 --tensor-parallel-size 8
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
## Evaluation Results
|
| 140 |
+

|
| 141 |
+
|
| 142 |
+
## Limitations
|
| 143 |
+
|
| 144 |
+
- **Reasoning vs. Efficiency Trade-off**: Thinking models provide superior reasoning but generate longer responses; Instruct models are more efficient for straightforward tasks.
|
| 145 |
+
- **Code Execution**: Models generate code but do not execute it; always validate outputs in sandboxed environments.
|
| 146 |
+
- **Domain Specificity**: While trained on diverse codebases, performance may vary on highly specialized or proprietary frameworks.
|
| 147 |
+
- **Factuality**: Models may generate plausible but incorrect code; verify critical implementations thoroughly.
|
| 148 |
+
|
| 149 |
+
## Citation
|
| 150 |
+
|
| 151 |
+
If you find our work helpful, please cite:
|
| 152 |
+
|
| 153 |
+
```bibtex
|
| 154 |
+
@article{iquest-coder-v1-2025,
|
| 155 |
+
title={IQuest-Coder-V1 Technical Report},
|
| 156 |
+
author={IQuest Coder Team},
|
| 157 |
+
url={https://github.com/IQuestLab/IQuest-Coder-V1/blob/main/papers/IQuest_Coder_Technical_Report.pdf}
|
| 158 |
+
year={2025}
|
| 159 |
+
}
|
| 160 |
+
@article{codescaling,
|
| 161 |
+
title={Scaling Laws for Code: Every Programming Language Matters},
|
| 162 |
+
author={Yang, Jian and Guo, Shawn and Jing, Lin and Zhang, Wei and Liu, Aishan and Hao, Chuan and Li, Zhoujun and Zhao, Wayne Xin and Liu, Xianglong and Lv, Weifeng and others},
|
| 163 |
+
journal={arXiv preprint arXiv:2512.13472},
|
| 164 |
+
year={2025}
|
| 165 |
+
}
|
| 166 |
+
@article{close_the_loop,
|
| 167 |
+
title={Close the Loop: Synthesizing Infinite Tool-Use Data via Multi-Agent Role-Playing},
|
| 168 |
+
author={Yuwen Li, Wei Zhang, Zelong Huang, Mason Yang, Jiajun Wu, Shawn Guo, Huahao Hu, Lingyi Sun, Jian Yang, Mingjie Tang, Byran Dai},
|
| 169 |
+
journal={arXiv preprint arXiv:2512.23611},
|
| 170 |
+
year={2025}
|
| 171 |
+
}
|
| 172 |
+
@article{loopcoder,
|
| 173 |
+
title={LoopCoder: Scaling Code Intelligence via Looped Language Models},
|
| 174 |
+
author={Jian Yang, Wei Zhang, Shawn Guo, Yizhi Li, Lin Jing, Zhengmao Ye, Shark Liu, Yuyang Song, Jiajun Wu, Che Liu, T. Zheng, Siwei Wu, L. Liao, X. Ma, Chuan Hao, Ran Tao, Yan Xing, Jianzhou Wang, Mingjie Tang, Aishan Liu, Zhoujun Li, Xianglong Liu, Weifeng Lv1, Bryan Dai},
|
| 175 |
+
year={2025}
|
| 176 |
+
}
|
| 177 |
+
@article{swe_compress,
|
| 178 |
+
title={Context as a Tool: Context Management for Long-Horizon SWE-Agents},
|
| 179 |
+
author={hukai Liu, Jian Yang, Bo Jiang, Yizhi Li, Jinyang Guo, Xianglong Liu, Bryan Dai},
|
| 180 |
+
journal={arXiv preprint arXiv:2512.22087},
|
| 181 |
+
year={2025}
|
| 182 |
+
}
|
| 183 |
+
```
|
__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""IQuestLoopCoder model package."""
|
| 2 |
+
|
| 3 |
+
from .configuration_iquestloopcoder import IQuestLoopCoderConfig
|
| 4 |
+
from .modeling_iquestloopcoder import (
|
| 5 |
+
IQuestLoopCoderPreTrainedModel,
|
| 6 |
+
IQuestLoopCoderModel,
|
| 7 |
+
IQuestLoopCoderForCausalLM,
|
| 8 |
+
IQuestLoopCoderCache,
|
| 9 |
+
)
|
| 10 |
+
from .tokenization_iquestcoder import IQuestCoderTokenizer
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from .tokenization_iquestcoder import IQuestCoderTokenizerFast
|
| 14 |
+
except ImportError:
|
| 15 |
+
IQuestCoderTokenizerFast = None
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"IQuestLoopCoderConfig",
|
| 19 |
+
"IQuestLoopCoderPreTrainedModel",
|
| 20 |
+
"IQuestLoopCoderModel",
|
| 21 |
+
"IQuestLoopCoderForCausalLM",
|
| 22 |
+
"IQuestLoopCoderCache",
|
| 23 |
+
"IQuestCoderTokenizer",
|
| 24 |
+
"IQuestCoderTokenizerFast",
|
| 25 |
+
]
|
| 26 |
+
|
added_tokens.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</think>": 75873,
|
| 3 |
+
"</tool_call>": 75877,
|
| 4 |
+
"</tool_response>": 75879,
|
| 5 |
+
"</tools>": 75875,
|
| 6 |
+
"<CLS>": 75858,
|
| 7 |
+
"<EOD>": 75860,
|
| 8 |
+
"<MASK>": 75861,
|
| 9 |
+
"<PAD>": 75862,
|
| 10 |
+
"<SEP>": 75859,
|
| 11 |
+
"<think>": 75872,
|
| 12 |
+
"<tool_call>": 75876,
|
| 13 |
+
"<tool_response>": 75878,
|
| 14 |
+
"<tools>": 75874,
|
| 15 |
+
"<|CLS|>": 75880,
|
| 16 |
+
"<|EOD|>": 75882,
|
| 17 |
+
"<|MASK|>": 75883,
|
| 18 |
+
"<|PAD|>": 75884,
|
| 19 |
+
"<|SEP|>": 75881,
|
| 20 |
+
"<|endoftext|>": 75869,
|
| 21 |
+
"<|file_sep|>": 75871,
|
| 22 |
+
"<|fim_middle|>": 75866,
|
| 23 |
+
"<|fim_pad|>": 75868,
|
| 24 |
+
"<|fim_prefix|>": 75865,
|
| 25 |
+
"<|fim_suffix|>": 75867,
|
| 26 |
+
"<|im_end|>": 75864,
|
| 27 |
+
"<|im_start|>": 75863,
|
| 28 |
+
"<|repo_name|>": 75870
|
| 29 |
+
}
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- else %}
|
| 6 |
+
{{- 'You are LoopCoder, a helpful assistant developed by IQuest.' }}
|
| 7 |
+
{%- endif %}
|
| 8 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 9 |
+
{%- for tool in tools %}
|
| 10 |
+
{{- "\n" }}
|
| 11 |
+
{{- tool | tojson }}
|
| 12 |
+
{%- endfor %}
|
| 13 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 14 |
+
{%- else %}
|
| 15 |
+
{%- if messages[0].role == 'system' %}
|
| 16 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 17 |
+
{%- else %}
|
| 18 |
+
{{- '<|im_start|>system\nYou are LoopCoder, a helpful assistant developed by IQuest.<|im_end|>\n' }}
|
| 19 |
+
{%- endif %}
|
| 20 |
+
{%- endif %}
|
| 21 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 22 |
+
{%- for message in messages[::-1] %}
|
| 23 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 24 |
+
{%- if ns.multi_step_tool and message.role == "user" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 25 |
+
{%- set ns.multi_step_tool = false %}
|
| 26 |
+
{%- set ns.last_query_index = index %}
|
| 27 |
+
{%- endif %}
|
| 28 |
+
{%- endfor %}
|
| 29 |
+
{%- for message in messages %}
|
| 30 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 31 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
| 32 |
+
{%- elif message.role == "assistant" %}
|
| 33 |
+
{%- set content = message.content %}
|
| 34 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 35 |
+
{%- if message.tool_calls %}
|
| 36 |
+
{%- for tool_call in message.tool_calls %}
|
| 37 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 38 |
+
{{- '\n' }}
|
| 39 |
+
{%- endif %}
|
| 40 |
+
{%- if tool_call.function %}
|
| 41 |
+
{%- set tool_call = tool_call.function %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 44 |
+
{{- tool_call.name }}
|
| 45 |
+
{{- '", "arguments": ' }}
|
| 46 |
+
{%- if tool_call.arguments is string %}
|
| 47 |
+
{{- tool_call.arguments }}
|
| 48 |
+
{%- else %}
|
| 49 |
+
{{- tool_call.arguments | tojson }}
|
| 50 |
+
{%- endif %}
|
| 51 |
+
{{- '}\n</tool_call>' }}
|
| 52 |
+
{%- endfor %}
|
| 53 |
+
{%- endif %}
|
| 54 |
+
{{- '<|im_end|>\n' }}
|
| 55 |
+
{%- elif message.role == "tool" %}
|
| 56 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 57 |
+
{{- '<|im_start|>user' }}
|
| 58 |
+
{%- endif %}
|
| 59 |
+
{{- '\n<tool_response>\n' }}
|
| 60 |
+
{{- message.content }}
|
| 61 |
+
{{- '\n</tool_response>' }}
|
| 62 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 63 |
+
{{- '<|im_end|>\n' }}
|
| 64 |
+
{%- endif %}
|
| 65 |
+
{%- endif %}
|
| 66 |
+
{%- endfor %}
|
| 67 |
+
{%- if add_generation_prompt %}
|
| 68 |
+
{{- '<|im_start|>assistant\n' }}
|
| 69 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "iquestloopcoder",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"IQuestLoopCoderForCausalLM"
|
| 5 |
+
],
|
| 6 |
+
"model_type": "iquestloopcoder",
|
| 7 |
+
"vocab_size": 76800,
|
| 8 |
+
"hidden_size": 5120,
|
| 9 |
+
"intermediate_size": 27648,
|
| 10 |
+
"num_hidden_layers": 80,
|
| 11 |
+
"eos_token_id": [2, 75864, 75869],
|
| 12 |
+
"num_attention_heads": 40,
|
| 13 |
+
"num_key_value_heads": 8,
|
| 14 |
+
"quantization_config": {
|
| 15 |
+
"config_groups": {
|
| 16 |
+
"group_0": {
|
| 17 |
+
"format": "pack-quantized",
|
| 18 |
+
"input_activations": null,
|
| 19 |
+
"output_activations": null,
|
| 20 |
+
"targets": [
|
| 21 |
+
"Linear"
|
| 22 |
+
],
|
| 23 |
+
"weights": {
|
| 24 |
+
"actorder": null,
|
| 25 |
+
"block_structure": null,
|
| 26 |
+
"dynamic": false,
|
| 27 |
+
"group_size": 32,
|
| 28 |
+
"num_bits": 4,
|
| 29 |
+
"observer": "mse",
|
| 30 |
+
"observer_kwargs": {},
|
| 31 |
+
"strategy": "group",
|
| 32 |
+
"symmetric": true,
|
| 33 |
+
"type": "int"
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
},
|
| 37 |
+
"format": "pack-quantized",
|
| 38 |
+
"global_compression_ratio": null,
|
| 39 |
+
"ignore": [
|
| 40 |
+
"lm_head"
|
| 41 |
+
],
|
| 42 |
+
"kv_cache_scheme": null,
|
| 43 |
+
"quant_method": "compressed-tensors",
|
| 44 |
+
"quantization_status": "compressed",
|
| 45 |
+
"sparsity_config": {},
|
| 46 |
+
"transform_config": {},
|
| 47 |
+
"version": "0.12.3.a20251110"
|
| 48 |
+
},
|
| 49 |
+
"head_dim": 128,
|
| 50 |
+
"hidden_act": "silu",
|
| 51 |
+
"max_position_embeddings": 131072,
|
| 52 |
+
"initializer_range": 0.02,
|
| 53 |
+
"rms_norm_eps": 1e-05,
|
| 54 |
+
"use_cache": true,
|
| 55 |
+
"tie_word_embeddings": false,
|
| 56 |
+
"rope_theta": 500000,
|
| 57 |
+
"attention_bias": false,
|
| 58 |
+
"attention_dropout": 0.0,
|
| 59 |
+
"mlp_bias": false,
|
| 60 |
+
"loop_num": 2,
|
| 61 |
+
"loop_window_size": 64,
|
| 62 |
+
"torch_dtype": "bfloat16",
|
| 63 |
+
"transformers_version": "4.56.0",
|
| 64 |
+
"auto_map": {
|
| 65 |
+
"AutoConfig": "configuration_iquestloopcoder.IQuestLoopCoderConfig",
|
| 66 |
+
"AutoModel": "modeling_iquestloopcoder.IQuestLoopCoderModel",
|
| 67 |
+
"AutoModelForCausalLM": "modeling_iquestloopcoder.IQuestLoopCoderForCausalLM"
|
| 68 |
+
}
|
| 69 |
+
}
|
configuration_iquestloopcoder.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 IQuestLoopCoder Authors
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
"""IQuestLoopCoder model configuration"""
|
| 6 |
+
|
| 7 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 8 |
+
from transformers.utils import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.get_logger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class IQuestLoopCoderConfig(PretrainedConfig):
|
| 14 |
+
r"""
|
| 15 |
+
Configuration class for IQuestLoopCoder model.
|
| 16 |
+
|
| 17 |
+
IQuestLoopCoder extends the standard LLaMA architecture with a loop mechanism:
|
| 18 |
+
- Loop 1: Standard attention, stores K1, V1
|
| 19 |
+
- Loop 2+: Mixed attention with gated combination of global (K1,V1) and local (K2,V2) KV
|
| 20 |
+
|
| 21 |
+
The gate is computed as: gate = sigmoid(W @ Q + bias)
|
| 22 |
+
Mixed output = gate * Attention(Q, K1, V1) + (1 - gate) * SlidingWindowAttention(Q, K2, V2)
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
vocab_size (`int`, *optional*, defaults to 76800):
|
| 26 |
+
Vocabulary size of the model.
|
| 27 |
+
hidden_size (`int`, *optional*, defaults to 5120):
|
| 28 |
+
Dimension of the hidden representations.
|
| 29 |
+
intermediate_size (`int`, *optional*, defaults to 27648):
|
| 30 |
+
Dimension of the MLP representations (FFN hidden size).
|
| 31 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
| 32 |
+
Number of hidden layers in the Transformer decoder.
|
| 33 |
+
num_attention_heads (`int`, *optional*, defaults to 40):
|
| 34 |
+
Number of attention heads for each attention layer.
|
| 35 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 36 |
+
Number of key-value heads (for GQA). If None, defaults to num_attention_heads.
|
| 37 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 38 |
+
Dimension of each attention head (hidden_size // num_attention_heads).
|
| 39 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 40 |
+
Activation function in the MLP.
|
| 41 |
+
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
| 42 |
+
Maximum sequence length.
|
| 43 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 44 |
+
Standard deviation for weight initialization.
|
| 45 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 46 |
+
Epsilon for RMS normalization layers.
|
| 47 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 48 |
+
Whether to use past key/values for generation.
|
| 49 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 50 |
+
Whether to tie input and output embeddings.
|
| 51 |
+
rope_theta (`float`, *optional*, defaults to 500000.0):
|
| 52 |
+
Base value for rotary position embeddings.
|
| 53 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 54 |
+
Whether to use bias in attention layers.
|
| 55 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 56 |
+
Dropout ratio for attention weights.
|
| 57 |
+
mlp_bias (`bool`, *optional*, defaults to `False`):
|
| 58 |
+
Whether to use bias in MLP layers.
|
| 59 |
+
|
| 60 |
+
# Loop-specific parameters
|
| 61 |
+
loop_num (`int`, *optional*, defaults to 2):
|
| 62 |
+
Number of loops through the decoder.
|
| 63 |
+
loop_window_size (`int`, *optional*, defaults to 64):
|
| 64 |
+
Window size for sliding window attention in Loop 2+.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
model_type = "iquestloopcoder"
|
| 68 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
vocab_size=76800,
|
| 73 |
+
hidden_size=5120,
|
| 74 |
+
intermediate_size=27648,
|
| 75 |
+
num_hidden_layers=80,
|
| 76 |
+
num_attention_heads=40,
|
| 77 |
+
num_key_value_heads=8,
|
| 78 |
+
head_dim=128,
|
| 79 |
+
hidden_act="silu",
|
| 80 |
+
max_position_embeddings=8192,
|
| 81 |
+
initializer_range=0.02,
|
| 82 |
+
rms_norm_eps=1e-5,
|
| 83 |
+
use_cache=True,
|
| 84 |
+
pad_token_id=None,
|
| 85 |
+
bos_token_id=1,
|
| 86 |
+
eos_token_id=2,
|
| 87 |
+
tie_word_embeddings=False,
|
| 88 |
+
rope_theta=500000.0,
|
| 89 |
+
rope_scaling=None,
|
| 90 |
+
attention_bias=False,
|
| 91 |
+
attention_dropout=0.0,
|
| 92 |
+
mlp_bias=False,
|
| 93 |
+
# Loop-specific parameters
|
| 94 |
+
loop_num=2,
|
| 95 |
+
loop_window_size=64,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
self.vocab_size = vocab_size
|
| 99 |
+
self.max_position_embeddings = max_position_embeddings
|
| 100 |
+
self.hidden_size = hidden_size
|
| 101 |
+
self.intermediate_size = intermediate_size
|
| 102 |
+
self.num_hidden_layers = num_hidden_layers
|
| 103 |
+
self.num_attention_heads = num_attention_heads
|
| 104 |
+
self.head_dim = head_dim
|
| 105 |
+
|
| 106 |
+
# GQA support
|
| 107 |
+
if num_key_value_heads is None:
|
| 108 |
+
num_key_value_heads = num_attention_heads
|
| 109 |
+
self.num_key_value_heads = num_key_value_heads
|
| 110 |
+
|
| 111 |
+
self.hidden_act = hidden_act
|
| 112 |
+
self.initializer_range = initializer_range
|
| 113 |
+
self.rms_norm_eps = rms_norm_eps
|
| 114 |
+
self.use_cache = use_cache
|
| 115 |
+
self.rope_theta = rope_theta
|
| 116 |
+
self.rope_scaling = rope_scaling
|
| 117 |
+
self.attention_bias = attention_bias
|
| 118 |
+
self.attention_dropout = attention_dropout
|
| 119 |
+
self.mlp_bias = mlp_bias
|
| 120 |
+
|
| 121 |
+
# Loop-specific
|
| 122 |
+
self.loop_num = loop_num
|
| 123 |
+
self.loop_window_size = loop_window_size
|
| 124 |
+
|
| 125 |
+
super().__init__(
|
| 126 |
+
pad_token_id=pad_token_id,
|
| 127 |
+
bos_token_id=bos_token_id,
|
| 128 |
+
eos_token_id=eos_token_id,
|
| 129 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 130 |
+
**kwargs,
|
| 131 |
+
)
|
| 132 |
+
|
generation_config.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"bos_token_id": 1,
|
| 4 |
+
"eos_token_id": [2, 75864, 75869],
|
| 5 |
+
"transformers_version": "4.56.0"
|
| 6 |
+
}
|
model-00001-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ccdbde9f269783df00bc8f36aaccc992e739118f786ef57c7dee0d3c1b1a8ee
|
| 3 |
+
size 4936193728
|
model-00002-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3335627ceed319afd837e9271c74ac4625c60bcbad33a5a69ac1d0379c9e25b
|
| 3 |
+
size 4937245048
|
model-00003-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:107b8fb093af3137329e3388257bbb49e23ab052edda7ee5445b982299735ece
|
| 3 |
+
size 4937245048
|
model-00004-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:791be88d98bd776782ea7a035833546f60ff2219e62ae0873d9cddca33b775c4
|
| 3 |
+
size 4937245048
|
model-00005-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:af9cbb3ddea39a45fbd91353c098068120dd8918e7e5c562edb008e6d52f0562
|
| 3 |
+
size 3769099040
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_iquestloopcoder.py
ADDED
|
@@ -0,0 +1,1113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Modified MIT License
|
| 3 |
+
|
| 4 |
+
Software Copyright© 2025 IQuest Research
|
| 5 |
+
|
| 6 |
+
Our only modification is that, if the Software (or any derivative works
|
| 7 |
+
thereof) is used for any of your commercial products or services, you shall
|
| 8 |
+
prominently display "IQuest Coder" on the user interface of such product or
|
| 9 |
+
service.
|
| 10 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 11 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 12 |
+
in the Software without restriction, including without limitation the rights
|
| 13 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 14 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 15 |
+
furnished to do so, subject to the following conditions:
|
| 16 |
+
|
| 17 |
+
The above copyright notice and this permission notice shall be included in all
|
| 18 |
+
copies or substantial portions of the Software.
|
| 19 |
+
|
| 20 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 21 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 22 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 23 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 24 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 25 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
import logging
|
| 29 |
+
from typing import Any, Callable, Optional, Union, Tuple, List
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
from torch import nn
|
| 33 |
+
|
| 34 |
+
from transformers.activations import ACT2FN
|
| 35 |
+
from transformers.cache_utils import Cache
|
| 36 |
+
from transformers.generation import GenerationMixin
|
| 37 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 38 |
+
from transformers.masking_utils import (
|
| 39 |
+
create_causal_mask,
|
| 40 |
+
create_sliding_window_causal_mask,
|
| 41 |
+
)
|
| 42 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 43 |
+
from transformers.modeling_layers import (
|
| 44 |
+
GenericForQuestionAnswering,
|
| 45 |
+
GenericForSequenceClassification,
|
| 46 |
+
GenericForTokenClassification,
|
| 47 |
+
GradientCheckpointingLayer,
|
| 48 |
+
)
|
| 49 |
+
from transformers.modeling_outputs import (
|
| 50 |
+
BaseModelOutputWithPast,
|
| 51 |
+
CausalLMOutputWithPast,
|
| 52 |
+
)
|
| 53 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 54 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 55 |
+
from transformers.processing_utils import Unpack
|
| 56 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
|
| 57 |
+
from transformers.utils.generic import check_model_inputs
|
| 58 |
+
from .configuration_iquestloopcoder import IQuestLoopCoderConfig
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
logger = logging.getLogger(__name__)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def needs_iquestloopcoder_cache(
|
| 65 |
+
cache: Optional[Cache]
|
| 66 |
+
) -> bool:
|
| 67 |
+
# need to test more conditions
|
| 68 |
+
if cache is None:
|
| 69 |
+
return True
|
| 70 |
+
if isinstance(cache, IQuestLoopCoderCache):
|
| 71 |
+
return False
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
class IQuestLoopCoderMLP(nn.Module):
|
| 75 |
+
def __init__(self, config):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.config = config
|
| 78 |
+
self.hidden_size = config.hidden_size
|
| 79 |
+
self.intermediate_size = config.intermediate_size
|
| 80 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 81 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 82 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 83 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 87 |
+
return down_proj
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def rotate_half(x):
|
| 91 |
+
"""Rotates half the hidden dims of the input."""
|
| 92 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 93 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 94 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 98 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
q (`torch.Tensor`): The query tensor.
|
| 102 |
+
k (`torch.Tensor`): The key tensor.
|
| 103 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 104 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 105 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 106 |
+
Deprecated and unused.
|
| 107 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 108 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 109 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 110 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 111 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 112 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 113 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 114 |
+
Returns:
|
| 115 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 116 |
+
"""
|
| 117 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 118 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 119 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 120 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 121 |
+
return q_embed, k_embed
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 125 |
+
"""
|
| 126 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 127 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 128 |
+
"""
|
| 129 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 130 |
+
if n_rep == 1:
|
| 131 |
+
return hidden_states
|
| 132 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
| 133 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
| 134 |
+
)
|
| 135 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class IQuestLoopCoderCache(Cache):
|
| 139 |
+
"""Cache implementation for IQuestLoopCoder that manages shared and local KV caches.
|
| 140 |
+
|
| 141 |
+
- shared_key_cache/shared_value_cache: Stores KV from Loop 1 (global context)
|
| 142 |
+
- local_key_cache/local_value_cache: Stores KV from Loop 2+ (local window, only window_size tokens)
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, window_size: int, num_layers: int, loop_num: int=2):
|
| 146 |
+
# We intentionally don't call super().__init__ because the parent assumes static cache sizes.
|
| 147 |
+
self.window_size = window_size
|
| 148 |
+
self.num_layers = num_layers
|
| 149 |
+
self.loop_num = loop_num
|
| 150 |
+
|
| 151 |
+
# Shared cache: stores Loop 1 KV (global context)
|
| 152 |
+
self.shared_key_cache: List[Optional[torch.Tensor]] = [None] * self.num_layers
|
| 153 |
+
self.shared_value_cache: List[Optional[torch.Tensor]] = [None] * self.num_layers
|
| 154 |
+
|
| 155 |
+
# Local cache: stores Loop 2+ KV (sliding window, only window_size tokens)
|
| 156 |
+
self.local_key_cache: List[Optional[torch.Tensor]] = [None] * (self.loop_num-1) * self.num_layers
|
| 157 |
+
self.local_value_cache: List[Optional[torch.Tensor]] = [None] * (self.loop_num-1) * self.num_layers
|
| 158 |
+
|
| 159 |
+
self.layers: List[Any] = [] # attribute expected by HF Cache utilities
|
| 160 |
+
self._seen_tokens = 0
|
| 161 |
+
|
| 162 |
+
def update_shared(
|
| 163 |
+
self,
|
| 164 |
+
key_states: torch.Tensor,
|
| 165 |
+
value_states: torch.Tensor,
|
| 166 |
+
layer_idx: int,
|
| 167 |
+
cache_kwargs: Optional[dict] = None,
|
| 168 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 169 |
+
"""Update shared cache (Loop 1 KV)."""
|
| 170 |
+
# only store the first loop's kv cache
|
| 171 |
+
loop_idx = cache_kwargs.get("loop_idx", 0)
|
| 172 |
+
assert loop_idx == 0
|
| 173 |
+
if layer_idx < 0 or layer_idx >= self.num_layers:
|
| 174 |
+
raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
|
| 175 |
+
|
| 176 |
+
cached_key = self.shared_key_cache[layer_idx]
|
| 177 |
+
cached_value = self.shared_value_cache[layer_idx]
|
| 178 |
+
|
| 179 |
+
if cached_key is None:
|
| 180 |
+
self.shared_key_cache[layer_idx] = key_states
|
| 181 |
+
self.shared_value_cache[layer_idx] = value_states
|
| 182 |
+
else:
|
| 183 |
+
if (
|
| 184 |
+
key_states.shape[0] != cached_key.shape[0]
|
| 185 |
+
or key_states.shape[1] != cached_key.shape[1]
|
| 186 |
+
or key_states.shape[3] != cached_key.shape[3]
|
| 187 |
+
):
|
| 188 |
+
raise ValueError(
|
| 189 |
+
"Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
|
| 190 |
+
)
|
| 191 |
+
assert key_states.shape[2] == 1
|
| 192 |
+
assert value_states.shape[2] == 1
|
| 193 |
+
self.shared_key_cache[layer_idx] = torch.cat([cached_key, key_states], dim=2)
|
| 194 |
+
self.shared_value_cache[layer_idx] = torch.cat([cached_value, value_states], dim=2)
|
| 195 |
+
|
| 196 |
+
result_key = self.shared_key_cache[layer_idx]
|
| 197 |
+
result_value = self.shared_value_cache[layer_idx]
|
| 198 |
+
assert result_key is not None and result_value is not None
|
| 199 |
+
|
| 200 |
+
# Track sequence length
|
| 201 |
+
self._seen_tokens = result_key.shape[2]
|
| 202 |
+
return result_key, result_value
|
| 203 |
+
|
| 204 |
+
def update_local(
|
| 205 |
+
self,
|
| 206 |
+
key_states: torch.Tensor,
|
| 207 |
+
value_states: torch.Tensor,
|
| 208 |
+
layer_idx: int,
|
| 209 |
+
cache_kwargs: Optional[dict] = None,
|
| 210 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 211 |
+
"""Update local cache (Loop 2+ KV) with sliding window management.
|
| 212 |
+
|
| 213 |
+
Ensures the local cache always contains at most window_size tokens.
|
| 214 |
+
Local cache only stores loop_idx > 0 (i.e., loop_idx = 1, 2, ...).
|
| 215 |
+
For loop_idx = 1, cache_idx = layer_idx + 0 * num_layers = layer_idx (0 to num_layers-1)
|
| 216 |
+
For loop_idx = 2, cache_idx = layer_idx + 1 * num_layers (num_layers to 2*num_layers-1)
|
| 217 |
+
"""
|
| 218 |
+
# only store the local kv cache for loop_idx > 0
|
| 219 |
+
loop_idx = cache_kwargs.get("loop_idx", 0)
|
| 220 |
+
assert loop_idx > 0, f"update_local should only be called for loop_idx > 0, got {loop_idx}"
|
| 221 |
+
if layer_idx < 0 or layer_idx >= self.num_layers:
|
| 222 |
+
raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
|
| 223 |
+
|
| 224 |
+
# Local cache size is (loop_num-1) * num_layers
|
| 225 |
+
# loop_idx = 1 maps to indices 0 to num_layers-1
|
| 226 |
+
# loop_idx = 2 maps to indices num_layers to 2*num_layers-1
|
| 227 |
+
# So offset = (loop_idx - 1) * num_layers
|
| 228 |
+
cache_idx = layer_idx + (loop_idx - 1) * self.num_layers
|
| 229 |
+
|
| 230 |
+
# Validate cache_idx is within bounds
|
| 231 |
+
max_cache_idx = (self.loop_num - 1) * self.num_layers
|
| 232 |
+
if cache_idx >= max_cache_idx:
|
| 233 |
+
raise IndexError(
|
| 234 |
+
f"cache_idx {cache_idx} out of range. "
|
| 235 |
+
f"loop_idx={loop_idx}, layer_idx={layer_idx}, "
|
| 236 |
+
f"max_cache_idx={max_cache_idx - 1}"
|
| 237 |
+
)
|
| 238 |
+
cached_key = self.local_key_cache[cache_idx]
|
| 239 |
+
cached_value = self.local_value_cache[cache_idx]
|
| 240 |
+
|
| 241 |
+
if cached_key is None:
|
| 242 |
+
# First token in local cache, for prefill
|
| 243 |
+
# If prefill sequence is longer than window_size, only keep the last window_size tokens
|
| 244 |
+
seq_len = key_states.shape[2]
|
| 245 |
+
if seq_len > self.window_size:
|
| 246 |
+
# Keep only the last window_size tokens
|
| 247 |
+
start_idx = seq_len - self.window_size
|
| 248 |
+
self.local_key_cache[cache_idx] = key_states[:, :, start_idx:, :]
|
| 249 |
+
self.local_value_cache[cache_idx] = value_states[:, :, start_idx:, :]
|
| 250 |
+
else:
|
| 251 |
+
self.local_key_cache[cache_idx] = key_states
|
| 252 |
+
self.local_value_cache[cache_idx] = value_states
|
| 253 |
+
else:
|
| 254 |
+
# store the local kv cache for decode
|
| 255 |
+
if (
|
| 256 |
+
key_states.shape[0] != cached_key.shape[0]
|
| 257 |
+
or key_states.shape[1] != cached_key.shape[1]
|
| 258 |
+
or key_states.shape[3] != cached_key.shape[3]
|
| 259 |
+
):
|
| 260 |
+
raise ValueError(
|
| 261 |
+
"Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
|
| 262 |
+
)
|
| 263 |
+
assert cached_value is not None
|
| 264 |
+
assert key_states.shape[2] == 1
|
| 265 |
+
assert value_states.shape[2] == 1
|
| 266 |
+
# Concatenate new tokens
|
| 267 |
+
new_key = torch.cat([cached_key, key_states], dim=2)
|
| 268 |
+
new_value = torch.cat([cached_value, value_states], dim=2)
|
| 269 |
+
|
| 270 |
+
# Ensure the total length doesn't exceed window_size
|
| 271 |
+
total_len = new_key.shape[2]
|
| 272 |
+
if total_len > self.window_size:
|
| 273 |
+
# Keep only the last window_size tokens
|
| 274 |
+
self.local_key_cache[cache_idx] = new_key[:, :, -self.window_size:, :]
|
| 275 |
+
self.local_value_cache[cache_idx] = new_value[:, :, -self.window_size:, :]
|
| 276 |
+
else:
|
| 277 |
+
self.local_key_cache[cache_idx] = new_key
|
| 278 |
+
self.local_value_cache[cache_idx] = new_value
|
| 279 |
+
|
| 280 |
+
result_key = self.local_key_cache[cache_idx]
|
| 281 |
+
result_value = self.local_value_cache[cache_idx]
|
| 282 |
+
assert result_key is not None and result_value is not None
|
| 283 |
+
# Ensure the result is at most window_size (can be less during prefill when sequence is shorter)
|
| 284 |
+
assert result_key.shape[2] <= self.window_size, f"Local cache size {result_key.shape[2]} exceeds window_size {self.window_size}"
|
| 285 |
+
|
| 286 |
+
return result_key, result_value
|
| 287 |
+
|
| 288 |
+
def get_shared(self, layer_idx: int|List[int]) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 289 |
+
"""Get shared cache for some layer."""
|
| 290 |
+
if isinstance(layer_idx, list):
|
| 291 |
+
return [self.get_shared(layer_idx) for layer_idx in layer_idx]
|
| 292 |
+
if layer_idx < 0 or layer_idx >= self.num_layers:
|
| 293 |
+
raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
|
| 294 |
+
return self.shared_key_cache[layer_idx], self.shared_value_cache[layer_idx]
|
| 295 |
+
|
| 296 |
+
def get_local(self, layer_idx: int|List[int], loop_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 297 |
+
"""Get local cache for a layer."""
|
| 298 |
+
assert loop_idx > 0, f"get_local should only be called for loop_idx > 0, got {loop_idx}"
|
| 299 |
+
if isinstance(layer_idx, list):
|
| 300 |
+
return [self.get_local(layer_idx, loop_idx) for layer_idx in layer_idx]
|
| 301 |
+
if layer_idx < 0 or layer_idx >= self.num_layers:
|
| 302 |
+
raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
|
| 303 |
+
|
| 304 |
+
# Local cache size is (loop_num-1) * num_layers
|
| 305 |
+
# loop_idx = 1 maps to indices 0 to num_layers-1
|
| 306 |
+
# loop_idx = 2 maps to indices num_layers to 2*num_layers-1
|
| 307 |
+
# So offset = (loop_idx - 1) * num_layers
|
| 308 |
+
cache_idx = layer_idx + (loop_idx - 1) * self.num_layers
|
| 309 |
+
|
| 310 |
+
# Validate cache_idx is within bounds
|
| 311 |
+
max_cache_idx = (self.loop_num - 1) * self.num_layers
|
| 312 |
+
if cache_idx >= max_cache_idx:
|
| 313 |
+
raise IndexError(
|
| 314 |
+
f"cache_idx {cache_idx} out of range. "
|
| 315 |
+
f"loop_idx={loop_idx}, layer_idx={layer_idx}, "
|
| 316 |
+
f"max_cache_idx={max_cache_idx - 1}"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return self.local_key_cache[cache_idx], self.local_value_cache[cache_idx]
|
| 320 |
+
|
| 321 |
+
def update(
|
| 322 |
+
self,
|
| 323 |
+
key_states: torch.Tensor,
|
| 324 |
+
value_states: torch.Tensor,
|
| 325 |
+
layer_idx: int,
|
| 326 |
+
cache_kwargs: Optional[dict] = None,
|
| 327 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 328 |
+
"""Default update method (for compatibility, updates shared cache)."""
|
| 329 |
+
loop_idx = cache_kwargs.get("loop_idx", 0)
|
| 330 |
+
assert loop_idx < self.loop_num
|
| 331 |
+
if loop_idx == 0:
|
| 332 |
+
return self.update_shared(key_states, value_states, layer_idx, cache_kwargs)
|
| 333 |
+
else:
|
| 334 |
+
return self.update_local(key_states, value_states, layer_idx, cache_kwargs)
|
| 335 |
+
|
| 336 |
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
| 337 |
+
"""Get sequence length from shared cache."""
|
| 338 |
+
if layer_idx is None:
|
| 339 |
+
layer_idx = 0
|
| 340 |
+
if layer_idx < 0 or layer_idx >= self.loop_num * self.num_layers:
|
| 341 |
+
return 0
|
| 342 |
+
cached_key = self.shared_key_cache[layer_idx]
|
| 343 |
+
if cached_key is None:
|
| 344 |
+
return 0
|
| 345 |
+
return cached_key.shape[2]
|
| 346 |
+
|
| 347 |
+
def get_max_length(self) -> Optional[int]:
|
| 348 |
+
return None
|
| 349 |
+
|
| 350 |
+
def get_usable_length(
|
| 351 |
+
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
| 352 |
+
) -> int:
|
| 353 |
+
return self.get_seq_length(layer_idx)
|
| 354 |
+
|
| 355 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 356 |
+
# pass
|
| 357 |
+
raise NotImplementedError("Reorder cache for beam search is not implemented")
|
| 358 |
+
"""Reorder cache for beam search.
|
| 359 |
+
|
| 360 |
+
Reorders both shared cache (Loop 1) and local cache (Loop 2+) according to beam_idx.
|
| 361 |
+
"""
|
| 362 |
+
# Reorder shared cache (Loop 1, loop_idx=0)
|
| 363 |
+
for layer_idx in range(self.num_layers):
|
| 364 |
+
if self.shared_key_cache[layer_idx] is not None:
|
| 365 |
+
device = self.shared_key_cache[layer_idx].device
|
| 366 |
+
self.shared_key_cache[layer_idx] = self.shared_key_cache[layer_idx].index_select(0, beam_idx.to(device))
|
| 367 |
+
self.shared_value_cache[layer_idx] = self.shared_value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
| 368 |
+
|
| 369 |
+
# Reorder local cache (Loop 2+, loop_idx > 0)
|
| 370 |
+
# Local cache size is (loop_num-1) * num_layers
|
| 371 |
+
for cache_idx in range(len(self.local_key_cache)):
|
| 372 |
+
if self.local_key_cache[cache_idx] is not None:
|
| 373 |
+
device = self.local_key_cache[cache_idx].device
|
| 374 |
+
self.local_key_cache[cache_idx] = self.local_key_cache[cache_idx].index_select(0, beam_idx.to(device))
|
| 375 |
+
self.local_value_cache[cache_idx] = self.local_value_cache[cache_idx].index_select(0, beam_idx.to(device))
|
| 376 |
+
|
| 377 |
+
@property
|
| 378 |
+
def is_compileable(self) -> bool:
|
| 379 |
+
return False
|
| 380 |
+
|
| 381 |
+
def clear(self) -> None:
|
| 382 |
+
"""Clear all caches."""
|
| 383 |
+
logger.debug("Clearing IQuestLoopCoderCache")
|
| 384 |
+
self.shared_key_cache = [None] * self.num_layers
|
| 385 |
+
self.shared_value_cache = [None] * self.num_layers
|
| 386 |
+
self.local_key_cache = [None] * self.num_layers * (self.loop_num-1)
|
| 387 |
+
self.local_value_cache = [None] * self.num_layers * (self.loop_num-1)
|
| 388 |
+
self._seen_tokens = 0
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def eager_attention_forward(
|
| 392 |
+
module: nn.Module,
|
| 393 |
+
query: torch.Tensor,
|
| 394 |
+
key: torch.Tensor,
|
| 395 |
+
value: torch.Tensor,
|
| 396 |
+
attention_mask: Optional[torch.Tensor],
|
| 397 |
+
scaling: float,
|
| 398 |
+
dropout: float = 0.0,
|
| 399 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 400 |
+
):
|
| 401 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 402 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 403 |
+
|
| 404 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 405 |
+
if attention_mask is not None:
|
| 406 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 407 |
+
attn_weights = attn_weights + causal_mask
|
| 408 |
+
|
| 409 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
| 410 |
+
query.dtype
|
| 411 |
+
)
|
| 412 |
+
attn_weights = nn.functional.dropout(
|
| 413 |
+
attn_weights, p=dropout, training=module.training
|
| 414 |
+
)
|
| 415 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 416 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 417 |
+
|
| 418 |
+
return attn_output, attn_weights
|
| 419 |
+
|
| 420 |
+
class LoopGateProjection(nn.Module):
|
| 421 |
+
"""Gate projection for mixed attention in Loop 2+.
|
| 422 |
+
|
| 423 |
+
Computes: g = sigmoid(linear(Q)) for each head independently.
|
| 424 |
+
This gate determines how much to use Loop1's KV (global) vs current loop's KV (local).
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
def __init__(self, num_heads: int, head_dim: int):
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.num_heads = num_heads
|
| 430 |
+
self.head_dim = head_dim
|
| 431 |
+
# Each head has its own gate: Linear(head_dim -> 1) per head
|
| 432 |
+
# Implemented as [num_heads, head_dim] weight + [num_heads] bias
|
| 433 |
+
self.weight = nn.Parameter(torch.zeros(num_heads, head_dim))
|
| 434 |
+
self.bias = nn.Parameter(torch.zeros(num_heads))
|
| 435 |
+
|
| 436 |
+
def forward(self, query: torch.Tensor) -> torch.Tensor:
|
| 437 |
+
"""Compute gate values from query tensor.
|
| 438 |
+
|
| 439 |
+
Args:
|
| 440 |
+
query: [batch, num_heads, seq_len, head_dim]
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
gate: [batch, num_heads, seq_len, 1]
|
| 444 |
+
"""
|
| 445 |
+
# query: [batch, num_heads, seq_len, head_dim]
|
| 446 |
+
# weight: [num_heads, head_dim]
|
| 447 |
+
# For each head h: gate_h = query[:, h, :, :] @ weight[h, :].T + bias[h]
|
| 448 |
+
# Using einsum: gate = einsum('bhsd,hd->bhs', query, weight) + bias
|
| 449 |
+
gate_logits = torch.einsum('bhsd,hd->bhs', query, self.weight) # [batch, num_heads, seq_len]
|
| 450 |
+
gate_logits = gate_logits + self.bias[None, :, None] # broadcast bias
|
| 451 |
+
gate = torch.sigmoid(gate_logits)
|
| 452 |
+
return gate.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
|
| 453 |
+
|
| 454 |
+
class IQuestLoopCoderAttention(nn.Module):
|
| 455 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 456 |
+
|
| 457 |
+
def __init__(self, config: IQuestLoopCoderConfig, layer_idx: int):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.config = config
|
| 460 |
+
assert layer_idx >= 0 and layer_idx < config.num_hidden_layers
|
| 461 |
+
self.layer_idx = layer_idx
|
| 462 |
+
|
| 463 |
+
self.head_dim = getattr(
|
| 464 |
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
| 465 |
+
)
|
| 466 |
+
self.num_key_value_groups = (
|
| 467 |
+
config.num_attention_heads // config.num_key_value_heads
|
| 468 |
+
)
|
| 469 |
+
self.scaling = self.head_dim**-0.5
|
| 470 |
+
self.attention_dropout = config.attention_dropout
|
| 471 |
+
self.is_causal = True
|
| 472 |
+
self.q_proj = nn.Linear(
|
| 473 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
|
| 474 |
+
)
|
| 475 |
+
self.k_proj = nn.Linear(
|
| 476 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
|
| 477 |
+
)
|
| 478 |
+
self.v_proj = nn.Linear(
|
| 479 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
|
| 480 |
+
)
|
| 481 |
+
self.o_proj = nn.Linear(
|
| 482 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
def forward(
|
| 486 |
+
self,
|
| 487 |
+
hidden_states: torch.Tensor,
|
| 488 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 489 |
+
attention_mask: Optional[torch.Tensor],
|
| 490 |
+
past_key_value: Optional[Cache] = None,
|
| 491 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 492 |
+
loop_idx: int = 0,
|
| 493 |
+
gate_proj: Optional[LoopGateProjection] = None,
|
| 494 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 495 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 496 |
+
if loop_idx == 0:
|
| 497 |
+
return self.forward_loop1(hidden_states, loop_idx, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
|
| 498 |
+
else:
|
| 499 |
+
return self.forward_loop2(hidden_states, loop_idx, position_embeddings, attention_mask, past_key_value, cache_position, gate_proj, **kwargs)
|
| 500 |
+
|
| 501 |
+
def forward_loop1(
|
| 502 |
+
self,
|
| 503 |
+
hidden_states: torch.Tensor,
|
| 504 |
+
loop_idx: int,
|
| 505 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 506 |
+
attention_mask: Optional[torch.Tensor],
|
| 507 |
+
past_key_value: Optional[IQuestLoopCoderCache] = None,
|
| 508 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 509 |
+
**kwargs: Unpack[FlashAttentionKwargs]) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 510 |
+
input_shape = hidden_states.shape[:-1]
|
| 511 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 512 |
+
|
| 513 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 514 |
+
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 515 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 516 |
+
|
| 517 |
+
cos, sin = position_embeddings
|
| 518 |
+
query_states, key_states = apply_rotary_pos_emb(
|
| 519 |
+
query_states, key_states, cos, sin
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
if past_key_value is not None:
|
| 523 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 524 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "loop_idx": loop_idx}
|
| 525 |
+
key_states, value_states = past_key_value.update(
|
| 526 |
+
key_states,
|
| 527 |
+
value_states,
|
| 528 |
+
self.layer_idx,
|
| 529 |
+
cache_kwargs,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
attention_interface: Callable = eager_attention_forward
|
| 533 |
+
if self.config._attn_implementation != "eager":
|
| 534 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
| 535 |
+
self.config._attn_implementation
|
| 536 |
+
]
|
| 537 |
+
|
| 538 |
+
attn_output, attn_weights = attention_interface(
|
| 539 |
+
self,
|
| 540 |
+
query_states,
|
| 541 |
+
key_states,
|
| 542 |
+
value_states,
|
| 543 |
+
attention_mask,
|
| 544 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 545 |
+
scaling=self.scaling,
|
| 546 |
+
**kwargs,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 550 |
+
attn_output = self.o_proj(attn_output)
|
| 551 |
+
return attn_output, (attn_weights)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def forward_loop2(
|
| 555 |
+
self,
|
| 556 |
+
hidden_states: torch.Tensor,
|
| 557 |
+
loop_idx: int,
|
| 558 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 559 |
+
attention_mask: Optional[torch.Tensor],
|
| 560 |
+
past_key_value: Optional[IQuestLoopCoderCache] = None,
|
| 561 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 562 |
+
gate_proj: Optional[LoopGateProjection] = None,
|
| 563 |
+
**kwargs: Unpack[FlashAttentionKwargs]) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
|
| 564 |
+
|
| 565 |
+
input_shape = hidden_states.shape[:-1]
|
| 566 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 567 |
+
|
| 568 |
+
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 569 |
+
key_states_local = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 570 |
+
value_states_local = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 571 |
+
|
| 572 |
+
cos, sin = position_embeddings
|
| 573 |
+
query_states, key_states_local = apply_rotary_pos_emb(
|
| 574 |
+
query_states, key_states_local, cos, sin
|
| 575 |
+
)
|
| 576 |
+
|
| 577 |
+
key_states_share, value_states_share = None, None
|
| 578 |
+
if past_key_value is not None:
|
| 579 |
+
# get key_share, value_share from past_key_value
|
| 580 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "loop_idx": loop_idx}
|
| 581 |
+
key_states_share, value_states_share = past_key_value.get_shared(self.layer_idx)
|
| 582 |
+
key_states_local, value_states_local = past_key_value.update(
|
| 583 |
+
key_states_local,
|
| 584 |
+
value_states_local,
|
| 585 |
+
self.layer_idx,
|
| 586 |
+
cache_kwargs,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
attention_interface: Callable = eager_attention_forward
|
| 590 |
+
if self.config._attn_implementation != "eager":
|
| 591 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
| 592 |
+
self.config._attn_implementation
|
| 593 |
+
]
|
| 594 |
+
|
| 595 |
+
# Create masks for global and local attention
|
| 596 |
+
# Global attention: full causal mask (can see all tokens in shared cache)
|
| 597 |
+
# Local attention: causal mask for local window (can only see window_size tokens in local cache)
|
| 598 |
+
attention_mask_global = attention_mask # Use full causal mask for global attention
|
| 599 |
+
|
| 600 |
+
# For local attention, create a mask that matches the local cache size
|
| 601 |
+
# The local cache already contains only the last window_size tokens,
|
| 602 |
+
# so we need a causal mask that allows attention within this window
|
| 603 |
+
attention_mask_local = None
|
| 604 |
+
if key_states_local is not None and value_states_local is not None:
|
| 605 |
+
# Local cache has shape [batch, num_heads, local_seq_len, head_dim]
|
| 606 |
+
# where local_seq_len <= window_size
|
| 607 |
+
local_seq_len = key_states_local.shape[2]
|
| 608 |
+
bsz = query_states.shape[0]
|
| 609 |
+
q_len = query_states.shape[2]
|
| 610 |
+
|
| 611 |
+
# Create a causal mask for local attention
|
| 612 |
+
# This allows each query position to attend to all positions up to and including itself
|
| 613 |
+
# within the local window (which is already the last window_size tokens)
|
| 614 |
+
device = query_states.device
|
| 615 |
+
dtype = query_states.dtype
|
| 616 |
+
|
| 617 |
+
if attention_mask is not None:
|
| 618 |
+
# If we have a global mask, we need to adapt it for local attention
|
| 619 |
+
# The global mask shape is [batch, 1, q_len, global_kv_len]
|
| 620 |
+
# For local attention, we only need the last local_seq_len positions
|
| 621 |
+
global_kv_len = attention_mask.shape[-1]
|
| 622 |
+
|
| 623 |
+
if global_kv_len >= local_seq_len:
|
| 624 |
+
# Extract the last local_seq_len columns from the global mask
|
| 625 |
+
# This represents attention to the last window_size tokens
|
| 626 |
+
attention_mask_local = attention_mask[..., -local_seq_len:]
|
| 627 |
+
else:
|
| 628 |
+
# If global mask is shorter than local_seq_len, create a simple causal mask
|
| 629 |
+
# This can happen during prefill when local cache is being built
|
| 630 |
+
attention_mask_local = torch.triu(
|
| 631 |
+
torch.ones((q_len, local_seq_len), device=device, dtype=dtype) * float("-inf"),
|
| 632 |
+
diagonal=1
|
| 633 |
+
).unsqueeze(0).expand(bsz, -1, -1, -1) # [batch, 1, q_len, local_seq_len]
|
| 634 |
+
else:
|
| 635 |
+
# No global mask provided, create a simple causal mask for local attention
|
| 636 |
+
# This allows full attention within the local window (causal)
|
| 637 |
+
attention_mask_local = torch.triu(
|
| 638 |
+
torch.ones((q_len, local_seq_len), device=device, dtype=dtype) * float("-inf"),
|
| 639 |
+
diagonal=1
|
| 640 |
+
).unsqueeze(0).expand(bsz, -1, -1, -1) # [batch, 1, q_len, local_seq_len]
|
| 641 |
+
|
| 642 |
+
# global attn: attend to all tokens in shared cache
|
| 643 |
+
attn_output_global, attn_weights_global = attention_interface(
|
| 644 |
+
self,
|
| 645 |
+
query_states,
|
| 646 |
+
key_states_share,
|
| 647 |
+
value_states_share,
|
| 648 |
+
attention_mask_global,
|
| 649 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 650 |
+
scaling=self.scaling,
|
| 651 |
+
**kwargs,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
# local attn: attend only to tokens in local cache (window_size)
|
| 655 |
+
attn_output_local, attn_weights_local = attention_interface(
|
| 656 |
+
self,
|
| 657 |
+
query_states,
|
| 658 |
+
key_states_local,
|
| 659 |
+
value_states_local,
|
| 660 |
+
attention_mask_local,
|
| 661 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 662 |
+
scaling=self.scaling,
|
| 663 |
+
**kwargs,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
# attention_interface returns [batch, seq_len, num_heads, head_dim] for eager_attention_forward
|
| 667 |
+
# but Flash Attention might return [batch, num_heads, seq_len, head_dim]
|
| 668 |
+
# We need [batch, num_heads, seq_len, head_dim] to match gate shape
|
| 669 |
+
q_len = query_states.shape[2] # Query sequence length
|
| 670 |
+
num_heads = query_states.shape[1]
|
| 671 |
+
|
| 672 |
+
# Normalize attn_output_global to [batch, num_heads, q_len, head_dim]
|
| 673 |
+
if attn_output_global.dim() == 4:
|
| 674 |
+
# Check if shape is [batch, seq_len, num_heads, head_dim] (eager) or [batch, num_heads, seq_len, head_dim] (flash)
|
| 675 |
+
if attn_output_global.shape[1] == q_len:
|
| 676 |
+
# Shape is [batch, seq_len, num_heads, head_dim], transpose to [batch, num_heads, seq_len, head_dim]
|
| 677 |
+
attn_output_global = attn_output_global.transpose(1, 2)
|
| 678 |
+
# Ensure sequence length matches query length (take first q_len tokens)
|
| 679 |
+
if attn_output_global.shape[2] > q_len:
|
| 680 |
+
attn_output_global = attn_output_global[:, :, :q_len, :]
|
| 681 |
+
elif attn_output_global.shape[2] < q_len:
|
| 682 |
+
# This shouldn't happen, but handle it gracefully
|
| 683 |
+
raise ValueError(f"attn_output_global seq_len {attn_output_global.shape[2]} < q_len {q_len}")
|
| 684 |
+
|
| 685 |
+
# Normalize attn_output_local to [batch, num_heads, q_len, head_dim]
|
| 686 |
+
if attn_output_local.dim() == 4:
|
| 687 |
+
# Check if shape is [batch, seq_len, num_heads, head_dim] (eager) or [batch, num_heads, seq_len, head_dim] (flash)
|
| 688 |
+
if attn_output_local.shape[1] == q_len:
|
| 689 |
+
# Shape is [batch, seq_len, num_heads, head_dim], transpose to [batch, num_heads, seq_len, head_dim]
|
| 690 |
+
attn_output_local = attn_output_local.transpose(1, 2)
|
| 691 |
+
# Ensure sequence length matches query length (take first q_len tokens)
|
| 692 |
+
if attn_output_local.shape[2] > q_len:
|
| 693 |
+
attn_output_local = attn_output_local[:, :, :q_len, :]
|
| 694 |
+
elif attn_output_local.shape[2] < q_len:
|
| 695 |
+
# This shouldn't happen, but handle it gracefully
|
| 696 |
+
raise ValueError(f"attn_output_local seq_len {attn_output_local.shape[2]} < q_len {q_len}")
|
| 697 |
+
|
| 698 |
+
assert gate_proj is not None
|
| 699 |
+
gate = gate_proj(query_states) # [batch, num_heads, seq_len, 1]
|
| 700 |
+
mixed_attn_output = attn_output_local * (1 - gate) + attn_output_global * gate
|
| 701 |
+
|
| 702 |
+
mixed_attn_output = mixed_attn_output.reshape(*input_shape, -1).contiguous()
|
| 703 |
+
mixed_attn_output = self.o_proj(mixed_attn_output)
|
| 704 |
+
return mixed_attn_output, (attn_weights_global, attn_weights_local, attn_output_global, attn_output_local, gate)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 708 |
+
class IQuestLoopCoderRMSNorm(nn.Module):
|
| 709 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 710 |
+
"""
|
| 711 |
+
IQuestLoopCoderRMSNorm is equivalent to T5LayerNorm
|
| 712 |
+
"""
|
| 713 |
+
super().__init__()
|
| 714 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 715 |
+
self.variance_epsilon = eps
|
| 716 |
+
|
| 717 |
+
def forward(self, hidden_states):
|
| 718 |
+
input_dtype = hidden_states.dtype
|
| 719 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 720 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 721 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 722 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 723 |
+
|
| 724 |
+
def extra_repr(self):
|
| 725 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
class IQuestLoopCoderDecoderLayer(GradientCheckpointingLayer):
|
| 729 |
+
def __init__(self, config: IQuestLoopCoderConfig, layer_idx: int):
|
| 730 |
+
super().__init__()
|
| 731 |
+
self.hidden_size = config.hidden_size
|
| 732 |
+
|
| 733 |
+
self.self_attn = IQuestLoopCoderAttention(config=config, layer_idx=layer_idx)
|
| 734 |
+
|
| 735 |
+
self.mlp = IQuestLoopCoderMLP(config)
|
| 736 |
+
self.input_layernorm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 737 |
+
self.post_attention_layernorm = IQuestLoopCoderRMSNorm(
|
| 738 |
+
config.hidden_size, eps=config.rms_norm_eps
|
| 739 |
+
)
|
| 740 |
+
self.layer_idx = layer_idx
|
| 741 |
+
|
| 742 |
+
def forward(
|
| 743 |
+
self,
|
| 744 |
+
hidden_states: torch.Tensor,
|
| 745 |
+
loop_idx: int = 0,
|
| 746 |
+
gate_proj: Optional[LoopGateProjection] = None,
|
| 747 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 748 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 749 |
+
past_key_value: Optional[Cache] = None,
|
| 750 |
+
use_cache: Optional[bool] = False,
|
| 751 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 752 |
+
position_embeddings: Optional[
|
| 753 |
+
tuple[torch.Tensor, torch.Tensor]
|
| 754 |
+
] = None, # necessary, but kept here for BC
|
| 755 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 756 |
+
) -> tuple[torch.Tensor]:
|
| 757 |
+
residual = hidden_states
|
| 758 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 759 |
+
# Self Attention
|
| 760 |
+
hidden_states, _ = self.self_attn(
|
| 761 |
+
hidden_states=hidden_states,
|
| 762 |
+
attention_mask=attention_mask,
|
| 763 |
+
position_ids=position_ids,
|
| 764 |
+
past_key_value=past_key_value,
|
| 765 |
+
use_cache=use_cache,
|
| 766 |
+
cache_position=cache_position,
|
| 767 |
+
loop_idx=loop_idx,
|
| 768 |
+
position_embeddings=position_embeddings,
|
| 769 |
+
gate_proj=gate_proj if loop_idx > 0 else None,
|
| 770 |
+
**kwargs,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
hidden_states = residual + hidden_states
|
| 774 |
+
|
| 775 |
+
# Fully Connected
|
| 776 |
+
residual = hidden_states
|
| 777 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 778 |
+
hidden_states = self.mlp(hidden_states)
|
| 779 |
+
hidden_states = residual + hidden_states
|
| 780 |
+
return hidden_states
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
@auto_docstring
|
| 784 |
+
class IQuestLoopCoderPreTrainedModel(PreTrainedModel):
|
| 785 |
+
config: IQuestLoopCoderConfig
|
| 786 |
+
base_model_prefix = "model"
|
| 787 |
+
supports_gradient_checkpointing = True
|
| 788 |
+
_no_split_modules = ["IQuestLoopCoderDecoderLayer"]
|
| 789 |
+
_skip_keys_device_placement = ["past_key_values"]
|
| 790 |
+
_supports_flash_attn = True
|
| 791 |
+
_supports_sdpa = True
|
| 792 |
+
_supports_flex_attn = True
|
| 793 |
+
|
| 794 |
+
_can_compile_fullgraph = True
|
| 795 |
+
_supports_attention_backend = True
|
| 796 |
+
_can_record_outputs = {
|
| 797 |
+
"hidden_states": IQuestLoopCoderDecoderLayer,
|
| 798 |
+
"attentions": IQuestLoopCoderAttention,
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
# Important for inference with `device_map` / low_cpu_mem_usage:
|
| 802 |
+
# Avoid initializing parameters that are not present in the checkpoint.
|
| 803 |
+
# Those should keep their constructor-time initialization (e.g. zeros for LoopGateProjection),
|
| 804 |
+
# instead of being materialized from meta/empty tensors which can contain NaNs.
|
| 805 |
+
def _init_weights(self, module: nn.Module) -> None:
|
| 806 |
+
return
|
| 807 |
+
|
| 808 |
+
|
| 809 |
+
class IQuestLoopCoderRotaryEmbedding(nn.Module):
|
| 810 |
+
def __init__(self, config: IQuestLoopCoderConfig, device=None):
|
| 811 |
+
super().__init__()
|
| 812 |
+
# BC: "rope_type" was originally "type"
|
| 813 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 814 |
+
self.rope_type = config.rope_scaling.get(
|
| 815 |
+
"rope_type", config.rope_scaling.get("type")
|
| 816 |
+
)
|
| 817 |
+
else:
|
| 818 |
+
self.rope_type = "default"
|
| 819 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 820 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 821 |
+
|
| 822 |
+
self.config = config
|
| 823 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 824 |
+
|
| 825 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 826 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 827 |
+
self.original_inv_freq = self.inv_freq
|
| 828 |
+
|
| 829 |
+
@torch.no_grad()
|
| 830 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 831 |
+
def forward(self, x, position_ids):
|
| 832 |
+
inv_freq_expanded = (
|
| 833 |
+
self.inv_freq[None, :, None]
|
| 834 |
+
.float()
|
| 835 |
+
.expand(position_ids.shape[0], -1, 1)
|
| 836 |
+
.to(x.device)
|
| 837 |
+
)
|
| 838 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 839 |
+
|
| 840 |
+
device_type = (
|
| 841 |
+
x.device.type
|
| 842 |
+
if isinstance(x.device.type, str) and x.device.type != "mps"
|
| 843 |
+
else "cpu"
|
| 844 |
+
)
|
| 845 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 846 |
+
freqs = (
|
| 847 |
+
inv_freq_expanded.float() @ position_ids_expanded.float()
|
| 848 |
+
).transpose(1, 2)
|
| 849 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 850 |
+
cos = emb.cos() * self.attention_scaling
|
| 851 |
+
sin = emb.sin() * self.attention_scaling
|
| 852 |
+
|
| 853 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
@auto_docstring
|
| 857 |
+
class IQuestLoopCoderModel(IQuestLoopCoderPreTrainedModel):
|
| 858 |
+
def __init__(self, config: IQuestLoopCoderConfig):
|
| 859 |
+
super().__init__(config)
|
| 860 |
+
self.padding_idx = config.pad_token_id
|
| 861 |
+
self.vocab_size = config.vocab_size
|
| 862 |
+
|
| 863 |
+
self.embed_tokens = nn.Embedding(
|
| 864 |
+
config.vocab_size, config.hidden_size, self.padding_idx
|
| 865 |
+
)
|
| 866 |
+
self.layers = nn.ModuleList(
|
| 867 |
+
[
|
| 868 |
+
IQuestLoopCoderDecoderLayer(config, layer_idx)
|
| 869 |
+
for layer_idx in range(config.num_hidden_layers)
|
| 870 |
+
]
|
| 871 |
+
)
|
| 872 |
+
self.norm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 873 |
+
self.rotary_emb = IQuestLoopCoderRotaryEmbedding(config=config)
|
| 874 |
+
self.gradient_checkpointing = False
|
| 875 |
+
self.loop_num = getattr(self.config, "loop_num", 2)
|
| 876 |
+
self.loop_window_size = getattr(self.config, "loop_window_size", 64)
|
| 877 |
+
|
| 878 |
+
# Gate projections for Loop 2+ (one per layer)
|
| 879 |
+
self.gate_projections = nn.ModuleList([
|
| 880 |
+
LoopGateProjection(config.num_attention_heads, config.head_dim)
|
| 881 |
+
for _ in range(config.num_hidden_layers)
|
| 882 |
+
])
|
| 883 |
+
|
| 884 |
+
# Initialize weights and apply final processing
|
| 885 |
+
self.post_init()
|
| 886 |
+
|
| 887 |
+
@check_model_inputs
|
| 888 |
+
@auto_docstring
|
| 889 |
+
def forward(
|
| 890 |
+
self,
|
| 891 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 892 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 893 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 894 |
+
past_key_values: Optional[Cache] = None,
|
| 895 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 896 |
+
use_cache: Optional[bool] = None,
|
| 897 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 898 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 899 |
+
) -> BaseModelOutputWithPast:
|
| 900 |
+
|
| 901 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 902 |
+
raise ValueError(
|
| 903 |
+
"You must specify exactly one of input_ids or inputs_embeds"
|
| 904 |
+
)
|
| 905 |
+
|
| 906 |
+
if inputs_embeds is None:
|
| 907 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 908 |
+
|
| 909 |
+
if use_cache is None:
|
| 910 |
+
use_cache = self.config.use_cache
|
| 911 |
+
|
| 912 |
+
if use_cache:
|
| 913 |
+
if needs_iquestloopcoder_cache(past_key_values):
|
| 914 |
+
past_key_values = IQuestLoopCoderCache(self.loop_window_size, self.config.num_hidden_layers, self.loop_num)
|
| 915 |
+
|
| 916 |
+
if cache_position is None:
|
| 917 |
+
past_seen_tokens = (
|
| 918 |
+
past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 919 |
+
)
|
| 920 |
+
cache_position = torch.arange(
|
| 921 |
+
past_seen_tokens,
|
| 922 |
+
past_seen_tokens + inputs_embeds.shape[1],
|
| 923 |
+
device=inputs_embeds.device,
|
| 924 |
+
)
|
| 925 |
+
|
| 926 |
+
if position_ids is None:
|
| 927 |
+
position_ids = cache_position.unsqueeze(0)
|
| 928 |
+
|
| 929 |
+
# It may already have been prepared by e.g. `generate`
|
| 930 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 931 |
+
# Prepare mask arguments
|
| 932 |
+
mask_kwargs = {
|
| 933 |
+
"config": self.config,
|
| 934 |
+
"input_embeds": inputs_embeds,
|
| 935 |
+
"attention_mask": attention_mask,
|
| 936 |
+
"cache_position": cache_position,
|
| 937 |
+
"past_key_values": past_key_values,
|
| 938 |
+
"position_ids": position_ids,
|
| 939 |
+
}
|
| 940 |
+
# Create the full causal mask for all layers
|
| 941 |
+
# All layers use full_attention (no sliding window layers)
|
| 942 |
+
full_attention_mask = create_causal_mask(**mask_kwargs)
|
| 943 |
+
causal_mask_mapping = {
|
| 944 |
+
"full_attention": full_attention_mask,
|
| 945 |
+
}
|
| 946 |
+
|
| 947 |
+
hidden_states = inputs_embeds
|
| 948 |
+
|
| 949 |
+
# create position embeddings to be shared across the decoder layers
|
| 950 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 951 |
+
hidden_states_list = []
|
| 952 |
+
|
| 953 |
+
for loop_idx in range(self.loop_num):
|
| 954 |
+
# For each loop, use the full_attention mask
|
| 955 |
+
# Loop 1: uses full_attention mask directly
|
| 956 |
+
# Loop 2+: forward_loop2 will create local mask internally, but uses full_attention mask for global attention
|
| 957 |
+
loop_attention_mask = causal_mask_mapping["full_attention"]
|
| 958 |
+
|
| 959 |
+
for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
|
| 960 |
+
hidden_states = decoder_layer(
|
| 961 |
+
hidden_states,
|
| 962 |
+
loop_idx,
|
| 963 |
+
gate_proj=self.gate_projections[layer_idx] if loop_idx > 0 else None,
|
| 964 |
+
attention_mask=loop_attention_mask,
|
| 965 |
+
position_ids=position_ids,
|
| 966 |
+
past_key_value=past_key_values,
|
| 967 |
+
use_cache=use_cache,
|
| 968 |
+
cache_position=cache_position,
|
| 969 |
+
position_embeddings=position_embeddings,
|
| 970 |
+
**kwargs,
|
| 971 |
+
)
|
| 972 |
+
if loop_idx < self.loop_num - 1:
|
| 973 |
+
hidden_states_list.append(hidden_states)
|
| 974 |
+
|
| 975 |
+
hidden_states = self.norm(hidden_states)
|
| 976 |
+
hidden_states_list.append(hidden_states)
|
| 977 |
+
|
| 978 |
+
return (
|
| 979 |
+
BaseModelOutputWithPast(
|
| 980 |
+
last_hidden_state=hidden_states,
|
| 981 |
+
past_key_values=past_key_values if use_cache else None,
|
| 982 |
+
),
|
| 983 |
+
hidden_states_list,
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
@auto_docstring
|
| 988 |
+
class IQuestLoopCoderForCausalLM(IQuestLoopCoderPreTrainedModel, GenerationMixin):
|
| 989 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 990 |
+
_tp_plan = {"lm_head": "colwise_rep"}
|
| 991 |
+
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
|
| 992 |
+
|
| 993 |
+
def __init__(self, config):
|
| 994 |
+
super().__init__(config)
|
| 995 |
+
self.model = IQuestLoopCoderModel(config)
|
| 996 |
+
self.vocab_size = config.vocab_size
|
| 997 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 998 |
+
|
| 999 |
+
# 分块大小配置
|
| 1000 |
+
self.chunk_size = getattr(config, "chunk_size", 2) # 默认分块大小为2
|
| 1001 |
+
|
| 1002 |
+
self.post_init()
|
| 1003 |
+
|
| 1004 |
+
def get_input_embeddings(self):
|
| 1005 |
+
return self.model.embed_tokens
|
| 1006 |
+
|
| 1007 |
+
def set_input_embeddings(self, value):
|
| 1008 |
+
self.model.embed_tokens = value
|
| 1009 |
+
|
| 1010 |
+
def get_output_embeddings(self):
|
| 1011 |
+
return self.lm_head
|
| 1012 |
+
|
| 1013 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1014 |
+
self.lm_head = new_embeddings
|
| 1015 |
+
|
| 1016 |
+
def set_decoder(self, decoder):
|
| 1017 |
+
self.model = decoder
|
| 1018 |
+
|
| 1019 |
+
def get_decoder(self):
|
| 1020 |
+
return self.model
|
| 1021 |
+
|
| 1022 |
+
@can_return_tuple
|
| 1023 |
+
@auto_docstring
|
| 1024 |
+
def forward(
|
| 1025 |
+
self,
|
| 1026 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1027 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1028 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1029 |
+
past_key_values: Optional[Cache] = None,
|
| 1030 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1031 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1032 |
+
use_cache: Optional[bool] = None,
|
| 1033 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 1034 |
+
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1035 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 1036 |
+
) -> CausalLMOutputWithPast:
|
| 1037 |
+
|
| 1038 |
+
outputs, hidden_states_list = self.model(
|
| 1039 |
+
input_ids=input_ids,
|
| 1040 |
+
attention_mask=attention_mask,
|
| 1041 |
+
position_ids=position_ids,
|
| 1042 |
+
past_key_values=past_key_values,
|
| 1043 |
+
inputs_embeds=inputs_embeds,
|
| 1044 |
+
use_cache=use_cache,
|
| 1045 |
+
cache_position=cache_position,
|
| 1046 |
+
**kwargs,
|
| 1047 |
+
)
|
| 1048 |
+
slice_indices = (
|
| 1049 |
+
slice(-logits_to_keep, None)
|
| 1050 |
+
if isinstance(logits_to_keep, int)
|
| 1051 |
+
else logits_to_keep
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
def _select_token_positions(tensor: torch.Tensor) -> torch.Tensor:
|
| 1055 |
+
if isinstance(slice_indices, slice):
|
| 1056 |
+
return tensor[:, slice_indices, ...]
|
| 1057 |
+
if isinstance(slice_indices, torch.Tensor):
|
| 1058 |
+
return tensor.index_select(1, slice_indices.to(tensor.device))
|
| 1059 |
+
raise TypeError(
|
| 1060 |
+
f"Unsupported index type for logits_to_keep: {type(slice_indices)}"
|
| 1061 |
+
)
|
| 1062 |
+
|
| 1063 |
+
stacked_exit_pdf = None
|
| 1064 |
+
|
| 1065 |
+
expected_logits_cache: Optional[torch.Tensor] = None
|
| 1066 |
+
|
| 1067 |
+
def compute_expected_logits() -> Optional[torch.Tensor]:
|
| 1068 |
+
nonlocal expected_logits_cache
|
| 1069 |
+
if expected_logits_cache is not None:
|
| 1070 |
+
return expected_logits_cache
|
| 1071 |
+
if stacked_exit_pdf is None or not hidden_states_list:
|
| 1072 |
+
return None
|
| 1073 |
+
token_exit_pdf = _select_token_positions(stacked_exit_pdf)
|
| 1074 |
+
expected_logits = None
|
| 1075 |
+
for step_idx, hidden in enumerate(hidden_states_list):
|
| 1076 |
+
step_hidden = _select_token_positions(hidden)
|
| 1077 |
+
step_logits = self.lm_head(step_hidden)
|
| 1078 |
+
weight = (
|
| 1079 |
+
token_exit_pdf[..., step_idx].unsqueeze(-1).to(step_logits.dtype)
|
| 1080 |
+
)
|
| 1081 |
+
expected_logits = (
|
| 1082 |
+
step_logits * weight
|
| 1083 |
+
if expected_logits is None
|
| 1084 |
+
else expected_logits + step_logits * weight
|
| 1085 |
+
)
|
| 1086 |
+
expected_logits_cache = expected_logits
|
| 1087 |
+
return expected_logits_cache
|
| 1088 |
+
|
| 1089 |
+
logits: Optional[torch.Tensor] = None
|
| 1090 |
+
loss: Optional[torch.Tensor] = None
|
| 1091 |
+
|
| 1092 |
+
hidden_states = outputs.last_hidden_state
|
| 1093 |
+
logits = self.lm_head(hidden_states)
|
| 1094 |
+
logits = logits.float()
|
| 1095 |
+
|
| 1096 |
+
if labels is not None:
|
| 1097 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1098 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 1099 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 1100 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 1101 |
+
shift_labels = shift_labels.view(-1)
|
| 1102 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 1103 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 1104 |
+
|
| 1105 |
+
result = CausalLMOutputWithPast(
|
| 1106 |
+
loss=loss,
|
| 1107 |
+
logits=logits,
|
| 1108 |
+
past_key_values=outputs.past_key_values,
|
| 1109 |
+
hidden_states=outputs.hidden_states,
|
| 1110 |
+
attentions=outputs.attentions,
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
return result
|
papers/iquest-coder-v1-logo.png
ADDED
|
Git LFS Details
|
papers/results.png
ADDED
|
Git LFS Details
|
recipe.yaml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stage:
|
| 2 |
+
default_modifiers:
|
| 3 |
+
AWQModifier:
|
| 4 |
+
config_groups:
|
| 5 |
+
group_0:
|
| 6 |
+
targets: [Linear]
|
| 7 |
+
weights:
|
| 8 |
+
num_bits: 4
|
| 9 |
+
type: int
|
| 10 |
+
symmetric: true
|
| 11 |
+
group_size: 32
|
| 12 |
+
strategy: group
|
| 13 |
+
block_structure: null
|
| 14 |
+
dynamic: false
|
| 15 |
+
actorder: null
|
| 16 |
+
scale_dtype: null
|
| 17 |
+
zp_dtype: null
|
| 18 |
+
observer: mse
|
| 19 |
+
observer_kwargs: {}
|
| 20 |
+
input_activations: null
|
| 21 |
+
output_activations: null
|
| 22 |
+
format: null
|
| 23 |
+
targets: [Linear]
|
| 24 |
+
ignore: [model.embed_tokens, 're:.*gate_projections.*', model.norm, lm_head, 're:.*rotary_emb.*']
|
| 25 |
+
mappings:
|
| 26 |
+
- smooth_layer: re:.*input_layernorm$
|
| 27 |
+
balance_layers: ['re:.*q_proj$', 're:.*k_proj$', 're:.*v_proj$']
|
| 28 |
+
- smooth_layer: re:.*v_proj$
|
| 29 |
+
balance_layers: ['re:.*o_proj$']
|
| 30 |
+
- smooth_layer: re:.*post_attention_layernorm$
|
| 31 |
+
balance_layers: ['re:.*gate_proj$', 're:.*up_proj$']
|
| 32 |
+
- smooth_layer: re:.*up_proj$
|
| 33 |
+
balance_layers: ['re:.*down_proj$']
|
| 34 |
+
duo_scaling: true
|
| 35 |
+
n_grid: 20
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|CLS|>",
|
| 4 |
+
"<|SEP|>",
|
| 5 |
+
"<|EOD|>",
|
| 6 |
+
"<|MASK|>",
|
| 7 |
+
"<|PAD|>",
|
| 8 |
+
"<|fim_prefix|>",
|
| 9 |
+
"<|fim_middle|>",
|
| 10 |
+
"<|fim_suffix|>",
|
| 11 |
+
"<|im_start|>",
|
| 12 |
+
"<|im_end|>",
|
| 13 |
+
"<|fim_pad|>",
|
| 14 |
+
"<|endoftext|>",
|
| 15 |
+
"<|repo_name|>",
|
| 16 |
+
"<|file_sep|>",
|
| 17 |
+
"<think>",
|
| 18 |
+
"</think>"
|
| 19 |
+
],
|
| 20 |
+
"bos_token": {
|
| 21 |
+
"content": "<s>",
|
| 22 |
+
"lstrip": false,
|
| 23 |
+
"normalized": true,
|
| 24 |
+
"rstrip": false,
|
| 25 |
+
"single_word": false
|
| 26 |
+
},
|
| 27 |
+
"eos_token": {
|
| 28 |
+
"content": "<|im_end|>",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": false,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false
|
| 33 |
+
},
|
| 34 |
+
"pad_token": {
|
| 35 |
+
"content": "<|endoftext|>",
|
| 36 |
+
"lstrip": false,
|
| 37 |
+
"normalized": false,
|
| 38 |
+
"rstrip": false,
|
| 39 |
+
"single_word": false
|
| 40 |
+
},
|
| 41 |
+
"unk_token": {
|
| 42 |
+
"content": "<unk>",
|
| 43 |
+
"lstrip": false,
|
| 44 |
+
"normalized": true,
|
| 45 |
+
"rstrip": false,
|
| 46 |
+
"single_word": true
|
| 47 |
+
}
|
| 48 |
+
}
|
tokenization_iquestcoder.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tokenization classes for IQuestCoder."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from shutil import copyfile
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import sentencepiece as spm
|
| 8 |
+
|
| 9 |
+
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
logger = logging.get_logger(__name__)
|
| 14 |
+
|
| 15 |
+
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
|
| 16 |
+
|
| 17 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
| 18 |
+
"vocab_file": {},
|
| 19 |
+
"tokenizer_file": {},
|
| 20 |
+
}
|
| 21 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class IQuestCoderTokenizer(PreTrainedTokenizer):
|
| 26 |
+
|
| 27 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 28 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
| 29 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
| 30 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
vocab_file,
|
| 35 |
+
unk_token="<unk>",
|
| 36 |
+
bos_token="<s>",
|
| 37 |
+
eos_token="</s>",
|
| 38 |
+
pad_token=None,
|
| 39 |
+
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 40 |
+
add_bos_token=True,
|
| 41 |
+
add_eos_token=False,
|
| 42 |
+
clean_up_tokenization_spaces=False,
|
| 43 |
+
add_prefix_space=False,
|
| 44 |
+
legacy=None,
|
| 45 |
+
use_default_system_prompt=False,
|
| 46 |
+
chat_template=None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
| 50 |
+
bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
|
| 51 |
+
eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
|
| 52 |
+
unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token
|
| 53 |
+
pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
|
| 54 |
+
|
| 55 |
+
# Legacy behavior handling
|
| 56 |
+
if legacy is None:
|
| 57 |
+
logger.warning_once(
|
| 58 |
+
f"You are using the default legacy behaviour of the {self.__class__.__name__}. This is"
|
| 59 |
+
" expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you."
|
| 60 |
+
" If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it"
|
| 61 |
+
" means, and thoroughly read the reason why this was added as explained in"
|
| 62 |
+
" https://github.com/huggingface/transformers/pull/24565"
|
| 63 |
+
)
|
| 64 |
+
legacy = True
|
| 65 |
+
|
| 66 |
+
self.legacy = legacy
|
| 67 |
+
self.vocab_file = vocab_file
|
| 68 |
+
self.add_bos_token = add_bos_token
|
| 69 |
+
self.add_eos_token = add_eos_token
|
| 70 |
+
self.add_prefix_space = add_prefix_space
|
| 71 |
+
self.use_default_system_prompt = use_default_system_prompt
|
| 72 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 73 |
+
self.sp_model.Load(vocab_file)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
super().__init__(
|
| 78 |
+
bos_token=bos_token,
|
| 79 |
+
eos_token=eos_token,
|
| 80 |
+
unk_token=unk_token,
|
| 81 |
+
pad_token=pad_token,
|
| 82 |
+
add_bos_token=add_bos_token,
|
| 83 |
+
add_eos_token=add_eos_token,
|
| 84 |
+
sp_model_kwargs=self.sp_model_kwargs,
|
| 85 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 86 |
+
add_prefix_space=add_prefix_space,
|
| 87 |
+
legacy=legacy,
|
| 88 |
+
use_default_system_prompt=use_default_system_prompt,
|
| 89 |
+
chat_template=chat_template,
|
| 90 |
+
**kwargs,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def __getstate__(self):
|
| 94 |
+
state = self.__dict__.copy()
|
| 95 |
+
state["sp_model"] = None
|
| 96 |
+
return state
|
| 97 |
+
|
| 98 |
+
def __setstate__(self, d):
|
| 99 |
+
self.__dict__ = d
|
| 100 |
+
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
| 101 |
+
self.sp_model.Load(self.vocab_file)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def vocab_size(self) -> int:
|
| 105 |
+
"""Returns the vocabulary size."""
|
| 106 |
+
return self.sp_model.get_piece_size()
|
| 107 |
+
|
| 108 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 109 |
+
"""Returns the vocabulary as a dictionary of token to index."""
|
| 110 |
+
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
| 111 |
+
vocab.update(self.added_tokens_encoder)
|
| 112 |
+
return vocab
|
| 113 |
+
|
| 114 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 115 |
+
"""
|
| 116 |
+
Tokenize a string.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
text (`str`): The text to tokenize.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
`List[str]`: The list of tokens.
|
| 123 |
+
"""
|
| 124 |
+
if self.add_prefix_space:
|
| 125 |
+
text = " " + text
|
| 126 |
+
|
| 127 |
+
if self.legacy:
|
| 128 |
+
return self.sp_model.encode(text, out_type=str)
|
| 129 |
+
|
| 130 |
+
# Non-legacy behavior: handle special tokens properly
|
| 131 |
+
return self.sp_model.encode(text, out_type=str)
|
| 132 |
+
|
| 133 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 134 |
+
"""Converts a token (str) to an id using the vocab."""
|
| 135 |
+
return self.sp_model.piece_to_id(token)
|
| 136 |
+
|
| 137 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 138 |
+
"""Converts an index (integer) to a token (str) using the vocab."""
|
| 139 |
+
token = self.sp_model.IdToPiece(index)
|
| 140 |
+
return token
|
| 141 |
+
|
| 142 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 143 |
+
"""
|
| 144 |
+
Converts a sequence of tokens (strings) to a single string.
|
| 145 |
+
|
| 146 |
+
This method handles special tokens separately to ensure they are not
|
| 147 |
+
decoded using the SentencePiece model.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
tokens (`List[str]`): The list of tokens to convert.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
`str`: The decoded string.
|
| 154 |
+
"""
|
| 155 |
+
current_sub_tokens = []
|
| 156 |
+
out_string = ""
|
| 157 |
+
prev_is_special = False
|
| 158 |
+
for i, token in enumerate(tokens):
|
| 159 |
+
# make sure that special tokens are not decoded using sentencepiece model
|
| 160 |
+
if token in self.all_special_tokens:
|
| 161 |
+
if not prev_is_special and i != 0:
|
| 162 |
+
out_string += " "
|
| 163 |
+
out_string += self.sp_model.decode(current_sub_tokens) + token
|
| 164 |
+
prev_is_special = True
|
| 165 |
+
current_sub_tokens = []
|
| 166 |
+
else:
|
| 167 |
+
current_sub_tokens.append(token)
|
| 168 |
+
prev_is_special = False
|
| 169 |
+
out_string += self.sp_model.decode(current_sub_tokens)
|
| 170 |
+
return out_string
|
| 171 |
+
|
| 172 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 173 |
+
"""
|
| 174 |
+
Save the vocabulary and special tokens file to a directory.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
save_directory (`str`):
|
| 178 |
+
The directory in which to save the vocabulary.
|
| 179 |
+
filename_prefix (`str`, *optional*):
|
| 180 |
+
An optional prefix to add to the named of the saved files.
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
`Tuple(str)`: Paths to the files saved.
|
| 184 |
+
"""
|
| 185 |
+
if not os.path.isdir(save_directory):
|
| 186 |
+
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
| 187 |
+
return
|
| 188 |
+
out_vocab_file = os.path.join(
|
| 189 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
| 193 |
+
copyfile(self.vocab_file, out_vocab_file)
|
| 194 |
+
elif not os.path.isfile(self.vocab_file):
|
| 195 |
+
with open(out_vocab_file, "wb") as fi:
|
| 196 |
+
content_spiece_model = self.sp_model.serialized_model_proto()
|
| 197 |
+
fi.write(content_spiece_model)
|
| 198 |
+
|
| 199 |
+
return (out_vocab_file,)
|
| 200 |
+
|
| 201 |
+
def build_inputs_with_special_tokens(
|
| 202 |
+
self,
|
| 203 |
+
token_ids_0: List[int],
|
| 204 |
+
token_ids_1: Optional[List[int]] = None
|
| 205 |
+
) -> List[int]:
|
| 206 |
+
"""
|
| 207 |
+
Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
|
| 208 |
+
and adding special tokens.
|
| 209 |
+
|
| 210 |
+
An IQuestCoder sequence has the following format:
|
| 211 |
+
|
| 212 |
+
- single sequence: `<s> X </s>` (if add_eos_token is True) or `<s> X` (default)
|
| 213 |
+
- pair of sequences: `<s> A </s> <s> B </s>` (if add_eos_token is True) or `<s> A <s> B` (default)
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
token_ids_0 (`List[int]`):
|
| 217 |
+
List of IDs to which the special tokens will be added.
|
| 218 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 219 |
+
Optional second list of IDs for sequence pairs.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
`List[int]`: List of input IDs with the appropriate special tokens.
|
| 223 |
+
"""
|
| 224 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 225 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 226 |
+
|
| 227 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 228 |
+
|
| 229 |
+
if token_ids_1 is not None:
|
| 230 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 231 |
+
|
| 232 |
+
return output
|
| 233 |
+
|
| 234 |
+
def get_special_tokens_mask(
|
| 235 |
+
self,
|
| 236 |
+
token_ids_0: List[int],
|
| 237 |
+
token_ids_1: Optional[List[int]] = None,
|
| 238 |
+
already_has_special_tokens: bool = False
|
| 239 |
+
) -> List[int]:
|
| 240 |
+
"""
|
| 241 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
| 242 |
+
special tokens using the tokenizer `prepare_for_model` method.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
token_ids_0 (`List[int]`):
|
| 246 |
+
List of IDs.
|
| 247 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 248 |
+
Optional second list of IDs for sequence pairs.
|
| 249 |
+
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
| 250 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
| 254 |
+
"""
|
| 255 |
+
if already_has_special_tokens:
|
| 256 |
+
return super().get_special_tokens_mask(
|
| 257 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 261 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 262 |
+
|
| 263 |
+
if token_ids_1 is None:
|
| 264 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 265 |
+
return (
|
| 266 |
+
bos_token_id
|
| 267 |
+
+ ([0] * len(token_ids_0))
|
| 268 |
+
+ eos_token_id
|
| 269 |
+
+ bos_token_id
|
| 270 |
+
+ ([0] * len(token_ids_1))
|
| 271 |
+
+ eos_token_id
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
def create_token_type_ids_from_sequences(
|
| 275 |
+
self,
|
| 276 |
+
token_ids_0: List[int],
|
| 277 |
+
token_ids_1: Optional[List[int]] = None
|
| 278 |
+
) -> List[int]:
|
| 279 |
+
"""
|
| 280 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task.
|
| 281 |
+
|
| 282 |
+
An IQuestCoder sequence pair mask has the following format:
|
| 283 |
+
|
| 284 |
+
```
|
| 285 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
| 286 |
+
| first sequence | second sequence |
|
| 287 |
+
```
|
| 288 |
+
|
| 289 |
+
If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
token_ids_0 (`List[int]`):
|
| 293 |
+
List of IDs.
|
| 294 |
+
token_ids_1 (`List[int]`, *optional*):
|
| 295 |
+
Optional second list of IDs for sequence pairs.
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
`List[int]`: List of token type IDs according to the given sequence(s).
|
| 299 |
+
"""
|
| 300 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 301 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 302 |
+
|
| 303 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 304 |
+
|
| 305 |
+
if token_ids_1 is not None:
|
| 306 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 307 |
+
|
| 308 |
+
return output
|
| 309 |
+
|
| 310 |
+
@property
|
| 311 |
+
def default_chat_template(self) -> str:
|
| 312 |
+
"""
|
| 313 |
+
Returns the default chat template for IQuestCoder.
|
| 314 |
+
|
| 315 |
+
This template formats conversations with system, user, and assistant roles.
|
| 316 |
+
"""
|
| 317 |
+
return DEFAULT_CHAT_TEMPLATE
|
| 318 |
+
|
| 319 |
+
def apply_chat_template(
|
| 320 |
+
self,
|
| 321 |
+
conversation: Union[List[Dict[str, str]], "Conversation"],
|
| 322 |
+
chat_template: Optional[str] = None,
|
| 323 |
+
add_generation_prompt: bool = False,
|
| 324 |
+
tokenize: bool = True,
|
| 325 |
+
padding: bool = False,
|
| 326 |
+
truncation: bool = False,
|
| 327 |
+
max_length: Optional[int] = None,
|
| 328 |
+
return_tensors: Optional[str] = None,
|
| 329 |
+
return_dict: bool = False,
|
| 330 |
+
**tokenizer_kwargs,
|
| 331 |
+
):
|
| 332 |
+
"""
|
| 333 |
+
Apply a chat template to format a conversation.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
conversation (`List[Dict[str, str]]` or `Conversation`):
|
| 337 |
+
A list of dicts with "role" and "content" keys, representing the conversation history.
|
| 338 |
+
chat_template (`str`, *optional*):
|
| 339 |
+
A Jinja template to use for formatting. If not provided, the tokenizer's default will be used.
|
| 340 |
+
add_generation_prompt (`bool`, *optional*, defaults to `False`):
|
| 341 |
+
Whether to add a generation prompt at the end for the assistant to continue.
|
| 342 |
+
tokenize (`bool`, *optional*, defaults to `True`):
|
| 343 |
+
Whether to tokenize the output. If `False`, returns a string.
|
| 344 |
+
padding (`bool`, *optional*, defaults to `False`):
|
| 345 |
+
Whether to pad sequences.
|
| 346 |
+
truncation (`bool`, *optional*, defaults to `False`):
|
| 347 |
+
Whether to truncate sequences.
|
| 348 |
+
max_length (`int`, *optional*):
|
| 349 |
+
Maximum length of the output.
|
| 350 |
+
return_tensors (`str`, *optional*):
|
| 351 |
+
The type of tensors to return ("pt", "tf", "np", or None).
|
| 352 |
+
return_dict (`bool`, *optional*, defaults to `False`):
|
| 353 |
+
Whether to return a dictionary with additional information.
|
| 354 |
+
**tokenizer_kwargs:
|
| 355 |
+
Additional keyword arguments passed to the tokenizer.
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
`Union[str, List[int], BatchEncoding]`: The formatted (and optionally tokenized) conversation.
|
| 359 |
+
|
| 360 |
+
Example:
|
| 361 |
+
```python
|
| 362 |
+
>>> tokenizer = IQuestCoderTokenizer.from_pretrained("path/to/model")
|
| 363 |
+
>>> conversation = [
|
| 364 |
+
... {"role": "system", "content": "You are a helpful assistant."},
|
| 365 |
+
... {"role": "user", "content": "Hello!"},
|
| 366 |
+
... {"role": "assistant", "content": "Hi there! How can I help you today?"},
|
| 367 |
+
... {"role": "user", "content": "What's the weather like?"},
|
| 368 |
+
... ]
|
| 369 |
+
>>> tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
| 370 |
+
'<|system|>\\nYou are a helpful assistant.\\n</|system|><|user|>\\nHello!\\n</|user|>...'
|
| 371 |
+
```
|
| 372 |
+
"""
|
| 373 |
+
# Use parent class implementation with our template
|
| 374 |
+
return super().apply_chat_template(
|
| 375 |
+
conversation,
|
| 376 |
+
chat_template=chat_template,
|
| 377 |
+
add_generation_prompt=add_generation_prompt,
|
| 378 |
+
tokenize=tokenize,
|
| 379 |
+
padding=padding,
|
| 380 |
+
truncation=truncation,
|
| 381 |
+
max_length=max_length,
|
| 382 |
+
return_tensors=return_tensors,
|
| 383 |
+
return_dict=return_dict,
|
| 384 |
+
**tokenizer_kwargs,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# Try to import and create Fast tokenizer version
|
| 389 |
+
try:
|
| 390 |
+
from transformers import PreTrainedTokenizerFast
|
| 391 |
+
from tokenizers import Tokenizer, decoders, models, normalizers, pre_tokenizers, processors
|
| 392 |
+
|
| 393 |
+
class IQuestCoderTokenizerFast(PreTrainedTokenizerFast):
|
| 394 |
+
"""
|
| 395 |
+
Construct a "fast" IQuestCoder tokenizer (backed by HuggingFace's *tokenizers* library).
|
| 396 |
+
|
| 397 |
+
This is a fast implementation of [`IQuestCoderTokenizer`] using the 🤗 Tokenizers library.
|
| 398 |
+
|
| 399 |
+
Args:
|
| 400 |
+
vocab_file (`str`, *optional*):
|
| 401 |
+
Path to the vocabulary file (SentencePiece model).
|
| 402 |
+
tokenizer_file (`str`, *optional*):
|
| 403 |
+
Path to a tokenizer JSON file.
|
| 404 |
+
unk_token (`str`, *optional*, defaults to `"<unk>"`):
|
| 405 |
+
The unknown token.
|
| 406 |
+
bos_token (`str`, *optional*, defaults to `"<s>"`):
|
| 407 |
+
The beginning of sequence token.
|
| 408 |
+
eos_token (`str`, *optional*, defaults to `"</s>"`):
|
| 409 |
+
The end of sequence token.
|
| 410 |
+
pad_token (`str`, *optional*):
|
| 411 |
+
The token used for padding.
|
| 412 |
+
add_bos_token (`bool`, *optional*, defaults to `True`):
|
| 413 |
+
Whether to add a BOS token at the start of sequences.
|
| 414 |
+
add_eos_token (`bool`, *optional*, defaults to `False`):
|
| 415 |
+
Whether to add an EOS token at the end of sequences.
|
| 416 |
+
add_prefix_space (`bool`, *optional*, defaults to `False`):
|
| 417 |
+
Whether to add an initial space to the input.
|
| 418 |
+
use_default_system_prompt (`bool`, *optional*, defaults to `False`):
|
| 419 |
+
Whether to use the default system prompt.
|
| 420 |
+
chat_template (`str`, *optional*):
|
| 421 |
+
A Jinja template for formatting conversations.
|
| 422 |
+
|
| 423 |
+
Example:
|
| 424 |
+
```python
|
| 425 |
+
>>> from tokenization_iquestcoder import IQuestCoderTokenizerFast
|
| 426 |
+
|
| 427 |
+
>>> tokenizer = IQuestCoderTokenizerFast.from_pretrained("path/to/model")
|
| 428 |
+
>>> tokenizer.encode("Hello, world!")
|
| 429 |
+
[1, 15043, 29892, 3186, 29991]
|
| 430 |
+
```
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
| 434 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
| 435 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
| 436 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 437 |
+
slow_tokenizer_class = IQuestCoderTokenizer
|
| 438 |
+
|
| 439 |
+
def __init__(
|
| 440 |
+
self,
|
| 441 |
+
vocab_file=None,
|
| 442 |
+
tokenizer_file=None,
|
| 443 |
+
unk_token="<unk>",
|
| 444 |
+
bos_token="<s>",
|
| 445 |
+
eos_token="</s>",
|
| 446 |
+
pad_token=None,
|
| 447 |
+
add_bos_token=True,
|
| 448 |
+
add_eos_token=False,
|
| 449 |
+
add_prefix_space=False,
|
| 450 |
+
use_default_system_prompt=False,
|
| 451 |
+
chat_template=None,
|
| 452 |
+
**kwargs,
|
| 453 |
+
):
|
| 454 |
+
self.add_bos_token = add_bos_token
|
| 455 |
+
self.add_eos_token = add_eos_token
|
| 456 |
+
self.add_prefix_space = add_prefix_space
|
| 457 |
+
self.use_default_system_prompt = use_default_system_prompt
|
| 458 |
+
|
| 459 |
+
if chat_template is None:
|
| 460 |
+
chat_template = DEFAULT_CHAT_TEMPLATE
|
| 461 |
+
|
| 462 |
+
super().__init__(
|
| 463 |
+
vocab_file=vocab_file,
|
| 464 |
+
tokenizer_file=tokenizer_file,
|
| 465 |
+
unk_token=unk_token,
|
| 466 |
+
bos_token=bos_token,
|
| 467 |
+
eos_token=eos_token,
|
| 468 |
+
pad_token=pad_token,
|
| 469 |
+
add_bos_token=add_bos_token,
|
| 470 |
+
add_eos_token=add_eos_token,
|
| 471 |
+
add_prefix_space=add_prefix_space,
|
| 472 |
+
use_default_system_prompt=use_default_system_prompt,
|
| 473 |
+
chat_template=chat_template,
|
| 474 |
+
**kwargs,
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
@property
|
| 478 |
+
def can_save_slow_tokenizer(self) -> bool:
|
| 479 |
+
return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
| 480 |
+
|
| 481 |
+
@property
|
| 482 |
+
def default_chat_template(self) -> str:
|
| 483 |
+
"""Returns the default chat template."""
|
| 484 |
+
return DEFAULT_CHAT_TEMPLATE
|
| 485 |
+
|
| 486 |
+
def build_inputs_with_special_tokens(
|
| 487 |
+
self,
|
| 488 |
+
token_ids_0: List[int],
|
| 489 |
+
token_ids_1: Optional[List[int]] = None
|
| 490 |
+
) -> List[int]:
|
| 491 |
+
"""Build model inputs with special tokens."""
|
| 492 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 493 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 494 |
+
|
| 495 |
+
output = bos_token_id + token_ids_0 + eos_token_id
|
| 496 |
+
|
| 497 |
+
if token_ids_1 is not None:
|
| 498 |
+
output = output + bos_token_id + token_ids_1 + eos_token_id
|
| 499 |
+
|
| 500 |
+
return output
|
| 501 |
+
|
| 502 |
+
def get_special_tokens_mask(
|
| 503 |
+
self,
|
| 504 |
+
token_ids_0: List[int],
|
| 505 |
+
token_ids_1: Optional[List[int]] = None,
|
| 506 |
+
already_has_special_tokens: bool = False
|
| 507 |
+
) -> List[int]:
|
| 508 |
+
"""Retrieve special tokens mask."""
|
| 509 |
+
if already_has_special_tokens:
|
| 510 |
+
return super().get_special_tokens_mask(
|
| 511 |
+
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
bos_token_id = [1] if self.add_bos_token else []
|
| 515 |
+
eos_token_id = [1] if self.add_eos_token else []
|
| 516 |
+
|
| 517 |
+
if token_ids_1 is None:
|
| 518 |
+
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
| 519 |
+
return (
|
| 520 |
+
bos_token_id
|
| 521 |
+
+ ([0] * len(token_ids_0))
|
| 522 |
+
+ eos_token_id
|
| 523 |
+
+ bos_token_id
|
| 524 |
+
+ ([0] * len(token_ids_1))
|
| 525 |
+
+ eos_token_id
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
def create_token_type_ids_from_sequences(
|
| 529 |
+
self,
|
| 530 |
+
token_ids_0: List[int],
|
| 531 |
+
token_ids_1: Optional[List[int]] = None
|
| 532 |
+
) -> List[int]:
|
| 533 |
+
"""Create token type IDs from sequences."""
|
| 534 |
+
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
| 535 |
+
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
| 536 |
+
|
| 537 |
+
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
| 538 |
+
|
| 539 |
+
if token_ids_1 is not None:
|
| 540 |
+
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
| 541 |
+
|
| 542 |
+
return output
|
| 543 |
+
|
| 544 |
+
except ImportError:
|
| 545 |
+
# tokenizers library not available, Fast tokenizer not supported
|
| 546 |
+
IQuestCoderTokenizerFast = None
|
| 547 |
+
logger.info(
|
| 548 |
+
"The `tokenizers` library is not installed. "
|
| 549 |
+
"IQuestCoderTokenizerFast will not be available. "
|
| 550 |
+
"Install it with `pip install tokenizers`."
|
| 551 |
+
)
|
| 552 |
+
|
tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7d3be68e090a927f31e0e378d7599b15c206dd47e4a73933775a746cc9c1cd91
|
| 3 |
+
size 1345108
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"0": {
|
| 6 |
+
"content": "<unk>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": true,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": true,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"1": {
|
| 14 |
+
"content": "<s>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": true,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"2": {
|
| 22 |
+
"content": "</s>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": true,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": true,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"75858": {
|
| 30 |
+
"content": "<CLS>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"75859": {
|
| 38 |
+
"content": "<SEP>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"75860": {
|
| 46 |
+
"content": "<EOD>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"75861": {
|
| 54 |
+
"content": "<MASK>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"75862": {
|
| 62 |
+
"content": "<PAD>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"75863": {
|
| 70 |
+
"content": "<|im_start|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"75864": {
|
| 78 |
+
"content": "<|im_end|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"75865": {
|
| 86 |
+
"content": "<|fim_prefix|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"75866": {
|
| 94 |
+
"content": "<|fim_middle|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"75867": {
|
| 102 |
+
"content": "<|fim_suffix|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"75868": {
|
| 110 |
+
"content": "<|fim_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"75869": {
|
| 118 |
+
"content": "<|endoftext|>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": true
|
| 124 |
+
},
|
| 125 |
+
"75870": {
|
| 126 |
+
"content": "<|repo_name|>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": true
|
| 132 |
+
},
|
| 133 |
+
"75871": {
|
| 134 |
+
"content": "<|file_sep|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": true
|
| 140 |
+
},
|
| 141 |
+
"75872": {
|
| 142 |
+
"content": "<think>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"75873": {
|
| 150 |
+
"content": "</think>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"75874": {
|
| 158 |
+
"content": "<tools>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"75875": {
|
| 166 |
+
"content": "</tools>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"75876": {
|
| 174 |
+
"content": "<tool_call>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"75877": {
|
| 182 |
+
"content": "</tool_call>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"75878": {
|
| 190 |
+
"content": "<tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"75879": {
|
| 198 |
+
"content": "</tool_response>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
}
|
| 205 |
+
},
|
| 206 |
+
"additional_special_tokens": [
|
| 207 |
+
"<|CLS|>",
|
| 208 |
+
"<|SEP|>",
|
| 209 |
+
"<|EOD|>",
|
| 210 |
+
"<|MASK|>",
|
| 211 |
+
"<|PAD|>",
|
| 212 |
+
"<|fim_prefix|>",
|
| 213 |
+
"<|fim_middle|>",
|
| 214 |
+
"<|fim_suffix|>",
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|fim_pad|>",
|
| 218 |
+
"<|endoftext|>",
|
| 219 |
+
"<|repo_name|>",
|
| 220 |
+
"<|file_sep|>",
|
| 221 |
+
"<think>",
|
| 222 |
+
"</think>"
|
| 223 |
+
],
|
| 224 |
+
"auto_map": {
|
| 225 |
+
"AutoTokenizer": [
|
| 226 |
+
"tokenization_iquestcoder.IQuestCoderTokenizer",
|
| 227 |
+
null
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
"bos_token": "<s>",
|
| 231 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- else %}\n {{- 'You are LoopCoder, a helpful assistant developed by IQuest.' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are LoopCoder, a helpful assistant developed by IQuest.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}",
|
| 232 |
+
"clean_up_tokenization_spaces": false,
|
| 233 |
+
"eos_token": "<|im_end|>",
|
| 234 |
+
"model_max_length": 131072,
|
| 235 |
+
"pad_token": "<|endoftext|>",
|
| 236 |
+
"padding_side": "right",
|
| 237 |
+
"sp_model_kwargs": {},
|
| 238 |
+
"split_special_tokens": false,
|
| 239 |
+
"tokenizer_class": "IQuestCoderTokenizer",
|
| 240 |
+
"unk_token": "<unk>",
|
| 241 |
+
"use_fast": false
|
| 242 |
+
}
|