Delete unsloth-main
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- unsloth-main/.github/FUNDING.yml +0 -13
- unsloth-main/CONTRIBUTING.md +0 -29
- unsloth-main/LICENSE +0 -201
- unsloth-main/README.md +0 -492
- unsloth-main/images/Assistant.png +0 -0
- unsloth-main/images/Colab.png +0 -0
- unsloth-main/images/Discord button.png +0 -0
- unsloth-main/images/Discord.png +0 -0
- unsloth-main/images/Documentation Button.png +0 -0
- unsloth-main/images/Free version button.png +0 -0
- unsloth-main/images/Kaggle.png +0 -0
- unsloth-main/images/Kofi button.png +0 -0
- unsloth-main/images/LAION 2GPU.png +0 -0
- unsloth-main/images/Merge.png +0 -0
- unsloth-main/images/Run.png +0 -0
- unsloth-main/images/Slim Orca 2GPUs.png +0 -0
- unsloth-main/images/Terminal_Type.png +0 -0
- unsloth-main/images/Where_Terminal.png +0 -0
- unsloth-main/images/buy me a coffee button.png +0 -0
- unsloth-main/images/documentation github button.png +0 -0
- unsloth-main/images/documentation green button.png +0 -0
- unsloth-main/images/documentation lighter.png +0 -0
- unsloth-main/images/documentation white button.png +0 -0
- unsloth-main/images/made with unsloth.png +0 -0
- unsloth-main/images/ollama.png +0 -0
- unsloth-main/images/peft x trl button.png +0 -0
- unsloth-main/images/start free finetune button.png +0 -0
- unsloth-main/images/unsloth end.png +0 -0
- unsloth-main/images/unsloth loading page render.png +0 -0
- unsloth-main/images/unsloth logo black text.png +0 -0
- unsloth-main/images/unsloth logo only.png +0 -0
- unsloth-main/images/unsloth logo white text.png +0 -0
- unsloth-main/images/unsloth made with love.png +0 -0
- unsloth-main/images/unsloth new logo.png +0 -0
- unsloth-main/pyproject.toml +0 -418
- unsloth-main/unsloth-cli.py +0 -221
- unsloth-main/unsloth/__init__.py +0 -181
- unsloth-main/unsloth/_auto_install.py +0 -31
- unsloth-main/unsloth/chat_templates.py +0 -2105
- unsloth-main/unsloth/kernels/__init__.py +0 -65
- unsloth-main/unsloth/kernels/cross_entropy_loss.py +0 -405
- unsloth-main/unsloth/kernels/fast_lora.py +0 -490
- unsloth-main/unsloth/kernels/flex_attention.py +0 -181
- unsloth-main/unsloth/kernels/geglu.py +0 -203
- unsloth-main/unsloth/kernels/layernorm.py +0 -213
- unsloth-main/unsloth/kernels/rms_layernorm.py +0 -297
- unsloth-main/unsloth/kernels/rope_embedding.py +0 -196
- unsloth-main/unsloth/kernels/swiglu.py +0 -99
- unsloth-main/unsloth/kernels/utils.py +0 -422
- unsloth-main/unsloth/models/__init__.py +0 -22
unsloth-main/.github/FUNDING.yml
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
# These are supported funding model platforms
|
2 |
-
|
3 |
-
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
4 |
-
patreon: # Replace with a single Patreon username
|
5 |
-
open_collective: # Replace with a single Open Collective username
|
6 |
-
ko_fi: unsloth
|
7 |
-
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
8 |
-
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
9 |
-
liberapay: # Replace with a single Liberapay username
|
10 |
-
issuehunt: # Replace with a single IssueHunt username
|
11 |
-
otechie: # Replace with a single Otechie username
|
12 |
-
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
13 |
-
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/CONTRIBUTING.md
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
# 🦥 Contributing to Unsloth
|
2 |
-
|
3 |
-
Thank you for not only using Unsloth but also for being interested in helping out! We value all contributions, whether they come in the form of code, ideas, support for others or just by simply spreading the word of Unsloth! 💕
|
4 |
-
|
5 |
-
- **[Support the Community](https://github.com/unslothai/unsloth/issues)**: Answer questions, review pull requests, or assist others in discussions.
|
6 |
-
- **Fix Bugs**: Identify and resolve issues with the existing codebase.
|
7 |
-
- **Submit Ideas**: Request new features or share enhancements you'd like to see.
|
8 |
-
- **Develop Features**: Implement new functionality or improve existing tools which can be done via PRs.
|
9 |
-
- **[Improve Documentation](https://docs.unsloth.ai/)**: Help by creating guides, FAQs, or enhancing clarity.
|
10 |
-
|
11 |
-
One of the best ways to support us is by spreading the word about Unsloth! Share how it’s powering your amazing projects in blog posts or social media, and inspire others to explore its potential. Even a simple star on our repo goes a long way in showing your support and helping the community grow. 🌟
|
12 |
-
|
13 |
-
## Submitting Issues
|
14 |
-
If you find a bug or have a feature idea, we’d love to hear from you! Here’s how to make your submission stand out:
|
15 |
-
|
16 |
-
### Reporting Bugs
|
17 |
-
1. **Search First**: Check if the issue has already been reported using GitHub’s search bar under Issues.
|
18 |
-
2. **Details Matter**: Is this on Google Colab, Kaggle, or on another platform service? Are you using Unsloth's official notebook? Include your OS, Python version, and other relevant details. For bugs, a concise code snippet that reproduces the issue is incredibly helpful.
|
19 |
-
3. **Be Thorough**: Attach screenshots, traceback logs, or any additional information that might speed up resolution.
|
20 |
-
|
21 |
-
## Spread the Word
|
22 |
-
Your support extends beyond code:
|
23 |
-
- Spread the word by writing about Unsloth in blogs or social media.
|
24 |
-
- Share how Unsloth powers your projects.
|
25 |
-
- Star our repository to show your appreciation.
|
26 |
-
|
27 |
-
Finally, please be mindful of our [Code of Conduct](https://github.com/unslothai/unsloth/tree/main/unsloth/CODE_OF_CONDUCT.md) to ensure a welcoming and inclusive environment for everyone.
|
28 |
-
|
29 |
-
Thank you so much for reading and we hope you have lots of fun using Unsloth! 🦥
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/README.md
DELETED
@@ -1,492 +0,0 @@
|
|
1 |
-
<div align="center">
|
2 |
-
|
3 |
-
<a href="https://unsloth.ai"><picture>
|
4 |
-
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png">
|
5 |
-
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png">
|
6 |
-
<img alt="unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png" height="110" style="max-width: 100%;">
|
7 |
-
</picture></a>
|
8 |
-
|
9 |
-
<a href="https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/start free finetune button.png" height="48"></a>
|
10 |
-
<a href="https://discord.gg/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/Discord button.png" height="48"></a>
|
11 |
-
<a href="https://docs.unsloth.ai"><img src="https://raw.githubusercontent.com/unslothai/unsloth/refs/heads/main/images/Documentation%20Button.png" height="48"></a>
|
12 |
-
|
13 |
-
### Finetune Llama 3.2, Mistral, Phi-3.5, Qwen 2.5 & Gemma 2-5x faster with 80% less memory!
|
14 |
-
|
15 |
-

|
16 |
-
|
17 |
-
</div>
|
18 |
-
|
19 |
-
## ✨ Finetune for Free
|
20 |
-
|
21 |
-
All notebooks are **beginner friendly**! Add your dataset, click "Run All", and you'll get a 2x faster finetuned model which can be exported to GGUF, Ollama, vLLM or uploaded to Hugging Face.
|
22 |
-
|
23 |
-
| Unsloth supports | Free Notebooks | Performance | Memory use |
|
24 |
-
|-----------|---------|--------|----------|
|
25 |
-
| **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing) | 2x faster | 60% less |
|
26 |
-
| **Llama 3.2 Vision (11B)** | [▶️ Start for free](https://colab.research.google.com/drive/1j0N4XTY1zXXy7mPAhOC1_gMYZ2F2EBlk?usp=sharing) | 2x faster | 40% less |
|
27 |
-
| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) | 2x faster | 60% less |
|
28 |
-
| **Phi-3.5 (mini)** | [▶️ Start for free](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) | 2x faster | 50% less |
|
29 |
-
| **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) | 2x faster | 63% less |
|
30 |
-
| **Qwen 2.5 (7B)** | [▶️ Start for free](https://colab.research.google.com/drive/1Kose-ucXO1IBaZq5BvbwWieuubP7hxvQ?usp=sharing) | 2x faster | 63% less |
|
31 |
-
| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) | 2.2x faster | 73% less |
|
32 |
-
| **Ollama** | [▶️ Start for free](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing) | 1.9x faster | 43% less |
|
33 |
-
| **ORPO** | [▶️ Start for free](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) | 1.9x faster | 43% less |
|
34 |
-
| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 43% less |
|
35 |
-
|
36 |
-
- See [all our notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks) and [all our models](https://docs.unsloth.ai/get-started/all-our-models)
|
37 |
-
- **Kaggle Notebooks** for [Llama 3.2 Kaggle notebook](https://www.kaggle.com/danielhanchen/kaggle-llama-3-2-1b-3b-unsloth-notebook), [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)
|
38 |
-
- Run notebooks for [Llama 3.2 conversational](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing), [Llama 3.1 conversational](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing)
|
39 |
-
- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text
|
40 |
-
- This [continued pretraining notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) is for learning another language
|
41 |
-
- Click [here](https://docs.unsloth.ai/) for detailed documentation for Unsloth.
|
42 |
-
|
43 |
-
## 🦥 Unsloth.ai News
|
44 |
-
- 📣 NEW! [Llama 3.3 (70B)](https://huggingface.co/collections/unsloth/llama-33-all-versions-67535d7d994794b9d7cf5e9f), Meta's latest model is now supported.
|
45 |
-
- 📣 NEW! We worked with Apple to add [Cut Cross Entropy](https://arxiv.org/abs/2411.09009). Unsloth now supports 89K context for Meta's Llama 3.3 (70B) on a 80GB GPU - 13x longer than HF+FA2. For Llama 3.1 (8B), Unsloth enables 342K context, surpassing its native 128K support.
|
46 |
-
- 📣 NEW! Introducing Unsloth [Dynamic 4-bit Quantization](https://unsloth.ai/blog/dynamic-4bit)! We dynamically opt not to quantize certain parameters and this greatly increases accuracy while only using <10% more VRAM than BnB 4-bit. See our collection on [Hugging Face here.](https://huggingface.co/collections/unsloth/unsloth-4-bit-dynamic-quants-67503bb873f89e15276c44e7)
|
47 |
-
- 📣 NEW! [Vision models](https://unsloth.ai/blog/vision) now supported! [Llama 3.2 Vision (11B)](https://colab.research.google.com/drive/1j0N4XTY1zXXy7mPAhOC1_gMYZ2F2EBlk?usp=sharing), [Qwen 2.5 VL (7B)](https://colab.research.google.com/drive/1whHb54GNZMrNxIsi2wm2EY_-Pvo2QyKh?usp=sharing) and [Pixtral (12B) 2409](https://colab.research.google.com/drive/1K9ZrdwvZRE96qGkCq_e88FgV3MLnymQq?usp=sharing)
|
48 |
-
- 📣 NEW! Qwen-2.5 including [Coder](https://colab.research.google.com/drive/18sN803sU23XuJV9Q8On2xgqHSer6-UZF?usp=sharing) models are now supported with bugfixes. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/drive/1qN1CEalC70EO1wGKhNxs1go1W9So61R5?usp=sharing)
|
49 |
-
- 📣 NEW! We found and helped fix a [gradient accumulation bug](https://unsloth.ai/blog/gradient)! Please update Unsloth and transformers.
|
50 |
-
<details>
|
51 |
-
<summary>Click for more news</summary>
|
52 |
-
|
53 |
-
- 📣 Try out [Chat interface](https://colab.research.google.com/drive/1i-8ESvtLRGNkkUQQr_-z_rcSAIo9c3lM?usp=sharing)!
|
54 |
-
- 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) finetuning fits in under 16GB of VRAM!
|
55 |
-
- 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) & [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct are now supported
|
56 |
-
- 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs.
|
57 |
-
- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) for other languages like Korean!
|
58 |
-
- 📣 [2x faster inference](https://colab.research.google.com/drive/1aqlNQi7MMJbynFDyOQteD2t0yVfjb9Zh?usp=sharing) added for all our models
|
59 |
-
- 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
|
60 |
-
</details>
|
61 |
-
|
62 |
-
## 🔗 Links and Resources
|
63 |
-
| Type | Links |
|
64 |
-
| ------------------------------- | --------------------------------------- |
|
65 |
-
| 📚 **Documentation & Wiki** | [Read Our Docs](https://docs.unsloth.ai) |
|
66 |
-
| <img height="14" src="https://upload.wikimedia.org/wikipedia/commons/6/6f/Logo_of_Twitter.svg" /> **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)|
|
67 |
-
| 💾 **Installation** | [unsloth/README.md](https://github.com/unslothai/unsloth/tree/main#-installation-instructions)|
|
68 |
-
| 🥇 **Benchmarking** | [Performance Tables](https://github.com/unslothai/unsloth/tree/main#-performance-benchmarking)
|
69 |
-
| 🌐 **Released Models** | [Unsloth Releases](https://docs.unsloth.ai/get-started/all-our-models)|
|
70 |
-
| ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog)|
|
71 |
-
| <img height="14" src="https://redditinc.com/hs-fs/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" /> **Reddit** | [Join our Reddit page](https://reddit.com/r/unsloth)|
|
72 |
-
|
73 |
-
## ⭐ Key Features
|
74 |
-
- All kernels written in [OpenAI's Triton](https://openai.com/research/triton) language. **Manual backprop engine**.
|
75 |
-
- **0% loss in accuracy** - no approximation methods - all exact.
|
76 |
-
- No change of hardware. Supports NVIDIA GPUs since 2018+. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow.
|
77 |
-
- Works on **Linux** and **Windows** via WSL.
|
78 |
-
- Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
|
79 |
-
- Open source trains 5x faster - see [Unsloth Pro](https://unsloth.ai/) for up to **30x faster training**!
|
80 |
-
- If you trained a model with 🦥Unsloth, you can use this cool sticker! <img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/made with unsloth.png" height="50" align="center" />
|
81 |
-
|
82 |
-
|
83 |
-
## 🥇 Performance Benchmarking
|
84 |
-
- For the full list of **reproducible** benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
|
85 |
-
|
86 |
-
| 1 A100 40GB | 🤗Hugging Face | Flash Attention | 🦥Unsloth Open Source | 🦥[Unsloth Pro](https://unsloth.ai/pricing) |
|
87 |
-
|--------------|--------------|-----------------|---------------------|-----------------|
|
88 |
-
| Alpaca | 1x | 1.04x | 1.98x | **15.64x** |
|
89 |
-
| LAION Chip2 | 1x | 0.92x | 1.61x | **20.73x** |
|
90 |
-
| OASST | 1x | 1.19x | 2.17x | **14.83x** |
|
91 |
-
| Slim Orca | 1x | 1.18x | 2.22x | **14.82x** |
|
92 |
-
|
93 |
-
- Benchmarking table below was conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl).
|
94 |
-
|
95 |
-
| Free Colab T4 | Dataset | 🤗Hugging Face | Pytorch 2.1.1 | 🦥Unsloth | 🦥 VRAM reduction |
|
96 |
-
| --- | --- | --- | --- | --- | --- |
|
97 |
-
| Llama-2 7b | OASST | 1x | 1.19x | 1.95x | -43.3% |
|
98 |
-
| Mistral 7b | Alpaca | 1x | 1.07x | 1.56x | -13.7% |
|
99 |
-
| Tiny Llama 1.1b | Alpaca | 1x | 2.06x | 3.87x | -73.8% |
|
100 |
-
| DPO with Zephyr | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
|
101 |
-
|
102 |
-

|
103 |
-
|
104 |
-
## 💾 Installation Instructions
|
105 |
-
|
106 |
-
For stable releases, use `pip install unsloth`. We recommend `pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"` for most installations though.
|
107 |
-
|
108 |
-
### Conda Installation
|
109 |
-
`⚠️Only use Conda if you have it. If not, use Pip`. Select either `pytorch-cuda=11.8,12.1` for CUDA 11.8 or CUDA 12.1. We support `python=3.10,3.11,3.12`.
|
110 |
-
```bash
|
111 |
-
conda create --name unsloth_env \
|
112 |
-
python=3.11 \
|
113 |
-
pytorch-cuda=12.1 \
|
114 |
-
pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
|
115 |
-
-y
|
116 |
-
conda activate unsloth_env
|
117 |
-
|
118 |
-
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
|
119 |
-
pip install --no-deps trl peft accelerate bitsandbytes
|
120 |
-
```
|
121 |
-
|
122 |
-
<details>
|
123 |
-
<summary>If you're looking to install Conda in a Linux environment, <a href="https://docs.anaconda.com/miniconda/">read here</a>, or run the below 🔽</summary>
|
124 |
-
|
125 |
-
```bash
|
126 |
-
mkdir -p ~/miniconda3
|
127 |
-
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
128 |
-
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
129 |
-
rm -rf ~/miniconda3/miniconda.sh
|
130 |
-
~/miniconda3/bin/conda init bash
|
131 |
-
~/miniconda3/bin/conda init zsh
|
132 |
-
```
|
133 |
-
</details>
|
134 |
-
|
135 |
-
### Pip Installation
|
136 |
-
`⚠️Do **NOT** use this if you have Conda.` Pip is a bit more complex since there are dependency issues. The pip command is different for `torch 2.2,2.3,2.4,2.5` and CUDA versions.
|
137 |
-
|
138 |
-
For other torch versions, we support `torch211`, `torch212`, `torch220`, `torch230`, `torch240` and for CUDA versions, we support `cu118` and `cu121` and `cu124`. For Ampere devices (A100, H100, RTX3090) and above, use `cu118-ampere` or `cu121-ampere` or `cu124-ampere`.
|
139 |
-
|
140 |
-
For example, if you have `torch 2.4` and `CUDA 12.1`, use:
|
141 |
-
```bash
|
142 |
-
pip install --upgrade pip
|
143 |
-
pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
144 |
-
```
|
145 |
-
|
146 |
-
Another example, if you have `torch 2.5` and `CUDA 12.4`, use:
|
147 |
-
```bash
|
148 |
-
pip install --upgrade pip
|
149 |
-
pip install "unsloth[cu124-torch250] @ git+https://github.com/unslothai/unsloth.git"
|
150 |
-
```
|
151 |
-
|
152 |
-
And other examples:
|
153 |
-
```bash
|
154 |
-
pip install "unsloth[cu121-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
155 |
-
pip install "unsloth[cu118-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
156 |
-
pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
157 |
-
pip install "unsloth[cu118-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
158 |
-
|
159 |
-
pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
|
160 |
-
pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
|
161 |
-
|
162 |
-
pip install "unsloth[cu121-torch250] @ git+https://github.com/unslothai/unsloth.git"
|
163 |
-
pip install "unsloth[cu124-ampere-torch250] @ git+https://github.com/unslothai/unsloth.git"
|
164 |
-
```
|
165 |
-
|
166 |
-
Or, run the below in a terminal to get the **optimal** pip installation command:
|
167 |
-
```bash
|
168 |
-
wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/_auto_install.py | python -
|
169 |
-
```
|
170 |
-
|
171 |
-
Or, run the below manually in a Python REPL:
|
172 |
-
```python
|
173 |
-
try: import torch
|
174 |
-
except: raise ImportError('Install torch via `pip install torch`')
|
175 |
-
from packaging.version import Version as V
|
176 |
-
v = V(torch.__version__)
|
177 |
-
cuda = str(torch.version.cuda)
|
178 |
-
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
179 |
-
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!")
|
180 |
-
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
|
181 |
-
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
|
182 |
-
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
|
183 |
-
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
|
184 |
-
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
|
185 |
-
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
|
186 |
-
elif v < V('2.6.0'): x = 'cu{}{}-torch250'
|
187 |
-
else: raise RuntimeError(f"Torch = {v} too new!")
|
188 |
-
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
189 |
-
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
|
190 |
-
```
|
191 |
-
|
192 |
-
### Windows Installation
|
193 |
-
|
194 |
-
To run Unsloth directly on Windows:
|
195 |
-
- Install Triton from this Windows fork and follow the instructions: https://github.com/woct0rdho/triton-windows
|
196 |
-
- In the SFTTrainer, set `dataset_num_proc=1` to avoid a crashing issue:
|
197 |
-
```python
|
198 |
-
trainer = SFTTrainer(
|
199 |
-
dataset_num_proc=1,
|
200 |
-
...
|
201 |
-
)
|
202 |
-
```
|
203 |
-
|
204 |
-
For **advanced installation instructions** or if you see weird errors during installations:
|
205 |
-
|
206 |
-
1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
|
207 |
-
2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
|
208 |
-
3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs.
|
209 |
-
4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
|
210 |
-
|
211 |
-
## 📜 [Documentation](https://docs.unsloth.ai)
|
212 |
-
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
|
213 |
-
- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
|
214 |
-
- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
215 |
-
|
216 |
-
```python
|
217 |
-
from unsloth import FastLanguageModel
|
218 |
-
from unsloth import is_bfloat16_supported
|
219 |
-
import torch
|
220 |
-
from trl import SFTTrainer
|
221 |
-
from transformers import TrainingArguments
|
222 |
-
from datasets import load_dataset
|
223 |
-
max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
|
224 |
-
# Get LAION dataset
|
225 |
-
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
|
226 |
-
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
|
227 |
-
|
228 |
-
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
|
229 |
-
fourbit_models = [
|
230 |
-
"unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster!
|
231 |
-
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
|
232 |
-
"unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster!
|
233 |
-
"unsloth/llama-3-8b-Instruct-bnb-4bit",
|
234 |
-
"unsloth/llama-3-70b-bnb-4bit",
|
235 |
-
"unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster!
|
236 |
-
"unsloth/Phi-3-medium-4k-instruct",
|
237 |
-
"unsloth/mistral-7b-bnb-4bit",
|
238 |
-
"unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster!
|
239 |
-
] # More models at https://huggingface.co/unsloth
|
240 |
-
|
241 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
242 |
-
model_name = "unsloth/llama-3-8b-bnb-4bit",
|
243 |
-
max_seq_length = max_seq_length,
|
244 |
-
dtype = None,
|
245 |
-
load_in_4bit = True,
|
246 |
-
)
|
247 |
-
|
248 |
-
# Do model patching and add fast LoRA weights
|
249 |
-
model = FastLanguageModel.get_peft_model(
|
250 |
-
model,
|
251 |
-
r = 16,
|
252 |
-
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
253 |
-
"gate_proj", "up_proj", "down_proj",],
|
254 |
-
lora_alpha = 16,
|
255 |
-
lora_dropout = 0, # Supports any, but = 0 is optimized
|
256 |
-
bias = "none", # Supports any, but = "none" is optimized
|
257 |
-
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
258 |
-
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
259 |
-
random_state = 3407,
|
260 |
-
max_seq_length = max_seq_length,
|
261 |
-
use_rslora = False, # We support rank stabilized LoRA
|
262 |
-
loftq_config = None, # And LoftQ
|
263 |
-
)
|
264 |
-
|
265 |
-
trainer = SFTTrainer(
|
266 |
-
model = model,
|
267 |
-
train_dataset = dataset,
|
268 |
-
dataset_text_field = "text",
|
269 |
-
max_seq_length = max_seq_length,
|
270 |
-
tokenizer = tokenizer,
|
271 |
-
args = TrainingArguments(
|
272 |
-
per_device_train_batch_size = 2,
|
273 |
-
gradient_accumulation_steps = 4,
|
274 |
-
warmup_steps = 10,
|
275 |
-
max_steps = 60,
|
276 |
-
fp16 = not is_bfloat16_supported(),
|
277 |
-
bf16 = is_bfloat16_supported(),
|
278 |
-
logging_steps = 1,
|
279 |
-
output_dir = "outputs",
|
280 |
-
optim = "adamw_8bit",
|
281 |
-
seed = 3407,
|
282 |
-
),
|
283 |
-
)
|
284 |
-
trainer.train()
|
285 |
-
|
286 |
-
# Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
|
287 |
-
# (1) Saving to GGUF / merging to 16bit for vLLM
|
288 |
-
# (2) Continued training from a saved LoRA adapter
|
289 |
-
# (3) Adding an evaluation loop / OOMs
|
290 |
-
# (4) Customized chat templates
|
291 |
-
```
|
292 |
-
|
293 |
-
<a name="DPO"></a>
|
294 |
-
## DPO Support
|
295 |
-
DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as per 3rd party independent testing from [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory). We have a preliminary Google Colab notebook for reproducing Zephyr on Tesla T4 here: [notebook](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing).
|
296 |
-
|
297 |
-
We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
298 |
-
|
299 |
-
```python
|
300 |
-
import os
|
301 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # Optional set GPU device ID
|
302 |
-
|
303 |
-
from unsloth import FastLanguageModel, PatchDPOTrainer
|
304 |
-
from unsloth import is_bfloat16_supported
|
305 |
-
PatchDPOTrainer()
|
306 |
-
import torch
|
307 |
-
from transformers import TrainingArguments
|
308 |
-
from trl import DPOTrainer
|
309 |
-
|
310 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
311 |
-
model_name = "unsloth/zephyr-sft-bnb-4bit",
|
312 |
-
max_seq_length = max_seq_length,
|
313 |
-
dtype = None,
|
314 |
-
load_in_4bit = True,
|
315 |
-
)
|
316 |
-
|
317 |
-
# Do model patching and add fast LoRA weights
|
318 |
-
model = FastLanguageModel.get_peft_model(
|
319 |
-
model,
|
320 |
-
r = 64,
|
321 |
-
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
322 |
-
"gate_proj", "up_proj", "down_proj",],
|
323 |
-
lora_alpha = 64,
|
324 |
-
lora_dropout = 0, # Supports any, but = 0 is optimized
|
325 |
-
bias = "none", # Supports any, but = "none" is optimized
|
326 |
-
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
327 |
-
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
328 |
-
random_state = 3407,
|
329 |
-
max_seq_length = max_seq_length,
|
330 |
-
)
|
331 |
-
|
332 |
-
dpo_trainer = DPOTrainer(
|
333 |
-
model = model,
|
334 |
-
ref_model = None,
|
335 |
-
args = TrainingArguments(
|
336 |
-
per_device_train_batch_size = 4,
|
337 |
-
gradient_accumulation_steps = 8,
|
338 |
-
warmup_ratio = 0.1,
|
339 |
-
num_train_epochs = 3,
|
340 |
-
fp16 = not is_bfloat16_supported(),
|
341 |
-
bf16 = is_bfloat16_supported(),
|
342 |
-
logging_steps = 1,
|
343 |
-
optim = "adamw_8bit",
|
344 |
-
seed = 42,
|
345 |
-
output_dir = "outputs",
|
346 |
-
),
|
347 |
-
beta = 0.1,
|
348 |
-
train_dataset = YOUR_DATASET_HERE,
|
349 |
-
# eval_dataset = YOUR_DATASET_HERE,
|
350 |
-
tokenizer = tokenizer,
|
351 |
-
max_length = 1024,
|
352 |
-
max_prompt_length = 512,
|
353 |
-
)
|
354 |
-
dpo_trainer.train()
|
355 |
-
```
|
356 |
-
|
357 |
-
## 🥇 Detailed Benchmarking Tables
|
358 |
-
- Click "Code" for fully reproducible examples
|
359 |
-
- "Unsloth Equal" is a preview of our PRO version, with code stripped out. All settings and the loss curve remains identical.
|
360 |
-
- For the full list of benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
|
361 |
-
|
362 |
-
| 1 A100 40GB | 🤗Hugging Face | Flash Attention 2 | 🦥Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
363 |
-
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
|
364 |
-
| Alpaca | 1x | 1.04x | 1.98x | 2.48x | 5.32x | **15.64x** |
|
365 |
-
| code | [Code](https://colab.research.google.com/drive/1u4dBeM-0vGNVmmO6X7cScAut-Hyt4KDF?usp=sharing) | [Code](https://colab.research.google.com/drive/1fgTOxpMbVjloQBvZyz4lF4BacKSZOB2A?usp=sharing) | [Code](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Code](https://colab.research.google.com/drive/1ANW8EFL3LVyTD7Gq4TkheC1Z7Rxw-rHp?usp=sharing) | | |
|
366 |
-
| seconds| 1040 | 1001 | 525 | 419 | 196 | 67 |
|
367 |
-
| memory MB| 18235 | 15365 | 9631 | 8525 | | |
|
368 |
-
| % saved| | 15.74 | 47.18 | 53.25 | | | |
|
369 |
-
|
370 |
-
### Llama-Factory 3rd party benchmarking
|
371 |
-
- [Link to performance table.](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-Comparison) TGS: tokens per GPU per second. Model: LLaMA2-7B. GPU: NVIDIA A100 * 1. Batch size: 4. Gradient accumulation: 2. LoRA rank: 8. Max length: 1024.
|
372 |
-
|
373 |
-
| Method | Bits | TGS | GRAM | Speed |
|
374 |
-
| --- | --- | --- | --- | --- |
|
375 |
-
| HF | 16 | 2392 | 18GB | 100% |
|
376 |
-
| HF+FA2 | 16 | 2954 | 17GB | 123% |
|
377 |
-
| Unsloth+FA2 | 16 | 4007 | 16GB | **168%** |
|
378 |
-
| HF | 4 | 2415 | 9GB | 101% |
|
379 |
-
| Unsloth+FA2 | 4 | 3726 | 7GB | **160%** |
|
380 |
-
|
381 |
-
### Performance comparisons between popular models
|
382 |
-
<details>
|
383 |
-
<summary>Click for specific model benchmarking tables (Mistral 7b, CodeLlama 34b etc.)</summary>
|
384 |
-
|
385 |
-
### Mistral 7b
|
386 |
-
| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
387 |
-
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
|
388 |
-
| Mistral 7B Slim Orca | 1x | 1.15x | 2.15x | 2.53x | 4.61x | **13.69x** |
|
389 |
-
| code | [Code](https://colab.research.google.com/drive/1mePk3KzwTD81hr5mcNcs_AX3Kbg_Ha0x?usp=sharing) | [Code](https://colab.research.google.com/drive/1dgHxjvTmX6hb0bPcLp26RXSE6_n9DKj7?usp=sharing) | [Code](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | [Code](https://colab.research.google.com/drive/18yOiyX0T81mTwZqOALFSCX_tSAqju6aD?usp=sharing) | |
|
390 |
-
| seconds | 1813 | 1571 | 842 | 718 | 393 | 132 |
|
391 |
-
| memory MB | 32853 | 19385 | 12465 | 10271 | | |
|
392 |
-
| % saved| | 40.99 | 62.06 | 68.74 | | |
|
393 |
-
|
394 |
-
### CodeLlama 34b
|
395 |
-
| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
396 |
-
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
|
397 |
-
| Code Llama 34B | OOM ❌ | 0.99x | 1.87x | 2.61x | 4.27x | 12.82x |
|
398 |
-
| code | [▶️ Code](https://colab.research.google.com/drive/1ykfz3BqrtC_AUFegCzUQjjfUNlxp6Otc?usp=sharing) | [Code](https://colab.research.google.com/drive/12ZypxQh7OC6kBXvWZI-5d05I4m-B_hoR?usp=sharing) | [Code](https://colab.research.google.com/drive/1gdHyAx8XJsz2yNV-DHvbHjR1iCef5Qmh?usp=sharing) | [Code](https://colab.research.google.com/drive/1fm7wqx9MJ0kRrwKOfmLkK1Rmw-pySahB?usp=sharing) | |
|
399 |
-
| seconds | 1953 | 1982 | 1043 | 748 | 458 | 152 |
|
400 |
-
| memory MB | 40000 | 33217 | 27413 | 22161 | | |
|
401 |
-
| % saved| | 16.96| 31.47 | 44.60 | | | |
|
402 |
-
|
403 |
-
### 1 Tesla T4
|
404 |
-
|
405 |
-
| 1 T4 16GB | Hugging Face | Flash Attention | Unsloth Open | Unsloth Pro Equal | Unsloth Pro | Unsloth Max |
|
406 |
-
|--------------|-------------|-----------------|-----------------|---------------|---------------|-------------|
|
407 |
-
| Alpaca | 1x | 1.09x | 1.69x | 1.79x | 2.93x | **8.3x** |
|
408 |
-
| code | [▶️ Code](https://colab.research.google.com/drive/1XpLIV4s8Bj5uryB-X2gqM88oRGHEGdaB?usp=sharing) | [Code](https://colab.research.google.com/drive/1LyXu6CjuymQg6ddHX8g1dpUvrMa1nn4L?usp=sharing) | [Code](https://colab.research.google.com/drive/1gsv4LpY7C32otl1rgRo5wXTk4HIitXoM?usp=sharing) | [Code](https://colab.research.google.com/drive/1VtULwRQwhEnVdNryjm27zXfdSM1tNfFK?usp=sharing) | | |
|
409 |
-
| seconds | 1599 | 1468 | 942 | 894 | 545 | 193 |
|
410 |
-
| memory MB | 7199 | 7059 | 6459 | 5443 | | |
|
411 |
-
| % saved | | 1.94 | 10.28 | 24.39 | | |
|
412 |
-
|
413 |
-
### 2 Tesla T4s via DDP
|
414 |
-
|
415 |
-
| 2 T4 DDP | Hugging Face | Flash Attention | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
416 |
-
|--------------|----------|-------------|-----------------|--------------|---------------|-------------|
|
417 |
-
| Alpaca | 1x | 0.99x | 4.95x | 4.44x | 7.28x | **20.61x** |
|
418 |
-
| code | [▶️ Code](https://www.kaggle.com/danielhanchen/hf-original-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/hf-sdpa-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) | | |
|
419 |
-
| seconds | 9882 | 9946 | 1996 | 2227 | 1357 | 480 |
|
420 |
-
| memory MB| 9176 | 9128 | 6904 | 6782 | | |
|
421 |
-
| % saved | | 0.52 | 24.76 | 26.09 | | | |
|
422 |
-
</details>
|
423 |
-
|
424 |
-
### Performance comparisons on 1 Tesla T4 GPU:
|
425 |
-
<details>
|
426 |
-
<summary>Click for Time taken for 1 epoch</summary>
|
427 |
-
|
428 |
-
One Tesla T4 on Google Colab
|
429 |
-
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
|
430 |
-
|
431 |
-
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
|
432 |
-
| --- | --- | --- | --- | --- | --- |
|
433 |
-
| Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m |
|
434 |
-
| Unsloth Open | 1 T4 | 13h 7m (1.8x) | 31h 47m (1.8x) | 4h 27m (1.9x) | 240h 4m (1.6x) |
|
435 |
-
| Unsloth Pro | 1 T4 | 3h 6m (7.5x) | 5h 17m (10.7x) | 1h 7m (7.7x) | 59h 53m (6.5x) |
|
436 |
-
| Unsloth Max | 1 T4 | 2h 39m (8.8x) | 4h 31m (12.5x) | 0h 58m (8.9x) | 51h 30m (7.6x) |
|
437 |
-
|
438 |
-
**Peak Memory Usage**
|
439 |
-
|
440 |
-
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
|
441 |
-
| --- | --- | --- | --- | --- | --- |
|
442 |
-
| Huggingface | 1 T4 | 7.3GB | 5.9GB | 14.0GB | 13.3GB |
|
443 |
-
| Unsloth Open | 1 T4 | 6.8GB | 5.7GB | 7.8GB | 7.7GB |
|
444 |
-
| Unsloth Pro | 1 T4 | 6.4GB | 6.4GB | 6.4GB | 6.4GB |
|
445 |
-
| Unsloth Max | 1 T4 | 11.4GB | 12.4GB | 11.9GB | 14.4GB |
|
446 |
-
</details>
|
447 |
-
|
448 |
-
<details>
|
449 |
-
<summary>Click for Performance Comparisons on 2 Tesla T4 GPUs via DDP:</summary>
|
450 |
-
**Time taken for 1 epoch**
|
451 |
-
|
452 |
-
Two Tesla T4s on Kaggle
|
453 |
-
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
|
454 |
-
|
455 |
-
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
|
456 |
-
| --- | --- | --- | --- | --- | --- |
|
457 |
-
| Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * |
|
458 |
-
| Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * |
|
459 |
-
| Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * |
|
460 |
-
|
461 |
-
**Peak Memory Usage on a Multi GPU System (2 GPUs)**
|
462 |
-
|
463 |
-
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
|
464 |
-
| --- | --- | --- | --- | --- | --- |
|
465 |
-
| Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * |
|
466 |
-
| Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * |
|
467 |
-
| Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * |
|
468 |
-
|
469 |
-
* Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency.
|
470 |
-
</details>
|
471 |
-
|
472 |
-

|
473 |
-
<br>
|
474 |
-
|
475 |
-
### Citation
|
476 |
-
|
477 |
-
You can cite the Unsloth repo as follows:
|
478 |
-
```bibtex
|
479 |
-
@software{unsloth,
|
480 |
-
author = {Daniel Han, Michael Han and Unsloth team},
|
481 |
-
title = {Unsloth},
|
482 |
-
url = {http://github.com/unslothai/unsloth},
|
483 |
-
year = {2023}
|
484 |
-
}
|
485 |
-
```
|
486 |
-
|
487 |
-
### Thank You to
|
488 |
-
- [Erik](https://github.com/erikwijmans) for his help adding [Apple's ML Cross Entropy](https://github.com/apple/ml-cross-entropy) in Unsloth
|
489 |
-
- [HuyNguyen-hust](https://github.com/HuyNguyen-hust) for making [RoPE Embeddings 28% faster](https://github.com/unslothai/unsloth/pull/238)
|
490 |
-
- [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support
|
491 |
-
- [152334H](https://github.com/152334H) for experimental DPO support
|
492 |
-
- [atgctg](https://github.com/atgctg) for syntax highlighting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/images/Assistant.png
DELETED
Binary file (82.6 kB)
|
|
unsloth-main/images/Colab.png
DELETED
Binary file (11.6 kB)
|
|
unsloth-main/images/Discord button.png
DELETED
Binary file (13.8 kB)
|
|
unsloth-main/images/Discord.png
DELETED
Binary file (13.8 kB)
|
|
unsloth-main/images/Documentation Button.png
DELETED
Binary file (11.8 kB)
|
|
unsloth-main/images/Free version button.png
DELETED
Binary file (9.53 kB)
|
|
unsloth-main/images/Kaggle.png
DELETED
Binary file (9.73 kB)
|
|
unsloth-main/images/Kofi button.png
DELETED
Binary file (18.1 kB)
|
|
unsloth-main/images/LAION 2GPU.png
DELETED
Binary file (51.4 kB)
|
|
unsloth-main/images/Merge.png
DELETED
Binary file (31.4 kB)
|
|
unsloth-main/images/Run.png
DELETED
Binary file (11.5 kB)
|
|
unsloth-main/images/Slim Orca 2GPUs.png
DELETED
Binary file (43.1 kB)
|
|
unsloth-main/images/Terminal_Type.png
DELETED
Binary file (69.4 kB)
|
|
unsloth-main/images/Where_Terminal.png
DELETED
Binary file (179 kB)
|
|
unsloth-main/images/buy me a coffee button.png
DELETED
Binary file (19 kB)
|
|
unsloth-main/images/documentation github button.png
DELETED
Binary file (11.8 kB)
|
|
unsloth-main/images/documentation green button.png
DELETED
Binary file (11.8 kB)
|
|
unsloth-main/images/documentation lighter.png
DELETED
Binary file (11.8 kB)
|
|
unsloth-main/images/documentation white button.png
DELETED
Binary file (11.2 kB)
|
|
unsloth-main/images/made with unsloth.png
DELETED
Binary file (70.4 kB)
|
|
unsloth-main/images/ollama.png
DELETED
Binary file (67.2 kB)
|
|
unsloth-main/images/peft x trl button.png
DELETED
Binary file (36.9 kB)
|
|
unsloth-main/images/start free finetune button.png
DELETED
Binary file (11.4 kB)
|
|
unsloth-main/images/unsloth end.png
DELETED
Binary file (892 kB)
|
|
unsloth-main/images/unsloth loading page render.png
DELETED
Binary file (790 kB)
|
|
unsloth-main/images/unsloth logo black text.png
DELETED
Binary file (58 kB)
|
|
unsloth-main/images/unsloth logo only.png
DELETED
Binary file (57.2 kB)
|
|
unsloth-main/images/unsloth logo white text.png
DELETED
Binary file (59 kB)
|
|
unsloth-main/images/unsloth made with love.png
DELETED
Binary file (63.5 kB)
|
|
unsloth-main/images/unsloth new logo.png
DELETED
Binary file (60.1 kB)
|
|
unsloth-main/pyproject.toml
DELETED
@@ -1,418 +0,0 @@
|
|
1 |
-
[build-system]
|
2 |
-
requires = ["setuptools", "setuptools-scm"]
|
3 |
-
build-backend = "setuptools.build_meta"
|
4 |
-
|
5 |
-
[project]
|
6 |
-
name = "unsloth"
|
7 |
-
dynamic = ["version"]
|
8 |
-
description = "2-5X faster LLM finetuning"
|
9 |
-
readme = "README.md"
|
10 |
-
requires-python = ">=3.9"
|
11 |
-
license = {file = "LICENSE"}
|
12 |
-
keywords = ["ai", "llm",]
|
13 |
-
authors = [
|
14 |
-
{email = "[email protected]"},
|
15 |
-
{name = "Unsloth AI team"},
|
16 |
-
]
|
17 |
-
maintainers = [
|
18 |
-
{name = "Daniel Han", email = "[email protected]"},
|
19 |
-
{name = "Michael Han", email = "[email protected]"},
|
20 |
-
]
|
21 |
-
classifiers = [
|
22 |
-
"Programming Language :: Python",
|
23 |
-
]
|
24 |
-
|
25 |
-
[tool.setuptools.dynamic]
|
26 |
-
version = {attr = "unsloth.models._utils.__version__"}
|
27 |
-
|
28 |
-
[tool.setuptools]
|
29 |
-
include-package-data = false
|
30 |
-
|
31 |
-
[tool.setuptools.packages.find]
|
32 |
-
exclude = ["images*"]
|
33 |
-
|
34 |
-
[project.optional-dependencies]
|
35 |
-
triton = [
|
36 |
-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
|
37 |
-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
|
38 |
-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
39 |
-
"triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
40 |
-
]
|
41 |
-
huggingface = [
|
42 |
-
"unsloth_zoo>=2024.12.7",
|
43 |
-
"packaging",
|
44 |
-
"tyro",
|
45 |
-
"transformers>=4.46.1,!=4.47.0",
|
46 |
-
"datasets>=2.16.0",
|
47 |
-
"sentencepiece>=0.2.0",
|
48 |
-
"tqdm",
|
49 |
-
"psutil",
|
50 |
-
"wheel>=0.42.0",
|
51 |
-
"numpy",
|
52 |
-
"accelerate>=0.34.1",
|
53 |
-
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
|
54 |
-
"peft>=0.7.1,!=0.11.0",
|
55 |
-
"protobuf<4.0.0",
|
56 |
-
"huggingface_hub",
|
57 |
-
"hf_transfer",
|
58 |
-
"unsloth[triton]",
|
59 |
-
]
|
60 |
-
cu118only = [
|
61 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
62 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
63 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
64 |
-
]
|
65 |
-
cu121only = [
|
66 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
67 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
68 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
69 |
-
]
|
70 |
-
cu118onlytorch211 = [
|
71 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
72 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
73 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
74 |
-
]
|
75 |
-
cu121onlytorch211 = [
|
76 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
77 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
78 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
79 |
-
]
|
80 |
-
cu118onlytorch212 = [
|
81 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
82 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
83 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
84 |
-
]
|
85 |
-
cu121onlytorch212 = [
|
86 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
87 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
88 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
89 |
-
]
|
90 |
-
cu118onlytorch220 = [
|
91 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
92 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
93 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
94 |
-
]
|
95 |
-
cu121onlytorch220 = [
|
96 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
97 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
98 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
99 |
-
]
|
100 |
-
cu118onlytorch230 = [
|
101 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
102 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
103 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
104 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
105 |
-
]
|
106 |
-
cu121onlytorch230 = [
|
107 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
108 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
109 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
110 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
111 |
-
]
|
112 |
-
cu118onlytorch240 = [
|
113 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
114 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
115 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
116 |
-
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
117 |
-
]
|
118 |
-
cu121onlytorch240 = [
|
119 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
120 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
121 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
122 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
123 |
-
]
|
124 |
-
cu124onlytorch240 = [
|
125 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
126 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
127 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
128 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
129 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
|
130 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
|
131 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
132 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post1-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
133 |
-
]
|
134 |
-
cu121onlytorch250 = [
|
135 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
136 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
137 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
138 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
139 |
-
]
|
140 |
-
cu124onlytorch250 = [
|
141 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
142 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
143 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
144 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
145 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
|
146 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
|
147 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
148 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post2-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
149 |
-
]
|
150 |
-
cu121onlytorch251 = [
|
151 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
152 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
153 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
154 |
-
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
155 |
-
]
|
156 |
-
cu124onlytorch251 = [
|
157 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-manylinux_2_28_x86_64.whl ; python_version=='3.9' and platform_system == 'Linux'",
|
158 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version=='3.10' and platform_system == 'Linux'",
|
159 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-manylinux_2_28_x86_64.whl ; python_version=='3.11' and platform_system == 'Linux'",
|
160 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-manylinux_2_28_x86_64.whl ; python_version=='3.12' and platform_system == 'Linux'",
|
161 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp39-cp39-win_amd64.whl ; python_version=='3.9' and platform_system == 'Windows'",
|
162 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp310-cp310-win_amd64.whl ; python_version=='3.10' and platform_system == 'Windows'",
|
163 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp311-cp311-win_amd64.whl ; python_version=='3.11' and platform_system == 'Windows'",
|
164 |
-
"xformers @ https://download.pytorch.org/whl/cu124/xformers-0.0.28.post3-cp312-cp312-win_amd64.whl ; python_version=='3.12' and platform_system == 'Windows'",
|
165 |
-
]
|
166 |
-
cu118 = [
|
167 |
-
"unsloth[huggingface]",
|
168 |
-
"bitsandbytes>=0.43.3",
|
169 |
-
"unsloth[cu118only]",
|
170 |
-
]
|
171 |
-
cu121 = [
|
172 |
-
"unsloth[huggingface]",
|
173 |
-
"bitsandbytes>=0.43.3",
|
174 |
-
"unsloth[cu121only]",
|
175 |
-
]
|
176 |
-
cu118-torch211 = [
|
177 |
-
"unsloth[huggingface]",
|
178 |
-
"bitsandbytes>=0.43.3",
|
179 |
-
"unsloth[cu118onlytorch211]",
|
180 |
-
]
|
181 |
-
cu121-torch211 = [
|
182 |
-
"unsloth[huggingface]",
|
183 |
-
"bitsandbytes>=0.43.3",
|
184 |
-
"unsloth[cu121onlytorch211]",
|
185 |
-
]
|
186 |
-
cu118-torch212 = [
|
187 |
-
"unsloth[huggingface]",
|
188 |
-
"bitsandbytes>=0.43.3",
|
189 |
-
"unsloth[cu118onlytorch212]",
|
190 |
-
]
|
191 |
-
cu121-torch212 = [
|
192 |
-
"unsloth[huggingface]",
|
193 |
-
"bitsandbytes>=0.43.3",
|
194 |
-
"unsloth[cu121onlytorch212]",
|
195 |
-
]
|
196 |
-
cu118-torch220 = [
|
197 |
-
"unsloth[huggingface]",
|
198 |
-
"bitsandbytes>=0.43.3",
|
199 |
-
"unsloth[cu118onlytorch220]",
|
200 |
-
]
|
201 |
-
cu121-torch220 = [
|
202 |
-
"unsloth[huggingface]",
|
203 |
-
"bitsandbytes>=0.43.3",
|
204 |
-
"unsloth[cu121onlytorch220]",
|
205 |
-
]
|
206 |
-
cu118-torch230 = [
|
207 |
-
"unsloth[huggingface]",
|
208 |
-
"bitsandbytes>=0.43.3",
|
209 |
-
"unsloth[cu118onlytorch230]",
|
210 |
-
]
|
211 |
-
cu121-torch230 = [
|
212 |
-
"unsloth[huggingface]",
|
213 |
-
"bitsandbytes>=0.43.3",
|
214 |
-
"unsloth[cu121onlytorch230]",
|
215 |
-
]
|
216 |
-
cu118-torch240 = [
|
217 |
-
"unsloth[huggingface]",
|
218 |
-
"bitsandbytes>=0.43.3",
|
219 |
-
"unsloth[cu118onlytorch240]",
|
220 |
-
]
|
221 |
-
cu121-torch240 = [
|
222 |
-
"unsloth[huggingface]",
|
223 |
-
"bitsandbytes>=0.43.3",
|
224 |
-
"unsloth[cu121onlytorch240]",
|
225 |
-
]
|
226 |
-
cu121-torch250 = [
|
227 |
-
"unsloth[huggingface]",
|
228 |
-
"bitsandbytes>=0.43.3",
|
229 |
-
"unsloth[cu121onlytorch250]",
|
230 |
-
]
|
231 |
-
cu124-torch240 = [
|
232 |
-
"unsloth[huggingface]",
|
233 |
-
"bitsandbytes>=0.43.3",
|
234 |
-
"unsloth[cu124onlytorch240]",
|
235 |
-
]
|
236 |
-
cu124-torch250 = [
|
237 |
-
"unsloth[huggingface]",
|
238 |
-
"bitsandbytes>=0.43.3",
|
239 |
-
"unsloth[cu124onlytorch250]",
|
240 |
-
]
|
241 |
-
cu121-torch251 = [
|
242 |
-
"unsloth[huggingface]",
|
243 |
-
"bitsandbytes>=0.43.3",
|
244 |
-
"unsloth[cu121onlytorch251]",
|
245 |
-
]
|
246 |
-
cu124-torch251 = [
|
247 |
-
"unsloth[huggingface]",
|
248 |
-
"bitsandbytes>=0.43.3",
|
249 |
-
"unsloth[cu124onlytorch251]",
|
250 |
-
]
|
251 |
-
kaggle = [
|
252 |
-
"unsloth[huggingface]",
|
253 |
-
]
|
254 |
-
kaggle-new = [
|
255 |
-
"unsloth[huggingface]",
|
256 |
-
"bitsandbytes>=0.43.3",
|
257 |
-
]
|
258 |
-
conda = [
|
259 |
-
"unsloth[huggingface]",
|
260 |
-
]
|
261 |
-
colab-torch211 = [
|
262 |
-
"unsloth[huggingface]",
|
263 |
-
"bitsandbytes>=0.43.3",
|
264 |
-
"unsloth[cu121onlytorch211]",
|
265 |
-
]
|
266 |
-
colab-ampere-torch211 = [
|
267 |
-
"unsloth[huggingface]",
|
268 |
-
"bitsandbytes>=0.43.3",
|
269 |
-
"unsloth[cu121onlytorch211]",
|
270 |
-
"packaging",
|
271 |
-
"ninja",
|
272 |
-
"flash-attn>=2.6.3",
|
273 |
-
]
|
274 |
-
colab-torch220 = [
|
275 |
-
"unsloth[huggingface]",
|
276 |
-
"bitsandbytes>=0.43.3",
|
277 |
-
"unsloth[cu121onlytorch220]",
|
278 |
-
]
|
279 |
-
colab-ampere-torch220 = [
|
280 |
-
"unsloth[huggingface]",
|
281 |
-
"bitsandbytes>=0.43.3",
|
282 |
-
"unsloth[cu121onlytorch220]",
|
283 |
-
"packaging",
|
284 |
-
"ninja",
|
285 |
-
"flash-attn>=2.6.3",
|
286 |
-
]
|
287 |
-
colab-new = [
|
288 |
-
"unsloth_zoo>=2024.12.7",
|
289 |
-
"packaging",
|
290 |
-
"tyro",
|
291 |
-
"transformers>=4.46.1,!=4.47.0",
|
292 |
-
"datasets>=2.16.0",
|
293 |
-
"sentencepiece>=0.2.0",
|
294 |
-
"tqdm",
|
295 |
-
"psutil",
|
296 |
-
"wheel>=0.42.0",
|
297 |
-
"numpy",
|
298 |
-
"protobuf<4.0.0",
|
299 |
-
"huggingface_hub",
|
300 |
-
"hf_transfer",
|
301 |
-
"bitsandbytes>=0.43.3",
|
302 |
-
"unsloth[triton]",
|
303 |
-
]
|
304 |
-
colab-no-deps = [
|
305 |
-
"accelerate>=0.34.1",
|
306 |
-
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3",
|
307 |
-
"peft>=0.7.1",
|
308 |
-
"xformers",
|
309 |
-
"bitsandbytes>=0.46.1",
|
310 |
-
"protobuf<4.0.0",
|
311 |
-
]
|
312 |
-
colab = [
|
313 |
-
"unsloth[cu121]",
|
314 |
-
]
|
315 |
-
flashattention = [
|
316 |
-
"packaging ; platform_system == 'Linux'",
|
317 |
-
"ninja ; platform_system == 'Linux'",
|
318 |
-
"flash-attn>=2.6.3 ; platform_system == 'Linux'",
|
319 |
-
]
|
320 |
-
colab-ampere = [
|
321 |
-
"unsloth[colab-ampere-torch220]",
|
322 |
-
"unsloth[flashattention]",
|
323 |
-
]
|
324 |
-
cu118-ampere = [
|
325 |
-
"unsloth[huggingface]",
|
326 |
-
"bitsandbytes>=0.43.3",
|
327 |
-
"unsloth[cu118only]",
|
328 |
-
"unsloth[flashattention]",
|
329 |
-
]
|
330 |
-
cu121-ampere = [
|
331 |
-
"unsloth[huggingface]",
|
332 |
-
"bitsandbytes>=0.43.3",
|
333 |
-
"unsloth[cu121only]",
|
334 |
-
"unsloth[flashattention]",
|
335 |
-
]
|
336 |
-
cu118-ampere-torch211 = [
|
337 |
-
"unsloth[huggingface]",
|
338 |
-
"bitsandbytes>=0.43.3",
|
339 |
-
"unsloth[cu118onlytorch211]",
|
340 |
-
"unsloth[flashattention]",
|
341 |
-
]
|
342 |
-
cu121-ampere-torch211 = [
|
343 |
-
"unsloth[huggingface]",
|
344 |
-
"bitsandbytes>=0.43.3",
|
345 |
-
"unsloth[cu121onlytorch211]",
|
346 |
-
"unsloth[flashattention]",
|
347 |
-
]
|
348 |
-
cu118-ampere-torch220 = [
|
349 |
-
"unsloth[huggingface]",
|
350 |
-
"bitsandbytes>=0.43.3",
|
351 |
-
"unsloth[cu118onlytorch220]",
|
352 |
-
"unsloth[flashattention]",
|
353 |
-
]
|
354 |
-
cu121-ampere-torch220 = [
|
355 |
-
"unsloth[huggingface]",
|
356 |
-
"bitsandbytes>=0.43.3",
|
357 |
-
"unsloth[cu121onlytorch220]",
|
358 |
-
"unsloth[flashattention]",
|
359 |
-
]
|
360 |
-
cu118-ampere-torch230 = [
|
361 |
-
"unsloth[huggingface]",
|
362 |
-
"bitsandbytes>=0.43.3",
|
363 |
-
"unsloth[cu118onlytorch230]",
|
364 |
-
"unsloth[flashattention]",
|
365 |
-
]
|
366 |
-
cu121-ampere-torch230 = [
|
367 |
-
"unsloth[huggingface]",
|
368 |
-
"bitsandbytes>=0.43.3",
|
369 |
-
"unsloth[cu121onlytorch230]",
|
370 |
-
"unsloth[flashattention]",
|
371 |
-
]
|
372 |
-
cu118-ampere-torch240 = [
|
373 |
-
"unsloth[huggingface]",
|
374 |
-
"bitsandbytes>=0.43.3",
|
375 |
-
"unsloth[cu118onlytorch240]",
|
376 |
-
"unsloth[flashattention]",
|
377 |
-
]
|
378 |
-
cu121-ampere-torch240 = [
|
379 |
-
"unsloth[huggingface]",
|
380 |
-
"bitsandbytes>=0.43.3",
|
381 |
-
"unsloth[cu121onlytorch240]",
|
382 |
-
"unsloth[flashattention]",
|
383 |
-
]
|
384 |
-
cu121-ampere-torch250 = [
|
385 |
-
"unsloth[huggingface]",
|
386 |
-
"bitsandbytes>=0.43.3",
|
387 |
-
"unsloth[cu121onlytorch250]",
|
388 |
-
"unsloth[flashattention]",
|
389 |
-
]
|
390 |
-
cu124-ampere-torch240 = [
|
391 |
-
"unsloth[huggingface]",
|
392 |
-
"bitsandbytes>=0.43.3",
|
393 |
-
"unsloth[cu124onlytorch240]",
|
394 |
-
"unsloth[flashattention]",
|
395 |
-
]
|
396 |
-
cu124-ampere-torch250 = [
|
397 |
-
"unsloth[huggingface]",
|
398 |
-
"bitsandbytes>=0.43.3",
|
399 |
-
"unsloth[cu124onlytorch250]",
|
400 |
-
"unsloth[flashattention]",
|
401 |
-
]
|
402 |
-
cu121-ampere-torch251 = [
|
403 |
-
"unsloth[huggingface]",
|
404 |
-
"bitsandbytes>=0.43.3",
|
405 |
-
"unsloth[cu121onlytorch251]",
|
406 |
-
"unsloth[flashattention]",
|
407 |
-
]
|
408 |
-
cu124-ampere-torch251 = [
|
409 |
-
"unsloth[huggingface]",
|
410 |
-
"bitsandbytes>=0.43.3",
|
411 |
-
"unsloth[cu124onlytorch251]",
|
412 |
-
"unsloth[flashattention]",
|
413 |
-
]
|
414 |
-
|
415 |
-
[project.urls]
|
416 |
-
homepage = "http://www.unsloth.ai"
|
417 |
-
documentation = "https://github.com/unslothai/unsloth"
|
418 |
-
repository = "https://github.com/unslothai/unsloth"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth-cli.py
DELETED
@@ -1,221 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python3
|
2 |
-
|
3 |
-
"""
|
4 |
-
🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
|
5 |
-
|
6 |
-
This script is designed as a starting point for fine-tuning your models using unsloth.
|
7 |
-
It includes configurable options for model loading, PEFT parameters, training arguments,
|
8 |
-
and model saving/pushing functionalities.
|
9 |
-
|
10 |
-
You will likely want to customize this script to suit your specific use case
|
11 |
-
and requirements.
|
12 |
-
|
13 |
-
Here are a few suggestions for customization:
|
14 |
-
- Modify the dataset loading and preprocessing steps to match your data.
|
15 |
-
- Customize the model saving and pushing configurations.
|
16 |
-
|
17 |
-
Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
|
18 |
-
python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
|
19 |
-
--r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
|
20 |
-
--random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
|
21 |
-
--warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
|
22 |
-
--weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
|
23 |
-
--report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
|
24 |
-
--push_model --hub_path "hf/model" --hub_token "your_hf_token"
|
25 |
-
|
26 |
-
To see a full list of configurable options, use:
|
27 |
-
python unsloth-cli.py --help
|
28 |
-
|
29 |
-
Happy fine-tuning!
|
30 |
-
"""
|
31 |
-
|
32 |
-
import argparse
|
33 |
-
|
34 |
-
def run(args):
|
35 |
-
import torch
|
36 |
-
from unsloth import FastLanguageModel
|
37 |
-
from datasets import load_dataset
|
38 |
-
from trl import SFTTrainer
|
39 |
-
from transformers import TrainingArguments
|
40 |
-
from unsloth import is_bfloat16_supported
|
41 |
-
import logging
|
42 |
-
logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
|
43 |
-
|
44 |
-
# Load model and tokenizer
|
45 |
-
model, tokenizer = FastLanguageModel.from_pretrained(
|
46 |
-
model_name=args.model_name,
|
47 |
-
max_seq_length=args.max_seq_length,
|
48 |
-
dtype=args.dtype,
|
49 |
-
load_in_4bit=args.load_in_4bit,
|
50 |
-
)
|
51 |
-
|
52 |
-
# Configure PEFT model
|
53 |
-
model = FastLanguageModel.get_peft_model(
|
54 |
-
model,
|
55 |
-
r=args.r,
|
56 |
-
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
57 |
-
"gate_proj", "up_proj", "down_proj"],
|
58 |
-
lora_alpha=args.lora_alpha,
|
59 |
-
lora_dropout=args.lora_dropout,
|
60 |
-
bias=args.bias,
|
61 |
-
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
62 |
-
random_state=args.random_state,
|
63 |
-
use_rslora=args.use_rslora,
|
64 |
-
loftq_config=args.loftq_config,
|
65 |
-
)
|
66 |
-
|
67 |
-
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
68 |
-
|
69 |
-
### Instruction:
|
70 |
-
{}
|
71 |
-
|
72 |
-
### Input:
|
73 |
-
{}
|
74 |
-
|
75 |
-
### Response:
|
76 |
-
{}"""
|
77 |
-
|
78 |
-
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
|
79 |
-
def formatting_prompts_func(examples):
|
80 |
-
instructions = examples["instruction"]
|
81 |
-
inputs = examples["input"]
|
82 |
-
outputs = examples["output"]
|
83 |
-
texts = []
|
84 |
-
for instruction, input, output in zip(instructions, inputs, outputs):
|
85 |
-
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
|
86 |
-
texts.append(text)
|
87 |
-
return {"text": texts}
|
88 |
-
|
89 |
-
# Load and format dataset
|
90 |
-
dataset = load_dataset(args.dataset, split="train")
|
91 |
-
dataset = dataset.map(formatting_prompts_func, batched=True)
|
92 |
-
print("Data is formatted and ready!")
|
93 |
-
|
94 |
-
# Configure training arguments
|
95 |
-
training_args = TrainingArguments(
|
96 |
-
per_device_train_batch_size=args.per_device_train_batch_size,
|
97 |
-
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
98 |
-
warmup_steps=args.warmup_steps,
|
99 |
-
max_steps=args.max_steps,
|
100 |
-
learning_rate=args.learning_rate,
|
101 |
-
fp16=not is_bfloat16_supported(),
|
102 |
-
bf16=is_bfloat16_supported(),
|
103 |
-
logging_steps=args.logging_steps,
|
104 |
-
optim=args.optim,
|
105 |
-
weight_decay=args.weight_decay,
|
106 |
-
lr_scheduler_type=args.lr_scheduler_type,
|
107 |
-
seed=args.seed,
|
108 |
-
output_dir=args.output_dir,
|
109 |
-
report_to=args.report_to,
|
110 |
-
)
|
111 |
-
|
112 |
-
# Initialize trainer
|
113 |
-
trainer = SFTTrainer(
|
114 |
-
model=model,
|
115 |
-
tokenizer=tokenizer,
|
116 |
-
train_dataset=dataset,
|
117 |
-
dataset_text_field="text",
|
118 |
-
max_seq_length=args.max_seq_length,
|
119 |
-
dataset_num_proc=2,
|
120 |
-
packing=False,
|
121 |
-
args=training_args,
|
122 |
-
)
|
123 |
-
|
124 |
-
# Train model
|
125 |
-
trainer_stats = trainer.train()
|
126 |
-
|
127 |
-
# Save model
|
128 |
-
if args.save_model:
|
129 |
-
# if args.quantization_method is a list, we will save the model for each quantization method
|
130 |
-
if args.save_gguf:
|
131 |
-
if isinstance(args.quantization, list):
|
132 |
-
for quantization_method in args.quantization:
|
133 |
-
print(f"Saving model with quantization method: {quantization_method}")
|
134 |
-
model.save_pretrained_gguf(
|
135 |
-
args.save_path,
|
136 |
-
tokenizer,
|
137 |
-
quantization_method=quantization_method,
|
138 |
-
)
|
139 |
-
if args.push_model:
|
140 |
-
model.push_to_hub_gguf(
|
141 |
-
hub_path=args.hub_path,
|
142 |
-
hub_token=args.hub_token,
|
143 |
-
quantization_method=quantization_method,
|
144 |
-
)
|
145 |
-
else:
|
146 |
-
print(f"Saving model with quantization method: {args.quantization}")
|
147 |
-
model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
|
148 |
-
if args.push_model:
|
149 |
-
model.push_to_hub_gguf(
|
150 |
-
hub_path=args.hub_path,
|
151 |
-
hub_token=args.hub_token,
|
152 |
-
quantization_method=quantization_method,
|
153 |
-
)
|
154 |
-
else:
|
155 |
-
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
|
156 |
-
if args.push_model:
|
157 |
-
model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
|
158 |
-
else:
|
159 |
-
print("Warning: The model is not saved!")
|
160 |
-
|
161 |
-
|
162 |
-
if __name__ == "__main__":
|
163 |
-
|
164 |
-
# Define argument parser
|
165 |
-
parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")
|
166 |
-
|
167 |
-
model_group = parser.add_argument_group("🤖 Model Options")
|
168 |
-
model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
|
169 |
-
model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
|
170 |
-
model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
|
171 |
-
model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
|
172 |
-
model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
|
173 |
-
|
174 |
-
lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
|
175 |
-
lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
|
176 |
-
lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
|
177 |
-
lora_group.add_argument('--lora_dropout', type=float, default=0, help="LoRA dropout rate, default is 0.0 which is optimized.")
|
178 |
-
lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
|
179 |
-
lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
|
180 |
-
lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
|
181 |
-
lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
|
182 |
-
lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
|
183 |
-
|
184 |
-
|
185 |
-
training_group = parser.add_argument_group("🎓 Training Options")
|
186 |
-
training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
|
187 |
-
training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
|
188 |
-
training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
|
189 |
-
training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
|
190 |
-
training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
|
191 |
-
training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
|
192 |
-
training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
|
193 |
-
training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
|
194 |
-
training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
|
195 |
-
|
196 |
-
|
197 |
-
# Report/Logging arguments
|
198 |
-
report_group = parser.add_argument_group("📊 Report Options")
|
199 |
-
report_group.add_argument('--report_to', type=str, default="tensorboard",
|
200 |
-
choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
|
201 |
-
help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
|
202 |
-
report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
|
203 |
-
|
204 |
-
# Saving and pushing arguments
|
205 |
-
save_group = parser.add_argument_group('💾 Save Model Options')
|
206 |
-
save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
|
207 |
-
save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
|
208 |
-
save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
|
209 |
-
save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
|
210 |
-
save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
|
211 |
-
save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
|
212 |
-
help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
|
213 |
-
|
214 |
-
push_group = parser.add_argument_group('🚀 Push Model Options')
|
215 |
-
push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
|
216 |
-
push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
|
217 |
-
push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
|
218 |
-
push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
|
219 |
-
|
220 |
-
args = parser.parse_args()
|
221 |
-
run(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/__init__.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import warnings, importlib, sys
|
16 |
-
from packaging.version import Version
|
17 |
-
import os, re, subprocess, inspect
|
18 |
-
import numpy as np
|
19 |
-
|
20 |
-
# # Define a list of modules to check
|
21 |
-
# MODULES_TO_CHECK = ["bitsandbytes"]
|
22 |
-
|
23 |
-
# # Check if any of the modules in the list have been imported
|
24 |
-
# for module in MODULES_TO_CHECK:
|
25 |
-
# if module in sys.modules:
|
26 |
-
# raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
|
27 |
-
# pass
|
28 |
-
# pass
|
29 |
-
|
30 |
-
# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
|
31 |
-
# enabling it will require much more work, so we have to prioritize. Please understand!
|
32 |
-
# We do have a beta version, which you can contact us about!
|
33 |
-
# Thank you for your understanding and we appreciate it immensely!
|
34 |
-
|
35 |
-
# Fixes https://github.com/unslothai/unsloth/issues/1266
|
36 |
-
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
37 |
-
|
38 |
-
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
39 |
-
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
40 |
-
devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
41 |
-
# Check if there are multiple cuda devices set in env
|
42 |
-
if not devices.isdigit():
|
43 |
-
first_id = devices.split(",")[0]
|
44 |
-
warnings.warn(
|
45 |
-
f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
|
46 |
-
"Unsloth currently does not support multi GPU setups - but we are working on it!\n"\
|
47 |
-
"Multiple CUDA devices detected but we require a single device.\n"\
|
48 |
-
f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
|
49 |
-
)
|
50 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
|
51 |
-
else:
|
52 |
-
# warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
|
53 |
-
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
54 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
55 |
-
pass
|
56 |
-
|
57 |
-
# Reduce VRAM usage by reducing fragmentation
|
58 |
-
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,roundup_power2_divisions:[64:128,256:64,>:32]"
|
59 |
-
|
60 |
-
# Hugging Face Hub faster downloads
|
61 |
-
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
62 |
-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
63 |
-
pass
|
64 |
-
|
65 |
-
# Log Unsloth is being used
|
66 |
-
os.environ["UNSLOTH_IS_PRESENT"] = "1"
|
67 |
-
|
68 |
-
try:
|
69 |
-
import torch
|
70 |
-
except ModuleNotFoundError:
|
71 |
-
raise ImportError(
|
72 |
-
"Unsloth: Pytorch is not installed. Go to https://pytorch.org/.\n"\
|
73 |
-
"We have some installation instructions on our Github page."
|
74 |
-
)
|
75 |
-
except Exception as exception:
|
76 |
-
raise exception
|
77 |
-
pass
|
78 |
-
|
79 |
-
# We support Pytorch 2
|
80 |
-
# Fixes https://github.com/unslothai/unsloth/issues/38
|
81 |
-
torch_version = torch.__version__.split(".")
|
82 |
-
major_torch, minor_torch = torch_version[0], torch_version[1]
|
83 |
-
major_torch, minor_torch = int(major_torch), int(minor_torch)
|
84 |
-
if (major_torch < 2):
|
85 |
-
raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
|
86 |
-
"We have some installation instructions on our Github page.")
|
87 |
-
elif (major_torch == 2) and (minor_torch < 2):
|
88 |
-
# Disable expandable_segments
|
89 |
-
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
|
90 |
-
pass
|
91 |
-
|
92 |
-
# Torch 2.4 has including_emulation
|
93 |
-
major_version, minor_version = torch.cuda.get_device_capability()
|
94 |
-
SUPPORTS_BFLOAT16 = (major_version >= 8)
|
95 |
-
|
96 |
-
old_is_bf16_supported = torch.cuda.is_bf16_supported
|
97 |
-
if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
|
98 |
-
def is_bf16_supported(including_emulation = False):
|
99 |
-
return old_is_bf16_supported(including_emulation)
|
100 |
-
torch.cuda.is_bf16_supported = is_bf16_supported
|
101 |
-
else:
|
102 |
-
def is_bf16_supported(): return SUPPORTS_BFLOAT16
|
103 |
-
torch.cuda.is_bf16_supported = is_bf16_supported
|
104 |
-
pass
|
105 |
-
|
106 |
-
# Try loading bitsandbytes and triton
|
107 |
-
import bitsandbytes as bnb
|
108 |
-
|
109 |
-
if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
|
110 |
-
|
111 |
-
import triton
|
112 |
-
libcuda_dirs = lambda: None
|
113 |
-
if Version(triton.__version__) >= Version("3.0.0"):
|
114 |
-
try: from triton.backends.nvidia.driver import libcuda_dirs
|
115 |
-
except: pass
|
116 |
-
else: from triton.common.build import libcuda_dirs
|
117 |
-
|
118 |
-
try:
|
119 |
-
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
120 |
-
libcuda_dirs()
|
121 |
-
except:
|
122 |
-
warnings.warn(
|
123 |
-
"Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
|
124 |
-
)
|
125 |
-
|
126 |
-
if os.path.exists("/usr/lib64-nvidia"):
|
127 |
-
os.system("ldconfig /usr/lib64-nvidia")
|
128 |
-
elif os.path.exists("/usr/local"):
|
129 |
-
# Sometimes bitsandbytes cannot be linked properly in Runpod for example
|
130 |
-
possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
|
131 |
-
find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
|
132 |
-
possible_cudas = [find_cuda.search(x) for x in possible_cudas]
|
133 |
-
possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
|
134 |
-
|
135 |
-
# Try linking cuda folder, or everything in local
|
136 |
-
if len(possible_cudas) == 0:
|
137 |
-
os.system("ldconfig /usr/local/")
|
138 |
-
else:
|
139 |
-
find_number = re.compile(r"([\d\.]{2,})")
|
140 |
-
latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
|
141 |
-
latest_cuda = possible_cudas[latest_cuda]
|
142 |
-
os.system(f"ldconfig /usr/local/{latest_cuda}")
|
143 |
-
pass
|
144 |
-
|
145 |
-
importlib.reload(bnb)
|
146 |
-
importlib.reload(triton)
|
147 |
-
try:
|
148 |
-
libcuda_dirs = lambda: None
|
149 |
-
if Version(triton.__version__) >= Version("3.0.0"):
|
150 |
-
try: from triton.backends.nvidia.driver import libcuda_dirs
|
151 |
-
except: pass
|
152 |
-
else: from triton.common.build import libcuda_dirs
|
153 |
-
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
154 |
-
libcuda_dirs()
|
155 |
-
except:
|
156 |
-
warnings.warn(
|
157 |
-
"Unsloth: CUDA is not linked properly.\n"\
|
158 |
-
"Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
|
159 |
-
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
|
160 |
-
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
|
161 |
-
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
|
162 |
-
"Unsloth will still run for now, but maybe it might crash - let's hope it works!"
|
163 |
-
)
|
164 |
-
pass
|
165 |
-
pass
|
166 |
-
|
167 |
-
# Check for unsloth_zoo
|
168 |
-
try:
|
169 |
-
import unsloth_zoo
|
170 |
-
except:
|
171 |
-
raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`")
|
172 |
-
pass
|
173 |
-
|
174 |
-
from .models import *
|
175 |
-
from .save import *
|
176 |
-
from .chat_templates import *
|
177 |
-
from .tokenizer_utils import *
|
178 |
-
from .trainer import *
|
179 |
-
|
180 |
-
# Patch TRL trainers for backwards compatibility
|
181 |
-
_patch_trl_trainer()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/_auto_install.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
try: import torch
|
16 |
-
except: raise ImportError('Install torch via `pip install torch`')
|
17 |
-
from packaging.version import Version as V
|
18 |
-
v = V(torch.__version__)
|
19 |
-
cuda = str(torch.version.cuda)
|
20 |
-
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
21 |
-
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4": raise RuntimeError(f"CUDA = {cuda} not supported!")
|
22 |
-
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
|
23 |
-
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
|
24 |
-
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
|
25 |
-
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
|
26 |
-
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
|
27 |
-
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
|
28 |
-
elif v < V('2.6.0'): x = 'cu{}{}-torch250'
|
29 |
-
else: raise RuntimeError(f"Torch = {v} too new!")
|
30 |
-
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
31 |
-
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/chat_templates.py
DELETED
@@ -1,2105 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
__all__ = [
|
16 |
-
"get_chat_template",
|
17 |
-
"test_chat_templates",
|
18 |
-
"test_hf_gguf_equivalence",
|
19 |
-
"remove_special_tokens",
|
20 |
-
|
21 |
-
"to_sharegpt",
|
22 |
-
"standardize_sharegpt",
|
23 |
-
"apply_chat_template",
|
24 |
-
"train_on_responses_only",
|
25 |
-
|
26 |
-
"test_construct_chat_template",
|
27 |
-
]
|
28 |
-
|
29 |
-
from transformers import StoppingCriteria, StoppingCriteriaList
|
30 |
-
from torch import LongTensor, FloatTensor
|
31 |
-
from transformers.models.llama.modeling_llama import logger
|
32 |
-
from .save import patch_saving_functions
|
33 |
-
import os
|
34 |
-
import shutil
|
35 |
-
from .tokenizer_utils import *
|
36 |
-
from .models._utils import patch_tokenizer
|
37 |
-
import re
|
38 |
-
from unsloth_zoo.dataset_utils import (
|
39 |
-
train_on_responses_only,
|
40 |
-
)
|
41 |
-
CHAT_TEMPLATES = {}
|
42 |
-
DEFAULT_SYSTEM_MESSAGE = {}
|
43 |
-
|
44 |
-
# =========================================== Unsloth
|
45 |
-
# Unsloth efficient template leverages from Zephyr
|
46 |
-
unsloth_template = \
|
47 |
-
"{{ bos_token }}"\
|
48 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
49 |
-
"{{ messages[0]['content'] + '\n' }}"\
|
50 |
-
"{% set loop_messages = messages[1:] %}"\
|
51 |
-
"{% else %}"\
|
52 |
-
"{{ '{system_message}' + '\n' }}"\
|
53 |
-
"{% set loop_messages = messages %}"\
|
54 |
-
"{% endif %}"\
|
55 |
-
"{% for message in loop_messages %}"\
|
56 |
-
"{% if message['role'] == 'user' %}"\
|
57 |
-
"{{ '>>> User: ' + message['content'] + '\n' }}"\
|
58 |
-
"{% elif message['role'] == 'assistant' %}"\
|
59 |
-
"{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
|
60 |
-
"{% else %}"\
|
61 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
62 |
-
"{% endif %}"\
|
63 |
-
"{% endfor %}"\
|
64 |
-
"{% if add_generation_prompt %}"\
|
65 |
-
"{{ '>>> Assistant: ' }}"\
|
66 |
-
"{% endif %}"
|
67 |
-
pass
|
68 |
-
|
69 |
-
unsloth_ollama = \
|
70 |
-
'''
|
71 |
-
FROM {__FILE_LOCATION__}
|
72 |
-
TEMPLATE """{{ if .System }}{{ .System }}
|
73 |
-
{{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
|
74 |
-
{{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
|
75 |
-
"""
|
76 |
-
PARAMETER stop "{__EOS_TOKEN__}"
|
77 |
-
PARAMETER temperature 1.5
|
78 |
-
PARAMETER min_p 0.1
|
79 |
-
SYSTEM """You are a helpful assistant to the user"""
|
80 |
-
'''
|
81 |
-
|
82 |
-
unsloth_eos_token = "eos_token"
|
83 |
-
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
|
84 |
-
DEFAULT_SYSTEM_MESSAGE["unsloth"] = "You are a helpful assistant to the user"
|
85 |
-
pass
|
86 |
-
|
87 |
-
# =========================================== Zephyr
|
88 |
-
# Zephyr has no BOS!
|
89 |
-
zephyr_template = \
|
90 |
-
"{% for message in messages %}"\
|
91 |
-
"{% if message['role'] == 'user' %}"\
|
92 |
-
"{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
|
93 |
-
"{% elif message['role'] == 'assistant' %}"\
|
94 |
-
"{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
|
95 |
-
"{% else %}"\
|
96 |
-
"{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
|
97 |
-
"{% endif %}"\
|
98 |
-
"{% endfor %}"\
|
99 |
-
"{% if add_generation_prompt %}"\
|
100 |
-
"{{ '<|assistant|>\n' }}"\
|
101 |
-
"{% endif %}"
|
102 |
-
pass
|
103 |
-
|
104 |
-
zephyr_ollama = \
|
105 |
-
'''
|
106 |
-
FROM {__FILE_LOCATION__}
|
107 |
-
TEMPLATE """{{ if .System }}<|system|>
|
108 |
-
{{ .System }}{__EOS_TOKEN__}
|
109 |
-
{{ end }}{{ if .Prompt }}<|user|>
|
110 |
-
{{ .Prompt }}{__EOS_TOKEN__}
|
111 |
-
{{ end }}<|assistant|>
|
112 |
-
{{ .Response }}{__EOS_TOKEN__}
|
113 |
-
"""
|
114 |
-
PARAMETER stop "{__EOS_TOKEN__}"
|
115 |
-
PARAMETER temperature 1.5
|
116 |
-
PARAMETER min_p 0.1
|
117 |
-
'''
|
118 |
-
|
119 |
-
zephyr_eos_token = "eos_token"
|
120 |
-
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
|
121 |
-
DEFAULT_SYSTEM_MESSAGE["zephyr"] = None # No system message in Zephyr
|
122 |
-
pass
|
123 |
-
|
124 |
-
# =========================================== ChatML
|
125 |
-
# ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
|
126 |
-
chatml_template = \
|
127 |
-
"{% for message in messages %}"\
|
128 |
-
"{% if message['role'] == 'user' %}"\
|
129 |
-
"{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
|
130 |
-
"{% elif message['role'] == 'assistant' %}"\
|
131 |
-
"{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
|
132 |
-
"{% else %}"\
|
133 |
-
"{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
|
134 |
-
"{% endif %}"\
|
135 |
-
"{% endfor %}"\
|
136 |
-
"{% if add_generation_prompt %}"\
|
137 |
-
"{{ '<|im_start|>assistant\n' }}"\
|
138 |
-
"{% endif %}"
|
139 |
-
pass
|
140 |
-
|
141 |
-
chatml_ollama = \
|
142 |
-
'''
|
143 |
-
FROM {__FILE_LOCATION__}
|
144 |
-
TEMPLATE """{{ if .System }}<|im_start|>system
|
145 |
-
{{ .System }}<|im_end|>
|
146 |
-
{{ end }}{{ if .Prompt }}<|im_start|>user
|
147 |
-
{{ .Prompt }}<|im_end|>
|
148 |
-
{{ end }}<|im_start|>assistant
|
149 |
-
{{ .Response }}<|im_end|>
|
150 |
-
"""
|
151 |
-
PARAMETER stop "<|im_start|>"
|
152 |
-
PARAMETER stop "<|im_end|>"
|
153 |
-
PARAMETER temperature 1.5
|
154 |
-
PARAMETER min_p 0.1
|
155 |
-
'''
|
156 |
-
|
157 |
-
chatml_eos_token = "<|im_end|>"
|
158 |
-
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
|
159 |
-
DEFAULT_SYSTEM_MESSAGE["chatml"] = None # No system message in ChatML
|
160 |
-
pass
|
161 |
-
|
162 |
-
# =========================================== Mistral-1
|
163 |
-
# Mistral Instruct doesn't allow system prompts, so we append it to the user message.
|
164 |
-
mistral_template = \
|
165 |
-
"{{ bos_token }}"\
|
166 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
167 |
-
"{% if messages[1]['role'] == 'user' %}"\
|
168 |
-
"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
|
169 |
-
"{% set loop_messages = messages[2:] %}"\
|
170 |
-
"{% else %}"\
|
171 |
-
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
|
172 |
-
"{% set loop_messages = messages[1:] %}"\
|
173 |
-
"{% endif %}"\
|
174 |
-
"{% else %}"\
|
175 |
-
"{% set loop_messages = messages %}"\
|
176 |
-
"{% endif %}"\
|
177 |
-
"{% for message in loop_messages %}"\
|
178 |
-
"{% if message['role'] == 'user' %}"\
|
179 |
-
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
|
180 |
-
"{% elif message['role'] == 'assistant' %}"\
|
181 |
-
"{{ message['content'] + eos_token }}"\
|
182 |
-
"{% else %}"\
|
183 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
184 |
-
"{% endif %}"\
|
185 |
-
"{% endfor %}"
|
186 |
-
pass
|
187 |
-
|
188 |
-
# Ollama from https://www.ollama.com/library/mistral
|
189 |
-
mistral_ollama = \
|
190 |
-
'''
|
191 |
-
FROM {__FILE_LOCATION__}
|
192 |
-
TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
|
193 |
-
PARAMETER stop "{__EOS_TOKEN__}"
|
194 |
-
PARAMETER temperature 1.5
|
195 |
-
PARAMETER min_p 0.1
|
196 |
-
'''
|
197 |
-
|
198 |
-
mistral_eos_token = "eos_token"
|
199 |
-
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
|
200 |
-
DEFAULT_SYSTEM_MESSAGE["mistral"] = None # No system message in Mistral
|
201 |
-
pass
|
202 |
-
|
203 |
-
# =========================================== Llama-2
|
204 |
-
# Adds BOS to every convo! And weird <<SYS>> system messages.
|
205 |
-
llama_template = \
|
206 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
207 |
-
"{% if messages[1]['role'] == 'user' %}"\
|
208 |
-
"{{ bos_token + '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
|
209 |
-
"{% set loop_messages = messages[2:] %}"\
|
210 |
-
"{% else %}"\
|
211 |
-
"{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
|
212 |
-
"{% set loop_messages = messages[1:] %}"\
|
213 |
-
"{% endif %}"\
|
214 |
-
"{% else %}"\
|
215 |
-
"{% set loop_messages = messages %}"\
|
216 |
-
"{% endif %}"\
|
217 |
-
"{% for message in loop_messages %}"\
|
218 |
-
"{% if message['role'] == 'user' %}"\
|
219 |
-
"{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
|
220 |
-
"{% elif message['role'] == 'assistant' %}"\
|
221 |
-
"{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
|
222 |
-
"{% else %}"\
|
223 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
224 |
-
"{% endif %}"\
|
225 |
-
"{% endfor %}"
|
226 |
-
pass
|
227 |
-
|
228 |
-
# Ollama from https://www.ollama.com/library/llama3
|
229 |
-
llama_ollama = \
|
230 |
-
'''
|
231 |
-
FROM {__FILE_LOCATION__}
|
232 |
-
TEMPLATE """[INST] <<SYS>>{{ .System }}<</SYS>>
|
233 |
-
|
234 |
-
{{ .Prompt }} [/INST]"""
|
235 |
-
PARAMETER stop "{__EOS_TOKEN__}"
|
236 |
-
PARAMETER temperature 1.5
|
237 |
-
PARAMETER min_p 0.1
|
238 |
-
'''
|
239 |
-
|
240 |
-
llama_eos_token = "eos_token"
|
241 |
-
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
|
242 |
-
DEFAULT_SYSTEM_MESSAGE["llama"] = None # No system message in Llama
|
243 |
-
pass
|
244 |
-
|
245 |
-
# =========================================== Vicuna
|
246 |
-
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
247 |
-
vicuna_template = \
|
248 |
-
"{{ bos_token }}"\
|
249 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
250 |
-
"{{ messages[0]['content'] + ' ' }}"\
|
251 |
-
"{% set loop_messages = messages[1:] %}"\
|
252 |
-
"{% else %}"\
|
253 |
-
"{{ '{system_message}' + ' ' }}"\
|
254 |
-
"{% set loop_messages = messages %}"\
|
255 |
-
"{% endif %}"\
|
256 |
-
"{% for message in loop_messages %}"\
|
257 |
-
"{% if message['role'] == 'user' %}"\
|
258 |
-
"{{ 'USER: ' + message['content'] + ' ' }}"\
|
259 |
-
"{% elif message['role'] == 'assistant' %}"\
|
260 |
-
"{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
|
261 |
-
"{% else %}"\
|
262 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
263 |
-
"{% endif %}"\
|
264 |
-
"{% endfor %}"\
|
265 |
-
"{% if add_generation_prompt %}"\
|
266 |
-
"{{ 'ASSISTANT:' }}"\
|
267 |
-
"{% endif %}"
|
268 |
-
pass
|
269 |
-
|
270 |
-
# Ollama from https://www.ollama.com/library/vicuna
|
271 |
-
vicuna_ollama = \
|
272 |
-
'''
|
273 |
-
FROM {__FILE_LOCATION__}
|
274 |
-
TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
|
275 |
-
PARAMETER stop "{__EOS_TOKEN__}"
|
276 |
-
PARAMETER temperature 1.5
|
277 |
-
PARAMETER min_p 0.1
|
278 |
-
'''
|
279 |
-
|
280 |
-
vicuna_eos_token = "eos_token"
|
281 |
-
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
|
282 |
-
DEFAULT_SYSTEM_MESSAGE["vicuna"] = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
|
283 |
-
pass
|
284 |
-
|
285 |
-
# =========================================== Vicuna Old
|
286 |
-
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
287 |
-
vicuna_old_template = \
|
288 |
-
"{{ bos_token }}"\
|
289 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
290 |
-
"{{ messages[0]['content'] + '\n' }}"\
|
291 |
-
"{% set loop_messages = messages[1:] %}"\
|
292 |
-
"{% else %}"\
|
293 |
-
"{{ '{system_message}' + '\n' }}"\
|
294 |
-
"{% set loop_messages = messages %}"\
|
295 |
-
"{% endif %}"\
|
296 |
-
"{% for message in loop_messages %}"\
|
297 |
-
"{% if message['role'] == 'user' %}"\
|
298 |
-
"{{ '### Human: ' + message['content'] + '\n' }}"\
|
299 |
-
"{% elif message['role'] == 'assistant' %}"\
|
300 |
-
"{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
|
301 |
-
"{% else %}"\
|
302 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
303 |
-
"{% endif %}"\
|
304 |
-
"{% endfor %}"\
|
305 |
-
"{% if add_generation_prompt %}"\
|
306 |
-
"{{ '### Assistant:' }}"\
|
307 |
-
"{% endif %}"
|
308 |
-
pass
|
309 |
-
|
310 |
-
vicuna_old_ollama = \
|
311 |
-
'''
|
312 |
-
FROM {__FILE_LOCATION__}
|
313 |
-
TEMPLATE """{{ if .System }}{{ .System }}
|
314 |
-
{{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
|
315 |
-
{{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
|
316 |
-
"""
|
317 |
-
PARAMETER stop "{__EOS_TOKEN__}"
|
318 |
-
PARAMETER temperature 1.5
|
319 |
-
PARAMETER min_p 0.1
|
320 |
-
SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
|
321 |
-
'''
|
322 |
-
|
323 |
-
vicuna_old_eos_token = "eos_token"
|
324 |
-
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
|
325 |
-
DEFAULT_SYSTEM_MESSAGE["vicuna_old"] = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions."
|
326 |
-
|
327 |
-
CHAT_TEMPLATES["vicuna old"] = CHAT_TEMPLATES["vicuna_old"]
|
328 |
-
DEFAULT_SYSTEM_MESSAGE["vicuna old"] = DEFAULT_SYSTEM_MESSAGE["vicuna_old"]
|
329 |
-
pass
|
330 |
-
|
331 |
-
# =========================================== Alpaca multi turn
|
332 |
-
# https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
|
333 |
-
alpaca_template = \
|
334 |
-
"{{ bos_token }}"\
|
335 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
336 |
-
"{{ messages[0]['content'] + '\n\n' }}"\
|
337 |
-
"{% set loop_messages = messages[1:] %}"\
|
338 |
-
"{% else %}"\
|
339 |
-
"{{ '{system_message}' + '\n\n' }}"\
|
340 |
-
"{% set loop_messages = messages %}"\
|
341 |
-
"{% endif %}"\
|
342 |
-
"{% for message in loop_messages %}"\
|
343 |
-
"{% if message['role'] == 'user' %}"\
|
344 |
-
"{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
|
345 |
-
"{% elif message['role'] == 'assistant' %}"\
|
346 |
-
"{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
|
347 |
-
"{% else %}"\
|
348 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
349 |
-
"{% endif %}"\
|
350 |
-
"{% endfor %}"\
|
351 |
-
"{% if add_generation_prompt %}"\
|
352 |
-
"{{ '### Response:\n' }}"\
|
353 |
-
"{% endif %}"
|
354 |
-
pass
|
355 |
-
|
356 |
-
alpaca_ollama = \
|
357 |
-
'''
|
358 |
-
FROM {__FILE_LOCATION__}
|
359 |
-
TEMPLATE """{{ if .System }}{{ .System }}
|
360 |
-
|
361 |
-
{{ end }}{{ if .Prompt }}### Instruction:
|
362 |
-
{{ .Prompt }}{{ end }}
|
363 |
-
|
364 |
-
### Response:
|
365 |
-
{{ .Response }}{__EOS_TOKEN__}
|
366 |
-
|
367 |
-
"""
|
368 |
-
PARAMETER stop "{__EOS_TOKEN__}"
|
369 |
-
PARAMETER temperature 1.5
|
370 |
-
PARAMETER min_p 0.1
|
371 |
-
SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
|
372 |
-
'''
|
373 |
-
|
374 |
-
alpaca_eos_token = "eos_token"
|
375 |
-
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
|
376 |
-
DEFAULT_SYSTEM_MESSAGE["alpaca"] = "Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
|
377 |
-
pass
|
378 |
-
|
379 |
-
# =========================================== Gemma
|
380 |
-
# https://huggingface.co/google/gemma-7b-it
|
381 |
-
# Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.
|
382 |
-
# <end_of_turn> maps to 107. user and model are normal 1 word tokens.
|
383 |
-
gemma_template = \
|
384 |
-
"{{ bos_token }}"\
|
385 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
386 |
-
"{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
|
387 |
-
"{% set messages = messages[2:] %}"\
|
388 |
-
"{% endif %}"\
|
389 |
-
"{% for message in messages %}"\
|
390 |
-
"{% if message['role'] == 'user' %}"\
|
391 |
-
"{{'<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>\n'}}"\
|
392 |
-
"{% elif message['role'] == 'assistant' %}"\
|
393 |
-
"{{'<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
|
394 |
-
"{% else %}"\
|
395 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
396 |
-
"{% endif %}"\
|
397 |
-
"{% endfor %}"\
|
398 |
-
"{% if add_generation_prompt %}"\
|
399 |
-
"{{ '<start_of_turn>model\n' }}"\
|
400 |
-
"{% endif %}"
|
401 |
-
pass
|
402 |
-
|
403 |
-
# Ollama from https://www.ollama.com/library/gemma
|
404 |
-
gemma_ollama = \
|
405 |
-
'''
|
406 |
-
FROM {__FILE_LOCATION__}
|
407 |
-
TEMPLATE """<start_of_turn>user
|
408 |
-
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
|
409 |
-
<start_of_turn>model
|
410 |
-
{{ .Response }}<end_of_turn>
|
411 |
-
"""
|
412 |
-
PARAMETER repeat_penalty 1
|
413 |
-
PARAMETER stop "<start_of_turn>"
|
414 |
-
PARAMETER stop "<end_of_turn>"
|
415 |
-
PARAMETER penalize_newline false
|
416 |
-
PARAMETER temperature 1.5
|
417 |
-
PARAMETER min_p 0.1
|
418 |
-
'''
|
419 |
-
|
420 |
-
gemma_eos_token = "<end_of_turn>"
|
421 |
-
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
|
422 |
-
DEFAULT_SYSTEM_MESSAGE["gemma"] = None # No system message in Gemma
|
423 |
-
pass
|
424 |
-
|
425 |
-
# =========================================== Gemma with ChatML instead
|
426 |
-
# We find using <eos> is still more appropriate!
|
427 |
-
gemma_chatml_template = "{{ bos_token }}" + chatml_template
|
428 |
-
pass
|
429 |
-
|
430 |
-
gemma_chatml_ollama = \
|
431 |
-
'''
|
432 |
-
FROM {__FILE_LOCATION__}
|
433 |
-
TEMPLATE """{{ if .System }}<|im_start|>system
|
434 |
-
{{ .System }}<|im_end|>
|
435 |
-
{{ end }}{{ if .Prompt }}<|im_start|>user
|
436 |
-
{{ .Prompt }}<|im_end|>
|
437 |
-
{{ end }}<|im_start|>assistant
|
438 |
-
{{ .Response }}<|im_end|>
|
439 |
-
"""
|
440 |
-
PARAMETER repeat_penalty 1
|
441 |
-
PARAMETER stop "<|im_start|>"
|
442 |
-
PARAMETER stop "<|im_end|>"
|
443 |
-
PARAMETER penalize_newline false
|
444 |
-
PARAMETER temperature 1.5
|
445 |
-
PARAMETER min_p 0.1
|
446 |
-
'''
|
447 |
-
|
448 |
-
gemma_chatml_eos_token = (
|
449 |
-
{"<start_of_turn>" : "<|im_start|>", "<eos>" : "<|im_end|>"},
|
450 |
-
"<|im_end|>",
|
451 |
-
)
|
452 |
-
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
|
453 |
-
DEFAULT_SYSTEM_MESSAGE["gemma_chatml"] = None # No system message in Gemma
|
454 |
-
pass
|
455 |
-
|
456 |
-
# =========================================== Gemma 2
|
457 |
-
# Same as Gemma 1, but with sliding window attention!
|
458 |
-
# https://ollama.com/library/gemma2/blobs/6522ca797f47
|
459 |
-
gemma2_template = gemma_template
|
460 |
-
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
|
461 |
-
gemma2_eos_token = "<end_of_turn>"
|
462 |
-
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
|
463 |
-
DEFAULT_SYSTEM_MESSAGE["gemma2"] = None # No system message in Gemma 2
|
464 |
-
|
465 |
-
# =========================================== Gemma 2 with ChatML instead
|
466 |
-
gemma2_chatml_template = gemma_chatml_template
|
467 |
-
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
|
468 |
-
gemma2_chatml_eos_token = gemma_chatml_eos_token
|
469 |
-
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
|
470 |
-
DEFAULT_SYSTEM_MESSAGE["gemma2_chatml"] = None # No system message in Gemma 2
|
471 |
-
pass
|
472 |
-
|
473 |
-
# =========================================== Llama-3
|
474 |
-
# Weirdly \n\n is needed?
|
475 |
-
llama3_template = \
|
476 |
-
"{{ bos_token }}"\
|
477 |
-
"{% for message in messages %}"\
|
478 |
-
"{% if message['role'] == 'user' %}"\
|
479 |
-
"{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
|
480 |
-
"{% elif message['role'] == 'assistant' %}"\
|
481 |
-
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
|
482 |
-
"{% else %}"\
|
483 |
-
"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
|
484 |
-
"{% endif %}"\
|
485 |
-
"{% endfor %}"\
|
486 |
-
"{% if add_generation_prompt %}"\
|
487 |
-
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
|
488 |
-
"{% endif %}"
|
489 |
-
pass
|
490 |
-
|
491 |
-
# Ollama from https://www.ollama.com/library/llama3
|
492 |
-
llama3_ollama = \
|
493 |
-
'''
|
494 |
-
FROM {__FILE_LOCATION__}
|
495 |
-
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
496 |
-
|
497 |
-
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
498 |
-
|
499 |
-
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
500 |
-
|
501 |
-
{{ .Response }}<|eot_id|>"""
|
502 |
-
PARAMETER stop "<|start_header_id|>"
|
503 |
-
PARAMETER stop "<|end_header_id|>"
|
504 |
-
PARAMETER stop "<|eot_id|>"
|
505 |
-
PARAMETER temperature 1.5
|
506 |
-
PARAMETER min_p 0.1
|
507 |
-
'''
|
508 |
-
|
509 |
-
llama3_template_eos_token = "eos_token"
|
510 |
-
|
511 |
-
CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
|
512 |
-
DEFAULT_SYSTEM_MESSAGE["llama-3"] = None # No system message in Llama-3
|
513 |
-
|
514 |
-
CHAT_TEMPLATES["llama3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
|
515 |
-
DEFAULT_SYSTEM_MESSAGE["llama3"] = None # No system message in Llama-3
|
516 |
-
pass
|
517 |
-
|
518 |
-
|
519 |
-
# =========================================== Phi-3
|
520 |
-
# "{{ bos_token }}"\ # Phi-3.5 removes BOS?
|
521 |
-
phi3_template = \
|
522 |
-
"{% for message in messages %}"\
|
523 |
-
"{% if message['role'] == 'user' %}"\
|
524 |
-
"{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"\
|
525 |
-
"{% elif message['role'] == 'assistant' %}"\
|
526 |
-
"{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"\
|
527 |
-
"{% else %}"\
|
528 |
-
"{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
|
529 |
-
"{% endif %}"\
|
530 |
-
"{% endfor %}"\
|
531 |
-
"{% if add_generation_prompt %}"\
|
532 |
-
"{{ '<|assistant|>\n' }}"\
|
533 |
-
"{% endif %}"
|
534 |
-
pass
|
535 |
-
|
536 |
-
# Ollama from https://www.ollama.com/library/phi3
|
537 |
-
phi3_ollama = \
|
538 |
-
'''
|
539 |
-
FROM {__FILE_LOCATION__}
|
540 |
-
TEMPLATE """{{ if .System }}<|system|>
|
541 |
-
{{ .System }}<|end|>
|
542 |
-
{{ end }}{{ if .Prompt }}<|user|>
|
543 |
-
{{ .Prompt }}<|end|>
|
544 |
-
{{ end }}<|assistant|>
|
545 |
-
{{ .Response }}<|end|>
|
546 |
-
"""
|
547 |
-
PARAMETER stop "<|end|>"
|
548 |
-
PARAMETER stop "<|user|>"
|
549 |
-
PARAMETER stop "<|assistant|>"
|
550 |
-
PARAMETER temperature 1.5
|
551 |
-
PARAMETER min_p 0.1
|
552 |
-
'''
|
553 |
-
|
554 |
-
phi3_template_eos_token = "<|end|>"
|
555 |
-
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
|
556 |
-
DEFAULT_SYSTEM_MESSAGE["phi-3"] = None # No system message in Phi-3
|
557 |
-
|
558 |
-
CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
|
559 |
-
DEFAULT_SYSTEM_MESSAGE["phi-35"] = None # No system message in Phi-3.5
|
560 |
-
|
561 |
-
CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
|
562 |
-
DEFAULT_SYSTEM_MESSAGE["phi-3.5"] = None # No system message in Phi-3.5
|
563 |
-
pass
|
564 |
-
|
565 |
-
# =========================================== Llama-3.1
|
566 |
-
"""
|
567 |
-
No trimming in Llama 3.1 Instruct!
|
568 |
-
Also an extra newline for Cutting Knowledge Date
|
569 |
-
See https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing
|
570 |
-
|
571 |
-
Also should be
|
572 |
-
|
573 |
-
import datetime
|
574 |
-
tokenizer.apply_chat_template(
|
575 |
-
messages,
|
576 |
-
add_generation_prompt = True,
|
577 |
-
tokenize = False,
|
578 |
-
date_string = datetime.today().strftime("%d %B %Y")),
|
579 |
-
)
|
580 |
-
"""
|
581 |
-
|
582 |
-
llama31_template = \
|
583 |
-
"""{{- bos_token }}
|
584 |
-
{%- if custom_tools is defined %}
|
585 |
-
{%- set tools = custom_tools %}
|
586 |
-
{%- endif %}
|
587 |
-
{%- if not tools_in_user_message is defined %}
|
588 |
-
{%- set tools_in_user_message = true %}
|
589 |
-
{%- endif %}
|
590 |
-
{%- if not date_string is defined %}
|
591 |
-
{%- set date_string = "26 July 2024" %}
|
592 |
-
{%- endif %}
|
593 |
-
{%- if not tools is defined %}
|
594 |
-
{%- set tools = none %}
|
595 |
-
{%- endif %}
|
596 |
-
|
597 |
-
{#- This block extracts the system message, so we can slot it into the right place. #}
|
598 |
-
{%- if messages[0]['role'] == 'system' %}
|
599 |
-
{%- set system_message = messages[0]['content'] %}
|
600 |
-
{%- set messages = messages[1:] %}
|
601 |
-
{%- else %}
|
602 |
-
{%- set system_message = "{system_message}" %}
|
603 |
-
{%- endif %}
|
604 |
-
|
605 |
-
{#- System message + builtin tools #}
|
606 |
-
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
|
607 |
-
{%- if builtin_tools is defined or tools is not none %}
|
608 |
-
{{- "Environment: ipython\n" }}
|
609 |
-
{%- endif %}
|
610 |
-
{%- if builtin_tools is defined %}
|
611 |
-
{{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
|
612 |
-
{%- endif %}
|
613 |
-
{{- "Cutting Knowledge Date: December 2023\n" }}
|
614 |
-
{{- "Today Date: " + date_string + "\n\n" }}
|
615 |
-
{%- if tools is not none and not tools_in_user_message %}
|
616 |
-
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
|
617 |
-
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
618 |
-
{{- "Do not use variables.\n\n" }}
|
619 |
-
{%- for t in tools %}
|
620 |
-
{{- t | tojson(indent=4) }}
|
621 |
-
{{- "\n\n" }}
|
622 |
-
{%- endfor %}
|
623 |
-
{%- endif %}
|
624 |
-
{{- system_message }}
|
625 |
-
{{- "<|eot_id|>" }}
|
626 |
-
|
627 |
-
{#- Custom tools are passed in a user message with some extra guidance #}
|
628 |
-
{%- if tools_in_user_message and not tools is none %}
|
629 |
-
{#- Extract the first user message so we can plug it in here #}
|
630 |
-
{%- if messages | length != 0 %}
|
631 |
-
{%- set first_user_message = messages[0]['content'] %}
|
632 |
-
{%- set messages = messages[1:] %}
|
633 |
-
{%- else %}
|
634 |
-
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
635 |
-
{%- endif %}
|
636 |
-
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
|
637 |
-
{{- "Given the following functions, please respond with a JSON for a function call " }}
|
638 |
-
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
|
639 |
-
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
640 |
-
{{- "Do not use variables.\n\n" }}
|
641 |
-
{%- for t in tools %}
|
642 |
-
{{- t | tojson(indent=4) }}
|
643 |
-
{{- "\n\n" }}
|
644 |
-
{%- endfor %}
|
645 |
-
{{- first_user_message + "<|eot_id|>"}}
|
646 |
-
{%- endif %}
|
647 |
-
|
648 |
-
{%- for message in messages %}
|
649 |
-
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
650 |
-
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
651 |
-
{%- elif 'tool_calls' in message %}
|
652 |
-
{%- if not message.tool_calls|length == 1 %}
|
653 |
-
{{- raise_exception("This model only supports single tool-calls at once!") }}
|
654 |
-
{%- endif %}
|
655 |
-
{%- set tool_call = message.tool_calls[0].function %}
|
656 |
-
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
|
657 |
-
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
658 |
-
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
|
659 |
-
{%- for arg_name, arg_val in tool_call.arguments | items %}
|
660 |
-
{{- arg_name + '="' + arg_val + '"' }}
|
661 |
-
{%- if not loop.last %}
|
662 |
-
{{- ", " }}
|
663 |
-
{%- endif %}
|
664 |
-
{%- endfor %}
|
665 |
-
{{- ")" }}
|
666 |
-
{%- else %}
|
667 |
-
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
668 |
-
{{- '{"name": "' + tool_call.name + '", ' }}
|
669 |
-
{{- '"parameters": ' }}
|
670 |
-
{{- tool_call.arguments | tojson }}
|
671 |
-
{{- "}" }}
|
672 |
-
{%- endif %}
|
673 |
-
{%- if builtin_tools is defined %}
|
674 |
-
{#- This means we're in ipython mode #}
|
675 |
-
{{- "<|eom_id|>" }}
|
676 |
-
{%- else %}
|
677 |
-
{{- "<|eot_id|>" }}
|
678 |
-
{%- endif %}
|
679 |
-
{%- elif message.role == "tool" or message.role == "ipython" %}
|
680 |
-
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
|
681 |
-
{%- if message.content is mapping or message.content is iterable %}
|
682 |
-
{{- message.content | tojson }}
|
683 |
-
{%- else %}
|
684 |
-
{{- message.content }}
|
685 |
-
{%- endif %}
|
686 |
-
{{- "<|eot_id|>" }}
|
687 |
-
{%- endif %}
|
688 |
-
{%- endfor %}
|
689 |
-
{%- if add_generation_prompt %}
|
690 |
-
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
691 |
-
{%- endif %}
|
692 |
-
"""
|
693 |
-
pass
|
694 |
-
|
695 |
-
# Ollama from https://ollama.com/library/llama3.1 (needs updating!)
|
696 |
-
llama31_ollama = \
|
697 |
-
'''
|
698 |
-
FROM {__FILE_LOCATION__}
|
699 |
-
TEMPLATE """{{ if .Messages }}
|
700 |
-
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
701 |
-
{{- if .System }}
|
702 |
-
|
703 |
-
{{ .System }}
|
704 |
-
{{- end }}
|
705 |
-
{{- if .Tools }}
|
706 |
-
|
707 |
-
You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the original use question.
|
708 |
-
{{- end }}
|
709 |
-
{{- end }}<|eot_id|>
|
710 |
-
{{- range $i, $_ := .Messages }}
|
711 |
-
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
712 |
-
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
713 |
-
{{- if and $.Tools $last }}
|
714 |
-
|
715 |
-
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
716 |
-
|
717 |
-
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
718 |
-
|
719 |
-
{{ $.Tools }}
|
720 |
-
{{- end }}
|
721 |
-
|
722 |
-
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
723 |
-
|
724 |
-
{{ end }}
|
725 |
-
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
726 |
-
{{- if .ToolCalls }}
|
727 |
-
|
728 |
-
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
729 |
-
{{- else }}
|
730 |
-
|
731 |
-
{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
|
732 |
-
{{- end }}
|
733 |
-
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
734 |
-
|
735 |
-
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
736 |
-
|
737 |
-
{{ end }}
|
738 |
-
{{- end }}
|
739 |
-
{{- end }}
|
740 |
-
{{- else }}
|
741 |
-
{{- if .System }}<|start_header_id|>system<|end_header_id|>
|
742 |
-
|
743 |
-
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
744 |
-
|
745 |
-
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
746 |
-
|
747 |
-
{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}"""
|
748 |
-
PARAMETER stop "<|start_header_id|>"
|
749 |
-
PARAMETER stop "<|end_header_id|>"
|
750 |
-
PARAMETER stop "<|eot_id|>"
|
751 |
-
PARAMETER stop "<|eom_id|>"
|
752 |
-
PARAMETER temperature 1.5
|
753 |
-
PARAMETER min_p 0.1
|
754 |
-
'''
|
755 |
-
|
756 |
-
llama31_template_eos_token = "eos_token"
|
757 |
-
CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
|
758 |
-
DEFAULT_SYSTEM_MESSAGE["llama-3.1"] = "" # Llama3.1 default system message is empty + the dates
|
759 |
-
|
760 |
-
CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
|
761 |
-
DEFAULT_SYSTEM_MESSAGE["llama-31"] = "" # Llama3.1 default system message is empty + the dates
|
762 |
-
pass
|
763 |
-
|
764 |
-
|
765 |
-
# =========================================== Qwen 2.5
|
766 |
-
qwen25_template = \
|
767 |
-
"""{%- if tools %}
|
768 |
-
{{- \'<|im_start|>system\\n\' }}
|
769 |
-
{%- if messages[0][\'role\'] == \'system\' %}
|
770 |
-
{{- messages[0][\'content\'] }}
|
771 |
-
{%- else %}
|
772 |
-
{{- \'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\' }}
|
773 |
-
{%- endif %}
|
774 |
-
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
|
775 |
-
{%- for tool in tools %}
|
776 |
-
{{- "\\n" }}
|
777 |
-
{{- tool | tojson }}
|
778 |
-
{%- endfor %}
|
779 |
-
{{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}
|
780 |
-
{%- if messages[0][\'role\'] == \'system\' %}
|
781 |
-
{{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}
|
782 |
-
{%- else %}
|
783 |
-
{{- \'<|im_start|>system\\n{system_message}<|im_end|>\\n\' }}
|
784 |
-
{%- endif %}\n{%- endif %}\n{%- for message in messages %}
|
785 |
-
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
786 |
-
{{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}
|
787 |
-
{%- elif message.role == "assistant" %}
|
788 |
-
{{- \'<|im_start|>\' + message.role }}
|
789 |
-
{%- if message.content %}
|
790 |
-
{{- \'\\n\' + message.content }}
|
791 |
-
{%- endif %}
|
792 |
-
{%- for tool_call in message.tool_calls %}
|
793 |
-
{%- if tool_call.function is defined %}
|
794 |
-
{%- set tool_call = tool_call.function %}
|
795 |
-
{%- endif %}
|
796 |
-
{{- \'\\n<tool_call>\\n{"name": "\' }}
|
797 |
-
{{- tool_call.name }}
|
798 |
-
{{- \'", "arguments": \' }}
|
799 |
-
{{- tool_call.arguments | tojson }}
|
800 |
-
{{- \'}\\n</tool_call>\' }}
|
801 |
-
{%- endfor %}
|
802 |
-
{{- \'<|im_end|>\\n\' }}
|
803 |
-
{%- elif message.role == "tool" %}
|
804 |
-
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} {{- \'<|im_start|>user\' }}
|
805 |
-
{%- endif %}
|
806 |
-
{{- \'\\n<tool_response>\\n\' }}
|
807 |
-
{{- message.content }}
|
808 |
-
{{- \'\\n</tool_response>\' }}
|
809 |
-
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
810 |
-
{{- \'<|im_end|>\\n\' }}
|
811 |
-
{%- endif %}
|
812 |
-
{%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}
|
813 |
-
{{- \'<|im_start|>assistant\\n\' }}
|
814 |
-
{%- endif %}
|
815 |
-
"""
|
816 |
-
|
817 |
-
|
818 |
-
# Ollama from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
|
819 |
-
qwen25_ollama = \
|
820 |
-
'''
|
821 |
-
FROM {__FILE_LOCATION__}
|
822 |
-
TEMPLATE """{{- if .Messages }}
|
823 |
-
{{- if or .System .Tools }}<|im_start|>system
|
824 |
-
{{- if .System }}
|
825 |
-
{{ .System }}
|
826 |
-
{{- end }}
|
827 |
-
{{- if .Tools }}
|
828 |
-
|
829 |
-
# Tools
|
830 |
-
|
831 |
-
You may call one or more functions to assist with the user query.
|
832 |
-
|
833 |
-
You are provided with function signatures within <tools></tools> XML tags:
|
834 |
-
<tools>
|
835 |
-
{{- range .Tools }}
|
836 |
-
{"type": "function", "function": {{ .Function }}}
|
837 |
-
{{- end }}
|
838 |
-
</tools>
|
839 |
-
|
840 |
-
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
841 |
-
<tool_call>
|
842 |
-
{"name": <function-name>, "arguments": <args-json-object>}
|
843 |
-
</tool_call>
|
844 |
-
{{- end }}<|im_end|>
|
845 |
-
{{ end }}
|
846 |
-
{{- range $i, $_ := .Messages }}
|
847 |
-
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
848 |
-
{{- if eq .Role "user" }}<|im_start|>user
|
849 |
-
{{ .Content }}<|im_end|>
|
850 |
-
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
851 |
-
{{ if .Content }}{{ .Content }}
|
852 |
-
{{- else if .ToolCalls }}<tool_call>
|
853 |
-
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
854 |
-
{{ end }}</tool_call>
|
855 |
-
{{- end }}{{ if not $last }}<|im_end|>
|
856 |
-
{{ end }}
|
857 |
-
{{- else if eq .Role "tool" }}<|im_start|>user
|
858 |
-
<tool_response>
|
859 |
-
{{ .Content }}
|
860 |
-
</tool_response><|im_end|>
|
861 |
-
{{ end }}
|
862 |
-
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
863 |
-
{{ end }}
|
864 |
-
{{- end }}
|
865 |
-
{{- else }}
|
866 |
-
{{- if .System }}<|im_start|>system
|
867 |
-
{{ .System }}<|im_end|>
|
868 |
-
{{ end }}{{ if .Prompt }}<|im_start|>user
|
869 |
-
{{ .Prompt }}<|im_end|>
|
870 |
-
{{ end }}<|im_start|>assistant
|
871 |
-
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
|
872 |
-
PARAMETER stop "<|im_end|>"
|
873 |
-
PARAMETER stop "<|endoftext|>"
|
874 |
-
PARAMETER temperature 1.5
|
875 |
-
PARAMETER min_p 0.1
|
876 |
-
'''
|
877 |
-
|
878 |
-
qwen25_template_eos_token = "eos_token"
|
879 |
-
qwen25_default_system_message = "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."
|
880 |
-
CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
881 |
-
DEFAULT_SYSTEM_MESSAGE["qwen-2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
|
882 |
-
|
883 |
-
CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
884 |
-
DEFAULT_SYSTEM_MESSAGE["qwen-25"] = qwen25_default_system_message # No system message in Qwen 2.5
|
885 |
-
|
886 |
-
CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
887 |
-
DEFAULT_SYSTEM_MESSAGE["qwen25"] = qwen25_default_system_message # No system message in Qwen 2.5
|
888 |
-
|
889 |
-
CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
890 |
-
DEFAULT_SYSTEM_MESSAGE["qwen2.5"] = qwen25_default_system_message # No system message in Qwen 2.5
|
891 |
-
pass
|
892 |
-
|
893 |
-
def _change_system_message(template: str, type_chat_template: str, system_message: str = None):
|
894 |
-
system_message_pattern = r"\{system_message\}"
|
895 |
-
|
896 |
-
# For predefined templates, check if default system message exists
|
897 |
-
default_system_message = DEFAULT_SYSTEM_MESSAGE.get(f"{type_chat_template}", None)
|
898 |
-
if default_system_message is None:
|
899 |
-
if system_message is not None:
|
900 |
-
logger.warning_once(
|
901 |
-
f"Unsloth: You tried to change the system message for {type_chat_template}, "
|
902 |
-
"but it doesn't have a default system message. "
|
903 |
-
"You need to manually add the system message in your data."
|
904 |
-
)
|
905 |
-
return template, system_message
|
906 |
-
pass
|
907 |
-
|
908 |
-
# For custom templates
|
909 |
-
if type_chat_template is None:
|
910 |
-
has_placeholder = re.search(system_message_pattern, template) is not None
|
911 |
-
|
912 |
-
if has_placeholder:
|
913 |
-
if system_message is None:
|
914 |
-
raise ValueError("Unsloth: You need to provide a system message for custom templates.")
|
915 |
-
new_template = re.sub(system_message_pattern, system_message, template)
|
916 |
-
return new_template, system_message
|
917 |
-
|
918 |
-
return template, system_message
|
919 |
-
pass
|
920 |
-
|
921 |
-
# For predefined templates with default system message
|
922 |
-
message_to_use = system_message if system_message is not None else default_system_message
|
923 |
-
new_template = re.sub(system_message_pattern, message_to_use, template)
|
924 |
-
|
925 |
-
return new_template, message_to_use
|
926 |
-
pass
|
927 |
-
|
928 |
-
|
929 |
-
def get_chat_template(
|
930 |
-
tokenizer,
|
931 |
-
chat_template = "chatml",
|
932 |
-
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
|
933 |
-
map_eos_token = True,
|
934 |
-
system_message = None,
|
935 |
-
):
|
936 |
-
assert(type(map_eos_token) is bool)
|
937 |
-
old_tokenizer = tokenizer
|
938 |
-
|
939 |
-
IS_GEMMA = False
|
940 |
-
if tokenizer.__class__.__name__.startswith("Gemma"):
|
941 |
-
if chat_template == "chatml": chat_template = "gemma_chatml"
|
942 |
-
IS_GEMMA = True
|
943 |
-
pass
|
944 |
-
|
945 |
-
# We add a check for Llama-3
|
946 |
-
# if chat_template == "llama-3":
|
947 |
-
# tokenizer._using_llama3_template = True
|
948 |
-
# else:
|
949 |
-
# llama3_tokens = set(["<|end_header_id|>", "<|eot_id|>", "<|start_header_id|>"])
|
950 |
-
# check_llama3_tokens = llama3_tokens & set(str(x) for x in tokenizer.added_tokens_decoder.values())
|
951 |
-
# if len(check_llama3_tokens) == len(llama3_tokens):
|
952 |
-
# tokenizer._using_llama3_template = True
|
953 |
-
# pass
|
954 |
-
# pass
|
955 |
-
|
956 |
-
# We first check if the tokenizer is a fast one. If not, we cannot convert this!
|
957 |
-
is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
|
958 |
-
old_padding_side = tokenizer.padding_side
|
959 |
-
|
960 |
-
same_padding_token = False
|
961 |
-
type_chat_template = None
|
962 |
-
|
963 |
-
if type(chat_template) in (list, tuple,):
|
964 |
-
# For changing system message later
|
965 |
-
# Since it's not supported yet, we will raise an error first!
|
966 |
-
type_chat_template = chat_template[0].lower()
|
967 |
-
chat_template, stop_word = chat_template
|
968 |
-
assert(type(chat_template) is str)
|
969 |
-
assert(type(stop_word) is str)
|
970 |
-
ollama_modelfile = None
|
971 |
-
|
972 |
-
elif type(chat_template) is str:
|
973 |
-
# For changing system message later
|
974 |
-
type_chat_template = chat_template.lower()
|
975 |
-
|
976 |
-
chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
|
977 |
-
|
978 |
-
# Check mapping to eos_token
|
979 |
-
if not map_eos_token and yes_map_eos_token: map_eos_token = True
|
980 |
-
if not yes_map_eos_token and map_eos_token: map_eos_token = False
|
981 |
-
|
982 |
-
if type(stop_word) in (list, tuple,):
|
983 |
-
token_mapping, stop_word = stop_word
|
984 |
-
assert(type(token_mapping) is dict)
|
985 |
-
else:
|
986 |
-
token_mapping = None
|
987 |
-
|
988 |
-
assert(type(stop_word) is str)
|
989 |
-
|
990 |
-
# Check fast tokenizer
|
991 |
-
if not is_fast_tokenizer:
|
992 |
-
print(
|
993 |
-
"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
|
994 |
-
"Please log a Github issue if you want this as a new feature!\n"\
|
995 |
-
"Your chat template will still work, but it won't add or edit tokens."
|
996 |
-
)
|
997 |
-
|
998 |
-
elif token_mapping is not None:
|
999 |
-
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
|
1000 |
-
# For Gemma :)
|
1001 |
-
|
1002 |
-
string_vocab = tokenizer._tokenizer.to_str()
|
1003 |
-
|
1004 |
-
skipped = 0
|
1005 |
-
for old_token, new_token in token_mapping.items():
|
1006 |
-
old_count = string_vocab.count(f'"{old_token}"')
|
1007 |
-
new_count = string_vocab.count(f'"{new_token}"')
|
1008 |
-
if new_count != 0:
|
1009 |
-
print(f"{new_token} is already a token. Skipping.")
|
1010 |
-
skipped += 1
|
1011 |
-
elif old_count == 0:
|
1012 |
-
raise RuntimeError(f"{old_token} was not part of the tokenizer!")
|
1013 |
-
else:
|
1014 |
-
string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
|
1015 |
-
pass
|
1016 |
-
pass
|
1017 |
-
|
1018 |
-
if map_eos_token and (not stop_word in token_mapping.values()):
|
1019 |
-
# Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
|
1020 |
-
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
|
1021 |
-
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
|
1022 |
-
pass
|
1023 |
-
|
1024 |
-
if skipped != len(token_mapping):
|
1025 |
-
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
|
1026 |
-
|
1027 |
-
# Careful on pad_token
|
1028 |
-
old_pad_token = tokenizer.pad_token
|
1029 |
-
if old_pad_token == tokenizer.eos_token:
|
1030 |
-
old_pad_token = stop_word
|
1031 |
-
same_padding_token = True
|
1032 |
-
pass
|
1033 |
-
|
1034 |
-
if map_eos_token:
|
1035 |
-
new_tokenizer = tokenizer.__class__(
|
1036 |
-
tokenizer_object = new_tokenizer,
|
1037 |
-
eos_token = stop_word,
|
1038 |
-
pad_token = old_pad_token,
|
1039 |
-
)
|
1040 |
-
else:
|
1041 |
-
new_tokenizer = tokenizer.__class__(
|
1042 |
-
tokenizer_object = new_tokenizer,
|
1043 |
-
pad_token = old_pad_token,
|
1044 |
-
)
|
1045 |
-
pass
|
1046 |
-
|
1047 |
-
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
|
1048 |
-
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
|
1049 |
-
else:
|
1050 |
-
pass
|
1051 |
-
|
1052 |
-
elif map_eos_token and (stop_word != "eos_token"):
|
1053 |
-
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
|
1054 |
-
|
1055 |
-
# Replaces the old EOS token with a new one.
|
1056 |
-
# Useful for ChatML <|im_end|> for example.
|
1057 |
-
# Usually we train 2 more tokens <|im_start|> and <|im_end|>
|
1058 |
-
# But training the lm_head and embeddings are slow!
|
1059 |
-
# This is a HACK!
|
1060 |
-
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
|
1061 |
-
|
1062 |
-
old_bos_token = getattr(tokenizer, "bos_token", None)
|
1063 |
-
old_eos_token = getattr(tokenizer, "eos_token", None)
|
1064 |
-
old_pad_token = getattr(tokenizer, "pad_token", None)
|
1065 |
-
old_unk_token = getattr(tokenizer, "unk_token", None)
|
1066 |
-
|
1067 |
-
string_vocab = tokenizer._tokenizer.to_str()
|
1068 |
-
# First check if new stop_word is in the tokenizer
|
1069 |
-
if stop_word in string_vocab:
|
1070 |
-
# We shall swap them around
|
1071 |
-
temporary_stop_token = "<|:__TEMP//STOP//TOKEN__:|>"
|
1072 |
-
string_vocab = string_vocab.replace(old_eos_token, temporary_stop_token)
|
1073 |
-
string_vocab = string_vocab.replace(stop_word, old_eos_token)
|
1074 |
-
string_vocab = string_vocab.replace(temporary_stop_token, stop_word)
|
1075 |
-
else:
|
1076 |
-
string_vocab = string_vocab.replace(old_eos_token, stop_word)
|
1077 |
-
pass
|
1078 |
-
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
|
1079 |
-
|
1080 |
-
# Careful on pad_token
|
1081 |
-
if old_pad_token == old_eos_token:
|
1082 |
-
old_pad_token = stop_word
|
1083 |
-
same_padding_token = True
|
1084 |
-
pass
|
1085 |
-
|
1086 |
-
new_tokenizer = tokenizer.__class__(
|
1087 |
-
tokenizer_object = new_tokenizer,
|
1088 |
-
bos_token = old_bos_token,
|
1089 |
-
eos_token = stop_word,
|
1090 |
-
unk_token = old_unk_token,
|
1091 |
-
pad_token = old_pad_token,
|
1092 |
-
)
|
1093 |
-
|
1094 |
-
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
|
1095 |
-
token_mapping = { old_eos_token : stop_word, }
|
1096 |
-
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
|
1097 |
-
pass
|
1098 |
-
|
1099 |
-
else:
|
1100 |
-
raise TypeError(
|
1101 |
-
f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
|
1102 |
-
f"{CHAT_TEMPLATES.keys()}"
|
1103 |
-
)
|
1104 |
-
pass
|
1105 |
-
|
1106 |
-
# Careful on Gemma
|
1107 |
-
# bos_token is a must or else losses become too high
|
1108 |
-
if IS_GEMMA and not chat_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
|
1109 |
-
chat_template = "{{ bos_token }}" + chat_template
|
1110 |
-
pass
|
1111 |
-
|
1112 |
-
# For ShareGPT role -> from and content -> value
|
1113 |
-
new_chat_template = chat_template\
|
1114 |
-
.replace("'role'", "'" + mapping["role"] + "'")\
|
1115 |
-
.replace("'content'", "'" + mapping["content"] + "'")\
|
1116 |
-
.replace("'user'", "'" + mapping["user"] + "'")\
|
1117 |
-
.replace("'assistant'", "'" + mapping["assistant"] + "'")
|
1118 |
-
|
1119 |
-
_, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
|
1120 |
-
tokenizer.padding_side = old_padding_side
|
1121 |
-
|
1122 |
-
# If not normal HF, we add a check to make old templates work
|
1123 |
-
if mapping != {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"}:
|
1124 |
-
chat_template = \
|
1125 |
-
"{% if 'role' in messages[0] %}" + \
|
1126 |
-
chat_template + \
|
1127 |
-
"{% else %}" + \
|
1128 |
-
new_chat_template + \
|
1129 |
-
"{% endif %}"
|
1130 |
-
else:
|
1131 |
-
chat_template = new_chat_template
|
1132 |
-
pass
|
1133 |
-
|
1134 |
-
chat_template, system_message = _change_system_message(chat_template, type_chat_template, system_message)
|
1135 |
-
|
1136 |
-
tokenizer.chat_template = chat_template
|
1137 |
-
|
1138 |
-
# Also fix up other tokens
|
1139 |
-
old_pad_token = getattr(old_tokenizer, "pad_token", None)
|
1140 |
-
old_bos_token = getattr(old_tokenizer, "bos_token", None)
|
1141 |
-
old_unk_token = getattr(old_tokenizer, "unk_token", None)
|
1142 |
-
new_pad_token = getattr(tokenizer, "pad_token", None)
|
1143 |
-
new_bos_token = getattr(tokenizer, "bos_token", None)
|
1144 |
-
new_unk_token = getattr(tokenizer, "unk_token", None)
|
1145 |
-
if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
|
1146 |
-
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
|
1147 |
-
if not same_padding_token:
|
1148 |
-
if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
|
1149 |
-
pass
|
1150 |
-
|
1151 |
-
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
|
1152 |
-
|
1153 |
-
# Patch saving functions
|
1154 |
-
tokenizer = patch_saving_functions(tokenizer)
|
1155 |
-
|
1156 |
-
# Add Ollama
|
1157 |
-
tokenizer._ollama_modelfile = ollama_modelfile
|
1158 |
-
tokenizer._system_message = system_message
|
1159 |
-
return tokenizer#, stopping_criteria
|
1160 |
-
pass
|
1161 |
-
|
1162 |
-
|
1163 |
-
def remove_special_tokens(tokenizer, prompt):
|
1164 |
-
# Removes double BOS token
|
1165 |
-
if prompt.startswith(tokenizer.bos_token):
|
1166 |
-
prompt = prompt[len(tokenizer.bos_token):]
|
1167 |
-
pass
|
1168 |
-
return prompt
|
1169 |
-
pass
|
1170 |
-
|
1171 |
-
|
1172 |
-
def _parse_combined_prompt(combined_prompt, dataset):
|
1173 |
-
# Find {...}
|
1174 |
-
possible_columns = re.findall(r"\{(.+?)\}", combined_prompt)
|
1175 |
-
dataset_columns = set(dataset.column_names)
|
1176 |
-
for column in possible_columns:
|
1177 |
-
if column not in dataset_columns:
|
1178 |
-
raise KeyError(
|
1179 |
-
f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
|
1180 |
-
f"Only allowed columns are {list(dataset_columns)}"
|
1181 |
-
)
|
1182 |
-
pass
|
1183 |
-
pass
|
1184 |
-
|
1185 |
-
# Find [[...]]
|
1186 |
-
optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
|
1187 |
-
optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]
|
1188 |
-
|
1189 |
-
final_optional_prompts = []
|
1190 |
-
if len(optional_prompts) != 0:
|
1191 |
-
# Add left
|
1192 |
-
left = optional_prompts[0]
|
1193 |
-
l = left[0][0]
|
1194 |
-
if l != 0: final_optional_prompts.append(combined_prompt[:l])
|
1195 |
-
|
1196 |
-
# Add in between
|
1197 |
-
for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):
|
1198 |
-
l, r = left[0][-1], right[0][0]
|
1199 |
-
final_optional_prompts.append(left)
|
1200 |
-
if l != r: final_optional_prompts.append(combined_prompt[l : r])
|
1201 |
-
pass
|
1202 |
-
final_optional_prompts.append(optional_prompts[-1])
|
1203 |
-
|
1204 |
-
# Add right
|
1205 |
-
right = optional_prompts[-1]
|
1206 |
-
r = right[0][1]
|
1207 |
-
if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])
|
1208 |
-
else:
|
1209 |
-
# Just add in the entire string
|
1210 |
-
final_optional_prompts.append(combined_prompt)
|
1211 |
-
pass
|
1212 |
-
|
1213 |
-
check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
|
1214 |
-
assert(combined_prompt == check_combined)
|
1215 |
-
|
1216 |
-
return possible_columns, final_optional_prompts
|
1217 |
-
pass
|
1218 |
-
|
1219 |
-
|
1220 |
-
def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
|
1221 |
-
# Start final prompt!
|
1222 |
-
function = ["def __combined_prompt_processor__(examples):"]
|
1223 |
-
columns = list(set(possible_columns))
|
1224 |
-
for column in columns:
|
1225 |
-
function.append(f"{' '*4}{column}__ = examples['{column}']")
|
1226 |
-
function.append(f"{' '*4}texts = []")
|
1227 |
-
function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
|
1228 |
-
|
1229 |
-
# Add optional tags as well!
|
1230 |
-
final_prompt = ""
|
1231 |
-
formatter = []
|
1232 |
-
|
1233 |
-
for j, optional_prompt in enumerate(final_optional_prompts):
|
1234 |
-
if type(optional_prompt) is str:
|
1235 |
-
columns = re.findall(r"\{(.+?)\}", optional_prompt)
|
1236 |
-
formatter += columns
|
1237 |
-
# Must escape \n \r
|
1238 |
-
final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
|
1239 |
-
else:
|
1240 |
-
where, prompt = optional_prompt
|
1241 |
-
# Strip [[...]]
|
1242 |
-
# Must escape \n \r
|
1243 |
-
prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
|
1244 |
-
columns = re.findall(r"\{(.+?)\}", prompt)
|
1245 |
-
x = f"__optional_{j}__"
|
1246 |
-
prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
|
1247 |
-
function.append(prompt)
|
1248 |
-
formatter.append(x)
|
1249 |
-
final_prompt += "{" + x + "}"
|
1250 |
-
pass
|
1251 |
-
pass
|
1252 |
-
|
1253 |
-
function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
|
1254 |
-
function.append(f"{' '*8}texts.append("\
|
1255 |
-
f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
|
1256 |
-
function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
|
1257 |
-
return "\n".join(function)
|
1258 |
-
pass
|
1259 |
-
|
1260 |
-
|
1261 |
-
def to_sharegpt(
|
1262 |
-
dataset,
|
1263 |
-
merged_prompt = "",
|
1264 |
-
merged_column_name = "instruction",
|
1265 |
-
output_column_name = "output",
|
1266 |
-
remove_unused_columns = True,
|
1267 |
-
conversation_extension = 1,
|
1268 |
-
random_state = 3407,
|
1269 |
-
):
|
1270 |
-
"""
|
1271 |
-
Converts a dataset to ShareGPT style.
|
1272 |
-
ShareGPT requires only 1 input and 1 output field.
|
1273 |
-
This means one has to merge multiple columns into 1 for 1 input field.
|
1274 |
-
Use `conversation_extension` to increase the length of each conversation by randomnly
|
1275 |
-
selecting a few and packing them into 1.
|
1276 |
-
|
1277 |
-
merged_prompt = "", Prompt to merge columns into 1 input
|
1278 |
-
merged_column_name = "instruction", Final column name for the input field
|
1279 |
-
output_column_name = "output", Final column name for the output field
|
1280 |
-
remove_unused_columns = True,
|
1281 |
-
conversation_extension = 1, Automatically combines `conversation_extension` convos into 1
|
1282 |
-
random_state = 3407,
|
1283 |
-
"""
|
1284 |
-
if "conversations" in dataset.column_names:
|
1285 |
-
convo = dataset[0]["conversations"]
|
1286 |
-
if type(convo) is list:
|
1287 |
-
raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
|
1288 |
-
pass
|
1289 |
-
pass
|
1290 |
-
|
1291 |
-
possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
|
1292 |
-
function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
|
1293 |
-
exec(function, globals())
|
1294 |
-
dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
|
1295 |
-
|
1296 |
-
def __convert_to_sharegpt__(examples):
|
1297 |
-
users = examples[merged_column_name]
|
1298 |
-
assistants = examples[output_column_name]
|
1299 |
-
texts = [
|
1300 |
-
[
|
1301 |
-
{"from" : "human", "value" : str(user) },
|
1302 |
-
{"from" : "gpt", "value" : str(assistant)},
|
1303 |
-
] \
|
1304 |
-
for user, assistant in zip(users, assistants)
|
1305 |
-
]
|
1306 |
-
return { "conversations" : texts, }
|
1307 |
-
pass
|
1308 |
-
|
1309 |
-
dataset = dataset.map(
|
1310 |
-
__convert_to_sharegpt__,
|
1311 |
-
batched = True,
|
1312 |
-
desc = "Converting to ShareGPT",
|
1313 |
-
# Remove unused columns!
|
1314 |
-
remove_columns = dataset.column_names if remove_unused_columns else None,
|
1315 |
-
)
|
1316 |
-
|
1317 |
-
# Randomnly concat conversations to create a long stream!
|
1318 |
-
from datasets import concatenate_datasets
|
1319 |
-
n_extensions = max(conversation_extension-1, 0)
|
1320 |
-
if n_extensions == 0: return dataset
|
1321 |
-
|
1322 |
-
dataset = dataset.rename_columns({"conversations" : "conversations0"})
|
1323 |
-
all_shuffled = [dataset]
|
1324 |
-
for j in range(1, n_extensions+1):
|
1325 |
-
shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
|
1326 |
-
all_shuffled.append(shuffled)
|
1327 |
-
pass
|
1328 |
-
dataset = concatenate_datasets(all_shuffled, axis = 1)
|
1329 |
-
|
1330 |
-
# Combine them into 1
|
1331 |
-
function = "def __combine_conversations__(examples):\n"
|
1332 |
-
n_extensions += 1
|
1333 |
-
for j in range(n_extensions):
|
1334 |
-
function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
|
1335 |
-
function += f"{' '*4}convos = []\n"
|
1336 |
-
function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
|
1337 |
-
f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
|
1338 |
-
function += f"{' '*8}convos.append("\
|
1339 |
-
f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
|
1340 |
-
function += f"{' '*4}return " + "{ " + "'conversations' : convos" + " }"
|
1341 |
-
|
1342 |
-
# Map function
|
1343 |
-
exec(function, globals())
|
1344 |
-
dataset = dataset.map(
|
1345 |
-
__combine_conversations__,
|
1346 |
-
batched = True,
|
1347 |
-
desc = "Extending conversations",
|
1348 |
-
# Remove unused columns!
|
1349 |
-
remove_columns = dataset.column_names if remove_unused_columns else None,
|
1350 |
-
)
|
1351 |
-
return dataset
|
1352 |
-
pass
|
1353 |
-
|
1354 |
-
|
1355 |
-
def standardize_sharegpt(
|
1356 |
-
dataset,
|
1357 |
-
aliases_for_system = ["system",],
|
1358 |
-
aliases_for_user = ["user", "human", "input",],
|
1359 |
-
aliases_for_assistant = ["gpt", "assistant", "output",],
|
1360 |
-
):
|
1361 |
-
"""
|
1362 |
-
Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
|
1363 |
-
|
1364 |
-
Get aliases for the system, user and assistant roles.
|
1365 |
-
These shall map to "system", "user" and "assistant" respectively.
|
1366 |
-
|
1367 |
-
aliases_for_system = ["system",],
|
1368 |
-
aliases_for_user = ["user", "human", "input",],
|
1369 |
-
aliases_for_assistant = ["gpt", "assistant", "output",],
|
1370 |
-
"""
|
1371 |
-
import collections
|
1372 |
-
import itertools
|
1373 |
-
|
1374 |
-
convos = dataset[:10]["conversations"]
|
1375 |
-
uniques = collections.defaultdict(list)
|
1376 |
-
for convo in convos:
|
1377 |
-
for message in convo:
|
1378 |
-
for key, value in message.items():
|
1379 |
-
uniques[key].append(value)
|
1380 |
-
pass
|
1381 |
-
|
1382 |
-
# Must be only 2 entries
|
1383 |
-
assert(len(uniques.keys()) == 2)
|
1384 |
-
|
1385 |
-
keys = list(uniques.keys())
|
1386 |
-
length_first = len(set(uniques[keys[0]]))
|
1387 |
-
length_second = len(set(uniques[keys[1]]))
|
1388 |
-
|
1389 |
-
if length_first < length_second:
|
1390 |
-
# Role is assigned to the first element
|
1391 |
-
role_key = keys[0]
|
1392 |
-
content_key = keys[1]
|
1393 |
-
else:
|
1394 |
-
role_key = keys[1]
|
1395 |
-
content_key = keys[0]
|
1396 |
-
pass
|
1397 |
-
|
1398 |
-
# Check roles are in aliases
|
1399 |
-
all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
|
1400 |
-
roles = set(uniques[role_key])
|
1401 |
-
leftover_aliases = (all_aliases | roles) - all_aliases
|
1402 |
-
if len(leftover_aliases) != 0:
|
1403 |
-
raise TypeError(
|
1404 |
-
f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
|
1405 |
-
)
|
1406 |
-
pass
|
1407 |
-
|
1408 |
-
# Mapping for aliases
|
1409 |
-
aliases_mapping = {}
|
1410 |
-
for x in aliases_for_system: aliases_mapping[x] = "system"
|
1411 |
-
for x in aliases_for_user: aliases_mapping[x] = "user"
|
1412 |
-
for x in aliases_for_assistant: aliases_mapping[x] = "assistant"
|
1413 |
-
|
1414 |
-
def _standardize_dataset(examples):
|
1415 |
-
convos = examples["conversations"]
|
1416 |
-
all_convos = []
|
1417 |
-
for convo in convos:
|
1418 |
-
new_convo = [
|
1419 |
-
{ "role" : aliases_mapping[message[role_key]], "content" : message[content_key], }
|
1420 |
-
for message in convo
|
1421 |
-
]
|
1422 |
-
all_convos.append(new_convo)
|
1423 |
-
pass
|
1424 |
-
return { "conversations" : all_convos, }
|
1425 |
-
pass
|
1426 |
-
|
1427 |
-
return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format")
|
1428 |
-
pass
|
1429 |
-
|
1430 |
-
|
1431 |
-
def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
|
1432 |
-
added_tokens_decoder = tokenizer.added_tokens_decoder.values()
|
1433 |
-
added_tokens_decoder = [str(x) for x in added_tokens_decoder]
|
1434 |
-
|
1435 |
-
# Remove added_tokens_decoder duplicates
|
1436 |
-
added_tokens_decoder = list(set(added_tokens_decoder) - set(extra_eos_tokens))
|
1437 |
-
|
1438 |
-
# Remove BOS
|
1439 |
-
if getattr(tokenizer, "bos_token", None) is not None:
|
1440 |
-
added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
|
1441 |
-
pass
|
1442 |
-
|
1443 |
-
repeatted_tokens = []
|
1444 |
-
# Join all vocab
|
1445 |
-
joined_text = "\x01\x00".join(added_tokens_decoder)
|
1446 |
-
for token in added_tokens_decoder:
|
1447 |
-
n = len(token)
|
1448 |
-
repeatted_counts = joined_text.count(token[:n//2])
|
1449 |
-
# Try finding longer than 1/2 of the token in the rest
|
1450 |
-
# For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>
|
1451 |
-
if repeatted_counts > 2:
|
1452 |
-
for j in range(n//2+1, n):
|
1453 |
-
if joined_text.count(token[:j]) < repeatted_counts:
|
1454 |
-
j -= 1
|
1455 |
-
# Remove repeatted tokens to reduce search space
|
1456 |
-
joined_text = joined_text.replace(token[:j], "")
|
1457 |
-
repeatted_tokens.append(token[:j])
|
1458 |
-
break
|
1459 |
-
pass
|
1460 |
-
pass
|
1461 |
-
pass
|
1462 |
-
|
1463 |
-
# Remove duplicates
|
1464 |
-
splitted = joined_text.split("\x01\x00")
|
1465 |
-
final_eos_tokens = []
|
1466 |
-
for old, new in zip(added_tokens_decoder, splitted):
|
1467 |
-
if old == new: final_eos_tokens.append(old)
|
1468 |
-
pass
|
1469 |
-
final_eos_tokens += extra_eos_tokens
|
1470 |
-
final_eos_tokens += repeatted_tokens
|
1471 |
-
|
1472 |
-
# Remove new lines, spaces and HTML tags
|
1473 |
-
filtered_eos_tokens = []
|
1474 |
-
for token in final_eos_tokens:
|
1475 |
-
if token.count("\n") == len(token): continue
|
1476 |
-
elif token.count("▁") == len(token): continue
|
1477 |
-
elif token.startswith("<") and len(token) <= 2: continue
|
1478 |
-
elif token.startswith("</") and len(token) == 3: continue
|
1479 |
-
filtered_eos_tokens.append(token)
|
1480 |
-
pass
|
1481 |
-
return filtered_eos_tokens
|
1482 |
-
pass
|
1483 |
-
|
1484 |
-
|
1485 |
-
def construct_chat_template( \
|
1486 |
-
|
1487 |
-
tokenizer = None,
|
1488 |
-
|
1489 |
-
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
1490 |
-
|
1491 |
-
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1492 |
-
|
1493 |
-
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1494 |
-
|
1495 |
-
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1496 |
-
|
1497 |
-
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1498 |
-
|
1499 |
-
{OUTPUT}<|eot_id|>""",
|
1500 |
-
|
1501 |
-
default_system_message = \
|
1502 |
-
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
|
1503 |
-
|
1504 |
-
extra_eos_tokens = None,
|
1505 |
-
):
|
1506 |
-
"""
|
1507 |
-
Creates a Ollama modelfile and a HF Jinja template from a custom
|
1508 |
-
template. You must provide 2x examples of an input & output.
|
1509 |
-
There is an optional system message as well.
|
1510 |
-
|
1511 |
-
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
|
1512 |
-
"""
|
1513 |
-
# Strip only the left
|
1514 |
-
chat_template = chat_template.lstrip()
|
1515 |
-
|
1516 |
-
assert(tokenizer is not None)
|
1517 |
-
|
1518 |
-
if extra_eos_tokens is None: extra_eos_tokens = []
|
1519 |
-
elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]
|
1520 |
-
|
1521 |
-
vocab = tokenizer.get_vocab()
|
1522 |
-
for extra_eos in extra_eos_tokens:
|
1523 |
-
assert(type(extra_eos) is str)
|
1524 |
-
if extra_eos not in vocab:
|
1525 |
-
raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
|
1526 |
-
pass
|
1527 |
-
pass
|
1528 |
-
|
1529 |
-
error_msg = \
|
1530 |
-
"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
|
1531 |
-
"and the assistant output {OUTPUT}\n\n"\
|
1532 |
-
"For example what is not allowed is just:\n"\
|
1533 |
-
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
|
1534 |
-
"What is required is 2x of this:\n"\
|
1535 |
-
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
|
1536 |
-
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
|
1537 |
-
|
1538 |
-
# Check for EOS after {OUTPUT}
|
1539 |
-
if tokenizer.eos_token is not None:
|
1540 |
-
extra_eos_tokens.insert(0, tokenizer.eos_token)
|
1541 |
-
if len(extra_eos_tokens) == 0:
|
1542 |
-
raise RuntimeError(
|
1543 |
-
"Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
|
1544 |
-
)
|
1545 |
-
pass
|
1546 |
-
|
1547 |
-
# Check tokenizer types
|
1548 |
-
tokenizer_name = tokenizer.name_or_path.lower()
|
1549 |
-
if tokenizer_name.startswith(("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")):
|
1550 |
-
# Add <|eot_id|>
|
1551 |
-
extra_eos_tokens.append("<|eot_id|>")
|
1552 |
-
elif ("<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template) and \
|
1553 |
-
tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
|
1554 |
-
# Warn
|
1555 |
-
logger.warning(
|
1556 |
-
"Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
|
1557 |
-
"Please use the instruct version or use <|end_of_text|>"
|
1558 |
-
)
|
1559 |
-
pass
|
1560 |
-
extra_eos_tokens = list(set(extra_eos_tokens))
|
1561 |
-
|
1562 |
-
count_eos = 0
|
1563 |
-
for eos in extra_eos_tokens:
|
1564 |
-
count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
|
1565 |
-
pass
|
1566 |
-
|
1567 |
-
# This forces you to provide 2 input and outputs
|
1568 |
-
final_combined_check = False
|
1569 |
-
|
1570 |
-
try:
|
1571 |
-
# O(N^2) search finding 2 repeatted pieces of text
|
1572 |
-
j = len(chat_template)-1
|
1573 |
-
at_least_one = False
|
1574 |
-
while j > 0:
|
1575 |
-
found = chat_template.rfind(chat_template[j:], 0, j)
|
1576 |
-
if found == -1: break
|
1577 |
-
j -= 1
|
1578 |
-
at_least_one = True
|
1579 |
-
pass
|
1580 |
-
if j > 0: j += 1
|
1581 |
-
else: raise RuntimeError(error_msg)
|
1582 |
-
|
1583 |
-
if not at_least_one: raise RuntimeError(error_msg)
|
1584 |
-
|
1585 |
-
# Must be equivalent to left
|
1586 |
-
final_combined_check = True
|
1587 |
-
|
1588 |
-
# Repeatted text
|
1589 |
-
instruction_response = chat_template[j:]
|
1590 |
-
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
|
1591 |
-
raise RuntimeError(error_msg)
|
1592 |
-
pass
|
1593 |
-
|
1594 |
-
# 1st System, Instruction, Output pair
|
1595 |
-
left = chat_template[:j]
|
1596 |
-
# 2nd Instruction, Output pair
|
1597 |
-
right = chat_template[j:]
|
1598 |
-
|
1599 |
-
final_combined_check = left if final_combined_check else chat_template
|
1600 |
-
|
1601 |
-
# Isolate input
|
1602 |
-
extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
|
1603 |
-
if len(extra_eos_tokens_regex) != 0:
|
1604 |
-
find_end = f"(?:{extra_eos_tokens_regex})?"
|
1605 |
-
else:
|
1606 |
-
find_end = ""
|
1607 |
-
find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
|
1608 |
-
input_end = list(re.finditer(find_end, right))
|
1609 |
-
assert(len(input_end) == 1)
|
1610 |
-
input_end = input_end[0]
|
1611 |
-
input_end = input_end.span(0)[1]
|
1612 |
-
input_part = right[:input_end]
|
1613 |
-
|
1614 |
-
# Isolate output
|
1615 |
-
output_part = right[input_end:]
|
1616 |
-
|
1617 |
-
# Isolate system
|
1618 |
-
where_system = left.find(input_part)
|
1619 |
-
system_part = left[:where_system if where_system != -1 else len(left)]
|
1620 |
-
|
1621 |
-
# Check if the user provided a correct prompt
|
1622 |
-
combined = system_part + input_part + output_part
|
1623 |
-
if combined != final_combined_check:
|
1624 |
-
combined_changed = combined .replace('\n', '\\n')
|
1625 |
-
left_changed = final_combined_check.replace('\n', '\\n')
|
1626 |
-
raise RuntimeError(
|
1627 |
-
"Unsloth: The prompt template you provided isn't correct. You gave:\n"\
|
1628 |
-
f"{combined_changed}\n\n"\
|
1629 |
-
"But we require the following:\n"\
|
1630 |
-
f"{left_changed}"
|
1631 |
-
)
|
1632 |
-
pass
|
1633 |
-
except:
|
1634 |
-
ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
|
1635 |
-
|
1636 |
-
ending = re.escape(ending)
|
1637 |
-
find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
|
1638 |
-
response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
|
1639 |
-
response_part = response_part[0]
|
1640 |
-
|
1641 |
-
for j in range(1, len(response_part)):
|
1642 |
-
try_find = re.escape(response_part[:j])
|
1643 |
-
try: found = next(re.finditer("(" + try_find + ").+?\{INPUT\}", chat_template, flags = re.DOTALL | re.MULTILINE))
|
1644 |
-
except: break
|
1645 |
-
pass
|
1646 |
-
separator = found.group(1)
|
1647 |
-
|
1648 |
-
response_start = chat_template.find(response_part)
|
1649 |
-
start_instruction = chat_template[:response_start].rfind(separator)
|
1650 |
-
if start_instruction == -1: start_instruction = 0
|
1651 |
-
instruction_part = chat_template[start_instruction:response_start]
|
1652 |
-
|
1653 |
-
combined = instruction_part + response_part
|
1654 |
-
where = chat_template.find(combined)
|
1655 |
-
system_part = chat_template[:where]
|
1656 |
-
|
1657 |
-
system_part, input_part, output_part = system_part, instruction_part, response_part
|
1658 |
-
pass
|
1659 |
-
|
1660 |
-
if count_eos == 0:
|
1661 |
-
logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
|
1662 |
-
eos = extra_eos_tokens[0]
|
1663 |
-
output_part = output_part + eos
|
1664 |
-
pass
|
1665 |
-
|
1666 |
-
# Ollama modelfile parts
|
1667 |
-
|
1668 |
-
# Check bos_token is in system prompt
|
1669 |
-
ollama_system = system_part
|
1670 |
-
has_bos_token = False
|
1671 |
-
always_bos_token = False
|
1672 |
-
if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
|
1673 |
-
always_bos_token = True
|
1674 |
-
if ollama_system.startswith(tokenizer.bos_token):
|
1675 |
-
has_bos_token = True
|
1676 |
-
ollama_system = ollama_system[len(tokenizer.bos_token):]
|
1677 |
-
pass
|
1678 |
-
pass
|
1679 |
-
# Check system
|
1680 |
-
if "{SYSTEM}" in ollama_system:
|
1681 |
-
system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
|
1682 |
-
else:
|
1683 |
-
system_modelfile = ollama_system
|
1684 |
-
pass
|
1685 |
-
input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}"
|
1686 |
-
output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
|
1687 |
-
|
1688 |
-
# Ollama EOS
|
1689 |
-
ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)
|
1690 |
-
ollama_eos = '\n'.join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
|
1691 |
-
|
1692 |
-
# Add temperature and min_p to counteract gibberish
|
1693 |
-
ollama_eos += "\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1"
|
1694 |
-
|
1695 |
-
# Ollama modelfile
|
1696 |
-
part = '"""'
|
1697 |
-
modelfile = 'FROM {__FILE_LOCATION__}\n\n'\
|
1698 |
-
'TEMPLATE ' + part + system_modelfile + input_modelfile + output_modelfile + \
|
1699 |
-
part + '\n\n' + ollama_eos
|
1700 |
-
|
1701 |
-
# HF Jinja Chat template
|
1702 |
-
def process(part, which, content = "message['content']"):
|
1703 |
-
if part.endswith(which):
|
1704 |
-
part = "'" + part[:part.find(which)] + f"' + {content}"
|
1705 |
-
elif part.startswith(which):
|
1706 |
-
part = f"{content} + '" + part[part.find(which):] + "'"
|
1707 |
-
else:
|
1708 |
-
part = "'" + part.replace(which, f"' + {content} + '") + "'"
|
1709 |
-
if part.startswith("'' + "): part = part[5:]
|
1710 |
-
return part
|
1711 |
-
pass
|
1712 |
-
input_jinja = process(input_part, "{INPUT}")
|
1713 |
-
output_jinja = process(output_part, "{OUTPUT}")
|
1714 |
-
pass
|
1715 |
-
|
1716 |
-
jinja_template = \
|
1717 |
-
"{% for message in loop_messages %}"\
|
1718 |
-
"{% if message['role'] == 'user' %}"\
|
1719 |
-
"{{ " + input_jinja + " }}"\
|
1720 |
-
"{% elif message['role'] == 'assistant' %}"\
|
1721 |
-
"{{ " + output_jinja + " }}"\
|
1722 |
-
"{% else %}"\
|
1723 |
-
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
1724 |
-
"{% endif %}"\
|
1725 |
-
"{% endfor %}"\
|
1726 |
-
"{% if add_generation_prompt %}"\
|
1727 |
-
"{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
|
1728 |
-
"{% endif %}"
|
1729 |
-
pass
|
1730 |
-
|
1731 |
-
# Now add system prompt to jinja
|
1732 |
-
if len(system_part) != 0:
|
1733 |
-
partial_system = process(system_part, "{SYSTEM}", "messages[0]['content']")
|
1734 |
-
partial_system = partial_system.replace("{SYSTEM}", "")
|
1735 |
-
|
1736 |
-
if "{SYSTEM}" in partial_system:
|
1737 |
-
if default_system_message is None:
|
1738 |
-
raise RuntimeError("Unsloth: Please specify a default system message!")
|
1739 |
-
pass
|
1740 |
-
|
1741 |
-
# Separate the BOS
|
1742 |
-
if has_bos_token:
|
1743 |
-
partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
|
1744 |
-
system_part = system_part .replace(tokenizer.bos_token, "", 1)
|
1745 |
-
pass
|
1746 |
-
|
1747 |
-
partial_system = \
|
1748 |
-
"{% if messages[0]['role'] == 'system' %}"\
|
1749 |
-
"{{ " + partial_system + " }}"\
|
1750 |
-
"{% set loop_messages = messages[1:] %}"
|
1751 |
-
if default_system_message is not None:
|
1752 |
-
full_system = system_part.replace("{SYSTEM}", default_system_message)
|
1753 |
-
if "{SYSTEM}" in system_part:
|
1754 |
-
modelfile += '\nSYSTEM "' + default_system_message + '"'
|
1755 |
-
pass
|
1756 |
-
partial_system += "{% else %}"\
|
1757 |
-
"{{ '" + full_system + "' }}"\
|
1758 |
-
"{% set loop_messages = messages %}"\
|
1759 |
-
"{% endif %}"
|
1760 |
-
else:
|
1761 |
-
partial_system += "{% endif %}"
|
1762 |
-
pass
|
1763 |
-
|
1764 |
-
jinja_template = partial_system + jinja_template
|
1765 |
-
|
1766 |
-
if has_bos_token:
|
1767 |
-
jinja_template = "{{ bos_token }}" + jinja_template
|
1768 |
-
pass
|
1769 |
-
|
1770 |
-
# Fix missing loop_messages
|
1771 |
-
if "{% set loop_messages = messages %}" not in jinja_template:
|
1772 |
-
jinja_template = jinja_template.replace(
|
1773 |
-
"{% for message in loop_messages %}",
|
1774 |
-
"{% for message in messages %}",
|
1775 |
-
1, # Only replace the first one
|
1776 |
-
)
|
1777 |
-
pass
|
1778 |
-
|
1779 |
-
# Check if system part is the same!
|
1780 |
-
jinja_template = re.sub(
|
1781 |
-
r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"\
|
1782 |
-
r"\{\% set loop\_messages \= messages\[1\:\] \%\}"\
|
1783 |
-
r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"\
|
1784 |
-
r"\{\% for message in loop\_messages \%\}",
|
1785 |
-
r"{{ '\1' }}{% for message in messages %}",
|
1786 |
-
jinja_template, flags = re.MULTILINE | re.DOTALL,
|
1787 |
-
)
|
1788 |
-
|
1789 |
-
# Check jinja tempate for bos
|
1790 |
-
if always_bos_token:
|
1791 |
-
if not jinja_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
|
1792 |
-
jinja_template = "{{ bos_token }}" + jinja_template
|
1793 |
-
pass
|
1794 |
-
|
1795 |
-
# Get instruction and output parts for train_on_inputs = False
|
1796 |
-
input_part = input_part [:input_part .find("{INPUT}")]
|
1797 |
-
output_part = output_part[:output_part.find("{OUTPUT}")]
|
1798 |
-
return modelfile, jinja_template, input_part, output_part
|
1799 |
-
pass
|
1800 |
-
|
1801 |
-
|
1802 |
-
def test_construct_chat_template():
|
1803 |
-
token = "hf_"
|
1804 |
-
from transformers import AutoTokenizer
|
1805 |
-
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token = token)
|
1806 |
-
|
1807 |
-
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
1808 |
-
|
1809 |
-
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1810 |
-
|
1811 |
-
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1812 |
-
|
1813 |
-
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1814 |
-
|
1815 |
-
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1816 |
-
|
1817 |
-
{OUTPUT}<|eot_id|>"""
|
1818 |
-
|
1819 |
-
default_system_message = \
|
1820 |
-
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
|
1821 |
-
|
1822 |
-
extra_eos_tokens = None
|
1823 |
-
|
1824 |
-
modelfile, jinja_template, _, _ = construct_chat_template(
|
1825 |
-
tokenizer = tokenizer,
|
1826 |
-
chat_template = chat_template,
|
1827 |
-
extra_eos_tokens = extra_eos_tokens,
|
1828 |
-
)
|
1829 |
-
|
1830 |
-
messages = [
|
1831 |
-
{"role": "system", "content": "You are an assistant"},
|
1832 |
-
{"role": "user", "content": "What is 2+2?"},
|
1833 |
-
{"role": "assistant", "content": "It's 4."},
|
1834 |
-
{"role": "user", "content": "Ok!"},
|
1835 |
-
{"role": "assistant", "content": "Anything else?"},
|
1836 |
-
{"role": "user", "content": "What's 2x2?"},
|
1837 |
-
]
|
1838 |
-
correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1839 |
-
|
1840 |
-
tokenizer.chat_template = jinja_template
|
1841 |
-
new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1842 |
-
assert(correct_output == new_output)
|
1843 |
-
pass
|
1844 |
-
pass
|
1845 |
-
|
1846 |
-
|
1847 |
-
def apply_chat_template( \
|
1848 |
-
|
1849 |
-
dataset,
|
1850 |
-
tokenizer = None,
|
1851 |
-
|
1852 |
-
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
1853 |
-
|
1854 |
-
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1855 |
-
|
1856 |
-
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1857 |
-
|
1858 |
-
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1859 |
-
|
1860 |
-
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1861 |
-
|
1862 |
-
{OUTPUT}<|eot_id|>""",
|
1863 |
-
|
1864 |
-
default_system_message = \
|
1865 |
-
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
|
1866 |
-
|
1867 |
-
extra_eos_tokens = None,
|
1868 |
-
|
1869 |
-
):
|
1870 |
-
"""
|
1871 |
-
Creates a Ollama modelfile and a HF Jinja template from a custom
|
1872 |
-
template. You must provide 2x examples of an input & output.
|
1873 |
-
There is an optional system message as well.
|
1874 |
-
|
1875 |
-
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
|
1876 |
-
"""
|
1877 |
-
modelfile, jinja_template, input_part, output_part = construct_chat_template(
|
1878 |
-
tokenizer = tokenizer,
|
1879 |
-
chat_template = chat_template,
|
1880 |
-
default_system_message = default_system_message,
|
1881 |
-
extra_eos_tokens = extra_eos_tokens,
|
1882 |
-
)
|
1883 |
-
def formatting_prompts_func(examples):
|
1884 |
-
convos = examples["conversations"]
|
1885 |
-
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
|
1886 |
-
return { "text" : texts, }
|
1887 |
-
pass
|
1888 |
-
|
1889 |
-
tokenizer.chat_template = jinja_template
|
1890 |
-
tokenizer._ollama_modelfile = modelfile
|
1891 |
-
tokenizer._unsloth_input_part = input_part
|
1892 |
-
tokenizer._unsloth_output_part = output_part
|
1893 |
-
|
1894 |
-
return dataset.map(formatting_prompts_func, batched = True,)
|
1895 |
-
pass
|
1896 |
-
|
1897 |
-
|
1898 |
-
def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
|
1899 |
-
class StoppingCriteriaSub(StoppingCriteria):
|
1900 |
-
__slots__ = "stop_token", "single_match", "length",
|
1901 |
-
|
1902 |
-
def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
|
1903 |
-
super().__init__()
|
1904 |
-
if stops == "eos_token":
|
1905 |
-
self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
|
1906 |
-
self.length = 1
|
1907 |
-
else:
|
1908 |
-
self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
|
1909 |
-
self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
|
1910 |
-
self.length = self.stop_token.shape[0]
|
1911 |
-
pass
|
1912 |
-
self.single_match = self.length == 1
|
1913 |
-
pass
|
1914 |
-
|
1915 |
-
def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
|
1916 |
-
input_ids = input_ids.ravel()
|
1917 |
-
last_token = input_ids[-1]
|
1918 |
-
if self.single_match and (last_token == self.stop_token): return True
|
1919 |
-
|
1920 |
-
if input_ids.shape[0] >= self.length and \
|
1921 |
-
(input_ids[-self.length:] == self.stop_token).all(): return True
|
1922 |
-
return False
|
1923 |
-
pass
|
1924 |
-
pass
|
1925 |
-
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
|
1926 |
-
return stopping_criteria
|
1927 |
-
pass
|
1928 |
-
|
1929 |
-
|
1930 |
-
def test_chat_templates():
|
1931 |
-
messages = [
|
1932 |
-
{"role": "system","content": " You are a friendly chatbot.",},
|
1933 |
-
{"role": "user", "content": "What is 2+2?"},
|
1934 |
-
{"role": "assistant", "content": "It's 4."},
|
1935 |
-
{"role": "user", "content": " But 2+2 is equal to 5. "},
|
1936 |
-
{"role": "assistant", "content": "No I'm sure its 4."},
|
1937 |
-
{"role": "user", "content": " No it's 100% 5! "},
|
1938 |
-
]
|
1939 |
-
|
1940 |
-
# Zephyr
|
1941 |
-
from transformers import AutoTokenizer
|
1942 |
-
template = zephyr_template
|
1943 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
1944 |
-
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1945 |
-
correct_tokenizer.chat_template = template
|
1946 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1947 |
-
assert(correct_prompt == our_prompt)
|
1948 |
-
|
1949 |
-
# Chatml
|
1950 |
-
template = chatml_template
|
1951 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
|
1952 |
-
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1953 |
-
correct_tokenizer.chat_template = template
|
1954 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1955 |
-
assert(correct_prompt == our_prompt)
|
1956 |
-
|
1957 |
-
# Mistral
|
1958 |
-
template = mistral_template
|
1959 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
1960 |
-
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
1961 |
-
correct_tokenizer.chat_template = template
|
1962 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
1963 |
-
assert(correct_prompt == our_prompt)
|
1964 |
-
|
1965 |
-
# Llama
|
1966 |
-
template = llama_template
|
1967 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
|
1968 |
-
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1969 |
-
correct_tokenizer.chat_template = template
|
1970 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1971 |
-
assert(correct_prompt == our_prompt)
|
1972 |
-
|
1973 |
-
# Vicuna
|
1974 |
-
try:
|
1975 |
-
from fastchat.conversation import get_conv_template
|
1976 |
-
except:
|
1977 |
-
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
|
1978 |
-
from fastchat.conversation import get_conv_template
|
1979 |
-
correct_prompt = get_conv_template("vicuna_v1.1")
|
1980 |
-
for j in range(len(messages)-1):
|
1981 |
-
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
|
1982 |
-
correct_prompt.append_message(correct_prompt.roles[1], "")
|
1983 |
-
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
|
1984 |
-
|
1985 |
-
template = vicuna_template
|
1986 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
1987 |
-
correct_tokenizer.chat_template = template
|
1988 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
1989 |
-
assert(correct_prompt == our_prompt)
|
1990 |
-
|
1991 |
-
try:
|
1992 |
-
from fastchat.conversation import get_conv_template
|
1993 |
-
except:
|
1994 |
-
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
|
1995 |
-
from fastchat.conversation import get_conv_template
|
1996 |
-
correct_prompt = get_conv_template("zero_shot")
|
1997 |
-
for j in range(len(messages)-1):
|
1998 |
-
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
|
1999 |
-
correct_prompt.append_message(correct_prompt.roles[1], "")
|
2000 |
-
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
|
2001 |
-
|
2002 |
-
template = vicuna_old_template
|
2003 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
2004 |
-
correct_tokenizer.chat_template = template
|
2005 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2006 |
-
# We add </s> ourselves
|
2007 |
-
assert(correct_prompt == our_prompt.replace("</s>", ""))
|
2008 |
-
|
2009 |
-
# Gemma
|
2010 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
|
2011 |
-
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2012 |
-
correct_tokenizer.chat_template = gemma_template
|
2013 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2014 |
-
assert(our_prompt == correct_prompt)
|
2015 |
-
|
2016 |
-
# Llama-3
|
2017 |
-
template = llama3_template
|
2018 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")
|
2019 |
-
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2020 |
-
correct_tokenizer.chat_template = template
|
2021 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2022 |
-
assert(correct_prompt == our_prompt)
|
2023 |
-
|
2024 |
-
# Phi-3
|
2025 |
-
template = phi3_template
|
2026 |
-
correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
2027 |
-
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2028 |
-
correct_tokenizer.chat_template = template
|
2029 |
-
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2030 |
-
assert(correct_prompt == our_prompt)
|
2031 |
-
pass
|
2032 |
-
|
2033 |
-
|
2034 |
-
def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
|
2035 |
-
"""
|
2036 |
-
Carefully checks the output of GGUF's tokenization and HF.
|
2037 |
-
Can catch all tokenization bugs.
|
2038 |
-
"""
|
2039 |
-
import subprocess
|
2040 |
-
import re
|
2041 |
-
messages = [
|
2042 |
-
{"role": "user", "content": "What is 2+2?"},
|
2043 |
-
{"role": "assistant", "content": "It's 4."},
|
2044 |
-
{"role": "user", "content": " But 2+2 is equal to 5. "},
|
2045 |
-
{"role": "assistant", "content": "No I'm sure its 4."},
|
2046 |
-
{"role": "user", "content": " No it's 100% 5! "},
|
2047 |
-
]
|
2048 |
-
|
2049 |
-
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
2050 |
-
|
2051 |
-
### Instruction:
|
2052 |
-
{}
|
2053 |
-
|
2054 |
-
### Input:
|
2055 |
-
{}
|
2056 |
-
|
2057 |
-
### Response:
|
2058 |
-
{}""".format(
|
2059 |
-
"Describe the city given eloquently.", # instruction
|
2060 |
-
"The lost city of Atlantis.", # input
|
2061 |
-
"", # output - leave this blank for generation!
|
2062 |
-
)
|
2063 |
-
prompts = [ prompt, ]
|
2064 |
-
|
2065 |
-
if tokenizer.chat_template is not None:
|
2066 |
-
prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2067 |
-
prompt = prompt.replace("'", "") # Subprocess does not like ''
|
2068 |
-
prompt = remove_special_tokens(tokenizer, prompt)
|
2069 |
-
prompts.append(prompt)
|
2070 |
-
pass
|
2071 |
-
|
2072 |
-
for prompt in prompts:
|
2073 |
-
command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
|
2074 |
-
f"--check-tensors -p '{prompt}'"
|
2075 |
-
|
2076 |
-
datas = []
|
2077 |
-
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
|
2078 |
-
for line in sp.stdout:
|
2079 |
-
datas.append(line.decode("utf-8", errors = "replace"))
|
2080 |
-
pass
|
2081 |
-
gguf_tokens = "".join(datas)
|
2082 |
-
|
2083 |
-
# Now extract GGUF tokenization attempt
|
2084 |
-
gguf_tokenized = re.findall("([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
|
2085 |
-
gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
|
2086 |
-
input_ids = tokenizer(prompt).input_ids
|
2087 |
-
|
2088 |
-
tokens = tokenizer.batch_decode(input_ids)
|
2089 |
-
hf_tokenized = list(zip(input_ids, tokens))
|
2090 |
-
|
2091 |
-
# Compare to Huggingface
|
2092 |
-
for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
|
2093 |
-
if (hf_token[0] != gguf_token[0]):
|
2094 |
-
print("Failed GGUF != HF at", j)
|
2095 |
-
print("HF =", hf_token)
|
2096 |
-
print("GGUF =", gguf_token)
|
2097 |
-
print(hf_tokenized)
|
2098 |
-
print()
|
2099 |
-
print(gguf_tokenized)
|
2100 |
-
print()
|
2101 |
-
raise RuntimeError("Failed comparing GGUF to HF.")
|
2102 |
-
pass
|
2103 |
-
pass
|
2104 |
-
return True
|
2105 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/__init__.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
from .cross_entropy_loss import (
|
16 |
-
fast_cross_entropy_loss,
|
17 |
-
post_patch_loss_function,
|
18 |
-
patch_loss_functions,
|
19 |
-
)
|
20 |
-
from .rms_layernorm import (
|
21 |
-
fast_rms_layernorm,
|
22 |
-
patch_rms_layernorm,
|
23 |
-
unpatch_rms_layernorm,
|
24 |
-
)
|
25 |
-
from .layernorm import (
|
26 |
-
fast_layernorm,
|
27 |
-
patch_layernorm,
|
28 |
-
)
|
29 |
-
from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
|
30 |
-
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
31 |
-
from .geglu import (
|
32 |
-
geglu_exact_forward_kernel,
|
33 |
-
geglu_exact_backward_kernel,
|
34 |
-
geglu_approx_forward_kernel,
|
35 |
-
geglu_approx_backward_kernel,
|
36 |
-
)
|
37 |
-
from .fast_lora import (
|
38 |
-
get_lora_parameters,
|
39 |
-
get_lora_parameters_bias,
|
40 |
-
apply_lora_mlp_swiglu,
|
41 |
-
apply_lora_mlp_geglu_exact,
|
42 |
-
apply_lora_mlp_geglu_approx,
|
43 |
-
apply_lora_qkv,
|
44 |
-
apply_lora_o,
|
45 |
-
fast_lora_forward,
|
46 |
-
)
|
47 |
-
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
|
48 |
-
|
49 |
-
from .flex_attention import (
|
50 |
-
HAS_FLEX_ATTENTION,
|
51 |
-
slow_attention_softcapping,
|
52 |
-
slow_inference_attention_softcapping,
|
53 |
-
create_flex_attention_causal_mask,
|
54 |
-
create_flex_attention_sliding_window_mask,
|
55 |
-
)
|
56 |
-
|
57 |
-
import os
|
58 |
-
if "UNSLOTH_ZOO_IS_PRESENT" not in os.environ:
|
59 |
-
try:
|
60 |
-
print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
|
61 |
-
except:
|
62 |
-
print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
|
63 |
-
pass
|
64 |
-
pass
|
65 |
-
del os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/cross_entropy_loss.py
DELETED
@@ -1,405 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import triton
|
16 |
-
import triton.language as tl
|
17 |
-
import torch
|
18 |
-
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh, triton_cast
|
19 |
-
from transformers.models.llama.modeling_llama import logger
|
20 |
-
from packaging.version import Version
|
21 |
-
|
22 |
-
from unsloth_zoo.loss_utils import (
|
23 |
-
patch_loss_functions as _patch_loss_functions,
|
24 |
-
post_patch_loss_function,
|
25 |
-
)
|
26 |
-
|
27 |
-
|
28 |
-
@triton.heuristics({
|
29 |
-
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
30 |
-
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
31 |
-
})
|
32 |
-
@triton.jit
|
33 |
-
def _cross_entropy_forward(
|
34 |
-
logits_ptr ,
|
35 |
-
logits_row_stride ,
|
36 |
-
loss_ptr ,
|
37 |
-
logsumexp_ptr ,
|
38 |
-
labels_ptr ,
|
39 |
-
VOCAB_SIZE ,
|
40 |
-
BLOCK_SIZE : tl.constexpr,
|
41 |
-
DO_SOFTCAPPING ,
|
42 |
-
SOFTCAP ,
|
43 |
-
DO_LOGIT_SCALING ,
|
44 |
-
LOGIT_SCALE ,
|
45 |
-
):
|
46 |
-
"""
|
47 |
-
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
48 |
-
Pi = exp(xi) / sum(exp(xi))
|
49 |
-
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
50 |
-
= -y [ x - log[sum(exp(x))] ]
|
51 |
-
= y * (log[sum(exp(x))] - x)
|
52 |
-
If y == 0: CE_i = 0
|
53 |
-
If y == 1: CE_i = logsumexp - x
|
54 |
-
|
55 |
-
logsumexp is also stable
|
56 |
-
Take y = log[sum(exp(x))]
|
57 |
-
exp(y) = sum(exp(x))
|
58 |
-
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
|
59 |
-
exp(y) = exp(c)*sum(exp(x - c))
|
60 |
-
y = log(exp(c)*sum(exp(x - c)))
|
61 |
-
y = c + log[sum(exp(x - c))]
|
62 |
-
This means we can set c = max(x) to make sure
|
63 |
-
exp(x - c) always is exp(x - max(x)).
|
64 |
-
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
|
65 |
-
"""
|
66 |
-
row_idx = tl.program_id(0)
|
67 |
-
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
68 |
-
loss_ptr += row_idx
|
69 |
-
logsumexp_ptr += row_idx
|
70 |
-
labels_ptr += row_idx
|
71 |
-
|
72 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
73 |
-
mask = col_offsets < VOCAB_SIZE
|
74 |
-
|
75 |
-
label_idx = tl.load(labels_ptr).to(tl.int32)
|
76 |
-
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
77 |
-
|
78 |
-
# Go logit scaling for Cohere: t * x
|
79 |
-
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
80 |
-
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
81 |
-
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
82 |
-
|
83 |
-
c = tl.max(logits, 0)
|
84 |
-
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
85 |
-
|
86 |
-
if label_idx != -100:
|
87 |
-
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
88 |
-
# Go logit scaling for Cohere: t * x
|
89 |
-
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
90 |
-
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
91 |
-
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
92 |
-
loss = logsumexp - x
|
93 |
-
else:
|
94 |
-
loss = 0.0
|
95 |
-
tl.store(logsumexp_ptr, logsumexp)
|
96 |
-
tl.store(loss_ptr, loss)
|
97 |
-
pass
|
98 |
-
|
99 |
-
|
100 |
-
@triton.heuristics({
|
101 |
-
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
102 |
-
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
103 |
-
})
|
104 |
-
@triton.jit
|
105 |
-
def _chunked_cross_entropy_forward(
|
106 |
-
logits_ptr ,
|
107 |
-
logits_row_stride ,
|
108 |
-
loss_ptr ,
|
109 |
-
logsumexp_ptr ,
|
110 |
-
labels_ptr ,
|
111 |
-
VOCAB_SIZE ,
|
112 |
-
N_CHUNKS ,
|
113 |
-
BLOCK_SIZE : tl.constexpr,
|
114 |
-
DO_SOFTCAPPING ,
|
115 |
-
SOFTCAP ,
|
116 |
-
DO_LOGIT_SCALING ,
|
117 |
-
LOGIT_SCALE ,
|
118 |
-
):
|
119 |
-
"""
|
120 |
-
256K vocab divided in 4 chunks
|
121 |
-
|
122 |
-
|-65536-| |-65536-| |-65536-| |-65536-|
|
123 |
-
|-------| |-------| |-------| |-------|
|
124 |
-
|-------| |-------| |-------| |-------|
|
125 |
-
|
126 |
-
If y == 0: CE_i = 0
|
127 |
-
If y == 1: CE_i = logsumexp - x
|
128 |
-
|
129 |
-
Notice we can do logsumexp for each chunk and then
|
130 |
-
logsumexp[chunk_sum(logsumexp)] == logsumexp
|
131 |
-
|
132 |
-
chunk_sum = log[chunk_sum(logsumexp)]
|
133 |
-
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
|
134 |
-
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
|
135 |
-
= log[sum(exp(a)) + ... + sum(exp(z))]
|
136 |
-
= logsumexp(x)
|
137 |
-
|
138 |
-
This means we can perform a logsumexp for each chunk, then do a
|
139 |
-
final logsumexp reduction!
|
140 |
-
|
141 |
-
Ie do: logsumexp(chunked_logsumexp) - x
|
142 |
-
"""
|
143 |
-
row_idx = tl.program_id(0)
|
144 |
-
chunk_idx = tl.program_id(1)
|
145 |
-
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
146 |
-
loss_ptr += row_idx
|
147 |
-
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
|
148 |
-
labels_ptr += row_idx
|
149 |
-
|
150 |
-
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
151 |
-
mask = col_offsets < VOCAB_SIZE
|
152 |
-
|
153 |
-
label_idx = tl.load(labels_ptr).to(tl.int32)
|
154 |
-
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
155 |
-
|
156 |
-
# Go logit scaling for Cohere: t * x
|
157 |
-
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
158 |
-
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
159 |
-
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
160 |
-
|
161 |
-
c = tl.max(logits, 0)
|
162 |
-
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
163 |
-
|
164 |
-
if chunk_idx == 0:
|
165 |
-
# logsumexp(chunked_logsumexp) - x
|
166 |
-
# Do the -x separately
|
167 |
-
if label_idx != -100:
|
168 |
-
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
169 |
-
# Go logit scaling for Cohere: t * x
|
170 |
-
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
171 |
-
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
172 |
-
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
173 |
-
loss = -1.0 * x
|
174 |
-
else:
|
175 |
-
loss = 0.0
|
176 |
-
tl.store(loss_ptr, loss)
|
177 |
-
pass
|
178 |
-
tl.store(logsumexp_ptr, logsumexp)
|
179 |
-
pass
|
180 |
-
|
181 |
-
|
182 |
-
@triton.heuristics({
|
183 |
-
"DO_SOFTCAPPING": lambda args: bool(args["DO_SOFTCAPPING" ]),
|
184 |
-
"DO_LOGIT_SCALING": lambda args: bool(args["DO_LOGIT_SCALING"]),
|
185 |
-
})
|
186 |
-
@triton.jit
|
187 |
-
def _cross_entropy_backward(
|
188 |
-
logits_ptr ,
|
189 |
-
logits_row_stride ,
|
190 |
-
dloss_ptr ,
|
191 |
-
dloss_row_stride ,
|
192 |
-
logsumexp_ptr ,
|
193 |
-
labels_ptr ,
|
194 |
-
VOCAB_SIZE ,
|
195 |
-
BLOCK_SIZE : tl.constexpr,
|
196 |
-
DO_SOFTCAPPING ,
|
197 |
-
SOFTCAP ,
|
198 |
-
DO_LOGIT_SCALING ,
|
199 |
-
LOGIT_SCALE ,
|
200 |
-
):
|
201 |
-
"""
|
202 |
-
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
203 |
-
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
204 |
-
|
205 |
-
From https://en.wikipedia.org/wiki/LogSumExp
|
206 |
-
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
207 |
-
|
208 |
-
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
209 |
-
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
210 |
-
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
211 |
-
|
212 |
-
If y == 0: dC/dx = 0
|
213 |
-
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
214 |
-
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
215 |
-
"""
|
216 |
-
row_idx = tl.program_id(0)
|
217 |
-
block_idx = tl.program_id(1)
|
218 |
-
|
219 |
-
logits_ptr += row_idx * triton_cast(logits_row_stride, tl.int64)
|
220 |
-
dloss_ptr += row_idx * dloss_row_stride
|
221 |
-
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
222 |
-
mask = col_offsets < VOCAB_SIZE
|
223 |
-
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
224 |
-
|
225 |
-
if label_idx != -100:
|
226 |
-
dloss = tl.load(dloss_ptr)
|
227 |
-
else:
|
228 |
-
dloss = 0.0
|
229 |
-
|
230 |
-
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
231 |
-
|
232 |
-
# Do logit scaling for Cohere
|
233 |
-
if DO_LOGIT_SCALING:
|
234 |
-
# d/dx [s * x] = s
|
235 |
-
x = x * LOGIT_SCALE
|
236 |
-
pass
|
237 |
-
|
238 |
-
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
239 |
-
partial = x
|
240 |
-
if DO_SOFTCAPPING:
|
241 |
-
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
242 |
-
partial = triton_tanh(x / SOFTCAP)
|
243 |
-
x = SOFTCAP * partial
|
244 |
-
pass
|
245 |
-
|
246 |
-
logsumexp = tl.load(logsumexp_ptr + row_idx)
|
247 |
-
y = tl.exp(x - logsumexp)
|
248 |
-
y = tl.where(
|
249 |
-
col_offsets == label_idx,
|
250 |
-
y - 1.0, # exp(x - logsumexp) - 1
|
251 |
-
y, # exp(x - logsumexp)
|
252 |
-
)
|
253 |
-
|
254 |
-
if DO_LOGIT_SCALING:
|
255 |
-
# d/dx [s * x] = s
|
256 |
-
y = y * LOGIT_SCALE
|
257 |
-
pass
|
258 |
-
|
259 |
-
if DO_SOFTCAPPING:
|
260 |
-
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
261 |
-
y = y * (1.0 - partial*partial)
|
262 |
-
pass
|
263 |
-
|
264 |
-
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
|
265 |
-
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
|
266 |
-
pass
|
267 |
-
|
268 |
-
|
269 |
-
MAX_FUSED_SIZE = 65536 # 2**16
|
270 |
-
|
271 |
-
class Fast_CrossEntropyLoss(torch.autograd.Function):
|
272 |
-
@staticmethod
|
273 |
-
def forward(ctx, logits, labels, logit_softcapping : float = 0, logit_scaling : float = 0):
|
274 |
-
n_rows : int
|
275 |
-
vocab_size : int
|
276 |
-
n_rows, vocab_size = logits.shape
|
277 |
-
|
278 |
-
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
|
279 |
-
n_chunks : int = div + (mod != 0)
|
280 |
-
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
281 |
-
|
282 |
-
DO_SOFTCAPPING : bool = bool(logit_softcapping != 0)
|
283 |
-
DO_LOGIT_SCALING : bool = bool(logit_scaling != 0)
|
284 |
-
|
285 |
-
BLOCK_SIZE : int
|
286 |
-
num_warps : int
|
287 |
-
if n_chunks == 1:
|
288 |
-
# For small vocabs <= 65336 like Llama, Mistral
|
289 |
-
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
|
290 |
-
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
291 |
-
|
292 |
-
_cross_entropy_forward[(n_rows,)](
|
293 |
-
logits, logits.stride(0),
|
294 |
-
losses,
|
295 |
-
logsumexp,
|
296 |
-
labels,
|
297 |
-
VOCAB_SIZE = vocab_size,
|
298 |
-
BLOCK_SIZE = BLOCK_SIZE,
|
299 |
-
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
300 |
-
SOFTCAP = logit_softcapping,
|
301 |
-
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
302 |
-
LOGIT_SCALE = logit_scaling,
|
303 |
-
num_warps = num_warps,
|
304 |
-
)
|
305 |
-
else:
|
306 |
-
# For large vocabs > 65336 like Gemma 256K
|
307 |
-
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
|
308 |
-
|
309 |
-
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
|
310 |
-
logits, logits.stride(0),
|
311 |
-
losses,
|
312 |
-
logsumexp,
|
313 |
-
labels,
|
314 |
-
VOCAB_SIZE = vocab_size,
|
315 |
-
N_CHUNKS = n_chunks,
|
316 |
-
BLOCK_SIZE = MAX_FUSED_SIZE,
|
317 |
-
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
318 |
-
SOFTCAP = logit_softcapping,
|
319 |
-
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
320 |
-
LOGIT_SCALE = logit_scaling,
|
321 |
-
num_warps = 32,
|
322 |
-
)
|
323 |
-
# logsumexp(chunked_logsumexp) - x
|
324 |
-
# Do the -x separately
|
325 |
-
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
|
326 |
-
losses += logsumexp
|
327 |
-
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
|
328 |
-
pass
|
329 |
-
|
330 |
-
ctx.save_for_backward(logits, logsumexp, labels)
|
331 |
-
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
|
332 |
-
ctx.logit_softcapping = logit_softcapping
|
333 |
-
ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
|
334 |
-
ctx.logit_scaling = logit_scaling
|
335 |
-
return losses
|
336 |
-
pass
|
337 |
-
|
338 |
-
|
339 |
-
@staticmethod
|
340 |
-
def backward(ctx, dlosses):
|
341 |
-
logits, logsumexp, labels = ctx.saved_tensors
|
342 |
-
n_rows : int
|
343 |
-
vocab_size : int
|
344 |
-
n_rows, vocab_size = logits.shape
|
345 |
-
|
346 |
-
BLOCK_SIZE : int = 4096
|
347 |
-
div : int
|
348 |
-
mod : int
|
349 |
-
div, mod = divmod(vocab_size, BLOCK_SIZE)
|
350 |
-
n_blocks : int = div + (mod != 0)
|
351 |
-
|
352 |
-
_cross_entropy_backward[(n_rows, n_blocks,)](
|
353 |
-
logits, logits.stride(0),
|
354 |
-
dlosses, dlosses.stride(0),
|
355 |
-
logsumexp,
|
356 |
-
labels,
|
357 |
-
VOCAB_SIZE = vocab_size,
|
358 |
-
BLOCK_SIZE = BLOCK_SIZE,
|
359 |
-
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
|
360 |
-
SOFTCAP = ctx.logit_softcapping,
|
361 |
-
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
|
362 |
-
LOGIT_SCALE = ctx.logit_scaling,
|
363 |
-
num_warps = 8,
|
364 |
-
)
|
365 |
-
return logits, None, None, None,
|
366 |
-
pass
|
367 |
-
pass
|
368 |
-
|
369 |
-
|
370 |
-
def fast_cross_entropy_loss(
|
371 |
-
logits,
|
372 |
-
labels,
|
373 |
-
logit_softcapping = 0,
|
374 |
-
logit_scaling = 0,
|
375 |
-
n_items = None,
|
376 |
-
):
|
377 |
-
"""
|
378 |
-
Arguments:
|
379 |
-
logits: (batch, seq_len, vocab_size)
|
380 |
-
labels: (batch, seq_len,)
|
381 |
-
Returns:
|
382 |
-
losses: float
|
383 |
-
"""
|
384 |
-
batch, seq_len, d = logits.shape
|
385 |
-
assert(labels.shape == (batch, seq_len))
|
386 |
-
|
387 |
-
loss = Fast_CrossEntropyLoss.apply(
|
388 |
-
logits.view(batch*seq_len, d),
|
389 |
-
labels.view(-1),
|
390 |
-
logit_softcapping,
|
391 |
-
logit_scaling,
|
392 |
-
)
|
393 |
-
if n_items is None:
|
394 |
-
n_items = torch.count_nonzero(labels != -100)
|
395 |
-
return loss.sum() / n_items
|
396 |
-
pass
|
397 |
-
if (Version(torch.__version__) < Version("2.4.0")) and \
|
398 |
-
not hasattr(fast_cross_entropy_loss, "__wrapped__"):
|
399 |
-
fast_cross_entropy_loss = torch._disable_dynamo(fast_cross_entropy_loss)
|
400 |
-
pass
|
401 |
-
|
402 |
-
# Patch CE Losses in transformers
|
403 |
-
def patch_loss_functions(torch_compile = True):
|
404 |
-
_patch_loss_functions(fast_cross_entropy_loss, torch_compile = torch_compile)
|
405 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/fast_lora.py
DELETED
@@ -1,490 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import torch
|
16 |
-
from .utils import (
|
17 |
-
fast_dequantize,
|
18 |
-
QUANT_STATE,
|
19 |
-
get_lora_parameters,
|
20 |
-
get_lora_parameters_bias,
|
21 |
-
matmul_lora,
|
22 |
-
torch_amp_custom_fwd,
|
23 |
-
torch_amp_custom_bwd,
|
24 |
-
)
|
25 |
-
|
26 |
-
|
27 |
-
class LoRA_MLP(torch.autograd.Function):
|
28 |
-
"""
|
29 |
-
### LoRA weights
|
30 |
-
G = G + Ag @ Bg
|
31 |
-
U = U + Au @ Bu
|
32 |
-
W = W + Aw @ Bw
|
33 |
-
|
34 |
-
### SwiGLU(X)
|
35 |
-
e = X @ G
|
36 |
-
f = e * sigmoid(e)
|
37 |
-
g = X @ U
|
38 |
-
h = f * g
|
39 |
-
i = h @ W
|
40 |
-
|
41 |
-
### Backpropagation chain rule
|
42 |
-
See our blog post for more details
|
43 |
-
|
44 |
-
df = sigmoid(e) * (1 - f) + f
|
45 |
-
dC/dW = h.T @ dY
|
46 |
-
dC/dU = X.T @ (D @ W.T * f)
|
47 |
-
dC/dG = X.T @ (D @ W.T * df * g)
|
48 |
-
|
49 |
-
### Down projection LoRA weights
|
50 |
-
dC/dAw = dC/dW @ B.T
|
51 |
-
dC/dBw = A.T @ dC/dW
|
52 |
-
dC/dAw = h.T @ dY @ B.T
|
53 |
-
dC/dBw = A.T @ h.T @ dY
|
54 |
-
|
55 |
-
### Up projection LoRA weights
|
56 |
-
dC/dAu = X.T @ (D @ W.T * f) @ B.T
|
57 |
-
dC/dBu = A.T @ X.T @ (D @ W.T * f)
|
58 |
-
|
59 |
-
### Gate projection LoRA weights
|
60 |
-
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
|
61 |
-
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
|
62 |
-
|
63 |
-
Don't forget to see our blog post for more details!
|
64 |
-
"""
|
65 |
-
@staticmethod
|
66 |
-
@torch_amp_custom_fwd
|
67 |
-
def forward(ctx, X : torch.Tensor,
|
68 |
-
gateW, gateW_quant, gateA, gateB, gateS,
|
69 |
-
upW, upW_quant, upA, upB, upS,
|
70 |
-
downW, downW_quant, downA, downB, downS,
|
71 |
-
_forward_function, _backward_function,
|
72 |
-
inplace = True,):
|
73 |
-
dtype = X.dtype
|
74 |
-
|
75 |
-
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
|
76 |
-
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
|
77 |
-
h = _forward_function(e, g)
|
78 |
-
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
|
79 |
-
|
80 |
-
ctx.custom_saved_tensors = (
|
81 |
-
gateW, gateW_quant, gateS,
|
82 |
-
upW, upW_quant, upS,
|
83 |
-
downW, downW_quant, downS,
|
84 |
-
_backward_function,
|
85 |
-
)
|
86 |
-
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
|
87 |
-
X, e, g)
|
88 |
-
ctx.inplace = inplace
|
89 |
-
return i
|
90 |
-
pass
|
91 |
-
|
92 |
-
|
93 |
-
@staticmethod
|
94 |
-
@torch_amp_custom_bwd
|
95 |
-
def backward(ctx, dY : torch.Tensor):
|
96 |
-
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
|
97 |
-
_backward_function = ctx.custom_saved_tensors
|
98 |
-
gateA, gateB, upA, upB, downA, downB, \
|
99 |
-
X, e, g = ctx.saved_tensors
|
100 |
-
|
101 |
-
gateA, gateB, upA, upB, downA, downB = \
|
102 |
-
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
|
103 |
-
|
104 |
-
batch, seq_len, hd = X.shape
|
105 |
-
dY = dY.view(-1, dY.shape[-1])
|
106 |
-
X = X .view(-1, X .shape[-1])
|
107 |
-
e = e .view(-1, e .shape[-1])
|
108 |
-
g = g .view(-1, g .shape[-1])
|
109 |
-
dtype = X.dtype
|
110 |
-
|
111 |
-
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
|
112 |
-
DW, e, g = _backward_function(DW, e, g)
|
113 |
-
h, df, de = DW, e, g
|
114 |
-
|
115 |
-
# Down projection LoRA weights
|
116 |
-
d_downA = h.t() @ (dY @ downB.t())
|
117 |
-
d_downB = (downA.t() @ h.t()) @ dY
|
118 |
-
d_downA *= downS
|
119 |
-
d_downB *= downS
|
120 |
-
|
121 |
-
# Up projection LoRA weights
|
122 |
-
d_upA = X.t() @ (df @ upB.t())
|
123 |
-
d_upB = (upA.t() @ X.t()) @ df
|
124 |
-
d_upA *= upS
|
125 |
-
d_upB *= upS
|
126 |
-
|
127 |
-
# Gate projection LoRA weights
|
128 |
-
d_gateA = X.t() @ (de @ gateB.t())
|
129 |
-
d_gateB = (gateA.t() @ X.t()) @ de
|
130 |
-
d_gateA *= gateS
|
131 |
-
d_gateB *= gateS
|
132 |
-
|
133 |
-
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
|
134 |
-
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
|
135 |
-
upW = fast_dequantize(upW.t(), upW_quant)
|
136 |
-
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
|
137 |
-
del upW
|
138 |
-
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
139 |
-
|
140 |
-
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
141 |
-
dX += de @ gateW.t()
|
142 |
-
del gateW
|
143 |
-
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
144 |
-
|
145 |
-
# gateW, gateW_quant, gateA, gateB, gateS,
|
146 |
-
# upW, upW_quant, upA, upB, upS,
|
147 |
-
# downW, downW_quant, downA, downB, downS,
|
148 |
-
return dX.view(batch, seq_len, hd), \
|
149 |
-
None, None, d_gateA.t(), d_gateB.t(), None, \
|
150 |
-
None, None, d_upA.t(), d_upB.t(), None, \
|
151 |
-
None, None, d_downA.t(), d_downB.t(), None, \
|
152 |
-
None, None, None, # _backward and _forward and inplace
|
153 |
-
pass
|
154 |
-
pass
|
155 |
-
|
156 |
-
|
157 |
-
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
158 |
-
def apply_lora_mlp_swiglu(self, X, inplace = True):
|
159 |
-
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
160 |
-
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
161 |
-
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
162 |
-
out = LoRA_MLP.apply(X,
|
163 |
-
gateW, gateW_quant, gateA, gateB, gateS,
|
164 |
-
upW, upW_quant, upA, upB, upS,
|
165 |
-
downW, downW_quant, downA, downB, downS,
|
166 |
-
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
|
167 |
-
inplace,)
|
168 |
-
return out
|
169 |
-
pass
|
170 |
-
|
171 |
-
|
172 |
-
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
173 |
-
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
|
174 |
-
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
175 |
-
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
176 |
-
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
177 |
-
out = LoRA_MLP.apply(X,
|
178 |
-
gateW, gateW_quant, gateA, gateB, gateS,
|
179 |
-
upW, upW_quant, upA, upB, upS,
|
180 |
-
downW, downW_quant, downA, downB, downS,
|
181 |
-
geglu_exact_forward_kernel, geglu_exact_backward_kernel,
|
182 |
-
inplace,)
|
183 |
-
return out
|
184 |
-
pass
|
185 |
-
|
186 |
-
|
187 |
-
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
|
188 |
-
def apply_lora_mlp_geglu_approx(self, X):
|
189 |
-
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
190 |
-
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
191 |
-
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
192 |
-
out = LoRA_MLP.apply(X,
|
193 |
-
gateW, gateW_quant, gateA, gateB, gateS,
|
194 |
-
upW, upW_quant, upA, upB, upS,
|
195 |
-
downW, downW_quant, downA, downB, downS,
|
196 |
-
geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
|
197 |
-
return out
|
198 |
-
pass
|
199 |
-
|
200 |
-
|
201 |
-
class LoRA_QKV(torch.autograd.Function):
|
202 |
-
"""
|
203 |
-
### LoRA weights
|
204 |
-
Wq = Wq + Aq @ Bq
|
205 |
-
Wk = Wk + Ak @ Bk
|
206 |
-
Wv = Wv + Av @ Bv
|
207 |
-
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
208 |
-
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
209 |
-
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
210 |
-
|
211 |
-
### Backpropagation chain rule
|
212 |
-
See our blogpost for more details.
|
213 |
-
|
214 |
-
dC/dWq = X.T @ D(Wq)
|
215 |
-
dC/dWk = X.T @ D(Wk)
|
216 |
-
dC/dWv = X.T @ D(Wv)
|
217 |
-
We then sum them all find dC/dX
|
218 |
-
|
219 |
-
### Q projection LoRA weights
|
220 |
-
dC/dAq = X.T @ D(Wq) @ B.T
|
221 |
-
dC/dBq = A.T @ X.T @ D(Wq)
|
222 |
-
|
223 |
-
### K projection LoRA weights
|
224 |
-
dC/dAk = X.T @ D(Wk) @ B.T
|
225 |
-
dC/dBk = A.T @ X.T @ D(Wk)
|
226 |
-
|
227 |
-
### V projection LoRA weights
|
228 |
-
dC/dAv = X.T @ D(Wv) @ B.T
|
229 |
-
dC/dBv = A.T @ X.T @ D(Wv)
|
230 |
-
"""
|
231 |
-
@staticmethod
|
232 |
-
@torch_amp_custom_fwd
|
233 |
-
def forward(ctx, X : torch.Tensor,
|
234 |
-
QW, QW_quant, QA, QB, QS,
|
235 |
-
KW, KW_quant, KA, KB, KS,
|
236 |
-
VW, VW_quant, VA, VB, VS,
|
237 |
-
inplace = True):
|
238 |
-
dtype = X.dtype
|
239 |
-
|
240 |
-
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
|
241 |
-
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
|
242 |
-
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
|
243 |
-
|
244 |
-
ctx.custom_saved_tensors = (
|
245 |
-
QW, QW_quant, QS,
|
246 |
-
KW, KW_quant, KS,
|
247 |
-
VW, VW_quant, VS,
|
248 |
-
)
|
249 |
-
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
|
250 |
-
ctx.inplace = inplace
|
251 |
-
return Q, K, V
|
252 |
-
pass
|
253 |
-
|
254 |
-
@staticmethod
|
255 |
-
@torch_amp_custom_bwd
|
256 |
-
def backward(ctx, dQ, dK, dV):
|
257 |
-
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
|
258 |
-
ctx.custom_saved_tensors
|
259 |
-
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
|
260 |
-
|
261 |
-
QA, QB, KA, KB, VA, VB = \
|
262 |
-
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
|
263 |
-
|
264 |
-
batch, seq_len, hd = X.shape
|
265 |
-
dQ = dQ.view(-1, dQ.shape[-1])
|
266 |
-
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
|
267 |
-
dV = dV.view(-1, dV.shape[-1])
|
268 |
-
X = X .view(-1, X .shape[-1])
|
269 |
-
dtype = X.dtype
|
270 |
-
|
271 |
-
### Weight projection LoRA weights
|
272 |
-
# See our blogpost for more details.
|
273 |
-
|
274 |
-
# Q Projection
|
275 |
-
d_QA = X.t() @ (dQ @ QB.t())
|
276 |
-
d_QB = (QA.t() @ X.t()) @ dQ
|
277 |
-
d_QA *= QS
|
278 |
-
d_QB *= QS
|
279 |
-
|
280 |
-
# K Projection
|
281 |
-
d_KA = X.t() @ (dK @ KB.t())
|
282 |
-
d_KB = (KA.t() @ X.t()) @ dK
|
283 |
-
d_KA *= KS
|
284 |
-
d_KB *= KS
|
285 |
-
|
286 |
-
# V Projection
|
287 |
-
d_VA = X.t() @ (dV @ VB.t())
|
288 |
-
d_VB = (VA.t() @ X.t()) @ dV
|
289 |
-
d_VA *= VS
|
290 |
-
d_VB *= VS
|
291 |
-
|
292 |
-
# Combine derivatives to find dX
|
293 |
-
# dQ
|
294 |
-
QW = fast_dequantize(QW.t(), QW_quant)
|
295 |
-
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
|
296 |
-
del QW
|
297 |
-
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
|
298 |
-
|
299 |
-
# dK
|
300 |
-
KW = fast_dequantize(KW.t(), KW_quant)
|
301 |
-
dX += dK @ KW.t()
|
302 |
-
del KW
|
303 |
-
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
|
304 |
-
|
305 |
-
# dV
|
306 |
-
VW = fast_dequantize(VW.t(), VW_quant)
|
307 |
-
dX += dV @ VW.t()
|
308 |
-
del VW
|
309 |
-
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
|
310 |
-
|
311 |
-
# QW, QW_quant, QA, QB, QS,
|
312 |
-
# KW, KW_quant, KA, KB, KS,
|
313 |
-
# VW, VW_quant, VA, VB, VS,
|
314 |
-
return dX.view(batch, seq_len, hd), \
|
315 |
-
None, None, d_QA.t(), d_QB.t(), None, \
|
316 |
-
None, None, d_KA.t(), d_KB.t(), None, \
|
317 |
-
None, None, d_VA.t(), d_VB.t(), None, \
|
318 |
-
None,
|
319 |
-
pass
|
320 |
-
pass
|
321 |
-
|
322 |
-
|
323 |
-
def apply_lora_qkv(self, X, inplace = True):
|
324 |
-
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
325 |
-
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
326 |
-
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
327 |
-
Q, K, V = LoRA_QKV.apply(X,
|
328 |
-
QW, QW_quant, QA, QB, QS,
|
329 |
-
KW, KW_quant, KA, KB, KS,
|
330 |
-
VW, VW_quant, VA, VB, VS,
|
331 |
-
inplace,
|
332 |
-
)
|
333 |
-
return Q, K, V
|
334 |
-
pass
|
335 |
-
|
336 |
-
|
337 |
-
class LoRA_W(torch.autograd.Function):
|
338 |
-
"""
|
339 |
-
### LoRA weights
|
340 |
-
Wq = Wq + Aq @ Bq
|
341 |
-
Wk = Wk + Ak @ Bk
|
342 |
-
Wv = Wv + Av @ Bv
|
343 |
-
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
344 |
-
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
345 |
-
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
346 |
-
|
347 |
-
### Backpropagation chain rule
|
348 |
-
dC/dWq = X.T @ D(Wq)
|
349 |
-
dC/dWk = X.T @ D(Wk)
|
350 |
-
dC/dWv = X.T @ D(Wv)
|
351 |
-
|
352 |
-
### Q projection LoRA weights
|
353 |
-
dC/dAq = X.T @ D(Wq) @ B.T
|
354 |
-
dC/dBq = A.T @ X.T @ D(Wq)
|
355 |
-
|
356 |
-
### K projection LoRA weights
|
357 |
-
dC/dAk = X.T @ D(Wk) @ B.T
|
358 |
-
dC/dBk = A.T @ X.T @ D(Wk)
|
359 |
-
|
360 |
-
### V projection LoRA weights
|
361 |
-
dC/dAv = X.T @ D(Wv) @ B.T
|
362 |
-
dC/dBv = A.T @ X.T @ D(Wv)
|
363 |
-
"""
|
364 |
-
@staticmethod
|
365 |
-
@torch_amp_custom_fwd
|
366 |
-
def forward(ctx, X : torch.Tensor,
|
367 |
-
W, W_quant, A, B, S):
|
368 |
-
dtype = X.dtype
|
369 |
-
XW = matmul_lora(X, W, W_quant, A, B, S)
|
370 |
-
ctx.custom_saved_tensors = (W, W_quant, S,)
|
371 |
-
ctx.save_for_backward(A, B, X)
|
372 |
-
return XW
|
373 |
-
pass
|
374 |
-
|
375 |
-
@staticmethod
|
376 |
-
@torch_amp_custom_bwd
|
377 |
-
def backward(ctx, dY : torch.Tensor):
|
378 |
-
W, W_quant, S = ctx.custom_saved_tensors
|
379 |
-
A, B, X = ctx.saved_tensors
|
380 |
-
|
381 |
-
A, B = A.t(), B.t()
|
382 |
-
|
383 |
-
batch, seq_len, hd = X.shape
|
384 |
-
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
|
385 |
-
X = X .reshape(-1, X .shape[-1]) # Must be reshape
|
386 |
-
dtype = X.dtype
|
387 |
-
|
388 |
-
### Weight projection LoRA weights
|
389 |
-
# Weight projection
|
390 |
-
d_A = X.t() @ (dY @ B.t())
|
391 |
-
d_B = (A.t() @ X.t()) @ dY
|
392 |
-
d_A *= S
|
393 |
-
d_B *= S
|
394 |
-
|
395 |
-
# Get derivative for dX
|
396 |
-
W = fast_dequantize(W.t(), W_quant)
|
397 |
-
dX = dY @ W.t()
|
398 |
-
del W
|
399 |
-
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
|
400 |
-
|
401 |
-
# W, W_quant, A, B, S
|
402 |
-
return dX.view(batch, seq_len, hd), \
|
403 |
-
None, None, d_A.t(), d_B.t(), None
|
404 |
-
pass
|
405 |
-
pass
|
406 |
-
|
407 |
-
|
408 |
-
def apply_lora_o(self, X):
|
409 |
-
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
410 |
-
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
|
411 |
-
return O
|
412 |
-
pass
|
413 |
-
|
414 |
-
|
415 |
-
IDENTITY_DROPOUT = torch.nn.Identity
|
416 |
-
@torch._disable_dynamo
|
417 |
-
def fast_lora_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
418 |
-
raise NotImplementedError(
|
419 |
-
"Unsloth: Currently not supported yet - reshaping done incorrectly"
|
420 |
-
)
|
421 |
-
self._check_forward_args(x, *args, **kwargs)
|
422 |
-
adapter_names = kwargs.pop("adapter_names", None)
|
423 |
-
|
424 |
-
if self.disable_adapters:
|
425 |
-
if self.merged:
|
426 |
-
self.unmerge()
|
427 |
-
result = self.base_layer(x, *args, **kwargs)
|
428 |
-
elif adapter_names is not None:
|
429 |
-
result = self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
|
430 |
-
elif self.merged:
|
431 |
-
result = self.base_layer(x, *args, **kwargs)
|
432 |
-
else:
|
433 |
-
# Fastpath
|
434 |
-
if len(self.active_adapters) == 1:
|
435 |
-
active_adapter = self.active_adapters[0]
|
436 |
-
if active_adapter not in self.lora_A.keys(): return self.base_layer(x, *args, **kwargs)
|
437 |
-
|
438 |
-
dropout = self.lora_dropout[active_adapter]
|
439 |
-
if isinstance(dropout, IDENTITY_DROPOUT) and not self.use_dora[active_adapter]:
|
440 |
-
lora_A = self.lora_A[active_adapter].weight
|
441 |
-
lora_B = self.lora_B[active_adapter].weight
|
442 |
-
scaling = self.scaling[active_adapter]
|
443 |
-
W = self.base_layer.weight
|
444 |
-
return LoRA_W.apply(x, W, QUANT_STATE(W), lora_A, lora_B, scaling)
|
445 |
-
pass
|
446 |
-
pass
|
447 |
-
|
448 |
-
result = self.base_layer(x, *args, **kwargs)
|
449 |
-
# As per Tim Dettmers, for 4bit, we need to defensively clone here.
|
450 |
-
# The reason is that in some cases, an error can occur that backprop
|
451 |
-
# does not work on a manipulated view. This issue may be solved with
|
452 |
-
# newer PyTorch versions but this would need extensive testing to be
|
453 |
-
# sure.
|
454 |
-
result = result.clone()
|
455 |
-
|
456 |
-
for active_adapter in self.active_adapters:
|
457 |
-
if active_adapter not in self.lora_A.keys():
|
458 |
-
continue
|
459 |
-
lora_A = self.lora_A[active_adapter]
|
460 |
-
lora_B = self.lora_B[active_adapter]
|
461 |
-
dropout = self.lora_dropout[active_adapter]
|
462 |
-
scaling = self.scaling[active_adapter]
|
463 |
-
|
464 |
-
requires_conversion = not torch.is_autocast_enabled()
|
465 |
-
if requires_conversion:
|
466 |
-
expected_dtype = result.dtype
|
467 |
-
x = x.to(lora_A.weight.dtype)
|
468 |
-
|
469 |
-
if not self.use_dora[active_adapter]:
|
470 |
-
result = result + lora_B(lora_A(dropout(x))) * scaling
|
471 |
-
else:
|
472 |
-
if isinstance(dropout, torch.nn.Identity) or not self.training:
|
473 |
-
base_result = result
|
474 |
-
else:
|
475 |
-
x = dropout(x)
|
476 |
-
base_result = None
|
477 |
-
|
478 |
-
result = result + self.lora_magnitude_vector[active_adapter](
|
479 |
-
x,
|
480 |
-
lora_A=lora_A,
|
481 |
-
lora_B=lora_B,
|
482 |
-
scaling=scaling,
|
483 |
-
base_layer=self.get_base_layer(),
|
484 |
-
base_result=base_result,
|
485 |
-
)
|
486 |
-
if requires_conversion:
|
487 |
-
result = result.to(expected_dtype)
|
488 |
-
|
489 |
-
return result
|
490 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/flex_attention.py
DELETED
@@ -1,181 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import torch
|
16 |
-
from functools import lru_cache
|
17 |
-
from transformers.models.llama.modeling_llama import logger
|
18 |
-
import os
|
19 |
-
|
20 |
-
torch_compile_options = {
|
21 |
-
"epilogue_fusion" : True,
|
22 |
-
"max_autotune" : True,
|
23 |
-
"shape_padding" : True,
|
24 |
-
"trace.enabled" : os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1",
|
25 |
-
"triton.cudagraphs" : False,
|
26 |
-
}
|
27 |
-
|
28 |
-
# Flex Attention supported from torch 2.5 onwards only
|
29 |
-
try:
|
30 |
-
from torch.nn.attention.flex_attention import (
|
31 |
-
flex_attention as _flex_attention,
|
32 |
-
create_block_mask as _create_block_mask,
|
33 |
-
)
|
34 |
-
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
|
35 |
-
HAS_FLEX_ATTENTION = False
|
36 |
-
except:
|
37 |
-
HAS_FLEX_ATTENTION = False
|
38 |
-
pass
|
39 |
-
|
40 |
-
|
41 |
-
if not HAS_FLEX_ATTENTION:
|
42 |
-
|
43 |
-
# Logit softcapping
|
44 |
-
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
45 |
-
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
46 |
-
n_heads = self.num_heads
|
47 |
-
head_dim = self.head_dim
|
48 |
-
n_kv_heads = self.num_key_value_heads
|
49 |
-
n_groups = self.num_key_value_groups
|
50 |
-
|
51 |
-
# Grouped query attention
|
52 |
-
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
53 |
-
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
54 |
-
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
55 |
-
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
56 |
-
|
57 |
-
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
58 |
-
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
59 |
-
# We default to using the config file itself
|
60 |
-
# s = self.config.hidden_size // self.config.num_attention_heads
|
61 |
-
s = self.config.query_pre_attn_scalar
|
62 |
-
t = self.config.attn_logit_softcapping
|
63 |
-
|
64 |
-
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
65 |
-
A = torch.matmul(Q, K.transpose(2, 3))
|
66 |
-
A = t * torch.tanh(A / t) # Logit softcapping
|
67 |
-
A += causal_mask[:q_len, :q_len]
|
68 |
-
# Much slower in torch compile!
|
69 |
-
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
70 |
-
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
71 |
-
A = torch.matmul(A, V)
|
72 |
-
A = A.transpose(1, 2).contiguous()
|
73 |
-
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
74 |
-
return A
|
75 |
-
pass
|
76 |
-
|
77 |
-
create_flex_attention_causal_mask = None
|
78 |
-
create_flex_attention_sliding_window_mask = None
|
79 |
-
else:
|
80 |
-
# See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
|
81 |
-
# for more examples
|
82 |
-
# BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
|
83 |
-
import functools, math
|
84 |
-
|
85 |
-
def generate_tanh_softcap(t):
|
86 |
-
def tanh_softcap(x, b, h, q_idx, kv_idx):
|
87 |
-
return t * torch.tanh(x / t)
|
88 |
-
return tanh_softcap
|
89 |
-
pass
|
90 |
-
def causal_masker(b, h, q_idx, kv_idx):
|
91 |
-
return q_idx >= kv_idx
|
92 |
-
pass
|
93 |
-
|
94 |
-
@functools.lru_cache
|
95 |
-
def sliding_window_masker(size = 4096):
|
96 |
-
def sliding_window(b, h, q_idx, kv_idx):
|
97 |
-
causal_mask = q_idx >= kv_idx
|
98 |
-
window_mask = q_idx - kv_idx <= size
|
99 |
-
return causal_mask & window_mask
|
100 |
-
return sliding_window
|
101 |
-
pass
|
102 |
-
|
103 |
-
@functools.lru_cache
|
104 |
-
def create_block_mask(mask, n = 128):
|
105 |
-
return _create_block_mask(
|
106 |
-
mask, 1, 1, n, n,
|
107 |
-
BLOCK_SIZE = 128,
|
108 |
-
_compile = True,
|
109 |
-
)
|
110 |
-
pass
|
111 |
-
|
112 |
-
def create_flex_attention_causal_mask(max_seq_length = 8192):
|
113 |
-
causal_mask = create_block_mask(causal_masker, max_seq_length)
|
114 |
-
return causal_mask
|
115 |
-
pass
|
116 |
-
|
117 |
-
def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
|
118 |
-
sliding_masker = sliding_window_masker(sliding_window)
|
119 |
-
causal_mask = create_block_mask(sliding_masker, max_seq_length)
|
120 |
-
return causal_mask
|
121 |
-
pass
|
122 |
-
|
123 |
-
@functools.lru_cache
|
124 |
-
def flex_attention(s, t):
|
125 |
-
scale = 1.0 / math.sqrt(s)
|
126 |
-
score_mod = generate_tanh_softcap(t)
|
127 |
-
return functools.partial(
|
128 |
-
_flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
|
129 |
-
)
|
130 |
-
pass
|
131 |
-
|
132 |
-
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
133 |
-
n_heads = self.num_heads
|
134 |
-
head_dim = self.head_dim
|
135 |
-
s = self.config.query_pre_attn_scalar
|
136 |
-
t = self.config.attn_logit_softcapping
|
137 |
-
fx = flex_attention(s, t)
|
138 |
-
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
|
139 |
-
A = A.transpose(1, 2).contiguous()
|
140 |
-
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
141 |
-
return A
|
142 |
-
pass
|
143 |
-
pass
|
144 |
-
|
145 |
-
|
146 |
-
torch_matmul = torch.matmul
|
147 |
-
torch_tanh = torch.tanh
|
148 |
-
torch_nn_functional_softmax = torch.nn.functional.softmax
|
149 |
-
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
150 |
-
n_heads = self.num_heads
|
151 |
-
head_dim = self.head_dim
|
152 |
-
n_kv_heads = self.num_key_value_heads
|
153 |
-
n_groups = self.num_key_value_groups
|
154 |
-
|
155 |
-
# Grouped query attention
|
156 |
-
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
157 |
-
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
158 |
-
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
159 |
-
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
160 |
-
|
161 |
-
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
162 |
-
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
163 |
-
# We default to using the config file itself
|
164 |
-
# s = self.config.hidden_size // self.config.num_attention_heads
|
165 |
-
s = self.config.query_pre_attn_scalar
|
166 |
-
t = self.config.attn_logit_softcapping
|
167 |
-
|
168 |
-
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
169 |
-
A = torch_matmul(Q, K.transpose(2, 3))
|
170 |
-
|
171 |
-
# Logit softcapping
|
172 |
-
A /= t; torch_tanh(A, out = A); A *= t;
|
173 |
-
A += causal_mask[:q_len, :q_len]
|
174 |
-
# Much slower in torch compile!
|
175 |
-
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
176 |
-
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
177 |
-
A = torch_matmul(A, V)
|
178 |
-
A = A.transpose(1, 2).contiguous()
|
179 |
-
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
180 |
-
return A
|
181 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/geglu.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import triton
|
16 |
-
import triton.language as tl
|
17 |
-
import torch
|
18 |
-
from .utils import calculate_settings, triton_tanh
|
19 |
-
|
20 |
-
|
21 |
-
@triton.jit
|
22 |
-
def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
23 |
-
block_idx = tl.program_id(0)
|
24 |
-
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
25 |
-
mask = offsets < n_elements
|
26 |
-
|
27 |
-
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
28 |
-
# h = f * up
|
29 |
-
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
30 |
-
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
31 |
-
|
32 |
-
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
33 |
-
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
34 |
-
h_row = f_row * g_row
|
35 |
-
|
36 |
-
# Store h
|
37 |
-
tl.store(h + offsets, h_row, mask = mask)
|
38 |
-
pass
|
39 |
-
|
40 |
-
|
41 |
-
def geglu_exact_forward_kernel(gate, up):
|
42 |
-
batch, seq_len, hd = gate.shape
|
43 |
-
n_elements = gate.numel()
|
44 |
-
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
|
45 |
-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
46 |
-
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
47 |
-
return out
|
48 |
-
pass
|
49 |
-
|
50 |
-
|
51 |
-
@triton.jit
|
52 |
-
def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
53 |
-
"""
|
54 |
-
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
55 |
-
h = f * up
|
56 |
-
|
57 |
-
df/de (with help of Wolfram :)
|
58 |
-
df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
59 |
-
|
60 |
-
Reuse via
|
61 |
-
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
|
62 |
-
"""
|
63 |
-
block_idx = tl.program_id(0)
|
64 |
-
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
65 |
-
mask = offsets < n_elements
|
66 |
-
|
67 |
-
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
68 |
-
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
69 |
-
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
70 |
-
|
71 |
-
# Break e_row away for re-use
|
72 |
-
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
73 |
-
f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
74 |
-
f_row = f_partial_row * e_row
|
75 |
-
|
76 |
-
f_row = f_row.to(DW_row.dtype)
|
77 |
-
# h = f * g
|
78 |
-
h_row = f_row * g_row
|
79 |
-
# df = DW * f
|
80 |
-
df_row = DW_row * f_row
|
81 |
-
# dg = DW * g
|
82 |
-
dg_row = DW_row * g_row
|
83 |
-
|
84 |
-
# df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
85 |
-
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
86 |
-
df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
|
87 |
-
|
88 |
-
de_row = dg_row.to(tl.float32) * df_de
|
89 |
-
de_row = de_row.to(DW_row.dtype)
|
90 |
-
|
91 |
-
# Store derivatives in buffers
|
92 |
-
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
93 |
-
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
94 |
-
tl.store(g + offsets, de_row, mask = mask) # de
|
95 |
-
pass
|
96 |
-
|
97 |
-
|
98 |
-
def geglu_exact_backward_kernel(DW, e, g):
|
99 |
-
batch_seq_len, hd = e.shape
|
100 |
-
n_elements = e.numel()
|
101 |
-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
102 |
-
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
103 |
-
return DW, e, g
|
104 |
-
pass
|
105 |
-
|
106 |
-
|
107 |
-
@triton.jit
|
108 |
-
def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
109 |
-
block_idx = tl.program_id(0)
|
110 |
-
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
111 |
-
mask = offsets < n_elements
|
112 |
-
|
113 |
-
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
114 |
-
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
115 |
-
# h = f * up
|
116 |
-
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
117 |
-
|
118 |
-
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
119 |
-
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
120 |
-
|
121 |
-
f_row = 0.5 * e_row * (
|
122 |
-
triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
|
123 |
-
+ 1.0
|
124 |
-
)
|
125 |
-
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
126 |
-
h_row = f_row * g_row
|
127 |
-
|
128 |
-
# Store h
|
129 |
-
tl.store(h + offsets, h_row, mask = mask)
|
130 |
-
pass
|
131 |
-
|
132 |
-
|
133 |
-
def geglu_approx_forward_kernel(gate, up):
|
134 |
-
batch, seq_len, hd = gate.shape
|
135 |
-
n_elements = gate.numel()
|
136 |
-
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
|
137 |
-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
138 |
-
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
139 |
-
return out
|
140 |
-
pass
|
141 |
-
|
142 |
-
|
143 |
-
@triton.jit
|
144 |
-
def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
145 |
-
"""
|
146 |
-
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
147 |
-
h = f * up
|
148 |
-
|
149 |
-
df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
|
150 |
-
df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
|
151 |
-
1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
|
152 |
-
( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
|
153 |
-
|
154 |
-
Notice sech^2(x) = 1 - tanh^2(x)
|
155 |
-
So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
|
156 |
-
|
157 |
-
See https://www.desmos.com/calculator/nqprfoni6x
|
158 |
-
"""
|
159 |
-
block_idx = tl.program_id(0)
|
160 |
-
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
161 |
-
mask = offsets < n_elements
|
162 |
-
|
163 |
-
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
164 |
-
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
165 |
-
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
166 |
-
|
167 |
-
# See https://www.desmos.com/calculator/nqprfoni6x
|
168 |
-
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
169 |
-
a = s * e_row # a = sqrt(2 / pi) * x
|
170 |
-
b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
|
171 |
-
T = 1.0 + triton_tanh(a + b)
|
172 |
-
T2 = 0.5 * T
|
173 |
-
# Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
|
174 |
-
Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
|
175 |
-
df_de = T2 + Q2 # 1/2 * (T + Q)
|
176 |
-
|
177 |
-
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
178 |
-
f_row = T2 * e_row
|
179 |
-
f_row = f_row.to(DW_row.dtype)
|
180 |
-
# h = f * g
|
181 |
-
h_row = f_row * g_row
|
182 |
-
# df = DW * f
|
183 |
-
df_row = DW_row * f_row
|
184 |
-
# dg = DW * g
|
185 |
-
dg_row = DW_row * g_row
|
186 |
-
|
187 |
-
de_row = dg_row.to(tl.float32) * df_de
|
188 |
-
de_row = de_row.to(DW_row.dtype)
|
189 |
-
|
190 |
-
# Store derivatives in buffers
|
191 |
-
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
192 |
-
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
193 |
-
tl.store(g + offsets, de_row, mask = mask) # de
|
194 |
-
pass
|
195 |
-
|
196 |
-
|
197 |
-
def geglu_approx_backward_kernel(DW, e, g):
|
198 |
-
batch_seq_len, hd = e.shape
|
199 |
-
n_elements = e.numel()
|
200 |
-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
201 |
-
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
202 |
-
return DW, e, g
|
203 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/layernorm.py
DELETED
@@ -1,213 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
# Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
|
16 |
-
import triton
|
17 |
-
import triton.language as tl
|
18 |
-
import torch
|
19 |
-
from .utils import calculate_settings
|
20 |
-
from unsloth_zoo.patching_utils import (
|
21 |
-
patch_layernorm,
|
22 |
-
)
|
23 |
-
|
24 |
-
|
25 |
-
@triton.jit
|
26 |
-
def layernorm_forward(
|
27 |
-
Y, Y_row_stride,
|
28 |
-
X, X_row_stride,
|
29 |
-
W,
|
30 |
-
b,
|
31 |
-
r,
|
32 |
-
mu,
|
33 |
-
n_cols, eps,
|
34 |
-
BLOCK_SIZE : tl.constexpr
|
35 |
-
):
|
36 |
-
row_idx = tl.program_id(0)
|
37 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
38 |
-
mask = col_offsets < n_cols
|
39 |
-
|
40 |
-
Y += row_idx * Y_row_stride
|
41 |
-
X += row_idx * X_row_stride
|
42 |
-
r += row_idx
|
43 |
-
mu += row_idx
|
44 |
-
|
45 |
-
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
46 |
-
# are in float32!
|
47 |
-
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
48 |
-
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
49 |
-
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
50 |
-
|
51 |
-
mean_X = tl.sum(X_row, axis = 0) / n_cols
|
52 |
-
XX = X_row - mean_X
|
53 |
-
row_var = tl.sum(XX * XX, axis = 0) / n_cols
|
54 |
-
inv_var = tl.math.rsqrt(row_var + eps)
|
55 |
-
tl.store (r, inv_var)
|
56 |
-
tl.store (mu, mean_X)
|
57 |
-
output = (XX * inv_var) * W_row + b_row
|
58 |
-
tl.store(Y + col_offsets, output, mask = mask)
|
59 |
-
pass
|
60 |
-
|
61 |
-
|
62 |
-
@triton.jit
|
63 |
-
def layernorm_backward(
|
64 |
-
dY, dY_row_stride,
|
65 |
-
X, X_row_stride,
|
66 |
-
W,
|
67 |
-
b,
|
68 |
-
r,
|
69 |
-
mu,
|
70 |
-
n_cols, eps,
|
71 |
-
BLOCK_SIZE : tl.constexpr
|
72 |
-
):
|
73 |
-
# Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
74 |
-
row_idx = tl.program_id(0)
|
75 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
76 |
-
mask = col_offsets < n_cols
|
77 |
-
|
78 |
-
dY += row_idx * dY_row_stride
|
79 |
-
X += row_idx * X_row_stride
|
80 |
-
r += row_idx
|
81 |
-
mu += row_idx
|
82 |
-
|
83 |
-
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
84 |
-
# are in float32!
|
85 |
-
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
86 |
-
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
87 |
-
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
88 |
-
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
89 |
-
|
90 |
-
inv_var = tl.load(r) .to(tl.float32)
|
91 |
-
mean = tl.load(mu).to(tl.float32)
|
92 |
-
normed = (X_row - mean) * inv_var
|
93 |
-
dY_W = dY_row * W_row
|
94 |
-
dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
|
95 |
-
dX_row = dX_row * inv_var
|
96 |
-
tl.store(dY + col_offsets, dX_row, mask = mask)
|
97 |
-
pass
|
98 |
-
|
99 |
-
|
100 |
-
class Fast_Layernorm(torch.autograd.Function):
|
101 |
-
@staticmethod
|
102 |
-
def forward(ctx, X, W, b, eps):
|
103 |
-
shape = X.shape
|
104 |
-
dim = shape[-1]
|
105 |
-
X = X.view(-1, dim)
|
106 |
-
n_rows, n_cols = X.shape
|
107 |
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
108 |
-
|
109 |
-
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
|
110 |
-
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
111 |
-
mu = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
112 |
-
|
113 |
-
layernorm_forward[(n_rows,)](
|
114 |
-
Y, Y.stride(0),
|
115 |
-
X, X.stride(0),
|
116 |
-
W,
|
117 |
-
b,
|
118 |
-
r,
|
119 |
-
mu,
|
120 |
-
n_cols, eps,
|
121 |
-
BLOCK_SIZE = BLOCK_SIZE,
|
122 |
-
num_warps = num_warps,
|
123 |
-
)
|
124 |
-
ctx.eps = eps
|
125 |
-
ctx.BLOCK_SIZE = BLOCK_SIZE
|
126 |
-
ctx.num_warps = num_warps
|
127 |
-
ctx.save_for_backward(X, W, b, r, mu)
|
128 |
-
return Y.view(*shape)
|
129 |
-
pass
|
130 |
-
|
131 |
-
@staticmethod
|
132 |
-
def backward(ctx, dY):
|
133 |
-
shape = dY.shape
|
134 |
-
dim = shape[-1]
|
135 |
-
dY = dY.view(-1, dim)
|
136 |
-
X, W, b, r, mu = ctx.saved_tensors
|
137 |
-
n_rows, n_cols = dY.shape
|
138 |
-
|
139 |
-
layernorm_backward[(n_rows,)](
|
140 |
-
dY, dY.stride(0),
|
141 |
-
X, X .stride(0),
|
142 |
-
W,
|
143 |
-
b,
|
144 |
-
r,
|
145 |
-
mu,
|
146 |
-
n_cols, ctx.eps,
|
147 |
-
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
148 |
-
num_warps = ctx.num_warps,
|
149 |
-
)
|
150 |
-
dX = dY.view(*shape)
|
151 |
-
return dX, None, None, None, None
|
152 |
-
pass
|
153 |
-
pass
|
154 |
-
|
155 |
-
|
156 |
-
def fast_layernorm(layernorm, X):
|
157 |
-
assert(layernorm.elementwise_affine is True)
|
158 |
-
W = layernorm.weight
|
159 |
-
bias = layernorm.bias
|
160 |
-
eps = layernorm.variance_epsilon if \
|
161 |
-
hasattr(layernorm, "variance_epsilon") \
|
162 |
-
else layernorm.eps
|
163 |
-
out = Fast_Layernorm.apply(X, W, bias, eps)
|
164 |
-
return out
|
165 |
-
pass
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
def test_layernorm(
|
170 |
-
dim = 1024, eps = 1e-5, dtype = torch.float16,
|
171 |
-
bsz = 21, random_state = 3407, seqlen = 3341,
|
172 |
-
):
|
173 |
-
from torch.nn import LayerNorm
|
174 |
-
layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
|
175 |
-
torch.cuda.manual_seed(random_state)
|
176 |
-
torch.manual_seed(random_state)
|
177 |
-
torch.nn.init.uniform_(layernorm.weight)
|
178 |
-
torch.nn.init.uniform_(layernorm.bias)
|
179 |
-
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
|
180 |
-
XX = X.clone()
|
181 |
-
X .requires_grad_(True)
|
182 |
-
XX.requires_grad_(True)
|
183 |
-
Y = layernorm(X)
|
184 |
-
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
|
185 |
-
Y.backward(YY)
|
186 |
-
correct_grad = X.grad.clone()
|
187 |
-
# from unsloth.kernels import fast_layernorm
|
188 |
-
Y = fast_layernorm(layernorm, XX)
|
189 |
-
Y.backward(YY)
|
190 |
-
assert(torch.dist(correct_grad, XX.grad).item() <= 0.1)
|
191 |
-
pass
|
192 |
-
|
193 |
-
|
194 |
-
def testing_suite_layernorm():
|
195 |
-
for dim in [512, 1024, 2048]:
|
196 |
-
for dtype in [torch.float16, torch.bfloat16]:
|
197 |
-
with torch.autocast(device_type = "cuda", dtype = dtype):
|
198 |
-
for seqlen in [3341, 2048, 349]:
|
199 |
-
for random_state in [3407, 42]:
|
200 |
-
test_layernorm(
|
201 |
-
dim = dim,
|
202 |
-
eps = 1e-5,
|
203 |
-
dtype = dtype,
|
204 |
-
bsz = 21,
|
205 |
-
random_state = random_state,
|
206 |
-
seqlen = seqlen,
|
207 |
-
)
|
208 |
-
pass
|
209 |
-
pass
|
210 |
-
pass
|
211 |
-
pass
|
212 |
-
pass
|
213 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/rms_layernorm.py
DELETED
@@ -1,297 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import triton
|
16 |
-
import triton.language as tl
|
17 |
-
import torch
|
18 |
-
from .utils import calculate_settings
|
19 |
-
|
20 |
-
|
21 |
-
@triton.jit
|
22 |
-
def _rms_layernorm_forward(
|
23 |
-
Y, Y_row_stride,
|
24 |
-
X, X_row_stride,
|
25 |
-
W, W_row_stride,
|
26 |
-
r, r_row_stride,
|
27 |
-
n_cols, eps,
|
28 |
-
BLOCK_SIZE : tl.constexpr
|
29 |
-
):
|
30 |
-
"""
|
31 |
-
Fast RMS Layernorm kernel
|
32 |
-
Inspiration from a Triton tutorial:
|
33 |
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
34 |
-
"""
|
35 |
-
row_idx = tl.program_id(0)
|
36 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
37 |
-
mask = col_offsets < n_cols
|
38 |
-
|
39 |
-
Y += row_idx * Y_row_stride
|
40 |
-
X += row_idx * X_row_stride
|
41 |
-
r += row_idx * r_row_stride
|
42 |
-
|
43 |
-
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
44 |
-
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
|
45 |
-
|
46 |
-
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
47 |
-
inv_var = tl.math.rsqrt(row_var + eps)
|
48 |
-
tl.store(r, inv_var)
|
49 |
-
normed = X_row * inv_var
|
50 |
-
normed = normed.to(W_row.dtype) # Exact copy from HF
|
51 |
-
output = normed * W_row
|
52 |
-
tl.store(Y + col_offsets, output, mask = mask)
|
53 |
-
pass
|
54 |
-
|
55 |
-
|
56 |
-
@triton.heuristics({"GEMMA": lambda args: bool(args["GEMMA"]),})
|
57 |
-
@triton.jit
|
58 |
-
def _rms_layernorm_backward(
|
59 |
-
dY, dY_row_stride,
|
60 |
-
dX, dX_row_stride,
|
61 |
-
X, X_row_stride,
|
62 |
-
W, W_row_stride,
|
63 |
-
r, r_row_stride,
|
64 |
-
# dW, dW_row_stride,
|
65 |
-
n_cols, eps,
|
66 |
-
GEMMA : tl.constexpr,
|
67 |
-
BLOCK_SIZE : tl.constexpr,
|
68 |
-
):
|
69 |
-
"""
|
70 |
-
Fast RMS Layernorm kernel for the backward pass
|
71 |
-
Inspiration from a Triton tutorial:
|
72 |
-
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
73 |
-
"""
|
74 |
-
row_idx = tl.program_id(0)
|
75 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
76 |
-
mask = col_offsets < n_cols
|
77 |
-
|
78 |
-
dY += row_idx * dY_row_stride
|
79 |
-
X += row_idx * X_row_stride
|
80 |
-
r += row_idx * r_row_stride
|
81 |
-
|
82 |
-
if GEMMA: dX += row_idx * dY_row_stride
|
83 |
-
else: dX = dY
|
84 |
-
|
85 |
-
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
86 |
-
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
87 |
-
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
88 |
-
|
89 |
-
# Get saved row variance
|
90 |
-
inv_var = tl.load(r).to(tl.float32)
|
91 |
-
normed = X_row * inv_var
|
92 |
-
|
93 |
-
if GEMMA: dY_W = dY_row * (W_row + 1.0)
|
94 |
-
else: dY_W = dY_row * W_row
|
95 |
-
|
96 |
-
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
97 |
-
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
98 |
-
tl.store(dX + col_offsets, output, mask = mask)
|
99 |
-
pass
|
100 |
-
|
101 |
-
|
102 |
-
@triton.jit
|
103 |
-
def _gemma_rms_layernorm_forward(
|
104 |
-
Y, Y_row_stride,
|
105 |
-
X, X_row_stride,
|
106 |
-
W, W_row_stride,
|
107 |
-
r, r_row_stride,
|
108 |
-
n_cols, eps,
|
109 |
-
BLOCK_SIZE : tl.constexpr,
|
110 |
-
):
|
111 |
-
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
|
112 |
-
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
|
113 |
-
# exactly. Essentially all in float32!
|
114 |
-
row_idx = tl.program_id(0)
|
115 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
116 |
-
mask = col_offsets < n_cols
|
117 |
-
|
118 |
-
Y += row_idx * Y_row_stride
|
119 |
-
X += row_idx * X_row_stride
|
120 |
-
r += row_idx * r_row_stride
|
121 |
-
|
122 |
-
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
123 |
-
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
124 |
-
|
125 |
-
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
126 |
-
inv_var = tl.math.rsqrt(row_var + eps)
|
127 |
-
tl.store(r, inv_var)
|
128 |
-
normed = X_row * inv_var
|
129 |
-
output = normed * (W_row + 1.0)
|
130 |
-
|
131 |
-
tl.store(Y + col_offsets, output, mask = mask)
|
132 |
-
pass
|
133 |
-
|
134 |
-
|
135 |
-
class Fast_RMS_Layernorm(torch.autograd.Function):
|
136 |
-
@staticmethod
|
137 |
-
def forward(ctx, X : torch.Tensor, W : torch.Tensor, eps : float, gemma : bool = False):
|
138 |
-
shape = X.shape
|
139 |
-
dim : int = shape[-1]
|
140 |
-
X = X.view(-1, dim)
|
141 |
-
n_rows : int
|
142 |
-
n_cols : int
|
143 |
-
n_rows, n_cols = X.shape
|
144 |
-
BLOCK_SIZE : int
|
145 |
-
num_warps : int
|
146 |
-
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
147 |
-
|
148 |
-
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
|
149 |
-
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
150 |
-
|
151 |
-
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
|
152 |
-
fx[(n_rows,)](
|
153 |
-
Y, Y.stride(0),
|
154 |
-
X, X.stride(0),
|
155 |
-
W, W.stride(0),
|
156 |
-
r, r.stride(0),
|
157 |
-
n_cols, eps,
|
158 |
-
BLOCK_SIZE = BLOCK_SIZE,
|
159 |
-
num_warps = num_warps,
|
160 |
-
)
|
161 |
-
ctx.eps = eps
|
162 |
-
ctx.BLOCK_SIZE = BLOCK_SIZE
|
163 |
-
ctx.num_warps = num_warps
|
164 |
-
ctx.GEMMA = gemma
|
165 |
-
ctx.save_for_backward(X, W, r)
|
166 |
-
return Y.view(*shape)
|
167 |
-
pass
|
168 |
-
|
169 |
-
@staticmethod
|
170 |
-
def backward(ctx, dY : torch.Tensor):
|
171 |
-
shape = dY.shape
|
172 |
-
dim : int = shape[-1]
|
173 |
-
dY = dY.view(-1, dim)
|
174 |
-
X, W, r = ctx.saved_tensors
|
175 |
-
n_rows : int
|
176 |
-
n_cols : int
|
177 |
-
n_rows, n_cols = dY.shape
|
178 |
-
# dW = X
|
179 |
-
dX = torch.empty_like(dY, device = "cuda:0") if ctx.GEMMA else dY
|
180 |
-
|
181 |
-
_rms_layernorm_backward[(n_rows,)](
|
182 |
-
dY, dY.stride(0),
|
183 |
-
dX, dX.stride(0),
|
184 |
-
X, X .stride(0),
|
185 |
-
W, W .stride(0),
|
186 |
-
r, r .stride(0),
|
187 |
-
# dW, dW.stride(0),
|
188 |
-
n_cols, ctx.eps,
|
189 |
-
GEMMA = ctx.GEMMA,
|
190 |
-
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
191 |
-
num_warps = ctx.num_warps,
|
192 |
-
)
|
193 |
-
dX = dX.view(*shape)
|
194 |
-
return dX, None, None, None
|
195 |
-
pass
|
196 |
-
pass
|
197 |
-
|
198 |
-
|
199 |
-
# [TODO] Unsure why RMS Layernorm is not torch.compiling properly
|
200 |
-
@torch.compiler.disable
|
201 |
-
def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
|
202 |
-
W : torch.Tensor = layernorm.weight
|
203 |
-
eps : float = layernorm.variance_epsilon if \
|
204 |
-
hasattr(layernorm, "variance_epsilon") \
|
205 |
-
else layernorm.eps
|
206 |
-
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
|
207 |
-
return out
|
208 |
-
pass
|
209 |
-
|
210 |
-
|
211 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
212 |
-
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
|
213 |
-
def forward(self, X):
|
214 |
-
return fast_rms_layernorm(self, X, gemma = False)
|
215 |
-
pass
|
216 |
-
pass
|
217 |
-
|
218 |
-
try:
|
219 |
-
from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
|
220 |
-
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
|
221 |
-
def forward(self, X):
|
222 |
-
return fast_rms_layernorm(self, X, gemma = False)
|
223 |
-
pass
|
224 |
-
pass
|
225 |
-
except:
|
226 |
-
pass
|
227 |
-
pass
|
228 |
-
|
229 |
-
def patch_rms_layernorm():
|
230 |
-
import transformers.models.llama.modeling_llama
|
231 |
-
transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
|
232 |
-
try:
|
233 |
-
import transformers.models.mllama.modeling_mllama
|
234 |
-
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
|
235 |
-
except:
|
236 |
-
pass
|
237 |
-
return
|
238 |
-
pass
|
239 |
-
|
240 |
-
|
241 |
-
def unpatch_rms_layernorm():
|
242 |
-
import transformers.models.llama.modeling_llama
|
243 |
-
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
244 |
-
try:
|
245 |
-
import transformers.models.mllama.modeling_mllama
|
246 |
-
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
|
247 |
-
except:
|
248 |
-
pass
|
249 |
-
return
|
250 |
-
return
|
251 |
-
pass
|
252 |
-
|
253 |
-
|
254 |
-
def test_rms_layernorm(
|
255 |
-
dim = 1024, eps = 1e-5, dtype = torch.float16,
|
256 |
-
bsz = 21, random_state = 3407, seqlen = 3341,
|
257 |
-
):
|
258 |
-
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
259 |
-
layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
|
260 |
-
torch.cuda.manual_seed(random_state)
|
261 |
-
torch.manual_seed(random_state)
|
262 |
-
torch.nn.init.uniform_(layernorm.weight)
|
263 |
-
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
|
264 |
-
XX = X.clone()
|
265 |
-
X .requires_grad_(True)
|
266 |
-
XX.requires_grad_(True)
|
267 |
-
Y = layernorm(X)
|
268 |
-
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
|
269 |
-
Y.backward(YY)
|
270 |
-
correct_grad = X.grad.clone()
|
271 |
-
# from unsloth.kernels import fast_rms_layernorm
|
272 |
-
Y = fast_rms_layernorm(layernorm, XX)
|
273 |
-
Y.backward(YY)
|
274 |
-
assert(torch.amax(correct_grad - XX.grad).item() <= 0.05)
|
275 |
-
pass
|
276 |
-
|
277 |
-
|
278 |
-
def testing_suite_layernorm():
|
279 |
-
for dim in [512, 1024, 2048]:
|
280 |
-
for dtype in [torch.float16, torch.bfloat16]:
|
281 |
-
with torch.autocast(device_type = "cuda", dtype = dtype):
|
282 |
-
for seqlen in [3341, 2048, 349]:
|
283 |
-
for random_state in [3407, 42]:
|
284 |
-
test_rms_layernorm(
|
285 |
-
dim = dim,
|
286 |
-
eps = 1e-5,
|
287 |
-
dtype = dtype,
|
288 |
-
bsz = 21,
|
289 |
-
random_state = random_state,
|
290 |
-
seqlen = seqlen,
|
291 |
-
)
|
292 |
-
pass
|
293 |
-
pass
|
294 |
-
pass
|
295 |
-
pass
|
296 |
-
pass
|
297 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/rope_embedding.py
DELETED
@@ -1,196 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import triton
|
16 |
-
import triton.language as tl
|
17 |
-
import torch
|
18 |
-
from .utils import calculate_settings
|
19 |
-
ROPE_GROUP_SIZE : int = 4
|
20 |
-
|
21 |
-
@triton.heuristics({"BACKWARD_PASS": lambda args: bool(args["BACKWARD_PASS"]),})
|
22 |
-
@triton.jit
|
23 |
-
def _rope_embedding(
|
24 |
-
Q, Q_row_stride,
|
25 |
-
cos, cos_row_stride,
|
26 |
-
sin, sin_row_stride,
|
27 |
-
seqlen,
|
28 |
-
head_dim : tl.constexpr,
|
29 |
-
n_heads : tl.constexpr,
|
30 |
-
BACKWARD_PASS : tl.constexpr,
|
31 |
-
BLOCK_SIZE : tl.constexpr,
|
32 |
-
):
|
33 |
-
"""
|
34 |
-
Calculates the RoPE Embedding quickly
|
35 |
-
RoPE is Q * cos + rotate_half(Q) * sin
|
36 |
-
See our blog post for more info
|
37 |
-
"""
|
38 |
-
ROPE_GROUP_SIZE = 4
|
39 |
-
row_position = tl.program_id(0)
|
40 |
-
group_head_position = tl.program_id(1)
|
41 |
-
col_offsets = tl.arange(0, BLOCK_SIZE)
|
42 |
-
half_head_dim = head_dim // 2
|
43 |
-
mask = col_offsets < half_head_dim
|
44 |
-
|
45 |
-
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
|
46 |
-
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
47 |
-
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
|
48 |
-
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
49 |
-
|
50 |
-
if BACKWARD_PASS:
|
51 |
-
# See our blog post for more info.
|
52 |
-
sin1 = -sin1
|
53 |
-
pass
|
54 |
-
|
55 |
-
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
|
56 |
-
head_start = group_head_position * ROPE_GROUP_SIZE
|
57 |
-
head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
|
58 |
-
|
59 |
-
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
|
60 |
-
for k in range(head_start, head_end):
|
61 |
-
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
|
62 |
-
offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
|
63 |
-
|
64 |
-
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
|
65 |
-
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
|
66 |
-
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
|
67 |
-
|
68 |
-
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
|
69 |
-
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
|
70 |
-
pass
|
71 |
-
pass
|
72 |
-
|
73 |
-
|
74 |
-
class Fast_RoPE_Embedding(torch.autograd.Function):
|
75 |
-
@staticmethod
|
76 |
-
def forward(ctx, Q, cos, sin):
|
77 |
-
cos, sin = cos.squeeze(), sin.squeeze()
|
78 |
-
batch : int
|
79 |
-
seq_len : int
|
80 |
-
n_heads : int
|
81 |
-
head_dim : int
|
82 |
-
batch, seq_len, n_heads, head_dim = Q.shape
|
83 |
-
Q = Q.view(batch*seq_len, n_heads*head_dim)
|
84 |
-
n_rows : int
|
85 |
-
n_cols : int
|
86 |
-
n_rows, n_cols = Q.shape
|
87 |
-
assert(seq_len <= cos.shape[0])
|
88 |
-
|
89 |
-
# [TODO] Changing blocksize to head_dim//2 seems to have
|
90 |
-
# some concurrency / un-deterministic issues.
|
91 |
-
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
|
92 |
-
|
93 |
-
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
|
94 |
-
div : int
|
95 |
-
mod : int
|
96 |
-
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
|
97 |
-
n_groups : int = div + (mod != 0)
|
98 |
-
|
99 |
-
_rope_embedding[(n_rows, n_groups, )](
|
100 |
-
Q, Q.stride(0),
|
101 |
-
cos, cos.stride(0),
|
102 |
-
sin, sin.stride(0),
|
103 |
-
seq_len,
|
104 |
-
head_dim, n_heads,
|
105 |
-
BACKWARD_PASS = False,
|
106 |
-
BLOCK_SIZE = BLOCK_SIZE,
|
107 |
-
num_warps = num_warps,
|
108 |
-
)
|
109 |
-
ctx.BLOCK_SIZE = BLOCK_SIZE
|
110 |
-
ctx.num_warps = num_warps
|
111 |
-
ctx.n_groups = n_groups
|
112 |
-
ctx.cos = cos
|
113 |
-
ctx.sin = sin
|
114 |
-
return Q.view(batch, seq_len, n_heads, head_dim)
|
115 |
-
pass
|
116 |
-
|
117 |
-
@staticmethod
|
118 |
-
def backward(ctx, dY):
|
119 |
-
batch : int
|
120 |
-
seq_len : int
|
121 |
-
n_heads : int
|
122 |
-
head_dim : int
|
123 |
-
batch, seq_len, n_heads, head_dim = dY.shape
|
124 |
-
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
|
125 |
-
# Must be reshape not view
|
126 |
-
n_rows : int
|
127 |
-
n_cols : int
|
128 |
-
n_rows, n_cols = dY.shape
|
129 |
-
|
130 |
-
cos = ctx.cos
|
131 |
-
sin = ctx.sin
|
132 |
-
|
133 |
-
_rope_embedding[(n_rows, ctx.n_groups, )](
|
134 |
-
dY, dY .stride(0),
|
135 |
-
cos, cos.stride(0),
|
136 |
-
sin, sin.stride(0),
|
137 |
-
seq_len, head_dim, n_heads,
|
138 |
-
BACKWARD_PASS = True,
|
139 |
-
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
140 |
-
num_warps = ctx.num_warps,
|
141 |
-
)
|
142 |
-
dY = dY.view(batch, seq_len, n_heads, head_dim)
|
143 |
-
return dY, None, None,
|
144 |
-
pass
|
145 |
-
pass
|
146 |
-
|
147 |
-
# [TODO] Unsure why RoPE Embedding is not torch.compiling properly
|
148 |
-
@torch.compiler.disable
|
149 |
-
def fast_rope_embedding(Q, K, cos, sin):
|
150 |
-
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
|
151 |
-
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
|
152 |
-
return Q, K
|
153 |
-
pass
|
154 |
-
|
155 |
-
|
156 |
-
class Slow_RoPE_Embedding(torch.autograd.Function):
|
157 |
-
@staticmethod
|
158 |
-
def forward(ctx, Q, cos, sin, position_ids):
|
159 |
-
if position_ids is not None:
|
160 |
-
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
161 |
-
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
162 |
-
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
163 |
-
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
164 |
-
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
165 |
-
|
166 |
-
# Q * cos + rotate_half(Q) * sin
|
167 |
-
half = Q.shape[-1]//2
|
168 |
-
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
169 |
-
Q *= cos
|
170 |
-
Q.addcmul_(RH_Q, sin)
|
171 |
-
# RH_Q *= sin
|
172 |
-
# Q += RH_Q
|
173 |
-
ctx.save_for_backward(cos, sin)
|
174 |
-
return Q
|
175 |
-
pass
|
176 |
-
|
177 |
-
@staticmethod
|
178 |
-
def backward(ctx, dY):
|
179 |
-
cos, sin = ctx.saved_tensors
|
180 |
-
# Q * cos + rotate_half.T(Q) * sin
|
181 |
-
half = dY.shape[-1]//2
|
182 |
-
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
183 |
-
dY *= cos
|
184 |
-
dY.addcmul_(RH_dY, sin)
|
185 |
-
# RH_dY *= sin
|
186 |
-
# dY += RH_dY
|
187 |
-
return dY, None, None, None
|
188 |
-
pass
|
189 |
-
pass
|
190 |
-
|
191 |
-
|
192 |
-
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
|
193 |
-
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
|
194 |
-
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
|
195 |
-
return Q, K
|
196 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/swiglu.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import triton
|
16 |
-
import triton.language as tl
|
17 |
-
import torch
|
18 |
-
from .utils import calculate_settings
|
19 |
-
|
20 |
-
|
21 |
-
@triton.jit
|
22 |
-
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
23 |
-
block_idx = tl.program_id(0)
|
24 |
-
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
25 |
-
mask = offsets < n_elements
|
26 |
-
|
27 |
-
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
28 |
-
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
29 |
-
|
30 |
-
# f = e * sigmoid(e)
|
31 |
-
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
|
32 |
-
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
33 |
-
# h = f * g
|
34 |
-
h_row = f_row * g_row
|
35 |
-
|
36 |
-
# Store h
|
37 |
-
tl.store(h + offsets, h_row, mask = mask)
|
38 |
-
pass
|
39 |
-
|
40 |
-
|
41 |
-
def swiglu_fg_kernel(e, g):
|
42 |
-
batch, seq_len, hd = e.shape
|
43 |
-
n_elements = e.numel()
|
44 |
-
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
|
45 |
-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
46 |
-
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
|
47 |
-
return h
|
48 |
-
pass
|
49 |
-
|
50 |
-
|
51 |
-
@triton.jit
|
52 |
-
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
53 |
-
"""
|
54 |
-
e = e.float()
|
55 |
-
se = 1.0 / (1.0 + torch.exp(-e))
|
56 |
-
f = (se * e).to(dtype)
|
57 |
-
h = f * g
|
58 |
-
df = DW * f
|
59 |
-
dg = DW * g
|
60 |
-
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
61 |
-
"""
|
62 |
-
block_idx = tl.program_id(0)
|
63 |
-
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
64 |
-
mask = offsets < n_elements
|
65 |
-
|
66 |
-
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
67 |
-
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
68 |
-
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
69 |
-
|
70 |
-
# e = e.float()
|
71 |
-
# se = 1.0 / (1.0 + torch.exp(-e))
|
72 |
-
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
|
73 |
-
# f = (se * e).to(dtype)
|
74 |
-
f_row = se_row * e_row
|
75 |
-
f_row = f_row.to(DW_row.dtype)
|
76 |
-
# h = f * g
|
77 |
-
h_row = f_row * g_row
|
78 |
-
# df = DW * f
|
79 |
-
df_row = DW_row * f_row
|
80 |
-
# dg = DW * g
|
81 |
-
dg_row = DW_row * g_row
|
82 |
-
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
83 |
-
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
|
84 |
-
de_row = de_row.to(DW_row.dtype)
|
85 |
-
|
86 |
-
# Store derivatives in buffers
|
87 |
-
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
88 |
-
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
89 |
-
tl.store(g + offsets, de_row, mask = mask) # de
|
90 |
-
pass
|
91 |
-
|
92 |
-
|
93 |
-
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
94 |
-
batch_seq_len, hd = e.shape
|
95 |
-
n_elements = e.numel()
|
96 |
-
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
97 |
-
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
98 |
-
return DW, e, g
|
99 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/kernels/utils.py
DELETED
@@ -1,422 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
import triton
|
16 |
-
MAX_FUSED_SIZE : int = 65536
|
17 |
-
next_power_of_2 = triton.next_power_of_2
|
18 |
-
|
19 |
-
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
|
20 |
-
import torch
|
21 |
-
from packaging.version import Version
|
22 |
-
if Version(torch.__version__) < Version("2.4.0"):
|
23 |
-
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
24 |
-
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
25 |
-
else:
|
26 |
-
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
27 |
-
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
28 |
-
pass
|
29 |
-
|
30 |
-
|
31 |
-
# tl.math.tanh now is libdevice.tanh
|
32 |
-
from packaging.version import Version
|
33 |
-
import triton
|
34 |
-
import triton.language as tl
|
35 |
-
if Version(triton.__version__) >= Version("3.0.0"):
|
36 |
-
from triton.language.extra import libdevice
|
37 |
-
triton_tanh = libdevice.tanh
|
38 |
-
triton_cast = tl.cast
|
39 |
-
else:
|
40 |
-
triton_tanh = tl.math.tanh
|
41 |
-
# No casting in old Triton versions
|
42 |
-
@triton.jit
|
43 |
-
def triton_cast(x, dtype):
|
44 |
-
return x.to(dtype)
|
45 |
-
pass
|
46 |
-
pass
|
47 |
-
|
48 |
-
|
49 |
-
def calculate_settings(n : int) -> (int, int,):
|
50 |
-
BLOCK_SIZE : int = next_power_of_2(n)
|
51 |
-
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
52 |
-
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
53 |
-
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
54 |
-
num_warps : int = 4
|
55 |
-
if BLOCK_SIZE >= 32768: num_warps = 32
|
56 |
-
elif BLOCK_SIZE >= 8192: num_warps = 16
|
57 |
-
elif BLOCK_SIZE >= 2048: num_warps = 8
|
58 |
-
return BLOCK_SIZE, num_warps
|
59 |
-
pass
|
60 |
-
|
61 |
-
|
62 |
-
import bitsandbytes as bnb
|
63 |
-
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
|
64 |
-
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
|
65 |
-
global CUDA_STREAM
|
66 |
-
CUDA_STREAM = None
|
67 |
-
get_ptr = bnb.functional.get_ptr
|
68 |
-
import ctypes
|
69 |
-
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
70 |
-
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
71 |
-
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
72 |
-
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
|
73 |
-
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
|
74 |
-
|
75 |
-
|
76 |
-
def QUANT_STATE(W):
|
77 |
-
return getattr(W, "quant_state", None)
|
78 |
-
pass
|
79 |
-
|
80 |
-
|
81 |
-
def get_lora_parameters(proj):
|
82 |
-
# For DPO or disabled adapters
|
83 |
-
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
84 |
-
W = base_layer.weight
|
85 |
-
|
86 |
-
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
87 |
-
return W, QUANT_STATE(W), None, None, None
|
88 |
-
pass
|
89 |
-
|
90 |
-
active_adapter = proj.active_adapters[0] if \
|
91 |
-
hasattr(proj, "active_adapters") else proj.active_adapter
|
92 |
-
A = proj.lora_A [active_adapter].weight
|
93 |
-
B = proj.lora_B [active_adapter].weight
|
94 |
-
s = proj.scaling[active_adapter]
|
95 |
-
return W, QUANT_STATE(W), A, B, s
|
96 |
-
pass
|
97 |
-
|
98 |
-
|
99 |
-
def get_lora_parameters_bias(proj):
|
100 |
-
# For DPO or disabled adapters
|
101 |
-
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
102 |
-
W = base_layer.weight
|
103 |
-
bias = base_layer.bias
|
104 |
-
|
105 |
-
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
106 |
-
return W, QUANT_STATE(W), None, None, None, bias
|
107 |
-
pass
|
108 |
-
|
109 |
-
active_adapter = proj.active_adapters[0] if \
|
110 |
-
hasattr(proj, "active_adapters") else proj.active_adapter
|
111 |
-
A = proj.lora_A [active_adapter].weight
|
112 |
-
B = proj.lora_B [active_adapter].weight
|
113 |
-
s = proj.scaling[active_adapter]
|
114 |
-
return W, QUANT_STATE(W), A, B, s, bias
|
115 |
-
pass
|
116 |
-
|
117 |
-
|
118 |
-
if HAS_CUDA_STREAM:
|
119 |
-
def fast_dequantize(W, quant_state = None, out = None):
|
120 |
-
if quant_state is None: return W
|
121 |
-
if type(quant_state) is not list:
|
122 |
-
# New quant_state as a class
|
123 |
-
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
124 |
-
absmax = quant_state.absmax
|
125 |
-
shape = quant_state.shape
|
126 |
-
dtype = quant_state.dtype
|
127 |
-
blocksize = quant_state.blocksize
|
128 |
-
offset = quant_state.offset
|
129 |
-
state2 = quant_state.state2
|
130 |
-
absmax2 = state2.absmax
|
131 |
-
code2 = state2.code
|
132 |
-
blocksize2 = state2.blocksize
|
133 |
-
else:
|
134 |
-
# Old quant_state as a list of lists
|
135 |
-
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
136 |
-
offset, state2 = compressed_stats
|
137 |
-
absmax2, code2, blocksize2, _, _, _, _ = state2
|
138 |
-
pass
|
139 |
-
global CUDA_STREAM
|
140 |
-
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
|
141 |
-
|
142 |
-
# Create weight matrix
|
143 |
-
if out is None:
|
144 |
-
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
|
145 |
-
else:
|
146 |
-
assert(out.shape == shape)
|
147 |
-
assert(out.dtype == dtype)
|
148 |
-
|
149 |
-
# NF4 dequantization of statistics
|
150 |
-
n_elements_absmax = absmax.numel()
|
151 |
-
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
|
152 |
-
|
153 |
-
# Do dequantization
|
154 |
-
ptr_out_absmax = get_ptr(out_absmax)
|
155 |
-
cdequantize_blockwise_fp32(
|
156 |
-
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
157 |
-
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM,
|
158 |
-
)
|
159 |
-
out_absmax += offset
|
160 |
-
|
161 |
-
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
162 |
-
cdequantize_blockwise_bf16_nf4
|
163 |
-
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
164 |
-
ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,)
|
165 |
-
|
166 |
-
# Careful returning transposed data
|
167 |
-
is_transposed = (True if W.shape[0] == 1 else False)
|
168 |
-
return out.t() if is_transposed else out
|
169 |
-
pass
|
170 |
-
else:
|
171 |
-
def fast_dequantize(W, quant_state = None, out = None):
|
172 |
-
if quant_state is None: return W
|
173 |
-
if type(quant_state) is not list:
|
174 |
-
# New quant_state as a class
|
175 |
-
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
176 |
-
absmax = quant_state.absmax
|
177 |
-
shape = quant_state.shape
|
178 |
-
dtype = quant_state.dtype
|
179 |
-
blocksize = quant_state.blocksize
|
180 |
-
offset = quant_state.offset
|
181 |
-
state2 = quant_state.state2
|
182 |
-
absmax2 = state2.absmax
|
183 |
-
code2 = state2.code
|
184 |
-
blocksize2 = state2.blocksize
|
185 |
-
else:
|
186 |
-
# Old quant_state as a list of lists
|
187 |
-
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
188 |
-
offset, state2 = compressed_stats
|
189 |
-
absmax2, code2, blocksize2, _, _, _, _ = state2
|
190 |
-
pass
|
191 |
-
|
192 |
-
# Create weight matrix
|
193 |
-
if out is None:
|
194 |
-
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
|
195 |
-
else:
|
196 |
-
assert(out.shape == shape)
|
197 |
-
assert(out.dtype == dtype)
|
198 |
-
|
199 |
-
# NF4 dequantization of statistics
|
200 |
-
n_elements_absmax = absmax.numel()
|
201 |
-
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
|
202 |
-
|
203 |
-
# Do dequantization
|
204 |
-
ptr_out_absmax = get_ptr(out_absmax)
|
205 |
-
cdequantize_blockwise_fp32(
|
206 |
-
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
207 |
-
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax),
|
208 |
-
)
|
209 |
-
out_absmax += offset
|
210 |
-
|
211 |
-
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
212 |
-
cdequantize_blockwise_bf16_nf4
|
213 |
-
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
214 |
-
ctypes.c_int(blocksize), ctypes.c_int(out.numel()),)
|
215 |
-
|
216 |
-
# Careful returning transposed data
|
217 |
-
is_transposed = (True if W.shape[0] == 1 else False)
|
218 |
-
return out.t() if is_transposed else out
|
219 |
-
pass
|
220 |
-
pass
|
221 |
-
|
222 |
-
|
223 |
-
if HAS_CUDA_STREAM:
|
224 |
-
def fast_gemv(X, W, quant_state, out = None):
|
225 |
-
if quant_state is None: return torch.matmul(X, W, out = out)
|
226 |
-
# For fast X @ W where seq_len == 1
|
227 |
-
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
228 |
-
_, q_len, hd = X.shape
|
229 |
-
# assert(q_len == 1)
|
230 |
-
|
231 |
-
if type(quant_state) is not list:
|
232 |
-
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
233 |
-
absmax = quant_state.absmax
|
234 |
-
shape = quant_state.shape
|
235 |
-
dtype = quant_state.dtype
|
236 |
-
blocksize = quant_state.blocksize
|
237 |
-
stats = quant_state.code
|
238 |
-
offset = quant_state.offset
|
239 |
-
state2 = quant_state.state2
|
240 |
-
absmax2 = state2.absmax
|
241 |
-
code2 = state2.code
|
242 |
-
blocksize2 = state2.blocksize
|
243 |
-
else:
|
244 |
-
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
245 |
-
offset, state2 = compressed_stats
|
246 |
-
absmax2, code2, blocksize2, _, _, _, _ = state2
|
247 |
-
pass
|
248 |
-
global CUDA_STREAM
|
249 |
-
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
|
250 |
-
|
251 |
-
# assert(dtype == X.dtype)
|
252 |
-
bout = shape[0]
|
253 |
-
|
254 |
-
if out is None:
|
255 |
-
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
|
256 |
-
# else:
|
257 |
-
# assert(out.shape == (1, 1, bout,))
|
258 |
-
# pass
|
259 |
-
|
260 |
-
n = 1
|
261 |
-
m = shape[0]
|
262 |
-
k = shape[1]
|
263 |
-
lda = shape[0]
|
264 |
-
ldc = shape[0]
|
265 |
-
ldb = (hd+1)//2
|
266 |
-
m = ctypes.c_int32(m)
|
267 |
-
n = ctypes.c_int32(n)
|
268 |
-
k = ctypes.c_int32(k)
|
269 |
-
lda = ctypes.c_int32(lda)
|
270 |
-
ldb = ctypes.c_int32(ldb)
|
271 |
-
ldc = ctypes.c_int32(ldc)
|
272 |
-
|
273 |
-
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
|
274 |
-
cdequantize_blockwise_fp32(
|
275 |
-
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
276 |
-
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM,
|
277 |
-
)
|
278 |
-
df += offset
|
279 |
-
absmax = df
|
280 |
-
|
281 |
-
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
282 |
-
cgemm_4bit_inference_naive_bf16
|
283 |
-
|
284 |
-
blocksize = ctypes.c_int32(blocksize)
|
285 |
-
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
286 |
-
lda, ldb, ldc, blocksize, CUDA_STREAM,)
|
287 |
-
|
288 |
-
return out
|
289 |
-
pass
|
290 |
-
else:
|
291 |
-
def fast_gemv(X, W, quant_state, out = None):
|
292 |
-
if quant_state is None: return torch.matmul(X, W, out = out)
|
293 |
-
# For fast X @ W where seq_len == 1
|
294 |
-
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
295 |
-
_, q_len, hd = X.shape
|
296 |
-
# assert(q_len == 1)
|
297 |
-
|
298 |
-
if type(quant_state) is not list:
|
299 |
-
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
300 |
-
absmax = quant_state.absmax
|
301 |
-
shape = quant_state.shape
|
302 |
-
dtype = quant_state.dtype
|
303 |
-
blocksize = quant_state.blocksize
|
304 |
-
stats = quant_state.code
|
305 |
-
offset = quant_state.offset
|
306 |
-
state2 = quant_state.state2
|
307 |
-
absmax2 = state2.absmax
|
308 |
-
code2 = state2.code
|
309 |
-
blocksize2 = state2.blocksize
|
310 |
-
else:
|
311 |
-
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
312 |
-
offset, state2 = compressed_stats
|
313 |
-
absmax2, code2, blocksize2, _, _, _, _ = state2
|
314 |
-
pass
|
315 |
-
# assert(dtype == X.dtype)
|
316 |
-
bout = shape[0]
|
317 |
-
|
318 |
-
if out is None:
|
319 |
-
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
|
320 |
-
# else:
|
321 |
-
# assert(out.shape == (1, 1, bout,))
|
322 |
-
# pass
|
323 |
-
|
324 |
-
n = 1
|
325 |
-
m = shape[0]
|
326 |
-
k = shape[1]
|
327 |
-
lda = shape[0]
|
328 |
-
ldc = shape[0]
|
329 |
-
ldb = (hd+1)//2
|
330 |
-
m = ctypes.c_int32(m)
|
331 |
-
n = ctypes.c_int32(n)
|
332 |
-
k = ctypes.c_int32(k)
|
333 |
-
lda = ctypes.c_int32(lda)
|
334 |
-
ldb = ctypes.c_int32(ldb)
|
335 |
-
ldc = ctypes.c_int32(ldc)
|
336 |
-
|
337 |
-
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
|
338 |
-
cdequantize_blockwise_fp32(
|
339 |
-
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
340 |
-
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
|
341 |
-
)
|
342 |
-
df += offset
|
343 |
-
absmax = df
|
344 |
-
|
345 |
-
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
346 |
-
cgemm_4bit_inference_naive_bf16
|
347 |
-
|
348 |
-
blocksize = ctypes.c_int32(blocksize)
|
349 |
-
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
350 |
-
lda, ldb, ldc, blocksize,)
|
351 |
-
|
352 |
-
return out
|
353 |
-
pass
|
354 |
-
pass
|
355 |
-
|
356 |
-
|
357 |
-
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
358 |
-
|
359 |
-
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
|
360 |
-
bsz, q_len, in_dim = X.shape
|
361 |
-
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
|
362 |
-
|
363 |
-
if W_quant is None:
|
364 |
-
out = torch.matmul(X, W.t(), out = out)
|
365 |
-
elif bsz == 1 and q_len == 1:
|
366 |
-
out = fast_gemv(X, W, W_quant, out = out)
|
367 |
-
else:
|
368 |
-
W = fast_dequantize(W.t(), W_quant)
|
369 |
-
out = torch.matmul(X, W, out = out)
|
370 |
-
pass
|
371 |
-
|
372 |
-
# Add in LoRA weights
|
373 |
-
if lora_A is not None:
|
374 |
-
out_dim = out.shape[2]
|
375 |
-
dtype = X.dtype
|
376 |
-
|
377 |
-
if not hasattr(lora_A, "_fast_lora"):
|
378 |
-
lora_A._fast_lora = lora_A.to(dtype)
|
379 |
-
lora_B._fast_lora = lora_B.to(dtype)
|
380 |
-
pass
|
381 |
-
|
382 |
-
if bsz == 1:
|
383 |
-
out = out.view(out_dim)
|
384 |
-
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
|
385 |
-
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
386 |
-
else:
|
387 |
-
out = out.view(bsz, out_dim)
|
388 |
-
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
|
389 |
-
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
|
390 |
-
pass
|
391 |
-
out = out.view(bsz, 1, out_dim)
|
392 |
-
pass
|
393 |
-
|
394 |
-
if bias is not None: out += bias
|
395 |
-
|
396 |
-
return out
|
397 |
-
pass
|
398 |
-
|
399 |
-
|
400 |
-
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
401 |
-
dtype = X.dtype
|
402 |
-
W = fast_dequantize(W.t(), W_quant)
|
403 |
-
|
404 |
-
if X.dim() == 3:
|
405 |
-
batch, seq_len, d = X.shape
|
406 |
-
X = X.view(-1, X.shape[-1])
|
407 |
-
reshape = True
|
408 |
-
else:
|
409 |
-
reshape = False
|
410 |
-
pass
|
411 |
-
|
412 |
-
out = torch.matmul(X, W, out = out)
|
413 |
-
if W_quant is not None: del W
|
414 |
-
|
415 |
-
if A is not None:
|
416 |
-
# LoRA is enabled
|
417 |
-
A, B = A.t(), B.t()
|
418 |
-
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
419 |
-
pass
|
420 |
-
|
421 |
-
return out.view(batch, seq_len, -1) if reshape else out
|
422 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unsloth-main/unsloth/models/__init__.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
|
15 |
-
|
16 |
-
from .granite import FastGraniteModel
|
17 |
-
from .loader import FastLanguageModel, FastVisionModel
|
18 |
-
from .llama import FastLlamaModel
|
19 |
-
from .mistral import FastMistralModel
|
20 |
-
from .qwen2 import FastQwen2Model
|
21 |
-
from .dpo import PatchDPOTrainer, PatchKTOTrainer
|
22 |
-
from ._utils import is_bfloat16_supported
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|