irridileepkumar commited on
Commit
0b410d0
·
verified ·
1 Parent(s): c4bfc74

Delete unsloth-main

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. unsloth-main/.github/FUNDING.yml +0 -13
  2. unsloth-main/CONTRIBUTING.md +0 -29
  3. unsloth-main/LICENSE +0 -201
  4. unsloth-main/README.md +0 -492
  5. unsloth-main/images/Assistant.png +0 -0
  6. unsloth-main/images/Colab.png +0 -0
  7. unsloth-main/images/Discord button.png +0 -0
  8. unsloth-main/images/Discord.png +0 -0
  9. unsloth-main/images/Documentation Button.png +0 -0
  10. unsloth-main/images/Free version button.png +0 -0
  11. unsloth-main/images/Kaggle.png +0 -0
  12. unsloth-main/images/Kofi button.png +0 -0
  13. unsloth-main/images/LAION 2GPU.png +0 -0
  14. unsloth-main/images/Merge.png +0 -0
  15. unsloth-main/images/Run.png +0 -0
  16. unsloth-main/images/Slim Orca 2GPUs.png +0 -0
  17. unsloth-main/images/Terminal_Type.png +0 -0
  18. unsloth-main/images/Where_Terminal.png +0 -0
  19. unsloth-main/images/buy me a coffee button.png +0 -0
  20. unsloth-main/images/documentation github button.png +0 -0
  21. unsloth-main/images/documentation green button.png +0 -0
  22. unsloth-main/images/documentation lighter.png +0 -0
  23. unsloth-main/images/documentation white button.png +0 -0
  24. unsloth-main/images/made with unsloth.png +0 -0
  25. unsloth-main/images/ollama.png +0 -0
  26. unsloth-main/images/peft x trl button.png +0 -0
  27. unsloth-main/images/start free finetune button.png +0 -0
  28. unsloth-main/images/unsloth end.png +0 -0
  29. unsloth-main/images/unsloth loading page render.png +0 -0
  30. unsloth-main/images/unsloth logo black text.png +0 -0
  31. unsloth-main/images/unsloth logo only.png +0 -0
  32. unsloth-main/images/unsloth logo white text.png +0 -0
  33. unsloth-main/images/unsloth made with love.png +0 -0
  34. unsloth-main/images/unsloth new logo.png +0 -0
  35. unsloth-main/pyproject.toml +0 -418
  36. unsloth-main/unsloth-cli.py +0 -221
  37. unsloth-main/unsloth/__init__.py +0 -181
  38. unsloth-main/unsloth/_auto_install.py +0 -31
  39. unsloth-main/unsloth/chat_templates.py +0 -2105
  40. unsloth-main/unsloth/kernels/__init__.py +0 -65
  41. unsloth-main/unsloth/kernels/cross_entropy_loss.py +0 -405
  42. unsloth-main/unsloth/kernels/fast_lora.py +0 -490
  43. unsloth-main/unsloth/kernels/flex_attention.py +0 -181
  44. unsloth-main/unsloth/kernels/geglu.py +0 -203
  45. unsloth-main/unsloth/kernels/layernorm.py +0 -213
  46. unsloth-main/unsloth/kernels/rms_layernorm.py +0 -297
  47. unsloth-main/unsloth/kernels/rope_embedding.py +0 -196
  48. unsloth-main/unsloth/kernels/swiglu.py +0 -99
  49. unsloth-main/unsloth/kernels/utils.py +0 -422
  50. unsloth-main/unsloth/models/__init__.py +0 -22
unsloth-main/.github/FUNDING.yml DELETED
@@ -1,13 +0,0 @@
1
- # These are supported funding model platforms
2
-
3
- github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
4
- patreon: # Replace with a single Patreon username
5
- open_collective: # Replace with a single Open Collective username
6
- ko_fi: unsloth
7
- tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
8
- community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
9
- liberapay: # Replace with a single Liberapay username
10
- issuehunt: # Replace with a single IssueHunt username
11
- otechie: # Replace with a single Otechie username
12
- lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
13
- custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/CONTRIBUTING.md DELETED
@@ -1,29 +0,0 @@
1
- # 🦥 Contributing to Unsloth
2
-
3
- Thank you for not only using Unsloth but also for being interested in helping out! We value all contributions, whether they come in the form of code, ideas, support for others or just by simply spreading the word of Unsloth! 💕
4
-
5
- - **[Support the Community](https://github.com/unslothai/unsloth/issues)**: Answer questions, review pull requests, or assist others in discussions.
6
- - **Fix Bugs**: Identify and resolve issues with the existing codebase.
7
- - **Submit Ideas**: Request new features or share enhancements you'd like to see.
8
- - **Develop Features**: Implement new functionality or improve existing tools which can be done via PRs.
9
- - **[Improve Documentation](https://docs.unsloth.ai/)**: Help by creating guides, FAQs, or enhancing clarity.
10
-
11
- One of the best ways to support us is by spreading the word about Unsloth! Share how it’s powering your amazing projects in blog posts or social media, and inspire others to explore its potential. Even a simple star on our repo goes a long way in showing your support and helping the community grow. 🌟
12
-
13
- ## Submitting Issues
14
- If you find a bug or have a feature idea, we’d love to hear from you! Here’s how to make your submission stand out:
15
-
16
- ### Reporting Bugs
17
- 1. **Search First**: Check if the issue has already been reported using GitHub’s search bar under Issues.
18
- 2. **Details Matter**: Is this on Google Colab, Kaggle, or on another platform service? Are you using Unsloth's official notebook? Include your OS, Python version, and other relevant details. For bugs, a concise code snippet that reproduces the issue is incredibly helpful.
19
- 3. **Be Thorough**: Attach screenshots, traceback logs, or any additional information that might speed up resolution.
20
-
21
- ## Spread the Word
22
- Your support extends beyond code:
23
- - Spread the word by writing about Unsloth in blogs or social media.
24
- - Share how Unsloth powers your projects.
25
- - Star our repository to show your appreciation.
26
-
27
- Finally, please be mindful of our [Code of Conduct](https://github.com/unslothai/unsloth/tree/main/unsloth/CODE_OF_CONDUCT.md) to ensure a welcoming and inclusive environment for everyone.
28
-
29
- Thank you so much for reading and we hope you have lots of fun using Unsloth! 🦥
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/README.md DELETED
@@ -1,492 +0,0 @@
1
- <div align="center">
2
-
3
- <a href="https://unsloth.ai"><picture>
4
- <source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png">
5
- <source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png">
6
- <img alt="unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png" height="110" style="max-width: 100%;">
7
- </picture></a>
8
-
9
- <a href="https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/start free finetune button.png" height="48"></a>
10
- <a href="https://discord.gg/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/Discord button.png" height="48"></a>
11
- <a href="https://docs.unsloth.ai"><img src="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/main/images/Documentation%20Button.png" height="48"></a>
12
-
13
- ### Finetune Llama 3.2, Mistral, Phi-3.5, Qwen 2.5 & Gemma 2-5x faster with 80% less memory!
14
-
15
- ![](https://i.ibb.co/sJ7RhGG/image-41.png)
16
-
17
- </div>
18
-
19
- ## ✨ Finetune for Free
20
-
21
- All notebooks are **beginner friendly**! Add your dataset, click "Run All", and you'll get a 2x faster finetuned model which can be exported to GGUF, Ollama, vLLM or uploaded to Hugging Face.
22
-
23
- | Unsloth supports | Free Notebooks | Performance | Memory use |
24
- |-----------|---------|--------|----------|
25
- | **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing) | 2x faster | 60% less |
26
- | **Llama 3.2 Vision (11B)** | [▶️ Start for free](https://colab.research.google.com/drive/1j0N4XTY1zXXy7mPAhOC1_gMYZ2F2EBlk?usp=sharing) | 2x faster | 40% less |
27
- | **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) | 2x faster | 60% less |
28
- | **Phi-3.5 (mini)** | [▶️ Start for free](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) | 2x faster | 50% less |
29
- | **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) | 2x faster | 63% less |
30
- | **Qwen 2.5 (7B)** | [▶️ Start for free](https://colab.research.google.com/drive/1Kose-ucXO1IBaZq5BvbwWieuubP7hxvQ?usp=sharing) | 2x faster | 63% less |
31
- | **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) | 2.2x faster | 73% less |
32
- | **Ollama** | [▶️ Start for free](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing) | 1.9x faster | 43% less |
33
- | **ORPO** | [▶️ Start for free](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) | 1.9x faster | 43% less |
34
- | **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 43% less |
35
-
36
- - See [all our notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks) and [all our models](https://docs.unsloth.ai/get-started/all-our-models)
37
- - **Kaggle Notebooks** for [Llama 3.2 Kaggle notebook](https://www.kaggle.com/danielhanchen/kaggle-llama-3-2-1b-3b-unsloth-notebook), [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)
38
- - Run notebooks for [Llama 3.2 conversational](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing), [Llama 3.1 conversational](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing)
39
- - This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text
40
- - This [continued pretraining notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) is for learning another language
41
- - Click [here](https://docs.unsloth.ai/) for detailed documentation for Unsloth.
42
-
43
- ## 🦥 Unsloth.ai News
44
- - 📣 NEW! [Llama 3.3 (70B)](https://huggingface.co/collections/unsloth/llama-33-all-versions-67535d7d994794b9d7cf5e9f), Meta's latest model is now supported.
45
- - 📣 NEW! We worked with Apple to add [Cut Cross Entropy](https://arxiv.org/abs/2411.09009). Unsloth now supports 89K context for Meta's Llama 3.3 (70B) on a 80GB GPU - 13x longer than HF+FA2. For Llama 3.1 (8B), Unsloth enables 342K context, surpassing its native 128K support.
46
- - 📣 NEW! Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7)
47
- - 📣 NEW! [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/drive/1j0N4XTY1zXXy7mPAhOC1_gMYZ2F2EBlk?usp=sharing), [Qwen 2.5 VL (7B)](https://colab.research.google.com/drive/1whHb54GNZMrNxIsi2wm2EY_-Pvo2QyKh?usp=sharing) and [Pixtral (12B) 2409](https://colab.research.google.com/drive/1K9ZrdwvZRE96qGkCq_e88FgV3MLnymQq?usp=sharing)
48
- - 📣 NEW! Qwen-2.5 including [Coder](https://colab.research.google.com/drive/18sN803sU23XuJV9Q8On2xgqHSer6-UZF?usp=sharing) models are now supported with bugfixes. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/drive/1qN1CEalC70EO1wGKhNxs1go1W9So61R5?usp=sharing)
49
- - 📣 NEW! We found and helped fix a [gradient accumulation bug](https://unsloth.ai/blog/gradient)! Please update Unsloth and transformers.
50
- <details>
51
- <summary>Click for more news</summary>
52
-
53
- - 📣 Try out [Chat interface](https://colab.research.google.com/drive/1i-8ESvtLRGNkkUQQr_-z_rcSAIo9c3lM?usp=sharing)!
54
- - 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) finetuning fits in under 16GB of VRAM!
55
- - 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) & [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct are now supported
56
- - 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs.
57
- - 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) for other languages like Korean!
58
- - 📣 [2x faster inference](https://colab.research.google.com/drive/1aqlNQi7MMJbynFDyOQteD2t0yVfjb9Zh?usp=sharing) added for all our models
59
- - 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
60
- </details>
61
-
62
- ## 🔗 Links and Resources
63
- | Type | Links |
64
- | ------------------------------- | --------------------------------------- |
65
- | 📚 **Documentation & Wiki** | [Read Our Docs](https://docs.unsloth.ai) |
66
- | <img height="14" src="https://upload.wikimedia.org/wikipedia/commons/6/6f/Logo_of_Twitter.svg" />&nbsp; **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)|
67
- | 💾 **Installation** | [unsloth/README.md](https://github.com/unslothai/unsloth/tree/main#-installation-instructions)|
68
- | 🥇 **Benchmarking** | [Performance Tables](https://github.com/unslothai/unsloth/tree/main#-performance-benchmarking)
69
- | 🌐 **Released Models** | [Unsloth Releases](https://docs.unsloth.ai/get-started/all-our-models)|
70
- | ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog)|
71
- | <img height="14" src="https://redditinc.com/hs-fs/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" />&nbsp; **Reddit** | [Join our Reddit page](https://reddit.com/r/unsloth)|
72
-
73
- ## ⭐ Key Features
74
- - All kernels written in [OpenAI's Triton](https://openai.com/research/triton) language. **Manual backprop engine**.
75
- - **0% loss in accuracy** - no approximation methods - all exact.
76
- - No change of hardware. Supports NVIDIA GPUs since 2018+. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow.
77
- - Works on **Linux** and **Windows** via WSL.
78
- - Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
79
- - Open source trains 5x faster - see [Unsloth Pro](https://unsloth.ai/) for up to **30x faster training**!
80
- - If you trained a model with 🦥Unsloth, you can use this cool sticker! &nbsp; <img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/made with unsloth.png" height="50" align="center" />
81
-
82
-
83
- ## 🥇 Performance Benchmarking
84
- - For the full list of **reproducible** benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
85
-
86
- | 1 A100 40GB | 🤗Hugging Face | Flash Attention | 🦥Unsloth Open Source | 🦥[Unsloth Pro](https://unsloth.ai/pricing) |
87
- |--------------|--------------|-----------------|---------------------|-----------------|
88
- | Alpaca | 1x | 1.04x | 1.98x | **15.64x** |
89
- | LAION Chip2 | 1x | 0.92x | 1.61x | **20.73x** |
90
- | OASST | 1x | 1.19x | 2.17x | **14.83x** |
91
- | Slim Orca | 1x | 1.18x | 2.22x | **14.82x** |
92
-
93
- - Benchmarking table below was conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl).
94
-
95
- | Free Colab T4 | Dataset | 🤗Hugging Face | Pytorch 2.1.1 | 🦥Unsloth | 🦥 VRAM reduction |
96
- | --- | --- | --- | --- | --- | --- |
97
- | Llama-2 7b | OASST | 1x | 1.19x | 1.95x | -43.3% |
98
- | Mistral 7b | Alpaca | 1x | 1.07x | 1.56x | -13.7% |
99
- | Tiny Llama 1.1b | Alpaca | 1x | 2.06x | 3.87x | -73.8% |
100
- | DPO with Zephyr | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
101
-
102
- ![](https://i.ibb.co/sJ7RhGG/image-41.png)
103
-
104
- ## 💾 Installation Instructions
105
-
106
- For stable releases, use `pip install unsloth`. We recommend `pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"` for most installations though.
107
-
108
- ### Conda Installation
109
- `⚠️Only use Conda if you have it. If not, use Pip`. Select either `pytorch-cuda=11.8,12.1` for CUDA 11.8 or CUDA 12.1. We support `python=3.10,3.11,3.12`.
110
- ```bash
111
- conda create --name unsloth_env \
112
- python=3.11 \
113
- pytorch-cuda=12.1 \
114
- pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
115
- -y
116
- conda activate unsloth_env
117
-
118
- pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
119
- pip install --no-deps trl peft accelerate bitsandbytes
120
- ```
121
-
122
- <details>
123
- <summary>If you're looking to install Conda in a Linux environment, <a href="https://docs.anaconda.com/miniconda/">read here</a>, or run the below 🔽</summary>
124
-
125
- ```bash
126
- mkdir -p ~/miniconda3
127
- wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
128
- bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
129
- rm -rf ~/miniconda3/miniconda.sh
130
- ~/miniconda3/bin/conda init bash
131
- ~/miniconda3/bin/conda init zsh
132
- ```
133
- </details>
134
-
135
- ### Pip Installation
136
- `⚠️Do **NOT** use this if you have Conda.` Pip is a bit more complex since there are dependency issues. The pip command is different for `torch 2.2,2.3,2.4,2.5` and CUDA versions.
137
-
138
- For other torch versions, we support `torch211`, `torch212`, `torch220`, `torch230`, `torch240` and for CUDA versions, we support `cu118` and `cu121` and `cu124`. For Ampere devices (A100, H100, RTX3090) and above, use `cu118-ampere` or `cu121-ampere` or `cu124-ampere`.
139
-
140
- For example, if you have `torch 2.4` and `CUDA 12.1`, use:
141
- ```bash
142
- pip install --upgrade pip
143
- pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
144
- ```
145
-
146
- Another example, if you have `torch 2.5` and `CUDA 12.4`, use:
147
- ```bash
148
- pip install --upgrade pip
149
- pip install "unsloth[cu124-torch250] @ git+https://github.com/unslothai/unsloth.git"
150
- ```
151
-
152
- And other examples:
153
- ```bash
154
- pip install "unsloth[cu121-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
155
- pip install "unsloth[cu118-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
156
- pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
157
- pip install "unsloth[cu118-torch240] @ git+https://github.com/unslothai/unsloth.git"
158
-
159
- pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
160
- pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
161
-
162
- pip install "unsloth[cu121-torch250] @ git+https://github.com/unslothai/unsloth.git"
163
- pip install "unsloth[cu124-ampere-torch250] @ git+https://github.com/unslothai/unsloth.git"
164
- ```
165
-
166
- Or, run the below in a terminal to get the **optimal** pip installation command:
167
- ```bash
168
- wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/_auto_install.py | python -
169
- ```
170
-
171
- Or, run the below manually in a Python REPL:
172
- ```python
173
- try: import torch
174
- except: raise ImportError('Install torch via `pip install torch`')
175
- from packaging.version import Version as V
176
- v = V(torch.__version__)
177
- cuda = str(torch.version.cuda)
178
- is_ampere = torch.cuda.get_device_capability()[0] >= 8
179
- if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!")
180
- if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
181
- elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
182
- elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
183
- elif v < V('2.3.0'): x = 'cu{}{}-torch220'
184
- elif v < V('2.4.0'): x = 'cu{}{}-torch230'
185
- elif v < V('2.5.0'): x = 'cu{}{}-torch240'
186
- elif v < V('2.6.0'): x = 'cu{}{}-torch250'
187
- else: raise RuntimeError(f"Torch = {v} too new!")
188
- x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
189
- print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
190
- ```
191
-
192
- ### Windows Installation
193
-
194
- To run Unsloth directly on Windows:
195
- - Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows
196
- - In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue:
197
- ```python
198
- trainer = SFTTrainer(
199
- dataset_num_proc=1,
200
- ...
201
- )
202
- ```
203
-
204
- For **advanced installation instructions** or if you see weird errors during installations:
205
-
206
- 1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
207
- 2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
208
- 3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs.
209
- 4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
210
-
211
- ## 📜 [Documentation](https://docs.unsloth.ai)
212
- - Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
213
- - We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
214
- - We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
215
-
216
- ```python
217
- from unsloth import FastLanguageModel
218
- from unsloth import is_bfloat16_supported
219
- import torch
220
- from trl import SFTTrainer
221
- from transformers import TrainingArguments
222
- from datasets import load_dataset
223
- max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
224
- # Get LAION dataset
225
- url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
226
- dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
227
-
228
- # 4bit pre quantized models we support for 4x faster downloading + no OOMs.
229
- fourbit_models = [
230
- "unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster!
231
- "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
232
- "unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster!
233
- "unsloth/llama-3-8b-Instruct-bnb-4bit",
234
- "unsloth/llama-3-70b-bnb-4bit",
235
- "unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster!
236
- "unsloth/Phi-3-medium-4k-instruct",
237
- "unsloth/mistral-7b-bnb-4bit",
238
- "unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster!
239
- ] # More models at https://huggingface.co/unsloth
240
-
241
- model, tokenizer = FastLanguageModel.from_pretrained(
242
- model_name = "unsloth/llama-3-8b-bnb-4bit",
243
- max_seq_length = max_seq_length,
244
- dtype = None,
245
- load_in_4bit = True,
246
- )
247
-
248
- # Do model patching and add fast LoRA weights
249
- model = FastLanguageModel.get_peft_model(
250
- model,
251
- r = 16,
252
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
253
- "gate_proj", "up_proj", "down_proj",],
254
- lora_alpha = 16,
255
- lora_dropout = 0, # Supports any, but = 0 is optimized
256
- bias = "none", # Supports any, but = "none" is optimized
257
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
258
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
259
- random_state = 3407,
260
- max_seq_length = max_seq_length,
261
- use_rslora = False, # We support rank stabilized LoRA
262
- loftq_config = None, # And LoftQ
263
- )
264
-
265
- trainer = SFTTrainer(
266
- model = model,
267
- train_dataset = dataset,
268
- dataset_text_field = "text",
269
- max_seq_length = max_seq_length,
270
- tokenizer = tokenizer,
271
- args = TrainingArguments(
272
- per_device_train_batch_size = 2,
273
- gradient_accumulation_steps = 4,
274
- warmup_steps = 10,
275
- max_steps = 60,
276
- fp16 = not is_bfloat16_supported(),
277
- bf16 = is_bfloat16_supported(),
278
- logging_steps = 1,
279
- output_dir = "outputs",
280
- optim = "adamw_8bit",
281
- seed = 3407,
282
- ),
283
- )
284
- trainer.train()
285
-
286
- # Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
287
- # (1) Saving to GGUF / merging to 16bit for vLLM
288
- # (2) Continued training from a saved LoRA adapter
289
- # (3) Adding an evaluation loop / OOMs
290
- # (4) Customized chat templates
291
- ```
292
-
293
- <a name="DPO"></a>
294
- ## DPO Support
295
- DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as per 3rd party independent testing from [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory). We have a preliminary Google Colab notebook for reproducing Zephyr on Tesla T4 here: [notebook](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing).
296
-
297
- We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
298
-
299
- ```python
300
- import os
301
- os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID
302
-
303
- from unsloth import FastLanguageModel, PatchDPOTrainer
304
- from unsloth import is_bfloat16_supported
305
- PatchDPOTrainer()
306
- import torch
307
- from transformers import TrainingArguments
308
- from trl import DPOTrainer
309
-
310
- model, tokenizer = FastLanguageModel.from_pretrained(
311
- model_name = "unsloth/zephyr-sft-bnb-4bit",
312
- max_seq_length = max_seq_length,
313
- dtype = None,
314
- load_in_4bit = True,
315
- )
316
-
317
- # Do model patching and add fast LoRA weights
318
- model = FastLanguageModel.get_peft_model(
319
- model,
320
- r = 64,
321
- target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
322
- "gate_proj", "up_proj", "down_proj",],
323
- lora_alpha = 64,
324
- lora_dropout = 0, # Supports any, but = 0 is optimized
325
- bias = "none", # Supports any, but = "none" is optimized
326
- # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
327
- use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
328
- random_state = 3407,
329
- max_seq_length = max_seq_length,
330
- )
331
-
332
- dpo_trainer = DPOTrainer(
333
- model = model,
334
- ref_model = None,
335
- args = TrainingArguments(
336
- per_device_train_batch_size = 4,
337
- gradient_accumulation_steps = 8,
338
- warmup_ratio = 0.1,
339
- num_train_epochs = 3,
340
- fp16 = not is_bfloat16_supported(),
341
- bf16 = is_bfloat16_supported(),
342
- logging_steps = 1,
343
- optim = "adamw_8bit",
344
- seed = 42,
345
- output_dir = "outputs",
346
- ),
347
- beta = 0.1,
348
- train_dataset = YOUR_DATASET_HERE,
349
- # eval_dataset = YOUR_DATASET_HERE,
350
- tokenizer = tokenizer,
351
- max_length = 1024,
352
- max_prompt_length = 512,
353
- )
354
- dpo_trainer.train()
355
- ```
356
-
357
- ## 🥇 Detailed Benchmarking Tables
358
- - Click "Code" for fully reproducible examples
359
- - "Unsloth Equal" is a preview of our PRO version, with code stripped out. All settings and the loss curve remains identical.
360
- - For the full list of benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
361
-
362
- | 1 A100 40GB | 🤗Hugging Face | Flash Attention 2 | 🦥Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
363
- |--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
364
- | Alpaca | 1x | 1.04x | 1.98x | 2.48x | 5.32x | **15.64x** |
365
- | code | [Code](https://colab.research.google.com/drive/1u4dBeM-0vGNVmmO6X7cScAut-Hyt4KDF?usp=sharing) | [Code](https://colab.research.google.com/drive/1fgTOxpMbVjloQBvZyz4lF4BacKSZOB2A?usp=sharing) | [Code](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Code](https://colab.research.google.com/drive/1ANW8EFL3LVyTD7Gq4TkheC1Z7Rxw-rHp?usp=sharing) | | |
366
- | seconds| 1040 | 1001 | 525 | 419 | 196 | 67 |
367
- | memory MB| 18235 | 15365 | 9631 | 8525 | | |
368
- | % saved| | 15.74 | 47.18 | 53.25 | | | |
369
-
370
- ### Llama-Factory 3rd party benchmarking
371
- - [Link to performance table.](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-Comparison) TGS: tokens per GPU per second. Model: LLaMA2-7B. GPU: NVIDIA A100 * 1. Batch size: 4. Gradient accumulation: 2. LoRA rank: 8. Max length: 1024.
372
-
373
- | Method | Bits | TGS | GRAM | Speed |
374
- | --- | --- | --- | --- | --- |
375
- | HF | 16 | 2392 | 18GB | 100% |
376
- | HF+FA2 | 16 | 2954 | 17GB | 123% |
377
- | Unsloth+FA2 | 16 | 4007 | 16GB | **168%** |
378
- | HF | 4 | 2415 | 9GB | 101% |
379
- | Unsloth+FA2 | 4 | 3726 | 7GB | **160%** |
380
-
381
- ### Performance comparisons between popular models
382
- <details>
383
- <summary>Click for specific model benchmarking tables (Mistral 7b, CodeLlama 34b etc.)</summary>
384
-
385
- ### Mistral 7b
386
- | 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
387
- |--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
388
- | Mistral 7B Slim Orca | 1x | 1.15x | 2.15x | 2.53x | 4.61x | **13.69x** |
389
- | code | [Code](https://colab.research.google.com/drive/1mePk3KzwTD81hr5mcNcs_AX3Kbg_Ha0x?usp=sharing) | [Code](https://colab.research.google.com/drive/1dgHxjvTmX6hb0bPcLp26RXSE6_n9DKj7?usp=sharing) | [Code](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | [Code](https://colab.research.google.com/drive/18yOiyX0T81mTwZqOALFSCX_tSAqju6aD?usp=sharing) | |
390
- | seconds | 1813 | 1571 | 842 | 718 | 393 | 132 |
391
- | memory MB | 32853 | 19385 | 12465 | 10271 | | |
392
- | % saved| | 40.99 | 62.06 | 68.74 | | |
393
-
394
- ### CodeLlama 34b
395
- | 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
396
- |--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
397
- | Code Llama 34B | OOM ❌ | 0.99x | 1.87x | 2.61x | 4.27x | 12.82x |
398
- | code | [▶️ Code](https://colab.research.google.com/drive/1ykfz3BqrtC_AUFegCzUQjjfUNlxp6Otc?usp=sharing) | [Code](https://colab.research.google.com/drive/12ZypxQh7OC6kBXvWZI-5d05I4m-B_hoR?usp=sharing) | [Code](https://colab.research.google.com/drive/1gdHyAx8XJsz2yNV-DHvbHjR1iCef5Qmh?usp=sharing) | [Code](https://colab.research.google.com/drive/1fm7wqx9MJ0kRrwKOfmLkK1Rmw-pySahB?usp=sharing) | |
399
- | seconds | 1953 | 1982 | 1043 | 748 | 458 | 152 |
400
- | memory MB | 40000 | 33217 | 27413 | 22161 | | |
401
- | % saved| | 16.96| 31.47 | 44.60 | | | |
402
-
403
- ### 1 Tesla T4
404
-
405
- | 1 T4 16GB | Hugging Face | Flash Attention | Unsloth Open | Unsloth Pro Equal | Unsloth Pro | Unsloth Max |
406
- |--------------|-------------|-----------------|-----------------|---------------|---------------|-------------|
407
- | Alpaca | 1x | 1.09x | 1.69x | 1.79x | 2.93x | **8.3x** |
408
- | code | [▶️ Code](https://colab.research.google.com/drive/1XpLIV4s8Bj5uryB-X2gqM88oRGHEGdaB?usp=sharing) | [Code](https://colab.research.google.com/drive/1LyXu6CjuymQg6ddHX8g1dpUvrMa1nn4L?usp=sharing) | [Code](https://colab.research.google.com/drive/1gsv4LpY7C32otl1rgRo5wXTk4HIitXoM?usp=sharing) | [Code](https://colab.research.google.com/drive/1VtULwRQwhEnVdNryjm27zXfdSM1tNfFK?usp=sharing) | | |
409
- | seconds | 1599 | 1468 | 942 | 894 | 545 | 193 |
410
- | memory MB | 7199 | 7059 | 6459 | 5443 | | |
411
- | % saved | | 1.94 | 10.28 | 24.39 | | |
412
-
413
- ### 2 Tesla T4s via DDP
414
-
415
- | 2 T4 DDP | Hugging Face | Flash Attention | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
416
- |--------------|----------|-------------|-----------------|--------------|---------------|-------------|
417
- | Alpaca | 1x | 0.99x | 4.95x | 4.44x | 7.28x | **20.61x** |
418
- | code | [▶️ Code](https://www.kaggle.com/danielhanchen/hf-original-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/hf-sdpa-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) | | |
419
- | seconds | 9882 | 9946 | 1996 | 2227 | 1357 | 480 |
420
- | memory MB| 9176 | 9128 | 6904 | 6782 | | |
421
- | % saved | | 0.52 | 24.76 | 26.09 | | | |
422
- </details>
423
-
424
- ### Performance comparisons on 1 Tesla T4 GPU:
425
- <details>
426
- <summary>Click for Time taken for 1 epoch</summary>
427
-
428
- One Tesla T4 on Google Colab
429
- `bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
430
-
431
- | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
432
- | --- | --- | --- | --- | --- | --- |
433
- | Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m |
434
- | Unsloth Open | 1 T4 | 13h 7m (1.8x) | 31h 47m (1.8x) | 4h 27m (1.9x) | 240h 4m (1.6x) |
435
- | Unsloth Pro | 1 T4 | 3h 6m (7.5x) | 5h 17m (10.7x) | 1h 7m (7.7x) | 59h 53m (6.5x) |
436
- | Unsloth Max | 1 T4 | 2h 39m (8.8x) | 4h 31m (12.5x) | 0h 58m (8.9x) | 51h 30m (7.6x) |
437
-
438
- **Peak Memory Usage**
439
-
440
- | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
441
- | --- | --- | --- | --- | --- | --- |
442
- | Huggingface | 1 T4 | 7.3GB | 5.9GB | 14.0GB | 13.3GB |
443
- | Unsloth Open | 1 T4 | 6.8GB | 5.7GB | 7.8GB | 7.7GB |
444
- | Unsloth Pro | 1 T4 | 6.4GB | 6.4GB | 6.4GB | 6.4GB |
445
- | Unsloth Max | 1 T4 | 11.4GB | 12.4GB | 11.9GB | 14.4GB |
446
- </details>
447
-
448
- <details>
449
- <summary>Click for Performance Comparisons on 2 Tesla T4 GPUs via DDP:</summary>
450
- **Time taken for 1 epoch**
451
-
452
- Two Tesla T4s on Kaggle
453
- `bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
454
-
455
- | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
456
- | --- | --- | --- | --- | --- | --- |
457
- | Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * |
458
- | Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * |
459
- | Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * |
460
-
461
- **Peak Memory Usage on a Multi GPU System (2 GPUs)**
462
-
463
- | System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
464
- | --- | --- | --- | --- | --- | --- |
465
- | Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * |
466
- | Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * |
467
- | Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * |
468
-
469
- * Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency.
470
- </details>
471
-
472
- ![](https://i.ibb.co/sJ7RhGG/image-41.png)
473
- <br>
474
-
475
- ### Citation
476
-
477
- You can cite the Unsloth repo as follows:
478
- ```bibtex
479
- @software{unsloth,
480
- author = {Daniel Han, Michael Han and Unsloth team},
481
- title = {Unsloth},
482
- url = {http://github.com/unslothai/unsloth},
483
- year = {2023}
484
- }
485
- ```
486
-
487
- ### Thank You to
488
- - [Erik](https://github.com/erikwijmans) for his help adding [Apple's ML Cross Entropy](https://github.com/apple/ml-cross-entropy) in Unsloth
489
- - [HuyNguyen-hust](https://github.com/HuyNguyen-hust) for making [RoPE Embeddings 28% faster](https://github.com/unslothai/unsloth/pull/238)
490
- - [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support
491
- - [152334H](https://github.com/152334H) for experimental DPO support
492
- - [atgctg](https://github.com/atgctg) for syntax highlighting
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/images/Assistant.png DELETED
Binary file (82.6 kB)
 
unsloth-main/images/Colab.png DELETED
Binary file (11.6 kB)
 
unsloth-main/images/Discord button.png DELETED
Binary file (13.8 kB)
 
unsloth-main/images/Discord.png DELETED
Binary file (13.8 kB)
 
unsloth-main/images/Documentation Button.png DELETED
Binary file (11.8 kB)
 
unsloth-main/images/Free version button.png DELETED
Binary file (9.53 kB)
 
unsloth-main/images/Kaggle.png DELETED
Binary file (9.73 kB)
 
unsloth-main/images/Kofi button.png DELETED
Binary file (18.1 kB)
 
unsloth-main/images/LAION 2GPU.png DELETED
Binary file (51.4 kB)
 
unsloth-main/images/Merge.png DELETED
Binary file (31.4 kB)
 
unsloth-main/images/Run.png DELETED
Binary file (11.5 kB)
 
unsloth-main/images/Slim Orca 2GPUs.png DELETED
Binary file (43.1 kB)
 
unsloth-main/images/Terminal_Type.png DELETED
Binary file (69.4 kB)
 
unsloth-main/images/Where_Terminal.png DELETED
Binary file (179 kB)
 
unsloth-main/images/buy me a coffee button.png DELETED
Binary file (19 kB)
 
unsloth-main/images/documentation github button.png DELETED
Binary file (11.8 kB)
 
unsloth-main/images/documentation green button.png DELETED
Binary file (11.8 kB)
 
unsloth-main/images/documentation lighter.png DELETED
Binary file (11.8 kB)
 
unsloth-main/images/documentation white button.png DELETED
Binary file (11.2 kB)
 
unsloth-main/images/made with unsloth.png DELETED
Binary file (70.4 kB)
 
unsloth-main/images/ollama.png DELETED
Binary file (67.2 kB)
 
unsloth-main/images/peft x trl button.png DELETED
Binary file (36.9 kB)
 
unsloth-main/images/start free finetune button.png DELETED
Binary file (11.4 kB)
 
unsloth-main/images/unsloth end.png DELETED
Binary file (892 kB)
 
unsloth-main/images/unsloth loading page render.png DELETED
Binary file (790 kB)
 
unsloth-main/images/unsloth logo black text.png DELETED
Binary file (58 kB)
 
unsloth-main/images/unsloth logo only.png DELETED
Binary file (57.2 kB)
 
unsloth-main/images/unsloth logo white text.png DELETED
Binary file (59 kB)
 
unsloth-main/images/unsloth made with love.png DELETED
Binary file (63.5 kB)
 
unsloth-main/images/unsloth new logo.png DELETED
Binary file (60.1 kB)
 
unsloth-main/pyproject.toml DELETED
@@ -1,418 +0,0 @@
1
- [build-system]
2
- requires = ["setuptools", "setuptools-scm"]
3
- build-backend = "setuptools.build_meta"
4
-
5
- [project]
6
- name = "unsloth"
7
- dynamic = ["version"]
8
- description = "2-5X faster LLM finetuning"
9
- readme = "README.md"
10
- requires-python = ">=3.9"
11
- license = {file = "LICENSE"}
12
- keywords = ["ai", "llm",]
13
- authors = [
14
- {email = "[email protected]"},
15
- {name = "Unsloth AI team"},
16
- ]
17
- maintainers = [
18
- {name = "Daniel Han", email = "[email protected]"},
19
- {name = "Michael Han", email = "[email protected]"},
20
- ]
21
- classifiers = [
22
- "Programming Language :: Python",
23
- ]
24
-
25
- [tool.setuptools.dynamic]
26
- version = {attr = "unsloth.models._utils.__version__"}
27
-
28
- [tool.setuptools]
29
- include-package-data = false
30
-
31
- [tool.setuptools.packages.find]
32
- exclude = ["images*"]
33
-
34
- [project.optional-dependencies]
35
- triton = [
36
- "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
37
- "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
38
- "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
39
- "triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
40
- ]
41
- huggingface = [
42
- "unsloth_zoo>=2024.12.7",
43
- "packaging",
44
- "tyro",
45
- "transformers>=4.46.1,!=4.47.0",
46
- "datasets>=2.16.0",
47
- "sentencepiece>=0.2.0",
48
- "tqdm",
49
- "psutil",
50
- "wheel>=0.42.0",
51
- "numpy",
52
- "accelerate>=0.34.1",
53
- "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
54
- "peft>=0.7.1,!=0.11.0",
55
- "protobuf<4.0.0",
56
- "huggingface_hub",
57
- "hf_transfer",
58
- "unsloth[triton]",
59
- ]
60
- cu118only = [
61
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
62
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
63
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
64
- ]
65
- cu121only = [
66
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
67
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
68
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
69
- ]
70
- cu118onlytorch211 = [
71
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
72
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
73
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
74
- ]
75
- cu121onlytorch211 = [
76
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
77
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
78
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
79
- ]
80
- cu118onlytorch212 = [
81
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
82
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
83
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
84
- ]
85
- cu121onlytorch212 = [
86
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
87
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
88
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
89
- ]
90
- cu118onlytorch220 = [
91
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
92
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
93
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
94
- ]
95
- cu121onlytorch220 = [
96
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
97
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
98
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
99
- ]
100
- cu118onlytorch230 = [
101
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
102
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
103
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
104
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
105
- ]
106
- cu121onlytorch230 = [
107
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
108
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
109
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
110
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
111
- ]
112
- cu118onlytorch240 = [
113
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
114
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
115
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
116
- "xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
117
- ]
118
- cu121onlytorch240 = [
119
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
120
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
121
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
122
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
123
- ]
124
- cu124onlytorch240 = [
125
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
126
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
127
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
128
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
129
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
130
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
131
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
132
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
133
- ]
134
- cu121onlytorch250 = [
135
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
136
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
137
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
138
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
139
- ]
140
- cu124onlytorch250 = [
141
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
142
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
143
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
144
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
145
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
146
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
147
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
148
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
149
- ]
150
- cu121onlytorch251 = [
151
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
152
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
153
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
154
- "xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
155
- ]
156
- cu124onlytorch251 = [
157
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
158
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
159
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
160
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
161
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
162
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
163
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
164
- "xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
165
- ]
166
- cu118 = [
167
- "unsloth[huggingface]",
168
- "bitsandbytes>=0.43.3",
169
- "unsloth[cu118only]",
170
- ]
171
- cu121 = [
172
- "unsloth[huggingface]",
173
- "bitsandbytes>=0.43.3",
174
- "unsloth[cu121only]",
175
- ]
176
- cu118-torch211 = [
177
- "unsloth[huggingface]",
178
- "bitsandbytes>=0.43.3",
179
- "unsloth[cu118onlytorch211]",
180
- ]
181
- cu121-torch211 = [
182
- "unsloth[huggingface]",
183
- "bitsandbytes>=0.43.3",
184
- "unsloth[cu121onlytorch211]",
185
- ]
186
- cu118-torch212 = [
187
- "unsloth[huggingface]",
188
- "bitsandbytes>=0.43.3",
189
- "unsloth[cu118onlytorch212]",
190
- ]
191
- cu121-torch212 = [
192
- "unsloth[huggingface]",
193
- "bitsandbytes>=0.43.3",
194
- "unsloth[cu121onlytorch212]",
195
- ]
196
- cu118-torch220 = [
197
- "unsloth[huggingface]",
198
- "bitsandbytes>=0.43.3",
199
- "unsloth[cu118onlytorch220]",
200
- ]
201
- cu121-torch220 = [
202
- "unsloth[huggingface]",
203
- "bitsandbytes>=0.43.3",
204
- "unsloth[cu121onlytorch220]",
205
- ]
206
- cu118-torch230 = [
207
- "unsloth[huggingface]",
208
- "bitsandbytes>=0.43.3",
209
- "unsloth[cu118onlytorch230]",
210
- ]
211
- cu121-torch230 = [
212
- "unsloth[huggingface]",
213
- "bitsandbytes>=0.43.3",
214
- "unsloth[cu121onlytorch230]",
215
- ]
216
- cu118-torch240 = [
217
- "unsloth[huggingface]",
218
- "bitsandbytes>=0.43.3",
219
- "unsloth[cu118onlytorch240]",
220
- ]
221
- cu121-torch240 = [
222
- "unsloth[huggingface]",
223
- "bitsandbytes>=0.43.3",
224
- "unsloth[cu121onlytorch240]",
225
- ]
226
- cu121-torch250 = [
227
- "unsloth[huggingface]",
228
- "bitsandbytes>=0.43.3",
229
- "unsloth[cu121onlytorch250]",
230
- ]
231
- cu124-torch240 = [
232
- "unsloth[huggingface]",
233
- "bitsandbytes>=0.43.3",
234
- "unsloth[cu124onlytorch240]",
235
- ]
236
- cu124-torch250 = [
237
- "unsloth[huggingface]",
238
- "bitsandbytes>=0.43.3",
239
- "unsloth[cu124onlytorch250]",
240
- ]
241
- cu121-torch251 = [
242
- "unsloth[huggingface]",
243
- "bitsandbytes>=0.43.3",
244
- "unsloth[cu121onlytorch251]",
245
- ]
246
- cu124-torch251 = [
247
- "unsloth[huggingface]",
248
- "bitsandbytes>=0.43.3",
249
- "unsloth[cu124onlytorch251]",
250
- ]
251
- kaggle = [
252
- "unsloth[huggingface]",
253
- ]
254
- kaggle-new = [
255
- "unsloth[huggingface]",
256
- "bitsandbytes>=0.43.3",
257
- ]
258
- conda = [
259
- "unsloth[huggingface]",
260
- ]
261
- colab-torch211 = [
262
- "unsloth[huggingface]",
263
- "bitsandbytes>=0.43.3",
264
- "unsloth[cu121onlytorch211]",
265
- ]
266
- colab-ampere-torch211 = [
267
- "unsloth[huggingface]",
268
- "bitsandbytes>=0.43.3",
269
- "unsloth[cu121onlytorch211]",
270
- "packaging",
271
- "ninja",
272
- "flash-attn>=2.6.3",
273
- ]
274
- colab-torch220 = [
275
- "unsloth[huggingface]",
276
- "bitsandbytes>=0.43.3",
277
- "unsloth[cu121onlytorch220]",
278
- ]
279
- colab-ampere-torch220 = [
280
- "unsloth[huggingface]",
281
- "bitsandbytes>=0.43.3",
282
- "unsloth[cu121onlytorch220]",
283
- "packaging",
284
- "ninja",
285
- "flash-attn>=2.6.3",
286
- ]
287
- colab-new = [
288
- "unsloth_zoo>=2024.12.7",
289
- "packaging",
290
- "tyro",
291
- "transformers>=4.46.1,!=4.47.0",
292
- "datasets>=2.16.0",
293
- "sentencepiece>=0.2.0",
294
- "tqdm",
295
- "psutil",
296
- "wheel>=0.42.0",
297
- "numpy",
298
- "protobuf<4.0.0",
299
- "huggingface_hub",
300
- "hf_transfer",
301
- "bitsandbytes>=0.43.3",
302
- "unsloth[triton]",
303
- ]
304
- colab-no-deps = [
305
- "accelerate>=0.34.1",
306
- "trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
307
- "peft>=0.7.1",
308
- "xformers",
309
- "bitsandbytes>=0.46.1",
310
- "protobuf<4.0.0",
311
- ]
312
- colab = [
313
- "unsloth[cu121]",
314
- ]
315
- flashattention = [
316
- "packaging ; platform_system == 'Linux'",
317
- "ninja ; platform_system == 'Linux'",
318
- "flash-attn>=2.6.3 ; platform_system == 'Linux'",
319
- ]
320
- colab-ampere = [
321
- "unsloth[colab-ampere-torch220]",
322
- "unsloth[flashattention]",
323
- ]
324
- cu118-ampere = [
325
- "unsloth[huggingface]",
326
- "bitsandbytes>=0.43.3",
327
- "unsloth[cu118only]",
328
- "unsloth[flashattention]",
329
- ]
330
- cu121-ampere = [
331
- "unsloth[huggingface]",
332
- "bitsandbytes>=0.43.3",
333
- "unsloth[cu121only]",
334
- "unsloth[flashattention]",
335
- ]
336
- cu118-ampere-torch211 = [
337
- "unsloth[huggingface]",
338
- "bitsandbytes>=0.43.3",
339
- "unsloth[cu118onlytorch211]",
340
- "unsloth[flashattention]",
341
- ]
342
- cu121-ampere-torch211 = [
343
- "unsloth[huggingface]",
344
- "bitsandbytes>=0.43.3",
345
- "unsloth[cu121onlytorch211]",
346
- "unsloth[flashattention]",
347
- ]
348
- cu118-ampere-torch220 = [
349
- "unsloth[huggingface]",
350
- "bitsandbytes>=0.43.3",
351
- "unsloth[cu118onlytorch220]",
352
- "unsloth[flashattention]",
353
- ]
354
- cu121-ampere-torch220 = [
355
- "unsloth[huggingface]",
356
- "bitsandbytes>=0.43.3",
357
- "unsloth[cu121onlytorch220]",
358
- "unsloth[flashattention]",
359
- ]
360
- cu118-ampere-torch230 = [
361
- "unsloth[huggingface]",
362
- "bitsandbytes>=0.43.3",
363
- "unsloth[cu118onlytorch230]",
364
- "unsloth[flashattention]",
365
- ]
366
- cu121-ampere-torch230 = [
367
- "unsloth[huggingface]",
368
- "bitsandbytes>=0.43.3",
369
- "unsloth[cu121onlytorch230]",
370
- "unsloth[flashattention]",
371
- ]
372
- cu118-ampere-torch240 = [
373
- "unsloth[huggingface]",
374
- "bitsandbytes>=0.43.3",
375
- "unsloth[cu118onlytorch240]",
376
- "unsloth[flashattention]",
377
- ]
378
- cu121-ampere-torch240 = [
379
- "unsloth[huggingface]",
380
- "bitsandbytes>=0.43.3",
381
- "unsloth[cu121onlytorch240]",
382
- "unsloth[flashattention]",
383
- ]
384
- cu121-ampere-torch250 = [
385
- "unsloth[huggingface]",
386
- "bitsandbytes>=0.43.3",
387
- "unsloth[cu121onlytorch250]",
388
- "unsloth[flashattention]",
389
- ]
390
- cu124-ampere-torch240 = [
391
- "unsloth[huggingface]",
392
- "bitsandbytes>=0.43.3",
393
- "unsloth[cu124onlytorch240]",
394
- "unsloth[flashattention]",
395
- ]
396
- cu124-ampere-torch250 = [
397
- "unsloth[huggingface]",
398
- "bitsandbytes>=0.43.3",
399
- "unsloth[cu124onlytorch250]",
400
- "unsloth[flashattention]",
401
- ]
402
- cu121-ampere-torch251 = [
403
- "unsloth[huggingface]",
404
- "bitsandbytes>=0.43.3",
405
- "unsloth[cu121onlytorch251]",
406
- "unsloth[flashattention]",
407
- ]
408
- cu124-ampere-torch251 = [
409
- "unsloth[huggingface]",
410
- "bitsandbytes>=0.43.3",
411
- "unsloth[cu124onlytorch251]",
412
- "unsloth[flashattention]",
413
- ]
414
-
415
- [project.urls]
416
- homepage = "http://www.unsloth.ai"
417
- documentation = "https://github.com/unslothai/unsloth"
418
- repository = "https://github.com/unslothai/unsloth"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth-cli.py DELETED
@@ -1,221 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- """
4
- 🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
5
-
6
- This script is designed as a starting point for fine-tuning your models using unsloth.
7
- It includes configurable options for model loading, PEFT parameters, training arguments,
8
- and model saving/pushing functionalities.
9
-
10
- You will likely want to customize this script to suit your specific use case
11
- and requirements.
12
-
13
- Here are a few suggestions for customization:
14
- - Modify the dataset loading and preprocessing steps to match your data.
15
- - Customize the model saving and pushing configurations.
16
-
17
- Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
18
- python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
19
- --r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
20
- --random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
21
- --warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
22
- --weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
23
- --report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
24
- --push_model --hub_path "hf/model" --hub_token "your_hf_token"
25
-
26
- To see a full list of configurable options, use:
27
- python unsloth-cli.py --help
28
-
29
- Happy fine-tuning!
30
- """
31
-
32
- import argparse
33
-
34
- def run(args):
35
- import torch
36
- from unsloth import FastLanguageModel
37
- from datasets import load_dataset
38
- from trl import SFTTrainer
39
- from transformers import TrainingArguments
40
- from unsloth import is_bfloat16_supported
41
- import logging
42
- logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
43
-
44
- # Load model and tokenizer
45
- model, tokenizer = FastLanguageModel.from_pretrained(
46
- model_name=args.model_name,
47
- max_seq_length=args.max_seq_length,
48
- dtype=args.dtype,
49
- load_in_4bit=args.load_in_4bit,
50
- )
51
-
52
- # Configure PEFT model
53
- model = FastLanguageModel.get_peft_model(
54
- model,
55
- r=args.r,
56
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
57
- "gate_proj", "up_proj", "down_proj"],
58
- lora_alpha=args.lora_alpha,
59
- lora_dropout=args.lora_dropout,
60
- bias=args.bias,
61
- use_gradient_checkpointing=args.use_gradient_checkpointing,
62
- random_state=args.random_state,
63
- use_rslora=args.use_rslora,
64
- loftq_config=args.loftq_config,
65
- )
66
-
67
- alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
68
-
69
- ### Instruction:
70
- {}
71
-
72
- ### Input:
73
- {}
74
-
75
- ### Response:
76
- {}"""
77
-
78
- EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
79
- def formatting_prompts_func(examples):
80
- instructions = examples["instruction"]
81
- inputs = examples["input"]
82
- outputs = examples["output"]
83
- texts = []
84
- for instruction, input, output in zip(instructions, inputs, outputs):
85
- text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
86
- texts.append(text)
87
- return {"text": texts}
88
-
89
- # Load and format dataset
90
- dataset = load_dataset(args.dataset, split="train")
91
- dataset = dataset.map(formatting_prompts_func, batched=True)
92
- print("Data is formatted and ready!")
93
-
94
- # Configure training arguments
95
- training_args = TrainingArguments(
96
- per_device_train_batch_size=args.per_device_train_batch_size,
97
- gradient_accumulation_steps=args.gradient_accumulation_steps,
98
- warmup_steps=args.warmup_steps,
99
- max_steps=args.max_steps,
100
- learning_rate=args.learning_rate,
101
- fp16=not is_bfloat16_supported(),
102
- bf16=is_bfloat16_supported(),
103
- logging_steps=args.logging_steps,
104
- optim=args.optim,
105
- weight_decay=args.weight_decay,
106
- lr_scheduler_type=args.lr_scheduler_type,
107
- seed=args.seed,
108
- output_dir=args.output_dir,
109
- report_to=args.report_to,
110
- )
111
-
112
- # Initialize trainer
113
- trainer = SFTTrainer(
114
- model=model,
115
- tokenizer=tokenizer,
116
- train_dataset=dataset,
117
- dataset_text_field="text",
118
- max_seq_length=args.max_seq_length,
119
- dataset_num_proc=2,
120
- packing=False,
121
- args=training_args,
122
- )
123
-
124
- # Train model
125
- trainer_stats = trainer.train()
126
-
127
- # Save model
128
- if args.save_model:
129
- # if args.quantization_method is a list, we will save the model for each quantization method
130
- if args.save_gguf:
131
- if isinstance(args.quantization, list):
132
- for quantization_method in args.quantization:
133
- print(f"Saving model with quantization method: {quantization_method}")
134
- model.save_pretrained_gguf(
135
- args.save_path,
136
- tokenizer,
137
- quantization_method=quantization_method,
138
- )
139
- if args.push_model:
140
- model.push_to_hub_gguf(
141
- hub_path=args.hub_path,
142
- hub_token=args.hub_token,
143
- quantization_method=quantization_method,
144
- )
145
- else:
146
- print(f"Saving model with quantization method: {args.quantization}")
147
- model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
148
- if args.push_model:
149
- model.push_to_hub_gguf(
150
- hub_path=args.hub_path,
151
- hub_token=args.hub_token,
152
- quantization_method=quantization_method,
153
- )
154
- else:
155
- model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
156
- if args.push_model:
157
- model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
158
- else:
159
- print("Warning: The model is not saved!")
160
-
161
-
162
- if __name__ == "__main__":
163
-
164
- # Define argument parser
165
- parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")
166
-
167
- model_group = parser.add_argument_group("🤖 Model Options")
168
- model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
169
- model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
170
- model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
171
- model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
172
- model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
173
-
174
- lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
175
- lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
176
- lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
177
- lora_group.add_argument('--lora_dropout', type=float, default=0, help="LoRA dropout rate, default is 0.0 which is optimized.")
178
- lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
179
- lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
180
- lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
181
- lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
182
- lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
183
-
184
-
185
- training_group = parser.add_argument_group("🎓 Training Options")
186
- training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
187
- training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
188
- training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
189
- training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
190
- training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
191
- training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
192
- training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
193
- training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
194
- training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
195
-
196
-
197
- # Report/Logging arguments
198
- report_group = parser.add_argument_group("📊 Report Options")
199
- report_group.add_argument('--report_to', type=str, default="tensorboard",
200
- choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
201
- help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
202
- report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
203
-
204
- # Saving and pushing arguments
205
- save_group = parser.add_argument_group('💾 Save Model Options')
206
- save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
207
- save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
208
- save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
209
- save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
210
- save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
211
- save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
212
- help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
213
-
214
- push_group = parser.add_argument_group('🚀 Push Model Options')
215
- push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
216
- push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
217
- push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
218
- push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
219
-
220
- args = parser.parse_args()
221
- run(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/__init__.py DELETED
@@ -1,181 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import warnings, importlib, sys
16
- from packaging.version import Version
17
- import os, re, subprocess, inspect
18
- import numpy as np
19
-
20
- # # Define a list of modules to check
21
- # MODULES_TO_CHECK = ["bitsandbytes"]
22
-
23
- # # Check if any of the modules in the list have been imported
24
- # for module in MODULES_TO_CHECK:
25
- # if module in sys.modules:
26
- # raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
27
- # pass
28
- # pass
29
-
30
- # Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
31
- # enabling it will require much more work, so we have to prioritize. Please understand!
32
- # We do have a beta version, which you can contact us about!
33
- # Thank you for your understanding and we appreciate it immensely!
34
-
35
- # Fixes https://github.com/unslothai/unsloth/issues/1266
36
- os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
37
-
38
- if "CUDA_VISIBLE_DEVICES" in os.environ:
39
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
40
- devices = os.environ["CUDA_VISIBLE_DEVICES"]
41
- # Check if there are multiple cuda devices set in env
42
- if not devices.isdigit():
43
- first_id = devices.split(",")[0]
44
- warnings.warn(
45
- f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
46
- "Unsloth currently does not support multi GPU setups - but we are working on it!\n"\
47
- "Multiple CUDA devices detected but we require a single device.\n"\
48
- f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
49
- )
50
- os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
51
- else:
52
- # warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
53
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
54
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
55
- pass
56
-
57
- # Reduce VRAM usage by reducing fragmentation
58
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]"
59
-
60
- # Hugging Face Hub faster downloads
61
- if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
62
- os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
63
- pass
64
-
65
- # Log Unsloth is being used
66
- os.environ["UNSLOTH_IS_PRESENT"] = "1"
67
-
68
- try:
69
- import torch
70
- except ModuleNotFoundError:
71
- raise ImportError(
72
- "Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n"\
73
- "We have some installation instructions on our Github page."
74
- )
75
- except Exception as exception:
76
- raise exception
77
- pass
78
-
79
- # We support Pytorch 2
80
- # Fixes https://github.com/unslothai/unsloth/issues/38
81
- torch_version = torch.__version__.split(".")
82
- major_torch, minor_torch = torch_version[0], torch_version[1]
83
- major_torch, minor_torch = int(major_torch), int(minor_torch)
84
- if (major_torch < 2):
85
- raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
86
- "We have some installation instructions on our Github page.")
87
- elif (major_torch == 2) and (minor_torch < 2):
88
- # Disable expandable_segments
89
- del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
90
- pass
91
-
92
- # Torch 2.4 has including_emulation
93
- major_version, minor_version = torch.cuda.get_device_capability()
94
- SUPPORTS_BFLOAT16 = (major_version >= 8)
95
-
96
- old_is_bf16_supported = torch.cuda.is_bf16_supported
97
- if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
98
- def is_bf16_supported(including_emulation = False):
99
- return old_is_bf16_supported(including_emulation)
100
- torch.cuda.is_bf16_supported = is_bf16_supported
101
- else:
102
- def is_bf16_supported(): return SUPPORTS_BFLOAT16
103
- torch.cuda.is_bf16_supported = is_bf16_supported
104
- pass
105
-
106
- # Try loading bitsandbytes and triton
107
- import bitsandbytes as bnb
108
-
109
- if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
110
-
111
- import triton
112
- libcuda_dirs = lambda: None
113
- if Version(triton.__version__) >= Version("3.0.0"):
114
- try: from triton.backends.nvidia.driver import libcuda_dirs
115
- except: pass
116
- else: from triton.common.build import libcuda_dirs
117
-
118
- try:
119
- cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
120
- libcuda_dirs()
121
- except:
122
- warnings.warn(
123
- "Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
124
- )
125
-
126
- if os.path.exists("/usr/lib64-nvidia"):
127
- os.system("ldconfig /usr/lib64-nvidia")
128
- elif os.path.exists("/usr/local"):
129
- # Sometimes bitsandbytes cannot be linked properly in Runpod for example
130
- possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
131
- find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
132
- possible_cudas = [find_cuda.search(x) for x in possible_cudas]
133
- possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
134
-
135
- # Try linking cuda folder, or everything in local
136
- if len(possible_cudas) == 0:
137
- os.system("ldconfig /usr/local/")
138
- else:
139
- find_number = re.compile(r"([\d\.]{2,})")
140
- latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
141
- latest_cuda = possible_cudas[latest_cuda]
142
- os.system(f"ldconfig /usr/local/{latest_cuda}")
143
- pass
144
-
145
- importlib.reload(bnb)
146
- importlib.reload(triton)
147
- try:
148
- libcuda_dirs = lambda: None
149
- if Version(triton.__version__) >= Version("3.0.0"):
150
- try: from triton.backends.nvidia.driver import libcuda_dirs
151
- except: pass
152
- else: from triton.common.build import libcuda_dirs
153
- cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
154
- libcuda_dirs()
155
- except:
156
- warnings.warn(
157
- "Unsloth: CUDA is not linked properly.\n"\
158
- "Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
159
- "We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
160
- "You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
161
- "Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
162
- "Unsloth will still run for now, but maybe it might crash - let's hope it works!"
163
- )
164
- pass
165
- pass
166
-
167
- # Check for unsloth_zoo
168
- try:
169
- import unsloth_zoo
170
- except:
171
- raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`")
172
- pass
173
-
174
- from .models import *
175
- from .save import *
176
- from .chat_templates import *
177
- from .tokenizer_utils import *
178
- from .trainer import *
179
-
180
- # Patch TRL trainers for backwards compatibility
181
- _patch_trl_trainer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/_auto_install.py DELETED
@@ -1,31 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- try: import torch
16
- except: raise ImportError('Install torch via `pip install torch`')
17
- from packaging.version import Version as V
18
- v = V(torch.__version__)
19
- cuda = str(torch.version.cuda)
20
- is_ampere = torch.cuda.get_device_capability()[0] >= 8
21
- if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!")
22
- if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
23
- elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
24
- elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
25
- elif v < V('2.3.0'): x = 'cu{}{}-torch220'
26
- elif v < V('2.4.0'): x = 'cu{}{}-torch230'
27
- elif v < V('2.5.0'): x = 'cu{}{}-torch240'
28
- elif v < V('2.6.0'): x = 'cu{}{}-torch250'
29
- else: raise RuntimeError(f"Torch = {v} too new!")
30
- x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
31
- print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/chat_templates.py DELETED
@@ -1,2105 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- __all__ = [
16
- "get_chat_template",
17
- "test_chat_templates",
18
- "test_hf_gguf_equivalence",
19
- "remove_special_tokens",
20
-
21
- "to_sharegpt",
22
- "standardize_sharegpt",
23
- "apply_chat_template",
24
- "train_on_responses_only",
25
-
26
- "test_construct_chat_template",
27
- ]
28
-
29
- from transformers import StoppingCriteria, StoppingCriteriaList
30
- from torch import LongTensor, FloatTensor
31
- from transformers.models.llama.modeling_llama import logger
32
- from .save import patch_saving_functions
33
- import os
34
- import shutil
35
- from .tokenizer_utils import *
36
- from .models._utils import patch_tokenizer
37
- import re
38
- from unsloth_zoo.dataset_utils import (
39
- train_on_responses_only,
40
- )
41
- CHAT_TEMPLATES = {}
42
- DEFAULT_SYSTEM_MESSAGE = {}
43
-
44
- # =========================================== Unsloth
45
- # Unsloth efficient template leverages from Zephyr
46
- unsloth_template = \
47
- "{{ bos_token }}"\
48
- "{% if messages[0]['role'] == 'system' %}"\
49
- "{{ messages[0]['content'] + '\n' }}"\
50
- "{% set loop_messages = messages[1:] %}"\
51
- "{% else %}"\
52
- "{{ '{system_message}' + '\n' }}"\
53
- "{% set loop_messages = messages %}"\
54
- "{% endif %}"\
55
- "{% for message in loop_messages %}"\
56
- "{% if message['role'] == 'user' %}"\
57
- "{{ '>>> User: ' + message['content'] + '\n' }}"\
58
- "{% elif message['role'] == 'assistant' %}"\
59
- "{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
60
- "{% else %}"\
61
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
62
- "{% endif %}"\
63
- "{% endfor %}"\
64
- "{% if add_generation_prompt %}"\
65
- "{{ '>>> Assistant: ' }}"\
66
- "{% endif %}"
67
- pass
68
-
69
- unsloth_ollama = \
70
- '''
71
- FROM {__FILE_LOCATION__}
72
- TEMPLATE """{{ if .System }}{{ .System }}
73
- {{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
74
- {{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
75
- """
76
- PARAMETER stop "{__EOS_TOKEN__}"
77
- PARAMETER temperature 1.5
78
- PARAMETER min_p 0.1
79
- SYSTEM """You are a helpful assistant to the user"""
80
- '''
81
-
82
- unsloth_eos_token = "eos_token"
83
- CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
84
- DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user"
85
- pass
86
-
87
- # =========================================== Zephyr
88
- # Zephyr has no BOS!
89
- zephyr_template = \
90
- "{% for message in messages %}"\
91
- "{% if message['role'] == 'user' %}"\
92
- "{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
93
- "{% elif message['role'] == 'assistant' %}"\
94
- "{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
95
- "{% else %}"\
96
- "{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
97
- "{% endif %}"\
98
- "{% endfor %}"\
99
- "{% if add_generation_prompt %}"\
100
- "{{ '<|assistant|>\n' }}"\
101
- "{% endif %}"
102
- pass
103
-
104
- zephyr_ollama = \
105
- '''
106
- FROM {__FILE_LOCATION__}
107
- TEMPLATE """{{ if .System }}<|system|>
108
- {{ .System }}{__EOS_TOKEN__}
109
- {{ end }}{{ if .Prompt }}<|user|>
110
- {{ .Prompt }}{__EOS_TOKEN__}
111
- {{ end }}<|assistant|>
112
- {{ .Response }}{__EOS_TOKEN__}
113
- """
114
- PARAMETER stop "{__EOS_TOKEN__}"
115
- PARAMETER temperature 1.5
116
- PARAMETER min_p 0.1
117
- '''
118
-
119
- zephyr_eos_token = "eos_token"
120
- CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
121
- DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr
122
- pass
123
-
124
- # =========================================== ChatML
125
- # ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
126
- chatml_template = \
127
- "{% for message in messages %}"\
128
- "{% if message['role'] == 'user' %}"\
129
- "{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
130
- "{% elif message['role'] == 'assistant' %}"\
131
- "{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
132
- "{% else %}"\
133
- "{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
134
- "{% endif %}"\
135
- "{% endfor %}"\
136
- "{% if add_generation_prompt %}"\
137
- "{{ '<|im_start|>assistant\n' }}"\
138
- "{% endif %}"
139
- pass
140
-
141
- chatml_ollama = \
142
- '''
143
- FROM {__FILE_LOCATION__}
144
- TEMPLATE """{{ if .System }}<|im_start|>system
145
- {{ .System }}<|im_end|>
146
- {{ end }}{{ if .Prompt }}<|im_start|>user
147
- {{ .Prompt }}<|im_end|>
148
- {{ end }}<|im_start|>assistant
149
- {{ .Response }}<|im_end|>
150
- """
151
- PARAMETER stop "<|im_start|>"
152
- PARAMETER stop "<|im_end|>"
153
- PARAMETER temperature 1.5
154
- PARAMETER min_p 0.1
155
- '''
156
-
157
- chatml_eos_token = "<|im_end|>"
158
- CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
159
- DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML
160
- pass
161
-
162
- # =========================================== Mistral-1
163
- # Mistral Instruct doesn't allow system prompts, so we append it to the user message.
164
- mistral_template = \
165
- "{{ bos_token }}"\
166
- "{% if messages[0]['role'] == 'system' %}"\
167
- "{% if messages[1]['role'] == 'user' %}"\
168
- "{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
169
- "{% set loop_messages = messages[2:] %}"\
170
- "{% else %}"\
171
- "{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
172
- "{% set loop_messages = messages[1:] %}"\
173
- "{% endif %}"\
174
- "{% else %}"\
175
- "{% set loop_messages = messages %}"\
176
- "{% endif %}"\
177
- "{% for message in loop_messages %}"\
178
- "{% if message['role'] == 'user' %}"\
179
- "{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
180
- "{% elif message['role'] == 'assistant' %}"\
181
- "{{ message['content'] + eos_token }}"\
182
- "{% else %}"\
183
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
184
- "{% endif %}"\
185
- "{% endfor %}"
186
- pass
187
-
188
- # Ollama from https://www.ollama.com/library/mistral
189
- mistral_ollama = \
190
- '''
191
- FROM {__FILE_LOCATION__}
192
- TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
193
- PARAMETER stop "{__EOS_TOKEN__}"
194
- PARAMETER temperature 1.5
195
- PARAMETER min_p 0.1
196
- '''
197
-
198
- mistral_eos_token = "eos_token"
199
- CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
200
- DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral
201
- pass
202
-
203
- # =========================================== Llama-2
204
- # Adds BOS to every convo! And weird <<SYS>> system messages.
205
- llama_template = \
206
- "{% if messages[0]['role'] == 'system' %}"\
207
- "{% if messages[1]['role'] == 'user' %}"\
208
- "{{ bos_token + '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
209
- "{% set loop_messages = messages[2:] %}"\
210
- "{% else %}"\
211
- "{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
212
- "{% set loop_messages = messages[1:] %}"\
213
- "{% endif %}"\
214
- "{% else %}"\
215
- "{% set loop_messages = messages %}"\
216
- "{% endif %}"\
217
- "{% for message in loop_messages %}"\
218
- "{% if message['role'] == 'user' %}"\
219
- "{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
220
- "{% elif message['role'] == 'assistant' %}"\
221
- "{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
222
- "{% else %}"\
223
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
224
- "{% endif %}"\
225
- "{% endfor %}"
226
- pass
227
-
228
- # Ollama from https://www.ollama.com/library/llama3
229
- llama_ollama = \
230
- '''
231
- FROM {__FILE_LOCATION__}
232
- TEMPLATE """[INST] <<SYS>>{{ .System }}<</SYS>>
233
-
234
- {{ .Prompt }} [/INST]"""
235
- PARAMETER stop "{__EOS_TOKEN__}"
236
- PARAMETER temperature 1.5
237
- PARAMETER min_p 0.1
238
- '''
239
-
240
- llama_eos_token = "eos_token"
241
- CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
242
- DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama
243
- pass
244
-
245
- # =========================================== Vicuna
246
- # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
247
- vicuna_template = \
248
- "{{ bos_token }}"\
249
- "{% if messages[0]['role'] == 'system' %}"\
250
- "{{ messages[0]['content'] + ' ' }}"\
251
- "{% set loop_messages = messages[1:] %}"\
252
- "{% else %}"\
253
- "{{ '{system_message}' + ' ' }}"\
254
- "{% set loop_messages = messages %}"\
255
- "{% endif %}"\
256
- "{% for message in loop_messages %}"\
257
- "{% if message['role'] == 'user' %}"\
258
- "{{ 'USER: ' + message['content'] + ' ' }}"\
259
- "{% elif message['role'] == 'assistant' %}"\
260
- "{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
261
- "{% else %}"\
262
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
263
- "{% endif %}"\
264
- "{% endfor %}"\
265
- "{% if add_generation_prompt %}"\
266
- "{{ 'ASSISTANT:' }}"\
267
- "{% endif %}"
268
- pass
269
-
270
- # Ollama from https://www.ollama.com/library/vicuna
271
- vicuna_ollama = \
272
- '''
273
- FROM {__FILE_LOCATION__}
274
- TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
275
- PARAMETER stop "{__EOS_TOKEN__}"
276
- PARAMETER temperature 1.5
277
- PARAMETER min_p 0.1
278
- '''
279
-
280
- vicuna_eos_token = "eos_token"
281
- CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
282
- DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
283
- pass
284
-
285
- # =========================================== Vicuna Old
286
- # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
287
- vicuna_old_template = \
288
- "{{ bos_token }}"\
289
- "{% if messages[0]['role'] == 'system' %}"\
290
- "{{ messages[0]['content'] + '\n' }}"\
291
- "{% set loop_messages = messages[1:] %}"\
292
- "{% else %}"\
293
- "{{ '{system_message}' + '\n' }}"\
294
- "{% set loop_messages = messages %}"\
295
- "{% endif %}"\
296
- "{% for message in loop_messages %}"\
297
- "{% if message['role'] == 'user' %}"\
298
- "{{ '### Human: ' + message['content'] + '\n' }}"\
299
- "{% elif message['role'] == 'assistant' %}"\
300
- "{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
301
- "{% else %}"\
302
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
303
- "{% endif %}"\
304
- "{% endfor %}"\
305
- "{% if add_generation_prompt %}"\
306
- "{{ '### Assistant:' }}"\
307
- "{% endif %}"
308
- pass
309
-
310
- vicuna_old_ollama = \
311
- '''
312
- FROM {__FILE_LOCATION__}
313
- TEMPLATE """{{ if .System }}{{ .System }}
314
- {{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
315
- {{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
316
- """
317
- PARAMETER stop "{__EOS_TOKEN__}"
318
- PARAMETER temperature 1.5
319
- PARAMETER min_p 0.1
320
- SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
321
- '''
322
-
323
- vicuna_old_eos_token = "eos_token"
324
- CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
325
- DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions."
326
-
327
- CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"]
328
- DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"]
329
- pass
330
-
331
- # =========================================== Alpaca multi turn
332
- # https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
333
- alpaca_template = \
334
- "{{ bos_token }}"\
335
- "{% if messages[0]['role'] == 'system' %}"\
336
- "{{ messages[0]['content'] + '\n\n' }}"\
337
- "{% set loop_messages = messages[1:] %}"\
338
- "{% else %}"\
339
- "{{ '{system_message}' + '\n\n' }}"\
340
- "{% set loop_messages = messages %}"\
341
- "{% endif %}"\
342
- "{% for message in loop_messages %}"\
343
- "{% if message['role'] == 'user' %}"\
344
- "{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
345
- "{% elif message['role'] == 'assistant' %}"\
346
- "{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
347
- "{% else %}"\
348
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
349
- "{% endif %}"\
350
- "{% endfor %}"\
351
- "{% if add_generation_prompt %}"\
352
- "{{ '### Response:\n' }}"\
353
- "{% endif %}"
354
- pass
355
-
356
- alpaca_ollama = \
357
- '''
358
- FROM {__FILE_LOCATION__}
359
- TEMPLATE """{{ if .System }}{{ .System }}
360
-
361
- {{ end }}{{ if .Prompt }}### Instruction:
362
- {{ .Prompt }}{{ end }}
363
-
364
- ### Response:
365
- {{ .Response }}{__EOS_TOKEN__}
366
-
367
- """
368
- PARAMETER stop "{__EOS_TOKEN__}"
369
- PARAMETER temperature 1.5
370
- PARAMETER min_p 0.1
371
- SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
372
- '''
373
-
374
- alpaca_eos_token = "eos_token"
375
- CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
376
- DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
377
- pass
378
-
379
- # =========================================== Gemma
380
- # https://huggingface.co/google/gemma-7b-it
381
- # Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.
382
- # <end_of_turn> maps to 107. user and model are normal 1 word tokens.
383
- gemma_template = \
384
- "{{ bos_token }}"\
385
- "{% if messages[0]['role'] == 'system' %}"\
386
- "{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
387
- "{% set messages = messages[2:] %}"\
388
- "{% endif %}"\
389
- "{% for message in messages %}"\
390
- "{% if message['role'] == 'user' %}"\
391
- "{{'<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>\n'}}"\
392
- "{% elif message['role'] == 'assistant' %}"\
393
- "{{'<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
394
- "{% else %}"\
395
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
396
- "{% endif %}"\
397
- "{% endfor %}"\
398
- "{% if add_generation_prompt %}"\
399
- "{{ '<start_of_turn>model\n' }}"\
400
- "{% endif %}"
401
- pass
402
-
403
- # Ollama from https://www.ollama.com/library/gemma
404
- gemma_ollama = \
405
- '''
406
- FROM {__FILE_LOCATION__}
407
- TEMPLATE """<start_of_turn>user
408
- {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
409
- <start_of_turn>model
410
- {{ .Response }}<end_of_turn>
411
- """
412
- PARAMETER repeat_penalty 1
413
- PARAMETER stop "<start_of_turn>"
414
- PARAMETER stop "<end_of_turn>"
415
- PARAMETER penalize_newline false
416
- PARAMETER temperature 1.5
417
- PARAMETER min_p 0.1
418
- '''
419
-
420
- gemma_eos_token = "<end_of_turn>"
421
- CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
422
- DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma
423
- pass
424
-
425
- # =========================================== Gemma with ChatML instead
426
- # We find using <eos> is still more appropriate!
427
- gemma_chatml_template = "{{ bos_token }}" + chatml_template
428
- pass
429
-
430
- gemma_chatml_ollama = \
431
- '''
432
- FROM {__FILE_LOCATION__}
433
- TEMPLATE """{{ if .System }}<|im_start|>system
434
- {{ .System }}<|im_end|>
435
- {{ end }}{{ if .Prompt }}<|im_start|>user
436
- {{ .Prompt }}<|im_end|>
437
- {{ end }}<|im_start|>assistant
438
- {{ .Response }}<|im_end|>
439
- """
440
- PARAMETER repeat_penalty 1
441
- PARAMETER stop "<|im_start|>"
442
- PARAMETER stop "<|im_end|>"
443
- PARAMETER penalize_newline false
444
- PARAMETER temperature 1.5
445
- PARAMETER min_p 0.1
446
- '''
447
-
448
- gemma_chatml_eos_token = (
449
- {"<start_of_turn>" : "<|im_start|>", "<eos>" : "<|im_end|>"},
450
- "<|im_end|>",
451
- )
452
- CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
453
- DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma
454
- pass
455
-
456
- # =========================================== Gemma 2
457
- # Same as Gemma 1, but with sliding window attention!
458
- # https://ollama.com/library/gemma2/blobs/6522ca797f47
459
- gemma2_template = gemma_template
460
- gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
461
- gemma2_eos_token = "<end_of_turn>"
462
- CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
463
- DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2
464
-
465
- # =========================================== Gemma 2 with ChatML instead
466
- gemma2_chatml_template = gemma_chatml_template
467
- gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
468
- gemma2_chatml_eos_token = gemma_chatml_eos_token
469
- CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
470
- DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2
471
- pass
472
-
473
- # =========================================== Llama-3
474
- # Weirdly \n\n is needed?
475
- llama3_template = \
476
- "{{ bos_token }}"\
477
- "{% for message in messages %}"\
478
- "{% if message['role'] == 'user' %}"\
479
- "{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
480
- "{% elif message['role'] == 'assistant' %}"\
481
- "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
482
- "{% else %}"\
483
- "{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
484
- "{% endif %}"\
485
- "{% endfor %}"\
486
- "{% if add_generation_prompt %}"\
487
- "{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
488
- "{% endif %}"
489
- pass
490
-
491
- # Ollama from https://www.ollama.com/library/llama3
492
- llama3_ollama = \
493
- '''
494
- FROM {__FILE_LOCATION__}
495
- TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
496
-
497
- {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
498
-
499
- {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
500
-
501
- {{ .Response }}<|eot_id|>"""
502
- PARAMETER stop "<|start_header_id|>"
503
- PARAMETER stop "<|end_header_id|>"
504
- PARAMETER stop "<|eot_id|>"
505
- PARAMETER temperature 1.5
506
- PARAMETER min_p 0.1
507
- '''
508
-
509
- llama3_template_eos_token = "eos_token"
510
-
511
- CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
512
- DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3
513
-
514
- CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
515
- DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3
516
- pass
517
-
518
-
519
- # =========================================== Phi-3
520
- # "{{ bos_token }}"\ # Phi-3.5 removes BOS?
521
- phi3_template = \
522
- "{% for message in messages %}"\
523
- "{% if message['role'] == 'user' %}"\
524
- "{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"\
525
- "{% elif message['role'] == 'assistant' %}"\
526
- "{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"\
527
- "{% else %}"\
528
- "{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
529
- "{% endif %}"\
530
- "{% endfor %}"\
531
- "{% if add_generation_prompt %}"\
532
- "{{ '<|assistant|>\n' }}"\
533
- "{% endif %}"
534
- pass
535
-
536
- # Ollama from https://www.ollama.com/library/phi3
537
- phi3_ollama = \
538
- '''
539
- FROM {__FILE_LOCATION__}
540
- TEMPLATE """{{ if .System }}<|system|>
541
- {{ .System }}<|end|>
542
- {{ end }}{{ if .Prompt }}<|user|>
543
- {{ .Prompt }}<|end|>
544
- {{ end }}<|assistant|>
545
- {{ .Response }}<|end|>
546
- """
547
- PARAMETER stop "<|end|>"
548
- PARAMETER stop "<|user|>"
549
- PARAMETER stop "<|assistant|>"
550
- PARAMETER temperature 1.5
551
- PARAMETER min_p 0.1
552
- '''
553
-
554
- phi3_template_eos_token = "<|end|>"
555
- CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
556
- DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3
557
-
558
- CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
559
- DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5
560
-
561
- CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
562
- DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5
563
- pass
564
-
565
- # =========================================== Llama-3.1
566
- """
567
- No trimming in Llama 3.1 Instruct!
568
- Also an extra newline for Cutting Knowledge Date
569
- See https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing
570
-
571
- Also should be
572
-
573
- import datetime
574
- tokenizer.apply_chat_template(
575
- messages,
576
- add_generation_prompt = True,
577
- tokenize = False,
578
- date_string = datetime.today().strftime("%d %B %Y")),
579
- )
580
- """
581
-
582
- llama31_template = \
583
- """{{- bos_token }}
584
- {%- if custom_tools is defined %}
585
- {%- set tools = custom_tools %}
586
- {%- endif %}
587
- {%- if not tools_in_user_message is defined %}
588
- {%- set tools_in_user_message = true %}
589
- {%- endif %}
590
- {%- if not date_string is defined %}
591
- {%- set date_string = "26 July 2024" %}
592
- {%- endif %}
593
- {%- if not tools is defined %}
594
- {%- set tools = none %}
595
- {%- endif %}
596
-
597
- {#- This block extracts the system message, so we can slot it into the right place. #}
598
- {%- if messages[0]['role'] == 'system' %}
599
- {%- set system_message = messages[0]['content'] %}
600
- {%- set messages = messages[1:] %}
601
- {%- else %}
602
- {%- set system_message = "{system_message}" %}
603
- {%- endif %}
604
-
605
- {#- System message + builtin tools #}
606
- {{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
607
- {%- if builtin_tools is defined or tools is not none %}
608
- {{- "Environment: ipython\n" }}
609
- {%- endif %}
610
- {%- if builtin_tools is defined %}
611
- {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
612
- {%- endif %}
613
- {{- "Cutting Knowledge Date: December 2023\n" }}
614
- {{- "Today Date: " + date_string + "\n\n" }}
615
- {%- if tools is not none and not tools_in_user_message %}
616
- {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
617
- {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
618
- {{- "Do not use variables.\n\n" }}
619
- {%- for t in tools %}
620
- {{- t | tojson(indent=4) }}
621
- {{- "\n\n" }}
622
- {%- endfor %}
623
- {%- endif %}
624
- {{- system_message }}
625
- {{- "<|eot_id|>" }}
626
-
627
- {#- Custom tools are passed in a user message with some extra guidance #}
628
- {%- if tools_in_user_message and not tools is none %}
629
- {#- Extract the first user message so we can plug it in here #}
630
- {%- if messages | length != 0 %}
631
- {%- set first_user_message = messages[0]['content'] %}
632
- {%- set messages = messages[1:] %}
633
- {%- else %}
634
- {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
635
- {%- endif %}
636
- {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
637
- {{- "Given the following functions, please respond with a JSON for a function call " }}
638
- {{- "with its proper arguments that best answers the given prompt.\n\n" }}
639
- {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
640
- {{- "Do not use variables.\n\n" }}
641
- {%- for t in tools %}
642
- {{- t | tojson(indent=4) }}
643
- {{- "\n\n" }}
644
- {%- endfor %}
645
- {{- first_user_message + "<|eot_id|>"}}
646
- {%- endif %}
647
-
648
- {%- for message in messages %}
649
- {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
650
- {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
651
- {%- elif 'tool_calls' in message %}
652
- {%- if not message.tool_calls|length == 1 %}
653
- {{- raise_exception("This model only supports single tool-calls at once!") }}
654
- {%- endif %}
655
- {%- set tool_call = message.tool_calls[0].function %}
656
- {%- if builtin_tools is defined and tool_call.name in builtin_tools %}
657
- {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
658
- {{- "<|python_tag|>" + tool_call.name + ".call(" }}
659
- {%- for arg_name, arg_val in tool_call.arguments | items %}
660
- {{- arg_name + '="' + arg_val + '"' }}
661
- {%- if not loop.last %}
662
- {{- ", " }}
663
- {%- endif %}
664
- {%- endfor %}
665
- {{- ")" }}
666
- {%- else %}
667
- {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
668
- {{- '{"name": "' + tool_call.name + '", ' }}
669
- {{- '"parameters": ' }}
670
- {{- tool_call.arguments | tojson }}
671
- {{- "}" }}
672
- {%- endif %}
673
- {%- if builtin_tools is defined %}
674
- {#- This means we're in ipython mode #}
675
- {{- "<|eom_id|>" }}
676
- {%- else %}
677
- {{- "<|eot_id|>" }}
678
- {%- endif %}
679
- {%- elif message.role == "tool" or message.role == "ipython" %}
680
- {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
681
- {%- if message.content is mapping or message.content is iterable %}
682
- {{- message.content | tojson }}
683
- {%- else %}
684
- {{- message.content }}
685
- {%- endif %}
686
- {{- "<|eot_id|>" }}
687
- {%- endif %}
688
- {%- endfor %}
689
- {%- if add_generation_prompt %}
690
- {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
691
- {%- endif %}
692
- """
693
- pass
694
-
695
- # Ollama from https://ollama.com/library/llama3.1 (needs updating!)
696
- llama31_ollama = \
697
- '''
698
- FROM {__FILE_LOCATION__}
699
- TEMPLATE """{{ if .Messages }}
700
- {{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
701
- {{- if .System }}
702
-
703
- {{ .System }}
704
- {{- end }}
705
- {{- if .Tools }}
706
-
707
- You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question.
708
- {{- end }}
709
- {{- end }}<|eot_id|>
710
- {{- range $i, $_ := .Messages }}
711
- {{- $last := eq (len (slice $.Messages $i)) 1 }}
712
- {{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
713
- {{- if and $.Tools $last }}
714
-
715
- Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
716
-
717
- Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
718
-
719
- {{ $.Tools }}
720
- {{- end }}
721
-
722
- {{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
723
-
724
- {{ end }}
725
- {{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
726
- {{- if .ToolCalls }}
727
-
728
- {{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
729
- {{- else }}
730
-
731
- {{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
732
- {{- end }}
733
- {{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
734
-
735
- {{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
736
-
737
- {{ end }}
738
- {{- end }}
739
- {{- end }}
740
- {{- else }}
741
- {{- if .System }}<|start_header_id|>system<|end_header_id|>
742
-
743
- {{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
744
-
745
- {{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
746
-
747
- {{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}"""
748
- PARAMETER stop "<|start_header_id|>"
749
- PARAMETER stop "<|end_header_id|>"
750
- PARAMETER stop "<|eot_id|>"
751
- PARAMETER stop "<|eom_id|>"
752
- PARAMETER temperature 1.5
753
- PARAMETER min_p 0.1
754
- '''
755
-
756
- llama31_template_eos_token = "eos_token"
757
- CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
758
- DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates
759
-
760
- CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
761
- DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates
762
- pass
763
-
764
-
765
- # =========================================== Qwen 2.5
766
- qwen25_template = \
767
- """{%- if tools %}
768
- {{- \'<|im_start|>system\\n\' }}
769
- {%- if messages[0][\'role\'] == \'system\' %}
770
- {{- messages[0][\'content\'] }}
771
- {%- else %}
772
- {{- \'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\' }}
773
- {%- endif %}
774
- {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
775
- {%- for tool in tools %}
776
- {{- "\\n" }}
777
- {{- tool | tojson }}
778
- {%- endfor %}
779
- {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}
780
- {%- if messages[0][\'role\'] == \'system\' %}
781
- {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}
782
- {%- else %}
783
- {{- \'<|im_start|>system\\n{system_message}<|im_end|>\\n\' }}
784
- {%- endif %}\n{%- endif %}\n{%- for message in messages %}
785
- {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
786
- {{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}
787
- {%- elif message.role == "assistant" %}
788
- {{- \'<|im_start|>\' + message.role }}
789
- {%- if message.content %}
790
- {{- \'\\n\' + message.content }}
791
- {%- endif %}
792
- {%- for tool_call in message.tool_calls %}
793
- {%- if tool_call.function is defined %}
794
- {%- set tool_call = tool_call.function %}
795
- {%- endif %}
796
- {{- \'\\n<tool_call>\\n{"name": "\' }}
797
- {{- tool_call.name }}
798
- {{- \'", "arguments": \' }}
799
- {{- tool_call.arguments | tojson }}
800
- {{- \'}\\n</tool_call>\' }}
801
- {%- endfor %}
802
- {{- \'<|im_end|>\\n\' }}
803
- {%- elif message.role == "tool" %}
804
- {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} {{- \'<|im_start|>user\' }}
805
- {%- endif %}
806
- {{- \'\\n<tool_response>\\n\' }}
807
- {{- message.content }}
808
- {{- \'\\n</tool_response>\' }}
809
- {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
810
- {{- \'<|im_end|>\\n\' }}
811
- {%- endif %}
812
- {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}
813
- {{- \'<|im_start|>assistant\\n\' }}
814
- {%- endif %}
815
- """
816
-
817
-
818
- # Ollama from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
819
- qwen25_ollama = \
820
- '''
821
- FROM {__FILE_LOCATION__}
822
- TEMPLATE """{{- if .Messages }}
823
- {{- if or .System .Tools }}<|im_start|>system
824
- {{- if .System }}
825
- {{ .System }}
826
- {{- end }}
827
- {{- if .Tools }}
828
-
829
- # Tools
830
-
831
- You may call one or more functions to assist with the user query.
832
-
833
- You are provided with function signatures within <tools></tools> XML tags:
834
- <tools>
835
- {{- range .Tools }}
836
- {"type": "function", "function": {{ .Function }}}
837
- {{- end }}
838
- </tools>
839
-
840
- For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
841
- <tool_call>
842
- {"name": <function-name>, "arguments": <args-json-object>}
843
- </tool_call>
844
- {{- end }}<|im_end|>
845
- {{ end }}
846
- {{- range $i, $_ := .Messages }}
847
- {{- $last := eq (len (slice $.Messages $i)) 1 -}}
848
- {{- if eq .Role "user" }}<|im_start|>user
849
- {{ .Content }}<|im_end|>
850
- {{ else if eq .Role "assistant" }}<|im_start|>assistant
851
- {{ if .Content }}{{ .Content }}
852
- {{- else if .ToolCalls }}<tool_call>
853
- {{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
854
- {{ end }}</tool_call>
855
- {{- end }}{{ if not $last }}<|im_end|>
856
- {{ end }}
857
- {{- else if eq .Role "tool" }}<|im_start|>user
858
- <tool_response>
859
- {{ .Content }}
860
- </tool_response><|im_end|>
861
- {{ end }}
862
- {{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
863
- {{ end }}
864
- {{- end }}
865
- {{- else }}
866
- {{- if .System }}<|im_start|>system
867
- {{ .System }}<|im_end|>
868
- {{ end }}{{ if .Prompt }}<|im_start|>user
869
- {{ .Prompt }}<|im_end|>
870
- {{ end }}<|im_start|>assistant
871
- {{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
872
- PARAMETER stop "<|im_end|>"
873
- PARAMETER stop "<|endoftext|>"
874
- PARAMETER temperature 1.5
875
- PARAMETER min_p 0.1
876
- '''
877
-
878
- qwen25_template_eos_token = "eos_token"
879
- qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
880
- CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
881
- DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
882
-
883
- CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
884
- DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5
885
-
886
- CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
887
- DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5
888
-
889
- CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
890
- DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
891
- pass
892
-
893
- def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
894
- system_message_pattern = r"\{system_message\}"
895
-
896
- # For predefined templates, check if default system message exists
897
- default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None)
898
- if default_system_message is None:
899
- if system_message is not None:
900
- logger.warning_once(
901
- f"Unsloth: You tried to change the system message for {type_chat_template}, "
902
- "but it doesn't have a default system message. "
903
- "You need to manually add the system message in your data."
904
- )
905
- return template, system_message
906
- pass
907
-
908
- # For custom templates
909
- if type_chat_template is None:
910
- has_placeholder = re.search(system_message_pattern, template) is not None
911
-
912
- if has_placeholder:
913
- if system_message is None:
914
- raise ValueError("Unsloth: You need to provide a system message for custom templates.")
915
- new_template = re.sub(system_message_pattern, system_message, template)
916
- return new_template, system_message
917
-
918
- return template, system_message
919
- pass
920
-
921
- # For predefined templates with default system message
922
- message_to_use = system_message if system_message is not None else default_system_message
923
- new_template = re.sub(system_message_pattern, message_to_use, template)
924
-
925
- return new_template, message_to_use
926
- pass
927
-
928
-
929
- def get_chat_template(
930
- tokenizer,
931
- chat_template = "chatml",
932
- mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
933
- map_eos_token = True,
934
- system_message = None,
935
- ):
936
- assert(type(map_eos_token) is bool)
937
- old_tokenizer = tokenizer
938
-
939
- IS_GEMMA = False
940
- if tokenizer.__class__.__name__.startswith("Gemma"):
941
- if chat_template == "chatml": chat_template = "gemma_chatml"
942
- IS_GEMMA = True
943
- pass
944
-
945
- # We add a check for Llama-3
946
- # if chat_template == "llama-3":
947
- # tokenizer._using_llama3_template = True
948
- # else:
949
- # llama3_tokens = set(["<|end_header_id|>", "<|eot_id|>", "<|start_header_id|>"])
950
- # check_llama3_tokens = llama3_tokens & set(str(x) for x in tokenizer.added_tokens_decoder.values())
951
- # if len(check_llama3_tokens) == len(llama3_tokens):
952
- # tokenizer._using_llama3_template = True
953
- # pass
954
- # pass
955
-
956
- # We first check if the tokenizer is a fast one. If not, we cannot convert this!
957
- is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
958
- old_padding_side = tokenizer.padding_side
959
-
960
- same_padding_token = False
961
- type_chat_template = None
962
-
963
- if type(chat_template) in (list, tuple,):
964
- # For changing system message later
965
- # Since it's not supported yet, we will raise an error first!
966
- type_chat_template = chat_template[0].lower()
967
- chat_template, stop_word = chat_template
968
- assert(type(chat_template) is str)
969
- assert(type(stop_word) is str)
970
- ollama_modelfile = None
971
-
972
- elif type(chat_template) is str:
973
- # For changing system message later
974
- type_chat_template = chat_template.lower()
975
-
976
- chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
977
-
978
- # Check mapping to eos_token
979
- if not map_eos_token and yes_map_eos_token: map_eos_token = True
980
- if not yes_map_eos_token and map_eos_token: map_eos_token = False
981
-
982
- if type(stop_word) in (list, tuple,):
983
- token_mapping, stop_word = stop_word
984
- assert(type(token_mapping) is dict)
985
- else:
986
- token_mapping = None
987
-
988
- assert(type(stop_word) is str)
989
-
990
- # Check fast tokenizer
991
- if not is_fast_tokenizer:
992
- print(
993
- "Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
994
- "Please log a Github issue if you want this as a new feature!\n"\
995
- "Your chat template will still work, but it won't add or edit tokens."
996
- )
997
-
998
- elif token_mapping is not None:
999
- # token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
1000
- # For Gemma :)
1001
-
1002
- string_vocab = tokenizer._tokenizer.to_str()
1003
-
1004
- skipped = 0
1005
- for old_token, new_token in token_mapping.items():
1006
- old_count = string_vocab.count(f'"{old_token}"')
1007
- new_count = string_vocab.count(f'"{new_token}"')
1008
- if new_count != 0:
1009
- print(f"{new_token} is already a token. Skipping.")
1010
- skipped += 1
1011
- elif old_count == 0:
1012
- raise RuntimeError(f"{old_token} was not part of the tokenizer!")
1013
- else:
1014
- string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
1015
- pass
1016
- pass
1017
-
1018
- if map_eos_token and (not stop_word in token_mapping.values()):
1019
- # Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
1020
- logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
1021
- string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
1022
- pass
1023
-
1024
- if skipped != len(token_mapping):
1025
- new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
1026
-
1027
- # Careful on pad_token
1028
- old_pad_token = tokenizer.pad_token
1029
- if old_pad_token == tokenizer.eos_token:
1030
- old_pad_token = stop_word
1031
- same_padding_token = True
1032
- pass
1033
-
1034
- if map_eos_token:
1035
- new_tokenizer = tokenizer.__class__(
1036
- tokenizer_object = new_tokenizer,
1037
- eos_token = stop_word,
1038
- pad_token = old_pad_token,
1039
- )
1040
- else:
1041
- new_tokenizer = tokenizer.__class__(
1042
- tokenizer_object = new_tokenizer,
1043
- pad_token = old_pad_token,
1044
- )
1045
- pass
1046
-
1047
- # Must fix the sentence piece tokenizer since there's no tokenizer.model file!
1048
- tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
1049
- else:
1050
- pass
1051
-
1052
- elif map_eos_token and (stop_word != "eos_token"):
1053
- logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
1054
-
1055
- # Replaces the old EOS token with a new one.
1056
- # Useful for ChatML <|im_end|> for example.
1057
- # Usually we train 2 more tokens <|im_start|> and <|im_end|>
1058
- # But training the lm_head and embeddings are slow!
1059
- # This is a HACK!
1060
- # Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
1061
-
1062
- old_bos_token = getattr(tokenizer, "bos_token", None)
1063
- old_eos_token = getattr(tokenizer, "eos_token", None)
1064
- old_pad_token = getattr(tokenizer, "pad_token", None)
1065
- old_unk_token = getattr(tokenizer, "unk_token", None)
1066
-
1067
- string_vocab = tokenizer._tokenizer.to_str()
1068
- # First check if new stop_word is in the tokenizer
1069
- if stop_word in string_vocab:
1070
- # We shall swap them around
1071
- temporary_stop_token = "<|:__TEMP//STOP//TOKEN__:|>"
1072
- string_vocab = string_vocab.replace(old_eos_token, temporary_stop_token)
1073
- string_vocab = string_vocab.replace(stop_word, old_eos_token)
1074
- string_vocab = string_vocab.replace(temporary_stop_token, stop_word)
1075
- else:
1076
- string_vocab = string_vocab.replace(old_eos_token, stop_word)
1077
- pass
1078
- new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
1079
-
1080
- # Careful on pad_token
1081
- if old_pad_token == old_eos_token:
1082
- old_pad_token = stop_word
1083
- same_padding_token = True
1084
- pass
1085
-
1086
- new_tokenizer = tokenizer.__class__(
1087
- tokenizer_object = new_tokenizer,
1088
- bos_token = old_bos_token,
1089
- eos_token = stop_word,
1090
- unk_token = old_unk_token,
1091
- pad_token = old_pad_token,
1092
- )
1093
-
1094
- # Must fix the sentence piece tokenizer since there's no tokenizer.model file!
1095
- token_mapping = { old_eos_token : stop_word, }
1096
- tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
1097
- pass
1098
-
1099
- else:
1100
- raise TypeError(
1101
- f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
1102
- f"{CHAT_TEMPLATES.keys()}"
1103
- )
1104
- pass
1105
-
1106
- # Careful on Gemma
1107
- # bos_token is a must or else losses become too high
1108
- if IS_GEMMA and not chat_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
1109
- chat_template = "{{ bos_token }}" + chat_template
1110
- pass
1111
-
1112
- # For ShareGPT role -> from and content -> value
1113
- new_chat_template = chat_template\
1114
- .replace("'role'", "'" + mapping["role"] + "'")\
1115
- .replace("'content'", "'" + mapping["content"] + "'")\
1116
- .replace("'user'", "'" + mapping["user"] + "'")\
1117
- .replace("'assistant'", "'" + mapping["assistant"] + "'")
1118
-
1119
- _, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
1120
- tokenizer.padding_side = old_padding_side
1121
-
1122
- # If not normal HF, we add a check to make old templates work
1123
- if mapping != {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"}:
1124
- chat_template = \
1125
- "{% if 'role' in messages[0] %}" + \
1126
- chat_template + \
1127
- "{% else %}" + \
1128
- new_chat_template + \
1129
- "{% endif %}"
1130
- else:
1131
- chat_template = new_chat_template
1132
- pass
1133
-
1134
- chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)
1135
-
1136
- tokenizer.chat_template = chat_template
1137
-
1138
- # Also fix up other tokens
1139
- old_pad_token = getattr(old_tokenizer, "pad_token", None)
1140
- old_bos_token = getattr(old_tokenizer, "bos_token", None)
1141
- old_unk_token = getattr(old_tokenizer, "unk_token", None)
1142
- new_pad_token = getattr(tokenizer, "pad_token", None)
1143
- new_bos_token = getattr(tokenizer, "bos_token", None)
1144
- new_unk_token = getattr(tokenizer, "unk_token", None)
1145
- if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
1146
- if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
1147
- if not same_padding_token:
1148
- if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
1149
- pass
1150
-
1151
- # stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
1152
-
1153
- # Patch saving functions
1154
- tokenizer = patch_saving_functions(tokenizer)
1155
-
1156
- # Add Ollama
1157
- tokenizer._ollama_modelfile = ollama_modelfile
1158
- tokenizer._system_message = system_message
1159
- return tokenizer#, stopping_criteria
1160
- pass
1161
-
1162
-
1163
- def remove_special_tokens(tokenizer, prompt):
1164
- # Removes double BOS token
1165
- if prompt.startswith(tokenizer.bos_token):
1166
- prompt = prompt[len(tokenizer.bos_token):]
1167
- pass
1168
- return prompt
1169
- pass
1170
-
1171
-
1172
- def _parse_combined_prompt(combined_prompt, dataset):
1173
- # Find {...}
1174
- possible_columns = re.findall(r"\{(.+?)\}", combined_prompt)
1175
- dataset_columns = set(dataset.column_names)
1176
- for column in possible_columns:
1177
- if column not in dataset_columns:
1178
- raise KeyError(
1179
- f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
1180
- f"Only allowed columns are {list(dataset_columns)}"
1181
- )
1182
- pass
1183
- pass
1184
-
1185
- # Find [[...]]
1186
- optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
1187
- optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]
1188
-
1189
- final_optional_prompts = []
1190
- if len(optional_prompts) != 0:
1191
- # Add left
1192
- left = optional_prompts[0]
1193
- l = left[0][0]
1194
- if l != 0: final_optional_prompts.append(combined_prompt[:l])
1195
-
1196
- # Add in between
1197
- for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):
1198
- l, r = left[0][-1], right[0][0]
1199
- final_optional_prompts.append(left)
1200
- if l != r: final_optional_prompts.append(combined_prompt[l : r])
1201
- pass
1202
- final_optional_prompts.append(optional_prompts[-1])
1203
-
1204
- # Add right
1205
- right = optional_prompts[-1]
1206
- r = right[0][1]
1207
- if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])
1208
- else:
1209
- # Just add in the entire string
1210
- final_optional_prompts.append(combined_prompt)
1211
- pass
1212
-
1213
- check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
1214
- assert(combined_prompt == check_combined)
1215
-
1216
- return possible_columns, final_optional_prompts
1217
- pass
1218
-
1219
-
1220
- def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
1221
- # Start final prompt!
1222
- function = ["def __combined_prompt_processor__(examples):"]
1223
- columns = list(set(possible_columns))
1224
- for column in columns:
1225
- function.append(f"{' '*4}{column}__ = examples['{column}']")
1226
- function.append(f"{' '*4}texts = []")
1227
- function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
1228
-
1229
- # Add optional tags as well!
1230
- final_prompt = ""
1231
- formatter = []
1232
-
1233
- for j, optional_prompt in enumerate(final_optional_prompts):
1234
- if type(optional_prompt) is str:
1235
- columns = re.findall(r"\{(.+?)\}", optional_prompt)
1236
- formatter += columns
1237
- # Must escape \n \r
1238
- final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
1239
- else:
1240
- where, prompt = optional_prompt
1241
- # Strip [[...]]
1242
- # Must escape \n \r
1243
- prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
1244
- columns = re.findall(r"\{(.+?)\}", prompt)
1245
- x = f"__optional_{j}__"
1246
- prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
1247
- function.append(prompt)
1248
- formatter.append(x)
1249
- final_prompt += "{" + x + "}"
1250
- pass
1251
- pass
1252
-
1253
- function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
1254
- function.append(f"{' '*8}texts.append("\
1255
- f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
1256
- function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
1257
- return "\n".join(function)
1258
- pass
1259
-
1260
-
1261
- def to_sharegpt(
1262
- dataset,
1263
- merged_prompt = "",
1264
- merged_column_name = "instruction",
1265
- output_column_name = "output",
1266
- remove_unused_columns = True,
1267
- conversation_extension = 1,
1268
- random_state = 3407,
1269
- ):
1270
- """
1271
- Converts a dataset to ShareGPT style.
1272
- ShareGPT requires only 1 input and 1 output field.
1273
- This means one has to merge multiple columns into 1 for 1 input field.
1274
- Use `conversation_extension` to increase the length of each conversation by randomnly
1275
- selecting a few and packing them into 1.
1276
-
1277
- merged_prompt = "", Prompt to merge columns into 1 input
1278
- merged_column_name = "instruction", Final column name for the input field
1279
- output_column_name = "output", Final column name for the output field
1280
- remove_unused_columns = True,
1281
- conversation_extension = 1, Automatically combines `conversation_extension` convos into 1
1282
- random_state = 3407,
1283
- """
1284
- if "conversations" in dataset.column_names:
1285
- convo = dataset[0]["conversations"]
1286
- if type(convo) is list:
1287
- raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
1288
- pass
1289
- pass
1290
-
1291
- possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
1292
- function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
1293
- exec(function, globals())
1294
- dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
1295
-
1296
- def __convert_to_sharegpt__(examples):
1297
- users = examples[merged_column_name]
1298
- assistants = examples[output_column_name]
1299
- texts = [
1300
- [
1301
- {"from" : "human", "value" : str(user) },
1302
- {"from" : "gpt", "value" : str(assistant)},
1303
- ] \
1304
- for user, assistant in zip(users, assistants)
1305
- ]
1306
- return { "conversations" : texts, }
1307
- pass
1308
-
1309
- dataset = dataset.map(
1310
- __convert_to_sharegpt__,
1311
- batched = True,
1312
- desc = "Converting to ShareGPT",
1313
- # Remove unused columns!
1314
- remove_columns = dataset.column_names if remove_unused_columns else None,
1315
- )
1316
-
1317
- # Randomnly concat conversations to create a long stream!
1318
- from datasets import concatenate_datasets
1319
- n_extensions = max(conversation_extension-1, 0)
1320
- if n_extensions == 0: return dataset
1321
-
1322
- dataset = dataset.rename_columns({"conversations" : "conversations0"})
1323
- all_shuffled = [dataset]
1324
- for j in range(1, n_extensions+1):
1325
- shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
1326
- all_shuffled.append(shuffled)
1327
- pass
1328
- dataset = concatenate_datasets(all_shuffled, axis = 1)
1329
-
1330
- # Combine them into 1
1331
- function = "def __combine_conversations__(examples):\n"
1332
- n_extensions += 1
1333
- for j in range(n_extensions):
1334
- function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
1335
- function += f"{' '*4}convos = []\n"
1336
- function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
1337
- f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
1338
- function += f"{' '*8}convos.append("\
1339
- f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
1340
- function += f"{' '*4}return " + "{ " + "'conversations' : convos" + " }"
1341
-
1342
- # Map function
1343
- exec(function, globals())
1344
- dataset = dataset.map(
1345
- __combine_conversations__,
1346
- batched = True,
1347
- desc = "Extending conversations",
1348
- # Remove unused columns!
1349
- remove_columns = dataset.column_names if remove_unused_columns else None,
1350
- )
1351
- return dataset
1352
- pass
1353
-
1354
-
1355
- def standardize_sharegpt(
1356
- dataset,
1357
- aliases_for_system = ["system",],
1358
- aliases_for_user = ["user", "human", "input",],
1359
- aliases_for_assistant = ["gpt", "assistant", "output",],
1360
- ):
1361
- """
1362
- Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
1363
-
1364
- Get aliases for the system, user and assistant roles.
1365
- These shall map to "system", "user" and "assistant" respectively.
1366
-
1367
- aliases_for_system = ["system",],
1368
- aliases_for_user = ["user", "human", "input",],
1369
- aliases_for_assistant = ["gpt", "assistant", "output",],
1370
- """
1371
- import collections
1372
- import itertools
1373
-
1374
- convos = dataset[:10]["conversations"]
1375
- uniques = collections.defaultdict(list)
1376
- for convo in convos:
1377
- for message in convo:
1378
- for key, value in message.items():
1379
- uniques[key].append(value)
1380
- pass
1381
-
1382
- # Must be only 2 entries
1383
- assert(len(uniques.keys()) == 2)
1384
-
1385
- keys = list(uniques.keys())
1386
- length_first = len(set(uniques[keys[0]]))
1387
- length_second = len(set(uniques[keys[1]]))
1388
-
1389
- if length_first < length_second:
1390
- # Role is assigned to the first element
1391
- role_key = keys[0]
1392
- content_key = keys[1]
1393
- else:
1394
- role_key = keys[1]
1395
- content_key = keys[0]
1396
- pass
1397
-
1398
- # Check roles are in aliases
1399
- all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
1400
- roles = set(uniques[role_key])
1401
- leftover_aliases = (all_aliases | roles) - all_aliases
1402
- if len(leftover_aliases) != 0:
1403
- raise TypeError(
1404
- f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
1405
- )
1406
- pass
1407
-
1408
- # Mapping for aliases
1409
- aliases_mapping = {}
1410
- for x in aliases_for_system: aliases_mapping[x] = "system"
1411
- for x in aliases_for_user: aliases_mapping[x] = "user"
1412
- for x in aliases_for_assistant: aliases_mapping[x] = "assistant"
1413
-
1414
- def _standardize_dataset(examples):
1415
- convos = examples["conversations"]
1416
- all_convos = []
1417
- for convo in convos:
1418
- new_convo = [
1419
- { "role" : aliases_mapping[message[role_key]], "content" : message[content_key], }
1420
- for message in convo
1421
- ]
1422
- all_convos.append(new_convo)
1423
- pass
1424
- return { "conversations" : all_convos, }
1425
- pass
1426
-
1427
- return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format")
1428
- pass
1429
-
1430
-
1431
- def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
1432
- added_tokens_decoder = tokenizer.added_tokens_decoder.values()
1433
- added_tokens_decoder = [str(x) for x in added_tokens_decoder]
1434
-
1435
- # Remove added_tokens_decoder duplicates
1436
- added_tokens_decoder = list(set(added_tokens_decoder) - set(extra_eos_tokens))
1437
-
1438
- # Remove BOS
1439
- if getattr(tokenizer, "bos_token", None) is not None:
1440
- added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
1441
- pass
1442
-
1443
- repeatted_tokens = []
1444
- # Join all vocab
1445
- joined_text = "\x01\x00".join(added_tokens_decoder)
1446
- for token in added_tokens_decoder:
1447
- n = len(token)
1448
- repeatted_counts = joined_text.count(token[:n//2])
1449
- # Try finding longer than 1/2 of the token in the rest
1450
- # For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>
1451
- if repeatted_counts > 2:
1452
- for j in range(n//2+1, n):
1453
- if joined_text.count(token[:j]) < repeatted_counts:
1454
- j -= 1
1455
- # Remove repeatted tokens to reduce search space
1456
- joined_text = joined_text.replace(token[:j], "")
1457
- repeatted_tokens.append(token[:j])
1458
- break
1459
- pass
1460
- pass
1461
- pass
1462
-
1463
- # Remove duplicates
1464
- splitted = joined_text.split("\x01\x00")
1465
- final_eos_tokens = []
1466
- for old, new in zip(added_tokens_decoder, splitted):
1467
- if old == new: final_eos_tokens.append(old)
1468
- pass
1469
- final_eos_tokens += extra_eos_tokens
1470
- final_eos_tokens += repeatted_tokens
1471
-
1472
- # Remove new lines, spaces and HTML tags
1473
- filtered_eos_tokens = []
1474
- for token in final_eos_tokens:
1475
- if token.count("\n") == len(token): continue
1476
- elif token.count("▁") == len(token): continue
1477
- elif token.startswith("<") and len(token) <= 2: continue
1478
- elif token.startswith("</") and len(token) == 3: continue
1479
- filtered_eos_tokens.append(token)
1480
- pass
1481
- return filtered_eos_tokens
1482
- pass
1483
-
1484
-
1485
- def construct_chat_template( \
1486
-
1487
- tokenizer = None,
1488
-
1489
- chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1490
-
1491
- {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
1492
-
1493
- {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1494
-
1495
- {OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
1496
-
1497
- {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1498
-
1499
- {OUTPUT}<|eot_id|>""",
1500
-
1501
- default_system_message = \
1502
- "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
1503
-
1504
- extra_eos_tokens = None,
1505
- ):
1506
- """
1507
- Creates a Ollama modelfile and a HF Jinja template from a custom
1508
- template. You must provide 2x examples of an input & output.
1509
- There is an optional system message as well.
1510
-
1511
- You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
1512
- """
1513
- # Strip only the left
1514
- chat_template = chat_template.lstrip()
1515
-
1516
- assert(tokenizer is not None)
1517
-
1518
- if extra_eos_tokens is None: extra_eos_tokens = []
1519
- elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]
1520
-
1521
- vocab = tokenizer.get_vocab()
1522
- for extra_eos in extra_eos_tokens:
1523
- assert(type(extra_eos) is str)
1524
- if extra_eos not in vocab:
1525
- raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
1526
- pass
1527
- pass
1528
-
1529
- error_msg = \
1530
- "Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
1531
- "and the assistant output {OUTPUT}\n\n"\
1532
- "For example what is not allowed is just:\n"\
1533
- "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
1534
- "What is required is 2x of this:\n"\
1535
- "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
1536
- "### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
1537
-
1538
- # Check for EOS after {OUTPUT}
1539
- if tokenizer.eos_token is not None:
1540
- extra_eos_tokens.insert(0, tokenizer.eos_token)
1541
- if len(extra_eos_tokens) == 0:
1542
- raise RuntimeError(
1543
- "Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
1544
- )
1545
- pass
1546
-
1547
- # Check tokenizer types
1548
- tokenizer_name = tokenizer.name_or_path.lower()
1549
- if tokenizer_name.startswith(("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")):
1550
- # Add <|eot_id|>
1551
- extra_eos_tokens.append("<|eot_id|>")
1552
- elif ("<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template) and \
1553
- tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
1554
- # Warn
1555
- logger.warning(
1556
- "Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
1557
- "Please use the instruct version or use <|end_of_text|>"
1558
- )
1559
- pass
1560
- extra_eos_tokens = list(set(extra_eos_tokens))
1561
-
1562
- count_eos = 0
1563
- for eos in extra_eos_tokens:
1564
- count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
1565
- pass
1566
-
1567
- # This forces you to provide 2 input and outputs
1568
- final_combined_check = False
1569
-
1570
- try:
1571
- # O(N^2) search finding 2 repeatted pieces of text
1572
- j = len(chat_template)-1
1573
- at_least_one = False
1574
- while j > 0:
1575
- found = chat_template.rfind(chat_template[j:], 0, j)
1576
- if found == -1: break
1577
- j -= 1
1578
- at_least_one = True
1579
- pass
1580
- if j > 0: j += 1
1581
- else: raise RuntimeError(error_msg)
1582
-
1583
- if not at_least_one: raise RuntimeError(error_msg)
1584
-
1585
- # Must be equivalent to left
1586
- final_combined_check = True
1587
-
1588
- # Repeatted text
1589
- instruction_response = chat_template[j:]
1590
- if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
1591
- raise RuntimeError(error_msg)
1592
- pass
1593
-
1594
- # 1st System, Instruction, Output pair
1595
- left = chat_template[:j]
1596
- # 2nd Instruction, Output pair
1597
- right = chat_template[j:]
1598
-
1599
- final_combined_check = left if final_combined_check else chat_template
1600
-
1601
- # Isolate input
1602
- extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
1603
- if len(extra_eos_tokens_regex) != 0:
1604
- find_end = f"(?:{extra_eos_tokens_regex})?"
1605
- else:
1606
- find_end = ""
1607
- find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
1608
- input_end = list(re.finditer(find_end, right))
1609
- assert(len(input_end) == 1)
1610
- input_end = input_end[0]
1611
- input_end = input_end.span(0)[1]
1612
- input_part = right[:input_end]
1613
-
1614
- # Isolate output
1615
- output_part = right[input_end:]
1616
-
1617
- # Isolate system
1618
- where_system = left.find(input_part)
1619
- system_part = left[:where_system if where_system != -1 else len(left)]
1620
-
1621
- # Check if the user provided a correct prompt
1622
- combined = system_part + input_part + output_part
1623
- if combined != final_combined_check:
1624
- combined_changed = combined .replace('\n', '\\n')
1625
- left_changed = final_combined_check.replace('\n', '\\n')
1626
- raise RuntimeError(
1627
- "Unsloth: The prompt template you provided isn't correct. You gave:\n"\
1628
- f"{combined_changed}\n\n"\
1629
- "But we require the following:\n"\
1630
- f"{left_changed}"
1631
- )
1632
- pass
1633
- except:
1634
- ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
1635
-
1636
- ending = re.escape(ending)
1637
- find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
1638
- response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
1639
- response_part = response_part[0]
1640
-
1641
- for j in range(1, len(response_part)):
1642
- try_find = re.escape(response_part[:j])
1643
- try: found = next(re.finditer("(" + try_find + ").+?\{INPUT\}", chat_template, flags = re.DOTALL | re.MULTILINE))
1644
- except: break
1645
- pass
1646
- separator = found.group(1)
1647
-
1648
- response_start = chat_template.find(response_part)
1649
- start_instruction = chat_template[:response_start].rfind(separator)
1650
- if start_instruction == -1: start_instruction = 0
1651
- instruction_part = chat_template[start_instruction:response_start]
1652
-
1653
- combined = instruction_part + response_part
1654
- where = chat_template.find(combined)
1655
- system_part = chat_template[:where]
1656
-
1657
- system_part, input_part, output_part = system_part, instruction_part, response_part
1658
- pass
1659
-
1660
- if count_eos == 0:
1661
- logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
1662
- eos = extra_eos_tokens[0]
1663
- output_part = output_part + eos
1664
- pass
1665
-
1666
- # Ollama modelfile parts
1667
-
1668
- # Check bos_token is in system prompt
1669
- ollama_system = system_part
1670
- has_bos_token = False
1671
- always_bos_token = False
1672
- if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
1673
- always_bos_token = True
1674
- if ollama_system.startswith(tokenizer.bos_token):
1675
- has_bos_token = True
1676
- ollama_system = ollama_system[len(tokenizer.bos_token):]
1677
- pass
1678
- pass
1679
- # Check system
1680
- if "{SYSTEM}" in ollama_system:
1681
- system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
1682
- else:
1683
- system_modelfile = ollama_system
1684
- pass
1685
- input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}"
1686
- output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
1687
-
1688
- # Ollama EOS
1689
- ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)
1690
- ollama_eos = '\n'.join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
1691
-
1692
- # Add temperature and min_p to counteract gibberish
1693
- ollama_eos += "\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1"
1694
-
1695
- # Ollama modelfile
1696
- part = '"""'
1697
- modelfile = 'FROM {__FILE_LOCATION__}\n\n'\
1698
- 'TEMPLATE ' + part + system_modelfile + input_modelfile + output_modelfile + \
1699
- part + '\n\n' + ollama_eos
1700
-
1701
- # HF Jinja Chat template
1702
- def process(part, which, content = "message['content']"):
1703
- if part.endswith(which):
1704
- part = "'" + part[:part.find(which)] + f"' + {content}"
1705
- elif part.startswith(which):
1706
- part = f"{content} + '" + part[part.find(which):] + "'"
1707
- else:
1708
- part = "'" + part.replace(which, f"' + {content} + '") + "'"
1709
- if part.startswith("'' + "): part = part[5:]
1710
- return part
1711
- pass
1712
- input_jinja = process(input_part, "{INPUT}")
1713
- output_jinja = process(output_part, "{OUTPUT}")
1714
- pass
1715
-
1716
- jinja_template = \
1717
- "{% for message in loop_messages %}"\
1718
- "{% if message['role'] == 'user' %}"\
1719
- "{{ " + input_jinja + " }}"\
1720
- "{% elif message['role'] == 'assistant' %}"\
1721
- "{{ " + output_jinja + " }}"\
1722
- "{% else %}"\
1723
- "{{ raise_exception('Only user and assistant roles are supported!') }}"\
1724
- "{% endif %}"\
1725
- "{% endfor %}"\
1726
- "{% if add_generation_prompt %}"\
1727
- "{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
1728
- "{% endif %}"
1729
- pass
1730
-
1731
- # Now add system prompt to jinja
1732
- if len(system_part) != 0:
1733
- partial_system = process(system_part, "{SYSTEM}", "messages[0]['content']")
1734
- partial_system = partial_system.replace("{SYSTEM}", "")
1735
-
1736
- if "{SYSTEM}" in partial_system:
1737
- if default_system_message is None:
1738
- raise RuntimeError("Unsloth: Please specify a default system message!")
1739
- pass
1740
-
1741
- # Separate the BOS
1742
- if has_bos_token:
1743
- partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
1744
- system_part = system_part .replace(tokenizer.bos_token, "", 1)
1745
- pass
1746
-
1747
- partial_system = \
1748
- "{% if messages[0]['role'] == 'system' %}"\
1749
- "{{ " + partial_system + " }}"\
1750
- "{% set loop_messages = messages[1:] %}"
1751
- if default_system_message is not None:
1752
- full_system = system_part.replace("{SYSTEM}", default_system_message)
1753
- if "{SYSTEM}" in system_part:
1754
- modelfile += '\nSYSTEM "' + default_system_message + '"'
1755
- pass
1756
- partial_system += "{% else %}"\
1757
- "{{ '" + full_system + "' }}"\
1758
- "{% set loop_messages = messages %}"\
1759
- "{% endif %}"
1760
- else:
1761
- partial_system += "{% endif %}"
1762
- pass
1763
-
1764
- jinja_template = partial_system + jinja_template
1765
-
1766
- if has_bos_token:
1767
- jinja_template = "{{ bos_token }}" + jinja_template
1768
- pass
1769
-
1770
- # Fix missing loop_messages
1771
- if "{% set loop_messages = messages %}" not in jinja_template:
1772
- jinja_template = jinja_template.replace(
1773
- "{% for message in loop_messages %}",
1774
- "{% for message in messages %}",
1775
- 1, # Only replace the first one
1776
- )
1777
- pass
1778
-
1779
- # Check if system part is the same!
1780
- jinja_template = re.sub(
1781
- r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"\
1782
- r"\{\% set loop\_messages \= messages\[1\:\] \%\}"\
1783
- r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"\
1784
- r"\{\% for message in loop\_messages \%\}",
1785
- r"{{ '\1' }}{% for message in messages %}",
1786
- jinja_template, flags = re.MULTILINE | re.DOTALL,
1787
- )
1788
-
1789
- # Check jinja tempate for bos
1790
- if always_bos_token:
1791
- if not jinja_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
1792
- jinja_template = "{{ bos_token }}" + jinja_template
1793
- pass
1794
-
1795
- # Get instruction and output parts for train_on_inputs = False
1796
- input_part = input_part [:input_part .find("{INPUT}")]
1797
- output_part = output_part[:output_part.find("{OUTPUT}")]
1798
- return modelfile, jinja_template, input_part, output_part
1799
- pass
1800
-
1801
-
1802
- def test_construct_chat_template():
1803
- token = "hf_"
1804
- from transformers import AutoTokenizer
1805
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token = token)
1806
-
1807
- chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1808
-
1809
- {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
1810
-
1811
- {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1812
-
1813
- {OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
1814
-
1815
- {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1816
-
1817
- {OUTPUT}<|eot_id|>"""
1818
-
1819
- default_system_message = \
1820
- "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
1821
-
1822
- extra_eos_tokens = None
1823
-
1824
- modelfile, jinja_template, _, _ = construct_chat_template(
1825
- tokenizer = tokenizer,
1826
- chat_template = chat_template,
1827
- extra_eos_tokens = extra_eos_tokens,
1828
- )
1829
-
1830
- messages = [
1831
- {"role": "system", "content": "You are an assistant"},
1832
- {"role": "user", "content": "What is 2+2?"},
1833
- {"role": "assistant", "content": "It's 4."},
1834
- {"role": "user", "content": "Ok!"},
1835
- {"role": "assistant", "content": "Anything else?"},
1836
- {"role": "user", "content": "What's 2x2?"},
1837
- ]
1838
- correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1839
-
1840
- tokenizer.chat_template = jinja_template
1841
- new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1842
- assert(correct_output == new_output)
1843
- pass
1844
- pass
1845
-
1846
-
1847
- def apply_chat_template( \
1848
-
1849
- dataset,
1850
- tokenizer = None,
1851
-
1852
- chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
1853
-
1854
- {SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
1855
-
1856
- {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1857
-
1858
- {OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
1859
-
1860
- {INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
1861
-
1862
- {OUTPUT}<|eot_id|>""",
1863
-
1864
- default_system_message = \
1865
- "Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
1866
-
1867
- extra_eos_tokens = None,
1868
-
1869
- ):
1870
- """
1871
- Creates a Ollama modelfile and a HF Jinja template from a custom
1872
- template. You must provide 2x examples of an input & output.
1873
- There is an optional system message as well.
1874
-
1875
- You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
1876
- """
1877
- modelfile, jinja_template, input_part, output_part = construct_chat_template(
1878
- tokenizer = tokenizer,
1879
- chat_template = chat_template,
1880
- default_system_message = default_system_message,
1881
- extra_eos_tokens = extra_eos_tokens,
1882
- )
1883
- def formatting_prompts_func(examples):
1884
- convos = examples["conversations"]
1885
- texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
1886
- return { "text" : texts, }
1887
- pass
1888
-
1889
- tokenizer.chat_template = jinja_template
1890
- tokenizer._ollama_modelfile = modelfile
1891
- tokenizer._unsloth_input_part = input_part
1892
- tokenizer._unsloth_output_part = output_part
1893
-
1894
- return dataset.map(formatting_prompts_func, batched = True,)
1895
- pass
1896
-
1897
-
1898
- def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
1899
- class StoppingCriteriaSub(StoppingCriteria):
1900
- __slots__ = "stop_token", "single_match", "length",
1901
-
1902
- def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
1903
- super().__init__()
1904
- if stops == "eos_token":
1905
- self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
1906
- self.length = 1
1907
- else:
1908
- self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
1909
- self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
1910
- self.length = self.stop_token.shape[0]
1911
- pass
1912
- self.single_match = self.length == 1
1913
- pass
1914
-
1915
- def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
1916
- input_ids = input_ids.ravel()
1917
- last_token = input_ids[-1]
1918
- if self.single_match and (last_token == self.stop_token): return True
1919
-
1920
- if input_ids.shape[0] >= self.length and \
1921
- (input_ids[-self.length:] == self.stop_token).all(): return True
1922
- return False
1923
- pass
1924
- pass
1925
- stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
1926
- return stopping_criteria
1927
- pass
1928
-
1929
-
1930
- def test_chat_templates():
1931
- messages = [
1932
- {"role": "system","content": " You are a friendly chatbot.",},
1933
- {"role": "user", "content": "What is 2+2?"},
1934
- {"role": "assistant", "content": "It's 4."},
1935
- {"role": "user", "content": " But 2+2 is equal to 5. "},
1936
- {"role": "assistant", "content": "No I'm sure its 4."},
1937
- {"role": "user", "content": " No it's 100% 5! "},
1938
- ]
1939
-
1940
- # Zephyr
1941
- from transformers import AutoTokenizer
1942
- template = zephyr_template
1943
- correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
1944
- correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1945
- correct_tokenizer.chat_template = template
1946
- our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1947
- assert(correct_prompt == our_prompt)
1948
-
1949
- # Chatml
1950
- template = chatml_template
1951
- correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
1952
- correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1953
- correct_tokenizer.chat_template = template
1954
- our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1955
- assert(correct_prompt == our_prompt)
1956
-
1957
- # Mistral
1958
- template = mistral_template
1959
- correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
1960
- correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
1961
- correct_tokenizer.chat_template = template
1962
- our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
1963
- assert(correct_prompt == our_prompt)
1964
-
1965
- # Llama
1966
- template = llama_template
1967
- correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
1968
- correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1969
- correct_tokenizer.chat_template = template
1970
- our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
1971
- assert(correct_prompt == our_prompt)
1972
-
1973
- # Vicuna
1974
- try:
1975
- from fastchat.conversation import get_conv_template
1976
- except:
1977
- os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
1978
- from fastchat.conversation import get_conv_template
1979
- correct_prompt = get_conv_template("vicuna_v1.1")
1980
- for j in range(len(messages)-1):
1981
- correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
1982
- correct_prompt.append_message(correct_prompt.roles[1], "")
1983
- correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
1984
-
1985
- template = vicuna_template
1986
- correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
1987
- correct_tokenizer.chat_template = template
1988
- our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
1989
- assert(correct_prompt == our_prompt)
1990
-
1991
- try:
1992
- from fastchat.conversation import get_conv_template
1993
- except:
1994
- os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
1995
- from fastchat.conversation import get_conv_template
1996
- correct_prompt = get_conv_template("zero_shot")
1997
- for j in range(len(messages)-1):
1998
- correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
1999
- correct_prompt.append_message(correct_prompt.roles[1], "")
2000
- correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
2001
-
2002
- template = vicuna_old_template
2003
- correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
2004
- correct_tokenizer.chat_template = template
2005
- our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2006
- # We add </s> ourselves
2007
- assert(correct_prompt == our_prompt.replace("</s>", ""))
2008
-
2009
- # Gemma
2010
- correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
2011
- correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2012
- correct_tokenizer.chat_template = gemma_template
2013
- our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2014
- assert(our_prompt == correct_prompt)
2015
-
2016
- # Llama-3
2017
- template = llama3_template
2018
- correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")
2019
- correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2020
- correct_tokenizer.chat_template = template
2021
- our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2022
- assert(correct_prompt == our_prompt)
2023
-
2024
- # Phi-3
2025
- template = phi3_template
2026
- correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
2027
- correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2028
- correct_tokenizer.chat_template = template
2029
- our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
2030
- assert(correct_prompt == our_prompt)
2031
- pass
2032
-
2033
-
2034
- def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
2035
- """
2036
- Carefully checks the output of GGUF's tokenization and HF.
2037
- Can catch all tokenization bugs.
2038
- """
2039
- import subprocess
2040
- import re
2041
- messages = [
2042
- {"role": "user", "content": "What is 2+2?"},
2043
- {"role": "assistant", "content": "It's 4."},
2044
- {"role": "user", "content": " But 2+2 is equal to 5. "},
2045
- {"role": "assistant", "content": "No I'm sure its 4."},
2046
- {"role": "user", "content": " No it's 100% 5! "},
2047
- ]
2048
-
2049
- prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
2050
-
2051
- ### Instruction:
2052
- {}
2053
-
2054
- ### Input:
2055
- {}
2056
-
2057
- ### Response:
2058
- {}""".format(
2059
- "Describe the city given eloquently.", # instruction
2060
- "The lost city of Atlantis.", # input
2061
- "", # output - leave this blank for generation!
2062
- )
2063
- prompts = [ prompt, ]
2064
-
2065
- if tokenizer.chat_template is not None:
2066
- prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
2067
- prompt = prompt.replace("'", "") # Subprocess does not like ''
2068
- prompt = remove_special_tokens(tokenizer, prompt)
2069
- prompts.append(prompt)
2070
- pass
2071
-
2072
- for prompt in prompts:
2073
- command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
2074
- f"--check-tensors -p '{prompt}'"
2075
-
2076
- datas = []
2077
- with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
2078
- for line in sp.stdout:
2079
- datas.append(line.decode("utf-8", errors = "replace"))
2080
- pass
2081
- gguf_tokens = "".join(datas)
2082
-
2083
- # Now extract GGUF tokenization attempt
2084
- gguf_tokenized = re.findall("([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
2085
- gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
2086
- input_ids = tokenizer(prompt).input_ids
2087
-
2088
- tokens = tokenizer.batch_decode(input_ids)
2089
- hf_tokenized = list(zip(input_ids, tokens))
2090
-
2091
- # Compare to Huggingface
2092
- for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
2093
- if (hf_token[0] != gguf_token[0]):
2094
- print("Failed GGUF != HF at", j)
2095
- print("HF =", hf_token)
2096
- print("GGUF =", gguf_token)
2097
- print(hf_tokenized)
2098
- print()
2099
- print(gguf_tokenized)
2100
- print()
2101
- raise RuntimeError("Failed comparing GGUF to HF.")
2102
- pass
2103
- pass
2104
- return True
2105
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/__init__.py DELETED
@@ -1,65 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from .cross_entropy_loss import (
16
- fast_cross_entropy_loss,
17
- post_patch_loss_function,
18
- patch_loss_functions,
19
- )
20
- from .rms_layernorm import (
21
- fast_rms_layernorm,
22
- patch_rms_layernorm,
23
- unpatch_rms_layernorm,
24
- )
25
- from .layernorm import (
26
- fast_layernorm,
27
- patch_layernorm,
28
- )
29
- from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
30
- from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
31
- from .geglu import (
32
- geglu_exact_forward_kernel,
33
- geglu_exact_backward_kernel,
34
- geglu_approx_forward_kernel,
35
- geglu_approx_backward_kernel,
36
- )
37
- from .fast_lora import (
38
- get_lora_parameters,
39
- get_lora_parameters_bias,
40
- apply_lora_mlp_swiglu,
41
- apply_lora_mlp_geglu_exact,
42
- apply_lora_mlp_geglu_approx,
43
- apply_lora_qkv,
44
- apply_lora_o,
45
- fast_lora_forward,
46
- )
47
- from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
48
-
49
- from .flex_attention import (
50
- HAS_FLEX_ATTENTION,
51
- slow_attention_softcapping,
52
- slow_inference_attention_softcapping,
53
- create_flex_attention_causal_mask,
54
- create_flex_attention_sliding_window_mask,
55
- )
56
-
57
- import os
58
- if "UNSLOTH_ZOO_IS_PRESENT" not in os.environ:
59
- try:
60
- print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
61
- except:
62
- print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
63
- pass
64
- pass
65
- del os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/cross_entropy_loss.py DELETED
@@ -1,405 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import triton
16
- import triton.language as tl
17
- import torch
18
- from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast
19
- from transformers.models.llama.modeling_llama import logger
20
- from packaging.version import Version
21
-
22
- from unsloth_zoo.loss_utils import (
23
- patch_loss_functions as _patch_loss_functions,
24
- post_patch_loss_function,
25
- )
26
-
27
-
28
- @triton.heuristics({
29
- "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
30
- "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
31
- })
32
- @triton.jit
33
- def _cross_entropy_forward(
34
- logits_ptr ,
35
- logits_row_stride ,
36
- loss_ptr ,
37
- logsumexp_ptr ,
38
- labels_ptr ,
39
- VOCAB_SIZE ,
40
- BLOCK_SIZE : tl.constexpr,
41
- DO_SOFTCAPPING ,
42
- SOFTCAP ,
43
- DO_LOGIT_SCALING ,
44
- LOGIT_SCALE ,
45
- ):
46
- """
47
- Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
48
- Pi = exp(xi) / sum(exp(xi))
49
- CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
50
- = -y [ x - log[sum(exp(x))] ]
51
- = y * (log[sum(exp(x))] - x)
52
- If y == 0: CE_i = 0
53
- If y == 1: CE_i = logsumexp - x
54
-
55
- logsumexp is also stable
56
- Take y = log[sum(exp(x))]
57
- exp(y) = sum(exp(x))
58
- exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
59
- exp(y) = exp(c)*sum(exp(x - c))
60
- y = log(exp(c)*sum(exp(x - c)))
61
- y = c + log[sum(exp(x - c))]
62
- This means we can set c = max(x) to make sure
63
- exp(x - c) always is exp(x - max(x)).
64
- This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
65
- """
66
- row_idx = tl.program_id(0)
67
- logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
68
- loss_ptr += row_idx
69
- logsumexp_ptr += row_idx
70
- labels_ptr += row_idx
71
-
72
- col_offsets = tl.arange(0, BLOCK_SIZE)
73
- mask = col_offsets < VOCAB_SIZE
74
-
75
- label_idx = tl.load(labels_ptr).to(tl.int32)
76
- logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
77
-
78
- # Go logit scaling for Cohere: t * x
79
- if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
80
- # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
81
- if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
82
-
83
- c = tl.max(logits, 0)
84
- logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
85
-
86
- if label_idx != -100:
87
- x = tl.load(logits_ptr + label_idx).to(tl.float32)
88
- # Go logit scaling for Cohere: t * x
89
- if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
90
- # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
91
- if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
92
- loss = logsumexp - x
93
- else:
94
- loss = 0.0
95
- tl.store(logsumexp_ptr, logsumexp)
96
- tl.store(loss_ptr, loss)
97
- pass
98
-
99
-
100
- @triton.heuristics({
101
- "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
102
- "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
103
- })
104
- @triton.jit
105
- def _chunked_cross_entropy_forward(
106
- logits_ptr ,
107
- logits_row_stride ,
108
- loss_ptr ,
109
- logsumexp_ptr ,
110
- labels_ptr ,
111
- VOCAB_SIZE ,
112
- N_CHUNKS ,
113
- BLOCK_SIZE : tl.constexpr,
114
- DO_SOFTCAPPING ,
115
- SOFTCAP ,
116
- DO_LOGIT_SCALING ,
117
- LOGIT_SCALE ,
118
- ):
119
- """
120
- 256K vocab divided in 4 chunks
121
-
122
- |-65536-| |-65536-| |-65536-| |-65536-|
123
- |-------| |-------| |-------| |-------|
124
- |-------| |-------| |-------| |-------|
125
-
126
- If y == 0: CE_i = 0
127
- If y == 1: CE_i = logsumexp - x
128
-
129
- Notice we can do logsumexp for each chunk and then
130
- logsumexp[chunk_sum(logsumexp)] == logsumexp
131
-
132
- chunk_sum = log[chunk_sum(logsumexp)]
133
- = log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
134
- = log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
135
- = log[sum(exp(a)) + ... + sum(exp(z))]
136
- = logsumexp(x)
137
-
138
- This means we can perform a logsumexp for each chunk, then do a
139
- final logsumexp reduction!
140
-
141
- Ie do: logsumexp(chunked_logsumexp) - x
142
- """
143
- row_idx = tl.program_id(0)
144
- chunk_idx = tl.program_id(1)
145
- logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
146
- loss_ptr += row_idx
147
- logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
148
- labels_ptr += row_idx
149
-
150
- col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
151
- mask = col_offsets < VOCAB_SIZE
152
-
153
- label_idx = tl.load(labels_ptr).to(tl.int32)
154
- logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
155
-
156
- # Go logit scaling for Cohere: t * x
157
- if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
158
- # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
159
- if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
160
-
161
- c = tl.max(logits, 0)
162
- logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
163
-
164
- if chunk_idx == 0:
165
- # logsumexp(chunked_logsumexp) - x
166
- # Do the -x separately
167
- if label_idx != -100:
168
- x = tl.load(logits_ptr + label_idx).to(tl.float32)
169
- # Go logit scaling for Cohere: t * x
170
- if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
171
- # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
172
- if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
173
- loss = -1.0 * x
174
- else:
175
- loss = 0.0
176
- tl.store(loss_ptr, loss)
177
- pass
178
- tl.store(logsumexp_ptr, logsumexp)
179
- pass
180
-
181
-
182
- @triton.heuristics({
183
- "DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
184
- "DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
185
- })
186
- @triton.jit
187
- def _cross_entropy_backward(
188
- logits_ptr ,
189
- logits_row_stride ,
190
- dloss_ptr ,
191
- dloss_row_stride ,
192
- logsumexp_ptr ,
193
- labels_ptr ,
194
- VOCAB_SIZE ,
195
- BLOCK_SIZE : tl.constexpr,
196
- DO_SOFTCAPPING ,
197
- SOFTCAP ,
198
- DO_LOGIT_SCALING ,
199
- LOGIT_SCALE ,
200
- ):
201
- """
202
- CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
203
- dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
204
-
205
- From https://en.wikipedia.org/wiki/LogSumExp
206
- d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
207
-
208
- dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
209
- dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
210
- dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
211
-
212
- If y == 0: dC/dx = 0
213
- If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
214
- If y == 1 and x != label: dC/dx = exp[x - logsumexp]
215
- """
216
- row_idx = tl.program_id(0)
217
- block_idx = tl.program_id(1)
218
-
219
- logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
220
- dloss_ptr += row_idx * dloss_row_stride
221
- col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
222
- mask = col_offsets < VOCAB_SIZE
223
- label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
224
-
225
- if label_idx != -100:
226
- dloss = tl.load(dloss_ptr)
227
- else:
228
- dloss = 0.0
229
-
230
- x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
231
-
232
- # Do logit scaling for Cohere
233
- if DO_LOGIT_SCALING:
234
- # d/dx [s * x] = s
235
- x = x * LOGIT_SCALE
236
- pass
237
-
238
- # Do logit softcapping for Gemma 2: t * tanh(1/t * x)
239
- partial = x
240
- if DO_SOFTCAPPING:
241
- # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
242
- partial = triton_tanh(x / SOFTCAP)
243
- x = SOFTCAP * partial
244
- pass
245
-
246
- logsumexp = tl.load(logsumexp_ptr + row_idx)
247
- y = tl.exp(x - logsumexp)
248
- y = tl.where(
249
- col_offsets == label_idx,
250
- y - 1.0, # exp(x - logsumexp) - 1
251
- y, # exp(x - logsumexp)
252
- )
253
-
254
- if DO_LOGIT_SCALING:
255
- # d/dx [s * x] = s
256
- y = y * LOGIT_SCALE
257
- pass
258
-
259
- if DO_SOFTCAPPING:
260
- # d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
261
- y = y * (1.0 - partial*partial)
262
- pass
263
-
264
- # If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
265
- tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
266
- pass
267
-
268
-
269
- MAX_FUSED_SIZE = 65536 # 2**16
270
-
271
- class Fast_CrossEntropyLoss(torch.autograd.Function):
272
- @staticmethod
273
- def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
274
- n_rows : int
275
- vocab_size : int
276
- n_rows, vocab_size = logits.shape
277
-
278
- div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
279
- n_chunks : int = div + (mod != 0)
280
- losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
281
-
282
- DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
283
- DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
284
-
285
- BLOCK_SIZE : int
286
- num_warps : int
287
- if n_chunks == 1:
288
- # For small vocabs <= 65336 like Llama, Mistral
289
- BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
290
- logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
291
-
292
- _cross_entropy_forward[(n_rows,)](
293
- logits, logits.stride(0),
294
- losses,
295
- logsumexp,
296
- labels,
297
- VOCAB_SIZE = vocab_size,
298
- BLOCK_SIZE = BLOCK_SIZE,
299
- DO_SOFTCAPPING = DO_SOFTCAPPING,
300
- SOFTCAP = logit_softcapping,
301
- DO_LOGIT_SCALING = DO_LOGIT_SCALING,
302
- LOGIT_SCALE = logit_scaling,
303
- num_warps = num_warps,
304
- )
305
- else:
306
- # For large vocabs > 65336 like Gemma 256K
307
- logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
308
-
309
- _chunked_cross_entropy_forward[(n_rows, n_chunks,)](
310
- logits, logits.stride(0),
311
- losses,
312
- logsumexp,
313
- labels,
314
- VOCAB_SIZE = vocab_size,
315
- N_CHUNKS = n_chunks,
316
- BLOCK_SIZE = MAX_FUSED_SIZE,
317
- DO_SOFTCAPPING = DO_SOFTCAPPING,
318
- SOFTCAP = logit_softcapping,
319
- DO_LOGIT_SCALING = DO_LOGIT_SCALING,
320
- LOGIT_SCALE = logit_scaling,
321
- num_warps = 32,
322
- )
323
- # logsumexp(chunked_logsumexp) - x
324
- # Do the -x separately
325
- logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
326
- losses += logsumexp
327
- losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
328
- pass
329
-
330
- ctx.save_for_backward(logits, logsumexp, labels)
331
- ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
332
- ctx.logit_softcapping = logit_softcapping
333
- ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
334
- ctx.logit_scaling = logit_scaling
335
- return losses
336
- pass
337
-
338
-
339
- @staticmethod
340
- def backward(ctx, dlosses):
341
- logits, logsumexp, labels = ctx.saved_tensors
342
- n_rows : int
343
- vocab_size : int
344
- n_rows, vocab_size = logits.shape
345
-
346
- BLOCK_SIZE : int = 4096
347
- div : int
348
- mod : int
349
- div, mod = divmod(vocab_size, BLOCK_SIZE)
350
- n_blocks : int = div + (mod != 0)
351
-
352
- _cross_entropy_backward[(n_rows, n_blocks,)](
353
- logits, logits.stride(0),
354
- dlosses, dlosses.stride(0),
355
- logsumexp,
356
- labels,
357
- VOCAB_SIZE = vocab_size,
358
- BLOCK_SIZE = BLOCK_SIZE,
359
- DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
360
- SOFTCAP = ctx.logit_softcapping,
361
- DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
362
- LOGIT_SCALE = ctx.logit_scaling,
363
- num_warps = 8,
364
- )
365
- return logits, None, None, None,
366
- pass
367
- pass
368
-
369
-
370
- def fast_cross_entropy_loss(
371
- logits,
372
- labels,
373
- logit_softcapping = 0,
374
- logit_scaling = 0,
375
- n_items = None,
376
- ):
377
- """
378
- Arguments:
379
- logits: (batch, seq_len, vocab_size)
380
- labels: (batch, seq_len,)
381
- Returns:
382
- losses: float
383
- """
384
- batch, seq_len, d = logits.shape
385
- assert(labels.shape == (batch, seq_len))
386
-
387
- loss = Fast_CrossEntropyLoss.apply(
388
- logits.view(batch*seq_len, d),
389
- labels.view(-1),
390
- logit_softcapping,
391
- logit_scaling,
392
- )
393
- if n_items is None:
394
- n_items = torch.count_nonzero(labels != -100)
395
- return loss.sum() / n_items
396
- pass
397
- if (Version(torch.__version__) < Version("2.4.0")) and \
398
- not hasattr(fast_cross_entropy_loss, "__wrapped__"):
399
- fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)
400
- pass
401
-
402
- # Patch CE Losses in transformers
403
- def patch_loss_functions(torch_compile = True):
404
- _patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
405
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/fast_lora.py DELETED
@@ -1,490 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- from .utils import (
17
- fast_dequantize,
18
- QUANT_STATE,
19
- get_lora_parameters,
20
- get_lora_parameters_bias,
21
- matmul_lora,
22
- torch_amp_custom_fwd,
23
- torch_amp_custom_bwd,
24
- )
25
-
26
-
27
- class LoRA_MLP(torch.autograd.Function):
28
- """
29
- ### LoRA weights
30
- G = G + Ag @ Bg
31
- U = U + Au @ Bu
32
- W = W + Aw @ Bw
33
-
34
- ### SwiGLU(X)
35
- e = X @ G
36
- f = e * sigmoid(e)
37
- g = X @ U
38
- h = f * g
39
- i = h @ W
40
-
41
- ### Backpropagation chain rule
42
- See our blog post for more details
43
-
44
- df = sigmoid(e) * (1 - f) + f
45
- dC/dW = h.T @ dY
46
- dC/dU = X.T @ (D @ W.T * f)
47
- dC/dG = X.T @ (D @ W.T * df * g)
48
-
49
- ### Down projection LoRA weights
50
- dC/dAw = dC/dW @ B.T
51
- dC/dBw = A.T @ dC/dW
52
- dC/dAw = h.T @ dY @ B.T
53
- dC/dBw = A.T @ h.T @ dY
54
-
55
- ### Up projection LoRA weights
56
- dC/dAu = X.T @ (D @ W.T * f) @ B.T
57
- dC/dBu = A.T @ X.T @ (D @ W.T * f)
58
-
59
- ### Gate projection LoRA weights
60
- dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
61
- dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
62
-
63
- Don't forget to see our blog post for more details!
64
- """
65
- @staticmethod
66
- @torch_amp_custom_fwd
67
- def forward(ctx, X : torch.Tensor,
68
- gateW, gateW_quant, gateA, gateB, gateS,
69
- upW, upW_quant, upA, upB, upS,
70
- downW, downW_quant, downA, downB, downS,
71
- _forward_function, _backward_function,
72
- inplace = True,):
73
- dtype = X.dtype
74
-
75
- e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
76
- g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
77
- h = _forward_function(e, g)
78
- i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
79
-
80
- ctx.custom_saved_tensors = (
81
- gateW, gateW_quant, gateS,
82
- upW, upW_quant, upS,
83
- downW, downW_quant, downS,
84
- _backward_function,
85
- )
86
- ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
87
- X, e, g)
88
- ctx.inplace = inplace
89
- return i
90
- pass
91
-
92
-
93
- @staticmethod
94
- @torch_amp_custom_bwd
95
- def backward(ctx, dY : torch.Tensor):
96
- gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
97
- _backward_function = ctx.custom_saved_tensors
98
- gateA, gateB, upA, upB, downA, downB, \
99
- X, e, g = ctx.saved_tensors
100
-
101
- gateA, gateB, upA, upB, downA, downB = \
102
- gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
103
-
104
- batch, seq_len, hd = X.shape
105
- dY = dY.view(-1, dY.shape[-1])
106
- X = X .view(-1, X .shape[-1])
107
- e = e .view(-1, e .shape[-1])
108
- g = g .view(-1, g .shape[-1])
109
- dtype = X.dtype
110
-
111
- DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
112
- DW, e, g = _backward_function(DW, e, g)
113
- h, df, de = DW, e, g
114
-
115
- # Down projection LoRA weights
116
- d_downA = h.t() @ (dY @ downB.t())
117
- d_downB = (downA.t() @ h.t()) @ dY
118
- d_downA *= downS
119
- d_downB *= downS
120
-
121
- # Up projection LoRA weights
122
- d_upA = X.t() @ (df @ upB.t())
123
- d_upB = (upA.t() @ X.t()) @ df
124
- d_upA *= upS
125
- d_upB *= upS
126
-
127
- # Gate projection LoRA weights
128
- d_gateA = X.t() @ (de @ gateB.t())
129
- d_gateB = (gateA.t() @ X.t()) @ de
130
- d_gateA *= gateS
131
- d_gateB *= gateS
132
-
133
- # dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
134
- # dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
135
- upW = fast_dequantize(upW.t(), upW_quant)
136
- dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
137
- del upW
138
- dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
139
-
140
- gateW = fast_dequantize(gateW.t(), gateW_quant)
141
- dX += de @ gateW.t()
142
- del gateW
143
- dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
144
-
145
- # gateW, gateW_quant, gateA, gateB, gateS,
146
- # upW, upW_quant, upA, upB, upS,
147
- # downW, downW_quant, downA, downB, downS,
148
- return dX.view(batch, seq_len, hd), \
149
- None, None, d_gateA.t(), d_gateB.t(), None, \
150
- None, None, d_upA.t(), d_upB.t(), None, \
151
- None, None, d_downA.t(), d_downB.t(), None, \
152
- None, None, None, # _backward and _forward and inplace
153
- pass
154
- pass
155
-
156
-
157
- from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
158
- def apply_lora_mlp_swiglu(self, X, inplace = True):
159
- gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
160
- upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
161
- downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
162
- out = LoRA_MLP.apply(X,
163
- gateW, gateW_quant, gateA, gateB, gateS,
164
- upW, upW_quant, upA, upB, upS,
165
- downW, downW_quant, downA, downB, downS,
166
- swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
167
- inplace,)
168
- return out
169
- pass
170
-
171
-
172
- from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
173
- def apply_lora_mlp_geglu_exact(self, X, inplace = True):
174
- gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
175
- upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
176
- downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
177
- out = LoRA_MLP.apply(X,
178
- gateW, gateW_quant, gateA, gateB, gateS,
179
- upW, upW_quant, upA, upB, upS,
180
- downW, downW_quant, downA, downB, downS,
181
- geglu_exact_forward_kernel, geglu_exact_backward_kernel,
182
- inplace,)
183
- return out
184
- pass
185
-
186
-
187
- from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
188
- def apply_lora_mlp_geglu_approx(self, X):
189
- gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
190
- upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
191
- downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
192
- out = LoRA_MLP.apply(X,
193
- gateW, gateW_quant, gateA, gateB, gateS,
194
- upW, upW_quant, upA, upB, upS,
195
- downW, downW_quant, downA, downB, downS,
196
- geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
197
- return out
198
- pass
199
-
200
-
201
- class LoRA_QKV(torch.autograd.Function):
202
- """
203
- ### LoRA weights
204
- Wq = Wq + Aq @ Bq
205
- Wk = Wk + Ak @ Bk
206
- Wv = Wv + Av @ Bv
207
- Q = X @ Wq = X @ Wq + X @ Aq @ Bq
208
- K = X @ Wk = X @ Wk + X @ Ak @ Bk
209
- V = X @ Wv = X @ Wv + X @ Av @ Bv
210
-
211
- ### Backpropagation chain rule
212
- See our blogpost for more details.
213
-
214
- dC/dWq = X.T @ D(Wq)
215
- dC/dWk = X.T @ D(Wk)
216
- dC/dWv = X.T @ D(Wv)
217
- We then sum them all find dC/dX
218
-
219
- ### Q projection LoRA weights
220
- dC/dAq = X.T @ D(Wq) @ B.T
221
- dC/dBq = A.T @ X.T @ D(Wq)
222
-
223
- ### K projection LoRA weights
224
- dC/dAk = X.T @ D(Wk) @ B.T
225
- dC/dBk = A.T @ X.T @ D(Wk)
226
-
227
- ### V projection LoRA weights
228
- dC/dAv = X.T @ D(Wv) @ B.T
229
- dC/dBv = A.T @ X.T @ D(Wv)
230
- """
231
- @staticmethod
232
- @torch_amp_custom_fwd
233
- def forward(ctx, X : torch.Tensor,
234
- QW, QW_quant, QA, QB, QS,
235
- KW, KW_quant, KA, KB, KS,
236
- VW, VW_quant, VA, VB, VS,
237
- inplace = True):
238
- dtype = X.dtype
239
-
240
- Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
241
- K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
242
- V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
243
-
244
- ctx.custom_saved_tensors = (
245
- QW, QW_quant, QS,
246
- KW, KW_quant, KS,
247
- VW, VW_quant, VS,
248
- )
249
- ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
250
- ctx.inplace = inplace
251
- return Q, K, V
252
- pass
253
-
254
- @staticmethod
255
- @torch_amp_custom_bwd
256
- def backward(ctx, dQ, dK, dV):
257
- QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
258
- ctx.custom_saved_tensors
259
- X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
260
-
261
- QA, QB, KA, KB, VA, VB = \
262
- QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
263
-
264
- batch, seq_len, hd = X.shape
265
- dQ = dQ.view(-1, dQ.shape[-1])
266
- dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
267
- dV = dV.view(-1, dV.shape[-1])
268
- X = X .view(-1, X .shape[-1])
269
- dtype = X.dtype
270
-
271
- ### Weight projection LoRA weights
272
- # See our blogpost for more details.
273
-
274
- # Q Projection
275
- d_QA = X.t() @ (dQ @ QB.t())
276
- d_QB = (QA.t() @ X.t()) @ dQ
277
- d_QA *= QS
278
- d_QB *= QS
279
-
280
- # K Projection
281
- d_KA = X.t() @ (dK @ KB.t())
282
- d_KB = (KA.t() @ X.t()) @ dK
283
- d_KA *= KS
284
- d_KB *= KS
285
-
286
- # V Projection
287
- d_VA = X.t() @ (dV @ VB.t())
288
- d_VB = (VA.t() @ X.t()) @ dV
289
- d_VA *= VS
290
- d_VB *= VS
291
-
292
- # Combine derivatives to find dX
293
- # dQ
294
- QW = fast_dequantize(QW.t(), QW_quant)
295
- dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
296
- del QW
297
- dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
298
-
299
- # dK
300
- KW = fast_dequantize(KW.t(), KW_quant)
301
- dX += dK @ KW.t()
302
- del KW
303
- dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
304
-
305
- # dV
306
- VW = fast_dequantize(VW.t(), VW_quant)
307
- dX += dV @ VW.t()
308
- del VW
309
- dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
310
-
311
- # QW, QW_quant, QA, QB, QS,
312
- # KW, KW_quant, KA, KB, KS,
313
- # VW, VW_quant, VA, VB, VS,
314
- return dX.view(batch, seq_len, hd), \
315
- None, None, d_QA.t(), d_QB.t(), None, \
316
- None, None, d_KA.t(), d_KB.t(), None, \
317
- None, None, d_VA.t(), d_VB.t(), None, \
318
- None,
319
- pass
320
- pass
321
-
322
-
323
- def apply_lora_qkv(self, X, inplace = True):
324
- QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
325
- KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
326
- VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
327
- Q, K, V = LoRA_QKV.apply(X,
328
- QW, QW_quant, QA, QB, QS,
329
- KW, KW_quant, KA, KB, KS,
330
- VW, VW_quant, VA, VB, VS,
331
- inplace,
332
- )
333
- return Q, K, V
334
- pass
335
-
336
-
337
- class LoRA_W(torch.autograd.Function):
338
- """
339
- ### LoRA weights
340
- Wq = Wq + Aq @ Bq
341
- Wk = Wk + Ak @ Bk
342
- Wv = Wv + Av @ Bv
343
- Q = X @ Wq = X @ Wq + X @ Aq @ Bq
344
- K = X @ Wk = X @ Wk + X @ Ak @ Bk
345
- V = X @ Wv = X @ Wv + X @ Av @ Bv
346
-
347
- ### Backpropagation chain rule
348
- dC/dWq = X.T @ D(Wq)
349
- dC/dWk = X.T @ D(Wk)
350
- dC/dWv = X.T @ D(Wv)
351
-
352
- ### Q projection LoRA weights
353
- dC/dAq = X.T @ D(Wq) @ B.T
354
- dC/dBq = A.T @ X.T @ D(Wq)
355
-
356
- ### K projection LoRA weights
357
- dC/dAk = X.T @ D(Wk) @ B.T
358
- dC/dBk = A.T @ X.T @ D(Wk)
359
-
360
- ### V projection LoRA weights
361
- dC/dAv = X.T @ D(Wv) @ B.T
362
- dC/dBv = A.T @ X.T @ D(Wv)
363
- """
364
- @staticmethod
365
- @torch_amp_custom_fwd
366
- def forward(ctx, X : torch.Tensor,
367
- W, W_quant, A, B, S):
368
- dtype = X.dtype
369
- XW = matmul_lora(X, W, W_quant, A, B, S)
370
- ctx.custom_saved_tensors = (W, W_quant, S,)
371
- ctx.save_for_backward(A, B, X)
372
- return XW
373
- pass
374
-
375
- @staticmethod
376
- @torch_amp_custom_bwd
377
- def backward(ctx, dY : torch.Tensor):
378
- W, W_quant, S = ctx.custom_saved_tensors
379
- A, B, X = ctx.saved_tensors
380
-
381
- A, B = A.t(), B.t()
382
-
383
- batch, seq_len, hd = X.shape
384
- dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
385
- X = X .reshape(-1, X .shape[-1]) # Must be reshape
386
- dtype = X.dtype
387
-
388
- ### Weight projection LoRA weights
389
- # Weight projection
390
- d_A = X.t() @ (dY @ B.t())
391
- d_B = (A.t() @ X.t()) @ dY
392
- d_A *= S
393
- d_B *= S
394
-
395
- # Get derivative for dX
396
- W = fast_dequantize(W.t(), W_quant)
397
- dX = dY @ W.t()
398
- del W
399
- dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
400
-
401
- # W, W_quant, A, B, S
402
- return dX.view(batch, seq_len, hd), \
403
- None, None, d_A.t(), d_B.t(), None
404
- pass
405
- pass
406
-
407
-
408
- def apply_lora_o(self, X):
409
- OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
410
- O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
411
- return O
412
- pass
413
-
414
-
415
- IDENTITY_DROPOUT = torch.nn.Identity
416
- @torch._disable_dynamo
417
- def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
418
- raise NotImplementedError(
419
- "Unsloth: Currently not supported yet - reshaping done incorrectly"
420
- )
421
- self._check_forward_args(x, *args, **kwargs)
422
- adapter_names = kwargs.pop("adapter_names", None)
423
-
424
- if self.disable_adapters:
425
- if self.merged:
426
- self.unmerge()
427
- result = self.base_layer(x, *args, **kwargs)
428
- elif adapter_names is not None:
429
- result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
430
- elif self.merged:
431
- result = self.base_layer(x, *args, **kwargs)
432
- else:
433
- # Fastpath
434
- if len(self.active_adapters) == 1:
435
- active_adapter = self.active_adapters[0]
436
- if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)
437
-
438
- dropout = self.lora_dropout[active_adapter]
439
- if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
440
- lora_A = self.lora_A[active_adapter].weight
441
- lora_B = self.lora_B[active_adapter].weight
442
- scaling = self.scaling[active_adapter]
443
- W = self.base_layer.weight
444
- return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
445
- pass
446
- pass
447
-
448
- result = self.base_layer(x, *args, **kwargs)
449
- # As per Tim Dettmers, for 4bit, we need to defensively clone here.
450
- # The reason is that in some cases, an error can occur that backprop
451
- # does not work on a manipulated view. This issue may be solved with
452
- # newer PyTorch versions but this would need extensive testing to be
453
- # sure.
454
- result = result.clone()
455
-
456
- for active_adapter in self.active_adapters:
457
- if active_adapter not in self.lora_A.keys():
458
- continue
459
- lora_A = self.lora_A[active_adapter]
460
- lora_B = self.lora_B[active_adapter]
461
- dropout = self.lora_dropout[active_adapter]
462
- scaling = self.scaling[active_adapter]
463
-
464
- requires_conversion = not torch.is_autocast_enabled()
465
- if requires_conversion:
466
- expected_dtype = result.dtype
467
- x = x.to(lora_A.weight.dtype)
468
-
469
- if not self.use_dora[active_adapter]:
470
- result = result + lora_B(lora_A(dropout(x))) * scaling
471
- else:
472
- if isinstance(dropout, torch.nn.Identity) or not self.training:
473
- base_result = result
474
- else:
475
- x = dropout(x)
476
- base_result = None
477
-
478
- result = result + self.lora_magnitude_vector[active_adapter](
479
- x,
480
- lora_A=lora_A,
481
- lora_B=lora_B,
482
- scaling=scaling,
483
- base_layer=self.get_base_layer(),
484
- base_result=base_result,
485
- )
486
- if requires_conversion:
487
- result = result.to(expected_dtype)
488
-
489
- return result
490
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/flex_attention.py DELETED
@@ -1,181 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- from functools import lru_cache
17
- from transformers.models.llama.modeling_llama import logger
18
- import os
19
-
20
- torch_compile_options = {
21
- "epilogue_fusion" : True,
22
- "max_autotune" : True,
23
- "shape_padding" : True,
24
- "trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
25
- "triton.cudagraphs" : False,
26
- }
27
-
28
- # Flex Attention supported from torch 2.5 onwards only
29
- try:
30
- from torch.nn.attention.flex_attention import (
31
- flex_attention as _flex_attention,
32
- create_block_mask as _create_block_mask,
33
- )
34
- _flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
35
- HAS_FLEX_ATTENTION = False
36
- except:
37
- HAS_FLEX_ATTENTION = False
38
- pass
39
-
40
-
41
- if not HAS_FLEX_ATTENTION:
42
-
43
- # Logit softcapping
44
- @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
45
- def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
46
- n_heads = self.num_heads
47
- head_dim = self.head_dim
48
- n_kv_heads = self.num_key_value_heads
49
- n_groups = self.num_key_value_groups
50
-
51
- # Grouped query attention
52
- K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
53
- V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
54
- K = K.reshape(bsz, n_heads, q_len, head_dim)
55
- V = V.reshape(bsz, n_heads, q_len, head_dim)
56
-
57
- # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
58
- # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
59
- # We default to using the config file itself
60
- # s = self.config.hidden_size // self.config.num_attention_heads
61
- s = self.config.query_pre_attn_scalar
62
- t = self.config.attn_logit_softcapping
63
-
64
- Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
65
- A = torch.matmul(Q, K.transpose(2, 3))
66
- A = t * torch.tanh(A / t) # Logit softcapping
67
- A += causal_mask[:q_len, :q_len]
68
- # Much slower in torch compile!
69
- # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
70
- A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
71
- A = torch.matmul(A, V)
72
- A = A.transpose(1, 2).contiguous()
73
- A = A.reshape(bsz, q_len, n_heads*head_dim)
74
- return A
75
- pass
76
-
77
- create_flex_attention_causal_mask = None
78
- create_flex_attention_sliding_window_mask = None
79
- else:
80
- # See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
81
- # for more examples
82
- # BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
83
- import functools, math
84
-
85
- def generate_tanh_softcap(t):
86
- def tanh_softcap(x, b, h, q_idx, kv_idx):
87
- return t * torch.tanh(x / t)
88
- return tanh_softcap
89
- pass
90
- def causal_masker(b, h, q_idx, kv_idx):
91
- return q_idx >= kv_idx
92
- pass
93
-
94
- @functools.lru_cache
95
- def sliding_window_masker(size = 4096):
96
- def sliding_window(b, h, q_idx, kv_idx):
97
- causal_mask = q_idx >= kv_idx
98
- window_mask = q_idx - kv_idx <= size
99
- return causal_mask & window_mask
100
- return sliding_window
101
- pass
102
-
103
- @functools.lru_cache
104
- def create_block_mask(mask, n = 128):
105
- return _create_block_mask(
106
- mask, 1, 1, n, n,
107
- BLOCK_SIZE = 128,
108
- _compile = True,
109
- )
110
- pass
111
-
112
- def create_flex_attention_causal_mask(max_seq_length = 8192):
113
- causal_mask = create_block_mask(causal_masker, max_seq_length)
114
- return causal_mask
115
- pass
116
-
117
- def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
118
- sliding_masker = sliding_window_masker(sliding_window)
119
- causal_mask = create_block_mask(sliding_masker, max_seq_length)
120
- return causal_mask
121
- pass
122
-
123
- @functools.lru_cache
124
- def flex_attention(s, t):
125
- scale = 1.0 / math.sqrt(s)
126
- score_mod = generate_tanh_softcap(t)
127
- return functools.partial(
128
- _flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
129
- )
130
- pass
131
-
132
- def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
133
- n_heads = self.num_heads
134
- head_dim = self.head_dim
135
- s = self.config.query_pre_attn_scalar
136
- t = self.config.attn_logit_softcapping
137
- fx = flex_attention(s, t)
138
- A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
139
- A = A.transpose(1, 2).contiguous()
140
- A = A.reshape(bsz, q_len, n_heads*head_dim)
141
- return A
142
- pass
143
- pass
144
-
145
-
146
- torch_matmul = torch.matmul
147
- torch_tanh = torch.tanh
148
- torch_nn_functional_softmax = torch.nn.functional.softmax
149
- def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
150
- n_heads = self.num_heads
151
- head_dim = self.head_dim
152
- n_kv_heads = self.num_key_value_heads
153
- n_groups = self.num_key_value_groups
154
-
155
- # Grouped query attention
156
- K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
157
- V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
158
- K = K.reshape(bsz, n_heads, q_len, head_dim)
159
- V = V.reshape(bsz, n_heads, q_len, head_dim)
160
-
161
- # See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
162
- # Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
163
- # We default to using the config file itself
164
- # s = self.config.hidden_size // self.config.num_attention_heads
165
- s = self.config.query_pre_attn_scalar
166
- t = self.config.attn_logit_softcapping
167
-
168
- Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
169
- A = torch_matmul(Q, K.transpose(2, 3))
170
-
171
- # Logit softcapping
172
- A /= t; torch_tanh(A, out = A); A *= t;
173
- A += causal_mask[:q_len, :q_len]
174
- # Much slower in torch compile!
175
- # A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
176
- A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
177
- A = torch_matmul(A, V)
178
- A = A.transpose(1, 2).contiguous()
179
- A = A.reshape(bsz, q_len, n_heads*head_dim)
180
- return A
181
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/geglu.py DELETED
@@ -1,203 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import triton
16
- import triton.language as tl
17
- import torch
18
- from .utils import calculate_settings, triton_tanh
19
-
20
-
21
- @triton.jit
22
- def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
23
- block_idx = tl.program_id(0)
24
- offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
25
- mask = offsets < n_elements
26
-
27
- # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
28
- # h = f * up
29
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
30
- g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
31
-
32
- f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
33
- f_row = f_row.to(g_row.dtype) # Exact copy from HF
34
- h_row = f_row * g_row
35
-
36
- # Store h
37
- tl.store(h + offsets, h_row, mask = mask)
38
- pass
39
-
40
-
41
- def geglu_exact_forward_kernel(gate, up):
42
- batch, seq_len, hd = gate.shape
43
- n_elements = gate.numel()
44
- out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
45
- grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
46
- _exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
47
- return out
48
- pass
49
-
50
-
51
- @triton.jit
52
- def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
53
- """
54
- f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
55
- h = f * up
56
-
57
- df/de (with help of Wolfram :)
58
- df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
59
-
60
- Reuse via
61
- f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
62
- """
63
- block_idx = tl.program_id(0)
64
- offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
65
- mask = offsets < n_elements
66
-
67
- DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
68
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
69
- g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
70
-
71
- # Break e_row away for re-use
72
- # f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
73
- f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
74
- f_row = f_partial_row * e_row
75
-
76
- f_row = f_row.to(DW_row.dtype)
77
- # h = f * g
78
- h_row = f_row * g_row
79
- # df = DW * f
80
- df_row = DW_row * f_row
81
- # dg = DW * g
82
- dg_row = DW_row * g_row
83
-
84
- # df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
85
- t = 0.3989422804014327 # 1/sqrt(2*pi)
86
- df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
87
-
88
- de_row = dg_row.to(tl.float32) * df_de
89
- de_row = de_row.to(DW_row.dtype)
90
-
91
- # Store derivatives in buffers
92
- tl.store(DW + offsets, h_row, mask = mask) # h = f * g
93
- tl.store(e + offsets, df_row, mask = mask) # df = DW * f
94
- tl.store(g + offsets, de_row, mask = mask) # de
95
- pass
96
-
97
-
98
- def geglu_exact_backward_kernel(DW, e, g):
99
- batch_seq_len, hd = e.shape
100
- n_elements = e.numel()
101
- grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
102
- _exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
103
- return DW, e, g
104
- pass
105
-
106
-
107
- @triton.jit
108
- def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
109
- block_idx = tl.program_id(0)
110
- offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
111
- mask = offsets < n_elements
112
-
113
- # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
114
- # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
115
- # h = f * up
116
- s = 0.7978845608028654 # math.sqrt(2 / math.pi)
117
-
118
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
119
- g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
120
-
121
- f_row = 0.5 * e_row * (
122
- triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
123
- + 1.0
124
- )
125
- f_row = f_row.to(g_row.dtype) # Exact copy from HF
126
- h_row = f_row * g_row
127
-
128
- # Store h
129
- tl.store(h + offsets, h_row, mask = mask)
130
- pass
131
-
132
-
133
- def geglu_approx_forward_kernel(gate, up):
134
- batch, seq_len, hd = gate.shape
135
- n_elements = gate.numel()
136
- out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
137
- grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
138
- _approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
139
- return out
140
- pass
141
-
142
-
143
- @triton.jit
144
- def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
145
- """
146
- f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
147
- h = f * up
148
-
149
- df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
150
- df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
151
- 1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
152
- ( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
153
-
154
- Notice sech^2(x) = 1 - tanh^2(x)
155
- So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
156
-
157
- See https://www.desmos.com/calculator/nqprfoni6x
158
- """
159
- block_idx = tl.program_id(0)
160
- offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
161
- mask = offsets < n_elements
162
-
163
- DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
164
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
165
- g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
166
-
167
- # See https://www.desmos.com/calculator/nqprfoni6x
168
- s = 0.7978845608028654 # math.sqrt(2 / math.pi)
169
- a = s * e_row # a = sqrt(2 / pi) * x
170
- b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
171
- T = 1.0 + triton_tanh(a + b)
172
- T2 = 0.5 * T
173
- # Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
174
- Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
175
- df_de = T2 + Q2 # 1/2 * (T + Q)
176
-
177
- # f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
178
- f_row = T2 * e_row
179
- f_row = f_row.to(DW_row.dtype)
180
- # h = f * g
181
- h_row = f_row * g_row
182
- # df = DW * f
183
- df_row = DW_row * f_row
184
- # dg = DW * g
185
- dg_row = DW_row * g_row
186
-
187
- de_row = dg_row.to(tl.float32) * df_de
188
- de_row = de_row.to(DW_row.dtype)
189
-
190
- # Store derivatives in buffers
191
- tl.store(DW + offsets, h_row, mask = mask) # h = f * g
192
- tl.store(e + offsets, df_row, mask = mask) # df = DW * f
193
- tl.store(g + offsets, de_row, mask = mask) # de
194
- pass
195
-
196
-
197
- def geglu_approx_backward_kernel(DW, e, g):
198
- batch_seq_len, hd = e.shape
199
- n_elements = e.numel()
200
- grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
201
- _approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
202
- return DW, e, g
203
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/layernorm.py DELETED
@@ -1,213 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- # Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import triton
17
- import triton.language as tl
18
- import torch
19
- from .utils import calculate_settings
20
- from unsloth_zoo.patching_utils import (
21
- patch_layernorm,
22
- )
23
-
24
-
25
- @triton.jit
26
- def layernorm_forward(
27
- Y, Y_row_stride,
28
- X, X_row_stride,
29
- W,
30
- b,
31
- r,
32
- mu,
33
- n_cols, eps,
34
- BLOCK_SIZE : tl.constexpr
35
- ):
36
- row_idx = tl.program_id(0)
37
- col_offsets = tl.arange(0, BLOCK_SIZE)
38
- mask = col_offsets < n_cols
39
-
40
- Y += row_idx * Y_row_stride
41
- X += row_idx * X_row_stride
42
- r += row_idx
43
- mu += row_idx
44
-
45
- # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
46
- # are in float32!
47
- X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
48
- W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
49
- b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
50
-
51
- mean_X = tl.sum(X_row, axis = 0) / n_cols
52
- XX = X_row - mean_X
53
- row_var = tl.sum(XX * XX, axis = 0) / n_cols
54
- inv_var = tl.math.rsqrt(row_var + eps)
55
- tl.store (r, inv_var)
56
- tl.store (mu, mean_X)
57
- output = (XX * inv_var) * W_row + b_row
58
- tl.store(Y + col_offsets, output, mask = mask)
59
- pass
60
-
61
-
62
- @triton.jit
63
- def layernorm_backward(
64
- dY, dY_row_stride,
65
- X, X_row_stride,
66
- W,
67
- b,
68
- r,
69
- mu,
70
- n_cols, eps,
71
- BLOCK_SIZE : tl.constexpr
72
- ):
73
- # Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
74
- row_idx = tl.program_id(0)
75
- col_offsets = tl.arange(0, BLOCK_SIZE)
76
- mask = col_offsets < n_cols
77
-
78
- dY += row_idx * dY_row_stride
79
- X += row_idx * X_row_stride
80
- r += row_idx
81
- mu += row_idx
82
-
83
- # According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
84
- # are in float32!
85
- dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
86
- X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
87
- W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
88
- b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
89
-
90
- inv_var = tl.load(r) .to(tl.float32)
91
- mean = tl.load(mu).to(tl.float32)
92
- normed = (X_row - mean) * inv_var
93
- dY_W = dY_row * W_row
94
- dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
95
- dX_row = dX_row * inv_var
96
- tl.store(dY + col_offsets, dX_row, mask = mask)
97
- pass
98
-
99
-
100
- class Fast_Layernorm(torch.autograd.Function):
101
- @staticmethod
102
- def forward(ctx, X, W, b, eps):
103
- shape = X.shape
104
- dim = shape[-1]
105
- X = X.view(-1, dim)
106
- n_rows, n_cols = X.shape
107
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
108
-
109
- Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
110
- r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
111
- mu = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
112
-
113
- layernorm_forward[(n_rows,)](
114
- Y, Y.stride(0),
115
- X, X.stride(0),
116
- W,
117
- b,
118
- r,
119
- mu,
120
- n_cols, eps,
121
- BLOCK_SIZE = BLOCK_SIZE,
122
- num_warps = num_warps,
123
- )
124
- ctx.eps = eps
125
- ctx.BLOCK_SIZE = BLOCK_SIZE
126
- ctx.num_warps = num_warps
127
- ctx.save_for_backward(X, W, b, r, mu)
128
- return Y.view(*shape)
129
- pass
130
-
131
- @staticmethod
132
- def backward(ctx, dY):
133
- shape = dY.shape
134
- dim = shape[-1]
135
- dY = dY.view(-1, dim)
136
- X, W, b, r, mu = ctx.saved_tensors
137
- n_rows, n_cols = dY.shape
138
-
139
- layernorm_backward[(n_rows,)](
140
- dY, dY.stride(0),
141
- X, X .stride(0),
142
- W,
143
- b,
144
- r,
145
- mu,
146
- n_cols, ctx.eps,
147
- BLOCK_SIZE = ctx.BLOCK_SIZE,
148
- num_warps = ctx.num_warps,
149
- )
150
- dX = dY.view(*shape)
151
- return dX, None, None, None, None
152
- pass
153
- pass
154
-
155
-
156
- def fast_layernorm(layernorm, X):
157
- assert(layernorm.elementwise_affine is True)
158
- W = layernorm.weight
159
- bias = layernorm.bias
160
- eps = layernorm.variance_epsilon if \
161
- hasattr(layernorm, "variance_epsilon") \
162
- else layernorm.eps
163
- out = Fast_Layernorm.apply(X, W, bias, eps)
164
- return out
165
- pass
166
-
167
-
168
-
169
- def test_layernorm(
170
- dim = 1024, eps = 1e-5, dtype = torch.float16,
171
- bsz = 21, random_state = 3407, seqlen = 3341,
172
- ):
173
- from torch.nn import LayerNorm
174
- layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
175
- torch.cuda.manual_seed(random_state)
176
- torch.manual_seed(random_state)
177
- torch.nn.init.uniform_(layernorm.weight)
178
- torch.nn.init.uniform_(layernorm.bias)
179
- X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
180
- XX = X.clone()
181
- X .requires_grad_(True)
182
- XX.requires_grad_(True)
183
- Y = layernorm(X)
184
- YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
185
- Y.backward(YY)
186
- correct_grad = X.grad.clone()
187
- # from unsloth.kernels import fast_layernorm
188
- Y = fast_layernorm(layernorm, XX)
189
- Y.backward(YY)
190
- assert(torch.dist(correct_grad, XX.grad).item() <= 0.1)
191
- pass
192
-
193
-
194
- def testing_suite_layernorm():
195
- for dim in [512, 1024, 2048]:
196
- for dtype in [torch.float16, torch.bfloat16]:
197
- with torch.autocast(device_type = "cuda", dtype = dtype):
198
- for seqlen in [3341, 2048, 349]:
199
- for random_state in [3407, 42]:
200
- test_layernorm(
201
- dim = dim,
202
- eps = 1e-5,
203
- dtype = dtype,
204
- bsz = 21,
205
- random_state = random_state,
206
- seqlen = seqlen,
207
- )
208
- pass
209
- pass
210
- pass
211
- pass
212
- pass
213
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/rms_layernorm.py DELETED
@@ -1,297 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import triton
16
- import triton.language as tl
17
- import torch
18
- from .utils import calculate_settings
19
-
20
-
21
- @triton.jit
22
- def _rms_layernorm_forward(
23
- Y, Y_row_stride,
24
- X, X_row_stride,
25
- W, W_row_stride,
26
- r, r_row_stride,
27
- n_cols, eps,
28
- BLOCK_SIZE : tl.constexpr
29
- ):
30
- """
31
- Fast RMS Layernorm kernel
32
- Inspiration from a Triton tutorial:
33
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
34
- """
35
- row_idx = tl.program_id(0)
36
- col_offsets = tl.arange(0, BLOCK_SIZE)
37
- mask = col_offsets < n_cols
38
-
39
- Y += row_idx * Y_row_stride
40
- X += row_idx * X_row_stride
41
- r += row_idx * r_row_stride
42
-
43
- X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
44
- W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
45
-
46
- row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
47
- inv_var = tl.math.rsqrt(row_var + eps)
48
- tl.store(r, inv_var)
49
- normed = X_row * inv_var
50
- normed = normed.to(W_row.dtype) # Exact copy from HF
51
- output = normed * W_row
52
- tl.store(Y + col_offsets, output, mask = mask)
53
- pass
54
-
55
-
56
- @triton.heuristics({"GEMMA": lambda args: bool(args["GEMMA"]),})
57
- @triton.jit
58
- def _rms_layernorm_backward(
59
- dY, dY_row_stride,
60
- dX, dX_row_stride,
61
- X, X_row_stride,
62
- W, W_row_stride,
63
- r, r_row_stride,
64
- # dW, dW_row_stride,
65
- n_cols, eps,
66
- GEMMA : tl.constexpr,
67
- BLOCK_SIZE : tl.constexpr,
68
- ):
69
- """
70
- Fast RMS Layernorm kernel for the backward pass
71
- Inspiration from a Triton tutorial:
72
- https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
73
- """
74
- row_idx = tl.program_id(0)
75
- col_offsets = tl.arange(0, BLOCK_SIZE)
76
- mask = col_offsets < n_cols
77
-
78
- dY += row_idx * dY_row_stride
79
- X += row_idx * X_row_stride
80
- r += row_idx * r_row_stride
81
-
82
- if GEMMA: dX += row_idx * dY_row_stride
83
- else: dX = dY
84
-
85
- dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
86
- X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
87
- W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
88
-
89
- # Get saved row variance
90
- inv_var = tl.load(r).to(tl.float32)
91
- normed = X_row * inv_var
92
-
93
- if GEMMA: dY_W = dY_row * (W_row + 1.0)
94
- else: dY_W = dY_row * W_row
95
-
96
- rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
97
- output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
98
- tl.store(dX + col_offsets, output, mask = mask)
99
- pass
100
-
101
-
102
- @triton.jit
103
- def _gemma_rms_layernorm_forward(
104
- Y, Y_row_stride,
105
- X, X_row_stride,
106
- W, W_row_stride,
107
- r, r_row_stride,
108
- n_cols, eps,
109
- BLOCK_SIZE : tl.constexpr,
110
- ):
111
- # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
112
- # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
113
- # exactly. Essentially all in float32!
114
- row_idx = tl.program_id(0)
115
- col_offsets = tl.arange(0, BLOCK_SIZE)
116
- mask = col_offsets < n_cols
117
-
118
- Y += row_idx * Y_row_stride
119
- X += row_idx * X_row_stride
120
- r += row_idx * r_row_stride
121
-
122
- X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
123
- W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
124
-
125
- row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
126
- inv_var = tl.math.rsqrt(row_var + eps)
127
- tl.store(r, inv_var)
128
- normed = X_row * inv_var
129
- output = normed * (W_row + 1.0)
130
-
131
- tl.store(Y + col_offsets, output, mask = mask)
132
- pass
133
-
134
-
135
- class Fast_RMS_Layernorm(torch.autograd.Function):
136
- @staticmethod
137
- def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False):
138
- shape = X.shape
139
- dim : int = shape[-1]
140
- X = X.view(-1, dim)
141
- n_rows : int
142
- n_cols : int
143
- n_rows, n_cols = X.shape
144
- BLOCK_SIZE : int
145
- num_warps : int
146
- BLOCK_SIZE, num_warps = calculate_settings(n_cols)
147
-
148
- Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
149
- r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
150
-
151
- fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
152
- fx[(n_rows,)](
153
- Y, Y.stride(0),
154
- X, X.stride(0),
155
- W, W.stride(0),
156
- r, r.stride(0),
157
- n_cols, eps,
158
- BLOCK_SIZE = BLOCK_SIZE,
159
- num_warps = num_warps,
160
- )
161
- ctx.eps = eps
162
- ctx.BLOCK_SIZE = BLOCK_SIZE
163
- ctx.num_warps = num_warps
164
- ctx.GEMMA = gemma
165
- ctx.save_for_backward(X, W, r)
166
- return Y.view(*shape)
167
- pass
168
-
169
- @staticmethod
170
- def backward(ctx, dY : torch.Tensor):
171
- shape = dY.shape
172
- dim : int = shape[-1]
173
- dY = dY.view(-1, dim)
174
- X, W, r = ctx.saved_tensors
175
- n_rows : int
176
- n_cols : int
177
- n_rows, n_cols = dY.shape
178
- # dW = X
179
- dX = torch.empty_like(dY, device = "cuda:0") if ctx.GEMMA else dY
180
-
181
- _rms_layernorm_backward[(n_rows,)](
182
- dY, dY.stride(0),
183
- dX, dX.stride(0),
184
- X, X .stride(0),
185
- W, W .stride(0),
186
- r, r .stride(0),
187
- # dW, dW.stride(0),
188
- n_cols, ctx.eps,
189
- GEMMA = ctx.GEMMA,
190
- BLOCK_SIZE = ctx.BLOCK_SIZE,
191
- num_warps = ctx.num_warps,
192
- )
193
- dX = dX.view(*shape)
194
- return dX, None, None, None
195
- pass
196
- pass
197
-
198
-
199
- # [TODO] Unsure why RMS Layernorm is not torch.compiling properly
200
- @torch.compiler.disable
201
- def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
202
- W : torch.Tensor = layernorm.weight
203
- eps : float = layernorm.variance_epsilon if \
204
- hasattr(layernorm, "variance_epsilon") \
205
- else layernorm.eps
206
- out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
207
- return out
208
- pass
209
-
210
-
211
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
212
- class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
213
- def forward(self, X):
214
- return fast_rms_layernorm(self, X, gemma = False)
215
- pass
216
- pass
217
-
218
- try:
219
- from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
220
- class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
221
- def forward(self, X):
222
- return fast_rms_layernorm(self, X, gemma = False)
223
- pass
224
- pass
225
- except:
226
- pass
227
- pass
228
-
229
- def patch_rms_layernorm():
230
- import transformers.models.llama.modeling_llama
231
- transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
232
- try:
233
- import transformers.models.mllama.modeling_mllama
234
- transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
235
- except:
236
- pass
237
- return
238
- pass
239
-
240
-
241
- def unpatch_rms_layernorm():
242
- import transformers.models.llama.modeling_llama
243
- transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
244
- try:
245
- import transformers.models.mllama.modeling_mllama
246
- transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
247
- except:
248
- pass
249
- return
250
- return
251
- pass
252
-
253
-
254
- def test_rms_layernorm(
255
- dim = 1024, eps = 1e-5, dtype = torch.float16,
256
- bsz = 21, random_state = 3407, seqlen = 3341,
257
- ):
258
- from transformers.models.llama.modeling_llama import LlamaRMSNorm
259
- layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
260
- torch.cuda.manual_seed(random_state)
261
- torch.manual_seed(random_state)
262
- torch.nn.init.uniform_(layernorm.weight)
263
- X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
264
- XX = X.clone()
265
- X .requires_grad_(True)
266
- XX.requires_grad_(True)
267
- Y = layernorm(X)
268
- YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
269
- Y.backward(YY)
270
- correct_grad = X.grad.clone()
271
- # from unsloth.kernels import fast_rms_layernorm
272
- Y = fast_rms_layernorm(layernorm, XX)
273
- Y.backward(YY)
274
- assert(torch.amax(correct_grad - XX.grad).item() <= 0.05)
275
- pass
276
-
277
-
278
- def testing_suite_layernorm():
279
- for dim in [512, 1024, 2048]:
280
- for dtype in [torch.float16, torch.bfloat16]:
281
- with torch.autocast(device_type = "cuda", dtype = dtype):
282
- for seqlen in [3341, 2048, 349]:
283
- for random_state in [3407, 42]:
284
- test_rms_layernorm(
285
- dim = dim,
286
- eps = 1e-5,
287
- dtype = dtype,
288
- bsz = 21,
289
- random_state = random_state,
290
- seqlen = seqlen,
291
- )
292
- pass
293
- pass
294
- pass
295
- pass
296
- pass
297
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/rope_embedding.py DELETED
@@ -1,196 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import triton
16
- import triton.language as tl
17
- import torch
18
- from .utils import calculate_settings
19
- ROPE_GROUP_SIZE : int = 4
20
-
21
- @triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),})
22
- @triton.jit
23
- def _rope_embedding(
24
- Q, Q_row_stride,
25
- cos, cos_row_stride,
26
- sin, sin_row_stride,
27
- seqlen,
28
- head_dim : tl.constexpr,
29
- n_heads : tl.constexpr,
30
- BACKWARD_PASS : tl.constexpr,
31
- BLOCK_SIZE : tl.constexpr,
32
- ):
33
- """
34
- Calculates the RoPE Embedding quickly
35
- RoPE is Q * cos + rotate_half(Q) * sin
36
- See our blog post for more info
37
- """
38
- ROPE_GROUP_SIZE = 4
39
- row_position = tl.program_id(0)
40
- group_head_position = tl.program_id(1)
41
- col_offsets = tl.arange(0, BLOCK_SIZE)
42
- half_head_dim = head_dim // 2
43
- mask = col_offsets < half_head_dim
44
-
45
- sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
46
- half_head_dim*0 + col_offsets, mask = mask, other = 0)
47
- cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
48
- half_head_dim*0 + col_offsets, mask = mask, other = 0)
49
-
50
- if BACKWARD_PASS:
51
- # See our blog post for more info.
52
- sin1 = -sin1
53
- pass
54
-
55
- # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
56
- head_start = group_head_position * ROPE_GROUP_SIZE
57
- head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
58
-
59
- # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
60
- for k in range(head_start, head_end):
61
- offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
62
- offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
63
-
64
- # For Gemma - sometimes RoPE must be done in float32 and not bfloat16
65
- Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
66
- Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
67
-
68
- tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
69
- tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
70
- pass
71
- pass
72
-
73
-
74
- class Fast_RoPE_Embedding(torch.autograd.Function):
75
- @staticmethod
76
- def forward(ctx, Q, cos, sin):
77
- cos, sin = cos.squeeze(), sin.squeeze()
78
- batch : int
79
- seq_len : int
80
- n_heads : int
81
- head_dim : int
82
- batch, seq_len, n_heads, head_dim = Q.shape
83
- Q = Q.view(batch*seq_len, n_heads*head_dim)
84
- n_rows : int
85
- n_cols : int
86
- n_rows, n_cols = Q.shape
87
- assert(seq_len <= cos.shape[0])
88
-
89
- # [TODO] Changing blocksize to head_dim//2 seems to have
90
- # some concurrency / un-deterministic issues.
91
- BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
92
-
93
- # group_size = 4 # 4 or 8, too large group_size can hurt performance.
94
- div : int
95
- mod : int
96
- div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
97
- n_groups : int = div + (mod != 0)
98
-
99
- _rope_embedding[(n_rows, n_groups, )](
100
- Q, Q.stride(0),
101
- cos, cos.stride(0),
102
- sin, sin.stride(0),
103
- seq_len,
104
- head_dim, n_heads,
105
- BACKWARD_PASS = False,
106
- BLOCK_SIZE = BLOCK_SIZE,
107
- num_warps = num_warps,
108
- )
109
- ctx.BLOCK_SIZE = BLOCK_SIZE
110
- ctx.num_warps = num_warps
111
- ctx.n_groups = n_groups
112
- ctx.cos = cos
113
- ctx.sin = sin
114
- return Q.view(batch, seq_len, n_heads, head_dim)
115
- pass
116
-
117
- @staticmethod
118
- def backward(ctx, dY):
119
- batch : int
120
- seq_len : int
121
- n_heads : int
122
- head_dim : int
123
- batch, seq_len, n_heads, head_dim = dY.shape
124
- dY = dY.reshape(batch*seq_len, n_heads*head_dim)
125
- # Must be reshape not view
126
- n_rows : int
127
- n_cols : int
128
- n_rows, n_cols = dY.shape
129
-
130
- cos = ctx.cos
131
- sin = ctx.sin
132
-
133
- _rope_embedding[(n_rows, ctx.n_groups, )](
134
- dY, dY .stride(0),
135
- cos, cos.stride(0),
136
- sin, sin.stride(0),
137
- seq_len, head_dim, n_heads,
138
- BACKWARD_PASS = True,
139
- BLOCK_SIZE = ctx.BLOCK_SIZE,
140
- num_warps = ctx.num_warps,
141
- )
142
- dY = dY.view(batch, seq_len, n_heads, head_dim)
143
- return dY, None, None,
144
- pass
145
- pass
146
-
147
- # [TODO] Unsure why RoPE Embedding is not torch.compiling properly
148
- @torch.compiler.disable
149
- def fast_rope_embedding(Q, K, cos, sin):
150
- Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
151
- K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
152
- return Q, K
153
- pass
154
-
155
-
156
- class Slow_RoPE_Embedding(torch.autograd.Function):
157
- @staticmethod
158
- def forward(ctx, Q, cos, sin, position_ids):
159
- if position_ids is not None:
160
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
161
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
162
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
163
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
164
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
165
-
166
- # Q * cos + rotate_half(Q) * sin
167
- half = Q.shape[-1]//2
168
- RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
169
- Q *= cos
170
- Q.addcmul_(RH_Q, sin)
171
- # RH_Q *= sin
172
- # Q += RH_Q
173
- ctx.save_for_backward(cos, sin)
174
- return Q
175
- pass
176
-
177
- @staticmethod
178
- def backward(ctx, dY):
179
- cos, sin = ctx.saved_tensors
180
- # Q * cos + rotate_half.T(Q) * sin
181
- half = dY.shape[-1]//2
182
- RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
183
- dY *= cos
184
- dY.addcmul_(RH_dY, sin)
185
- # RH_dY *= sin
186
- # dY += RH_dY
187
- return dY, None, None, None
188
- pass
189
- pass
190
-
191
-
192
- def inplace_rope_embedding(Q, K, cos, sin, position_ids):
193
- Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
194
- K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
195
- return Q, K
196
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/swiglu.py DELETED
@@ -1,99 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import triton
16
- import triton.language as tl
17
- import torch
18
- from .utils import calculate_settings
19
-
20
-
21
- @triton.jit
22
- def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
23
- block_idx = tl.program_id(0)
24
- offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
25
- mask = offsets < n_elements
26
-
27
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
28
- g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
29
-
30
- # f = e * sigmoid(e)
31
- f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
32
- f_row = f_row.to(g_row.dtype) # Exact copy from HF
33
- # h = f * g
34
- h_row = f_row * g_row
35
-
36
- # Store h
37
- tl.store(h + offsets, h_row, mask = mask)
38
- pass
39
-
40
-
41
- def swiglu_fg_kernel(e, g):
42
- batch, seq_len, hd = e.shape
43
- n_elements = e.numel()
44
- h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
45
- grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
46
- _fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
47
- return h
48
- pass
49
-
50
-
51
- @triton.jit
52
- def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
53
- """
54
- e = e.float()
55
- se = 1.0 / (1.0 + torch.exp(-e))
56
- f = (se * e).to(dtype)
57
- h = f * g
58
- df = DW * f
59
- dg = DW * g
60
- de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
61
- """
62
- block_idx = tl.program_id(0)
63
- offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
64
- mask = offsets < n_elements
65
-
66
- DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
67
- e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
68
- g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
69
-
70
- # e = e.float()
71
- # se = 1.0 / (1.0 + torch.exp(-e))
72
- se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
73
- # f = (se * e).to(dtype)
74
- f_row = se_row * e_row
75
- f_row = f_row.to(DW_row.dtype)
76
- # h = f * g
77
- h_row = f_row * g_row
78
- # df = DW * f
79
- df_row = DW_row * f_row
80
- # dg = DW * g
81
- dg_row = DW_row * g_row
82
- # de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
83
- de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
84
- de_row = de_row.to(DW_row.dtype)
85
-
86
- # Store derivatives in buffers
87
- tl.store(DW + offsets, h_row, mask = mask) # h = f * g
88
- tl.store(e + offsets, df_row, mask = mask) # df = DW * f
89
- tl.store(g + offsets, de_row, mask = mask) # de
90
- pass
91
-
92
-
93
- def swiglu_DWf_DW_dfg_kernel(DW, e, g):
94
- batch_seq_len, hd = e.shape
95
- n_elements = e.numel()
96
- grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
97
- _DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
98
- return DW, e, g
99
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/kernels/utils.py DELETED
@@ -1,422 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import triton
16
- MAX_FUSED_SIZE : int = 65536
17
- next_power_of_2 = triton.next_power_of_2
18
-
19
- # torch.cuda.amp.custom_fwd is deprecated >= 2.4
20
- import torch
21
- from packaging.version import Version
22
- if Version(torch.__version__) < Version("2.4.0"):
23
- torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
24
- torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
25
- else:
26
- torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
27
- torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
28
- pass
29
-
30
-
31
- # tl.math.tanh now is libdevice.tanh
32
- from packaging.version import Version
33
- import triton
34
- import triton.language as tl
35
- if Version(triton.__version__) >= Version("3.0.0"):
36
- from triton.language.extra import libdevice
37
- triton_tanh = libdevice.tanh
38
- triton_cast = tl.cast
39
- else:
40
- triton_tanh = tl.math.tanh
41
- # No casting in old Triton versions
42
- @triton.jit
43
- def triton_cast(x, dtype):
44
- return x.to(dtype)
45
- pass
46
- pass
47
-
48
-
49
- def calculate_settings(n : int) -> (int, int,):
50
- BLOCK_SIZE : int = next_power_of_2(n)
51
- if BLOCK_SIZE > MAX_FUSED_SIZE:
52
- raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
53
- f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
54
- num_warps : int = 4
55
- if BLOCK_SIZE >= 32768: num_warps = 32
56
- elif BLOCK_SIZE >= 8192: num_warps = 16
57
- elif BLOCK_SIZE >= 2048: num_warps = 8
58
- return BLOCK_SIZE, num_warps
59
- pass
60
-
61
-
62
- import bitsandbytes as bnb
63
- # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
64
- HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
65
- global CUDA_STREAM
66
- CUDA_STREAM = None
67
- get_ptr = bnb.functional.get_ptr
68
- import ctypes
69
- cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
70
- cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
71
- cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
72
- cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
73
- cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
74
-
75
-
76
- def QUANT_STATE(W):
77
- return getattr(W, "quant_state", None)
78
- pass
79
-
80
-
81
- def get_lora_parameters(proj):
82
- # For DPO or disabled adapters
83
- base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
84
- W = base_layer.weight
85
-
86
- if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
87
- return W, QUANT_STATE(W), None, None, None
88
- pass
89
-
90
- active_adapter = proj.active_adapters[0] if \
91
- hasattr(proj, "active_adapters") else proj.active_adapter
92
- A = proj.lora_A [active_adapter].weight
93
- B = proj.lora_B [active_adapter].weight
94
- s = proj.scaling[active_adapter]
95
- return W, QUANT_STATE(W), A, B, s
96
- pass
97
-
98
-
99
- def get_lora_parameters_bias(proj):
100
- # For DPO or disabled adapters
101
- base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
102
- W = base_layer.weight
103
- bias = base_layer.bias
104
-
105
- if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
106
- return W, QUANT_STATE(W), None, None, None, bias
107
- pass
108
-
109
- active_adapter = proj.active_adapters[0] if \
110
- hasattr(proj, "active_adapters") else proj.active_adapter
111
- A = proj.lora_A [active_adapter].weight
112
- B = proj.lora_B [active_adapter].weight
113
- s = proj.scaling[active_adapter]
114
- return W, QUANT_STATE(W), A, B, s, bias
115
- pass
116
-
117
-
118
- if HAS_CUDA_STREAM:
119
- def fast_dequantize(W, quant_state = None, out = None):
120
- if quant_state is None: return W
121
- if type(quant_state) is not list:
122
- # New quant_state as a class
123
- # https://github.com/TimDettmers/bitsandbytes/pull/763/files
124
- absmax = quant_state.absmax
125
- shape = quant_state.shape
126
- dtype = quant_state.dtype
127
- blocksize = quant_state.blocksize
128
- offset = quant_state.offset
129
- state2 = quant_state.state2
130
- absmax2 = state2.absmax
131
- code2 = state2.code
132
- blocksize2 = state2.blocksize
133
- else:
134
- # Old quant_state as a list of lists
135
- absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
136
- offset, state2 = compressed_stats
137
- absmax2, code2, blocksize2, _, _, _, _ = state2
138
- pass
139
- global CUDA_STREAM
140
- if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
141
-
142
- # Create weight matrix
143
- if out is None:
144
- out = torch.empty(shape, dtype = dtype, device = "cuda:0")
145
- else:
146
- assert(out.shape == shape)
147
- assert(out.dtype == dtype)
148
-
149
- # NF4 dequantization of statistics
150
- n_elements_absmax = absmax.numel()
151
- out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
152
-
153
- # Do dequantization
154
- ptr_out_absmax = get_ptr(out_absmax)
155
- cdequantize_blockwise_fp32(
156
- get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
157
- ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM,
158
- )
159
- out_absmax += offset
160
-
161
- fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
162
- cdequantize_blockwise_bf16_nf4
163
- fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
164
- ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,)
165
-
166
- # Careful returning transposed data
167
- is_transposed = (True if W.shape[0] == 1 else False)
168
- return out.t() if is_transposed else out
169
- pass
170
- else:
171
- def fast_dequantize(W, quant_state = None, out = None):
172
- if quant_state is None: return W
173
- if type(quant_state) is not list:
174
- # New quant_state as a class
175
- # https://github.com/TimDettmers/bitsandbytes/pull/763/files
176
- absmax = quant_state.absmax
177
- shape = quant_state.shape
178
- dtype = quant_state.dtype
179
- blocksize = quant_state.blocksize
180
- offset = quant_state.offset
181
- state2 = quant_state.state2
182
- absmax2 = state2.absmax
183
- code2 = state2.code
184
- blocksize2 = state2.blocksize
185
- else:
186
- # Old quant_state as a list of lists
187
- absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
188
- offset, state2 = compressed_stats
189
- absmax2, code2, blocksize2, _, _, _, _ = state2
190
- pass
191
-
192
- # Create weight matrix
193
- if out is None:
194
- out = torch.empty(shape, dtype = dtype, device = "cuda:0")
195
- else:
196
- assert(out.shape == shape)
197
- assert(out.dtype == dtype)
198
-
199
- # NF4 dequantization of statistics
200
- n_elements_absmax = absmax.numel()
201
- out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
202
-
203
- # Do dequantization
204
- ptr_out_absmax = get_ptr(out_absmax)
205
- cdequantize_blockwise_fp32(
206
- get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
207
- ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax),
208
- )
209
- out_absmax += offset
210
-
211
- fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
212
- cdequantize_blockwise_bf16_nf4
213
- fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
214
- ctypes.c_int(blocksize), ctypes.c_int(out.numel()),)
215
-
216
- # Careful returning transposed data
217
- is_transposed = (True if W.shape[0] == 1 else False)
218
- return out.t() if is_transposed else out
219
- pass
220
- pass
221
-
222
-
223
- if HAS_CUDA_STREAM:
224
- def fast_gemv(X, W, quant_state, out = None):
225
- if quant_state is None: return torch.matmul(X, W, out = out)
226
- # For fast X @ W where seq_len == 1
227
- # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
228
- _, q_len, hd = X.shape
229
- # assert(q_len == 1)
230
-
231
- if type(quant_state) is not list:
232
- # https://github.com/TimDettmers/bitsandbytes/pull/763/files
233
- absmax = quant_state.absmax
234
- shape = quant_state.shape
235
- dtype = quant_state.dtype
236
- blocksize = quant_state.blocksize
237
- stats = quant_state.code
238
- offset = quant_state.offset
239
- state2 = quant_state.state2
240
- absmax2 = state2.absmax
241
- code2 = state2.code
242
- blocksize2 = state2.blocksize
243
- else:
244
- absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
245
- offset, state2 = compressed_stats
246
- absmax2, code2, blocksize2, _, _, _, _ = state2
247
- pass
248
- global CUDA_STREAM
249
- if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
250
-
251
- # assert(dtype == X.dtype)
252
- bout = shape[0]
253
-
254
- if out is None:
255
- out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
256
- # else:
257
- # assert(out.shape == (1, 1, bout,))
258
- # pass
259
-
260
- n = 1
261
- m = shape[0]
262
- k = shape[1]
263
- lda = shape[0]
264
- ldc = shape[0]
265
- ldb = (hd+1)//2
266
- m = ctypes.c_int32(m)
267
- n = ctypes.c_int32(n)
268
- k = ctypes.c_int32(k)
269
- lda = ctypes.c_int32(lda)
270
- ldb = ctypes.c_int32(ldb)
271
- ldc = ctypes.c_int32(ldc)
272
-
273
- df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
274
- cdequantize_blockwise_fp32(
275
- get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
276
- ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM,
277
- )
278
- df += offset
279
- absmax = df
280
-
281
- fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
282
- cgemm_4bit_inference_naive_bf16
283
-
284
- blocksize = ctypes.c_int32(blocksize)
285
- fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
286
- lda, ldb, ldc, blocksize, CUDA_STREAM,)
287
-
288
- return out
289
- pass
290
- else:
291
- def fast_gemv(X, W, quant_state, out = None):
292
- if quant_state is None: return torch.matmul(X, W, out = out)
293
- # For fast X @ W where seq_len == 1
294
- # From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
295
- _, q_len, hd = X.shape
296
- # assert(q_len == 1)
297
-
298
- if type(quant_state) is not list:
299
- # https://github.com/TimDettmers/bitsandbytes/pull/763/files
300
- absmax = quant_state.absmax
301
- shape = quant_state.shape
302
- dtype = quant_state.dtype
303
- blocksize = quant_state.blocksize
304
- stats = quant_state.code
305
- offset = quant_state.offset
306
- state2 = quant_state.state2
307
- absmax2 = state2.absmax
308
- code2 = state2.code
309
- blocksize2 = state2.blocksize
310
- else:
311
- absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
312
- offset, state2 = compressed_stats
313
- absmax2, code2, blocksize2, _, _, _, _ = state2
314
- pass
315
- # assert(dtype == X.dtype)
316
- bout = shape[0]
317
-
318
- if out is None:
319
- out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
320
- # else:
321
- # assert(out.shape == (1, 1, bout,))
322
- # pass
323
-
324
- n = 1
325
- m = shape[0]
326
- k = shape[1]
327
- lda = shape[0]
328
- ldc = shape[0]
329
- ldb = (hd+1)//2
330
- m = ctypes.c_int32(m)
331
- n = ctypes.c_int32(n)
332
- k = ctypes.c_int32(k)
333
- lda = ctypes.c_int32(lda)
334
- ldb = ctypes.c_int32(ldb)
335
- ldc = ctypes.c_int32(ldc)
336
-
337
- df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
338
- cdequantize_blockwise_fp32(
339
- get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
340
- ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
341
- )
342
- df += offset
343
- absmax = df
344
-
345
- fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
346
- cgemm_4bit_inference_naive_bf16
347
-
348
- blocksize = ctypes.c_int32(blocksize)
349
- fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
350
- lda, ldb, ldc, blocksize,)
351
-
352
- return out
353
- pass
354
- pass
355
-
356
-
357
- def fast_linear_forward(proj, X, temp_lora = None, out = None):
358
-
359
- W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
360
- bsz, q_len, in_dim = X.shape
361
- if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
362
-
363
- if W_quant is None:
364
- out = torch.matmul(X, W.t(), out = out)
365
- elif bsz == 1 and q_len == 1:
366
- out = fast_gemv(X, W, W_quant, out = out)
367
- else:
368
- W = fast_dequantize(W.t(), W_quant)
369
- out = torch.matmul(X, W, out = out)
370
- pass
371
-
372
- # Add in LoRA weights
373
- if lora_A is not None:
374
- out_dim = out.shape[2]
375
- dtype = X.dtype
376
-
377
- if not hasattr(lora_A, "_fast_lora"):
378
- lora_A._fast_lora = lora_A.to(dtype)
379
- lora_B._fast_lora = lora_B.to(dtype)
380
- pass
381
-
382
- if bsz == 1:
383
- out = out.view(out_dim)
384
- temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
385
- out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
386
- else:
387
- out = out.view(bsz, out_dim)
388
- temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
389
- out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
390
- pass
391
- out = out.view(bsz, 1, out_dim)
392
- pass
393
-
394
- if bias is not None: out += bias
395
-
396
- return out
397
- pass
398
-
399
-
400
- def matmul_lora(X, W, W_quant, A, B, s, out = None):
401
- dtype = X.dtype
402
- W = fast_dequantize(W.t(), W_quant)
403
-
404
- if X.dim() == 3:
405
- batch, seq_len, d = X.shape
406
- X = X.view(-1, X.shape[-1])
407
- reshape = True
408
- else:
409
- reshape = False
410
- pass
411
-
412
- out = torch.matmul(X, W, out = out)
413
- if W_quant is not None: del W
414
-
415
- if A is not None:
416
- # LoRA is enabled
417
- A, B = A.t(), B.t()
418
- out += (X @ A.to(dtype)) @ (s * B.to(dtype))
419
- pass
420
-
421
- return out.view(batch, seq_len, -1) if reshape else out
422
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unsloth-main/unsloth/models/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from .granite import FastGraniteModel
17
- from .loader import FastLanguageModel, FastVisionModel
18
- from .llama import FastLlamaModel
19
- from .mistral import FastMistralModel
20
- from .qwen2 import FastQwen2Model
21
- from .dpo import PatchDPOTrainer, PatchKTOTrainer
22
- from ._utils import is_bfloat16_supported