Upload 58 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- unsloth-main/unsloth-main/.github/FUNDING.yml +13 -0
- unsloth-main/unsloth-main/LICENSE +201 -0
- unsloth-main/unsloth-main/README.md +455 -0
- unsloth-main/unsloth-main/images/Assistant.png +0 -0
- unsloth-main/unsloth-main/images/Colab.png +0 -0
- unsloth-main/unsloth-main/images/Discord button.png +0 -0
- unsloth-main/unsloth-main/images/Discord.png +0 -0
- unsloth-main/unsloth-main/images/Free version button.png +0 -0
- unsloth-main/unsloth-main/images/Kaggle.png +0 -0
- unsloth-main/unsloth-main/images/Kofi button.png +0 -0
- unsloth-main/unsloth-main/images/LAION 2GPU.png +0 -0
- unsloth-main/unsloth-main/images/Merge.png +0 -0
- unsloth-main/unsloth-main/images/Run.png +0 -0
- unsloth-main/unsloth-main/images/Slim Orca 2GPUs.png +0 -0
- unsloth-main/unsloth-main/images/Terminal_Type.png +0 -0
- unsloth-main/unsloth-main/images/Where_Terminal.png +0 -0
- unsloth-main/unsloth-main/images/buy me a coffee button.png +0 -0
- unsloth-main/unsloth-main/images/made with unsloth.png +0 -0
- unsloth-main/unsloth-main/images/ollama.png +0 -0
- unsloth-main/unsloth-main/images/peft x trl button.png +0 -0
- unsloth-main/unsloth-main/images/start free finetune button.png +0 -0
- unsloth-main/unsloth-main/images/unsloth end.png +0 -0
- unsloth-main/unsloth-main/images/unsloth loading page render.png +0 -0
- unsloth-main/unsloth-main/images/unsloth logo black text.png +0 -0
- unsloth-main/unsloth-main/images/unsloth logo only.png +0 -0
- unsloth-main/unsloth-main/images/unsloth logo white text.png +0 -0
- unsloth-main/unsloth-main/images/unsloth made with love.png +0 -0
- unsloth-main/unsloth-main/images/unsloth new logo.png +0 -0
- unsloth-main/unsloth-main/pyproject.toml +327 -0
- unsloth-main/unsloth-main/unsloth-cli.py +221 -0
- unsloth-main/unsloth-main/unsloth/__init__.py +161 -0
- unsloth-main/unsloth-main/unsloth/_auto_install.py +30 -0
- unsloth-main/unsloth-main/unsloth/chat_templates.py +2210 -0
- unsloth-main/unsloth-main/unsloth/kernels/__init__.py +61 -0
- unsloth-main/unsloth-main/unsloth/kernels/cross_entropy_loss.py +461 -0
- unsloth-main/unsloth-main/unsloth/kernels/fast_lora.py +412 -0
- unsloth-main/unsloth-main/unsloth/kernels/flex_attention.py +180 -0
- unsloth-main/unsloth-main/unsloth/kernels/geglu.py +203 -0
- unsloth-main/unsloth-main/unsloth/kernels/layernorm.py +231 -0
- unsloth-main/unsloth-main/unsloth/kernels/rms_layernorm.py +283 -0
- unsloth-main/unsloth-main/unsloth/kernels/rope_embedding.py +181 -0
- unsloth-main/unsloth-main/unsloth/kernels/swiglu.py +99 -0
- unsloth-main/unsloth-main/unsloth/kernels/utils.py +416 -0
- unsloth-main/unsloth-main/unsloth/models/__init__.py +20 -0
- unsloth-main/unsloth-main/unsloth/models/_utils.py +1140 -0
- unsloth-main/unsloth-main/unsloth/models/cohere.py +473 -0
- unsloth-main/unsloth-main/unsloth/models/dpo.py +130 -0
- unsloth-main/unsloth-main/unsloth/models/gemma.py +430 -0
- unsloth-main/unsloth-main/unsloth/models/gemma2.py +581 -0
- unsloth-main/unsloth-main/unsloth/models/llama.py +0 -0
unsloth-main/unsloth-main/.github/FUNDING.yml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# These are supported funding model platforms
|
2 |
+
|
3 |
+
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
4 |
+
patreon: # Replace with a single Patreon username
|
5 |
+
open_collective: # Replace with a single Open Collective username
|
6 |
+
ko_fi: unsloth
|
7 |
+
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
8 |
+
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
9 |
+
liberapay: # Replace with a single Liberapay username
|
10 |
+
issuehunt: # Replace with a single IssueHunt username
|
11 |
+
otechie: # Replace with a single Otechie username
|
12 |
+
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
13 |
+
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
|
unsloth-main/unsloth-main/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2024-] [Unsloth AI, Daniel Han-Chen & Michael Han-Chen]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
unsloth-main/unsloth-main/README.md
ADDED
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
|
3 |
+
<a href="https://unsloth.ai"><picture>
|
4 |
+
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20white%20text.png">
|
5 |
+
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png">
|
6 |
+
<img alt="unsloth logo" src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/unsloth%20logo%20black%20text.png" height="110" style="max-width: 100%;">
|
7 |
+
</picture></a>
|
8 |
+
|
9 |
+
<a href="https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/start free finetune button.png" height="48"></a>
|
10 |
+
<a href="https://discord.gg/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/Discord button.png" height="48"></a>
|
11 |
+
<a href="https://ko-fi.com/unsloth"><img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/buy me a coffee button.png" height="48"></a>
|
12 |
+
|
13 |
+
### Finetune Llama 3.2, Mistral, Phi-3.5 & Gemma 2-5x faster with 80% less memory!
|
14 |
+
|
15 |
+

|
16 |
+
|
17 |
+
</div>
|
18 |
+
|
19 |
+
## ✨ Finetune for Free
|
20 |
+
|
21 |
+
All notebooks are **beginner friendly**! Add your dataset, click "Run All", and you'll get a 2x faster finetuned model which can be exported to GGUF, Ollama, vLLM or uploaded to Hugging Face.
|
22 |
+
|
23 |
+
| Unsloth supports | Free Notebooks | Performance | Memory use |
|
24 |
+
|-----------|---------|--------|----------|
|
25 |
+
| **Llama 3.2 (3B)** | [▶️ Start for free](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing) | 2x faster | 60% less |
|
26 |
+
| **Llama 3.1 (8B)** | [▶️ Start for free](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) | 2x faster | 60% less |
|
27 |
+
| **Phi-3.5 (mini)** | [▶️ Start for free](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) | 2x faster | 50% less |
|
28 |
+
| **Gemma 2 (9B)** | [▶️ Start for free](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) | 2x faster | 63% less |
|
29 |
+
| **Mistral Small (22B)** | [▶️ Start for free](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) | 2x faster | 60% less |
|
30 |
+
| **Ollama** | [▶️ Start for free](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing) | 1.9x faster | 43% less |
|
31 |
+
| **Mistral v0.3 (7B)** | [▶️ Start for free](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) | 2.2x faster | 73% less |
|
32 |
+
| **ORPO** | [▶️ Start for free](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) | 1.9x faster | 43% less |
|
33 |
+
| **DPO Zephyr** | [▶️ Start for free](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing) | 1.9x faster | 43% less |
|
34 |
+
|
35 |
+
- **Kaggle Notebooks** for [Llama 3.1 (8B)](https://www.kaggle.com/danielhanchen/kaggle-llama-3-1-8b-unsloth-notebook), [Gemma 2 (9B)](https://www.kaggle.com/code/danielhanchen/kaggle-gemma-7b-unsloth-notebook/), [Mistral (7B)](https://www.kaggle.com/code/danielhanchen/kaggle-mistral-7b-unsloth-notebook)
|
36 |
+
- Run [Llama 3.2 1B 3B notebook](https://colab.research.google.com/drive/1hoHFpf7ROqk_oZHzxQdfPW9yvTxnvItq?usp=sharing) and [Llama 3.2 conversational notebook](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing)
|
37 |
+
- Run [Llama 3.1 conversational notebook](https://colab.research.google.com/drive/15OyFkGoCImV9dSsewU1wa2JuKB4-mDE_?usp=sharing) and [Mistral v0.3 ChatML](https://colab.research.google.com/drive/15F1xyn8497_dUbxZP4zWmPZ3PJx1Oymv?usp=sharing)
|
38 |
+
- This [text completion notebook](https://colab.research.google.com/drive/1ef-tab5bhkvWmBOObepl1WgJvfvSzn5Q?usp=sharing) is for continued pretraining / raw text
|
39 |
+
- This [continued pretraining notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) is for learning another language
|
40 |
+
- Click [here](https://github.com/unslothai/unsloth/wiki) for detailed documentation for Unsloth.
|
41 |
+
|
42 |
+
## 🦥 Unsloth.ai News
|
43 |
+
- 📣 NEW! [Llama 3.2 Conversational notebook](https://colab.research.google.com/drive/1T5-zKWM_5OD21QHwXHiV9ixTRR7k3iB9?usp=sharing) includes training only on completions / outputs (increase accuracy), ShareGPT standardization and more!
|
44 |
+
- 📣 NEW! [Llama 3.2 Kaggle notebook](https://www.kaggle.com/danielhanchen/kaggle-llama-3-2-1b-3b-unsloth-notebook) and [Llama 3.2 Kaggle conversational notebook](https://www.kaggle.com/code/danielhanchen/kaggle-llama-3-2-1b-3b-conversational-unsloth/notebook)
|
45 |
+
- 📣 NEW! [Qwen 2.5 7b notebook](https://colab.research.google.com/drive/1Kose-ucXO1IBaZq5BvbwWieuubP7hxvQ?usp=sharing) finetuning is supported! Qwen 2.5 comes in multiple sizes - check our [4bit uploads](https://huggingface.co/unsloth) for 4x faster downloads!. 14b fits in a Colab GPU! [Qwen 2.5 conversational notebook](https://colab.research.google.com/drive/1qN1CEalC70EO1wGKhNxs1go1W9So61R5?usp=sharing)
|
46 |
+
- 📣 NEW! [Mistral Small 22b notebook](https://colab.research.google.com/drive/1oCEHcED15DzL8xXGU1VTx5ZfOJM8WY01?usp=sharing) finetuning fits in under 16GB of VRAM!
|
47 |
+
- 📣 NEW! [Phi-3.5 (mini)](https://colab.research.google.com/drive/1lN6hPQveB_mHSnTOYifygFcrO8C1bxq4?usp=sharing) now supported
|
48 |
+
- 📣 NEW! [Gemma-2-2b](https://colab.research.google.com/drive/1weTpKOjBZxZJ5PQ-Ql8i6ptAY2x-FWVA?usp=sharing) now supported! Try out [Chat interface](https://colab.research.google.com/drive/1i-8ESvtLRGNkkUQQr_-z_rcSAIo9c3lM?usp=sharing)!
|
49 |
+
- 📣 NEW! [Llama 3.1 8b, 70b](https://colab.research.google.com/drive/1Ys44kVvmeZtnICzWz0xgpRnrIOjZAuxp?usp=sharing) & [Mistral Nemo-12b](https://colab.research.google.com/drive/17d3U-CAIwzmbDRqbZ9NnpHxCkmXB6LZ0?usp=sharing) both Base and Instruct are now supported
|
50 |
+
<details>
|
51 |
+
<summary>Click for more news</summary>
|
52 |
+
|
53 |
+
- 📣 NEW! `pip install unsloth` now works! Head over to [pypi](https://pypi.org/project/unsloth/) to check it out! This allows non git pull installs. Use `pip install unsloth[colab-new]` for non dependency installs.
|
54 |
+
- 📣 NEW! [Gemma-2-9b](https://colab.research.google.com/drive/1vIrqH5uYDQwsJ4-OO3DErvuv4pBgVwk4?usp=sharing) and Gemma-2-27b now supported
|
55 |
+
- 📣 UPDATE! [Phi-3 mini](https://colab.research.google.com/drive/1hhdhBa1j_hsymiW9m-WzxQtgqTH_NHqi?usp=sharing) model updated. [Phi-3 Medium](https://colab.research.google.com/drive/1hhdhBa1j_hsymiW9m-WzxQtgqTH_NHqi?usp=sharing) 2x faster finetuning.
|
56 |
+
- 📣 NEW! Continued Pretraining [notebook](https://colab.research.google.com/drive/1tEd1FrOXWMnCU9UIvdYhs61tkxdMuKZu?usp=sharing) for other languages like Korean!
|
57 |
+
- 📣 NEW! Qwen2 now works
|
58 |
+
- 📣 [Mistral v0.3 Base](https://colab.research.google.com/drive/1_yNCks4BTD5zOnjozppphh5GzMFaMKq_?usp=sharing) and [Mistral v0.3 Instruct]
|
59 |
+
- 📣 [ORPO support](https://colab.research.google.com/drive/11t4njE3c4Lxl-07OD8lJSMKkfyJml3Tn?usp=sharing) is here + [2x faster inference](https://colab.research.google.com/drive/1aqlNQi7MMJbynFDyOQteD2t0yVfjb9Zh?usp=sharing) added for all our models
|
60 |
+
- 📣 We cut memory usage by a [further 30%](https://unsloth.ai/blog/long-context) and now support [4x longer context windows](https://unsloth.ai/blog/long-context)!
|
61 |
+
-
|
62 |
+
</details>
|
63 |
+
|
64 |
+
## 🔗 Links and Resources
|
65 |
+
| Type | Links |
|
66 |
+
| ------------------------------- | --------------------------------------- |
|
67 |
+
| 📚 **Documentation & Wiki** | [Read Our Docs](https://docs.unsloth.ai) |
|
68 |
+
| <img height="14" src="https://upload.wikimedia.org/wikipedia/commons/6/6f/Logo_of_Twitter.svg" /> **Twitter (aka X)** | [Follow us on X](https://twitter.com/unslothai)|
|
69 |
+
| 💾 **Installation** | [unsloth/README.md](https://github.com/unslothai/unsloth/tree/main#-installation-instructions)|
|
70 |
+
| 🥇 **Benchmarking** | [Performance Tables](https://github.com/unslothai/unsloth/tree/main#-performance-benchmarking)
|
71 |
+
| 🌐 **Released Models** | [Unsloth Releases](https://huggingface.co/unsloth)|
|
72 |
+
| ✍️ **Blog** | [Read our Blogs](https://unsloth.ai/blog)|
|
73 |
+
|
74 |
+
## ⭐ Key Features
|
75 |
+
- All kernels written in [OpenAI's Triton](https://openai.com/research/triton) language. **Manual backprop engine**.
|
76 |
+
- **0% loss in accuracy** - no approximation methods - all exact.
|
77 |
+
- No change of hardware. Supports NVIDIA GPUs since 2018+. Minimum CUDA Capability 7.0 (V100, T4, Titan V, RTX 20, 30, 40x, A100, H100, L40 etc) [Check your GPU!](https://developer.nvidia.com/cuda-gpus) GTX 1070, 1080 works, but is slow.
|
78 |
+
- Works on **Linux** and **Windows** via WSL.
|
79 |
+
- Supports 4bit and 16bit QLoRA / LoRA finetuning via [bitsandbytes](https://github.com/TimDettmers/bitsandbytes).
|
80 |
+
- Open source trains 5x faster - see [Unsloth Pro](https://unsloth.ai/) for up to **30x faster training**!
|
81 |
+
- If you trained a model with 🦥Unsloth, you can use this cool sticker! <img src="https://raw.githubusercontent.com/unslothai/unsloth/main/images/made with unsloth.png" height="50" align="center" />
|
82 |
+
|
83 |
+
|
84 |
+
## 🥇 Performance Benchmarking
|
85 |
+
- For the full list of **reproducible** benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
|
86 |
+
|
87 |
+
| 1 A100 40GB | 🤗Hugging Face | Flash Attention | 🦥Unsloth Open Source | 🦥[Unsloth Pro](https://unsloth.ai/pricing) |
|
88 |
+
|--------------|--------------|-----------------|---------------------|-----------------|
|
89 |
+
| Alpaca | 1x | 1.04x | 1.98x | **15.64x** |
|
90 |
+
| LAION Chip2 | 1x | 0.92x | 1.61x | **20.73x** |
|
91 |
+
| OASST | 1x | 1.19x | 2.17x | **14.83x** |
|
92 |
+
| Slim Orca | 1x | 1.18x | 2.22x | **14.82x** |
|
93 |
+
|
94 |
+
- Benchmarking table below was conducted by [🤗Hugging Face](https://huggingface.co/blog/unsloth-trl).
|
95 |
+
|
96 |
+
| Free Colab T4 | Dataset | 🤗Hugging Face | Pytorch 2.1.1 | 🦥Unsloth | 🦥 VRAM reduction |
|
97 |
+
| --- | --- | --- | --- | --- | --- |
|
98 |
+
| Llama-2 7b | OASST | 1x | 1.19x | 1.95x | -43.3% |
|
99 |
+
| Mistral 7b | Alpaca | 1x | 1.07x | 1.56x | -13.7% |
|
100 |
+
| Tiny Llama 1.1b | Alpaca | 1x | 2.06x | 3.87x | -73.8% |
|
101 |
+
| DPO with Zephyr | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
|
102 |
+
|
103 |
+

|
104 |
+
|
105 |
+
## 💾 Installation Instructions
|
106 |
+
|
107 |
+
For stable releases, use `pip install unsloth`. We recommend `pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"` for most installations though.
|
108 |
+
|
109 |
+
### Conda Installation
|
110 |
+
`⚠️Only use Conda if you have it. If not, use Pip`. Select either `pytorch-cuda=11.8,12.1` for CUDA 11.8 or CUDA 12.1. We support `python=3.10,3.11,3.12`.
|
111 |
+
```bash
|
112 |
+
conda create --name unsloth_env \
|
113 |
+
python=3.11 \
|
114 |
+
pytorch-cuda=12.1 \
|
115 |
+
pytorch cudatoolkit xformers -c pytorch -c nvidia -c xformers \
|
116 |
+
-y
|
117 |
+
conda activate unsloth_env
|
118 |
+
|
119 |
+
pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
|
120 |
+
pip install --no-deps trl peft accelerate bitsandbytes
|
121 |
+
```
|
122 |
+
|
123 |
+
<details>
|
124 |
+
<summary>If you're looking to install Conda in a Linux environment, <a href="https://docs.anaconda.com/miniconda/">read here</a>, or run the below 🔽</summary>
|
125 |
+
|
126 |
+
```bash
|
127 |
+
mkdir -p ~/miniconda3
|
128 |
+
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
129 |
+
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
130 |
+
rm -rf ~/miniconda3/miniconda.sh
|
131 |
+
~/miniconda3/bin/conda init bash
|
132 |
+
~/miniconda3/bin/conda init zsh
|
133 |
+
```
|
134 |
+
</details>
|
135 |
+
|
136 |
+
### Pip Installation
|
137 |
+
`⚠️Do **NOT** use this if you have Conda.` Pip is a bit more complex since there are dependency issues. The pip command is different for `torch 2.2,2.3,2.4,2.5` and CUDA versions.
|
138 |
+
|
139 |
+
For other torch versions, we support `torch211`, `torch212`, `torch220`, `torch230`, `torch240` and for CUDA versions, we support `cu118` and `cu121`. For Ampere devices (A100, H100, RTX3090) and above, use `cu118-ampere` or `cu121-ampere`.
|
140 |
+
|
141 |
+
For example, if you have `torch 2.4` and `CUDA 12.1`, use:
|
142 |
+
```bash
|
143 |
+
pip install --upgrade pip
|
144 |
+
pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
145 |
+
```
|
146 |
+
|
147 |
+
And other examples:
|
148 |
+
```bash
|
149 |
+
pip install "unsloth[cu121-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
150 |
+
pip install "unsloth[cu118-ampere-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
151 |
+
pip install "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
152 |
+
pip install "unsloth[cu118-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
153 |
+
|
154 |
+
pip install "unsloth[cu121-torch230] @ git+https://github.com/unslothai/unsloth.git"
|
155 |
+
pip install "unsloth[cu121-ampere-torch230] @ git+https://github.com/unslothai/unsloth.git"
|
156 |
+
```
|
157 |
+
|
158 |
+
Or, run the below in a terminal to get the **optimal** pip installation command:
|
159 |
+
```bash
|
160 |
+
wget -qO- https://raw.githubusercontent.com/unslothai/unsloth/main/unsloth/_auto_install.py | python -
|
161 |
+
```
|
162 |
+
|
163 |
+
Or, run the below manually in a Python REPL:
|
164 |
+
```python
|
165 |
+
try: import torch
|
166 |
+
except: raise ImportError("Install torch via `pip install torch`")
|
167 |
+
from packaging.version import Version as V
|
168 |
+
v = V(torch.__version__)
|
169 |
+
cuda = str(torch.version.cuda)
|
170 |
+
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
171 |
+
if cuda != "12.1" and cuda != "11.8": raise RuntimeError(f"CUDA = {cuda} not supported!")
|
172 |
+
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
|
173 |
+
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
|
174 |
+
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
|
175 |
+
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
|
176 |
+
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
|
177 |
+
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
|
178 |
+
else: raise RuntimeError(f"Torch = {v} too new!")
|
179 |
+
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
180 |
+
print(f'pip install --upgrade pip && pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
|
181 |
+
```
|
182 |
+
|
183 |
+
For **advanced installation instructions** or if you see weird errors during installations:
|
184 |
+
|
185 |
+
1. Install `torch` and `triton`. Go to https://pytorch.org to install it. For example `pip install torch torchvision torchaudio triton`
|
186 |
+
2. Confirm if CUDA is installated correctly. Try `nvcc`. If that fails, you need to install `cudatoolkit` or CUDA drivers.
|
187 |
+
3. Install `xformers` manually. You can try installing `vllm` and seeing if `vllm` succeeds. Check if `xformers` succeeded with `python -m xformers.info` Go to https://github.com/facebookresearch/xformers. Another option is to install `flash-attn` for Ampere GPUs.
|
188 |
+
4. Finally, install `bitsandbytes` and check it with `python -m bitsandbytes`
|
189 |
+
|
190 |
+
## 📜 [Documentation](https://docs.unsloth.ai)
|
191 |
+
- Go to our official [Documentation](https://docs.unsloth.ai) for saving to GGUF, checkpointing, evaluation and more!
|
192 |
+
- We support Huggingface's TRL, Trainer, Seq2SeqTrainer or even Pytorch code!
|
193 |
+
- We're in 🤗Hugging Face's official docs! Check out the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
194 |
+
|
195 |
+
```python
|
196 |
+
from unsloth import FastLanguageModel
|
197 |
+
from unsloth import is_bfloat16_supported
|
198 |
+
import torch
|
199 |
+
from trl import SFTTrainer
|
200 |
+
from transformers import TrainingArguments
|
201 |
+
from datasets import load_dataset
|
202 |
+
max_seq_length = 2048 # Supports RoPE Scaling interally, so choose any!
|
203 |
+
# Get LAION dataset
|
204 |
+
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
|
205 |
+
dataset = load_dataset("json", data_files = {"train" : url}, split = "train")
|
206 |
+
|
207 |
+
# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
|
208 |
+
fourbit_models = [
|
209 |
+
"unsloth/mistral-7b-v0.3-bnb-4bit", # New Mistral v3 2x faster!
|
210 |
+
"unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
|
211 |
+
"unsloth/llama-3-8b-bnb-4bit", # Llama-3 15 trillion tokens model 2x faster!
|
212 |
+
"unsloth/llama-3-8b-Instruct-bnb-4bit",
|
213 |
+
"unsloth/llama-3-70b-bnb-4bit",
|
214 |
+
"unsloth/Phi-3-mini-4k-instruct", # Phi-3 2x faster!
|
215 |
+
"unsloth/Phi-3-medium-4k-instruct",
|
216 |
+
"unsloth/mistral-7b-bnb-4bit",
|
217 |
+
"unsloth/gemma-7b-bnb-4bit", # Gemma 2.2x faster!
|
218 |
+
] # More models at https://huggingface.co/unsloth
|
219 |
+
|
220 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
221 |
+
model_name = "unsloth/llama-3-8b-bnb-4bit",
|
222 |
+
max_seq_length = max_seq_length,
|
223 |
+
dtype = None,
|
224 |
+
load_in_4bit = True,
|
225 |
+
)
|
226 |
+
|
227 |
+
# Do model patching and add fast LoRA weights
|
228 |
+
model = FastLanguageModel.get_peft_model(
|
229 |
+
model,
|
230 |
+
r = 16,
|
231 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
232 |
+
"gate_proj", "up_proj", "down_proj",],
|
233 |
+
lora_alpha = 16,
|
234 |
+
lora_dropout = 0, # Supports any, but = 0 is optimized
|
235 |
+
bias = "none", # Supports any, but = "none" is optimized
|
236 |
+
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
237 |
+
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
238 |
+
random_state = 3407,
|
239 |
+
max_seq_length = max_seq_length,
|
240 |
+
use_rslora = False, # We support rank stabilized LoRA
|
241 |
+
loftq_config = None, # And LoftQ
|
242 |
+
)
|
243 |
+
|
244 |
+
trainer = SFTTrainer(
|
245 |
+
model = model,
|
246 |
+
train_dataset = dataset,
|
247 |
+
dataset_text_field = "text",
|
248 |
+
max_seq_length = max_seq_length,
|
249 |
+
tokenizer = tokenizer,
|
250 |
+
args = TrainingArguments(
|
251 |
+
per_device_train_batch_size = 2,
|
252 |
+
gradient_accumulation_steps = 4,
|
253 |
+
warmup_steps = 10,
|
254 |
+
max_steps = 60,
|
255 |
+
fp16 = not is_bfloat16_supported(),
|
256 |
+
bf16 = is_bfloat16_supported(),
|
257 |
+
logging_steps = 1,
|
258 |
+
output_dir = "outputs",
|
259 |
+
optim = "adamw_8bit",
|
260 |
+
seed = 3407,
|
261 |
+
),
|
262 |
+
)
|
263 |
+
trainer.train()
|
264 |
+
|
265 |
+
# Go to https://github.com/unslothai/unsloth/wiki for advanced tips like
|
266 |
+
# (1) Saving to GGUF / merging to 16bit for vLLM
|
267 |
+
# (2) Continued training from a saved LoRA adapter
|
268 |
+
# (3) Adding an evaluation loop / OOMs
|
269 |
+
# (4) Customized chat templates
|
270 |
+
```
|
271 |
+
|
272 |
+
<a name="DPO"></a>
|
273 |
+
## DPO Support
|
274 |
+
DPO (Direct Preference Optimization), PPO, Reward Modelling all seem to work as per 3rd party independent testing from [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory). We have a preliminary Google Colab notebook for reproducing Zephyr on Tesla T4 here: [notebook](https://colab.research.google.com/drive/15vttTpzzVXv_tJwEk-hIcQ0S9FcEWvwP?usp=sharing).
|
275 |
+
|
276 |
+
We're in 🤗Hugging Face's official docs! We're on the [SFT docs](https://huggingface.co/docs/trl/main/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth) and the [DPO docs](https://huggingface.co/docs/trl/main/en/dpo_trainer#accelerate-dpo-fine-tuning-using-unsloth)!
|
277 |
+
|
278 |
+
```python
|
279 |
+
from unsloth import FastLanguageModel, PatchDPOTrainer
|
280 |
+
from unsloth import is_bfloat16_supported
|
281 |
+
PatchDPOTrainer()
|
282 |
+
import torch
|
283 |
+
from transformers import TrainingArguments
|
284 |
+
from trl import DPOTrainer
|
285 |
+
|
286 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
287 |
+
model_name = "unsloth/zephyr-sft-bnb-4bit",
|
288 |
+
max_seq_length = max_seq_length,
|
289 |
+
dtype = None,
|
290 |
+
load_in_4bit = True,
|
291 |
+
)
|
292 |
+
|
293 |
+
# Do model patching and add fast LoRA weights
|
294 |
+
model = FastLanguageModel.get_peft_model(
|
295 |
+
model,
|
296 |
+
r = 64,
|
297 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
298 |
+
"gate_proj", "up_proj", "down_proj",],
|
299 |
+
lora_alpha = 64,
|
300 |
+
lora_dropout = 0, # Supports any, but = 0 is optimized
|
301 |
+
bias = "none", # Supports any, but = "none" is optimized
|
302 |
+
# [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
|
303 |
+
use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
|
304 |
+
random_state = 3407,
|
305 |
+
max_seq_length = max_seq_length,
|
306 |
+
)
|
307 |
+
|
308 |
+
dpo_trainer = DPOTrainer(
|
309 |
+
model = model,
|
310 |
+
ref_model = None,
|
311 |
+
args = TrainingArguments(
|
312 |
+
per_device_train_batch_size = 4,
|
313 |
+
gradient_accumulation_steps = 8,
|
314 |
+
warmup_ratio = 0.1,
|
315 |
+
num_train_epochs = 3,
|
316 |
+
fp16 = not is_bfloat16_supported(),
|
317 |
+
bf16 = is_bfloat16_supported(),
|
318 |
+
logging_steps = 1,
|
319 |
+
optim = "adamw_8bit",
|
320 |
+
seed = 42,
|
321 |
+
output_dir = "outputs",
|
322 |
+
),
|
323 |
+
beta = 0.1,
|
324 |
+
train_dataset = YOUR_DATASET_HERE,
|
325 |
+
# eval_dataset = YOUR_DATASET_HERE,
|
326 |
+
tokenizer = tokenizer,
|
327 |
+
max_length = 1024,
|
328 |
+
max_prompt_length = 512,
|
329 |
+
)
|
330 |
+
dpo_trainer.train()
|
331 |
+
```
|
332 |
+
|
333 |
+
## 🥇 Detailed Benchmarking Tables
|
334 |
+
- Click "Code" for fully reproducible examples
|
335 |
+
- "Unsloth Equal" is a preview of our PRO version, with code stripped out. All settings and the loss curve remains identical.
|
336 |
+
- For the full list of benchmarking tables, [go to our website](https://unsloth.ai/blog/mistral-benchmark#Benchmark%20tables)
|
337 |
+
|
338 |
+
| 1 A100 40GB | 🤗Hugging Face | Flash Attention 2 | 🦥Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
339 |
+
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
|
340 |
+
| Alpaca | 1x | 1.04x | 1.98x | 2.48x | 5.32x | **15.64x** |
|
341 |
+
| code | [Code](https://colab.research.google.com/drive/1u4dBeM-0vGNVmmO6X7cScAut-Hyt4KDF?usp=sharing) | [Code](https://colab.research.google.com/drive/1fgTOxpMbVjloQBvZyz4lF4BacKSZOB2A?usp=sharing) | [Code](https://colab.research.google.com/drive/1YIPY_18xm-K0iJDgvNkRoJsgkPMPAO3G?usp=sharing) | [Code](https://colab.research.google.com/drive/1ANW8EFL3LVyTD7Gq4TkheC1Z7Rxw-rHp?usp=sharing) | | |
|
342 |
+
| seconds| 1040 | 1001 | 525 | 419 | 196 | 67 |
|
343 |
+
| memory MB| 18235 | 15365 | 9631 | 8525 | | |
|
344 |
+
| % saved| | 15.74 | 47.18 | 53.25 | | | |
|
345 |
+
|
346 |
+
### Llama-Factory 3rd party benchmarking
|
347 |
+
- [Link to performance table.](https://github.com/hiyouga/LLaMA-Factory/wiki/Performance-Comparison) TGS: tokens per GPU per second. Model: LLaMA2-7B. GPU: NVIDIA A100 * 1. Batch size: 4. Gradient accumulation: 2. LoRA rank: 8. Max length: 1024.
|
348 |
+
|
349 |
+
| Method | Bits | TGS | GRAM | Speed |
|
350 |
+
| --- | --- | --- | --- | --- |
|
351 |
+
| HF | 16 | 2392 | 18GB | 100% |
|
352 |
+
| HF+FA2 | 16 | 2954 | 17GB | 123% |
|
353 |
+
| Unsloth+FA2 | 16 | 4007 | 16GB | **168%** |
|
354 |
+
| HF | 4 | 2415 | 9GB | 101% |
|
355 |
+
| Unsloth+FA2 | 4 | 3726 | 7GB | **160%** |
|
356 |
+
|
357 |
+
### Performance comparisons between popular models
|
358 |
+
<details>
|
359 |
+
<summary>Click for specific model benchmarking tables (Mistral 7b, CodeLlama 34b etc.)</summary>
|
360 |
+
|
361 |
+
### Mistral 7b
|
362 |
+
| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
363 |
+
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
|
364 |
+
| Mistral 7B Slim Orca | 1x | 1.15x | 2.15x | 2.53x | 4.61x | **13.69x** |
|
365 |
+
| code | [Code](https://colab.research.google.com/drive/1mePk3KzwTD81hr5mcNcs_AX3Kbg_Ha0x?usp=sharing) | [Code](https://colab.research.google.com/drive/1dgHxjvTmX6hb0bPcLp26RXSE6_n9DKj7?usp=sharing) | [Code](https://colab.research.google.com/drive/1SKrKGV-BZoU4kv5q3g0jtE_OhRgPtrrQ?usp=sharing) | [Code](https://colab.research.google.com/drive/18yOiyX0T81mTwZqOALFSCX_tSAqju6aD?usp=sharing) | |
|
366 |
+
| seconds | 1813 | 1571 | 842 | 718 | 393 | 132 |
|
367 |
+
| memory MB | 32853 | 19385 | 12465 | 10271 | | |
|
368 |
+
| % saved| | 40.99 | 62.06 | 68.74 | | |
|
369 |
+
|
370 |
+
### CodeLlama 34b
|
371 |
+
| 1 A100 40GB | Hugging Face | Flash Attention 2 | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
372 |
+
|--------------|-------------|-------------|-----------------|--------------|---------------|-------------|
|
373 |
+
| Code Llama 34B | OOM ❌ | 0.99x | 1.87x | 2.61x | 4.27x | 12.82x |
|
374 |
+
| code | [▶️ Code](https://colab.research.google.com/drive/1ykfz3BqrtC_AUFegCzUQjjfUNlxp6Otc?usp=sharing) | [Code](https://colab.research.google.com/drive/12ZypxQh7OC6kBXvWZI-5d05I4m-B_hoR?usp=sharing) | [Code](https://colab.research.google.com/drive/1gdHyAx8XJsz2yNV-DHvbHjR1iCef5Qmh?usp=sharing) | [Code](https://colab.research.google.com/drive/1fm7wqx9MJ0kRrwKOfmLkK1Rmw-pySahB?usp=sharing) | |
|
375 |
+
| seconds | 1953 | 1982 | 1043 | 748 | 458 | 152 |
|
376 |
+
| memory MB | 40000 | 33217 | 27413 | 22161 | | |
|
377 |
+
| % saved| | 16.96| 31.47 | 44.60 | | | |
|
378 |
+
|
379 |
+
### 1 Tesla T4
|
380 |
+
|
381 |
+
| 1 T4 16GB | Hugging Face | Flash Attention | Unsloth Open | Unsloth Pro Equal | Unsloth Pro | Unsloth Max |
|
382 |
+
|--------------|-------------|-----------------|-----------------|---------------|---------------|-------------|
|
383 |
+
| Alpaca | 1x | 1.09x | 1.69x | 1.79x | 2.93x | **8.3x** |
|
384 |
+
| code | [▶️ Code](https://colab.research.google.com/drive/1XpLIV4s8Bj5uryB-X2gqM88oRGHEGdaB?usp=sharing) | [Code](https://colab.research.google.com/drive/1LyXu6CjuymQg6ddHX8g1dpUvrMa1nn4L?usp=sharing) | [Code](https://colab.research.google.com/drive/1gsv4LpY7C32otl1rgRo5wXTk4HIitXoM?usp=sharing) | [Code](https://colab.research.google.com/drive/1VtULwRQwhEnVdNryjm27zXfdSM1tNfFK?usp=sharing) | | |
|
385 |
+
| seconds | 1599 | 1468 | 942 | 894 | 545 | 193 |
|
386 |
+
| memory MB | 7199 | 7059 | 6459 | 5443 | | |
|
387 |
+
| % saved | | 1.94 | 10.28 | 24.39 | | |
|
388 |
+
|
389 |
+
### 2 Tesla T4s via DDP
|
390 |
+
|
391 |
+
| 2 T4 DDP | Hugging Face | Flash Attention | Unsloth Open | Unsloth Equal | Unsloth Pro | Unsloth Max |
|
392 |
+
|--------------|----------|-------------|-----------------|--------------|---------------|-------------|
|
393 |
+
| Alpaca | 1x | 0.99x | 4.95x | 4.44x | 7.28x | **20.61x** |
|
394 |
+
| code | [▶️ Code](https://www.kaggle.com/danielhanchen/hf-original-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/hf-sdpa-alpaca-t4-ddp) | [Code](https://www.kaggle.com/danielhanchen/unsloth-alpaca-t4-ddp) | | |
|
395 |
+
| seconds | 9882 | 9946 | 1996 | 2227 | 1357 | 480 |
|
396 |
+
| memory MB| 9176 | 9128 | 6904 | 6782 | | |
|
397 |
+
| % saved | | 0.52 | 24.76 | 26.09 | | | |
|
398 |
+
</details>
|
399 |
+
|
400 |
+
### Performance comparisons on 1 Tesla T4 GPU:
|
401 |
+
<details>
|
402 |
+
<summary>Click for Time taken for 1 epoch</summary>
|
403 |
+
|
404 |
+
One Tesla T4 on Google Colab
|
405 |
+
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
|
406 |
+
|
407 |
+
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
|
408 |
+
| --- | --- | --- | --- | --- | --- |
|
409 |
+
| Huggingface | 1 T4 | 23h 15m | 56h 28m | 8h 38m | 391h 41m |
|
410 |
+
| Unsloth Open | 1 T4 | 13h 7m (1.8x) | 31h 47m (1.8x) | 4h 27m (1.9x) | 240h 4m (1.6x) |
|
411 |
+
| Unsloth Pro | 1 T4 | 3h 6m (7.5x) | 5h 17m (10.7x) | 1h 7m (7.7x) | 59h 53m (6.5x) |
|
412 |
+
| Unsloth Max | 1 T4 | 2h 39m (8.8x) | 4h 31m (12.5x) | 0h 58m (8.9x) | 51h 30m (7.6x) |
|
413 |
+
|
414 |
+
**Peak Memory Usage**
|
415 |
+
|
416 |
+
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) |
|
417 |
+
| --- | --- | --- | --- | --- | --- |
|
418 |
+
| Huggingface | 1 T4 | 7.3GB | 5.9GB | 14.0GB | 13.3GB |
|
419 |
+
| Unsloth Open | 1 T4 | 6.8GB | 5.7GB | 7.8GB | 7.7GB |
|
420 |
+
| Unsloth Pro | 1 T4 | 6.4GB | 6.4GB | 6.4GB | 6.4GB |
|
421 |
+
| Unsloth Max | 1 T4 | 11.4GB | 12.4GB | 11.9GB | 14.4GB |
|
422 |
+
</details>
|
423 |
+
|
424 |
+
<details>
|
425 |
+
<summary>Click for Performance Comparisons on 2 Tesla T4 GPUs via DDP:</summary>
|
426 |
+
**Time taken for 1 epoch**
|
427 |
+
|
428 |
+
Two Tesla T4s on Kaggle
|
429 |
+
`bsz = 2, ga = 4, max_grad_norm = 0.3, num_train_epochs = 1, seed = 3047, lr = 2e-4, wd = 0.01, optim = "adamw_8bit", schedule = "linear", schedule_steps = 10`
|
430 |
+
|
431 |
+
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
|
432 |
+
| --- | --- | --- | --- | --- | --- |
|
433 |
+
| Huggingface | 2 T4 | 84h 47m | 163h 48m | 30h 51m | 1301h 24m * |
|
434 |
+
| Unsloth Pro | 2 T4 | 3h 20m (25.4x) | 5h 43m (28.7x) | 1h 12m (25.7x) | 71h 40m (18.1x) * |
|
435 |
+
| Unsloth Max | 2 T4 | 3h 4m (27.6x) | 5h 14m (31.3x) | 1h 6m (28.1x) | 54h 20m (23.9x) * |
|
436 |
+
|
437 |
+
**Peak Memory Usage on a Multi GPU System (2 GPUs)**
|
438 |
+
|
439 |
+
| System | GPU | Alpaca (52K) | LAION OIG (210K) | Open Assistant (10K) | SlimOrca (518K) * |
|
440 |
+
| --- | --- | --- | --- | --- | --- |
|
441 |
+
| Huggingface | 2 T4 | 8.4GB \| 6GB | 7.2GB \| 5.3GB | 14.3GB \| 6.6GB | 10.9GB \| 5.9GB * |
|
442 |
+
| Unsloth Pro | 2 T4 | 7.7GB \| 4.9GB | 7.5GB \| 4.9GB | 8.5GB \| 4.9GB | 6.2GB \| 4.7GB * |
|
443 |
+
| Unsloth Max | 2 T4 | 10.5GB \| 5GB | 10.6GB \| 5GB | 10.6GB \| 5GB | 10.5GB \| 5GB * |
|
444 |
+
|
445 |
+
* Slim Orca `bsz=1` for all benchmarks since `bsz=2` OOMs. We can handle `bsz=2`, but we benchmark it with `bsz=1` for consistency.
|
446 |
+
</details>
|
447 |
+
|
448 |
+

|
449 |
+
<br>
|
450 |
+
|
451 |
+
### Thank You to
|
452 |
+
- [HuyNguyen-hust](https://github.com/HuyNguyen-hust) for making [RoPE Embeddings 28% faster](https://github.com/unslothai/unsloth/pull/238)
|
453 |
+
- [RandomInternetPreson](https://github.com/RandomInternetPreson) for confirming WSL support
|
454 |
+
- [152334H](https://github.com/152334H) for experimental DPO support
|
455 |
+
- [atgctg](https://github.com/atgctg) for syntax highlighting
|
unsloth-main/unsloth-main/images/Assistant.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Colab.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Discord button.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Discord.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Free version button.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Kaggle.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Kofi button.png
ADDED
![]() |
unsloth-main/unsloth-main/images/LAION 2GPU.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Merge.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Run.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Slim Orca 2GPUs.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Terminal_Type.png
ADDED
![]() |
unsloth-main/unsloth-main/images/Where_Terminal.png
ADDED
![]() |
unsloth-main/unsloth-main/images/buy me a coffee button.png
ADDED
![]() |
unsloth-main/unsloth-main/images/made with unsloth.png
ADDED
![]() |
unsloth-main/unsloth-main/images/ollama.png
ADDED
![]() |
unsloth-main/unsloth-main/images/peft x trl button.png
ADDED
![]() |
unsloth-main/unsloth-main/images/start free finetune button.png
ADDED
![]() |
unsloth-main/unsloth-main/images/unsloth end.png
ADDED
![]() |
unsloth-main/unsloth-main/images/unsloth loading page render.png
ADDED
![]() |
unsloth-main/unsloth-main/images/unsloth logo black text.png
ADDED
![]() |
unsloth-main/unsloth-main/images/unsloth logo only.png
ADDED
![]() |
unsloth-main/unsloth-main/images/unsloth logo white text.png
ADDED
![]() |
unsloth-main/unsloth-main/images/unsloth made with love.png
ADDED
![]() |
unsloth-main/unsloth-main/images/unsloth new logo.png
ADDED
![]() |
unsloth-main/unsloth-main/pyproject.toml
ADDED
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools", "setuptools-scm"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "unsloth"
|
7 |
+
dynamic = ["version"]
|
8 |
+
description = "2-5X faster LLM finetuning"
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.9"
|
11 |
+
license = {file = "LICENSE"}
|
12 |
+
keywords = ["ai", "llm",]
|
13 |
+
authors = [
|
14 |
+
{email = "[email protected]"},
|
15 |
+
{name = "Unsloth AI team"},
|
16 |
+
]
|
17 |
+
maintainers = [
|
18 |
+
{name = "Daniel Han", email = "[email protected]"},
|
19 |
+
{name = "Michael Han", email = "[email protected]"},
|
20 |
+
]
|
21 |
+
classifiers = [
|
22 |
+
"Programming Language :: Python",
|
23 |
+
]
|
24 |
+
|
25 |
+
[tool.setuptools.dynamic]
|
26 |
+
version = {attr = "unsloth.models._utils.__version__"}
|
27 |
+
|
28 |
+
[tool.setuptools]
|
29 |
+
include-package-data = false
|
30 |
+
|
31 |
+
[tool.setuptools.packages.find]
|
32 |
+
exclude = ["images*"]
|
33 |
+
|
34 |
+
[project.optional-dependencies]
|
35 |
+
huggingface = [
|
36 |
+
"packaging",
|
37 |
+
"tyro",
|
38 |
+
"transformers>=4.44.2",
|
39 |
+
"datasets>=2.16.0",
|
40 |
+
"sentencepiece>=0.2.0",
|
41 |
+
"tqdm",
|
42 |
+
"psutil",
|
43 |
+
"wheel>=0.42.0",
|
44 |
+
"numpy",
|
45 |
+
"accelerate>=0.34.1",
|
46 |
+
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
|
47 |
+
"peft>=0.7.1,!=0.11.0",
|
48 |
+
"protobuf<4.0.0",
|
49 |
+
"huggingface_hub",
|
50 |
+
"hf_transfer",
|
51 |
+
]
|
52 |
+
cu118only = [
|
53 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
54 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
55 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.22.post7%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
56 |
+
]
|
57 |
+
cu121only = [
|
58 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
59 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
60 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
61 |
+
]
|
62 |
+
cu118onlytorch211 = [
|
63 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
64 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
65 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
66 |
+
]
|
67 |
+
cu121onlytorch211 = [
|
68 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
69 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
70 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
71 |
+
]
|
72 |
+
cu118onlytorch212 = [
|
73 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
74 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
75 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.23.post1%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
76 |
+
]
|
77 |
+
cu121onlytorch212 = [
|
78 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
79 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
80 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.23.post1-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
81 |
+
]
|
82 |
+
cu118onlytorch220 = [
|
83 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
84 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
85 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.24%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
86 |
+
]
|
87 |
+
cu121onlytorch220 = [
|
88 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
89 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
90 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.24-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
91 |
+
]
|
92 |
+
cu118onlytorch230 = [
|
93 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
94 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
95 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
96 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
|
97 |
+
]
|
98 |
+
cu121onlytorch230 = [
|
99 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
100 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
101 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
102 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
|
103 |
+
]
|
104 |
+
cu118onlytorch240 = [
|
105 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
106 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
107 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
108 |
+
"xformers @ https://download.pytorch.org/whl/cu118/xformers-0.0.27.post2%2Bcu118-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
|
109 |
+
]
|
110 |
+
cu121onlytorch240 = [
|
111 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp39-cp39-manylinux2014_x86_64.whl ; python_version=='3.9'",
|
112 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp310-cp310-manylinux2014_x86_64.whl ; python_version=='3.10'",
|
113 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp311-cp311-manylinux2014_x86_64.whl ; python_version=='3.11'",
|
114 |
+
"xformers @ https://download.pytorch.org/whl/cu121/xformers-0.0.27.post2-cp312-cp312-manylinux2014_x86_64.whl ; python_version=='3.12'",
|
115 |
+
]
|
116 |
+
cu118 = [
|
117 |
+
"unsloth[huggingface]",
|
118 |
+
"bitsandbytes>=0.43.3",
|
119 |
+
"unsloth[cu118only]",
|
120 |
+
]
|
121 |
+
cu121 = [
|
122 |
+
"unsloth[huggingface]",
|
123 |
+
"bitsandbytes>=0.43.3",
|
124 |
+
"unsloth[cu121only]",
|
125 |
+
]
|
126 |
+
cu118-torch211 = [
|
127 |
+
"unsloth[huggingface]",
|
128 |
+
"bitsandbytes>=0.43.3",
|
129 |
+
"unsloth[cu118onlytorch211]",
|
130 |
+
]
|
131 |
+
cu121-torch211 = [
|
132 |
+
"unsloth[huggingface]",
|
133 |
+
"bitsandbytes>=0.43.3",
|
134 |
+
"unsloth[cu121onlytorch211]",
|
135 |
+
]
|
136 |
+
cu118-torch212 = [
|
137 |
+
"unsloth[huggingface]",
|
138 |
+
"bitsandbytes>=0.43.3",
|
139 |
+
"unsloth[cu118onlytorch212]",
|
140 |
+
]
|
141 |
+
cu121-torch212 = [
|
142 |
+
"unsloth[huggingface]",
|
143 |
+
"bitsandbytes>=0.43.3",
|
144 |
+
"unsloth[cu121onlytorch212]",
|
145 |
+
]
|
146 |
+
cu118-torch220 = [
|
147 |
+
"unsloth[huggingface]",
|
148 |
+
"bitsandbytes>=0.43.3",
|
149 |
+
"unsloth[cu118onlytorch220]",
|
150 |
+
]
|
151 |
+
cu121-torch220 = [
|
152 |
+
"unsloth[huggingface]",
|
153 |
+
"bitsandbytes>=0.43.3",
|
154 |
+
"unsloth[cu121onlytorch220]",
|
155 |
+
]
|
156 |
+
cu118-torch230 = [
|
157 |
+
"unsloth[huggingface]",
|
158 |
+
"bitsandbytes>=0.43.3",
|
159 |
+
"unsloth[cu118onlytorch230]",
|
160 |
+
]
|
161 |
+
cu121-torch230 = [
|
162 |
+
"unsloth[huggingface]",
|
163 |
+
"bitsandbytes>=0.43.3",
|
164 |
+
"unsloth[cu121onlytorch230]",
|
165 |
+
]
|
166 |
+
cu118-torch240 = [
|
167 |
+
"unsloth[huggingface]",
|
168 |
+
"bitsandbytes>=0.43.3",
|
169 |
+
"unsloth[cu118onlytorch240]",
|
170 |
+
]
|
171 |
+
cu121-torch240 = [
|
172 |
+
"unsloth[huggingface]",
|
173 |
+
"bitsandbytes>=0.43.3",
|
174 |
+
"unsloth[cu121onlytorch240]",
|
175 |
+
]
|
176 |
+
kaggle = [
|
177 |
+
"unsloth[huggingface]",
|
178 |
+
]
|
179 |
+
kaggle-new = [
|
180 |
+
"unsloth[huggingface]",
|
181 |
+
"bitsandbytes>=0.43.3",
|
182 |
+
]
|
183 |
+
conda = [
|
184 |
+
"unsloth[huggingface]",
|
185 |
+
]
|
186 |
+
colab-torch211 = [
|
187 |
+
"unsloth[huggingface]",
|
188 |
+
"bitsandbytes>=0.43.3",
|
189 |
+
"unsloth[cu121onlytorch211]",
|
190 |
+
]
|
191 |
+
colab-ampere-torch211 = [
|
192 |
+
"unsloth[huggingface]",
|
193 |
+
"bitsandbytes>=0.43.3",
|
194 |
+
"unsloth[cu121onlytorch211]",
|
195 |
+
"packaging",
|
196 |
+
"ninja",
|
197 |
+
"flash-attn>=2.6.3",
|
198 |
+
]
|
199 |
+
colab-torch220 = [
|
200 |
+
"unsloth[huggingface]",
|
201 |
+
"bitsandbytes>=0.43.3",
|
202 |
+
"unsloth[cu121onlytorch220]",
|
203 |
+
]
|
204 |
+
colab-ampere-torch220 = [
|
205 |
+
"unsloth[huggingface]",
|
206 |
+
"bitsandbytes>=0.43.3",
|
207 |
+
"unsloth[cu121onlytorch220]",
|
208 |
+
"packaging",
|
209 |
+
"ninja",
|
210 |
+
"flash-attn>=2.6.3",
|
211 |
+
]
|
212 |
+
colab-new = [
|
213 |
+
"packaging",
|
214 |
+
"tyro",
|
215 |
+
"transformers>=4.44.2",
|
216 |
+
"datasets>=2.16.0",
|
217 |
+
"sentencepiece>=0.2.0",
|
218 |
+
"tqdm",
|
219 |
+
"psutil",
|
220 |
+
"wheel>=0.42.0",
|
221 |
+
"numpy",
|
222 |
+
"protobuf<4.0.0",
|
223 |
+
"huggingface_hub",
|
224 |
+
"hf_transfer",
|
225 |
+
]
|
226 |
+
colab-no-deps = [
|
227 |
+
"accelerate>=0.34.1",
|
228 |
+
"trl>=0.7.9,!=0.9.0,!=0.9.1,!=0.9.2,!=0.9.3,<=0.11.1",
|
229 |
+
"peft>=0.7.1",
|
230 |
+
"xformers<0.0.27",
|
231 |
+
"bitsandbytes>=0.43.3",
|
232 |
+
"protobuf<4.0.0",
|
233 |
+
]
|
234 |
+
colab = [
|
235 |
+
"unsloth[cu121]",
|
236 |
+
]
|
237 |
+
colab-ampere = [
|
238 |
+
"unsloth[colab-ampere-torch220]",
|
239 |
+
"packaging",
|
240 |
+
"ninja",
|
241 |
+
"flash-attn>=2.6.3",
|
242 |
+
]
|
243 |
+
cu118-ampere = [
|
244 |
+
"unsloth[huggingface]",
|
245 |
+
"bitsandbytes>=0.43.3",
|
246 |
+
"unsloth[cu118only]",
|
247 |
+
"packaging",
|
248 |
+
"ninja",
|
249 |
+
"flash-attn>=2.6.3",
|
250 |
+
]
|
251 |
+
cu121-ampere = [
|
252 |
+
"unsloth[huggingface]",
|
253 |
+
"bitsandbytes>=0.43.3",
|
254 |
+
"unsloth[cu121only]",
|
255 |
+
"packaging",
|
256 |
+
"ninja",
|
257 |
+
"flash-attn>=2.6.3",
|
258 |
+
]
|
259 |
+
cu118-ampere-torch211 = [
|
260 |
+
"unsloth[huggingface]",
|
261 |
+
"bitsandbytes>=0.43.3",
|
262 |
+
"unsloth[cu118onlytorch211]",
|
263 |
+
"packaging",
|
264 |
+
"ninja",
|
265 |
+
"flash-attn>=2.6.3",
|
266 |
+
]
|
267 |
+
cu121-ampere-torch211 = [
|
268 |
+
"unsloth[huggingface]",
|
269 |
+
"bitsandbytes>=0.43.3",
|
270 |
+
"unsloth[cu121onlytorch211]",
|
271 |
+
"packaging",
|
272 |
+
"ninja",
|
273 |
+
"flash-attn>=2.6.3",
|
274 |
+
]
|
275 |
+
cu118-ampere-torch220 = [
|
276 |
+
"unsloth[huggingface]",
|
277 |
+
"bitsandbytes>=0.43.3",
|
278 |
+
"unsloth[cu118onlytorch220]",
|
279 |
+
"packaging",
|
280 |
+
"ninja",
|
281 |
+
"flash-attn>=2.6.3",
|
282 |
+
]
|
283 |
+
cu121-ampere-torch220 = [
|
284 |
+
"unsloth[huggingface]",
|
285 |
+
"bitsandbytes>=0.43.3",
|
286 |
+
"unsloth[cu121onlytorch220]",
|
287 |
+
"packaging",
|
288 |
+
"ninja",
|
289 |
+
"flash-attn>=2.6.3",
|
290 |
+
]
|
291 |
+
cu118-ampere-torch230 = [
|
292 |
+
"unsloth[huggingface]",
|
293 |
+
"bitsandbytes>=0.43.3",
|
294 |
+
"unsloth[cu118onlytorch230]",
|
295 |
+
"packaging",
|
296 |
+
"ninja",
|
297 |
+
"flash-attn>=2.6.3",
|
298 |
+
]
|
299 |
+
cu121-ampere-torch230 = [
|
300 |
+
"unsloth[huggingface]",
|
301 |
+
"bitsandbytes>=0.43.3",
|
302 |
+
"unsloth[cu121onlytorch230]",
|
303 |
+
"packaging",
|
304 |
+
"ninja",
|
305 |
+
"flash-attn>=2.6.3",
|
306 |
+
]
|
307 |
+
cu118-ampere-torch240 = [
|
308 |
+
"unsloth[huggingface]",
|
309 |
+
"bitsandbytes>=0.43.3",
|
310 |
+
"unsloth[cu118onlytorch240]",
|
311 |
+
"packaging",
|
312 |
+
"ninja",
|
313 |
+
"flash-attn>=2.6.3",
|
314 |
+
]
|
315 |
+
cu121-ampere-torch240 = [
|
316 |
+
"unsloth[huggingface]",
|
317 |
+
"bitsandbytes>=0.43.3",
|
318 |
+
"unsloth[cu121onlytorch240]",
|
319 |
+
"packaging",
|
320 |
+
"ninja",
|
321 |
+
"flash-attn>=2.6.3",
|
322 |
+
]
|
323 |
+
|
324 |
+
[project.urls]
|
325 |
+
homepage = "http://www.unsloth.ai"
|
326 |
+
documentation = "https://github.com/unslothai/unsloth"
|
327 |
+
repository = "https://github.com/unslothai/unsloth"
|
unsloth-main/unsloth-main/unsloth-cli.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
"""
|
4 |
+
🦥 Starter Script for Fine-Tuning FastLanguageModel with Unsloth
|
5 |
+
|
6 |
+
This script is designed as a starting point for fine-tuning your models using unsloth.
|
7 |
+
It includes configurable options for model loading, PEFT parameters, training arguments,
|
8 |
+
and model saving/pushing functionalities.
|
9 |
+
|
10 |
+
You will likely want to customize this script to suit your specific use case
|
11 |
+
and requirements.
|
12 |
+
|
13 |
+
Here are a few suggestions for customization:
|
14 |
+
- Modify the dataset loading and preprocessing steps to match your data.
|
15 |
+
- Customize the model saving and pushing configurations.
|
16 |
+
|
17 |
+
Usage: (most of the options have valid default values this is an extended example for demonstration purposes)
|
18 |
+
python unsloth-cli.py --model_name "unsloth/llama-3-8b" --max_seq_length 8192 --dtype None --load_in_4bit \
|
19 |
+
--r 64 --lora_alpha 32 --lora_dropout 0.1 --bias "none" --use_gradient_checkpointing "unsloth" \
|
20 |
+
--random_state 3407 --use_rslora --per_device_train_batch_size 4 --gradient_accumulation_steps 8 \
|
21 |
+
--warmup_steps 5 --max_steps 400 --learning_rate 2e-6 --logging_steps 1 --optim "adamw_8bit" \
|
22 |
+
--weight_decay 0.005 --lr_scheduler_type "linear" --seed 3407 --output_dir "outputs" \
|
23 |
+
--report_to "tensorboard" --save_model --save_path "model" --quantization_method "f16" \
|
24 |
+
--push_model --hub_path "hf/model" --hub_token "your_hf_token"
|
25 |
+
|
26 |
+
To see a full list of configurable options, use:
|
27 |
+
python unsloth-cli.py --help
|
28 |
+
|
29 |
+
Happy fine-tuning!
|
30 |
+
"""
|
31 |
+
|
32 |
+
import argparse
|
33 |
+
|
34 |
+
def run(args):
|
35 |
+
import torch
|
36 |
+
from unsloth import FastLanguageModel
|
37 |
+
from datasets import load_dataset
|
38 |
+
from trl import SFTTrainer
|
39 |
+
from transformers import TrainingArguments
|
40 |
+
from unsloth import is_bfloat16_supported
|
41 |
+
import logging
|
42 |
+
logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
|
43 |
+
|
44 |
+
# Load model and tokenizer
|
45 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
46 |
+
model_name=args.model_name,
|
47 |
+
max_seq_length=args.max_seq_length,
|
48 |
+
dtype=args.dtype,
|
49 |
+
load_in_4bit=args.load_in_4bit,
|
50 |
+
)
|
51 |
+
|
52 |
+
# Configure PEFT model
|
53 |
+
model = FastLanguageModel.get_peft_model(
|
54 |
+
model,
|
55 |
+
r=args.r,
|
56 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
57 |
+
"gate_proj", "up_proj", "down_proj"],
|
58 |
+
lora_alpha=args.lora_alpha,
|
59 |
+
lora_dropout=args.lora_dropout,
|
60 |
+
bias=args.bias,
|
61 |
+
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
62 |
+
random_state=args.random_state,
|
63 |
+
use_rslora=args.use_rslora,
|
64 |
+
loftq_config=args.loftq_config,
|
65 |
+
)
|
66 |
+
|
67 |
+
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
68 |
+
|
69 |
+
### Instruction:
|
70 |
+
{}
|
71 |
+
|
72 |
+
### Input:
|
73 |
+
{}
|
74 |
+
|
75 |
+
### Response:
|
76 |
+
{}"""
|
77 |
+
|
78 |
+
EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
|
79 |
+
def formatting_prompts_func(examples):
|
80 |
+
instructions = examples["instruction"]
|
81 |
+
inputs = examples["input"]
|
82 |
+
outputs = examples["output"]
|
83 |
+
texts = []
|
84 |
+
for instruction, input, output in zip(instructions, inputs, outputs):
|
85 |
+
text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
|
86 |
+
texts.append(text)
|
87 |
+
return {"text": texts}
|
88 |
+
|
89 |
+
# Load and format dataset
|
90 |
+
dataset = load_dataset(args.dataset, split="train")
|
91 |
+
dataset = dataset.map(formatting_prompts_func, batched=True)
|
92 |
+
print("Data is formatted and ready!")
|
93 |
+
|
94 |
+
# Configure training arguments
|
95 |
+
training_args = TrainingArguments(
|
96 |
+
per_device_train_batch_size=args.per_device_train_batch_size,
|
97 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
98 |
+
warmup_steps=args.warmup_steps,
|
99 |
+
max_steps=args.max_steps,
|
100 |
+
learning_rate=args.learning_rate,
|
101 |
+
fp16=not is_bfloat16_supported(),
|
102 |
+
bf16=is_bfloat16_supported(),
|
103 |
+
logging_steps=args.logging_steps,
|
104 |
+
optim=args.optim,
|
105 |
+
weight_decay=args.weight_decay,
|
106 |
+
lr_scheduler_type=args.lr_scheduler_type,
|
107 |
+
seed=args.seed,
|
108 |
+
output_dir=args.output_dir,
|
109 |
+
report_to=args.report_to,
|
110 |
+
)
|
111 |
+
|
112 |
+
# Initialize trainer
|
113 |
+
trainer = SFTTrainer(
|
114 |
+
model=model,
|
115 |
+
tokenizer=tokenizer,
|
116 |
+
train_dataset=dataset,
|
117 |
+
dataset_text_field="text",
|
118 |
+
max_seq_length=args.max_seq_length,
|
119 |
+
dataset_num_proc=2,
|
120 |
+
packing=False,
|
121 |
+
args=training_args,
|
122 |
+
)
|
123 |
+
|
124 |
+
# Train model
|
125 |
+
trainer_stats = trainer.train()
|
126 |
+
|
127 |
+
# Save model
|
128 |
+
if args.save_model:
|
129 |
+
# if args.quantization_method is a list, we will save the model for each quantization method
|
130 |
+
if args.save_gguf:
|
131 |
+
if isinstance(args.quantization, list):
|
132 |
+
for quantization_method in args.quantization:
|
133 |
+
print(f"Saving model with quantization method: {quantization_method}")
|
134 |
+
model.save_pretrained_gguf(
|
135 |
+
args.save_path,
|
136 |
+
tokenizer,
|
137 |
+
quantization_method=quantization_method,
|
138 |
+
)
|
139 |
+
if args.push_model:
|
140 |
+
model.push_to_hub_gguf(
|
141 |
+
hub_path=args.hub_path,
|
142 |
+
hub_token=args.hub_token,
|
143 |
+
quantization_method=quantization_method,
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
print(f"Saving model with quantization method: {args.quantization}")
|
147 |
+
model.save_pretrained_gguf(args.save_path, tokenizer, quantization_method=args.quantization)
|
148 |
+
if args.push_model:
|
149 |
+
model.push_to_hub_gguf(
|
150 |
+
hub_path=args.hub_path,
|
151 |
+
hub_token=args.hub_token,
|
152 |
+
quantization_method=quantization_method,
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
|
156 |
+
if args.push_model:
|
157 |
+
model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
|
158 |
+
else:
|
159 |
+
print("Warning: The model is not saved!")
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
|
164 |
+
# Define argument parser
|
165 |
+
parser = argparse.ArgumentParser(description="🦥 Fine-tune your llm faster using unsloth!")
|
166 |
+
|
167 |
+
model_group = parser.add_argument_group("🤖 Model Options")
|
168 |
+
model_group.add_argument('--model_name', type=str, default="unsloth/llama-3-8b", help="Model name to load")
|
169 |
+
model_group.add_argument('--max_seq_length', type=int, default=2048, help="Maximum sequence length, default is 2048. We auto support RoPE Scaling internally!")
|
170 |
+
model_group.add_argument('--dtype', type=str, default=None, help="Data type for model (None for auto detection)")
|
171 |
+
model_group.add_argument('--load_in_4bit', action='store_true', help="Use 4bit quantization to reduce memory usage")
|
172 |
+
model_group.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", help="Huggingface dataset to use for training")
|
173 |
+
|
174 |
+
lora_group = parser.add_argument_group("🧠 LoRA Options", "These options are used to configure the LoRA model.")
|
175 |
+
lora_group.add_argument('--r', type=int, default=16, help="Rank for Lora model, default is 16. (common values: 8, 16, 32, 64, 128)")
|
176 |
+
lora_group.add_argument('--lora_alpha', type=int, default=16, help="LoRA alpha parameter, default is 16. (common values: 8, 16, 32, 64, 128)")
|
177 |
+
lora_group.add_argument('--lora_dropout', type=float, default=0, help="LoRA dropout rate, default is 0.0 which is optimized.")
|
178 |
+
lora_group.add_argument('--bias', type=str, default="none", help="Bias setting for LoRA")
|
179 |
+
lora_group.add_argument('--use_gradient_checkpointing', type=str, default="unsloth", help="Use gradient checkpointing")
|
180 |
+
lora_group.add_argument('--random_state', type=int, default=3407, help="Random state for reproducibility, default is 3407.")
|
181 |
+
lora_group.add_argument('--use_rslora', action='store_true', help="Use rank stabilized LoRA")
|
182 |
+
lora_group.add_argument('--loftq_config', type=str, default=None, help="Configuration for LoftQ")
|
183 |
+
|
184 |
+
|
185 |
+
training_group = parser.add_argument_group("🎓 Training Options")
|
186 |
+
training_group.add_argument('--per_device_train_batch_size', type=int, default=2, help="Batch size per device during training, default is 2.")
|
187 |
+
training_group.add_argument('--gradient_accumulation_steps', type=int, default=4, help="Number of gradient accumulation steps, default is 4.")
|
188 |
+
training_group.add_argument('--warmup_steps', type=int, default=5, help="Number of warmup steps, default is 5.")
|
189 |
+
training_group.add_argument('--max_steps', type=int, default=400, help="Maximum number of training steps.")
|
190 |
+
training_group.add_argument('--learning_rate', type=float, default=2e-4, help="Learning rate, default is 2e-4.")
|
191 |
+
training_group.add_argument('--optim', type=str, default="adamw_8bit", help="Optimizer type.")
|
192 |
+
training_group.add_argument('--weight_decay', type=float, default=0.01, help="Weight decay, default is 0.01.")
|
193 |
+
training_group.add_argument('--lr_scheduler_type', type=str, default="linear", help="Learning rate scheduler type, default is 'linear'.")
|
194 |
+
training_group.add_argument('--seed', type=int, default=3407, help="Seed for reproducibility, default is 3407.")
|
195 |
+
|
196 |
+
|
197 |
+
# Report/Logging arguments
|
198 |
+
report_group = parser.add_argument_group("📊 Report Options")
|
199 |
+
report_group.add_argument('--report_to', type=str, default="tensorboard",
|
200 |
+
choices=["azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "neptune", "tensorboard", "wandb", "all", "none"],
|
201 |
+
help="The list of integrations to report the results and logs to. Supported platforms are: \n\t\t 'azure_ml', 'clearml', 'codecarbon', 'comet_ml', 'dagshub', 'dvclive', 'flyte', 'mlflow', 'neptune', 'tensorboard', and 'wandb'. Use 'all' to report to all integrations installed, 'none' for no integrations.")
|
202 |
+
report_group.add_argument('--logging_steps', type=int, default=1, help="Logging steps, default is 1")
|
203 |
+
|
204 |
+
# Saving and pushing arguments
|
205 |
+
save_group = parser.add_argument_group('💾 Save Model Options')
|
206 |
+
save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
|
207 |
+
save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
|
208 |
+
save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
|
209 |
+
save_group.add_argument('--save_gguf', action='store_true', help="Convert the model to GGUF after training")
|
210 |
+
save_group.add_argument('--save_path', type=str, default="model", help="Path to save the model")
|
211 |
+
save_group.add_argument('--quantization', type=str, default="q8_0", nargs="+",
|
212 |
+
help="Quantization method for saving the model. common values ('f16', 'q4_k_m', 'q8_0'), Check our wiki for all quantization methods https://github.com/unslothai/unsloth/wiki#saving-to-gguf ")
|
213 |
+
|
214 |
+
push_group = parser.add_argument_group('🚀 Push Model Options')
|
215 |
+
push_group.add_argument('--push_model', action='store_true', help="Push the model to Hugging Face hub after training")
|
216 |
+
push_group.add_argument('--push_gguf', action='store_true', help="Push the model as GGUF to Hugging Face hub after training")
|
217 |
+
push_group.add_argument('--hub_path', type=str, default="hf/model", help="Path on Hugging Face hub to push the model")
|
218 |
+
push_group.add_argument('--hub_token', type=str, help="Token for pushing the model to Hugging Face hub")
|
219 |
+
|
220 |
+
args = parser.parse_args()
|
221 |
+
run(args)
|
unsloth-main/unsloth-main/unsloth/__init__.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import warnings, importlib, sys
|
16 |
+
from packaging.version import Version
|
17 |
+
import os, re, subprocess, inspect
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
# # Define a list of modules to check
|
21 |
+
# MODULES_TO_CHECK = ["bitsandbytes"]
|
22 |
+
|
23 |
+
# # Check if any of the modules in the list have been imported
|
24 |
+
# for module in MODULES_TO_CHECK:
|
25 |
+
# if module in sys.modules:
|
26 |
+
# raise ImportError(f"Unsloth: Please import Unsloth before {module}.")
|
27 |
+
# pass
|
28 |
+
# pass
|
29 |
+
|
30 |
+
# Unsloth currently does not work on multi GPU setups - sadly we are a 2 brother team so
|
31 |
+
# enabling it will require much more work, so we have to prioritize. Please understand!
|
32 |
+
# We do have a beta version, which you can contact us about!
|
33 |
+
# Thank you for your understanding and we appreciate it immensely!
|
34 |
+
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
35 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
36 |
+
devices = os.environ["CUDA_VISIBLE_DEVICES"]
|
37 |
+
# Check if there are multiple cuda devices set in env
|
38 |
+
if not devices.isdigit():
|
39 |
+
first_id = devices.split(",")[0]
|
40 |
+
warnings.warn(
|
41 |
+
f"Unsloth: 'CUDA_VISIBLE_DEVICES' is currently {devices} \n"\
|
42 |
+
"Unsloth currently does not support multi GPU setups - but we are working on it!\n"\
|
43 |
+
"Multiple CUDA devices detected but we require a single device.\n"\
|
44 |
+
f"We will override CUDA_VISIBLE_DEVICES to first device: {first_id}."
|
45 |
+
)
|
46 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(first_id)
|
47 |
+
else:
|
48 |
+
# warnings.warn("Unsloth: 'CUDA_VISIBLE_DEVICES' is not set. We shall set it ourselves.")
|
49 |
+
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
50 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
51 |
+
pass
|
52 |
+
|
53 |
+
# Reduce VRAM usage by reducing fragmentation
|
54 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
55 |
+
|
56 |
+
try:
|
57 |
+
import torch
|
58 |
+
except:
|
59 |
+
raise ImportError("Pytorch is not installed. Go to https://pytorch.org/.\n"\
|
60 |
+
"We have some installation instructions on our Github page.")
|
61 |
+
pass
|
62 |
+
|
63 |
+
# Hugging Face Hub faster downloads (only enable during Colab and Kaggle sessions)
|
64 |
+
keynames = "\n" + "\n".join(os.environ.keys())
|
65 |
+
if "\nCOLAB_" in keynames or "\nKAGGLE_" in keynames:
|
66 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
67 |
+
pass
|
68 |
+
|
69 |
+
# We support Pytorch 2
|
70 |
+
# Fixes https://github.com/unslothai/unsloth/issues/38
|
71 |
+
torch_version = torch.__version__.split(".")
|
72 |
+
major_torch, minor_torch = torch_version[0], torch_version[1]
|
73 |
+
major_torch, minor_torch = int(major_torch), int(minor_torch)
|
74 |
+
if (major_torch < 2):
|
75 |
+
raise ImportError("Unsloth only supports Pytorch 2 for now. Please update your Pytorch to 2.1.\n"\
|
76 |
+
"We have some installation instructions on our Github page.")
|
77 |
+
elif (major_torch == 2) and (minor_torch < 2):
|
78 |
+
# Disable expandable_segments
|
79 |
+
del os.environ["PYTORCH_CUDA_ALLOC_CONF"]
|
80 |
+
pass
|
81 |
+
|
82 |
+
# Torch 2.4 has including_emulation
|
83 |
+
major_version, minor_version = torch.cuda.get_device_capability()
|
84 |
+
SUPPORTS_BFLOAT16 = (major_version >= 8)
|
85 |
+
|
86 |
+
old_is_bf16_supported = torch.cuda.is_bf16_supported
|
87 |
+
if "including_emulation" in str(inspect.signature(old_is_bf16_supported)):
|
88 |
+
def is_bf16_supported(including_emulation = False):
|
89 |
+
return old_is_bf16_supported(including_emulation)
|
90 |
+
torch.cuda.is_bf16_supported = is_bf16_supported
|
91 |
+
else:
|
92 |
+
def is_bf16_supported(): return SUPPORTS_BFLOAT16
|
93 |
+
torch.cuda.is_bf16_supported = is_bf16_supported
|
94 |
+
pass
|
95 |
+
|
96 |
+
# Try loading bitsandbytes and triton
|
97 |
+
import bitsandbytes as bnb
|
98 |
+
|
99 |
+
if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
|
100 |
+
|
101 |
+
import triton
|
102 |
+
libcuda_dirs = lambda: None
|
103 |
+
if Version(triton.__version__) >= Version("3.0.0"):
|
104 |
+
try: from triton.backends.nvidia.driver import libcuda_dirs
|
105 |
+
except: pass
|
106 |
+
else: from triton.common.build import libcuda_dirs
|
107 |
+
|
108 |
+
try:
|
109 |
+
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
110 |
+
libcuda_dirs()
|
111 |
+
except:
|
112 |
+
warnings.warn(
|
113 |
+
"Unsloth: Running `ldconfig /usr/lib64-nvidia` to link CUDA."\
|
114 |
+
)
|
115 |
+
|
116 |
+
if os.path.exists("/usr/lib64-nvidia"):
|
117 |
+
os.system("ldconfig /usr/lib64-nvidia")
|
118 |
+
elif os.path.exists("/usr/local"):
|
119 |
+
# Sometimes bitsandbytes cannot be linked properly in Runpod for example
|
120 |
+
possible_cudas = subprocess.check_output(["ls", "-al", "/usr/local"]).decode("utf-8").split("\n")
|
121 |
+
find_cuda = re.compile(r"[\s](cuda\-[\d\.]{2,})$")
|
122 |
+
possible_cudas = [find_cuda.search(x) for x in possible_cudas]
|
123 |
+
possible_cudas = [x.group(1) for x in possible_cudas if x is not None]
|
124 |
+
|
125 |
+
# Try linking cuda folder, or everything in local
|
126 |
+
if len(possible_cudas) == 0:
|
127 |
+
os.system(f"ldconfig /usr/local/")
|
128 |
+
else:
|
129 |
+
find_number = re.compile(r"([\d\.]{2,})")
|
130 |
+
latest_cuda = np.argsort([float(find_number.search(x).group(1)) for x in possible_cudas])[::-1][0]
|
131 |
+
latest_cuda = possible_cudas[latest_cuda]
|
132 |
+
os.system(f"ldconfig /usr/local/{latest_cuda}")
|
133 |
+
pass
|
134 |
+
|
135 |
+
importlib.reload(bnb)
|
136 |
+
importlib.reload(triton)
|
137 |
+
try:
|
138 |
+
libcuda_dirs = lambda: None
|
139 |
+
if Version(triton.__version__) >= Version("3.0.0"):
|
140 |
+
try: from triton.backends.nvidia.driver import libcuda_dirs
|
141 |
+
except: pass
|
142 |
+
else: from triton.common.build import libcuda_dirs
|
143 |
+
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
144 |
+
libcuda_dirs()
|
145 |
+
except:
|
146 |
+
warnings.warn(
|
147 |
+
"Unsloth: CUDA is not linked properly.\n"\
|
148 |
+
"Try running `python -m bitsandbytes` then `python -m xformers.info`\n"\
|
149 |
+
"We tried running `ldconfig /usr/lib64-nvidia` ourselves, but it didn't work.\n"\
|
150 |
+
"You need to run in your terminal `sudo ldconfig /usr/lib64-nvidia` yourself, then import Unsloth.\n"\
|
151 |
+
"Also try `sudo ldconfig /usr/local/cuda-xx.x` - find the latest cuda version.\n"\
|
152 |
+
"Unsloth will still run for now, but maybe it might crash - let's hope it works!"
|
153 |
+
)
|
154 |
+
pass
|
155 |
+
pass
|
156 |
+
|
157 |
+
from .models import *
|
158 |
+
from .save import *
|
159 |
+
from .chat_templates import *
|
160 |
+
from .tokenizer_utils import *
|
161 |
+
from .trainer import *
|
unsloth-main/unsloth-main/unsloth/_auto_install.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
try: import torch
|
16 |
+
except: raise ImportError('Install torch via `pip install torch`')
|
17 |
+
from packaging.version import Version as V
|
18 |
+
v = V(torch.__version__)
|
19 |
+
cuda = str(torch.version.cuda)
|
20 |
+
is_ampere = torch.cuda.get_device_capability()[0] >= 8
|
21 |
+
if cuda != "12.1" and cuda != "11.8": raise RuntimeError(f"CUDA = {cuda} not supported!")
|
22 |
+
if v <= V('2.1.0'): raise RuntimeError(f"Torch = {v} too old!")
|
23 |
+
elif v <= V('2.1.1'): x = 'cu{}{}-torch211'
|
24 |
+
elif v <= V('2.1.2'): x = 'cu{}{}-torch212'
|
25 |
+
elif v < V('2.3.0'): x = 'cu{}{}-torch220'
|
26 |
+
elif v < V('2.4.0'): x = 'cu{}{}-torch230'
|
27 |
+
elif v < V('2.5.0'): x = 'cu{}{}-torch240'
|
28 |
+
else: raise RuntimeError(f"Torch = {v} too new!")
|
29 |
+
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
|
30 |
+
print(f'pip install "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"')
|
unsloth-main/unsloth-main/unsloth/chat_templates.py
ADDED
@@ -0,0 +1,2210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"get_chat_template",
|
17 |
+
"test_chat_templates",
|
18 |
+
"test_hf_gguf_equivalence",
|
19 |
+
"remove_special_tokens",
|
20 |
+
|
21 |
+
"to_sharegpt",
|
22 |
+
"standardize_sharegpt",
|
23 |
+
"apply_chat_template",
|
24 |
+
"train_on_responses_only",
|
25 |
+
|
26 |
+
"test_construct_chat_template",
|
27 |
+
]
|
28 |
+
|
29 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
30 |
+
from torch import LongTensor, FloatTensor
|
31 |
+
from transformers.models.llama.modeling_llama import logger
|
32 |
+
from .save import patch_saving_functions
|
33 |
+
import os
|
34 |
+
import shutil
|
35 |
+
from .tokenizer_utils import *
|
36 |
+
from .models._utils import patch_tokenizer
|
37 |
+
import re
|
38 |
+
|
39 |
+
CHAT_TEMPLATES = {}
|
40 |
+
|
41 |
+
# =========================================== Unsloth
|
42 |
+
# Unsloth efficient template leverages from Zephyr
|
43 |
+
unsloth_template = \
|
44 |
+
"{{ bos_token }}"\
|
45 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
46 |
+
"{{ messages[0]['content'] + '\n' }}"\
|
47 |
+
"{% set loop_messages = messages[1:] %}"\
|
48 |
+
"{% else %}"\
|
49 |
+
"{{ 'You are a helpful assistant to the user\n' }}"\
|
50 |
+
"{% set loop_messages = messages %}"\
|
51 |
+
"{% endif %}"\
|
52 |
+
"{% for message in loop_messages %}"\
|
53 |
+
"{% if message['role'] == 'user' %}"\
|
54 |
+
"{{ '>>> User: ' + message['content'] + '\n' }}"\
|
55 |
+
"{% elif message['role'] == 'assistant' %}"\
|
56 |
+
"{{ '>>> Assistant: ' + message['content'] + eos_token + '\n' }}"\
|
57 |
+
"{% else %}"\
|
58 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
59 |
+
"{% endif %}"\
|
60 |
+
"{% endfor %}"\
|
61 |
+
"{% if add_generation_prompt %}"\
|
62 |
+
"{{ '>>> Assistant: ' }}"\
|
63 |
+
"{% endif %}"
|
64 |
+
pass
|
65 |
+
|
66 |
+
unsloth_ollama = \
|
67 |
+
'''
|
68 |
+
FROM {__FILE_LOCATION__}
|
69 |
+
TEMPLATE """{{ if .System }}{{ .System }}
|
70 |
+
{{ end }}{{ if .Prompt }}>>> User: {{ .Prompt }}
|
71 |
+
{{ end }}>>> Assistant: {{ .Response }}{__EOS_TOKEN__}
|
72 |
+
"""
|
73 |
+
PARAMETER stop "{__EOS_TOKEN__}"
|
74 |
+
PARAMETER temperature 1.5
|
75 |
+
PARAMETER min_p 0.1
|
76 |
+
SYSTEM """You are a helpful assistant to the user"""
|
77 |
+
'''
|
78 |
+
|
79 |
+
unsloth_eos_token = "eos_token"
|
80 |
+
CHAT_TEMPLATES["unsloth"] = (unsloth_template, unsloth_eos_token, False, unsloth_ollama,)
|
81 |
+
pass
|
82 |
+
|
83 |
+
# =========================================== Zephyr
|
84 |
+
# Zephyr has no BOS!
|
85 |
+
zephyr_template = \
|
86 |
+
"{% for message in messages %}"\
|
87 |
+
"{% if message['role'] == 'user' %}"\
|
88 |
+
"{{ '<|user|>\n' + message['content'] + eos_token + '\n' }}"\
|
89 |
+
"{% elif message['role'] == 'assistant' %}"\
|
90 |
+
"{{ '<|assistant|>\n' + message['content'] + eos_token + '\n' }}"\
|
91 |
+
"{% else %}"\
|
92 |
+
"{{ '<|system|>\n' + message['content'] + eos_token + '\n' }}"\
|
93 |
+
"{% endif %}"\
|
94 |
+
"{% endfor %}"\
|
95 |
+
"{% if add_generation_prompt %}"\
|
96 |
+
"{{ '<|assistant|>\n' }}"\
|
97 |
+
"{% endif %}"
|
98 |
+
pass
|
99 |
+
|
100 |
+
zephyr_ollama = \
|
101 |
+
'''
|
102 |
+
FROM {__FILE_LOCATION__}
|
103 |
+
TEMPLATE """{{ if .System }}<|system|>
|
104 |
+
{{ .System }}{__EOS_TOKEN__}
|
105 |
+
{{ end }}{{ if .Prompt }}<|user|>
|
106 |
+
{{ .Prompt }}{__EOS_TOKEN__}
|
107 |
+
{{ end }}<|assistant|>
|
108 |
+
{{ .Response }}{__EOS_TOKEN__}
|
109 |
+
"""
|
110 |
+
PARAMETER stop "{__EOS_TOKEN__}"
|
111 |
+
PARAMETER temperature 1.5
|
112 |
+
PARAMETER min_p 0.1
|
113 |
+
'''
|
114 |
+
|
115 |
+
zephyr_eos_token = "eos_token"
|
116 |
+
CHAT_TEMPLATES["zephyr"] = (zephyr_template, zephyr_eos_token, False, zephyr_ollama,)
|
117 |
+
pass
|
118 |
+
|
119 |
+
# =========================================== ChatML
|
120 |
+
# ChatML has no BOS and not EOS! Rather <|im_start|> and <|im_end|> acts as BOS / EOS.
|
121 |
+
chatml_template = \
|
122 |
+
"{% for message in messages %}"\
|
123 |
+
"{% if message['role'] == 'user' %}"\
|
124 |
+
"{{'<|im_start|>user\n' + message['content'] + '<|im_end|>\n'}}"\
|
125 |
+
"{% elif message['role'] == 'assistant' %}"\
|
126 |
+
"{{'<|im_start|>assistant\n' + message['content'] + '<|im_end|>\n' }}"\
|
127 |
+
"{% else %}"\
|
128 |
+
"{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>\n' }}"\
|
129 |
+
"{% endif %}"\
|
130 |
+
"{% endfor %}"\
|
131 |
+
"{% if add_generation_prompt %}"\
|
132 |
+
"{{ '<|im_start|>assistant\n' }}"\
|
133 |
+
"{% endif %}"
|
134 |
+
pass
|
135 |
+
|
136 |
+
chatml_ollama = \
|
137 |
+
'''
|
138 |
+
FROM {__FILE_LOCATION__}
|
139 |
+
TEMPLATE """{{ if .System }}<|im_start|>system
|
140 |
+
{{ .System }}<|im_end|>
|
141 |
+
{{ end }}{{ if .Prompt }}<|im_start|>user
|
142 |
+
{{ .Prompt }}<|im_end|>
|
143 |
+
{{ end }}<|im_start|>assistant
|
144 |
+
{{ .Response }}<|im_end|>
|
145 |
+
"""
|
146 |
+
PARAMETER stop "<|im_start|>"
|
147 |
+
PARAMETER stop "<|im_end|>"
|
148 |
+
PARAMETER temperature 1.5
|
149 |
+
PARAMETER min_p 0.1
|
150 |
+
'''
|
151 |
+
|
152 |
+
chatml_eos_token = "<|im_end|>"
|
153 |
+
CHAT_TEMPLATES["chatml"] = (chatml_template, chatml_eos_token, True, chatml_ollama,)
|
154 |
+
pass
|
155 |
+
|
156 |
+
# =========================================== Mistral-1
|
157 |
+
# Mistral Instruct doesn't allow system prompts, so we append it to the user message.
|
158 |
+
mistral_template = \
|
159 |
+
"{{ bos_token }}"\
|
160 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
161 |
+
"{% if messages[1]['role'] == 'user' %}"\
|
162 |
+
"{{ '[INST] ' + messages[0]['content'] + ' ' + messages[1]['content'] + ' [/INST]' }}"\
|
163 |
+
"{% set loop_messages = messages[2:] %}"\
|
164 |
+
"{% else %}"\
|
165 |
+
"{{ '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
|
166 |
+
"{% set loop_messages = messages[1:] %}"\
|
167 |
+
"{% endif %}"\
|
168 |
+
"{% else %}"\
|
169 |
+
"{% set loop_messages = messages %}"\
|
170 |
+
"{% endif %}"\
|
171 |
+
"{% for message in loop_messages %}"\
|
172 |
+
"{% if message['role'] == 'user' %}"\
|
173 |
+
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"\
|
174 |
+
"{% elif message['role'] == 'assistant' %}"\
|
175 |
+
"{{ message['content'] + eos_token }}"\
|
176 |
+
"{% else %}"\
|
177 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
178 |
+
"{% endif %}"\
|
179 |
+
"{% endfor %}"
|
180 |
+
pass
|
181 |
+
|
182 |
+
# Ollama from https://www.ollama.com/library/mistral
|
183 |
+
mistral_ollama = \
|
184 |
+
'''
|
185 |
+
FROM {__FILE_LOCATION__}
|
186 |
+
TEMPLATE """[INST] {{ if .System }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]"""
|
187 |
+
PARAMETER stop "{__EOS_TOKEN__}"
|
188 |
+
PARAMETER temperature 1.5
|
189 |
+
PARAMETER min_p 0.1
|
190 |
+
'''
|
191 |
+
|
192 |
+
mistral_eos_token = "eos_token"
|
193 |
+
CHAT_TEMPLATES["mistral"] = (mistral_template, mistral_eos_token, False, mistral_ollama,)
|
194 |
+
pass
|
195 |
+
|
196 |
+
# =========================================== Llama-2
|
197 |
+
# Adds BOS to every convo! And weird <<SYS>> system messages.
|
198 |
+
llama_template = \
|
199 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
200 |
+
"{% if messages[1]['role'] == 'user' %}"\
|
201 |
+
"{{ bos_token + '[INST] <<SYS>>\n' + messages[0]['content'] + '\n<</SYS>>\n\n' + messages[1]['content'] + ' [/INST]' }}"\
|
202 |
+
"{% set loop_messages = messages[2:] %}"\
|
203 |
+
"{% else %}"\
|
204 |
+
"{{ bos_token + '[INST] ' + messages[0]['content'] + ' [/INST]' }}"\
|
205 |
+
"{% set loop_messages = messages[1:] %}"\
|
206 |
+
"{% endif %}"\
|
207 |
+
"{% else %}"\
|
208 |
+
"{% set loop_messages = messages %}"\
|
209 |
+
"{% endif %}"\
|
210 |
+
"{% for message in loop_messages %}"\
|
211 |
+
"{% if message['role'] == 'user' %}"\
|
212 |
+
"{{ bos_token + '[INST] ' + message['content'].strip() + ' [/INST]' }}"\
|
213 |
+
"{% elif message['role'] == 'assistant' %}"\
|
214 |
+
"{{ ' ' + message['content'].strip() + ' ' + eos_token }}"\
|
215 |
+
"{% else %}"\
|
216 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
217 |
+
"{% endif %}"\
|
218 |
+
"{% endfor %}"
|
219 |
+
pass
|
220 |
+
|
221 |
+
# Ollama from https://www.ollama.com/library/llama3
|
222 |
+
llama_ollama = \
|
223 |
+
'''
|
224 |
+
FROM {__FILE_LOCATION__}
|
225 |
+
TEMPLATE """[INST] <<SYS>>{{ .System }}<</SYS>>
|
226 |
+
|
227 |
+
{{ .Prompt }} [/INST]"""
|
228 |
+
PARAMETER stop "{__EOS_TOKEN__}"
|
229 |
+
PARAMETER temperature 1.5
|
230 |
+
PARAMETER min_p 0.1
|
231 |
+
'''
|
232 |
+
|
233 |
+
llama_eos_token = "eos_token"
|
234 |
+
CHAT_TEMPLATES["llama"] = (llama_template, llama_eos_token, False, llama_ollama,)
|
235 |
+
pass
|
236 |
+
|
237 |
+
# =========================================== Vicuna
|
238 |
+
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
239 |
+
vicuna_template = \
|
240 |
+
"{{ bos_token }}"\
|
241 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
242 |
+
"{{ messages[0]['content'] + ' ' }}"\
|
243 |
+
"{% set loop_messages = messages[1:] %}"\
|
244 |
+
"{% else %}"\
|
245 |
+
"{{ 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\\'s questions.' + ' ' }}"\
|
246 |
+
"{% set loop_messages = messages %}"\
|
247 |
+
"{% endif %}"\
|
248 |
+
"{% for message in loop_messages %}"\
|
249 |
+
"{% if message['role'] == 'user' %}"\
|
250 |
+
"{{ 'USER: ' + message['content'] + ' ' }}"\
|
251 |
+
"{% elif message['role'] == 'assistant' %}"\
|
252 |
+
"{{ 'ASSISTANT: ' + message['content'] + eos_token }}"\
|
253 |
+
"{% else %}"\
|
254 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
255 |
+
"{% endif %}"\
|
256 |
+
"{% endfor %}"\
|
257 |
+
"{% if add_generation_prompt %}"\
|
258 |
+
"{{ 'ASSISTANT:' }}"\
|
259 |
+
"{% endif %}"
|
260 |
+
pass
|
261 |
+
|
262 |
+
# Ollama from https://www.ollama.com/library/vicuna
|
263 |
+
vicuna_ollama = \
|
264 |
+
'''
|
265 |
+
FROM {__FILE_LOCATION__}
|
266 |
+
TEMPLATE """{{ if .System }}{{ .System }} {{ end }}{{ if .Prompt }}USER: {{ .Prompt }} {{ end }}ASSISTANT: {{ .Response }} {__EOS_TOKEN__}"""
|
267 |
+
PARAMETER stop "{__EOS_TOKEN__}"
|
268 |
+
PARAMETER temperature 1.5
|
269 |
+
PARAMETER min_p 0.1
|
270 |
+
'''
|
271 |
+
|
272 |
+
vicuna_eos_token = "eos_token"
|
273 |
+
CHAT_TEMPLATES["vicuna"] = (vicuna_template, vicuna_eos_token, False, vicuna_ollama,)
|
274 |
+
pass
|
275 |
+
|
276 |
+
# =========================================== Vicuna Old
|
277 |
+
# https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
|
278 |
+
vicuna_old_template = \
|
279 |
+
"{{ bos_token }}"\
|
280 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
281 |
+
"{{ messages[0]['content'] + '\n' }}"\
|
282 |
+
"{% set loop_messages = messages[1:] %}"\
|
283 |
+
"{% else %}"\
|
284 |
+
"{{ 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\\'s questions.' + '\n' }}"\
|
285 |
+
"{% set loop_messages = messages %}"\
|
286 |
+
"{% endif %}"\
|
287 |
+
"{% for message in loop_messages %}"\
|
288 |
+
"{% if message['role'] == 'user' %}"\
|
289 |
+
"{{ '### Human: ' + message['content'] + '\n' }}"\
|
290 |
+
"{% elif message['role'] == 'assistant' %}"\
|
291 |
+
"{{ '### Assistant: ' + message['content'] + eos_token + '\n' }}"\
|
292 |
+
"{% else %}"\
|
293 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
294 |
+
"{% endif %}"\
|
295 |
+
"{% endfor %}"\
|
296 |
+
"{% if add_generation_prompt %}"\
|
297 |
+
"{{ '### Assistant:' }}"\
|
298 |
+
"{% endif %}"
|
299 |
+
pass
|
300 |
+
|
301 |
+
vicuna_old_ollama = \
|
302 |
+
'''
|
303 |
+
FROM {__FILE_LOCATION__}
|
304 |
+
TEMPLATE """{{ if .System }}{{ .System }}
|
305 |
+
{{ end }}{{ if .Prompt }}### Human: {{ .Prompt }}
|
306 |
+
{{ end }}### Assistant: {{ .Response }}{__EOS_TOKEN__}
|
307 |
+
"""
|
308 |
+
PARAMETER stop "{__EOS_TOKEN__}"
|
309 |
+
PARAMETER temperature 1.5
|
310 |
+
PARAMETER min_p 0.1
|
311 |
+
SYSTEM """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."""
|
312 |
+
'''
|
313 |
+
|
314 |
+
vicuna_old_eos_token = "eos_token"
|
315 |
+
CHAT_TEMPLATES["vicuna_old"] = (vicuna_old_template, vicuna_old_eos_token, False, vicuna_old_ollama,)
|
316 |
+
pass
|
317 |
+
|
318 |
+
# =========================================== Alpaca multi turn
|
319 |
+
# https://github.com/tatsu-lab/stanford_alpaca Changed for multi-turn convos
|
320 |
+
alpaca_template = \
|
321 |
+
"{{ bos_token }}"\
|
322 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
323 |
+
"{{ messages[0]['content'] + '\n\n' }}"\
|
324 |
+
"{% set loop_messages = messages[1:] %}"\
|
325 |
+
"{% else %}"\
|
326 |
+
"{{ 'Below are some instructions that describe some tasks. Write responses that appropriately complete each request.\n\n' }}"\
|
327 |
+
"{% set loop_messages = messages %}"\
|
328 |
+
"{% endif %}"\
|
329 |
+
"{% for message in loop_messages %}"\
|
330 |
+
"{% if message['role'] == 'user' %}"\
|
331 |
+
"{{ '### Instruction:\n' + message['content'] + '\n\n' }}"\
|
332 |
+
"{% elif message['role'] == 'assistant' %}"\
|
333 |
+
"{{ '### Response:\n' + message['content'] + eos_token + '\n\n' }}"\
|
334 |
+
"{% else %}"\
|
335 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
336 |
+
"{% endif %}"\
|
337 |
+
"{% endfor %}"\
|
338 |
+
"{% if add_generation_prompt %}"\
|
339 |
+
"{{ '### Response:\n' }}"\
|
340 |
+
"{% endif %}"
|
341 |
+
pass
|
342 |
+
|
343 |
+
alpaca_ollama = \
|
344 |
+
'''
|
345 |
+
FROM {__FILE_LOCATION__}
|
346 |
+
TEMPLATE """{{ if .System }}{{ .System }}
|
347 |
+
|
348 |
+
{{ end }}{{ if .Prompt }}### Instruction:
|
349 |
+
{{ .Prompt }}{{ end }}
|
350 |
+
|
351 |
+
### Response:
|
352 |
+
{{ .Response }}{__EOS_TOKEN__}
|
353 |
+
|
354 |
+
"""
|
355 |
+
PARAMETER stop "{__EOS_TOKEN__}"
|
356 |
+
PARAMETER temperature 1.5
|
357 |
+
PARAMETER min_p 0.1
|
358 |
+
SYSTEM """Below are some instructions that describe some tasks. Write responses that appropriately complete each request."""
|
359 |
+
'''
|
360 |
+
|
361 |
+
alpaca_eos_token = "eos_token"
|
362 |
+
CHAT_TEMPLATES["alpaca"] = (alpaca_template, alpaca_eos_token, False, alpaca_ollama,)
|
363 |
+
pass
|
364 |
+
|
365 |
+
# =========================================== Gemma
|
366 |
+
# https://huggingface.co/google/gemma-7b-it
|
367 |
+
# Notice we must use |trim for lstrip and rstrip. <start_of_turn> maps to 106.
|
368 |
+
# <end_of_turn> maps to 107. user and model are normal 1 word tokens.
|
369 |
+
gemma_template = \
|
370 |
+
"{{ bos_token }}"\
|
371 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
372 |
+
"{{'<start_of_turn>user\n' + messages[0]['content'] | trim + ' ' + messages[1]['content'] | trim + '<end_of_turn>\n'}}"\
|
373 |
+
"{% set loop_messages = messages[2:] %}"\
|
374 |
+
"{% endif %}"\
|
375 |
+
"{% for message in messages %}"\
|
376 |
+
"{% if message['role'] == 'user' %}"\
|
377 |
+
"{{'<start_of_turn>user\n' + message['content'] | trim + '<end_of_turn>\n'}}"\
|
378 |
+
"{% elif message['role'] == 'assistant' %}"\
|
379 |
+
"{{'<start_of_turn>model\n' + message['content'] | trim + '<end_of_turn>\n' }}"\
|
380 |
+
"{% else %}"\
|
381 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
382 |
+
"{% endif %}"\
|
383 |
+
"{% endfor %}"\
|
384 |
+
"{% if add_generation_prompt %}"\
|
385 |
+
"{{ '<start_of_turn>model\n' }}"\
|
386 |
+
"{% endif %}"
|
387 |
+
pass
|
388 |
+
|
389 |
+
# Ollama from https://www.ollama.com/library/gemma
|
390 |
+
gemma_ollama = \
|
391 |
+
'''
|
392 |
+
FROM {__FILE_LOCATION__}
|
393 |
+
TEMPLATE """<start_of_turn>user
|
394 |
+
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
|
395 |
+
<start_of_turn>model
|
396 |
+
{{ .Response }}<end_of_turn>
|
397 |
+
"""
|
398 |
+
PARAMETER repeat_penalty 1
|
399 |
+
PARAMETER stop "<start_of_turn>"
|
400 |
+
PARAMETER stop "<end_of_turn>"
|
401 |
+
PARAMETER penalize_newline false
|
402 |
+
PARAMETER temperature 1.5
|
403 |
+
PARAMETER min_p 0.1
|
404 |
+
'''
|
405 |
+
|
406 |
+
gemma_eos_token = "<end_of_turn>"
|
407 |
+
CHAT_TEMPLATES["gemma"] = (gemma_template, gemma_eos_token, True, gemma_ollama,)
|
408 |
+
pass
|
409 |
+
|
410 |
+
# =========================================== Gemma with ChatML instead
|
411 |
+
# We find using <eos> is still more appropriate!
|
412 |
+
gemma_chatml_template = "{{ bos_token }}" + chatml_template
|
413 |
+
pass
|
414 |
+
|
415 |
+
gemma_chatml_ollama = \
|
416 |
+
'''
|
417 |
+
FROM {__FILE_LOCATION__}
|
418 |
+
TEMPLATE """{{ if .System }}<|im_start|>system
|
419 |
+
{{ .System }}<|im_end|>
|
420 |
+
{{ end }}{{ if .Prompt }}<|im_start|>user
|
421 |
+
{{ .Prompt }}<|im_end|>
|
422 |
+
{{ end }}<|im_start|>assistant
|
423 |
+
{{ .Response }}<|im_end|>
|
424 |
+
"""
|
425 |
+
PARAMETER repeat_penalty 1
|
426 |
+
PARAMETER stop "<|im_start|>"
|
427 |
+
PARAMETER stop "<|im_end|>"
|
428 |
+
PARAMETER penalize_newline false
|
429 |
+
PARAMETER temperature 1.5
|
430 |
+
PARAMETER min_p 0.1
|
431 |
+
'''
|
432 |
+
|
433 |
+
gemma_chatml_eos_token = (
|
434 |
+
{"<start_of_turn>" : "<|im_start|>", "<eos>" : "<|im_end|>"},
|
435 |
+
"<|im_end|>",
|
436 |
+
)
|
437 |
+
CHAT_TEMPLATES["gemma_chatml"] = (gemma_chatml_template, gemma_chatml_eos_token, True, gemma_chatml_ollama,)
|
438 |
+
pass
|
439 |
+
|
440 |
+
# =========================================== Gemma 2
|
441 |
+
# Same as Gemma 1, but with sliding window attention!
|
442 |
+
# https://ollama.com/library/gemma2/blobs/6522ca797f47
|
443 |
+
gemma2_template = gemma_template
|
444 |
+
gemma2_ollama = gemma_ollama + "PARAMETER num_ctx 4096\n"
|
445 |
+
gemma2_eos_token = "<end_of_turn>"
|
446 |
+
CHAT_TEMPLATES["gemma2"] = (gemma2_template, gemma2_eos_token, True, gemma2_ollama,)
|
447 |
+
|
448 |
+
# =========================================== Gemma 2 with ChatML instead
|
449 |
+
gemma2_chatml_template = gemma_chatml_template
|
450 |
+
gemma2_chatml_ollama = gemma_chatml_ollama + "PARAMETER num_ctx 4096\n"
|
451 |
+
gemma2_chatml_eos_token = gemma_chatml_eos_token
|
452 |
+
CHAT_TEMPLATES["gemma2_chatml"] = (gemma2_chatml_template, gemma2_chatml_eos_token, True, gemma2_chatml_ollama,)
|
453 |
+
pass
|
454 |
+
|
455 |
+
# =========================================== Llama-3
|
456 |
+
# Weirdly \n\n is needed?
|
457 |
+
llama3_template = \
|
458 |
+
"{{ bos_token }}"\
|
459 |
+
"{% for message in messages %}"\
|
460 |
+
"{% if message['role'] == 'user' %}"\
|
461 |
+
"{{ '<|start_header_id|>user<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
|
462 |
+
"{% elif message['role'] == 'assistant' %}"\
|
463 |
+
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
|
464 |
+
"{% else %}"\
|
465 |
+
"{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}"\
|
466 |
+
"{% endif %}"\
|
467 |
+
"{% endfor %}"\
|
468 |
+
"{% if add_generation_prompt %}"\
|
469 |
+
"{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}"\
|
470 |
+
"{% endif %}"
|
471 |
+
pass
|
472 |
+
|
473 |
+
# Ollama from https://www.ollama.com/library/llama3
|
474 |
+
llama3_ollama = \
|
475 |
+
'''
|
476 |
+
FROM {__FILE_LOCATION__}
|
477 |
+
TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
|
478 |
+
|
479 |
+
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
480 |
+
|
481 |
+
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
482 |
+
|
483 |
+
{{ .Response }}<|eot_id|>"""
|
484 |
+
PARAMETER stop "<|start_header_id|>"
|
485 |
+
PARAMETER stop "<|end_header_id|>"
|
486 |
+
PARAMETER stop "<|eot_id|>"
|
487 |
+
PARAMETER temperature 1.5
|
488 |
+
PARAMETER min_p 0.1
|
489 |
+
'''
|
490 |
+
|
491 |
+
llama3_template_eos_token = "eos_token"
|
492 |
+
CHAT_TEMPLATES["llama-3"] = (llama3_template, llama3_template_eos_token, False, llama3_ollama,)
|
493 |
+
pass
|
494 |
+
|
495 |
+
|
496 |
+
# =========================================== Phi-3
|
497 |
+
# "{{ bos_token }}"\ # Phi-3.5 removes BOS?
|
498 |
+
phi3_template = \
|
499 |
+
"{% for message in messages %}"\
|
500 |
+
"{% if message['role'] == 'user' %}"\
|
501 |
+
"{{'<|user|>\n' + message['content'] + '<|end|>\n'}}"\
|
502 |
+
"{% elif message['role'] == 'assistant' %}"\
|
503 |
+
"{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}"\
|
504 |
+
"{% else %}"\
|
505 |
+
"{{'<|' + message['role'] + '|>\n' + message['content'] + '<|end|>\n'}}"\
|
506 |
+
"{% endif %}"\
|
507 |
+
"{% endfor %}"\
|
508 |
+
"{% if add_generation_prompt %}"\
|
509 |
+
"{{ '<|assistant|>\n' }}"\
|
510 |
+
"{% endif %}"
|
511 |
+
pass
|
512 |
+
|
513 |
+
# Ollama from https://www.ollama.com/library/phi3
|
514 |
+
phi3_ollama = \
|
515 |
+
'''
|
516 |
+
FROM {__FILE_LOCATION__}
|
517 |
+
TEMPLATE """{{ if .System }}<|system|>
|
518 |
+
{{ .System }}<|end|>
|
519 |
+
{{ end }}{{ if .Prompt }}<|user|>
|
520 |
+
{{ .Prompt }}<|end|>
|
521 |
+
{{ end }}<|assistant|>
|
522 |
+
{{ .Response }}<|end|>
|
523 |
+
"""
|
524 |
+
PARAMETER stop "<|end|>"
|
525 |
+
PARAMETER stop "<|user|>"
|
526 |
+
PARAMETER stop "<|assistant|>"
|
527 |
+
PARAMETER temperature 1.5
|
528 |
+
PARAMETER min_p 0.1
|
529 |
+
'''
|
530 |
+
|
531 |
+
phi3_template_eos_token = "<|end|>"
|
532 |
+
CHAT_TEMPLATES["phi-3"] = (phi3_template, phi3_template_eos_token, False, phi3_ollama,)
|
533 |
+
CHAT_TEMPLATES["phi-35"] = CHAT_TEMPLATES["phi-3"]
|
534 |
+
CHAT_TEMPLATES["phi-3.5"] = CHAT_TEMPLATES["phi-3"]
|
535 |
+
pass
|
536 |
+
|
537 |
+
# =========================================== Llama-3.1
|
538 |
+
"""
|
539 |
+
No trimming in Llama 3.1 Instruct!
|
540 |
+
Also an extra newline for Cutting Knowledge Date
|
541 |
+
See https://colab.research.google.com/drive/1Xpqq5xpIgO-B00MQ-UccYMwN2J8QFgBM?usp=sharing
|
542 |
+
|
543 |
+
Also should be
|
544 |
+
|
545 |
+
import datetime
|
546 |
+
tokenizer.apply_chat_template(
|
547 |
+
messages,
|
548 |
+
add_generation_prompt = True,
|
549 |
+
tokenize = False,
|
550 |
+
date_string = datetime.today().strftime("%d %B %Y")),
|
551 |
+
)
|
552 |
+
"""
|
553 |
+
|
554 |
+
llama31_template = \
|
555 |
+
"""{{- bos_token }}
|
556 |
+
{%- if custom_tools is defined %}
|
557 |
+
{%- set tools = custom_tools %}
|
558 |
+
{%- endif %}
|
559 |
+
{%- if not tools_in_user_message is defined %}
|
560 |
+
{%- set tools_in_user_message = true %}
|
561 |
+
{%- endif %}
|
562 |
+
{%- if not date_string is defined %}
|
563 |
+
{%- set date_string = "26 July 2024" %}
|
564 |
+
{%- endif %}
|
565 |
+
{%- if not tools is defined %}
|
566 |
+
{%- set tools = none %}
|
567 |
+
{%- endif %}
|
568 |
+
|
569 |
+
{#- This block extracts the system message, so we can slot it into the right place. #}
|
570 |
+
{%- if messages[0]['role'] == 'system' %}
|
571 |
+
{%- set system_message = messages[0]['content'] %}
|
572 |
+
{%- set messages = messages[1:] %}
|
573 |
+
{%- else %}
|
574 |
+
{%- set system_message = "" %}
|
575 |
+
{%- endif %}
|
576 |
+
|
577 |
+
{#- System message + builtin tools #}
|
578 |
+
{{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
|
579 |
+
{%- if builtin_tools is defined or tools is not none %}
|
580 |
+
{{- "Environment: ipython\n" }}
|
581 |
+
{%- endif %}
|
582 |
+
{%- if builtin_tools is defined %}
|
583 |
+
{{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}}
|
584 |
+
{%- endif %}
|
585 |
+
{{- "Cutting Knowledge Date: December 2023\n" }}
|
586 |
+
{{- "Today Date: " + date_string + "\n\n" }}
|
587 |
+
{%- if tools is not none and not tools_in_user_message %}
|
588 |
+
{{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
|
589 |
+
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
590 |
+
{{- "Do not use variables.\n\n" }}
|
591 |
+
{%- for t in tools %}
|
592 |
+
{{- t | tojson(indent=4) }}
|
593 |
+
{{- "\n\n" }}
|
594 |
+
{%- endfor %}
|
595 |
+
{%- endif %}
|
596 |
+
{{- system_message }}
|
597 |
+
{{- "<|eot_id|>" }}
|
598 |
+
|
599 |
+
{#- Custom tools are passed in a user message with some extra guidance #}
|
600 |
+
{%- if tools_in_user_message and not tools is none %}
|
601 |
+
{#- Extract the first user message so we can plug it in here #}
|
602 |
+
{%- if messages | length != 0 %}
|
603 |
+
{%- set first_user_message = messages[0]['content'] %}
|
604 |
+
{%- set messages = messages[1:] %}
|
605 |
+
{%- else %}
|
606 |
+
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
607 |
+
{%- endif %}
|
608 |
+
{{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
|
609 |
+
{{- "Given the following functions, please respond with a JSON for a function call " }}
|
610 |
+
{{- "with its proper arguments that best answers the given prompt.\n\n" }}
|
611 |
+
{{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
|
612 |
+
{{- "Do not use variables.\n\n" }}
|
613 |
+
{%- for t in tools %}
|
614 |
+
{{- t | tojson(indent=4) }}
|
615 |
+
{{- "\n\n" }}
|
616 |
+
{%- endfor %}
|
617 |
+
{{- first_user_message + "<|eot_id|>"}}
|
618 |
+
{%- endif %}
|
619 |
+
|
620 |
+
{%- for message in messages %}
|
621 |
+
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
622 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] + '<|eot_id|>' }}
|
623 |
+
{%- elif 'tool_calls' in message %}
|
624 |
+
{%- if not message.tool_calls|length == 1 %}
|
625 |
+
{{- raise_exception("This model only supports single tool-calls at once!") }}
|
626 |
+
{%- endif %}
|
627 |
+
{%- set tool_call = message.tool_calls[0].function %}
|
628 |
+
{%- if builtin_tools is defined and tool_call.name in builtin_tools %}
|
629 |
+
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
630 |
+
{{- "<|python_tag|>" + tool_call.name + ".call(" }}
|
631 |
+
{%- for arg_name, arg_val in tool_call.arguments | items %}
|
632 |
+
{{- arg_name + '="' + arg_val + '"' }}
|
633 |
+
{%- if not loop.last %}
|
634 |
+
{{- ", " }}
|
635 |
+
{%- endif %}
|
636 |
+
{%- endfor %}
|
637 |
+
{{- ")" }}
|
638 |
+
{%- else %}
|
639 |
+
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
|
640 |
+
{{- '{"name": "' + tool_call.name + '", ' }}
|
641 |
+
{{- '"parameters": ' }}
|
642 |
+
{{- tool_call.arguments | tojson }}
|
643 |
+
{{- "}" }}
|
644 |
+
{%- endif %}
|
645 |
+
{%- if builtin_tools is defined %}
|
646 |
+
{#- This means we're in ipython mode #}
|
647 |
+
{{- "<|eom_id|>" }}
|
648 |
+
{%- else %}
|
649 |
+
{{- "<|eot_id|>" }}
|
650 |
+
{%- endif %}
|
651 |
+
{%- elif message.role == "tool" or message.role == "ipython" %}
|
652 |
+
{{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
|
653 |
+
{%- if message.content is mapping or message.content is iterable %}
|
654 |
+
{{- message.content | tojson }}
|
655 |
+
{%- else %}
|
656 |
+
{{- message.content }}
|
657 |
+
{%- endif %}
|
658 |
+
{{- "<|eot_id|>" }}
|
659 |
+
{%- endif %}
|
660 |
+
{%- endfor %}
|
661 |
+
{%- if add_generation_prompt %}
|
662 |
+
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
663 |
+
{%- endif %}
|
664 |
+
"""
|
665 |
+
pass
|
666 |
+
|
667 |
+
# Ollama from https://ollama.com/library/llama3.1 (needs updating!)
|
668 |
+
llama31_ollama = \
|
669 |
+
'''
|
670 |
+
FROM {__FILE_LOCATION__}
|
671 |
+
TEMPLATE """{{ if .Messages }}
|
672 |
+
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
|
673 |
+
{{- if .System }}
|
674 |
+
|
675 |
+
{{ .System }}
|
676 |
+
{{- end }}
|
677 |
+
{{- if .Tools }}
|
678 |
+
|
679 |
+
You are a helpful assistant with tool calling capabilities. When you receive a tool call response, use the output to format an answer to the orginal use question.
|
680 |
+
{{- end }}
|
681 |
+
{{- end }}<|eot_id|>
|
682 |
+
{{- range $i, $_ := .Messages }}
|
683 |
+
{{- $last := eq (len (slice $.Messages $i)) 1 }}
|
684 |
+
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
|
685 |
+
{{- if and $.Tools $last }}
|
686 |
+
|
687 |
+
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
|
688 |
+
|
689 |
+
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
|
690 |
+
|
691 |
+
{{ $.Tools }}
|
692 |
+
{{- end }}
|
693 |
+
|
694 |
+
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
695 |
+
|
696 |
+
{{ end }}
|
697 |
+
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
|
698 |
+
{{- if .ToolCalls }}
|
699 |
+
|
700 |
+
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
|
701 |
+
{{- else }}
|
702 |
+
|
703 |
+
{{ .Content }}{{ if not $last }}<|eot_id|>{{ end }}
|
704 |
+
{{- end }}
|
705 |
+
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
|
706 |
+
|
707 |
+
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
|
708 |
+
|
709 |
+
{{ end }}
|
710 |
+
{{- end }}
|
711 |
+
{{- end }}
|
712 |
+
{{- else }}
|
713 |
+
{{- if .System }}<|start_header_id|>system<|end_header_id|>
|
714 |
+
|
715 |
+
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
|
716 |
+
|
717 |
+
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
|
718 |
+
|
719 |
+
{{ end }}{{ .Response }}{{ if .Response }}<|eot_id|>{{ end }}"""
|
720 |
+
PARAMETER stop "<|start_header_id|>"
|
721 |
+
PARAMETER stop "<|end_header_id|>"
|
722 |
+
PARAMETER stop "<|eot_id|>"
|
723 |
+
PARAMETER stop "<|eom_id|>"
|
724 |
+
PARAMETER temperature 1.5
|
725 |
+
PARAMETER min_p 0.1
|
726 |
+
'''
|
727 |
+
|
728 |
+
llama31_template_eos_token = "eos_token"
|
729 |
+
CHAT_TEMPLATES["llama-3.1"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
|
730 |
+
CHAT_TEMPLATES["llama-31"] = (llama31_template, llama31_template_eos_token, False, llama31_ollama,)
|
731 |
+
pass
|
732 |
+
|
733 |
+
|
734 |
+
# =========================================== Qwen 2.5
|
735 |
+
qwen25_template = \
|
736 |
+
"""{%- if tools %}
|
737 |
+
{{- \'<|im_start|>system\\n\' }}
|
738 |
+
{%- if messages[0][\'role\'] == \'system\' %}
|
739 |
+
{{- messages[0][\'content\'] }}
|
740 |
+
{%- else %}
|
741 |
+
{{- \'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\' }}
|
742 |
+
{%- endif %}
|
743 |
+
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
|
744 |
+
{%- for tool in tools %}
|
745 |
+
{{- "\\n" }}
|
746 |
+
{{- tool | tojson }}
|
747 |
+
{%- endfor %}
|
748 |
+
{{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}
|
749 |
+
{%- if messages[0][\'role\'] == \'system\' %}
|
750 |
+
{{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}
|
751 |
+
{%- else %}
|
752 |
+
{{- \'<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n\' }}
|
753 |
+
{%- endif %}\n{%- endif %}\n{%- for message in messages %}
|
754 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
755 |
+
{{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}
|
756 |
+
{%- elif message.role == "assistant" %}
|
757 |
+
{{- \'<|im_start|>\' + message.role }}
|
758 |
+
{%- if message.content %}
|
759 |
+
{{- \'\\n\' + message.content }}
|
760 |
+
{%- endif %}
|
761 |
+
{%- for tool_call in message.tool_calls %}
|
762 |
+
{%- if tool_call.function is defined %}
|
763 |
+
{%- set tool_call = tool_call.function %}
|
764 |
+
{%- endif %}
|
765 |
+
{{- \'\\n<tool_call>\\n{"name": "\' }}
|
766 |
+
{{- tool_call.name }}
|
767 |
+
{{- \'", "arguments": \' }}
|
768 |
+
{{- tool_call.arguments | tojson }}
|
769 |
+
{{- \'}\\n</tool_call>\' }}
|
770 |
+
{%- endfor %}
|
771 |
+
{{- \'<|im_end|>\\n\' }}
|
772 |
+
{%- elif message.role == "tool" %}
|
773 |
+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} {{- \'<|im_start|>user\' }}
|
774 |
+
{%- endif %}
|
775 |
+
{{- \'\\n<tool_response>\\n\' }}
|
776 |
+
{{- message.content }}
|
777 |
+
{{- \'\\n</tool_response>\' }}
|
778 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
779 |
+
{{- \'<|im_end|>\\n\' }}
|
780 |
+
{%- endif %}
|
781 |
+
{%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}
|
782 |
+
{{- \'<|im_start|>assistant\\n\' }}
|
783 |
+
{%- endif %}
|
784 |
+
"""
|
785 |
+
|
786 |
+
|
787 |
+
# Ollama from https://ollama.com/library/qwen2.5/blobs/eb4402837c78
|
788 |
+
qwen25_ollama = \
|
789 |
+
'''
|
790 |
+
FROM {__FILE_LOCATION__}
|
791 |
+
TEMPLATE """{{- if .Messages }}
|
792 |
+
{{- if or .System .Tools }}<|im_start|>system
|
793 |
+
{{- if .System }}
|
794 |
+
{{ .System }}
|
795 |
+
{{- end }}
|
796 |
+
{{- if .Tools }}
|
797 |
+
|
798 |
+
# Tools
|
799 |
+
|
800 |
+
You may call one or more functions to assist with the user query.
|
801 |
+
|
802 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
803 |
+
<tools>
|
804 |
+
{{- range .Tools }}
|
805 |
+
{"type": "function", "function": {{ .Function }}}
|
806 |
+
{{- end }}
|
807 |
+
</tools>
|
808 |
+
|
809 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
810 |
+
<tool_call>
|
811 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
812 |
+
</tool_call>
|
813 |
+
{{- end }}<|im_end|>
|
814 |
+
{{ end }}
|
815 |
+
{{- range $i, $_ := .Messages }}
|
816 |
+
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
|
817 |
+
{{- if eq .Role "user" }}<|im_start|>user
|
818 |
+
{{ .Content }}<|im_end|>
|
819 |
+
{{ else if eq .Role "assistant" }}<|im_start|>assistant
|
820 |
+
{{ if .Content }}{{ .Content }}
|
821 |
+
{{- else if .ToolCalls }}<tool_call>
|
822 |
+
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
823 |
+
{{ end }}</tool_call>
|
824 |
+
{{- end }}{{ if not $last }}<|im_end|>
|
825 |
+
{{ end }}
|
826 |
+
{{- else if eq .Role "tool" }}<|im_start|>user
|
827 |
+
<tool_response>
|
828 |
+
{{ .Content }}
|
829 |
+
</tool_response><|im_end|>
|
830 |
+
{{ end }}
|
831 |
+
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
|
832 |
+
{{ end }}
|
833 |
+
{{- end }}
|
834 |
+
{{- else }}
|
835 |
+
{{- if .System }}<|im_start|>system
|
836 |
+
{{ .System }}<|im_end|>
|
837 |
+
{{ end }}{{ if .Prompt }}<|im_start|>user
|
838 |
+
{{ .Prompt }}<|im_end|>
|
839 |
+
{{ end }}<|im_start|>assistant
|
840 |
+
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}"""
|
841 |
+
PARAMETER stop "<|im_end|>"
|
842 |
+
PARAMETER stop "<|endoftext|>"
|
843 |
+
PARAMETER temperature 1.5
|
844 |
+
PARAMETER min_p 0.1
|
845 |
+
'''
|
846 |
+
|
847 |
+
qwen25_template_eos_token = "eos_token"
|
848 |
+
CHAT_TEMPLATES["qwen-2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
849 |
+
CHAT_TEMPLATES["qwen-25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
850 |
+
CHAT_TEMPLATES["qwen25"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
851 |
+
CHAT_TEMPLATES["qwen2.5"] = (qwen25_template, qwen25_template_eos_token, False, qwen25_ollama,)
|
852 |
+
pass
|
853 |
+
|
854 |
+
|
855 |
+
def get_chat_template(
|
856 |
+
tokenizer,
|
857 |
+
chat_template = "chatml",
|
858 |
+
mapping = {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"},
|
859 |
+
map_eos_token = True,
|
860 |
+
system_message = None,
|
861 |
+
):
|
862 |
+
assert(type(map_eos_token) is bool)
|
863 |
+
old_tokenizer = tokenizer
|
864 |
+
|
865 |
+
IS_GEMMA = False
|
866 |
+
if tokenizer.__class__.__name__.startswith("Gemma"):
|
867 |
+
if chat_template == "chatml": chat_template = "gemma_chatml"
|
868 |
+
IS_GEMMA = True
|
869 |
+
pass
|
870 |
+
|
871 |
+
# We add a check for Llama-3
|
872 |
+
# if chat_template == "llama-3":
|
873 |
+
# tokenizer._using_llama3_template = True
|
874 |
+
# else:
|
875 |
+
# llama3_tokens = set(["<|end_header_id|>", "<|eot_id|>", "<|start_header_id|>"])
|
876 |
+
# check_llama3_tokens = llama3_tokens & set(str(x) for x in tokenizer.added_tokens_decoder.values())
|
877 |
+
# if len(check_llama3_tokens) == len(llama3_tokens):
|
878 |
+
# tokenizer._using_llama3_template = True
|
879 |
+
# pass
|
880 |
+
# pass
|
881 |
+
|
882 |
+
# We first check if the tokenizer is a fast one. If not, we cannot convert this!
|
883 |
+
is_fast_tokenizer = getattr(tokenizer, "is_fast", False)
|
884 |
+
old_padding_side = tokenizer.padding_side
|
885 |
+
|
886 |
+
same_padding_token = False
|
887 |
+
|
888 |
+
if type(chat_template) in (list, tuple,):
|
889 |
+
chat_template, stop_word = chat_template
|
890 |
+
assert(type(chat_template) is str)
|
891 |
+
assert(type(stop_word) is str)
|
892 |
+
ollama_modelfile = None
|
893 |
+
|
894 |
+
elif type(chat_template) is str:
|
895 |
+
|
896 |
+
chat_template, stop_word, yes_map_eos_token, ollama_modelfile = CHAT_TEMPLATES[chat_template]
|
897 |
+
|
898 |
+
# Check mapping to eos_token
|
899 |
+
if not map_eos_token and yes_map_eos_token: map_eos_token = True
|
900 |
+
if not yes_map_eos_token and map_eos_token: map_eos_token = False
|
901 |
+
|
902 |
+
if type(stop_word) in (list, tuple,):
|
903 |
+
token_mapping, stop_word = stop_word
|
904 |
+
assert(type(token_mapping) is dict)
|
905 |
+
else:
|
906 |
+
token_mapping = None
|
907 |
+
|
908 |
+
assert(type(stop_word) is str)
|
909 |
+
|
910 |
+
# Check fast tokenizer
|
911 |
+
if not is_fast_tokenizer:
|
912 |
+
print(
|
913 |
+
f"Unsloth: Not a fast tokenizer, so can't process it as of yet :(\n"\
|
914 |
+
"Please log a Github issue if you want this as a new feature!\n"\
|
915 |
+
"Your chat template will still work, but it won't add or edit tokens."
|
916 |
+
)
|
917 |
+
|
918 |
+
elif token_mapping is not None:
|
919 |
+
# token_mapping = {"<start_of_turn>" : "<|im_start|>", "<end_of_turn>" : "<|im_end|>"}
|
920 |
+
# For Gemma :)
|
921 |
+
|
922 |
+
string_vocab = tokenizer._tokenizer.to_str()
|
923 |
+
|
924 |
+
skipped = 0
|
925 |
+
for old_token, new_token in token_mapping.items():
|
926 |
+
old_count = string_vocab.count(f'"{old_token}"')
|
927 |
+
new_count = string_vocab.count(f'"{new_token}"')
|
928 |
+
if new_count != 0:
|
929 |
+
print(f"{new_token} is already a token. Skipping.")
|
930 |
+
skipped += 1
|
931 |
+
elif old_count == 0:
|
932 |
+
raise RuntimeError(f"{old_token} was not part of the tokenizer!")
|
933 |
+
else:
|
934 |
+
string_vocab = string_vocab.replace(f'"{old_token}"', f'"{new_token}"')
|
935 |
+
pass
|
936 |
+
pass
|
937 |
+
|
938 |
+
if map_eos_token and (not stop_word in token_mapping.values()):
|
939 |
+
# Do not map 107 = <|im_end|> and 1 = <|im_end|>. This will reduce the vocab size by 1
|
940 |
+
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
|
941 |
+
string_vocab = string_vocab.replace(tokenizer.eos_token, stop_word)
|
942 |
+
pass
|
943 |
+
|
944 |
+
if skipped != len(token_mapping):
|
945 |
+
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
|
946 |
+
|
947 |
+
# Careful on pad_token
|
948 |
+
old_pad_token = tokenizer.pad_token
|
949 |
+
if old_pad_token == tokenizer.eos_token:
|
950 |
+
old_pad_token = stop_word
|
951 |
+
same_padding_token = True
|
952 |
+
pass
|
953 |
+
|
954 |
+
if map_eos_token:
|
955 |
+
new_tokenizer = tokenizer.__class__(
|
956 |
+
tokenizer_object = new_tokenizer,
|
957 |
+
eos_token = stop_word,
|
958 |
+
pad_token = old_pad_token,
|
959 |
+
)
|
960 |
+
else:
|
961 |
+
new_tokenizer = tokenizer.__class__(
|
962 |
+
tokenizer_object = new_tokenizer,
|
963 |
+
pad_token = old_pad_token,
|
964 |
+
)
|
965 |
+
pass
|
966 |
+
|
967 |
+
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
|
968 |
+
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
|
969 |
+
else:
|
970 |
+
pass
|
971 |
+
|
972 |
+
elif map_eos_token and (stop_word != "eos_token"):
|
973 |
+
logger.warning_once(f"Unsloth: Will map {stop_word} to EOS = {tokenizer.eos_token}.")
|
974 |
+
|
975 |
+
# Replaces the old EOS token with a new one.
|
976 |
+
# Useful for ChatML <|im_end|> for example.
|
977 |
+
# Usually we train 2 more tokens <|im_start|> and <|im_end|>
|
978 |
+
# But training the lm_head and embeddings are slow!
|
979 |
+
# This is a HACK!
|
980 |
+
# Idea from https://huggingface.co/cognitivecomputations/dolphin-2.6-mistral-7b-dpo-laser
|
981 |
+
|
982 |
+
old_bos_token = getattr(tokenizer, "bos_token", None)
|
983 |
+
old_eos_token = getattr(tokenizer, "eos_token", None)
|
984 |
+
old_pad_token = getattr(tokenizer, "pad_token", None)
|
985 |
+
old_unk_token = getattr(tokenizer, "unk_token", None)
|
986 |
+
|
987 |
+
string_vocab = tokenizer._tokenizer.to_str()
|
988 |
+
# First check if new stop_word is in the tokenizer
|
989 |
+
if stop_word in string_vocab:
|
990 |
+
# We shall swap them around
|
991 |
+
temporary_stop_token = "<|:__TEMP//STOP//TOKEN__:|>"
|
992 |
+
string_vocab = string_vocab.replace(old_eos_token, temporary_stop_token)
|
993 |
+
string_vocab = string_vocab.replace(stop_word, old_eos_token)
|
994 |
+
string_vocab = string_vocab.replace(temporary_stop_token, stop_word)
|
995 |
+
else:
|
996 |
+
string_vocab = string_vocab.replace(old_eos_token, stop_word)
|
997 |
+
pass
|
998 |
+
new_tokenizer = tokenizer._tokenizer.from_str(string_vocab)
|
999 |
+
|
1000 |
+
# Careful on pad_token
|
1001 |
+
if old_pad_token == old_eos_token:
|
1002 |
+
old_pad_token = stop_word
|
1003 |
+
same_padding_token = True
|
1004 |
+
pass
|
1005 |
+
|
1006 |
+
new_tokenizer = tokenizer.__class__(
|
1007 |
+
tokenizer_object = new_tokenizer,
|
1008 |
+
bos_token = old_bos_token,
|
1009 |
+
eos_token = stop_word,
|
1010 |
+
unk_token = old_unk_token,
|
1011 |
+
pad_token = old_pad_token,
|
1012 |
+
)
|
1013 |
+
|
1014 |
+
# Must fix the sentence piece tokenizer since there's no tokenizer.model file!
|
1015 |
+
token_mapping = { old_eos_token : stop_word, }
|
1016 |
+
tokenizer = fix_sentencepiece_tokenizer(tokenizer, new_tokenizer, token_mapping,)
|
1017 |
+
pass
|
1018 |
+
|
1019 |
+
else:
|
1020 |
+
raise TypeError(
|
1021 |
+
f"Unsloth: `chat_template` must be a tuple of (your_template, eos_token,) or one of\n"\
|
1022 |
+
f"{CHAT_TEMPLATES.keys()}"
|
1023 |
+
)
|
1024 |
+
pass
|
1025 |
+
|
1026 |
+
# Careful on Gemma
|
1027 |
+
# bos_token is a must or else losses become too high
|
1028 |
+
if IS_GEMMA and not chat_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
|
1029 |
+
chat_template = "{{ bos_token }}" + chat_template
|
1030 |
+
pass
|
1031 |
+
|
1032 |
+
# For ShareGPT role -> from and content -> value
|
1033 |
+
new_chat_template = chat_template\
|
1034 |
+
.replace("'role'", "'" + mapping["role"] + "'")\
|
1035 |
+
.replace("'content'", "'" + mapping["content"] + "'")\
|
1036 |
+
.replace("'user'", "'" + mapping["user"] + "'")\
|
1037 |
+
.replace("'assistant'", "'" + mapping["assistant"] + "'")
|
1038 |
+
|
1039 |
+
_, tokenizer = patch_tokenizer(model = None, tokenizer = tokenizer)
|
1040 |
+
tokenizer.padding_side = old_padding_side
|
1041 |
+
|
1042 |
+
# If not normal HF, we add a check to make old templates work
|
1043 |
+
if mapping != {"role" : "role", "content" : "content", "user" : "user", "assistant" : "assistant"}:
|
1044 |
+
chat_template = \
|
1045 |
+
"{% if 'role' in messages[0] %}" + \
|
1046 |
+
chat_template + \
|
1047 |
+
"{% else %}" + \
|
1048 |
+
new_chat_template + \
|
1049 |
+
"{% endif %}"
|
1050 |
+
else:
|
1051 |
+
chat_template = new_chat_template
|
1052 |
+
pass
|
1053 |
+
tokenizer.chat_template = chat_template
|
1054 |
+
|
1055 |
+
# Also fix up other tokens
|
1056 |
+
old_pad_token = getattr(old_tokenizer, "pad_token", None)
|
1057 |
+
old_bos_token = getattr(old_tokenizer, "bos_token", None)
|
1058 |
+
old_unk_token = getattr(old_tokenizer, "unk_token", None)
|
1059 |
+
new_pad_token = getattr(tokenizer, "pad_token", None)
|
1060 |
+
new_bos_token = getattr(tokenizer, "bos_token", None)
|
1061 |
+
new_unk_token = getattr(tokenizer, "unk_token", None)
|
1062 |
+
if old_bos_token != new_bos_token: tokenizer.bos_token = old_bos_token
|
1063 |
+
if old_unk_token != new_unk_token: tokenizer.unk_token = old_unk_token
|
1064 |
+
if not same_padding_token:
|
1065 |
+
if old_pad_token != new_pad_token: tokenizer.pad_token = old_pad_token
|
1066 |
+
pass
|
1067 |
+
|
1068 |
+
# stopping_criteria = create_stopping_criteria(tokenizer, stop_word)
|
1069 |
+
|
1070 |
+
# Patch saving functions
|
1071 |
+
tokenizer = patch_saving_functions(tokenizer)
|
1072 |
+
|
1073 |
+
# Add Ollama
|
1074 |
+
tokenizer._ollama_modelfile = ollama_modelfile
|
1075 |
+
tokenizer._system_message = system_message
|
1076 |
+
return tokenizer#, stopping_criteria
|
1077 |
+
pass
|
1078 |
+
|
1079 |
+
|
1080 |
+
def remove_special_tokens(tokenizer, prompt):
|
1081 |
+
# Removes double BOS token
|
1082 |
+
if prompt.startswith(tokenizer.bos_token):
|
1083 |
+
prompt = prompt[len(tokenizer.bos_token):]
|
1084 |
+
pass
|
1085 |
+
return prompt
|
1086 |
+
pass
|
1087 |
+
|
1088 |
+
|
1089 |
+
def _parse_combined_prompt(combined_prompt, dataset):
|
1090 |
+
# Find {...}
|
1091 |
+
possible_columns = re.findall(r"\{(.+?)\}", combined_prompt)
|
1092 |
+
dataset_columns = set(dataset.column_names)
|
1093 |
+
for column in possible_columns:
|
1094 |
+
if column not in dataset_columns:
|
1095 |
+
raise KeyError(
|
1096 |
+
f"Unsloth: Your prompt includes '{column}' but this does not exist in the dataset. "\
|
1097 |
+
f"Only allowed columns are {list(dataset_columns)}"
|
1098 |
+
)
|
1099 |
+
pass
|
1100 |
+
pass
|
1101 |
+
|
1102 |
+
# Find [[...]]
|
1103 |
+
optional_prompts = list(re.finditer(r"\[\[.+?\]\]", combined_prompt, flags = re.DOTALL | re.MULTILINE))
|
1104 |
+
optional_prompts = [(x.span(), x.group(0)) for x in optional_prompts]
|
1105 |
+
|
1106 |
+
final_optional_prompts = []
|
1107 |
+
if len(optional_prompts) != 0:
|
1108 |
+
# Add left
|
1109 |
+
left = optional_prompts[0]
|
1110 |
+
l = left[0][0]
|
1111 |
+
if l != 0: final_optional_prompts.append(combined_prompt[:l])
|
1112 |
+
|
1113 |
+
# Add in between
|
1114 |
+
for left, right in zip(optional_prompts[:-1], optional_prompts[1:]):
|
1115 |
+
l, r = left[0][-1], right[0][0]
|
1116 |
+
final_optional_prompts.append(left)
|
1117 |
+
if l != r: final_optional_prompts.append(combined_prompt[l : r])
|
1118 |
+
pass
|
1119 |
+
final_optional_prompts.append(optional_prompts[-1])
|
1120 |
+
|
1121 |
+
# Add right
|
1122 |
+
right = optional_prompts[-1]
|
1123 |
+
r = right[0][1]
|
1124 |
+
if r != len(combined_prompt): final_optional_prompts.append(combined_prompt[r:])
|
1125 |
+
else:
|
1126 |
+
# Just add in the entire string
|
1127 |
+
final_optional_prompts.append(combined_prompt)
|
1128 |
+
pass
|
1129 |
+
|
1130 |
+
check_combined = "".join(x if type(x) is str else x[1] for x in final_optional_prompts)
|
1131 |
+
assert(combined_prompt == check_combined)
|
1132 |
+
|
1133 |
+
return possible_columns, final_optional_prompts
|
1134 |
+
pass
|
1135 |
+
|
1136 |
+
|
1137 |
+
def _create_formatter(possible_columns, final_optional_prompts, user_column_name):
|
1138 |
+
# Start final prompt!
|
1139 |
+
function = ["def __combined_prompt_processor__(examples):"]
|
1140 |
+
columns = list(set(possible_columns))
|
1141 |
+
for column in columns:
|
1142 |
+
function.append(f"{' '*4}{column}__ = examples['{column}']")
|
1143 |
+
function.append(f"{' '*4}texts = []")
|
1144 |
+
function.append(f"{' '*4}for ({', '.join(columns)}) in zip({', '.join(f'{x}__' for x in columns)}):")
|
1145 |
+
|
1146 |
+
# Add optional tags as well!
|
1147 |
+
final_prompt = ""
|
1148 |
+
formatter = []
|
1149 |
+
|
1150 |
+
for j, optional_prompt in enumerate(final_optional_prompts):
|
1151 |
+
if type(optional_prompt) is str:
|
1152 |
+
columns = re.findall(r"\{(.+?)\}", optional_prompt)
|
1153 |
+
formatter += columns
|
1154 |
+
# Must escape \n \r
|
1155 |
+
final_prompt += optional_prompt.encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
|
1156 |
+
else:
|
1157 |
+
where, prompt = optional_prompt
|
1158 |
+
# Strip [[...]]
|
1159 |
+
# Must escape \n \r
|
1160 |
+
prompt = prompt[2:-2].encode("unicode-escape").decode("utf-8").replace("'", "\\'").replace('"', '\\"')
|
1161 |
+
columns = re.findall(r"\{(.+?)\}", prompt)
|
1162 |
+
x = f"__optional_{j}__"
|
1163 |
+
prompt = f"{' '*8}{x} = '{prompt}'.format({', '.join(f'{x} = {x}' for x in columns)}) if {columns[0]} else ''"
|
1164 |
+
function.append(prompt)
|
1165 |
+
formatter.append(x)
|
1166 |
+
final_prompt += "{" + x + "}"
|
1167 |
+
pass
|
1168 |
+
pass
|
1169 |
+
|
1170 |
+
function.insert(1, f"{' '*4}__combined_prompt__ = '{final_prompt}'")
|
1171 |
+
function.append(f"{' '*8}texts.append("\
|
1172 |
+
f"__combined_prompt__.format({', '.join(f'{x} = {x}' for x in formatter)}))")
|
1173 |
+
function.append(f"{' '*4}return " + "{ " + f"'{user_column_name}' : texts" + " }")
|
1174 |
+
return "\n".join(function)
|
1175 |
+
pass
|
1176 |
+
|
1177 |
+
|
1178 |
+
def to_sharegpt(
|
1179 |
+
dataset,
|
1180 |
+
merged_prompt = "",
|
1181 |
+
merged_column_name = "instruction",
|
1182 |
+
output_column_name = "output",
|
1183 |
+
remove_unused_columns = True,
|
1184 |
+
conversation_extension = 1,
|
1185 |
+
random_state = 3407,
|
1186 |
+
):
|
1187 |
+
"""
|
1188 |
+
Converts a dataset to ShareGPT style.
|
1189 |
+
ShareGPT requires only 1 input and 1 output field.
|
1190 |
+
This means one has to merge multiple columns into 1 for 1 input field.
|
1191 |
+
Use `conversation_extension` to increase the length of each conversation by randomnly
|
1192 |
+
selecting a few and packing them into 1.
|
1193 |
+
|
1194 |
+
merged_prompt = "", Prompt to merge columns into 1 input
|
1195 |
+
merged_column_name = "instruction", Final column name for the input field
|
1196 |
+
output_column_name = "output", Final column name for the output field
|
1197 |
+
remove_unused_columns = True,
|
1198 |
+
conversation_extension = 1, Automatically combines `conversation_extension` convos into 1
|
1199 |
+
random_state = 3407,
|
1200 |
+
"""
|
1201 |
+
if "conversations" in dataset.column_names:
|
1202 |
+
convo = dataset[0]["conversations"]
|
1203 |
+
if type(convo) is list:
|
1204 |
+
raise TypeError("Unsloth: Your dataset is probably already in ShareGPT format!")
|
1205 |
+
pass
|
1206 |
+
pass
|
1207 |
+
|
1208 |
+
possible_columns, final_optional_prompts = _parse_combined_prompt(merged_prompt, dataset)
|
1209 |
+
function = _create_formatter(possible_columns, final_optional_prompts, merged_column_name)
|
1210 |
+
exec(function, globals())
|
1211 |
+
dataset = dataset.map(__combined_prompt_processor__, batched = True, desc = "Merging columns")
|
1212 |
+
|
1213 |
+
def __convert_to_sharegpt__(examples):
|
1214 |
+
users = examples[merged_column_name]
|
1215 |
+
assistants = examples[output_column_name]
|
1216 |
+
texts = [
|
1217 |
+
[
|
1218 |
+
{"from" : "human", "value" : str(user) },
|
1219 |
+
{"from" : "gpt", "value" : str(assistant)},
|
1220 |
+
] \
|
1221 |
+
for user, assistant in zip(users, assistants)
|
1222 |
+
]
|
1223 |
+
return { "conversations" : texts, }
|
1224 |
+
pass
|
1225 |
+
|
1226 |
+
dataset = dataset.map(
|
1227 |
+
__convert_to_sharegpt__,
|
1228 |
+
batched = True,
|
1229 |
+
desc = "Converting to ShareGPT",
|
1230 |
+
# Remove unused columns!
|
1231 |
+
remove_columns = dataset.column_names if remove_unused_columns else None,
|
1232 |
+
)
|
1233 |
+
|
1234 |
+
# Randomnly concat conversations to create a long stream!
|
1235 |
+
from datasets import concatenate_datasets
|
1236 |
+
n_extensions = max(conversation_extension-1, 0)
|
1237 |
+
if n_extensions == 0: return dataset
|
1238 |
+
|
1239 |
+
dataset = dataset.rename_columns({"conversations" : f"conversations0"})
|
1240 |
+
all_shuffled = [dataset]
|
1241 |
+
for j in range(1, n_extensions+1):
|
1242 |
+
shuffled = dataset.shuffle(seed = random_state+j).rename_columns({"conversations0" : f"conversations{j}"})
|
1243 |
+
all_shuffled.append(shuffled)
|
1244 |
+
pass
|
1245 |
+
dataset = concatenate_datasets(all_shuffled, axis = 1)
|
1246 |
+
|
1247 |
+
# Combine them into 1
|
1248 |
+
function = "def __combine_conversations__(examples):\n"
|
1249 |
+
n_extensions += 1
|
1250 |
+
for j in range(n_extensions):
|
1251 |
+
function += f"{' '*4}conversations{j}__ = examples['conversations{j}']\n"
|
1252 |
+
function += f"{' '*4}convos = []\n"
|
1253 |
+
function += f"{' '*4}for ({', '.join(f'conversations{j}' for j in range(n_extensions))}) "\
|
1254 |
+
f"in zip({', '.join(f'conversations{j}__' for j in range(n_extensions))}):\n"
|
1255 |
+
function += f"{' '*8}convos.append("\
|
1256 |
+
f"{'+'.join(f'conversations{j}' for j in range(n_extensions))})\n"
|
1257 |
+
function += f"{' '*4}return " + "{ " + f"'conversations' : convos" + " }"
|
1258 |
+
|
1259 |
+
# Map function
|
1260 |
+
exec(function, globals())
|
1261 |
+
dataset = dataset.map(
|
1262 |
+
__combine_conversations__,
|
1263 |
+
batched = True,
|
1264 |
+
desc = "Extending conversations",
|
1265 |
+
# Remove unused columns!
|
1266 |
+
remove_columns = dataset.column_names if remove_unused_columns else None,
|
1267 |
+
)
|
1268 |
+
return dataset
|
1269 |
+
pass
|
1270 |
+
|
1271 |
+
|
1272 |
+
def standardize_sharegpt(
|
1273 |
+
dataset,
|
1274 |
+
aliases_for_system = ["system",],
|
1275 |
+
aliases_for_user = ["user", "human", "input",],
|
1276 |
+
aliases_for_assistant = ["gpt", "assistant", "output",],
|
1277 |
+
):
|
1278 |
+
"""
|
1279 |
+
Standardizes ShareGPT and other formats to user/assistant Hugging Face format.
|
1280 |
+
|
1281 |
+
Get aliases for the system, user and assistant roles.
|
1282 |
+
These shall map to "system", "user" and "assistant" respectively.
|
1283 |
+
|
1284 |
+
aliases_for_system = ["system",],
|
1285 |
+
aliases_for_user = ["user", "human", "input",],
|
1286 |
+
aliases_for_assistant = ["gpt", "assistant", "output",],
|
1287 |
+
"""
|
1288 |
+
import collections
|
1289 |
+
import itertools
|
1290 |
+
|
1291 |
+
convos = dataset[:10]["conversations"]
|
1292 |
+
uniques = collections.defaultdict(list)
|
1293 |
+
for convo in convos:
|
1294 |
+
for message in convo:
|
1295 |
+
for key, value in message.items():
|
1296 |
+
uniques[key].append(value)
|
1297 |
+
pass
|
1298 |
+
|
1299 |
+
# Must be only 2 entries
|
1300 |
+
assert(len(uniques.keys()) == 2)
|
1301 |
+
|
1302 |
+
keys = list(uniques.keys())
|
1303 |
+
length_first = len(set(uniques[keys[0]]))
|
1304 |
+
length_second = len(set(uniques[keys[1]]))
|
1305 |
+
|
1306 |
+
if length_first < length_second:
|
1307 |
+
# Role is assigned to the first element
|
1308 |
+
role_key = keys[0]
|
1309 |
+
content_key = keys[1]
|
1310 |
+
else:
|
1311 |
+
role_key = keys[1]
|
1312 |
+
content_key = keys[0]
|
1313 |
+
pass
|
1314 |
+
|
1315 |
+
# Check roles are in aliases
|
1316 |
+
all_aliases = set(aliases_for_system + aliases_for_user + aliases_for_assistant)
|
1317 |
+
roles = set(uniques[role_key])
|
1318 |
+
leftover_aliases = (all_aliases | roles) - all_aliases
|
1319 |
+
if len(leftover_aliases) != 0:
|
1320 |
+
raise TypeError(
|
1321 |
+
f"Unsloth: {list(leftover_aliases)} are not in aliases. Please update aliases."
|
1322 |
+
)
|
1323 |
+
pass
|
1324 |
+
|
1325 |
+
# Mapping for aliases
|
1326 |
+
aliases_mapping = {}
|
1327 |
+
for x in aliases_for_system: aliases_mapping[x] = "system"
|
1328 |
+
for x in aliases_for_user: aliases_mapping[x] = "user"
|
1329 |
+
for x in aliases_for_assistant: aliases_mapping[x] = "assistant"
|
1330 |
+
|
1331 |
+
def _standardize_dataset(examples):
|
1332 |
+
convos = examples["conversations"]
|
1333 |
+
all_convos = []
|
1334 |
+
for convo in convos:
|
1335 |
+
new_convo = [
|
1336 |
+
{ "role" : aliases_mapping[message[role_key]], "content" : message[content_key], }
|
1337 |
+
for message in convo
|
1338 |
+
]
|
1339 |
+
all_convos.append(new_convo)
|
1340 |
+
pass
|
1341 |
+
return { "conversations" : all_convos, }
|
1342 |
+
pass
|
1343 |
+
|
1344 |
+
return dataset.map(_standardize_dataset, batched = True, desc = "Standardizing format")
|
1345 |
+
pass
|
1346 |
+
|
1347 |
+
|
1348 |
+
def get_ollama_eos_tokens(tokenizer, extra_eos_tokens = []):
|
1349 |
+
added_tokens_decoder = tokenizer.added_tokens_decoder.values()
|
1350 |
+
added_tokens_decoder = [str(x) for x in added_tokens_decoder]
|
1351 |
+
|
1352 |
+
# Remove added_tokens_decoder duplicates
|
1353 |
+
added_tokens_decoder = list(set(added_tokens_decoder) - set(extra_eos_tokens))
|
1354 |
+
|
1355 |
+
# Remove BOS
|
1356 |
+
if getattr(tokenizer, "bos_token", None) is not None:
|
1357 |
+
added_tokens_decoder = [x for x in added_tokens_decoder if x != tokenizer.bos_token]
|
1358 |
+
pass
|
1359 |
+
|
1360 |
+
repeatted_tokens = []
|
1361 |
+
# Join all vocab
|
1362 |
+
joined_text = "\x01\x00".join(added_tokens_decoder)
|
1363 |
+
for token in added_tokens_decoder:
|
1364 |
+
n = len(token)
|
1365 |
+
repeatted_counts = joined_text.count(token[:n//2])
|
1366 |
+
# Try finding longer than 1/2 of the token in the rest
|
1367 |
+
# For eg <|reserved_special_token_0|>, <|reserved_special_token_1|>
|
1368 |
+
if repeatted_counts > 2:
|
1369 |
+
for j in range(n//2+1, n):
|
1370 |
+
if joined_text.count(token[:j]) < repeatted_counts:
|
1371 |
+
j -= 1
|
1372 |
+
# Remove repeatted tokens to reduce search space
|
1373 |
+
joined_text = joined_text.replace(token[:j], "")
|
1374 |
+
repeatted_tokens.append(token[:j])
|
1375 |
+
break
|
1376 |
+
pass
|
1377 |
+
pass
|
1378 |
+
pass
|
1379 |
+
|
1380 |
+
# Remove duplicates
|
1381 |
+
splitted = joined_text.split("\x01\x00")
|
1382 |
+
final_eos_tokens = []
|
1383 |
+
for old, new in zip(added_tokens_decoder, splitted):
|
1384 |
+
if old == new: final_eos_tokens.append(old)
|
1385 |
+
pass
|
1386 |
+
final_eos_tokens += extra_eos_tokens
|
1387 |
+
final_eos_tokens += repeatted_tokens
|
1388 |
+
|
1389 |
+
# Remove new lines, spaces and HTML tags
|
1390 |
+
filtered_eos_tokens = []
|
1391 |
+
for token in final_eos_tokens:
|
1392 |
+
if token.count("\n") == len(token): continue
|
1393 |
+
elif token.count("▁") == len(token): continue
|
1394 |
+
elif token.startswith("<") and len(token) <= 2: continue
|
1395 |
+
elif token.startswith("</") and len(token) == 3: continue
|
1396 |
+
filtered_eos_tokens.append(token)
|
1397 |
+
pass
|
1398 |
+
return filtered_eos_tokens
|
1399 |
+
pass
|
1400 |
+
|
1401 |
+
|
1402 |
+
def construct_chat_template( \
|
1403 |
+
|
1404 |
+
tokenizer = None,
|
1405 |
+
|
1406 |
+
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
1407 |
+
|
1408 |
+
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1409 |
+
|
1410 |
+
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1411 |
+
|
1412 |
+
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1413 |
+
|
1414 |
+
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1415 |
+
|
1416 |
+
{OUTPUT}<|eot_id|>""",
|
1417 |
+
|
1418 |
+
default_system_message = \
|
1419 |
+
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
|
1420 |
+
|
1421 |
+
extra_eos_tokens = None,
|
1422 |
+
):
|
1423 |
+
"""
|
1424 |
+
Creates a Ollama modelfile and a HF Jinja template from a custom
|
1425 |
+
template. You must provide 2x examples of an input & output.
|
1426 |
+
There is an optional system message as well.
|
1427 |
+
|
1428 |
+
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
|
1429 |
+
"""
|
1430 |
+
# Strip only the left
|
1431 |
+
chat_template = chat_template.lstrip()
|
1432 |
+
|
1433 |
+
assert(tokenizer is not None)
|
1434 |
+
|
1435 |
+
if extra_eos_tokens is None: extra_eos_tokens = []
|
1436 |
+
elif type(extra_eos_tokens) is str: extra_eos_tokens = [extra_eos_tokens,]
|
1437 |
+
|
1438 |
+
vocab = tokenizer.get_vocab()
|
1439 |
+
for extra_eos in extra_eos_tokens:
|
1440 |
+
assert(type(extra_eos) is str)
|
1441 |
+
if extra_eos not in vocab:
|
1442 |
+
raise ValueError(f"Unsloth: `{extra_eos}` is not a singular token in the tokenizer.")
|
1443 |
+
pass
|
1444 |
+
pass
|
1445 |
+
|
1446 |
+
error_msg = \
|
1447 |
+
"Unsloth: Your prompt template must have 2 examples showing the user input {INPUT} "\
|
1448 |
+
"and the assistant output {OUTPUT}\n\n"\
|
1449 |
+
"For example what is not allowed is just:\n"\
|
1450 |
+
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n\n\n"\
|
1451 |
+
"What is required is 2x of this:\n"\
|
1452 |
+
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"\
|
1453 |
+
"### Input:\\n{INPUT}\\n\\n### Response:\\n{OUTPUT}\\n"
|
1454 |
+
|
1455 |
+
# Check for EOS after {OUTPUT}
|
1456 |
+
if tokenizer.eos_token is not None:
|
1457 |
+
extra_eos_tokens.insert(0, tokenizer.eos_token)
|
1458 |
+
if len(extra_eos_tokens) == 0:
|
1459 |
+
raise RuntimeError(
|
1460 |
+
"Unsloth: Your tokenizer does not have an EOS token? Please provide one via extra_eos_tokens!"
|
1461 |
+
)
|
1462 |
+
pass
|
1463 |
+
|
1464 |
+
# Check tokenizer types
|
1465 |
+
tokenizer_name = tokenizer.name_or_path.lower()
|
1466 |
+
if tokenizer_name.startswith(("unsloth/llama-3-8b-instruct", "unsloth/llama-3-70b-instruct")):
|
1467 |
+
# Add <|eot_id|>
|
1468 |
+
extra_eos_tokens.append("<|eot_id|>")
|
1469 |
+
elif ("<|eot_id|>" in extra_eos_tokens or "<|eot_id|>" in chat_template) and \
|
1470 |
+
tokenizer_name.startswith(("unsloth/llama-3-8b", "unsloth/llama-3-70b")):
|
1471 |
+
# Warn
|
1472 |
+
logger.warning(
|
1473 |
+
"Unsloth: Base llama-3 models did not train <|eot_id|>.\n"\
|
1474 |
+
"Please use the instruct version or use <|end_of_text|>"
|
1475 |
+
)
|
1476 |
+
pass
|
1477 |
+
extra_eos_tokens = list(set(extra_eos_tokens))
|
1478 |
+
|
1479 |
+
count_eos = 0
|
1480 |
+
for eos in extra_eos_tokens:
|
1481 |
+
count_eos += len(re.findall(r"{OUTPUT}" + re.escape(eos), chat_template))
|
1482 |
+
pass
|
1483 |
+
|
1484 |
+
# This forces you to provide 2 input and outputs
|
1485 |
+
final_combined_check = False
|
1486 |
+
|
1487 |
+
try:
|
1488 |
+
# O(N^2) search finding 2 repeatted pieces of text
|
1489 |
+
j = len(chat_template)-1
|
1490 |
+
at_least_one = False
|
1491 |
+
while j > 0:
|
1492 |
+
found = chat_template.rfind(chat_template[j:], 0, j)
|
1493 |
+
if found == -1: break
|
1494 |
+
j -= 1
|
1495 |
+
at_least_one = True
|
1496 |
+
pass
|
1497 |
+
if j > 0: j += 1
|
1498 |
+
else: raise RuntimeError(error_msg)
|
1499 |
+
|
1500 |
+
if not at_least_one: raise RuntimeError(error_msg)
|
1501 |
+
|
1502 |
+
# Must be equivalent to left
|
1503 |
+
final_combined_check = True
|
1504 |
+
|
1505 |
+
# Repeatted text
|
1506 |
+
instruction_response = chat_template[j:]
|
1507 |
+
if instruction_response.count("{INPUT}") != 1 or instruction_response.count("{OUTPUT}") != 1:
|
1508 |
+
raise RuntimeError(error_msg)
|
1509 |
+
pass
|
1510 |
+
|
1511 |
+
# 1st System, Instruction, Output pair
|
1512 |
+
left = chat_template[:j]
|
1513 |
+
# 2nd Instruction, Output pair
|
1514 |
+
right = chat_template[j:]
|
1515 |
+
|
1516 |
+
final_combined_check = left if final_combined_check else chat_template
|
1517 |
+
|
1518 |
+
# Isolate input
|
1519 |
+
extra_eos_tokens_regex = "|".join(f"(?:{re.escape(x)})" for x in extra_eos_tokens)
|
1520 |
+
if len(extra_eos_tokens_regex) != 0:
|
1521 |
+
find_end = f"(?:{extra_eos_tokens_regex})?"
|
1522 |
+
else:
|
1523 |
+
find_end = ""
|
1524 |
+
find_end = r"\{INPUT\}[\s\n]{0,}" + find_end
|
1525 |
+
input_end = list(re.finditer(find_end, right))
|
1526 |
+
assert(len(input_end) == 1)
|
1527 |
+
input_end = input_end[0]
|
1528 |
+
input_end = input_end.span(0)[1]
|
1529 |
+
input_part = right[:input_end]
|
1530 |
+
|
1531 |
+
# Isolate output
|
1532 |
+
output_part = right[input_end:]
|
1533 |
+
|
1534 |
+
# Isolate system
|
1535 |
+
where_system = left.find(input_part)
|
1536 |
+
system_part = left[:where_system if where_system != -1 else len(left)]
|
1537 |
+
|
1538 |
+
# Check if the user provided a correct prompt
|
1539 |
+
combined = system_part + input_part + output_part
|
1540 |
+
if combined != final_combined_check:
|
1541 |
+
combined_changed = combined .replace('\n', '\\n')
|
1542 |
+
left_changed = final_combined_check.replace('\n', '\\n')
|
1543 |
+
raise RuntimeError(
|
1544 |
+
"Unsloth: The prompt template you provided isn't correct. You gave:\n"\
|
1545 |
+
f"{combined_changed}\n\n"\
|
1546 |
+
"But we require the following:\n"\
|
1547 |
+
f"{left_changed}"
|
1548 |
+
)
|
1549 |
+
pass
|
1550 |
+
except:
|
1551 |
+
ending = chat_template[chat_template.find("{OUTPUT}") + len("{OUTPUT}"):]
|
1552 |
+
|
1553 |
+
ending = re.escape(ending)
|
1554 |
+
find_text = "{INPUT}" + ending + "(.+?{OUTPUT}" + ending + ")"
|
1555 |
+
response_part = re.findall(find_text, chat_template, flags = re.DOTALL | re.MULTILINE)
|
1556 |
+
response_part = response_part[0]
|
1557 |
+
|
1558 |
+
for j in range(1, len(response_part)):
|
1559 |
+
try_find = re.escape(response_part[:j])
|
1560 |
+
try: found = next(re.finditer("(" + try_find + ").+?\{INPUT\}", chat_template, flags = re.DOTALL | re.MULTILINE))
|
1561 |
+
except: break
|
1562 |
+
pass
|
1563 |
+
separator = found.group(1)
|
1564 |
+
|
1565 |
+
response_start = chat_template.find(response_part)
|
1566 |
+
start_instruction = chat_template[:response_start].rfind(separator)
|
1567 |
+
if start_instruction == -1: start_instruction = 0
|
1568 |
+
instruction_part = chat_template[start_instruction:response_start]
|
1569 |
+
|
1570 |
+
combined = instruction_part + response_part
|
1571 |
+
where = chat_template.find(combined)
|
1572 |
+
system_part = chat_template[:where]
|
1573 |
+
|
1574 |
+
system_part, input_part, output_part = system_part, instruction_part, response_part
|
1575 |
+
pass
|
1576 |
+
|
1577 |
+
if count_eos == 0:
|
1578 |
+
logger.warning("Unsloth: We automatically added an EOS token to stop endless generations.")
|
1579 |
+
eos = extra_eos_tokens[0]
|
1580 |
+
output_part = output_part + eos
|
1581 |
+
pass
|
1582 |
+
|
1583 |
+
# Ollama modelfile parts
|
1584 |
+
|
1585 |
+
# Check bos_token is in system prompt
|
1586 |
+
ollama_system = system_part
|
1587 |
+
has_bos_token = False
|
1588 |
+
always_bos_token = False
|
1589 |
+
if tokenizer("A").input_ids[0] == getattr(tokenizer, "bos_token_id", None):
|
1590 |
+
always_bos_token = True
|
1591 |
+
if ollama_system.startswith(tokenizer.bos_token):
|
1592 |
+
has_bos_token = True
|
1593 |
+
ollama_system = ollama_system[len(tokenizer.bos_token):]
|
1594 |
+
pass
|
1595 |
+
pass
|
1596 |
+
# Check system
|
1597 |
+
if "{SYSTEM}" in ollama_system:
|
1598 |
+
system_modelfile = "{{ if .System }}" + ollama_system.replace("{SYSTEM}", "{{ .System }}") + "{{ end }}"
|
1599 |
+
else:
|
1600 |
+
system_modelfile = ollama_system
|
1601 |
+
pass
|
1602 |
+
input_modelfile = "{{ if .Prompt }}" + input_part .replace("{INPUT}", "{{ .Prompt }}") + "{{ end }}"
|
1603 |
+
output_modelfile = output_part.replace("{OUTPUT}", "{{ .Response }}")
|
1604 |
+
|
1605 |
+
# Ollama EOS
|
1606 |
+
ollama_eos = get_ollama_eos_tokens(tokenizer, extra_eos_tokens)
|
1607 |
+
ollama_eos = '\n'.join(f'PARAMETER stop "{eos}"' for eos in ollama_eos)
|
1608 |
+
|
1609 |
+
# Add temperature and min_p to counteract gibberish
|
1610 |
+
ollama_eos += "\nPARAMETER temperature 1.5\nPARAMETER min_p 0.1"
|
1611 |
+
|
1612 |
+
# Ollama modelfile
|
1613 |
+
part = '"""'
|
1614 |
+
modelfile = 'FROM {__FILE_LOCATION__}\n\n'\
|
1615 |
+
'TEMPLATE ' + part + system_modelfile + input_modelfile + output_modelfile + \
|
1616 |
+
part + '\n\n' + ollama_eos
|
1617 |
+
|
1618 |
+
# HF Jinja Chat template
|
1619 |
+
def process(part, which, content = "message['content']"):
|
1620 |
+
if part.endswith(which):
|
1621 |
+
part = "'" + part[:part.find(which)] + f"' + {content}"
|
1622 |
+
elif part.startswith(which):
|
1623 |
+
part = f"{content} + '" + part[part.find(which):] + "'"
|
1624 |
+
else:
|
1625 |
+
part = "'" + part.replace(which, f"' + {content} + '") + "'"
|
1626 |
+
if part.startswith("'' + "): part = part[5:]
|
1627 |
+
return part
|
1628 |
+
pass
|
1629 |
+
input_jinja = process(input_part, "{INPUT}")
|
1630 |
+
output_jinja = process(output_part, "{OUTPUT}")
|
1631 |
+
pass
|
1632 |
+
|
1633 |
+
jinja_template = \
|
1634 |
+
"{% for message in loop_messages %}"\
|
1635 |
+
"{% if message['role'] == 'user' %}"\
|
1636 |
+
"{{ " + input_jinja + " }}"\
|
1637 |
+
"{% elif message['role'] == 'assistant' %}"\
|
1638 |
+
"{{ " + output_jinja + " }}"\
|
1639 |
+
"{% else %}"\
|
1640 |
+
"{{ raise_exception('Only user and assistant roles are supported!') }}"\
|
1641 |
+
"{% endif %}"\
|
1642 |
+
"{% endfor %}"\
|
1643 |
+
"{% if add_generation_prompt %}"\
|
1644 |
+
"{{ '" + output_part[:output_part.find("{OUTPUT}")] + "' }}"\
|
1645 |
+
"{% endif %}"
|
1646 |
+
pass
|
1647 |
+
|
1648 |
+
# Now add system prompt to jinja
|
1649 |
+
if len(system_part) != 0:
|
1650 |
+
partial_system = process(system_part, "{SYSTEM}", "messages[0]['content']")
|
1651 |
+
partial_system = partial_system.replace("{SYSTEM}", "")
|
1652 |
+
|
1653 |
+
if "{SYSTEM}" in partial_system:
|
1654 |
+
if default_system_message is None:
|
1655 |
+
raise RuntimeError("Unsloth: Please specify a default system message!")
|
1656 |
+
pass
|
1657 |
+
|
1658 |
+
# Separate the BOS
|
1659 |
+
if has_bos_token:
|
1660 |
+
partial_system = partial_system.replace(tokenizer.bos_token, "", 1)
|
1661 |
+
system_part = system_part .replace(tokenizer.bos_token, "", 1)
|
1662 |
+
pass
|
1663 |
+
|
1664 |
+
partial_system = \
|
1665 |
+
"{% if messages[0]['role'] == 'system' %}"\
|
1666 |
+
"{{ " + partial_system + " }}"\
|
1667 |
+
"{% set loop_messages = messages[1:] %}"
|
1668 |
+
if default_system_message is not None:
|
1669 |
+
full_system = system_part.replace("{SYSTEM}", default_system_message)
|
1670 |
+
if "{SYSTEM}" in system_part:
|
1671 |
+
modelfile += '\nSYSTEM "' + default_system_message + '"'
|
1672 |
+
pass
|
1673 |
+
partial_system += "{% else %}"\
|
1674 |
+
"{{ '" + full_system + "' }}"\
|
1675 |
+
"{% set loop_messages = messages %}"\
|
1676 |
+
"{% endif %}"
|
1677 |
+
else:
|
1678 |
+
partial_system += "{% endif %}"
|
1679 |
+
pass
|
1680 |
+
|
1681 |
+
jinja_template = partial_system + jinja_template
|
1682 |
+
|
1683 |
+
if has_bos_token:
|
1684 |
+
jinja_template = "{{ bos_token }}" + jinja_template
|
1685 |
+
pass
|
1686 |
+
|
1687 |
+
# Fix missing loop_messages
|
1688 |
+
if "{% set loop_messages = messages %}" not in jinja_template:
|
1689 |
+
jinja_template = jinja_template.replace(
|
1690 |
+
"{% for message in loop_messages %}",
|
1691 |
+
"{% for message in messages %}",
|
1692 |
+
1, # Only replace the first one
|
1693 |
+
)
|
1694 |
+
pass
|
1695 |
+
|
1696 |
+
# Check if system part is the same!
|
1697 |
+
jinja_template = re.sub(
|
1698 |
+
r"\{\% if messages\[0\]\['role'\] \=\= 'system' \%\}\{\{ '(.+?)' \}\}"\
|
1699 |
+
r"\{\% set loop\_messages \= messages\[1\:\] \%\}"\
|
1700 |
+
r"\{\% else \%\}\{\{ '\1' \}\}\{\% set loop\_messages \= messages \%\}\{\% endif \%\}"\
|
1701 |
+
r"\{\% for message in loop\_messages \%\}",
|
1702 |
+
r"{{ '\1' }}{% for message in messages %}",
|
1703 |
+
jinja_template, flags = re.MULTILINE | re.DOTALL,
|
1704 |
+
)
|
1705 |
+
|
1706 |
+
# Check jinja tempate for bos
|
1707 |
+
if always_bos_token:
|
1708 |
+
if not jinja_template.startswith(("{{ bos_token }}", "{{- bos_token }}")):
|
1709 |
+
jinja_template = "{{ bos_token }}" + jinja_template
|
1710 |
+
pass
|
1711 |
+
|
1712 |
+
# Get instruction and output parts for train_on_inputs = False
|
1713 |
+
input_part = input_part [:input_part .find("{INPUT}")]
|
1714 |
+
output_part = output_part[:output_part.find("{OUTPUT}")]
|
1715 |
+
return modelfile, jinja_template, input_part, output_part
|
1716 |
+
pass
|
1717 |
+
|
1718 |
+
|
1719 |
+
def test_construct_chat_template():
|
1720 |
+
token = "hf_"
|
1721 |
+
from transformers import AutoTokenizer
|
1722 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", token = token)
|
1723 |
+
|
1724 |
+
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
1725 |
+
|
1726 |
+
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1727 |
+
|
1728 |
+
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1729 |
+
|
1730 |
+
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1731 |
+
|
1732 |
+
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1733 |
+
|
1734 |
+
{OUTPUT}<|eot_id|>"""
|
1735 |
+
|
1736 |
+
default_system_message = \
|
1737 |
+
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request."
|
1738 |
+
|
1739 |
+
extra_eos_tokens = None
|
1740 |
+
|
1741 |
+
modelfile, jinja_template, _, _ = construct_chat_template(
|
1742 |
+
tokenizer = tokenizer,
|
1743 |
+
chat_template = chat_template,
|
1744 |
+
extra_eos_tokens = extra_eos_tokens,
|
1745 |
+
)
|
1746 |
+
|
1747 |
+
messages = [
|
1748 |
+
{"role": "system", "content": "You are an assistant"},
|
1749 |
+
{"role": "user", "content": "What is 2+2?"},
|
1750 |
+
{"role": "assistant", "content": "It's 4."},
|
1751 |
+
{"role": "user", "content": "Ok!"},
|
1752 |
+
{"role": "assistant", "content": "Anything else?"},
|
1753 |
+
{"role": "user", "content": "What's 2x2?"},
|
1754 |
+
]
|
1755 |
+
correct_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1756 |
+
|
1757 |
+
tokenizer.chat_template = jinja_template
|
1758 |
+
new_output = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
1759 |
+
assert(correct_output == new_output)
|
1760 |
+
pass
|
1761 |
+
pass
|
1762 |
+
|
1763 |
+
|
1764 |
+
def apply_chat_template( \
|
1765 |
+
|
1766 |
+
dataset,
|
1767 |
+
tokenizer = None,
|
1768 |
+
|
1769 |
+
chat_template = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
1770 |
+
|
1771 |
+
{SYSTEM}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1772 |
+
|
1773 |
+
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1774 |
+
|
1775 |
+
{OUTPUT}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
1776 |
+
|
1777 |
+
{INPUT}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
1778 |
+
|
1779 |
+
{OUTPUT}<|eot_id|>""",
|
1780 |
+
|
1781 |
+
default_system_message = \
|
1782 |
+
"Below are some instructions that describe some tasks. Write responses that appropriately complete each request.",
|
1783 |
+
|
1784 |
+
extra_eos_tokens = None,
|
1785 |
+
|
1786 |
+
):
|
1787 |
+
"""
|
1788 |
+
Creates a Ollama modelfile and a HF Jinja template from a custom
|
1789 |
+
template. You must provide 2x examples of an input & output.
|
1790 |
+
There is an optional system message as well.
|
1791 |
+
|
1792 |
+
You must use {INPUT}, {OUTPUT} twice, and {SYSTEM} is optional.
|
1793 |
+
"""
|
1794 |
+
modelfile, jinja_template, input_part, output_part = construct_chat_template(
|
1795 |
+
tokenizer = tokenizer,
|
1796 |
+
chat_template = chat_template,
|
1797 |
+
default_system_message = default_system_message,
|
1798 |
+
extra_eos_tokens = extra_eos_tokens,
|
1799 |
+
)
|
1800 |
+
def formatting_prompts_func(examples):
|
1801 |
+
convos = examples["conversations"]
|
1802 |
+
texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]
|
1803 |
+
return { "text" : texts, }
|
1804 |
+
pass
|
1805 |
+
|
1806 |
+
tokenizer.chat_template = jinja_template
|
1807 |
+
tokenizer._ollama_modelfile = modelfile
|
1808 |
+
tokenizer._unsloth_input_part = input_part
|
1809 |
+
tokenizer._unsloth_output_part = output_part
|
1810 |
+
|
1811 |
+
return dataset.map(formatting_prompts_func, batched = True,)
|
1812 |
+
pass
|
1813 |
+
|
1814 |
+
|
1815 |
+
# From https://www.geeksforgeeks.org/longest-common-substring-array-strings/
|
1816 |
+
# Longest Common Substring in an Array of Strings
|
1817 |
+
def _longest_common_substring(arr):
|
1818 |
+
n = len(arr)
|
1819 |
+
s = arr[0]
|
1820 |
+
l = len(s)
|
1821 |
+
res = ""
|
1822 |
+
for i in range(l):
|
1823 |
+
for j in range(i + 1, l + 1):
|
1824 |
+
stem = s[i:j]
|
1825 |
+
k = 1
|
1826 |
+
for k in range(1, n):
|
1827 |
+
if stem not in arr[k]:
|
1828 |
+
break
|
1829 |
+
if (k + 1 == n and len(res) < len(stem)):
|
1830 |
+
res = stem
|
1831 |
+
return res
|
1832 |
+
pass
|
1833 |
+
|
1834 |
+
|
1835 |
+
def _find_common_token_ids(component, tokenizer):
|
1836 |
+
"""
|
1837 |
+
\n### User:\n\n
|
1838 |
+
\n\n### User:\n\n
|
1839 |
+
etc
|
1840 |
+
we need to find the middle most repeatted part.
|
1841 |
+
Tokenizers can tokenize newlines or spaces as 1 token!
|
1842 |
+
"""
|
1843 |
+
right_text = ""
|
1844 |
+
if component.endswith (" "): right_text = " "
|
1845 |
+
elif component.endswith("\n"): right_text = "\n"
|
1846 |
+
left_text = ""
|
1847 |
+
if component.startswith (" "): left_text = " "
|
1848 |
+
elif component.startswith("\n"): left_text = "\n"
|
1849 |
+
stripped = component.strip()
|
1850 |
+
|
1851 |
+
# Add current pieces and also newlines
|
1852 |
+
all_input_ids = []
|
1853 |
+
for left in range(3):
|
1854 |
+
for right in range(3):
|
1855 |
+
x = left*left_text + stripped + right*right_text
|
1856 |
+
x = tokenizer(x, add_special_tokens = False).input_ids
|
1857 |
+
all_input_ids.append(x)
|
1858 |
+
|
1859 |
+
x = left*"\n" + stripped + right*"\n"
|
1860 |
+
x = tokenizer(x, add_special_tokens = False).input_ids
|
1861 |
+
all_input_ids.append(x)
|
1862 |
+
pass
|
1863 |
+
pass
|
1864 |
+
substring = _longest_common_substring([str(x + [0]) for x in all_input_ids])
|
1865 |
+
substring = substring.split(", ")[:-1]
|
1866 |
+
substring = [int(x) for x in substring]
|
1867 |
+
|
1868 |
+
# Also get rest of tokenized string
|
1869 |
+
original = tokenizer(component, add_special_tokens = False).input_ids
|
1870 |
+
# Get optional left and right
|
1871 |
+
for j in range(len(original)):
|
1872 |
+
if original[j : j + len(substring)] == substring: break
|
1873 |
+
optional_left = original[:j]
|
1874 |
+
optional_right = original[j+len(substring):]
|
1875 |
+
return substring, optional_left, optional_right
|
1876 |
+
pass
|
1877 |
+
|
1878 |
+
|
1879 |
+
def train_on_responses_only(
|
1880 |
+
trainer,
|
1881 |
+
instruction_part = None,
|
1882 |
+
response_part = None,
|
1883 |
+
):
|
1884 |
+
"""
|
1885 |
+
Trains only on responses and not on the instruction by masking out
|
1886 |
+
the labels with -100 for the instruction part.
|
1887 |
+
"""
|
1888 |
+
tokenizer = trainer.tokenizer
|
1889 |
+
|
1890 |
+
if not hasattr(tokenizer, "_unsloth_input_part") or \
|
1891 |
+
not hasattr(tokenizer, "_unsloth_output_part"):
|
1892 |
+
|
1893 |
+
if instruction_part is None or response_part is None:
|
1894 |
+
raise ValueError("Unsloth: instruction_part and response_part must be given!")
|
1895 |
+
pass
|
1896 |
+
elif (instruction_part is not None or response_part is not None) and \
|
1897 |
+
(hasattr(tokenizer, "_unsloth_input_part") or hasattr(tokenizer, "_unsloth_output_part")):
|
1898 |
+
|
1899 |
+
raise ValueError("Unsloth: Your tokenizer already has instruction and response parts set - do not give custom ones!")
|
1900 |
+
else:
|
1901 |
+
instruction_part = tokenizer._unsloth_input_part
|
1902 |
+
response_part = tokenizer._unsloth_output_part
|
1903 |
+
pass
|
1904 |
+
|
1905 |
+
# Get most common tokens since tokenizers can tokenize stuff differently!
|
1906 |
+
Q_must, Q_left, Q_right = _find_common_token_ids(instruction_part, tokenizer)
|
1907 |
+
A_must, A_left, A_right = _find_common_token_ids(response_part, tokenizer)
|
1908 |
+
|
1909 |
+
# Store some temporary stuff
|
1910 |
+
A_first = A_must[0]
|
1911 |
+
len_A_must = len(A_must)
|
1912 |
+
A_left_reversed = A_left[::-1]
|
1913 |
+
A_right_forward = A_right
|
1914 |
+
|
1915 |
+
Q_first = Q_must[0]
|
1916 |
+
len_Q_must = len(Q_must)
|
1917 |
+
Q_left_reversed = Q_left[::-1]
|
1918 |
+
Q_right_forward = Q_right
|
1919 |
+
|
1920 |
+
def _train_on_responses_only(examples):
|
1921 |
+
input_ids_ = examples["input_ids"]
|
1922 |
+
all_labels = []
|
1923 |
+
|
1924 |
+
for input_ids in input_ids_:
|
1925 |
+
n = len(input_ids)
|
1926 |
+
labels = [-100] * n
|
1927 |
+
n_minus_1 = n - 1
|
1928 |
+
j = 0
|
1929 |
+
while j < n:
|
1930 |
+
# Find <assistant>
|
1931 |
+
if (input_ids[j] == A_first) and \
|
1932 |
+
(input_ids[j : (k := j + len_A_must)] == A_must):
|
1933 |
+
|
1934 |
+
# Now backtrack to get previous optional tokens
|
1935 |
+
for optional_left in A_left_reversed:
|
1936 |
+
if j < 1: break
|
1937 |
+
if optional_left == input_ids[j-1]: j -= 1
|
1938 |
+
else: break
|
1939 |
+
pass
|
1940 |
+
# And forwards look as well
|
1941 |
+
for optional_right in A_right_forward:
|
1942 |
+
if k >= n_minus_1: break
|
1943 |
+
if optional_right == input_ids[k+1]: k += 1
|
1944 |
+
else: break
|
1945 |
+
pass
|
1946 |
+
# assistant_j = j
|
1947 |
+
assistant_k = k
|
1948 |
+
|
1949 |
+
j = assistant_k
|
1950 |
+
# Given <assistant>, now find next user
|
1951 |
+
while j < n:
|
1952 |
+
# Find <user>
|
1953 |
+
# Also accept last final item if assistant is the last turn
|
1954 |
+
if (j == n_minus_1) or \
|
1955 |
+
((input_ids[j] == Q_first) and \
|
1956 |
+
(input_ids[j : (k := j + len_Q_must)] == Q_must)):
|
1957 |
+
|
1958 |
+
# Now backtrack to get previous optional tokens
|
1959 |
+
for optional_left in Q_left_reversed:
|
1960 |
+
if j < 1: break
|
1961 |
+
if optional_left == input_ids[j-1]: j -= 1
|
1962 |
+
else: break
|
1963 |
+
pass
|
1964 |
+
# And forwards look as well
|
1965 |
+
for optional_right in Q_right_forward:
|
1966 |
+
if k >= n_minus_1: break
|
1967 |
+
if optional_right == input_ids[k+1]: k += 1
|
1968 |
+
else: break
|
1969 |
+
pass
|
1970 |
+
user_j = j
|
1971 |
+
# Account for last item
|
1972 |
+
if user_j != n_minus_1:
|
1973 |
+
# user_k = k
|
1974 |
+
# j = user_k
|
1975 |
+
j = k
|
1976 |
+
else:
|
1977 |
+
user_j = n
|
1978 |
+
k = n
|
1979 |
+
pass
|
1980 |
+
# Now copy input_ids to labels
|
1981 |
+
labels[assistant_k : user_j] = input_ids[assistant_k : user_j]
|
1982 |
+
# print(assistant_j, assistant_k, user_j, user_k)
|
1983 |
+
break
|
1984 |
+
pass
|
1985 |
+
j += 1
|
1986 |
+
pass
|
1987 |
+
pass
|
1988 |
+
j += 1
|
1989 |
+
pass
|
1990 |
+
all_labels.append(labels)
|
1991 |
+
pass
|
1992 |
+
return { "labels" : all_labels }
|
1993 |
+
pass
|
1994 |
+
|
1995 |
+
if hasattr(trainer, "train_dataset") and trainer.train_dataset is not None:
|
1996 |
+
trainer.train_dataset = trainer.train_dataset.map(_train_on_responses_only, batched = True)
|
1997 |
+
if hasattr(trainer, "eval_dataset") and trainer.eval_dataset is not None:
|
1998 |
+
trainer.eval_dataset = trainer.eval_dataset.map(_train_on_responses_only, batched = True)
|
1999 |
+
return trainer
|
2000 |
+
pass
|
2001 |
+
|
2002 |
+
|
2003 |
+
def create_stopping_criteria(tokenizer, stop_word = "eos_token"):
|
2004 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
2005 |
+
__slots__ = "stop_token", "single_match", "length",
|
2006 |
+
|
2007 |
+
def __init__(self, stops = "eos_token", device = "cuda", encounters = 1):
|
2008 |
+
super().__init__()
|
2009 |
+
if stops == "eos_token":
|
2010 |
+
self.stop_token = torch.tensor(tokenizer.eos_token_id, device = "cuda")
|
2011 |
+
self.length = 1
|
2012 |
+
else:
|
2013 |
+
self.stop_token = tokenizer(["\n" + stops], add_special_tokens = False, return_tensors = "pt")
|
2014 |
+
self.stop_token = self.stop_token.input_ids.ravel()[1:].to("cuda")
|
2015 |
+
self.length = self.stop_token.shape[0]
|
2016 |
+
pass
|
2017 |
+
self.single_match = self.length == 1
|
2018 |
+
pass
|
2019 |
+
|
2020 |
+
def __call__(self, input_ids: LongTensor, scores: FloatTensor) -> bool:
|
2021 |
+
input_ids = input_ids.ravel()
|
2022 |
+
last_token = input_ids[-1]
|
2023 |
+
if self.single_match and (last_token == self.stop_token): return True
|
2024 |
+
|
2025 |
+
if input_ids.shape[0] >= self.length and \
|
2026 |
+
(input_ids[-self.length:] == self.stop_token).all(): return True
|
2027 |
+
return False
|
2028 |
+
pass
|
2029 |
+
pass
|
2030 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = stop_word)])
|
2031 |
+
return stopping_criteria
|
2032 |
+
pass
|
2033 |
+
|
2034 |
+
|
2035 |
+
def test_chat_templates():
|
2036 |
+
messages = [
|
2037 |
+
{"role": "system","content": " You are a friendly chatbot.",},
|
2038 |
+
{"role": "user", "content": "What is 2+2?"},
|
2039 |
+
{"role": "assistant", "content": "It's 4."},
|
2040 |
+
{"role": "user", "content": " But 2+2 is equal to 5. "},
|
2041 |
+
{"role": "assistant", "content": "No I'm sure its 4."},
|
2042 |
+
{"role": "user", "content": " No it's 100% 5! "},
|
2043 |
+
]
|
2044 |
+
|
2045 |
+
# Zephyr
|
2046 |
+
from transformers import AutoTokenizer
|
2047 |
+
template = zephyr_template
|
2048 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
2049 |
+
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2050 |
+
correct_tokenizer.chat_template = template
|
2051 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2052 |
+
assert(correct_prompt == our_prompt)
|
2053 |
+
|
2054 |
+
# Chatml
|
2055 |
+
template = chatml_template
|
2056 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
|
2057 |
+
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2058 |
+
correct_tokenizer.chat_template = template
|
2059 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2060 |
+
assert(correct_prompt == our_prompt)
|
2061 |
+
|
2062 |
+
# Mistral
|
2063 |
+
template = mistral_template
|
2064 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
2065 |
+
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2066 |
+
correct_tokenizer.chat_template = template
|
2067 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2068 |
+
assert(correct_prompt == our_prompt)
|
2069 |
+
|
2070 |
+
# Llama
|
2071 |
+
template = llama_template
|
2072 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-2-7b-chat")
|
2073 |
+
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2074 |
+
correct_tokenizer.chat_template = template
|
2075 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2076 |
+
assert(correct_prompt == our_prompt)
|
2077 |
+
|
2078 |
+
# Vicuna
|
2079 |
+
try:
|
2080 |
+
from fastchat.conversation import get_conv_template
|
2081 |
+
except:
|
2082 |
+
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
|
2083 |
+
from fastchat.conversation import get_conv_template
|
2084 |
+
correct_prompt = get_conv_template("vicuna_v1.1")
|
2085 |
+
for j in range(len(messages)-1):
|
2086 |
+
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
|
2087 |
+
correct_prompt.append_message(correct_prompt.roles[1], "")
|
2088 |
+
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
|
2089 |
+
|
2090 |
+
template = vicuna_template
|
2091 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
2092 |
+
correct_tokenizer.chat_template = template
|
2093 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2094 |
+
assert(correct_prompt == our_prompt)
|
2095 |
+
|
2096 |
+
try:
|
2097 |
+
from fastchat.conversation import get_conv_template
|
2098 |
+
except:
|
2099 |
+
os.system("pip -qqq install git+https://github.com/lm-sys/FastChat.git")
|
2100 |
+
from fastchat.conversation import get_conv_template
|
2101 |
+
correct_prompt = get_conv_template("zero_shot")
|
2102 |
+
for j in range(len(messages)-1):
|
2103 |
+
correct_prompt.append_message(correct_prompt.roles[j%2==1], messages[j+1]["content"])
|
2104 |
+
correct_prompt.append_message(correct_prompt.roles[1], "")
|
2105 |
+
correct_prompt = tokenizer.bos_token + correct_prompt.get_prompt()
|
2106 |
+
|
2107 |
+
template = vicuna_old_template
|
2108 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
2109 |
+
correct_tokenizer.chat_template = template
|
2110 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2111 |
+
# We add </s> ourselves
|
2112 |
+
assert(correct_prompt == our_prompt.replace("</s>", ""))
|
2113 |
+
|
2114 |
+
# Gemma
|
2115 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/gemma-7b-it")
|
2116 |
+
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2117 |
+
correct_tokenizer.chat_template = gemma_template
|
2118 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2119 |
+
assert(our_prompt == correct_prompt)
|
2120 |
+
|
2121 |
+
# Llama-3
|
2122 |
+
template = llama3_template
|
2123 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("unsloth/llama-3-8b-Instruct")
|
2124 |
+
correct_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2125 |
+
correct_tokenizer.chat_template = template
|
2126 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2127 |
+
assert(correct_prompt == our_prompt)
|
2128 |
+
|
2129 |
+
# Phi-3
|
2130 |
+
template = phi3_template
|
2131 |
+
correct_tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
|
2132 |
+
correct_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2133 |
+
correct_tokenizer.chat_template = template
|
2134 |
+
our_prompt = correct_tokenizer.apply_chat_template(messages[1:], tokenize = False, add_generation_prompt = True)
|
2135 |
+
assert(correct_prompt == our_prompt)
|
2136 |
+
pass
|
2137 |
+
|
2138 |
+
|
2139 |
+
def test_hf_gguf_equivalence(tokenizer, gguf_model = "./model-unsloth.F16.gguf"):
|
2140 |
+
"""
|
2141 |
+
Carefully checks the output of GGUF's tokenization and HF.
|
2142 |
+
Can catch all tokenization bugs.
|
2143 |
+
"""
|
2144 |
+
import subprocess
|
2145 |
+
import re
|
2146 |
+
messages = [
|
2147 |
+
{"role": "user", "content": "What is 2+2?"},
|
2148 |
+
{"role": "assistant", "content": "It's 4."},
|
2149 |
+
{"role": "user", "content": " But 2+2 is equal to 5. "},
|
2150 |
+
{"role": "assistant", "content": "No I'm sure its 4."},
|
2151 |
+
{"role": "user", "content": " No it's 100% 5! "},
|
2152 |
+
]
|
2153 |
+
|
2154 |
+
prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
2155 |
+
|
2156 |
+
### Instruction:
|
2157 |
+
{}
|
2158 |
+
|
2159 |
+
### Input:
|
2160 |
+
{}
|
2161 |
+
|
2162 |
+
### Response:
|
2163 |
+
{}""".format(
|
2164 |
+
"Describe the city given eloquently.", # instruction
|
2165 |
+
"The lost city of Atlantis.", # input
|
2166 |
+
"", # output - leave this blank for generation!
|
2167 |
+
)
|
2168 |
+
prompts = [ prompt, ]
|
2169 |
+
|
2170 |
+
if tokenizer.chat_template is not None:
|
2171 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
|
2172 |
+
prompt = prompt.replace("'", "") # Subprocess does not like ''
|
2173 |
+
prompt = remove_special_tokens(tokenizer, prompt)
|
2174 |
+
prompts.append(prompt)
|
2175 |
+
pass
|
2176 |
+
|
2177 |
+
for prompt in prompts:
|
2178 |
+
command = f"./llama.cpp/llama-cli -m {gguf_model} -n 0 --temp 0.0 --verbose-prompt "\
|
2179 |
+
f"--check-tensors -p '{prompt}'"
|
2180 |
+
|
2181 |
+
datas = []
|
2182 |
+
with subprocess.Popen(command, shell = True, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, bufsize = 1) as sp:
|
2183 |
+
for line in sp.stdout:
|
2184 |
+
datas.append(line.decode("utf-8", errors = "replace"))
|
2185 |
+
pass
|
2186 |
+
gguf_tokens = "".join(datas)
|
2187 |
+
|
2188 |
+
# Now extract GGUF tokenization attempt
|
2189 |
+
gguf_tokenized = re.findall("([\d]{1,}) \-\> \'([^\']{1,})\'", gguf_tokens, flags = re.MULTILINE)
|
2190 |
+
gguf_tokenized = [(int(x[0]), x[1],) for x in gguf_tokenized]
|
2191 |
+
input_ids = tokenizer(prompt).input_ids
|
2192 |
+
|
2193 |
+
tokens = tokenizer.batch_decode(input_ids)
|
2194 |
+
hf_tokenized = list(zip(input_ids, tokens))
|
2195 |
+
|
2196 |
+
# Compare to Huggingface
|
2197 |
+
for j, (hf_token, gguf_token) in enumerate(zip(hf_tokenized, gguf_tokenized)):
|
2198 |
+
if (hf_token[0] != gguf_token[0]):
|
2199 |
+
print("Failed GGUF != HF at", j)
|
2200 |
+
print("HF =", hf_token)
|
2201 |
+
print("GGUF =", gguf_token)
|
2202 |
+
print(hf_tokenized)
|
2203 |
+
print()
|
2204 |
+
print(gguf_tokenized)
|
2205 |
+
print()
|
2206 |
+
raise RuntimeError("Failed comparing GGUF to HF.")
|
2207 |
+
pass
|
2208 |
+
pass
|
2209 |
+
return True
|
2210 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/__init__.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .cross_entropy_loss import (
|
16 |
+
fast_cross_entropy_loss,
|
17 |
+
patch_llama_for_causal_lm,
|
18 |
+
unpatch_llama_for_causal_lm,
|
19 |
+
)
|
20 |
+
from .rms_layernorm import (
|
21 |
+
fast_rms_layernorm,
|
22 |
+
patch_rms_layernorm,
|
23 |
+
unpatch_rms_layernorm,
|
24 |
+
)
|
25 |
+
from .layernorm import (
|
26 |
+
fast_layernorm,
|
27 |
+
patch_layernorm,
|
28 |
+
unpatch_layernorm,
|
29 |
+
)
|
30 |
+
from .rope_embedding import fast_rope_embedding, inplace_rope_embedding
|
31 |
+
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
32 |
+
from .geglu import (
|
33 |
+
geglu_exact_forward_kernel,
|
34 |
+
geglu_exact_backward_kernel,
|
35 |
+
geglu_approx_forward_kernel,
|
36 |
+
geglu_approx_backward_kernel,
|
37 |
+
)
|
38 |
+
from .fast_lora import (
|
39 |
+
get_lora_parameters,
|
40 |
+
get_lora_parameters_bias,
|
41 |
+
apply_lora_mlp_swiglu,
|
42 |
+
apply_lora_mlp_geglu_exact,
|
43 |
+
apply_lora_mlp_geglu_approx,
|
44 |
+
apply_lora_qkv,
|
45 |
+
apply_lora_o,
|
46 |
+
)
|
47 |
+
from .utils import fast_dequantize, fast_gemv, QUANT_STATE, fast_linear_forward, matmul_lora
|
48 |
+
|
49 |
+
from .flex_attention import (
|
50 |
+
HAS_FLEX_ATTENTION,
|
51 |
+
slow_attention_softcapping,
|
52 |
+
slow_inference_attention_softcapping,
|
53 |
+
create_flex_attention_causal_mask,
|
54 |
+
create_flex_attention_sliding_window_mask,
|
55 |
+
)
|
56 |
+
|
57 |
+
try:
|
58 |
+
print("🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.")
|
59 |
+
except:
|
60 |
+
print("Unsloth: Will patch your computer to enable 2x faster free finetuning.")
|
61 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/cross_entropy_loss.py
ADDED
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings, MAX_FUSED_SIZE, triton_tanh
|
19 |
+
from transformers.models.llama.modeling_llama import logger
|
20 |
+
|
21 |
+
|
22 |
+
@triton.heuristics({
|
23 |
+
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
|
24 |
+
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
|
25 |
+
})
|
26 |
+
@triton.jit
|
27 |
+
def _cross_entropy_forward(
|
28 |
+
logits_ptr, logits_row_stride,
|
29 |
+
loss_ptr,
|
30 |
+
logsumexp_ptr,
|
31 |
+
labels_ptr,
|
32 |
+
VOCAB_SIZE : tl.constexpr,
|
33 |
+
BLOCK_SIZE : tl.constexpr,
|
34 |
+
DO_SOFTCAPPING : tl.constexpr,
|
35 |
+
SOFTCAP : tl.constexpr,
|
36 |
+
DO_LOGIT_SCALING: tl.constexpr,
|
37 |
+
LOGIT_SCALE : tl.constexpr,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
41 |
+
Pi = exp(xi) / sum(exp(xi))
|
42 |
+
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
43 |
+
= -y [ x - log[sum(exp(x))] ]
|
44 |
+
= y * (log[sum(exp(x))] - x)
|
45 |
+
If y == 0: CE_i = 0
|
46 |
+
If y == 1: CE_i = logsumexp - x
|
47 |
+
|
48 |
+
logsumexp is also stable
|
49 |
+
Take y = log[sum(exp(x))]
|
50 |
+
exp(y) = sum(exp(x))
|
51 |
+
exp(y) = sum(exp(x - c)*exp(c)) Since e^(x-c)*e^c = e^x
|
52 |
+
exp(y) = exp(c)*sum(exp(x - c))
|
53 |
+
y = log(exp(c)*sum(exp(x - c)))
|
54 |
+
y = c + log[sum(exp(x - c))]
|
55 |
+
This means we can set c = max(x) to make sure
|
56 |
+
exp(x - c) always is exp(x - max(x)).
|
57 |
+
This ensures exp(x - max(x))'s maximum is 1 as exp(0) = 1.
|
58 |
+
"""
|
59 |
+
row_idx = tl.program_id(0)
|
60 |
+
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
|
61 |
+
loss_ptr += row_idx
|
62 |
+
logsumexp_ptr += row_idx
|
63 |
+
labels_ptr += row_idx
|
64 |
+
|
65 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
66 |
+
mask = col_offsets < VOCAB_SIZE
|
67 |
+
|
68 |
+
label_idx = tl.load(labels_ptr).to(tl.int32)
|
69 |
+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
|
70 |
+
|
71 |
+
# Go logit scaling for Cohere: t * x
|
72 |
+
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
73 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
74 |
+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
75 |
+
|
76 |
+
logits = logits.to(tl.float32)
|
77 |
+
c = tl.max(logits, 0)
|
78 |
+
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
79 |
+
|
80 |
+
if label_idx != -100:
|
81 |
+
x = tl.load(logits_ptr + label_idx)
|
82 |
+
# Go logit scaling for Cohere: t * x
|
83 |
+
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
84 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
85 |
+
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
86 |
+
loss = logsumexp - x.to(tl.float32)
|
87 |
+
else:
|
88 |
+
loss = 0.0
|
89 |
+
tl.store(logsumexp_ptr, logsumexp)
|
90 |
+
tl.store(loss_ptr, loss)
|
91 |
+
pass
|
92 |
+
|
93 |
+
|
94 |
+
@triton.heuristics({
|
95 |
+
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
|
96 |
+
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
|
97 |
+
})
|
98 |
+
@triton.jit
|
99 |
+
def _chunked_cross_entropy_forward(
|
100 |
+
logits_ptr, logits_row_stride,
|
101 |
+
loss_ptr,
|
102 |
+
logsumexp_ptr,
|
103 |
+
labels_ptr,
|
104 |
+
VOCAB_SIZE : tl.constexpr,
|
105 |
+
N_CHUNKS : tl.constexpr,
|
106 |
+
BLOCK_SIZE : tl.constexpr,
|
107 |
+
DO_SOFTCAPPING : tl.constexpr,
|
108 |
+
SOFTCAP : tl.constexpr,
|
109 |
+
DO_LOGIT_SCALING: tl.constexpr,
|
110 |
+
LOGIT_SCALE : tl.constexpr,
|
111 |
+
):
|
112 |
+
"""
|
113 |
+
256K vocab divided in 4 chunks
|
114 |
+
|
115 |
+
|-65536-| |-65536-| |-65536-| |-65536-|
|
116 |
+
|-------| |-------| |-------| |-------|
|
117 |
+
|-------| |-------| |-------| |-------|
|
118 |
+
|
119 |
+
If y == 0: CE_i = 0
|
120 |
+
If y == 1: CE_i = logsumexp - x
|
121 |
+
|
122 |
+
Notice we can do logsumexp for each chunk and then
|
123 |
+
logsumexp[chunk_sum(logsumexp)] == logsumexp
|
124 |
+
|
125 |
+
chunk_sum = log[chunk_sum(logsumexp)]
|
126 |
+
= log[exp(logsumexp(a)) + ... + exp(logsumexp(z))]
|
127 |
+
= log[exp(log[sum(exp(a))]) + ... + exp(log[sum(exp(z))])]
|
128 |
+
= log[sum(exp(a)) + ... + sum(exp(z))]
|
129 |
+
= logsumexp(x)
|
130 |
+
|
131 |
+
This means we can perform a logsumexp for each chunk, then do a
|
132 |
+
final logsumexp reduction!
|
133 |
+
|
134 |
+
Ie do: logsumexp(chunked_logsumexp) - x
|
135 |
+
"""
|
136 |
+
row_idx = tl.program_id(0)
|
137 |
+
chunk_idx = tl.program_id(1)
|
138 |
+
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
|
139 |
+
loss_ptr += row_idx
|
140 |
+
logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx
|
141 |
+
labels_ptr += row_idx
|
142 |
+
|
143 |
+
col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
144 |
+
mask = col_offsets < VOCAB_SIZE
|
145 |
+
|
146 |
+
label_idx = tl.load(labels_ptr).to(tl.int32)
|
147 |
+
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
|
148 |
+
|
149 |
+
# Go logit scaling for Cohere: t * x
|
150 |
+
if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits
|
151 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
152 |
+
if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)
|
153 |
+
|
154 |
+
logits = logits.to(tl.float32)
|
155 |
+
c = tl.max(logits, 0)
|
156 |
+
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))
|
157 |
+
|
158 |
+
if chunk_idx == 0:
|
159 |
+
# logsumexp(chunked_logsumexp) - x
|
160 |
+
# Do the -x separately
|
161 |
+
if label_idx != -100:
|
162 |
+
x = tl.load(logits_ptr + label_idx).to(tl.float32)
|
163 |
+
# Go logit scaling for Cohere: t * x
|
164 |
+
if DO_LOGIT_SCALING: x = LOGIT_SCALE * x
|
165 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
166 |
+
if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)
|
167 |
+
loss = -1.0 * x.to(tl.float32)
|
168 |
+
else:
|
169 |
+
loss = 0.0
|
170 |
+
tl.store(loss_ptr, loss)
|
171 |
+
pass
|
172 |
+
tl.store(logsumexp_ptr, logsumexp)
|
173 |
+
pass
|
174 |
+
|
175 |
+
|
176 |
+
@triton.heuristics({
|
177 |
+
"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING" ],
|
178 |
+
"DO_LOGIT_SCALING": lambda args: args["DO_LOGIT_SCALING"],
|
179 |
+
})
|
180 |
+
@triton.jit
|
181 |
+
def _cross_entropy_backward(
|
182 |
+
logits_ptr, logits_row_stride,
|
183 |
+
dloss_ptr, dloss_row_stride,
|
184 |
+
logsumexp_ptr,
|
185 |
+
labels_ptr,
|
186 |
+
VOCAB_SIZE : tl.constexpr,
|
187 |
+
BLOCK_SIZE : tl.constexpr,
|
188 |
+
DO_SOFTCAPPING : tl.constexpr,
|
189 |
+
SOFTCAP : tl.constexpr,
|
190 |
+
DO_LOGIT_SCALING: tl.constexpr,
|
191 |
+
LOGIT_SCALE : tl.constexpr,
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
195 |
+
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
196 |
+
|
197 |
+
From https://en.wikipedia.org/wiki/LogSumExp
|
198 |
+
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
199 |
+
|
200 |
+
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
201 |
+
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
202 |
+
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
203 |
+
|
204 |
+
If y == 0: dC/dx = 0
|
205 |
+
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
206 |
+
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
207 |
+
"""
|
208 |
+
row_idx = tl.program_id(0)
|
209 |
+
block_idx = tl.program_id(1)
|
210 |
+
|
211 |
+
logits_ptr += row_idx * logits_row_stride.to(tl.int64)
|
212 |
+
dloss_ptr += row_idx * dloss_row_stride
|
213 |
+
col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
214 |
+
mask = col_offsets < VOCAB_SIZE
|
215 |
+
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
216 |
+
|
217 |
+
if label_idx != -100:
|
218 |
+
dloss = tl.load(dloss_ptr)
|
219 |
+
else:
|
220 |
+
dloss = 0.0
|
221 |
+
|
222 |
+
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
|
223 |
+
|
224 |
+
# Do logit scaling for Cohere
|
225 |
+
if DO_LOGIT_SCALING:
|
226 |
+
# d/dx [s * x] = s
|
227 |
+
x = x * LOGIT_SCALE
|
228 |
+
pass
|
229 |
+
|
230 |
+
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
|
231 |
+
if DO_SOFTCAPPING:
|
232 |
+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
233 |
+
partial = triton_tanh(x / SOFTCAP)
|
234 |
+
x = SOFTCAP * partial
|
235 |
+
pass
|
236 |
+
|
237 |
+
logsumexp = tl.load(logsumexp_ptr + row_idx)
|
238 |
+
y = tl.exp(x.to(tl.float32) - logsumexp)
|
239 |
+
y = tl.where(
|
240 |
+
col_offsets == label_idx,
|
241 |
+
y - 1.0, # exp(x - logsumexp) - 1
|
242 |
+
y, # exp(x - logsumexp)
|
243 |
+
)
|
244 |
+
|
245 |
+
if DO_LOGIT_SCALING:
|
246 |
+
# d/dx [s * x] = s
|
247 |
+
y = y * LOGIT_SCALE
|
248 |
+
pass
|
249 |
+
|
250 |
+
if DO_SOFTCAPPING:
|
251 |
+
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
|
252 |
+
y = y * (1.0 - partial*partial)
|
253 |
+
pass
|
254 |
+
|
255 |
+
# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
|
256 |
+
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
|
257 |
+
pass
|
258 |
+
|
259 |
+
|
260 |
+
MAX_FUSED_SIZE = 65536 # 2**16
|
261 |
+
|
262 |
+
class Fast_CrossEntropyLoss(torch.autograd.Function):
|
263 |
+
@staticmethod
|
264 |
+
def forward(ctx, logits, labels, logit_softcapping = 0, logit_scaling = 0):
|
265 |
+
n_rows, vocab_size = logits.shape
|
266 |
+
|
267 |
+
div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
|
268 |
+
n_chunks = div + (mod != 0)
|
269 |
+
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
270 |
+
|
271 |
+
DO_SOFTCAPPING = (logit_softcapping != 0)
|
272 |
+
DO_LOGIT_SCALING = (logit_scaling != 0)
|
273 |
+
|
274 |
+
if n_chunks == 1:
|
275 |
+
# For small vocabs <= 65336 like Llama, Mistral
|
276 |
+
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
|
277 |
+
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
278 |
+
|
279 |
+
_cross_entropy_forward[(n_rows,)](
|
280 |
+
logits, logits.stride(0),
|
281 |
+
losses,
|
282 |
+
logsumexp,
|
283 |
+
labels,
|
284 |
+
VOCAB_SIZE = vocab_size,
|
285 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
286 |
+
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
287 |
+
SOFTCAP = logit_softcapping,
|
288 |
+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
289 |
+
LOGIT_SCALE = logit_scaling,
|
290 |
+
num_warps = num_warps,
|
291 |
+
)
|
292 |
+
else:
|
293 |
+
# For large vocabs > 65336 like Gemma 256K
|
294 |
+
logsumexp = torch.empty((n_rows, n_chunks,), dtype = torch.float32, device = "cuda:0")
|
295 |
+
|
296 |
+
_chunked_cross_entropy_forward[(n_rows, n_chunks,)](
|
297 |
+
logits, logits.stride(0),
|
298 |
+
losses,
|
299 |
+
logsumexp,
|
300 |
+
labels,
|
301 |
+
VOCAB_SIZE = vocab_size,
|
302 |
+
N_CHUNKS = n_chunks,
|
303 |
+
BLOCK_SIZE = MAX_FUSED_SIZE,
|
304 |
+
DO_SOFTCAPPING = DO_SOFTCAPPING,
|
305 |
+
SOFTCAP = logit_softcapping,
|
306 |
+
DO_LOGIT_SCALING = DO_LOGIT_SCALING,
|
307 |
+
LOGIT_SCALE = logit_scaling,
|
308 |
+
num_warps = 32,
|
309 |
+
)
|
310 |
+
# logsumexp(chunked_logsumexp) - x
|
311 |
+
# Do the -x separately
|
312 |
+
logsumexp = torch.logsumexp(logsumexp, dim = 1) # Row sum
|
313 |
+
losses += logsumexp
|
314 |
+
losses.masked_fill_(labels == -100, 0) # Don't forget to mask padding out!
|
315 |
+
pass
|
316 |
+
|
317 |
+
ctx.save_for_backward(logits, logsumexp, labels)
|
318 |
+
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
|
319 |
+
ctx.logit_softcapping = logit_softcapping
|
320 |
+
ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING
|
321 |
+
ctx.logit_scaling = logit_scaling
|
322 |
+
return losses
|
323 |
+
pass
|
324 |
+
|
325 |
+
@staticmethod
|
326 |
+
def backward(ctx, dlosses):
|
327 |
+
logits, logsumexp, labels = ctx.saved_tensors
|
328 |
+
n_rows, vocab_size = logits.shape
|
329 |
+
|
330 |
+
BLOCK_SIZE = 4096
|
331 |
+
div, mod = divmod(vocab_size, BLOCK_SIZE)
|
332 |
+
n_blocks = div + (mod != 0)
|
333 |
+
|
334 |
+
_cross_entropy_backward[(n_rows, n_blocks,)](
|
335 |
+
logits, logits.stride(0),
|
336 |
+
dlosses, dlosses.stride(0),
|
337 |
+
logsumexp,
|
338 |
+
labels,
|
339 |
+
VOCAB_SIZE = vocab_size,
|
340 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
341 |
+
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
|
342 |
+
SOFTCAP = ctx.logit_softcapping,
|
343 |
+
DO_LOGIT_SCALING = ctx.DO_LOGIT_SCALING,
|
344 |
+
LOGIT_SCALE = ctx.logit_scaling,
|
345 |
+
num_warps = 8,
|
346 |
+
)
|
347 |
+
return logits, None, None, None,
|
348 |
+
pass
|
349 |
+
pass
|
350 |
+
|
351 |
+
|
352 |
+
@torch._disable_dynamo
|
353 |
+
def fast_cross_entropy_loss(
|
354 |
+
logits,
|
355 |
+
labels,
|
356 |
+
logit_softcapping = 0,
|
357 |
+
logit_scaling = 0,
|
358 |
+
):
|
359 |
+
"""
|
360 |
+
Arguments:
|
361 |
+
logits: (batch, seq_len, vocab_size)
|
362 |
+
labels: (batch, seq_len,)
|
363 |
+
Returns:
|
364 |
+
losses: float
|
365 |
+
"""
|
366 |
+
batch, seq_len, d = logits.shape
|
367 |
+
assert(labels.shape == (batch, seq_len))
|
368 |
+
|
369 |
+
loss = Fast_CrossEntropyLoss.apply(
|
370 |
+
logits.view(batch*seq_len, d),
|
371 |
+
labels.view(-1),
|
372 |
+
logit_softcapping,
|
373 |
+
logit_scaling,
|
374 |
+
)
|
375 |
+
n_items = torch.count_nonzero(labels != -100)
|
376 |
+
return loss.sum() / n_items
|
377 |
+
pass
|
378 |
+
|
379 |
+
|
380 |
+
from transformers.models.llama.modeling_llama import (
|
381 |
+
LlamaForCausalLM,
|
382 |
+
CausalLMOutputWithPast,
|
383 |
+
Optional,
|
384 |
+
Union,
|
385 |
+
Cache,
|
386 |
+
List,
|
387 |
+
Tuple,
|
388 |
+
)
|
389 |
+
import inspect, re
|
390 |
+
function = inspect.getsource(LlamaForCausalLM.forward)
|
391 |
+
function = function.split("\n")
|
392 |
+
i = re.match(r"[ ]{1,}", function[0]).span(0)[1]
|
393 |
+
function = [x[i:] for x in function]
|
394 |
+
function = "\n".join(function)
|
395 |
+
function = function[function.find("def forward"):]
|
396 |
+
replacement = """ loss = None
|
397 |
+
logit_softcapping = getattr(self.config, "final_logit_softcapping", 0)
|
398 |
+
logit_scaling = getattr(self.config, "logit_scale", 0)
|
399 |
+
if labels is not None:
|
400 |
+
shift_logits = logits
|
401 |
+
if not hasattr(self, "extra_ignored_labels"):
|
402 |
+
# Fixes https://github.com/unslothai/unsloth/issues/10
|
403 |
+
self.extra_ignored_labels = torch.full((self.max_seq_length, 1), -100, device = "cuda:0")
|
404 |
+
pass
|
405 |
+
|
406 |
+
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
407 |
+
loss = fast_cross_entropy_loss(
|
408 |
+
logits = shift_logits,
|
409 |
+
labels = shift_labels,
|
410 |
+
logit_softcapping = logit_softcapping,
|
411 |
+
logit_scaling = logit_scaling,
|
412 |
+
)
|
413 |
+
else:
|
414 |
+
if logit_scaling != 0:
|
415 |
+
if logits.requires_grad:
|
416 |
+
logits = logit_scaling * logits
|
417 |
+
else:
|
418 |
+
logits *= logit_scaling
|
419 |
+
pass
|
420 |
+
pass
|
421 |
+
if logit_softcapping != 0:
|
422 |
+
if logits.requires_grad:
|
423 |
+
logits = (1.0 / logit_softcapping) * logits
|
424 |
+
logits = torch.tanh(logits)
|
425 |
+
logits = logit_softcapping * logits
|
426 |
+
else:
|
427 |
+
logits *= (1.0 / logit_softcapping)
|
428 |
+
torch.tanh(logits, out = logits)
|
429 |
+
logits *= logit_softcapping
|
430 |
+
pass
|
431 |
+
pass
|
432 |
+
pass
|
433 |
+
"""
|
434 |
+
function = \
|
435 |
+
function[:function.find(" loss = None")] + \
|
436 |
+
replacement + \
|
437 |
+
function[ function.find(" if not return_dict"):]
|
438 |
+
function = function.replace("logits = logits.float()", "\n")
|
439 |
+
# Missed spaces
|
440 |
+
function = function.split("\n")
|
441 |
+
# Not the first one though!
|
442 |
+
function = [function[0]] + [" "*4 + x for x in function[1:]]
|
443 |
+
function = "\n".join(function)
|
444 |
+
function = f"class Unsloth_LlamaForCausalLM(LlamaForCausalLM):\n"\
|
445 |
+
f" {function}\n"
|
446 |
+
exec(function, globals())
|
447 |
+
del function, replacement, inspect, re
|
448 |
+
|
449 |
+
|
450 |
+
def patch_llama_for_causal_lm():
|
451 |
+
import transformers.models.llama.modeling_llama
|
452 |
+
transformers.models.llama.modeling_llama.LlamaForCausalLM = Unsloth_LlamaForCausalLM
|
453 |
+
return
|
454 |
+
pass
|
455 |
+
|
456 |
+
|
457 |
+
def unpatch_llama_for_causal_lm():
|
458 |
+
import transformers.models.llama.modeling_llama
|
459 |
+
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
|
460 |
+
return
|
461 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/fast_lora.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from .utils import (
|
17 |
+
fast_dequantize,
|
18 |
+
QUANT_STATE,
|
19 |
+
get_lora_parameters,
|
20 |
+
get_lora_parameters_bias,
|
21 |
+
matmul_lora,
|
22 |
+
torch_amp_custom_fwd,
|
23 |
+
torch_amp_custom_bwd,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class LoRA_MLP(torch.autograd.Function):
|
28 |
+
"""
|
29 |
+
### LoRA weights
|
30 |
+
G = G + Ag @ Bg
|
31 |
+
U = U + Au @ Bu
|
32 |
+
W = W + Aw @ Bw
|
33 |
+
|
34 |
+
### SwiGLU(X)
|
35 |
+
e = X @ G
|
36 |
+
f = e * sigmoid(e)
|
37 |
+
g = X @ U
|
38 |
+
h = f * g
|
39 |
+
i = h @ W
|
40 |
+
|
41 |
+
### Backpropagation chain rule
|
42 |
+
See our blog post for more details
|
43 |
+
|
44 |
+
df = sigmoid(e) * (1 - f) + f
|
45 |
+
dC/dW = h.T @ dY
|
46 |
+
dC/dU = X.T @ (D @ W.T * f)
|
47 |
+
dC/dG = X.T @ (D @ W.T * df * g)
|
48 |
+
|
49 |
+
### Down projection LoRA weights
|
50 |
+
dC/dAw = dC/dW @ B.T
|
51 |
+
dC/dBw = A.T @ dC/dW
|
52 |
+
dC/dAw = h.T @ dY @ B.T
|
53 |
+
dC/dBw = A.T @ h.T @ dY
|
54 |
+
|
55 |
+
### Up projection LoRA weights
|
56 |
+
dC/dAu = X.T @ (D @ W.T * f) @ B.T
|
57 |
+
dC/dBu = A.T @ X.T @ (D @ W.T * f)
|
58 |
+
|
59 |
+
### Gate projection LoRA weights
|
60 |
+
dC/dAg = X.T @ (D @ W.T * df * g) @ B.T
|
61 |
+
dC/dBg = A.T @ X.T @ (D @ W.T * df * g)
|
62 |
+
|
63 |
+
Don't forget to see our blog post for more details!
|
64 |
+
"""
|
65 |
+
@staticmethod
|
66 |
+
@torch_amp_custom_fwd
|
67 |
+
def forward(ctx, X : torch.Tensor,
|
68 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
69 |
+
upW, upW_quant, upA, upB, upS,
|
70 |
+
downW, downW_quant, downA, downB, downS,
|
71 |
+
_forward_function, _backward_function,
|
72 |
+
inplace = True,):
|
73 |
+
dtype = X.dtype
|
74 |
+
|
75 |
+
e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS)
|
76 |
+
g = matmul_lora(X, upW, upW_quant, upA, upB, upS)
|
77 |
+
h = _forward_function(e, g)
|
78 |
+
i = matmul_lora(h, downW, downW_quant, downA, downB, downS)
|
79 |
+
|
80 |
+
ctx.custom_saved_tensors = (
|
81 |
+
gateW, gateW_quant, gateS,
|
82 |
+
upW, upW_quant, upS,
|
83 |
+
downW, downW_quant, downS,
|
84 |
+
_backward_function,
|
85 |
+
)
|
86 |
+
ctx.save_for_backward(gateA, gateB, upA, upB, downA, downB,
|
87 |
+
X, e, g)
|
88 |
+
ctx.inplace = inplace
|
89 |
+
return i
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
@torch_amp_custom_bwd
|
95 |
+
def backward(ctx, dY : torch.Tensor):
|
96 |
+
gateW, gateW_quant, gateS, upW, upW_quant, upS, downW, downW_quant, downS, \
|
97 |
+
_backward_function = ctx.custom_saved_tensors
|
98 |
+
gateA, gateB, upA, upB, downA, downB, \
|
99 |
+
X, e, g = ctx.saved_tensors
|
100 |
+
|
101 |
+
gateA, gateB, upA, upB, downA, downB = \
|
102 |
+
gateA.t(), gateB.t(), upA.t(), upB.t(), downA.t(), downB.t()
|
103 |
+
|
104 |
+
batch, seq_len, hd = X.shape
|
105 |
+
dY = dY.view(-1, dY.shape[-1])
|
106 |
+
X = X .view(-1, X .shape[-1])
|
107 |
+
e = e .view(-1, e .shape[-1])
|
108 |
+
g = g .view(-1, g .shape[-1])
|
109 |
+
dtype = X.dtype
|
110 |
+
|
111 |
+
DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
|
112 |
+
DW, e, g = _backward_function(DW, e, g)
|
113 |
+
h, df, de = DW, e, g
|
114 |
+
|
115 |
+
# Down projection LoRA weights
|
116 |
+
d_downA = h.t() @ (dY @ downB.t())
|
117 |
+
d_downB = (downA.t() @ h.t()) @ dY
|
118 |
+
d_downA *= downS
|
119 |
+
d_downB *= downS
|
120 |
+
|
121 |
+
# Up projection LoRA weights
|
122 |
+
d_upA = X.t() @ (df @ upB.t())
|
123 |
+
d_upB = (upA.t() @ X.t()) @ df
|
124 |
+
d_upA *= upS
|
125 |
+
d_upB *= upS
|
126 |
+
|
127 |
+
# Gate projection LoRA weights
|
128 |
+
d_gateA = X.t() @ (de @ gateB.t())
|
129 |
+
d_gateB = (gateA.t() @ X.t()) @ de
|
130 |
+
d_gateA *= gateS
|
131 |
+
d_gateB *= gateS
|
132 |
+
|
133 |
+
# dX = matmul_lora(df, upW.t(), upW_quant, upB, upA, upS)
|
134 |
+
# dX += matmul_lora(de, gateW.t(), gateW_quant, gateB, gateA, gateS)
|
135 |
+
upW = fast_dequantize(upW.t(), upW_quant)
|
136 |
+
dX = torch.matmul(df, upW.t(), out = X if ctx.inplace else None)
|
137 |
+
del upW
|
138 |
+
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
|
139 |
+
|
140 |
+
gateW = fast_dequantize(gateW.t(), gateW_quant)
|
141 |
+
dX += de @ gateW.t()
|
142 |
+
del gateW
|
143 |
+
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
|
144 |
+
|
145 |
+
# gateW, gateW_quant, gateA, gateB, gateS,
|
146 |
+
# upW, upW_quant, upA, upB, upS,
|
147 |
+
# downW, downW_quant, downA, downB, downS,
|
148 |
+
return dX.view(batch, seq_len, hd), \
|
149 |
+
None, None, d_gateA.t(), d_gateB.t(), None, \
|
150 |
+
None, None, d_upA.t(), d_upB.t(), None, \
|
151 |
+
None, None, d_downA.t(), d_downB.t(), None, \
|
152 |
+
None, None, None, # _backward and _forward and inplace
|
153 |
+
pass
|
154 |
+
pass
|
155 |
+
|
156 |
+
|
157 |
+
from .swiglu import swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel
|
158 |
+
def apply_lora_mlp_swiglu(self, X, inplace = True):
|
159 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
160 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
161 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
162 |
+
out = LoRA_MLP.apply(X,
|
163 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
164 |
+
upW, upW_quant, upA, upB, upS,
|
165 |
+
downW, downW_quant, downA, downB, downS,
|
166 |
+
swiglu_fg_kernel, swiglu_DWf_DW_dfg_kernel,
|
167 |
+
inplace,)
|
168 |
+
return out
|
169 |
+
pass
|
170 |
+
|
171 |
+
|
172 |
+
from .geglu import geglu_exact_forward_kernel, geglu_exact_backward_kernel
|
173 |
+
def apply_lora_mlp_geglu_exact(self, X, inplace = True):
|
174 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
175 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
176 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
177 |
+
out = LoRA_MLP.apply(X,
|
178 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
179 |
+
upW, upW_quant, upA, upB, upS,
|
180 |
+
downW, downW_quant, downA, downB, downS,
|
181 |
+
geglu_exact_forward_kernel, geglu_exact_backward_kernel,
|
182 |
+
inplace,)
|
183 |
+
return out
|
184 |
+
pass
|
185 |
+
|
186 |
+
|
187 |
+
from .geglu import geglu_approx_forward_kernel, geglu_approx_backward_kernel
|
188 |
+
def apply_lora_mlp_geglu_approx(self, X):
|
189 |
+
gateW, gateW_quant, gateA, gateB, gateS = get_lora_parameters(self.gate_proj)
|
190 |
+
upW, upW_quant, upA, upB, upS = get_lora_parameters(self. up_proj)
|
191 |
+
downW, downW_quant, downA, downB, downS = get_lora_parameters(self.down_proj)
|
192 |
+
out = LoRA_MLP.apply(X,
|
193 |
+
gateW, gateW_quant, gateA, gateB, gateS,
|
194 |
+
upW, upW_quant, upA, upB, upS,
|
195 |
+
downW, downW_quant, downA, downB, downS,
|
196 |
+
geglu_approx_forward_kernel, geglu_approx_backward_kernel,)
|
197 |
+
return out
|
198 |
+
pass
|
199 |
+
|
200 |
+
|
201 |
+
class LoRA_QKV(torch.autograd.Function):
|
202 |
+
"""
|
203 |
+
### LoRA weights
|
204 |
+
Wq = Wq + Aq @ Bq
|
205 |
+
Wk = Wk + Ak @ Bk
|
206 |
+
Wv = Wv + Av @ Bv
|
207 |
+
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
208 |
+
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
209 |
+
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
210 |
+
|
211 |
+
### Backpropagation chain rule
|
212 |
+
See our blogpost for more details.
|
213 |
+
|
214 |
+
dC/dWq = X.T @ D(Wq)
|
215 |
+
dC/dWk = X.T @ D(Wk)
|
216 |
+
dC/dWv = X.T @ D(Wv)
|
217 |
+
We then sum them all find dC/dX
|
218 |
+
|
219 |
+
### Q projection LoRA weights
|
220 |
+
dC/dAq = X.T @ D(Wq) @ B.T
|
221 |
+
dC/dBq = A.T @ X.T @ D(Wq)
|
222 |
+
|
223 |
+
### K projection LoRA weights
|
224 |
+
dC/dAk = X.T @ D(Wk) @ B.T
|
225 |
+
dC/dBk = A.T @ X.T @ D(Wk)
|
226 |
+
|
227 |
+
### V projection LoRA weights
|
228 |
+
dC/dAv = X.T @ D(Wv) @ B.T
|
229 |
+
dC/dBv = A.T @ X.T @ D(Wv)
|
230 |
+
"""
|
231 |
+
@staticmethod
|
232 |
+
@torch_amp_custom_fwd
|
233 |
+
def forward(ctx, X : torch.Tensor,
|
234 |
+
QW, QW_quant, QA, QB, QS,
|
235 |
+
KW, KW_quant, KA, KB, KS,
|
236 |
+
VW, VW_quant, VA, VB, VS,
|
237 |
+
inplace = True):
|
238 |
+
dtype = X.dtype
|
239 |
+
|
240 |
+
Q = matmul_lora(X, QW, QW_quant, QA, QB, QS)
|
241 |
+
K = matmul_lora(X, KW, KW_quant, KA, KB, KS)
|
242 |
+
V = matmul_lora(X, VW, VW_quant, VA, VB, VS)
|
243 |
+
|
244 |
+
ctx.custom_saved_tensors = (
|
245 |
+
QW, QW_quant, QS,
|
246 |
+
KW, KW_quant, KS,
|
247 |
+
VW, VW_quant, VS,
|
248 |
+
)
|
249 |
+
ctx.save_for_backward(X, QA, QB, KA, KB, VA, VB,)
|
250 |
+
ctx.inplace = inplace
|
251 |
+
return Q, K, V
|
252 |
+
pass
|
253 |
+
|
254 |
+
@staticmethod
|
255 |
+
@torch_amp_custom_bwd
|
256 |
+
def backward(ctx, dQ, dK, dV):
|
257 |
+
QW, QW_quant, QS, KW, KW_quant, KS, VW, VW_quant, VS = \
|
258 |
+
ctx.custom_saved_tensors
|
259 |
+
X, QA, QB, KA, KB, VA, VB, = ctx.saved_tensors
|
260 |
+
|
261 |
+
QA, QB, KA, KB, VA, VB = \
|
262 |
+
QA.t(), QB.t(), KA.t(), KB.t(), VA.t(), VB.t()
|
263 |
+
|
264 |
+
batch, seq_len, hd = X.shape
|
265 |
+
dQ = dQ.view(-1, dQ.shape[-1])
|
266 |
+
dK = dK.reshape(-1, dK.shape[-1]) # view doesn't work on K.T
|
267 |
+
dV = dV.view(-1, dV.shape[-1])
|
268 |
+
X = X .view(-1, X .shape[-1])
|
269 |
+
dtype = X.dtype
|
270 |
+
|
271 |
+
### Weight projection LoRA weights
|
272 |
+
# See our blogpost for more details.
|
273 |
+
|
274 |
+
# Q Projection
|
275 |
+
d_QA = X.t() @ (dQ @ QB.t())
|
276 |
+
d_QB = (QA.t() @ X.t()) @ dQ
|
277 |
+
d_QA *= QS
|
278 |
+
d_QB *= QS
|
279 |
+
|
280 |
+
# K Projection
|
281 |
+
d_KA = X.t() @ (dK @ KB.t())
|
282 |
+
d_KB = (KA.t() @ X.t()) @ dK
|
283 |
+
d_KA *= KS
|
284 |
+
d_KB *= KS
|
285 |
+
|
286 |
+
# V Projection
|
287 |
+
d_VA = X.t() @ (dV @ VB.t())
|
288 |
+
d_VB = (VA.t() @ X.t()) @ dV
|
289 |
+
d_VA *= VS
|
290 |
+
d_VB *= VS
|
291 |
+
|
292 |
+
# Combine derivatives to find dX
|
293 |
+
# dQ
|
294 |
+
QW = fast_dequantize(QW.t(), QW_quant)
|
295 |
+
dX = torch.matmul(dQ, QW.t(), out = X if ctx.inplace else None)
|
296 |
+
del QW
|
297 |
+
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
|
298 |
+
|
299 |
+
# dK
|
300 |
+
KW = fast_dequantize(KW.t(), KW_quant)
|
301 |
+
dX += dK @ KW.t()
|
302 |
+
del KW
|
303 |
+
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
|
304 |
+
|
305 |
+
# dV
|
306 |
+
VW = fast_dequantize(VW.t(), VW_quant)
|
307 |
+
dX += dV @ VW.t()
|
308 |
+
del VW
|
309 |
+
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
|
310 |
+
|
311 |
+
# QW, QW_quant, QA, QB, QS,
|
312 |
+
# KW, KW_quant, KA, KB, KS,
|
313 |
+
# VW, VW_quant, VA, VB, VS,
|
314 |
+
return dX.view(batch, seq_len, hd), \
|
315 |
+
None, None, d_QA.t(), d_QB.t(), None, \
|
316 |
+
None, None, d_KA.t(), d_KB.t(), None, \
|
317 |
+
None, None, d_VA.t(), d_VB.t(), None, \
|
318 |
+
None,
|
319 |
+
pass
|
320 |
+
pass
|
321 |
+
|
322 |
+
|
323 |
+
def apply_lora_qkv(self, X, inplace = True):
|
324 |
+
QW, QW_quant, QA, QB, QS = get_lora_parameters(self.q_proj)
|
325 |
+
KW, KW_quant, KA, KB, KS = get_lora_parameters(self.k_proj)
|
326 |
+
VW, VW_quant, VA, VB, VS = get_lora_parameters(self.v_proj)
|
327 |
+
Q, K, V = LoRA_QKV.apply(X,
|
328 |
+
QW, QW_quant, QA, QB, QS,
|
329 |
+
KW, KW_quant, KA, KB, KS,
|
330 |
+
VW, VW_quant, VA, VB, VS,
|
331 |
+
inplace,
|
332 |
+
)
|
333 |
+
return Q, K, V
|
334 |
+
pass
|
335 |
+
|
336 |
+
|
337 |
+
class LoRA_W(torch.autograd.Function):
|
338 |
+
"""
|
339 |
+
### LoRA weights
|
340 |
+
Wq = Wq + Aq @ Bq
|
341 |
+
Wk = Wk + Ak @ Bk
|
342 |
+
Wv = Wv + Av @ Bv
|
343 |
+
Q = X @ Wq = X @ Wq + X @ Aq @ Bq
|
344 |
+
K = X @ Wk = X @ Wk + X @ Ak @ Bk
|
345 |
+
V = X @ Wv = X @ Wv + X @ Av @ Bv
|
346 |
+
|
347 |
+
### Backpropagation chain rule
|
348 |
+
dC/dWq = X.T @ D(Wq)
|
349 |
+
dC/dWk = X.T @ D(Wk)
|
350 |
+
dC/dWv = X.T @ D(Wv)
|
351 |
+
|
352 |
+
### Q projection LoRA weights
|
353 |
+
dC/dAq = X.T @ D(Wq) @ B.T
|
354 |
+
dC/dBq = A.T @ X.T @ D(Wq)
|
355 |
+
|
356 |
+
### K projection LoRA weights
|
357 |
+
dC/dAk = X.T @ D(Wk) @ B.T
|
358 |
+
dC/dBk = A.T @ X.T @ D(Wk)
|
359 |
+
|
360 |
+
### V projection LoRA weights
|
361 |
+
dC/dAv = X.T @ D(Wv) @ B.T
|
362 |
+
dC/dBv = A.T @ X.T @ D(Wv)
|
363 |
+
"""
|
364 |
+
@staticmethod
|
365 |
+
@torch_amp_custom_fwd
|
366 |
+
def forward(ctx, X : torch.Tensor,
|
367 |
+
W, W_quant, A, B, S):
|
368 |
+
dtype = X.dtype
|
369 |
+
XW = matmul_lora(X, W, W_quant, A, B, S)
|
370 |
+
ctx.custom_saved_tensors = (W, W_quant, S,)
|
371 |
+
ctx.save_for_backward(A, B, X)
|
372 |
+
return XW
|
373 |
+
pass
|
374 |
+
|
375 |
+
@staticmethod
|
376 |
+
@torch_amp_custom_bwd
|
377 |
+
def backward(ctx, dY : torch.Tensor):
|
378 |
+
W, W_quant, S = ctx.custom_saved_tensors
|
379 |
+
A, B, X = ctx.saved_tensors
|
380 |
+
|
381 |
+
A, B = A.t(), B.t()
|
382 |
+
|
383 |
+
batch, seq_len, hd = X.shape
|
384 |
+
dY = dY.reshape(-1, dY.shape[-1]) # Must be reshape
|
385 |
+
X = X .reshape(-1, X .shape[-1]) # Must be reshape
|
386 |
+
dtype = X.dtype
|
387 |
+
|
388 |
+
### Weight projection LoRA weights
|
389 |
+
# Weight projection
|
390 |
+
d_A = X.t() @ (dY @ B.t())
|
391 |
+
d_B = (A.t() @ X.t()) @ dY
|
392 |
+
d_A *= S
|
393 |
+
d_B *= S
|
394 |
+
|
395 |
+
# Get derivative for dX
|
396 |
+
W = fast_dequantize(W.t(), W_quant)
|
397 |
+
dX = dY @ W.t()
|
398 |
+
del W
|
399 |
+
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
|
400 |
+
|
401 |
+
# W, W_quant, A, B, S
|
402 |
+
return dX.view(batch, seq_len, hd), \
|
403 |
+
None, None, d_A.t(), d_B.t(), None
|
404 |
+
pass
|
405 |
+
pass
|
406 |
+
|
407 |
+
|
408 |
+
def apply_lora_o(self, X):
|
409 |
+
OW, OW_quant, OA, OB, OS = get_lora_parameters(self.o_proj)
|
410 |
+
O = LoRA_W.apply(X, OW, OW_quant, OA, OB, OS)
|
411 |
+
return O
|
412 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/flex_attention.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from functools import lru_cache
|
17 |
+
from transformers.models.llama.modeling_llama import logger
|
18 |
+
|
19 |
+
torch_compile_options = {
|
20 |
+
"epilogue_fusion" : True,
|
21 |
+
"max_autotune" : True,
|
22 |
+
"shape_padding" : True,
|
23 |
+
"trace.enabled" : False, # Output Triton kernel outputs!
|
24 |
+
"triton.cudagraphs" : False,
|
25 |
+
}
|
26 |
+
|
27 |
+
# Flex Attention supported from torch 2.5 onwards only
|
28 |
+
try:
|
29 |
+
from torch.nn.attention.flex_attention import (
|
30 |
+
flex_attention as _flex_attention,
|
31 |
+
create_block_mask as _create_block_mask,
|
32 |
+
)
|
33 |
+
_flex_attention = torch.compile(_flex_attention, dynamic = True, options = torch_compile_options)
|
34 |
+
HAS_FLEX_ATTENTION = True
|
35 |
+
except:
|
36 |
+
HAS_FLEX_ATTENTION = False
|
37 |
+
pass
|
38 |
+
|
39 |
+
|
40 |
+
if not HAS_FLEX_ATTENTION:
|
41 |
+
|
42 |
+
# Logit softcapping
|
43 |
+
@torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
|
44 |
+
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
45 |
+
n_heads = self.num_heads
|
46 |
+
head_dim = self.head_dim
|
47 |
+
n_kv_heads = self.num_key_value_heads
|
48 |
+
n_groups = self.num_key_value_groups
|
49 |
+
|
50 |
+
# Grouped query attention
|
51 |
+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
52 |
+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
53 |
+
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
54 |
+
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
55 |
+
|
56 |
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
57 |
+
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
58 |
+
# We default to using the config file itself
|
59 |
+
# s = self.config.hidden_size // self.config.num_attention_heads
|
60 |
+
s = self.config.query_pre_attn_scalar
|
61 |
+
t = self.config.attn_logit_softcapping
|
62 |
+
|
63 |
+
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
64 |
+
A = torch.matmul(Q, K.transpose(2, 3))
|
65 |
+
A = t * torch.tanh(A / t) # Logit softcapping
|
66 |
+
A += causal_mask[:q_len, :q_len]
|
67 |
+
# Much slower in torch compile!
|
68 |
+
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
69 |
+
A = torch.nn.functional.softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
70 |
+
A = torch.matmul(A, V)
|
71 |
+
A = A.transpose(1, 2).contiguous()
|
72 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
73 |
+
return A
|
74 |
+
pass
|
75 |
+
|
76 |
+
create_flex_attention_causal_mask = None
|
77 |
+
create_flex_attention_sliding_window_mask = None
|
78 |
+
else:
|
79 |
+
# See https://github.com/pytorch-labs/attention-gym/blob/main/examples/flex_attn.ipynb
|
80 |
+
# for more examples
|
81 |
+
# BSD 3-Clause License Copyright (c) 2023, Driss Guessous, Horace He et al
|
82 |
+
import functools, math
|
83 |
+
|
84 |
+
def generate_tanh_softcap(t):
|
85 |
+
def tanh_softcap(x, b, h, q_idx, kv_idx):
|
86 |
+
return t * torch.tanh(x / t)
|
87 |
+
return tanh_softcap
|
88 |
+
pass
|
89 |
+
def causal_masker(b, h, q_idx, kv_idx):
|
90 |
+
return q_idx >= kv_idx
|
91 |
+
pass
|
92 |
+
|
93 |
+
@functools.lru_cache
|
94 |
+
def sliding_window_masker(size = 4096):
|
95 |
+
def sliding_window(b, h, q_idx, kv_idx):
|
96 |
+
causal_mask = q_idx >= kv_idx
|
97 |
+
window_mask = q_idx - kv_idx <= size
|
98 |
+
return causal_mask & window_mask
|
99 |
+
return sliding_window
|
100 |
+
pass
|
101 |
+
|
102 |
+
@functools.lru_cache
|
103 |
+
def create_block_mask(mask, n = 128):
|
104 |
+
return _create_block_mask(
|
105 |
+
mask, 1, 1, n, n,
|
106 |
+
BLOCK_SIZE = 128,
|
107 |
+
_compile = True,
|
108 |
+
)
|
109 |
+
pass
|
110 |
+
|
111 |
+
def create_flex_attention_causal_mask(max_seq_length = 8192):
|
112 |
+
causal_mask = create_block_mask(causal_masker, max_seq_length)
|
113 |
+
return causal_mask
|
114 |
+
pass
|
115 |
+
|
116 |
+
def create_flex_attention_sliding_window_mask(max_seq_length = 8192, sliding_window = 4096):
|
117 |
+
sliding_masker = sliding_window_masker(sliding_window)
|
118 |
+
causal_mask = create_block_mask(sliding_masker, max_seq_length)
|
119 |
+
return causal_mask
|
120 |
+
pass
|
121 |
+
|
122 |
+
@functools.lru_cache
|
123 |
+
def flex_attention(s, t):
|
124 |
+
scale = 1.0 / math.sqrt(s)
|
125 |
+
score_mod = generate_tanh_softcap(t)
|
126 |
+
return functools.partial(
|
127 |
+
_flex_attention, score_mod = score_mod, scale = scale, enable_gqa = True,
|
128 |
+
)
|
129 |
+
pass
|
130 |
+
|
131 |
+
def slow_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
132 |
+
n_heads = self.num_heads
|
133 |
+
head_dim = self.head_dim
|
134 |
+
s = self.config.query_pre_attn_scalar
|
135 |
+
t = self.config.attn_logit_softcapping
|
136 |
+
fx = flex_attention(s, t)
|
137 |
+
A = fx(query = Q, key = K, value = V, block_mask = causal_mask)
|
138 |
+
A = A.transpose(1, 2).contiguous()
|
139 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
140 |
+
return A
|
141 |
+
pass
|
142 |
+
pass
|
143 |
+
|
144 |
+
|
145 |
+
torch_matmul = torch.matmul
|
146 |
+
torch_tanh = torch.tanh
|
147 |
+
torch_nn_functional_softmax = torch.nn.functional.softmax
|
148 |
+
def slow_inference_attention_softcapping(Q, K, V, causal_mask, self, bsz, q_len):
|
149 |
+
n_heads = self.num_heads
|
150 |
+
head_dim = self.head_dim
|
151 |
+
n_kv_heads = self.num_key_value_heads
|
152 |
+
n_groups = self.num_key_value_groups
|
153 |
+
|
154 |
+
# Grouped query attention
|
155 |
+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
156 |
+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, q_len, head_dim)
|
157 |
+
K = K.reshape(bsz, n_heads, q_len, head_dim)
|
158 |
+
V = V.reshape(bsz, n_heads, q_len, head_dim)
|
159 |
+
|
160 |
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
161 |
+
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
162 |
+
# We default to using the config file itself
|
163 |
+
# s = self.config.hidden_size // self.config.num_attention_heads
|
164 |
+
s = self.config.query_pre_attn_scalar
|
165 |
+
t = self.config.attn_logit_softcapping
|
166 |
+
|
167 |
+
Q = Q * torch.tensor(s**-0.5, dtype = Q.dtype) # Follow Keras exactly
|
168 |
+
A = torch_matmul(Q, K.transpose(2, 3))
|
169 |
+
|
170 |
+
# Logit softcapping
|
171 |
+
A /= t; torch_tanh(A, out = A); A *= t;
|
172 |
+
A += causal_mask[:q_len, :q_len]
|
173 |
+
# Much slower in torch compile!
|
174 |
+
# A.masked_fill_(causal_mask[:q_len, :q_len], -float("inf"))
|
175 |
+
A = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32).to(Q.dtype)
|
176 |
+
A = torch_matmul(A, V)
|
177 |
+
A = A.transpose(1, 2).contiguous()
|
178 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
179 |
+
return A
|
180 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/geglu.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings, triton_tanh
|
19 |
+
|
20 |
+
|
21 |
+
@triton.jit
|
22 |
+
def _exact_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
23 |
+
block_idx = tl.program_id(0)
|
24 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
25 |
+
mask = offsets < n_elements
|
26 |
+
|
27 |
+
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
28 |
+
# h = f * up
|
29 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
30 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
31 |
+
|
32 |
+
f_row = 0.5 * e_row * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
33 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
34 |
+
h_row = f_row * g_row
|
35 |
+
|
36 |
+
# Store h
|
37 |
+
tl.store(h + offsets, h_row, mask = mask)
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
def geglu_exact_forward_kernel(gate, up):
|
42 |
+
batch, seq_len, hd = gate.shape
|
43 |
+
n_elements = gate.numel()
|
44 |
+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
|
45 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
46 |
+
_exact_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
47 |
+
return out
|
48 |
+
pass
|
49 |
+
|
50 |
+
|
51 |
+
@triton.jit
|
52 |
+
def _exact_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
53 |
+
"""
|
54 |
+
f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
55 |
+
h = f * up
|
56 |
+
|
57 |
+
df/de (with help of Wolfram :)
|
58 |
+
df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
59 |
+
|
60 |
+
Reuse via
|
61 |
+
f = 1/2 * (1 + erf(1/sqrt(2) * e)) * e
|
62 |
+
"""
|
63 |
+
block_idx = tl.program_id(0)
|
64 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
65 |
+
mask = offsets < n_elements
|
66 |
+
|
67 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
68 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
69 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
70 |
+
|
71 |
+
# Break e_row away for re-use
|
72 |
+
# f = 1/2 * e * (1 + erf(1/sqrt(2) * e))
|
73 |
+
f_partial_row = 0.5 * (tl.math.erf(tl.math.rsqrt(2.0) * e_row) + 1.0)
|
74 |
+
f_row = f_partial_row * e_row
|
75 |
+
|
76 |
+
f_row = f_row.to(DW_row.dtype)
|
77 |
+
# h = f * g
|
78 |
+
h_row = f_row * g_row
|
79 |
+
# df = DW * f
|
80 |
+
df_row = DW_row * f_row
|
81 |
+
# dg = DW * g
|
82 |
+
dg_row = DW_row * g_row
|
83 |
+
|
84 |
+
# df/de = 1/2 * (1 + erf(1/sqrt(2) * e)) + 1/sqrt(2*pi) * e * exp(-1/2 * e^2)
|
85 |
+
t = 0.3989422804014327 # 1/sqrt(2*pi)
|
86 |
+
df_de = f_partial_row + t * e_row * tl.exp(-0.5 * e_row * e_row)
|
87 |
+
|
88 |
+
de_row = dg_row.to(tl.float32) * df_de
|
89 |
+
de_row = de_row.to(DW_row.dtype)
|
90 |
+
|
91 |
+
# Store derivatives in buffers
|
92 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
93 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
94 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
95 |
+
pass
|
96 |
+
|
97 |
+
|
98 |
+
def geglu_exact_backward_kernel(DW, e, g):
|
99 |
+
batch_seq_len, hd = e.shape
|
100 |
+
n_elements = e.numel()
|
101 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
102 |
+
_exact_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
103 |
+
return DW, e, g
|
104 |
+
pass
|
105 |
+
|
106 |
+
|
107 |
+
@triton.jit
|
108 |
+
def _approx_forward_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
109 |
+
block_idx = tl.program_id(0)
|
110 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
111 |
+
mask = offsets < n_elements
|
112 |
+
|
113 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
114 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
115 |
+
# h = f * up
|
116 |
+
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
117 |
+
|
118 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
119 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
120 |
+
|
121 |
+
f_row = 0.5 * e_row * (
|
122 |
+
triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) \
|
123 |
+
+ 1.0
|
124 |
+
)
|
125 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
126 |
+
h_row = f_row * g_row
|
127 |
+
|
128 |
+
# Store h
|
129 |
+
tl.store(h + offsets, h_row, mask = mask)
|
130 |
+
pass
|
131 |
+
|
132 |
+
|
133 |
+
def geglu_approx_forward_kernel(gate, up):
|
134 |
+
batch, seq_len, hd = gate.shape
|
135 |
+
n_elements = gate.numel()
|
136 |
+
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = "cuda:0")
|
137 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
138 |
+
_approx_forward_kernel[grid](gate, up, out, n_elements, BLOCK_SIZE = 1024,)
|
139 |
+
return out
|
140 |
+
pass
|
141 |
+
|
142 |
+
|
143 |
+
@triton.jit
|
144 |
+
def _approx_backward_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
145 |
+
"""
|
146 |
+
f = 1/2 * e * (1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ))
|
147 |
+
h = f * up
|
148 |
+
|
149 |
+
df/de (with help from https://arxiv.org/pdf/2305.12073.pdf :))
|
150 |
+
df/de = 1/2 * [1 + tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )] +
|
151 |
+
1/2 * sech^2 [ sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) ] * \
|
152 |
+
( sqrt(2/pi) * x * (1 + 0.044715 * x^2 * 3 ) )
|
153 |
+
|
154 |
+
Notice sech^2(x) = 1 - tanh^2(x)
|
155 |
+
So reuse tanh( sqrt(2/pi) * x * (1 + 0.044715 * x^2 ) )
|
156 |
+
|
157 |
+
See https://www.desmos.com/calculator/nqprfoni6x
|
158 |
+
"""
|
159 |
+
block_idx = tl.program_id(0)
|
160 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
161 |
+
mask = offsets < n_elements
|
162 |
+
|
163 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
164 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
165 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
166 |
+
|
167 |
+
# See https://www.desmos.com/calculator/nqprfoni6x
|
168 |
+
s = 0.7978845608028654 # math.sqrt(2 / math.pi)
|
169 |
+
a = s * e_row # a = sqrt(2 / pi) * x
|
170 |
+
b = a * 0.044715 * e_row * e_row # b = a * 0.044715 * x^2
|
171 |
+
T = 1.0 + triton_tanh(a + b)
|
172 |
+
T2 = 0.5 * T
|
173 |
+
# Q = 0.5 * -T * (T - 2.0) * (a + 3.0 * b)
|
174 |
+
Q2 = -T2 * (T - 2.0) * (a + 3.0 * b)
|
175 |
+
df_de = T2 + Q2 # 1/2 * (T + Q)
|
176 |
+
|
177 |
+
# f = 1/2 * e * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3 ) ))
|
178 |
+
f_row = T2 * e_row
|
179 |
+
f_row = f_row.to(DW_row.dtype)
|
180 |
+
# h = f * g
|
181 |
+
h_row = f_row * g_row
|
182 |
+
# df = DW * f
|
183 |
+
df_row = DW_row * f_row
|
184 |
+
# dg = DW * g
|
185 |
+
dg_row = DW_row * g_row
|
186 |
+
|
187 |
+
de_row = dg_row.to(tl.float32) * df_de
|
188 |
+
de_row = de_row.to(DW_row.dtype)
|
189 |
+
|
190 |
+
# Store derivatives in buffers
|
191 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
192 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
193 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
194 |
+
pass
|
195 |
+
|
196 |
+
|
197 |
+
def geglu_approx_backward_kernel(DW, e, g):
|
198 |
+
batch_seq_len, hd = e.shape
|
199 |
+
n_elements = e.numel()
|
200 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
201 |
+
_approx_backward_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
202 |
+
return DW, e, g
|
203 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/layernorm.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
# Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import triton
|
17 |
+
import triton.language as tl
|
18 |
+
import torch
|
19 |
+
from .utils import calculate_settings
|
20 |
+
|
21 |
+
|
22 |
+
@triton.jit
|
23 |
+
def layernorm_forward(
|
24 |
+
Y, Y_row_stride,
|
25 |
+
X, X_row_stride,
|
26 |
+
W,
|
27 |
+
b,
|
28 |
+
r,
|
29 |
+
mu,
|
30 |
+
n_cols, eps,
|
31 |
+
BLOCK_SIZE : tl.constexpr
|
32 |
+
):
|
33 |
+
row_idx = tl.program_id(0)
|
34 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
35 |
+
mask = col_offsets < n_cols
|
36 |
+
|
37 |
+
Y += row_idx * Y_row_stride
|
38 |
+
X += row_idx * X_row_stride
|
39 |
+
r += row_idx
|
40 |
+
mu += row_idx
|
41 |
+
|
42 |
+
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
43 |
+
# are in float32!
|
44 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
45 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
46 |
+
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
47 |
+
|
48 |
+
mean_X = tl.sum(X_row, axis = 0) / n_cols
|
49 |
+
XX = X_row - mean_X
|
50 |
+
row_var = tl.sum(XX * XX, axis = 0) / n_cols
|
51 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
52 |
+
tl.store (r, inv_var)
|
53 |
+
tl.store (mu, mean_X)
|
54 |
+
output = (XX * inv_var) * W_row + b_row
|
55 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
56 |
+
pass
|
57 |
+
|
58 |
+
|
59 |
+
@triton.jit
|
60 |
+
def layernorm_backward(
|
61 |
+
dY, dY_row_stride,
|
62 |
+
X, X_row_stride,
|
63 |
+
W,
|
64 |
+
b,
|
65 |
+
r,
|
66 |
+
mu,
|
67 |
+
n_cols, eps,
|
68 |
+
BLOCK_SIZE : tl.constexpr
|
69 |
+
):
|
70 |
+
# Approximately follows https://github.com/karpathy/llm.c/blob/master/doc/layernorm/layernorm.md
|
71 |
+
row_idx = tl.program_id(0)
|
72 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
73 |
+
mask = col_offsets < n_cols
|
74 |
+
|
75 |
+
dY += row_idx * dY_row_stride
|
76 |
+
X += row_idx * X_row_stride
|
77 |
+
r += row_idx
|
78 |
+
mu += row_idx
|
79 |
+
|
80 |
+
# According to https://pytorch.org/torchtune/stable/_modules/torchtune/modules/layer_norm.html#Fp32LayerNorm, all modules
|
81 |
+
# are in float32!
|
82 |
+
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
83 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
84 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
85 |
+
b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)
|
86 |
+
|
87 |
+
inv_var = tl.load(r) .to(tl.float32)
|
88 |
+
mean = tl.load(mu).to(tl.float32)
|
89 |
+
normed = (X_row - mean) * inv_var
|
90 |
+
dY_W = dY_row * W_row
|
91 |
+
dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols
|
92 |
+
dX_row = dX_row * inv_var
|
93 |
+
tl.store(dY + col_offsets, dX_row, mask = mask)
|
94 |
+
pass
|
95 |
+
|
96 |
+
|
97 |
+
class Fast_Layernorm(torch.autograd.Function):
|
98 |
+
@staticmethod
|
99 |
+
def forward(ctx, X, W, b, eps):
|
100 |
+
shape = X.shape
|
101 |
+
dim = shape[-1]
|
102 |
+
X = X.view(-1, dim)
|
103 |
+
n_rows, n_cols = X.shape
|
104 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
105 |
+
|
106 |
+
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
|
107 |
+
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
108 |
+
mu = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
109 |
+
|
110 |
+
layernorm_forward[(n_rows,)](
|
111 |
+
Y, Y.stride(0),
|
112 |
+
X, X.stride(0),
|
113 |
+
W,
|
114 |
+
b,
|
115 |
+
r,
|
116 |
+
mu,
|
117 |
+
n_cols, eps,
|
118 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
119 |
+
num_warps = num_warps,
|
120 |
+
)
|
121 |
+
ctx.eps = eps
|
122 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
123 |
+
ctx.num_warps = num_warps
|
124 |
+
ctx.save_for_backward(X, W, b, r, mu)
|
125 |
+
return Y.view(*shape)
|
126 |
+
pass
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def backward(ctx, dY):
|
130 |
+
shape = dY.shape
|
131 |
+
dim = shape[-1]
|
132 |
+
dY = dY.view(-1, dim)
|
133 |
+
X, W, b, r, mu = ctx.saved_tensors
|
134 |
+
n_rows, n_cols = dY.shape
|
135 |
+
|
136 |
+
layernorm_backward[(n_rows,)](
|
137 |
+
dY, dY.stride(0),
|
138 |
+
X, X .stride(0),
|
139 |
+
W,
|
140 |
+
b,
|
141 |
+
r,
|
142 |
+
mu,
|
143 |
+
n_cols, ctx.eps,
|
144 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
145 |
+
num_warps = ctx.num_warps,
|
146 |
+
)
|
147 |
+
dX = dY.view(*shape)
|
148 |
+
return dX, None, None, None, None
|
149 |
+
pass
|
150 |
+
pass
|
151 |
+
|
152 |
+
|
153 |
+
def fast_layernorm(layernorm, X):
|
154 |
+
assert(layernorm.elementwise_affine is True)
|
155 |
+
W = layernorm.weight
|
156 |
+
bias = layernorm.bias
|
157 |
+
eps = layernorm.variance_epsilon if \
|
158 |
+
hasattr(layernorm, "variance_epsilon") \
|
159 |
+
else layernorm.eps
|
160 |
+
out = Fast_Layernorm.apply(X, W, bias, eps)
|
161 |
+
return out
|
162 |
+
pass
|
163 |
+
|
164 |
+
|
165 |
+
from torch.nn import LayerNorm
|
166 |
+
class Unsloth_LayerNorm(LayerNorm):
|
167 |
+
def forward(self, X):
|
168 |
+
return fast_layernorm(self, X)
|
169 |
+
pass
|
170 |
+
pass
|
171 |
+
|
172 |
+
|
173 |
+
def patch_layernorm():
|
174 |
+
import torch.nn
|
175 |
+
torch.nn.LayerNorm = Unsloth_LayerNorm
|
176 |
+
return
|
177 |
+
pass
|
178 |
+
|
179 |
+
|
180 |
+
def unpatch_layernorm():
|
181 |
+
import torch.nn
|
182 |
+
torch.nn.LayerNorm = LayerNorm
|
183 |
+
return
|
184 |
+
pass
|
185 |
+
|
186 |
+
|
187 |
+
def test_layernorm(
|
188 |
+
dim = 1024, eps = 1e-5, dtype = torch.float16,
|
189 |
+
bsz = 21, random_state = 3407, seqlen = 3341,
|
190 |
+
):
|
191 |
+
from torch.nn import LayerNorm
|
192 |
+
layernorm = LayerNorm((dim,), eps = eps, device = "cuda", dtype = dtype)
|
193 |
+
torch.cuda.manual_seed(random_state)
|
194 |
+
torch.manual_seed(random_state)
|
195 |
+
torch.nn.init.uniform_(layernorm.weight)
|
196 |
+
torch.nn.init.uniform_(layernorm.bias)
|
197 |
+
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
|
198 |
+
XX = X.clone()
|
199 |
+
X .requires_grad_(True)
|
200 |
+
XX.requires_grad_(True)
|
201 |
+
Y = layernorm(X)
|
202 |
+
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
|
203 |
+
Y.backward(YY)
|
204 |
+
correct_grad = X.grad.clone()
|
205 |
+
# from unsloth.kernels import fast_layernorm
|
206 |
+
Y = fast_layernorm(layernorm, XX)
|
207 |
+
Y.backward(YY)
|
208 |
+
assert(torch.dist(correct_grad, XX.grad).item() <= 0.1)
|
209 |
+
pass
|
210 |
+
|
211 |
+
|
212 |
+
def testing_suite_layernorm():
|
213 |
+
for dim in [512, 1024, 2048]:
|
214 |
+
for dtype in [torch.float16, torch.bfloat16]:
|
215 |
+
with torch.autocast(device_type = "cuda", dtype = dtype):
|
216 |
+
for seqlen in [3341, 2048, 349]:
|
217 |
+
for random_state in [3407, 42]:
|
218 |
+
test_layernorm(
|
219 |
+
dim = dim,
|
220 |
+
eps = 1e-5,
|
221 |
+
dtype = dtype,
|
222 |
+
bsz = 21,
|
223 |
+
random_state = random_state,
|
224 |
+
seqlen = seqlen,
|
225 |
+
)
|
226 |
+
pass
|
227 |
+
pass
|
228 |
+
pass
|
229 |
+
pass
|
230 |
+
pass
|
231 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/rms_layernorm.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings
|
19 |
+
|
20 |
+
|
21 |
+
@triton.jit
|
22 |
+
def _rms_layernorm_forward(
|
23 |
+
Y, Y_row_stride,
|
24 |
+
X, X_row_stride,
|
25 |
+
W, W_row_stride,
|
26 |
+
r, r_row_stride,
|
27 |
+
n_cols, eps,
|
28 |
+
BLOCK_SIZE : tl.constexpr
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Fast RMS Layernorm kernel
|
32 |
+
Inspiration from a Triton tutorial:
|
33 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
34 |
+
"""
|
35 |
+
row_idx = tl.program_id(0)
|
36 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
37 |
+
mask = col_offsets < n_cols
|
38 |
+
|
39 |
+
Y += row_idx * Y_row_stride
|
40 |
+
X += row_idx * X_row_stride
|
41 |
+
r += row_idx * r_row_stride
|
42 |
+
|
43 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
44 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
|
45 |
+
|
46 |
+
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
47 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
48 |
+
tl.store(r, inv_var)
|
49 |
+
normed = X_row * inv_var
|
50 |
+
normed = normed.to(W_row.dtype) # Exact copy from HF
|
51 |
+
output = normed * W_row
|
52 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
53 |
+
pass
|
54 |
+
|
55 |
+
|
56 |
+
@triton.heuristics({"GEMMA": lambda args: args["GEMMA"],})
|
57 |
+
@triton.jit
|
58 |
+
def _rms_layernorm_backward(
|
59 |
+
dY, dY_row_stride,
|
60 |
+
X, X_row_stride,
|
61 |
+
W, W_row_stride,
|
62 |
+
r, r_row_stride,
|
63 |
+
dW, dW_row_stride,
|
64 |
+
n_cols, eps,
|
65 |
+
GEMMA : tl.constexpr,
|
66 |
+
BLOCK_SIZE : tl.constexpr,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Fast RMS Layernorm kernel for the backward pass
|
70 |
+
Inspiration from a Triton tutorial:
|
71 |
+
https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
72 |
+
"""
|
73 |
+
row_idx = tl.program_id(0)
|
74 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
75 |
+
mask = col_offsets < n_cols
|
76 |
+
|
77 |
+
dY += row_idx * dY_row_stride
|
78 |
+
X += row_idx * X_row_stride
|
79 |
+
r += row_idx * r_row_stride
|
80 |
+
|
81 |
+
dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
|
82 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
83 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
84 |
+
|
85 |
+
# Get saved row variance
|
86 |
+
inv_var = tl.load(r).to(tl.float32)
|
87 |
+
normed = X_row * inv_var
|
88 |
+
|
89 |
+
if GEMMA: dY_W = dY_row * (W_row + 1.0)
|
90 |
+
else: dY_W = dY_row * W_row
|
91 |
+
|
92 |
+
rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
|
93 |
+
output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
|
94 |
+
tl.store(dY + col_offsets, output, mask = mask)
|
95 |
+
pass
|
96 |
+
|
97 |
+
|
98 |
+
@triton.jit
|
99 |
+
def _gemma_rms_layernorm_forward(
|
100 |
+
Y, Y_row_stride,
|
101 |
+
X, X_row_stride,
|
102 |
+
W, W_row_stride,
|
103 |
+
r, r_row_stride,
|
104 |
+
n_cols, eps,
|
105 |
+
BLOCK_SIZE : tl.constexpr,
|
106 |
+
):
|
107 |
+
# Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
|
108 |
+
# and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
|
109 |
+
# exactly. Essentially all in float32!
|
110 |
+
row_idx = tl.program_id(0)
|
111 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
112 |
+
mask = col_offsets < n_cols
|
113 |
+
|
114 |
+
Y += row_idx * Y_row_stride
|
115 |
+
X += row_idx * X_row_stride
|
116 |
+
r += row_idx * r_row_stride
|
117 |
+
|
118 |
+
X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
|
119 |
+
W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
|
120 |
+
|
121 |
+
row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
|
122 |
+
inv_var = tl.math.rsqrt(row_var + eps)
|
123 |
+
tl.store(r, inv_var)
|
124 |
+
normed = X_row * inv_var
|
125 |
+
output = normed * (W_row + 1.0)
|
126 |
+
|
127 |
+
tl.store(Y + col_offsets, output, mask = mask)
|
128 |
+
pass
|
129 |
+
|
130 |
+
|
131 |
+
class Fast_RMS_Layernorm(torch.autograd.Function):
|
132 |
+
@staticmethod
|
133 |
+
def forward(ctx, X, W, eps, gemma = False):
|
134 |
+
shape = X.shape
|
135 |
+
dim = shape[-1]
|
136 |
+
X = X.view(-1, dim)
|
137 |
+
n_rows, n_cols = X.shape
|
138 |
+
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
139 |
+
|
140 |
+
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda:0")
|
141 |
+
r = torch.empty(n_rows, dtype = torch.float32, device = "cuda:0")
|
142 |
+
|
143 |
+
fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
|
144 |
+
fx[(n_rows,)](
|
145 |
+
Y, Y.stride(0),
|
146 |
+
X, X.stride(0),
|
147 |
+
W, W.stride(0),
|
148 |
+
r, r.stride(0),
|
149 |
+
n_cols, eps,
|
150 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
151 |
+
num_warps = num_warps,
|
152 |
+
)
|
153 |
+
ctx.eps = eps
|
154 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
155 |
+
ctx.num_warps = num_warps
|
156 |
+
ctx.GEMMA = gemma
|
157 |
+
ctx.save_for_backward(X, W, r)
|
158 |
+
return Y.view(*shape)
|
159 |
+
pass
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def backward(ctx, dY):
|
163 |
+
shape = dY.shape
|
164 |
+
dim = shape[-1]
|
165 |
+
dY = dY.view(-1, dim)
|
166 |
+
X, W, r = ctx.saved_tensors
|
167 |
+
n_rows, n_cols = dY.shape
|
168 |
+
dW = X
|
169 |
+
|
170 |
+
_rms_layernorm_backward[(n_rows,)](
|
171 |
+
dY, dY.stride(0),
|
172 |
+
X, X .stride(0),
|
173 |
+
W, W .stride(0),
|
174 |
+
r, r .stride(0),
|
175 |
+
dW, dW.stride(0),
|
176 |
+
n_cols, ctx.eps,
|
177 |
+
GEMMA = ctx.GEMMA,
|
178 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
179 |
+
num_warps = ctx.num_warps,
|
180 |
+
)
|
181 |
+
dX = dY.view(*shape)
|
182 |
+
return dX, None, None, None
|
183 |
+
pass
|
184 |
+
pass
|
185 |
+
|
186 |
+
|
187 |
+
def fast_rms_layernorm(layernorm, X, gemma = False):
|
188 |
+
W = layernorm.weight
|
189 |
+
eps = layernorm.variance_epsilon if \
|
190 |
+
hasattr(layernorm, "variance_epsilon") \
|
191 |
+
else layernorm.eps
|
192 |
+
out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
|
193 |
+
return out
|
194 |
+
pass
|
195 |
+
|
196 |
+
|
197 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
198 |
+
class Unsloth_LlamaRMSNorm(LlamaRMSNorm):
|
199 |
+
def forward(self, X):
|
200 |
+
return fast_rms_layernorm(self, X, gemma = False)
|
201 |
+
pass
|
202 |
+
pass
|
203 |
+
|
204 |
+
try:
|
205 |
+
from transformers.models.mllama.modeling_mllama import MllamaTextRMSNorm
|
206 |
+
class Unsloth_MllamaTextRMSNorm(MllamaTextRMSNorm):
|
207 |
+
def forward(self, X):
|
208 |
+
return fast_rms_layernorm(self, X, gemma = False)
|
209 |
+
pass
|
210 |
+
pass
|
211 |
+
except:
|
212 |
+
pass
|
213 |
+
pass
|
214 |
+
|
215 |
+
def patch_rms_layernorm():
|
216 |
+
import transformers.models.llama.modeling_llama
|
217 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = Unsloth_LlamaRMSNorm
|
218 |
+
try:
|
219 |
+
import transformers.models.mllama.modeling_mllama
|
220 |
+
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = Unsloth_MllamaTextRMSNorm
|
221 |
+
except:
|
222 |
+
pass
|
223 |
+
return
|
224 |
+
pass
|
225 |
+
|
226 |
+
|
227 |
+
def unpatch_rms_layernorm():
|
228 |
+
import transformers.models.llama.modeling_llama
|
229 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
230 |
+
try:
|
231 |
+
import transformers.models.mllama.modeling_mllama
|
232 |
+
transformers.models.mllama.modeling_mllama.MllamaTextRMSNorm = MllamaTextRMSNorm
|
233 |
+
except:
|
234 |
+
pass
|
235 |
+
return
|
236 |
+
return
|
237 |
+
pass
|
238 |
+
|
239 |
+
|
240 |
+
def test_rms_layernorm(
|
241 |
+
dim = 1024, eps = 1e-5, dtype = torch.float16,
|
242 |
+
bsz = 21, random_state = 3407, seqlen = 3341,
|
243 |
+
):
|
244 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
245 |
+
layernorm = LlamaRMSNorm((dim,), eps = eps).to("cuda")
|
246 |
+
torch.cuda.manual_seed(random_state)
|
247 |
+
torch.manual_seed(random_state)
|
248 |
+
torch.nn.init.uniform_(layernorm.weight)
|
249 |
+
X = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda")
|
250 |
+
XX = X.clone()
|
251 |
+
X .requires_grad_(True)
|
252 |
+
XX.requires_grad_(True)
|
253 |
+
Y = layernorm(X)
|
254 |
+
YY = torch.randn((bsz, seqlen, dim), dtype = dtype, device = "cuda", requires_grad = True)
|
255 |
+
Y.backward(YY)
|
256 |
+
correct_grad = X.grad.clone()
|
257 |
+
# from unsloth.kernels import fast_rms_layernorm
|
258 |
+
Y = fast_rms_layernorm(layernorm, XX)
|
259 |
+
Y.backward(YY)
|
260 |
+
assert(torch.amax(correct_grad - XX.grad).item() <= 0.05)
|
261 |
+
pass
|
262 |
+
|
263 |
+
|
264 |
+
def testing_suite_layernorm():
|
265 |
+
for dim in [512, 1024, 2048]:
|
266 |
+
for dtype in [torch.float16, torch.bfloat16]:
|
267 |
+
with torch.autocast(device_type = "cuda", dtype = dtype):
|
268 |
+
for seqlen in [3341, 2048, 349]:
|
269 |
+
for random_state in [3407, 42]:
|
270 |
+
test_rms_layernorm(
|
271 |
+
dim = dim,
|
272 |
+
eps = 1e-5,
|
273 |
+
dtype = dtype,
|
274 |
+
bsz = 21,
|
275 |
+
random_state = random_state,
|
276 |
+
seqlen = seqlen,
|
277 |
+
)
|
278 |
+
pass
|
279 |
+
pass
|
280 |
+
pass
|
281 |
+
pass
|
282 |
+
pass
|
283 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/rope_embedding.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings
|
19 |
+
ROPE_GROUP_SIZE = 4
|
20 |
+
|
21 |
+
@triton.heuristics({"BACKWARD_PASS": lambda args: args["BACKWARD_PASS"],})
|
22 |
+
@triton.jit
|
23 |
+
def _rope_embedding(
|
24 |
+
Q, Q_row_stride,
|
25 |
+
cos, cos_row_stride,
|
26 |
+
sin, sin_row_stride,
|
27 |
+
seqlen,
|
28 |
+
head_dim : tl.constexpr,
|
29 |
+
n_heads : tl.constexpr,
|
30 |
+
BACKWARD_PASS : tl.constexpr,
|
31 |
+
BLOCK_SIZE : tl.constexpr,
|
32 |
+
):
|
33 |
+
"""
|
34 |
+
Calculates the RoPE Embedding quickly
|
35 |
+
RoPE is Q * cos + rotate_half(Q) * sin
|
36 |
+
See our blog post for more info
|
37 |
+
"""
|
38 |
+
ROPE_GROUP_SIZE = 4
|
39 |
+
row_position = tl.program_id(0)
|
40 |
+
group_head_position = tl.program_id(1)
|
41 |
+
col_offsets = tl.arange(0, BLOCK_SIZE)
|
42 |
+
half_head_dim = head_dim // 2
|
43 |
+
mask = col_offsets < half_head_dim
|
44 |
+
|
45 |
+
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
|
46 |
+
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
47 |
+
cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
|
48 |
+
half_head_dim*0 + col_offsets, mask = mask, other = 0)
|
49 |
+
|
50 |
+
if BACKWARD_PASS:
|
51 |
+
# See our blog post for more info.
|
52 |
+
sin1 = -sin1
|
53 |
+
pass
|
54 |
+
|
55 |
+
# [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8
|
56 |
+
head_start = group_head_position * ROPE_GROUP_SIZE
|
57 |
+
head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)
|
58 |
+
|
59 |
+
# 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)
|
60 |
+
for k in range(head_start, head_end):
|
61 |
+
offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
|
62 |
+
offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
|
63 |
+
|
64 |
+
# For Gemma - sometimes RoPE must be done in float32 and not bfloat16
|
65 |
+
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
|
66 |
+
Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
|
67 |
+
|
68 |
+
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
|
69 |
+
tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
|
70 |
+
pass
|
71 |
+
pass
|
72 |
+
|
73 |
+
|
74 |
+
class Fast_RoPE_Embedding(torch.autograd.Function):
|
75 |
+
@staticmethod
|
76 |
+
def forward(ctx, Q, cos, sin):
|
77 |
+
cos, sin = cos.squeeze(), sin.squeeze()
|
78 |
+
batch, seq_len, n_heads, head_dim = Q.shape
|
79 |
+
Q = Q.view(batch*seq_len, n_heads*head_dim)
|
80 |
+
n_rows, n_cols = Q.shape
|
81 |
+
assert(seq_len <= cos.shape[0])
|
82 |
+
|
83 |
+
# [TODO] Changing blocksize to head_dim//2 seems to have
|
84 |
+
# some concurrency / un-deterministic issues.
|
85 |
+
BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)
|
86 |
+
|
87 |
+
# group_size = 4 # 4 or 8, too large group_size can hurt performance.
|
88 |
+
div, mod = divmod(n_heads, ROPE_GROUP_SIZE)
|
89 |
+
n_groups = div + (mod != 0)
|
90 |
+
|
91 |
+
_rope_embedding[(n_rows, n_groups, )](
|
92 |
+
Q, Q.stride(0),
|
93 |
+
cos, cos.stride(0),
|
94 |
+
sin, sin.stride(0),
|
95 |
+
seq_len,
|
96 |
+
head_dim, n_heads,
|
97 |
+
BACKWARD_PASS = False,
|
98 |
+
BLOCK_SIZE = BLOCK_SIZE,
|
99 |
+
num_warps = num_warps,
|
100 |
+
)
|
101 |
+
ctx.BLOCK_SIZE = BLOCK_SIZE
|
102 |
+
ctx.num_warps = num_warps
|
103 |
+
ctx.n_groups = n_groups
|
104 |
+
ctx.cos = cos
|
105 |
+
ctx.sin = sin
|
106 |
+
return Q.view(batch, seq_len, n_heads, head_dim)
|
107 |
+
pass
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def backward(ctx, dY):
|
111 |
+
batch, seq_len, n_heads, head_dim = dY.shape
|
112 |
+
dY = dY.reshape(batch*seq_len, n_heads*head_dim)
|
113 |
+
# Must be reshape not view
|
114 |
+
n_rows, n_cols = dY.shape
|
115 |
+
|
116 |
+
cos = ctx.cos
|
117 |
+
sin = ctx.sin
|
118 |
+
|
119 |
+
_rope_embedding[(n_rows, ctx.n_groups, )](
|
120 |
+
dY, dY .stride(0),
|
121 |
+
cos, cos.stride(0),
|
122 |
+
sin, sin.stride(0),
|
123 |
+
seq_len, head_dim, n_heads,
|
124 |
+
BACKWARD_PASS = True,
|
125 |
+
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
126 |
+
num_warps = ctx.num_warps,
|
127 |
+
)
|
128 |
+
dY = dY.view(batch, seq_len, n_heads, head_dim)
|
129 |
+
return dY, None, None,
|
130 |
+
pass
|
131 |
+
pass
|
132 |
+
|
133 |
+
|
134 |
+
def fast_rope_embedding(Q, K, cos, sin):
|
135 |
+
Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)
|
136 |
+
K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)
|
137 |
+
return Q, K
|
138 |
+
pass
|
139 |
+
|
140 |
+
|
141 |
+
class Slow_RoPE_Embedding(torch.autograd.Function):
|
142 |
+
@staticmethod
|
143 |
+
def forward(ctx, Q, cos, sin, position_ids):
|
144 |
+
if position_ids is not None:
|
145 |
+
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
146 |
+
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
147 |
+
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
148 |
+
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
149 |
+
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
150 |
+
|
151 |
+
# Q * cos + rotate_half(Q) * sin
|
152 |
+
half = Q.shape[-1]//2
|
153 |
+
RH_Q = torch.cat((-Q[..., half:], Q[..., :half]), dim = -1)
|
154 |
+
Q *= cos
|
155 |
+
Q.addcmul_(RH_Q, sin)
|
156 |
+
# RH_Q *= sin
|
157 |
+
# Q += RH_Q
|
158 |
+
ctx.save_for_backward(cos, sin)
|
159 |
+
return Q
|
160 |
+
pass
|
161 |
+
|
162 |
+
@staticmethod
|
163 |
+
def backward(ctx, dY):
|
164 |
+
cos, sin = ctx.saved_tensors
|
165 |
+
# Q * cos + rotate_half.T(Q) * sin
|
166 |
+
half = dY.shape[-1]//2
|
167 |
+
RH_dY = torch.cat((dY[..., half:], -dY[..., :half]), dim = -1)
|
168 |
+
dY *= cos
|
169 |
+
dY.addcmul_(RH_dY, sin)
|
170 |
+
# RH_dY *= sin
|
171 |
+
# dY += RH_dY
|
172 |
+
return dY, None, None, None
|
173 |
+
pass
|
174 |
+
pass
|
175 |
+
|
176 |
+
|
177 |
+
def inplace_rope_embedding(Q, K, cos, sin, position_ids):
|
178 |
+
Q = Slow_RoPE_Embedding.apply(Q, cos, sin, position_ids)
|
179 |
+
K = Slow_RoPE_Embedding.apply(K, cos, sin, position_ids)
|
180 |
+
return Q, K
|
181 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/swiglu.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
import triton.language as tl
|
17 |
+
import torch
|
18 |
+
from .utils import calculate_settings
|
19 |
+
|
20 |
+
|
21 |
+
@triton.jit
|
22 |
+
def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE : tl.constexpr,):
|
23 |
+
block_idx = tl.program_id(0)
|
24 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
25 |
+
mask = offsets < n_elements
|
26 |
+
|
27 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
28 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
29 |
+
|
30 |
+
# f = e * sigmoid(e)
|
31 |
+
f_row = e_row * tl.sigmoid(e_row) # e_row / (1 + tl.exp(-e_row))
|
32 |
+
f_row = f_row.to(g_row.dtype) # Exact copy from HF
|
33 |
+
# h = f * g
|
34 |
+
h_row = f_row * g_row
|
35 |
+
|
36 |
+
# Store h
|
37 |
+
tl.store(h + offsets, h_row, mask = mask)
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
def swiglu_fg_kernel(e, g):
|
42 |
+
batch, seq_len, hd = e.shape
|
43 |
+
n_elements = e.numel()
|
44 |
+
h = torch.empty((batch, seq_len, hd), dtype = e.dtype, device = "cuda:0")
|
45 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
46 |
+
_fg_kernel[grid](e, g, h, n_elements, BLOCK_SIZE = 1024,)
|
47 |
+
return h
|
48 |
+
pass
|
49 |
+
|
50 |
+
|
51 |
+
@triton.jit
|
52 |
+
def _DWf_DW_dfg_kernel(DW, e, g, n_elements, BLOCK_SIZE : tl.constexpr,):
|
53 |
+
"""
|
54 |
+
e = e.float()
|
55 |
+
se = 1.0 / (1.0 + torch.exp(-e))
|
56 |
+
f = (se * e).to(dtype)
|
57 |
+
h = f * g
|
58 |
+
df = DW * f
|
59 |
+
dg = DW * g
|
60 |
+
de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
61 |
+
"""
|
62 |
+
block_idx = tl.program_id(0)
|
63 |
+
offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
64 |
+
mask = offsets < n_elements
|
65 |
+
|
66 |
+
DW_row = tl.load(DW + offsets, mask = mask, other = 0)#.to(tl.float32)
|
67 |
+
e_row = tl.load(e + offsets, mask = mask, other = 0).to(tl.float32)
|
68 |
+
g_row = tl.load(g + offsets, mask = mask, other = 0)#.to(tl.float32)
|
69 |
+
|
70 |
+
# e = e.float()
|
71 |
+
# se = 1.0 / (1.0 + torch.exp(-e))
|
72 |
+
se_row = tl.sigmoid(e_row) # 1.0 / (1.0 + tl.exp(-e_row))
|
73 |
+
# f = (se * e).to(dtype)
|
74 |
+
f_row = se_row * e_row
|
75 |
+
f_row = f_row.to(DW_row.dtype)
|
76 |
+
# h = f * g
|
77 |
+
h_row = f_row * g_row
|
78 |
+
# df = DW * f
|
79 |
+
df_row = DW_row * f_row
|
80 |
+
# dg = DW * g
|
81 |
+
dg_row = DW_row * g_row
|
82 |
+
# de = (dg.float() * se * (1.0 + e * (1.0 - se))).to(dtype)
|
83 |
+
de_row = dg_row.to(tl.float32) * se_row * (1.0 + e_row * (1.0 - se_row))
|
84 |
+
de_row = de_row.to(DW_row.dtype)
|
85 |
+
|
86 |
+
# Store derivatives in buffers
|
87 |
+
tl.store(DW + offsets, h_row, mask = mask) # h = f * g
|
88 |
+
tl.store(e + offsets, df_row, mask = mask) # df = DW * f
|
89 |
+
tl.store(g + offsets, de_row, mask = mask) # de
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
def swiglu_DWf_DW_dfg_kernel(DW, e, g):
|
94 |
+
batch_seq_len, hd = e.shape
|
95 |
+
n_elements = e.numel()
|
96 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
97 |
+
_DWf_DW_dfg_kernel[grid](DW, e, g, n_elements, BLOCK_SIZE = 1024,)
|
98 |
+
return DW, e, g
|
99 |
+
pass
|
unsloth-main/unsloth-main/unsloth/kernels/utils.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import triton
|
16 |
+
MAX_FUSED_SIZE = 65536
|
17 |
+
next_power_of_2 = triton.next_power_of_2
|
18 |
+
|
19 |
+
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
|
20 |
+
import torch
|
21 |
+
from packaging.version import Version
|
22 |
+
if Version(torch.__version__) < Version("2.4.0"):
|
23 |
+
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
24 |
+
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
25 |
+
else:
|
26 |
+
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
27 |
+
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
+
# tl.math.tanh now is libdevice.tanh
|
32 |
+
from packaging.version import Version
|
33 |
+
import triton
|
34 |
+
if Version(triton.__version__) >= Version("3.0.0"):
|
35 |
+
from triton.language.extra import libdevice
|
36 |
+
triton_tanh = libdevice.tanh
|
37 |
+
else:
|
38 |
+
import triton.language as tl
|
39 |
+
triton_tanh = tl.math.tanh
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
def calculate_settings(n):
|
44 |
+
BLOCK_SIZE = next_power_of_2(n)
|
45 |
+
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
46 |
+
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
47 |
+
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
48 |
+
num_warps = 4
|
49 |
+
if BLOCK_SIZE >= 32768: num_warps = 32
|
50 |
+
elif BLOCK_SIZE >= 8192: num_warps = 16
|
51 |
+
elif BLOCK_SIZE >= 2048: num_warps = 8
|
52 |
+
return BLOCK_SIZE, num_warps
|
53 |
+
pass
|
54 |
+
|
55 |
+
|
56 |
+
import bitsandbytes as bnb
|
57 |
+
# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
|
58 |
+
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
|
59 |
+
global CUDA_STREAM
|
60 |
+
CUDA_STREAM = None
|
61 |
+
get_ptr = bnb.functional.get_ptr
|
62 |
+
import ctypes
|
63 |
+
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
|
64 |
+
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
|
65 |
+
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
|
66 |
+
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
|
67 |
+
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
|
68 |
+
|
69 |
+
|
70 |
+
def QUANT_STATE(W):
|
71 |
+
return getattr(W, "quant_state", None)
|
72 |
+
pass
|
73 |
+
|
74 |
+
|
75 |
+
def get_lora_parameters(proj):
|
76 |
+
# For DPO or disabled adapters
|
77 |
+
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
78 |
+
W = base_layer.weight
|
79 |
+
|
80 |
+
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
81 |
+
return W, QUANT_STATE(W), None, None, None
|
82 |
+
pass
|
83 |
+
|
84 |
+
active_adapter = proj.active_adapters[0] if \
|
85 |
+
hasattr(proj, "active_adapters") else proj.active_adapter
|
86 |
+
A = proj.lora_A [active_adapter].weight
|
87 |
+
B = proj.lora_B [active_adapter].weight
|
88 |
+
s = proj.scaling[active_adapter]
|
89 |
+
return W, QUANT_STATE(W), A, B, s
|
90 |
+
pass
|
91 |
+
|
92 |
+
|
93 |
+
def get_lora_parameters_bias(proj):
|
94 |
+
# For DPO or disabled adapters
|
95 |
+
base_layer = (proj.base_layer if hasattr(proj, "base_layer") else proj)
|
96 |
+
W = base_layer.weight
|
97 |
+
bias = base_layer.bias
|
98 |
+
|
99 |
+
if not hasattr(proj, "disable_adapters") or proj.disable_adapters or proj.merged:
|
100 |
+
return W, QUANT_STATE(W), None, None, None, bias
|
101 |
+
pass
|
102 |
+
|
103 |
+
active_adapter = proj.active_adapters[0] if \
|
104 |
+
hasattr(proj, "active_adapters") else proj.active_adapter
|
105 |
+
A = proj.lora_A [active_adapter].weight
|
106 |
+
B = proj.lora_B [active_adapter].weight
|
107 |
+
s = proj.scaling[active_adapter]
|
108 |
+
return W, QUANT_STATE(W), A, B, s, bias
|
109 |
+
pass
|
110 |
+
|
111 |
+
|
112 |
+
if HAS_CUDA_STREAM:
|
113 |
+
def fast_dequantize(W, quant_state = None, out = None):
|
114 |
+
if quant_state is None: return W
|
115 |
+
if type(quant_state) is not list:
|
116 |
+
# New quant_state as a class
|
117 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
118 |
+
absmax = quant_state.absmax
|
119 |
+
shape = quant_state.shape
|
120 |
+
dtype = quant_state.dtype
|
121 |
+
blocksize = quant_state.blocksize
|
122 |
+
offset = quant_state.offset
|
123 |
+
state2 = quant_state.state2
|
124 |
+
absmax2 = state2.absmax
|
125 |
+
code2 = state2.code
|
126 |
+
blocksize2 = state2.blocksize
|
127 |
+
else:
|
128 |
+
# Old quant_state as a list of lists
|
129 |
+
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
130 |
+
offset, state2 = compressed_stats
|
131 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
132 |
+
pass
|
133 |
+
global CUDA_STREAM
|
134 |
+
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
|
135 |
+
|
136 |
+
# Create weight matrix
|
137 |
+
if out is None:
|
138 |
+
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
|
139 |
+
else:
|
140 |
+
assert(out.shape == shape)
|
141 |
+
assert(out.dtype == dtype)
|
142 |
+
|
143 |
+
# NF4 dequantization of statistics
|
144 |
+
n_elements_absmax = absmax.numel()
|
145 |
+
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
|
146 |
+
|
147 |
+
# Do dequantization
|
148 |
+
ptr_out_absmax = get_ptr(out_absmax)
|
149 |
+
cdequantize_blockwise_fp32(
|
150 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
151 |
+
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax), CUDA_STREAM,
|
152 |
+
)
|
153 |
+
out_absmax += offset
|
154 |
+
|
155 |
+
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
156 |
+
cdequantize_blockwise_bf16_nf4
|
157 |
+
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
158 |
+
ctypes.c_int(blocksize), ctypes.c_int(out.numel()), CUDA_STREAM,)
|
159 |
+
|
160 |
+
# Careful returning transposed data
|
161 |
+
is_transposed = (True if W.shape[0] == 1 else False)
|
162 |
+
return out.t() if is_transposed else out
|
163 |
+
pass
|
164 |
+
else:
|
165 |
+
def fast_dequantize(W, quant_state = None, out = None):
|
166 |
+
if quant_state is None: return W
|
167 |
+
if type(quant_state) is not list:
|
168 |
+
# New quant_state as a class
|
169 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
170 |
+
absmax = quant_state.absmax
|
171 |
+
shape = quant_state.shape
|
172 |
+
dtype = quant_state.dtype
|
173 |
+
blocksize = quant_state.blocksize
|
174 |
+
offset = quant_state.offset
|
175 |
+
state2 = quant_state.state2
|
176 |
+
absmax2 = state2.absmax
|
177 |
+
code2 = state2.code
|
178 |
+
blocksize2 = state2.blocksize
|
179 |
+
else:
|
180 |
+
# Old quant_state as a list of lists
|
181 |
+
absmax, shape, dtype, blocksize, compressed_stats, _, _ = quant_state
|
182 |
+
offset, state2 = compressed_stats
|
183 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
184 |
+
pass
|
185 |
+
|
186 |
+
# Create weight matrix
|
187 |
+
if out is None:
|
188 |
+
out = torch.empty(shape, dtype = dtype, device = "cuda:0")
|
189 |
+
else:
|
190 |
+
assert(out.shape == shape)
|
191 |
+
assert(out.dtype == dtype)
|
192 |
+
|
193 |
+
# NF4 dequantization of statistics
|
194 |
+
n_elements_absmax = absmax.numel()
|
195 |
+
out_absmax = torch.empty(n_elements_absmax, dtype = torch.float32, device = "cuda:0")
|
196 |
+
|
197 |
+
# Do dequantization
|
198 |
+
ptr_out_absmax = get_ptr(out_absmax)
|
199 |
+
cdequantize_blockwise_fp32(
|
200 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), ptr_out_absmax,
|
201 |
+
ctypes.c_int(blocksize2), ctypes.c_int(n_elements_absmax),
|
202 |
+
)
|
203 |
+
out_absmax += offset
|
204 |
+
|
205 |
+
fx = cdequantize_blockwise_fp16_nf4 if dtype == torch.float16 else \
|
206 |
+
cdequantize_blockwise_bf16_nf4
|
207 |
+
fx(get_ptr(None), get_ptr(W), ptr_out_absmax, get_ptr(out),
|
208 |
+
ctypes.c_int(blocksize), ctypes.c_int(out.numel()),)
|
209 |
+
|
210 |
+
# Careful returning transposed data
|
211 |
+
is_transposed = (True if W.shape[0] == 1 else False)
|
212 |
+
return out.t() if is_transposed else out
|
213 |
+
pass
|
214 |
+
pass
|
215 |
+
|
216 |
+
|
217 |
+
if HAS_CUDA_STREAM:
|
218 |
+
def fast_gemv(X, W, quant_state, out = None):
|
219 |
+
if quant_state is None: return torch.matmul(X, W, out = out)
|
220 |
+
# For fast X @ W where seq_len == 1
|
221 |
+
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
222 |
+
_, q_len, hd = X.shape
|
223 |
+
# assert(q_len == 1)
|
224 |
+
|
225 |
+
if type(quant_state) is not list:
|
226 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
227 |
+
absmax = quant_state.absmax
|
228 |
+
shape = quant_state.shape
|
229 |
+
dtype = quant_state.dtype
|
230 |
+
blocksize = quant_state.blocksize
|
231 |
+
stats = quant_state.code
|
232 |
+
offset = quant_state.offset
|
233 |
+
state2 = quant_state.state2
|
234 |
+
absmax2 = state2.absmax
|
235 |
+
code2 = state2.code
|
236 |
+
blocksize2 = state2.blocksize
|
237 |
+
else:
|
238 |
+
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
239 |
+
offset, state2 = compressed_stats
|
240 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
241 |
+
pass
|
242 |
+
global CUDA_STREAM
|
243 |
+
if CUDA_STREAM is None: CUDA_STREAM = torch.cuda.current_stream("cuda:0")
|
244 |
+
|
245 |
+
# assert(dtype == X.dtype)
|
246 |
+
bout = shape[0]
|
247 |
+
|
248 |
+
if out is None:
|
249 |
+
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
|
250 |
+
# else:
|
251 |
+
# assert(out.shape == (1, 1, bout,))
|
252 |
+
# pass
|
253 |
+
|
254 |
+
n = 1
|
255 |
+
m = shape[0]
|
256 |
+
k = shape[1]
|
257 |
+
lda = shape[0]
|
258 |
+
ldc = shape[0]
|
259 |
+
ldb = (hd+1)//2
|
260 |
+
m = ctypes.c_int32(m)
|
261 |
+
n = ctypes.c_int32(n)
|
262 |
+
k = ctypes.c_int32(k)
|
263 |
+
lda = ctypes.c_int32(lda)
|
264 |
+
ldb = ctypes.c_int32(ldb)
|
265 |
+
ldc = ctypes.c_int32(ldc)
|
266 |
+
|
267 |
+
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
|
268 |
+
cdequantize_blockwise_fp32(
|
269 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
270 |
+
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()), CUDA_STREAM,
|
271 |
+
)
|
272 |
+
df += offset
|
273 |
+
absmax = df
|
274 |
+
|
275 |
+
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
276 |
+
cgemm_4bit_inference_naive_bf16
|
277 |
+
|
278 |
+
blocksize = ctypes.c_int32(blocksize)
|
279 |
+
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
280 |
+
lda, ldb, ldc, blocksize, CUDA_STREAM,)
|
281 |
+
|
282 |
+
return out
|
283 |
+
pass
|
284 |
+
else:
|
285 |
+
def fast_gemv(X, W, quant_state, out = None):
|
286 |
+
if quant_state is None: return torch.matmul(X, W, out = out)
|
287 |
+
# For fast X @ W where seq_len == 1
|
288 |
+
# From https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/functional.py#L1469
|
289 |
+
_, q_len, hd = X.shape
|
290 |
+
# assert(q_len == 1)
|
291 |
+
|
292 |
+
if type(quant_state) is not list:
|
293 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
294 |
+
absmax = quant_state.absmax
|
295 |
+
shape = quant_state.shape
|
296 |
+
dtype = quant_state.dtype
|
297 |
+
blocksize = quant_state.blocksize
|
298 |
+
stats = quant_state.code
|
299 |
+
offset = quant_state.offset
|
300 |
+
state2 = quant_state.state2
|
301 |
+
absmax2 = state2.absmax
|
302 |
+
code2 = state2.code
|
303 |
+
blocksize2 = state2.blocksize
|
304 |
+
else:
|
305 |
+
absmax, shape, dtype, blocksize, compressed_stats, quant_type, stats = quant_state
|
306 |
+
offset, state2 = compressed_stats
|
307 |
+
absmax2, code2, blocksize2, _, _, _, _ = state2
|
308 |
+
pass
|
309 |
+
# assert(dtype == X.dtype)
|
310 |
+
bout = shape[0]
|
311 |
+
|
312 |
+
if out is None:
|
313 |
+
out = torch.empty((1, 1, bout,), dtype = dtype, device = "cuda:0")
|
314 |
+
# else:
|
315 |
+
# assert(out.shape == (1, 1, bout,))
|
316 |
+
# pass
|
317 |
+
|
318 |
+
n = 1
|
319 |
+
m = shape[0]
|
320 |
+
k = shape[1]
|
321 |
+
lda = shape[0]
|
322 |
+
ldc = shape[0]
|
323 |
+
ldb = (hd+1)//2
|
324 |
+
m = ctypes.c_int32(m)
|
325 |
+
n = ctypes.c_int32(n)
|
326 |
+
k = ctypes.c_int32(k)
|
327 |
+
lda = ctypes.c_int32(lda)
|
328 |
+
ldb = ctypes.c_int32(ldb)
|
329 |
+
ldc = ctypes.c_int32(ldc)
|
330 |
+
|
331 |
+
df = torch.empty(absmax.shape, dtype = torch.float32, device = "cuda:0")
|
332 |
+
cdequantize_blockwise_fp32(
|
333 |
+
get_ptr(code2), get_ptr(absmax), get_ptr(absmax2), get_ptr(df),
|
334 |
+
ctypes.c_int(blocksize2), ctypes.c_int(df.numel()),
|
335 |
+
)
|
336 |
+
df += offset
|
337 |
+
absmax = df
|
338 |
+
|
339 |
+
fx = cgemm_4bit_inference_naive_fp16 if dtype == torch.float16 else \
|
340 |
+
cgemm_4bit_inference_naive_bf16
|
341 |
+
|
342 |
+
blocksize = ctypes.c_int32(blocksize)
|
343 |
+
fx(m, n, k, get_ptr(X), get_ptr(W), get_ptr(absmax), get_ptr(stats), get_ptr(out),
|
344 |
+
lda, ldb, ldc, blocksize,)
|
345 |
+
|
346 |
+
return out
|
347 |
+
pass
|
348 |
+
pass
|
349 |
+
|
350 |
+
|
351 |
+
def fast_linear_forward(proj, X, temp_lora = None, out = None):
|
352 |
+
|
353 |
+
W, W_quant, lora_A, lora_B, lora_S, bias = get_lora_parameters_bias(proj)
|
354 |
+
bsz, q_len, in_dim = X.shape
|
355 |
+
if q_len != 1: return matmul_lora(X, W, W_quant, lora_A, lora_B, lora_S)
|
356 |
+
|
357 |
+
if W_quant is None:
|
358 |
+
out = torch.matmul(X, W.t(), out = out)
|
359 |
+
elif bsz == 1 and q_len == 1:
|
360 |
+
out = fast_gemv(X, W, W_quant, out = out)
|
361 |
+
else:
|
362 |
+
W = fast_dequantize(W.t(), W_quant)
|
363 |
+
out = torch.matmul(X, W, out = out)
|
364 |
+
pass
|
365 |
+
|
366 |
+
# Add in LoRA weights
|
367 |
+
if lora_A is not None:
|
368 |
+
out_dim = out.shape[2]
|
369 |
+
dtype = X.dtype
|
370 |
+
|
371 |
+
if not hasattr(lora_A, "_fast_lora"):
|
372 |
+
lora_A._fast_lora = lora_A.to(dtype)
|
373 |
+
lora_B._fast_lora = lora_B.to(dtype)
|
374 |
+
pass
|
375 |
+
|
376 |
+
if bsz == 1:
|
377 |
+
out = out.view(out_dim)
|
378 |
+
temp_lora = torch.mv(lora_A._fast_lora, X.ravel(), out = temp_lora)
|
379 |
+
out.addmv_(lora_B._fast_lora, temp_lora, alpha = lora_S)
|
380 |
+
else:
|
381 |
+
out = out.view(bsz, out_dim)
|
382 |
+
temp_lora = torch.mm(X.view(bsz, in_dim), lora_A._fast_lora.t(), out = temp_lora)
|
383 |
+
out.addmm_(temp_lora, lora_B._fast_lora.t(), alpha = lora_S)
|
384 |
+
pass
|
385 |
+
out = out.view(bsz, 1, out_dim)
|
386 |
+
pass
|
387 |
+
|
388 |
+
if bias is not None: out += bias
|
389 |
+
|
390 |
+
return out
|
391 |
+
pass
|
392 |
+
|
393 |
+
|
394 |
+
def matmul_lora(X, W, W_quant, A, B, s, out = None):
|
395 |
+
dtype = X.dtype
|
396 |
+
W = fast_dequantize(W.t(), W_quant)
|
397 |
+
|
398 |
+
if X.dim() == 3:
|
399 |
+
batch, seq_len, d = X.shape
|
400 |
+
X = X.view(-1, X.shape[-1])
|
401 |
+
reshape = True
|
402 |
+
else:
|
403 |
+
reshape = False
|
404 |
+
pass
|
405 |
+
|
406 |
+
out = torch.matmul(X, W, out = out)
|
407 |
+
if W_quant is not None: del W
|
408 |
+
|
409 |
+
if A is not None:
|
410 |
+
# LoRA is enabled
|
411 |
+
A, B = A.t(), B.t()
|
412 |
+
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
|
413 |
+
pass
|
414 |
+
|
415 |
+
return out.view(batch, seq_len, -1) if reshape else out
|
416 |
+
pass
|
unsloth-main/unsloth-main/unsloth/models/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .loader import FastLanguageModel
|
16 |
+
from .llama import FastLlamaModel
|
17 |
+
from .mistral import FastMistralModel
|
18 |
+
from .qwen2 import FastQwen2Model
|
19 |
+
from .dpo import PatchDPOTrainer
|
20 |
+
from ._utils import is_bfloat16_supported
|
unsloth-main/unsloth-main/unsloth/models/_utils.py
ADDED
@@ -0,0 +1,1140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
__version__ = "2024.9.post4"
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
"prepare_model_for_kbit_training",
|
19 |
+
"xformers",
|
20 |
+
"xformers_attention",
|
21 |
+
"xformers_version",
|
22 |
+
"__version__",
|
23 |
+
"HAS_FLASH_ATTENTION",
|
24 |
+
"HAS_FLASH_ATTENTION_SOFTCAPPING",
|
25 |
+
"PRE_CHECK",
|
26 |
+
"platform_system",
|
27 |
+
"patch_tokenizer",
|
28 |
+
"get_statistics",
|
29 |
+
"Unsloth_Offloaded_Gradient_Checkpointer",
|
30 |
+
"offload_to_disk",
|
31 |
+
"offload_input_embeddings",
|
32 |
+
"offload_output_embeddings",
|
33 |
+
"is_bfloat16_supported",
|
34 |
+
"unsloth_offloaded_gradient_checkpoint",
|
35 |
+
"torch_compile_options",
|
36 |
+
"patch_linear_scaling",
|
37 |
+
"patch_llama_rope_scaling",
|
38 |
+
"check_nvidia",
|
39 |
+
"create_boolean_mask",
|
40 |
+
"torch_amp_custom_fwd",
|
41 |
+
"torch_amp_custom_bwd",
|
42 |
+
"accelerate_old_send_to_device",
|
43 |
+
"accelerate_new_send_to_device",
|
44 |
+
"patch_gradient_checkpointing",
|
45 |
+
"unpatch_gradient_checkpointing",
|
46 |
+
]
|
47 |
+
|
48 |
+
import torch
|
49 |
+
from typing import Union, Optional, List, Any, Callable, Tuple
|
50 |
+
from platform import system as platform_system
|
51 |
+
platform_system = platform_system()
|
52 |
+
import numpy as np
|
53 |
+
import warnings, subprocess, re, inspect, psutil, os, math
|
54 |
+
from packaging.version import Version
|
55 |
+
|
56 |
+
# =============================================
|
57 |
+
# Disable some warnings which can get annoying
|
58 |
+
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
|
59 |
+
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
|
60 |
+
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
|
61 |
+
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "huggingface_hub")
|
62 |
+
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
|
63 |
+
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
|
64 |
+
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
|
65 |
+
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
|
66 |
+
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocessing")
|
67 |
+
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
|
68 |
+
|
69 |
+
# Stop "Special tokens have been added in the vocabulary, ..."
|
70 |
+
import logging
|
71 |
+
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL+1)
|
72 |
+
# =============================================
|
73 |
+
|
74 |
+
# =============================================
|
75 |
+
# Edits all Config files to enable RoPE Scaling for all models
|
76 |
+
|
77 |
+
# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
|
78 |
+
def patch_mistral_nemo_config(config):
|
79 |
+
if "head_dim (" not in config:
|
80 |
+
add_head_dim = "If it is not specified, will default to `8`.\n"\
|
81 |
+
" head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"\
|
82 |
+
" The attention head dimension."
|
83 |
+
config = config.replace("If it is not specified, will default to `8`.", add_head_dim)
|
84 |
+
|
85 |
+
add_head_dim = "num_key_value_heads=8,\n head_dim=None,"
|
86 |
+
config = config.replace("num_key_value_heads=8,", add_head_dim)
|
87 |
+
|
88 |
+
add_head_dim = "self.sliding_window = sliding_window\n self.head_dim = head_dim or hidden_size // num_attention_heads\n"
|
89 |
+
config = config.replace("self.sliding_window = sliding_window", add_head_dim)
|
90 |
+
pass
|
91 |
+
return config
|
92 |
+
pass
|
93 |
+
|
94 |
+
from transformers import __version__ as transformers_version
|
95 |
+
from transformers import PretrainedConfig
|
96 |
+
model_architectures = ["llama", "mistral", "gemma", "gemma2", "qwen2",]
|
97 |
+
|
98 |
+
for model_name in model_architectures:
|
99 |
+
config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
|
100 |
+
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
|
101 |
+
config_filename = f"{model_name.title()}Config"
|
102 |
+
exec(f"from {config_filepath} import {config_filename}", globals())
|
103 |
+
|
104 |
+
try:
|
105 |
+
config = inspect.getsource(eval(config_filename))
|
106 |
+
except:
|
107 |
+
continue
|
108 |
+
if "rope_scaling" in config: continue
|
109 |
+
config = re.sub(
|
110 |
+
r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
|
111 |
+
r"rope_scaling=None,"\
|
112 |
+
r"\n **kwargs):\n"\
|
113 |
+
r"\n self.rope_scaling = rope_scaling\n",
|
114 |
+
config,
|
115 |
+
)
|
116 |
+
|
117 |
+
# Just for Mistral Nemo
|
118 |
+
if model_name == "mistral":
|
119 |
+
if Version(transformers_version) <= Version("4.42.4"):
|
120 |
+
config = patch_mistral_nemo_config(config)
|
121 |
+
pass
|
122 |
+
|
123 |
+
exec(config, globals())
|
124 |
+
exec(f"import {config_filepath}", globals())
|
125 |
+
exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
|
126 |
+
pass
|
127 |
+
# =============================================
|
128 |
+
|
129 |
+
# =============================================
|
130 |
+
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
|
131 |
+
import torch
|
132 |
+
torch_version = torch.__version__
|
133 |
+
if Version(torch_version) < Version("2.4.0"):
|
134 |
+
torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
|
135 |
+
torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
|
136 |
+
else:
|
137 |
+
torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
|
138 |
+
torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
|
139 |
+
pass
|
140 |
+
# =============================================
|
141 |
+
|
142 |
+
# =============================================
|
143 |
+
# Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'
|
144 |
+
import transformers.cache_utils
|
145 |
+
if hasattr(transformers.cache_utils, "DynamicCache") and \
|
146 |
+
transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__":
|
147 |
+
|
148 |
+
source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)
|
149 |
+
start = source.find("def")
|
150 |
+
spaces = start*" "
|
151 |
+
source = source.split("\n")
|
152 |
+
source = "\n".join(x[start:] for x in source)
|
153 |
+
where = source.find("raise KeyError")
|
154 |
+
source = source[:where] + \
|
155 |
+
f"if len(self) == 0:\n{spaces}{spaces}"\
|
156 |
+
" raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \
|
157 |
+
f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:]
|
158 |
+
source = source.replace("__getitem__", "__cache_utils_getitem__", 1)
|
159 |
+
exec(source)
|
160 |
+
transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__
|
161 |
+
pass
|
162 |
+
# =============================================
|
163 |
+
|
164 |
+
# =============================================
|
165 |
+
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
|
166 |
+
import bitsandbytes as bnb
|
167 |
+
from transformers import AutoTokenizer
|
168 |
+
from transformers.utils.import_utils import _is_package_available
|
169 |
+
|
170 |
+
major_version, minor_version = torch.cuda.get_device_capability()
|
171 |
+
SUPPORTS_BFLOAT16 = False
|
172 |
+
HAS_FLASH_ATTENTION = False
|
173 |
+
HAS_FLASH_ATTENTION_SOFTCAPPING = False
|
174 |
+
|
175 |
+
if major_version >= 8:
|
176 |
+
SUPPORTS_BFLOAT16 = True
|
177 |
+
if _is_package_available("flash_attn"):
|
178 |
+
# Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
|
179 |
+
try:
|
180 |
+
from flash_attn.flash_attn_interface import flash_attn_cuda
|
181 |
+
HAS_FLASH_ATTENTION = True
|
182 |
+
|
183 |
+
# Also check for softcapping
|
184 |
+
from flash_attn import __version__ as flash_attn_version
|
185 |
+
HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version("2.6.3")
|
186 |
+
if not HAS_FLASH_ATTENTION_SOFTCAPPING:
|
187 |
+
print(
|
188 |
+
"Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"\
|
189 |
+
"Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"\
|
190 |
+
"To update flash-attn, do the below:\n"\
|
191 |
+
'\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
|
192 |
+
)
|
193 |
+
except:
|
194 |
+
print(
|
195 |
+
"Unsloth: Your Flash Attention 2 installation seems to be broken?\n"\
|
196 |
+
"A possible explanation is you have a new CUDA version which isn't\n"\
|
197 |
+
"yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"\
|
198 |
+
"We shall now use Xformers instead, which does not have any performance hits!\n"\
|
199 |
+
"We found this negligible impact by benchmarking on 1x A100."
|
200 |
+
)
|
201 |
+
|
202 |
+
# Stop Flash Attention from importing!
|
203 |
+
import transformers.utils.import_utils
|
204 |
+
transformers.utils.import_utils.is_flash_attn_2_available = lambda *args, **kwargs: False
|
205 |
+
import transformers.utils
|
206 |
+
transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False
|
207 |
+
|
208 |
+
HAS_FLASH_ATTENTION = False
|
209 |
+
pass
|
210 |
+
else:
|
211 |
+
HAS_FLASH_ATTENTION = False
|
212 |
+
else:
|
213 |
+
# Tri Dao's benchmark shows xformers is faster for now.
|
214 |
+
HAS_FLASH_ATTENTION = False
|
215 |
+
pass
|
216 |
+
|
217 |
+
from transformers.models.llama.modeling_llama import logger
|
218 |
+
|
219 |
+
# =============================================
|
220 |
+
# Get Xformers
|
221 |
+
from xformers import __version__ as xformers_version
|
222 |
+
# Temporarily disable 0.0.27 and higher - inference issues
|
223 |
+
if False: #Version(xformers_version) >= Version("0.0.27"):
|
224 |
+
raise ImportError(
|
225 |
+
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
|
226 |
+
"then press Disconnect Runtime and then Restart it.\n"\
|
227 |
+
"\n"\
|
228 |
+
"%%capture\n"
|
229 |
+
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
|
230 |
+
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
|
231 |
+
'!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
|
232 |
+
'\n'\
|
233 |
+
f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"\
|
234 |
+
'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"'
|
235 |
+
)
|
236 |
+
pass
|
237 |
+
|
238 |
+
if Version(torch_version) < Version("2.2.0") and Version(xformers_version) >= Version("0.0.24"):
|
239 |
+
raise ImportError(
|
240 |
+
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
|
241 |
+
f"Please install xformers < 0.0.24 for torch = {torch_version}."
|
242 |
+
)
|
243 |
+
elif Version(torch_version) < Version("2.3.0") and Version(xformers_version) >= Version("0.0.26"):
|
244 |
+
raise ImportError(
|
245 |
+
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
|
246 |
+
f"Please install xformers < 0.0.26 for torch = {torch_version}."
|
247 |
+
)
|
248 |
+
elif Version(torch_version) < Version("2.4.0") and Version(xformers_version) > Version("0.0.27"):
|
249 |
+
raise ImportError(
|
250 |
+
f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"\
|
251 |
+
f"Please install xformers <= 0.0.27 for torch = {torch_version}."
|
252 |
+
)
|
253 |
+
pass
|
254 |
+
|
255 |
+
from xformers._cpp_lib import _register_extensions
|
256 |
+
try:
|
257 |
+
_register_extensions() # Check if C++ modules are loaded correctly
|
258 |
+
except Exception as error:
|
259 |
+
raise ImportError(
|
260 |
+
"Unsloth: Xformers was not installed correctly.\n"\
|
261 |
+
"Please install xformers separately first.\n"\
|
262 |
+
"Then confirm if it's correctly installed by running:\n"\
|
263 |
+
"python -m xformers.info\n\n"
|
264 |
+
"Longer error message:\n" + str(error)
|
265 |
+
)
|
266 |
+
pass
|
267 |
+
import xformers.ops.fmha as xformers
|
268 |
+
xformers_attention = xformers.memory_efficient_attention
|
269 |
+
|
270 |
+
# Check TRL version
|
271 |
+
from trl import __version__ as trl_version
|
272 |
+
# Unsloth now supports all TRL versions!
|
273 |
+
if False:#Version(trl_version) >= Version("0.9.0"):
|
274 |
+
raise ImportError(
|
275 |
+
"Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "\
|
276 |
+
"then press Disconnect Runtime and then Restart it.\n"\
|
277 |
+
"\n"\
|
278 |
+
"%%capture\n"
|
279 |
+
"# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
|
280 |
+
'!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
|
281 |
+
'!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'\
|
282 |
+
'\n'\
|
283 |
+
f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"\
|
284 |
+
'Please downgrade TRL via `pip install --force-reinstall trl'
|
285 |
+
)
|
286 |
+
pass
|
287 |
+
|
288 |
+
# =============================================
|
289 |
+
# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
|
290 |
+
accelerate_old_send_to_device = None
|
291 |
+
accelerate_new_send_to_device = None
|
292 |
+
if Version(xformers_version) >= Version("0.0.27"):
|
293 |
+
import accelerate.utils.operations
|
294 |
+
if hasattr(accelerate.utils.operations, "send_to_device") and \
|
295 |
+
accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
|
296 |
+
accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
|
297 |
+
from accelerate.utils.operations import *
|
298 |
+
send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
|
299 |
+
send_to_device = re.sub(
|
300 |
+
r"([ ]{4,})return tensor\.to\(device\)",
|
301 |
+
r"\1try: return tensor.to(device)\n\1except: return tensor",
|
302 |
+
send_to_device,
|
303 |
+
).replace("def send_to_device", "def _fixed_send_to_device")
|
304 |
+
exec(send_to_device)
|
305 |
+
# accelerate.utils.operations.send_to_device = _fixed_send_to_device
|
306 |
+
accelerate_new_send_to_device = _fixed_send_to_device
|
307 |
+
pass
|
308 |
+
pass
|
309 |
+
|
310 |
+
# Transformers 4.46 breaks dynamic caching. This is a hack
|
311 |
+
import transformers.generation.configuration_utils
|
312 |
+
if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"):
|
313 |
+
if type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS) is list:
|
314 |
+
transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append("dynamic")
|
315 |
+
pass
|
316 |
+
pass
|
317 |
+
# =============================================
|
318 |
+
|
319 |
+
# =============================================
|
320 |
+
# Torch compile settings
|
321 |
+
|
322 |
+
# Just remove max_autotune_gemm warning
|
323 |
+
import functools
|
324 |
+
@functools.lru_cache(None)
|
325 |
+
def is_big_gpu(index):
|
326 |
+
sms = torch.cuda.get_device_properties(index).multi_processor_count
|
327 |
+
if sms < 80: # V100
|
328 |
+
# log.warning("not enough SMs to use max_autotune_gemm mode")
|
329 |
+
return False
|
330 |
+
return True
|
331 |
+
import torch._inductor.utils
|
332 |
+
torch._inductor.utils.is_big_gpu = is_big_gpu
|
333 |
+
|
334 |
+
|
335 |
+
# Torch compile arguments
|
336 |
+
torch_compile_arguments = [
|
337 |
+
"config.dce = True",
|
338 |
+
"config.memory_planning = True",
|
339 |
+
"config.memory_pool = 'combined'",
|
340 |
+
"config.coordinate_descent_tuning = True",
|
341 |
+
"config.max_autotune_gemm = False", # GEMM is unnecessary
|
342 |
+
"config.autotune_multi_device = False",
|
343 |
+
"config.max_autotune_gemm_backends = 'TRITON,ATEN,CPP'", # Not much faster
|
344 |
+
"config.aggressive_fusion = False", # Careful changes results!
|
345 |
+
"config.cuda.enable_cuda_lto = True",
|
346 |
+
"config.cuda.use_fast_math = True",
|
347 |
+
"config.cuda.compile_opt_level = '-O2'",
|
348 |
+
]
|
349 |
+
# Torch dynamo arguments
|
350 |
+
torch_dynamo_arguments = [
|
351 |
+
"config.accumulated_cache_size_limit = 1024", # Bump up a bit from 256
|
352 |
+
"config.suppress_errors = True", # Supress errors for now
|
353 |
+
"config.do_not_emit_runtime_asserts = True",
|
354 |
+
"config.cache_size_limit = 1024", # Flex Attention
|
355 |
+
]
|
356 |
+
import torch._inductor.config as config
|
357 |
+
for _try_compile_argument in torch_compile_arguments:
|
358 |
+
try: exec(_try_compile_argument)
|
359 |
+
except: pass
|
360 |
+
pass
|
361 |
+
import torch._dynamo.config as config
|
362 |
+
for _try_dynamo_argument in torch_dynamo_arguments:
|
363 |
+
try: exec(_try_dynamo_argument)
|
364 |
+
except: pass
|
365 |
+
pass
|
366 |
+
torch_compile_options = {
|
367 |
+
"epilogue_fusion" : True,
|
368 |
+
"max_autotune" : True,
|
369 |
+
"shape_padding" : True,
|
370 |
+
"trace.enabled" : False, # Output Triton kernel outputs!
|
371 |
+
"triton.cudagraphs" : False,
|
372 |
+
}
|
373 |
+
# =============================================
|
374 |
+
|
375 |
+
def prepare_model_for_kbit_training(
|
376 |
+
model : Any,
|
377 |
+
use_gradient_checkpointing : Optional = True,
|
378 |
+
use_reentrant : Optional[bool] = True,
|
379 |
+
) -> Any:
|
380 |
+
"""
|
381 |
+
Calculates where to place the gradient checkpoints given n_layers.
|
382 |
+
We also freeze all other layers's gradients
|
383 |
+
|
384 |
+
Args:
|
385 |
+
model: Any LlamaModel with layers.
|
386 |
+
use_gradient_checkpointing (`bool`, *optional*):
|
387 |
+
Default enabled. Provides memory savings by not saving all activations,
|
388 |
+
but only some.
|
389 |
+
use_reentrant (`bool`, *optional*):
|
390 |
+
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
|
391 |
+
Optimal gradient checkpointing algorithm which will be the default in
|
392 |
+
future Pytorch versions.
|
393 |
+
"""
|
394 |
+
|
395 |
+
# Freeze all parameters except LoRA
|
396 |
+
with torch.no_grad():
|
397 |
+
for name, param in model.named_parameters():
|
398 |
+
if ".lora_A." in name or ".lora_B." in name or ".lora_magnitude_vector" in name:
|
399 |
+
param.requires_grad_(True)
|
400 |
+
# Also must be in float32!
|
401 |
+
if param.dtype != torch.float32:
|
402 |
+
name = name.replace("base_model", "model", 1)
|
403 |
+
layer_number = re.search(r"\.[\d]{1,}\.", name).group(0)
|
404 |
+
name = name.replace(layer_number, f"[{layer_number[1:-1]}].")
|
405 |
+
name = name.replace(".weight", "", 1)
|
406 |
+
exec(f"{name}.to(torch.float32)")
|
407 |
+
pass
|
408 |
+
else:
|
409 |
+
param.requires_grad_(False)
|
410 |
+
pass
|
411 |
+
pass
|
412 |
+
|
413 |
+
# Gradient checkpointing!
|
414 |
+
if use_gradient_checkpointing == "unsloth":
|
415 |
+
|
416 |
+
# Saves VRAM!
|
417 |
+
original_model = model
|
418 |
+
while hasattr(original_model, "model"):
|
419 |
+
original_model._offloaded_gradient_checkpointing = True
|
420 |
+
original_model = original_model.model
|
421 |
+
pass
|
422 |
+
original_model._offloaded_gradient_checkpointing = True
|
423 |
+
|
424 |
+
model.gradient_checkpointing_enable()
|
425 |
+
|
426 |
+
elif use_gradient_checkpointing == True:
|
427 |
+
model.gradient_checkpointing_enable()
|
428 |
+
pass
|
429 |
+
|
430 |
+
# If use_reentrant = True which is the Pytorch default, we just make the input requires_grad.
|
431 |
+
if use_reentrant:
|
432 |
+
if hasattr(model, "enable_input_require_grads"):
|
433 |
+
model.enable_input_require_grads()
|
434 |
+
else:
|
435 |
+
def make_inputs_require_grad(module, input, output):
|
436 |
+
output.requires_grad_(True)
|
437 |
+
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
438 |
+
|
439 |
+
return model
|
440 |
+
pass
|
441 |
+
|
442 |
+
|
443 |
+
def patch_tokenizer(model, tokenizer):
|
444 |
+
"""
|
445 |
+
Phi3's pad_token isn't set. We set it to <|placeholder...
|
446 |
+
Llama-3 is <|reserved...
|
447 |
+
Llama-2 is <unk>
|
448 |
+
Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
|
449 |
+
Fixes https://github.com/unslothai/unsloth/issues/5
|
450 |
+
"""
|
451 |
+
possible_reserved_tokens = (
|
452 |
+
"<|finetune_right_pad_id|>", # Llama-3.1
|
453 |
+
"<pad>", # Mistral Nemo
|
454 |
+
"<|reserved", # Llama-3
|
455 |
+
"<|placeholder", # Phi-3
|
456 |
+
"[control", # Mistral type models
|
457 |
+
)
|
458 |
+
joiner = "\1\0=+=\0\1"
|
459 |
+
number_repetitions = 3 - 1 # Number of reserved tokens needed
|
460 |
+
|
461 |
+
if model is not None:
|
462 |
+
model.config.update({"unsloth_version" : __version__})
|
463 |
+
|
464 |
+
bad_pad_token = False
|
465 |
+
if hasattr(tokenizer, "pad_token") and tokenizer.pad_token is not None:
|
466 |
+
# Check if pad_token is not the same as eos_token otherwise the loss will ignore it!!
|
467 |
+
bad_pad_token = tokenizer.eos_token == tokenizer.pad_token
|
468 |
+
elif hasattr(tokenizer, "pad_token") and tokenizer.pad_token is None:
|
469 |
+
bad_pad_token = True
|
470 |
+
else:
|
471 |
+
bad_pad_token = False
|
472 |
+
pass
|
473 |
+
|
474 |
+
if bad_pad_token:
|
475 |
+
# Find a better pad token
|
476 |
+
added_tokens = [str(x) for x in tokenizer.added_tokens_decoder.values()]
|
477 |
+
all_added_tokens = joiner.join(added_tokens[::-1])
|
478 |
+
all_added_tokens += joiner
|
479 |
+
|
480 |
+
final_pad_token = None
|
481 |
+
final_good_match = False
|
482 |
+
|
483 |
+
for possible_reserved_token in possible_reserved_tokens:
|
484 |
+
possible_reserved_token = re.escape(possible_reserved_token)
|
485 |
+
found = re.finditer(f"{possible_reserved_token}", all_added_tokens)
|
486 |
+
first_match = None
|
487 |
+
good_match = False
|
488 |
+
for j, x in enumerate(found):
|
489 |
+
if j == 0: first_match = x
|
490 |
+
if j >= number_repetitions:
|
491 |
+
good_match = True
|
492 |
+
break
|
493 |
+
pass
|
494 |
+
pass
|
495 |
+
|
496 |
+
if first_match is None: continue
|
497 |
+
|
498 |
+
# If it ends with |> or > etc, then set it as a good pad token!
|
499 |
+
start = first_match.span(0)[0]
|
500 |
+
possible_pad_token = first_match.group(0)
|
501 |
+
end = all_added_tokens.find(joiner, start)
|
502 |
+
first_match = all_added_tokens[start:end]
|
503 |
+
|
504 |
+
if first_match is not None:
|
505 |
+
good_match = possible_pad_token.endswith((">", "|>", "]", ")"))
|
506 |
+
pass
|
507 |
+
possible_pad_token = first_match
|
508 |
+
|
509 |
+
# Replace current pad token if another exact match is found
|
510 |
+
if not final_good_match and good_match:
|
511 |
+
final_good_match = True
|
512 |
+
final_pad_token = possible_pad_token
|
513 |
+
break
|
514 |
+
else:
|
515 |
+
final_good_match = False
|
516 |
+
final_pad_token = possible_pad_token
|
517 |
+
pass
|
518 |
+
pass
|
519 |
+
possible_pad_token = final_pad_token
|
520 |
+
|
521 |
+
# Try unk_token
|
522 |
+
if possible_pad_token is None and hasattr(tokenizer, "unk_token"):
|
523 |
+
possible_pad_token = tokenizer.unk_token
|
524 |
+
pass
|
525 |
+
|
526 |
+
# Check pad token's id must be less than vocab size
|
527 |
+
if possible_pad_token is not None:
|
528 |
+
check_pad_token = tokenizer(possible_pad_token, add_special_tokens = False).input_ids
|
529 |
+
if len(check_pad_token) != 1:
|
530 |
+
possible_pad_token = None
|
531 |
+
if model is not None and check_pad_token[0] >= model.config.vocab_size:
|
532 |
+
possible_pad_token = None
|
533 |
+
pass
|
534 |
+
|
535 |
+
if possible_pad_token is None:
|
536 |
+
# Failure to find a good replacement!! We shall manually add one!
|
537 |
+
new_pad_token = "<|PAD_TOKEN|>"
|
538 |
+
while new_pad_token in tokenizer.get_vocab():
|
539 |
+
new_pad_token = f"<{new_pad_token}>"
|
540 |
+
pass
|
541 |
+
possible_pad_token = new_pad_token
|
542 |
+
pass
|
543 |
+
|
544 |
+
name = model.config._name_or_path if model is not None else "Model"
|
545 |
+
logger.warning_once(
|
546 |
+
f"{name} does not have a padding token! Will use pad_token = {possible_pad_token}."
|
547 |
+
)
|
548 |
+
|
549 |
+
# Edit pad_token
|
550 |
+
tokenizer.add_special_tokens({"pad_token" : possible_pad_token})
|
551 |
+
tokenizer.pad_token = possible_pad_token
|
552 |
+
if model is not None:
|
553 |
+
model.config.update({"pad_token_id" : tokenizer.pad_token_id})
|
554 |
+
if getattr(model, "generation_config") is not None:
|
555 |
+
model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
|
556 |
+
else:
|
557 |
+
if model is not None:
|
558 |
+
if model.config.pad_token_id is None:
|
559 |
+
model.config.update({"pad_token_id" : tokenizer.pad_token_id})
|
560 |
+
if getattr(model, "generation_config") is not None:
|
561 |
+
model.generation_config.update(pad_token_id = tokenizer.pad_token_id)
|
562 |
+
pass
|
563 |
+
pass
|
564 |
+
|
565 |
+
if model is not None:
|
566 |
+
if getattr(model, "generation_config") is not None:
|
567 |
+
model.generation_config.update(max_length = model.config.max_position_embeddings)
|
568 |
+
|
569 |
+
return model, tokenizer
|
570 |
+
pass
|
571 |
+
|
572 |
+
|
573 |
+
# =============================================
|
574 |
+
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
|
575 |
+
# For mixed precision, we need it to be in float32 not float16.
|
576 |
+
from peft import __version__ as peft_version
|
577 |
+
if Version(peft_version) < Version("0.12.0"):
|
578 |
+
from peft.tuners.lora.layer import LoraLayer
|
579 |
+
try:
|
580 |
+
source = inspect.getsource(LoraLayer.update_layer)
|
581 |
+
text = "if weight is not None:\n"
|
582 |
+
start = source.find(text) + len(text)
|
583 |
+
end = source.find("self.to(weight.device)", start)
|
584 |
+
spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
|
585 |
+
source = source.replace(source[start : end], spaces)
|
586 |
+
spaces = len(re.match(r"[\s]{1,}", source).group(0))
|
587 |
+
lines = source.split("\n")
|
588 |
+
source = "\n".join(x[spaces:] for x in lines)
|
589 |
+
source = re.sub("([^\.])nn\.", r"\1torch.nn.", source)
|
590 |
+
source = source.replace("def update_layer", "def LoraLayer_update_layer")
|
591 |
+
exec(source, globals())
|
592 |
+
|
593 |
+
# Fix up incorrect downcasting of LoRA weights
|
594 |
+
from peft.tuners.lora.layer import LoraLayer
|
595 |
+
LoraLayer.update_layer = LoraLayer_update_layer
|
596 |
+
from peft.tuners.lora import LoraLayer
|
597 |
+
LoraLayer.update_layer = LoraLayer_update_layer
|
598 |
+
except:
|
599 |
+
logger.warning_once(
|
600 |
+
"Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"\
|
601 |
+
"Luckily, your training run will still work in the meantime!"
|
602 |
+
)
|
603 |
+
pass
|
604 |
+
pass
|
605 |
+
# =============================================
|
606 |
+
|
607 |
+
import psutil
|
608 |
+
def _get_statistics(statistics = None, force_download = True):
|
609 |
+
# We log some basic stats about which environment is being used.
|
610 |
+
# We simply download a README.md file from HF - all data is made public.
|
611 |
+
# This is simply so we can check if some envs are broken or not.
|
612 |
+
# You can disable this by commenting the below out
|
613 |
+
try:
|
614 |
+
n_cpus = psutil.cpu_count(logical = False)
|
615 |
+
keynames = "\n" + "\n".join(os.environ.keys())
|
616 |
+
if statistics is not None: pass
|
617 |
+
elif "\nCOLAB_" in keynames and n_cpus == 1: statistics = "colab"
|
618 |
+
elif "\nCOLAB_" in keynames: statistics = "colabpro"
|
619 |
+
elif "\nKAGGLE_" in keynames: statistics = "kaggle"
|
620 |
+
elif "\nRUNPOD_" in keynames: statistics = "runpod"
|
621 |
+
elif "\nAWS_" in keynames: statistics = "aws"
|
622 |
+
elif "\nAZURE_" in keynames: statistics = "azure"
|
623 |
+
# elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
|
624 |
+
elif "\nINVOCATION_ID" in keynames: statistics = "lambda"
|
625 |
+
# else: statistics = "other"
|
626 |
+
else:
|
627 |
+
def try_vllm_check():
|
628 |
+
vendor_files = (
|
629 |
+
"/sys/class/dmi/id/product_version",
|
630 |
+
"/sys/class/dmi/id/bios_vendor",
|
631 |
+
"/sys/class/dmi/id/product_name",
|
632 |
+
"/sys/class/dmi/id/chassis_asset_tag",
|
633 |
+
"/sys/class/dmi/id/sys_vendor",
|
634 |
+
)
|
635 |
+
from pathlib import Path
|
636 |
+
for vendor_file in vendor_files:
|
637 |
+
path = Path(vendor_file)
|
638 |
+
if path.is_file():
|
639 |
+
file_content = path.read_text().lower()
|
640 |
+
if "amazon" in file_content: return "aws"
|
641 |
+
elif "microsoft corporation" in file_content: return "azure"
|
642 |
+
elif "google" in file_content: return "gcp"
|
643 |
+
return "other"
|
644 |
+
pass
|
645 |
+
try: statistics = try_vllm_check()
|
646 |
+
except: statistics = "other"
|
647 |
+
pass
|
648 |
+
if statistics is not None:
|
649 |
+
from transformers import AutoModelForCausalLM
|
650 |
+
stats_model = AutoModelForCausalLM.from_pretrained(
|
651 |
+
f"unslothai/{statistics}",
|
652 |
+
force_download = force_download,
|
653 |
+
)
|
654 |
+
del stats_model
|
655 |
+
pass
|
656 |
+
except:
|
657 |
+
pass
|
658 |
+
pass
|
659 |
+
|
660 |
+
|
661 |
+
def get_statistics():
|
662 |
+
# We log some basic stats about which environment is being used.
|
663 |
+
# We simply download a README.md file from HF - all data is made public.
|
664 |
+
# This is simply so we can check if some envs are broken or not.
|
665 |
+
# You can disable this by commenting the below out
|
666 |
+
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars, are_progress_bars_disabled
|
667 |
+
disabled = False
|
668 |
+
if not are_progress_bars_disabled():
|
669 |
+
disable_progress_bars()
|
670 |
+
disabled = True
|
671 |
+
pass
|
672 |
+
_get_statistics(None)
|
673 |
+
_get_statistics("repeat", force_download = False)
|
674 |
+
try:
|
675 |
+
vram = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 / 1024
|
676 |
+
if vram <= 8 : vram = 8
|
677 |
+
elif vram <= 16: vram = 16
|
678 |
+
elif vram <= 20: vram = 20
|
679 |
+
elif vram <= 24: vram = 24
|
680 |
+
elif vram <= 40: vram = 40
|
681 |
+
elif vram <= 48: vram = 48
|
682 |
+
elif vram <= 80: vram = 80
|
683 |
+
else: vram = 96
|
684 |
+
_get_statistics(f"vram-{vram}")
|
685 |
+
except:
|
686 |
+
pass
|
687 |
+
pass
|
688 |
+
try:
|
689 |
+
devices = torch.cuda.device_count()
|
690 |
+
_get_statistics(f"{devices if devices <= 8 else 9}")
|
691 |
+
except:
|
692 |
+
pass
|
693 |
+
if disabled: enable_progress_bars()
|
694 |
+
pass
|
695 |
+
|
696 |
+
|
697 |
+
def _calculate_n_gradient_checkpoints(
|
698 |
+
n_layers : int,
|
699 |
+
method : Optional[Union[str, int]] = "sqrt",
|
700 |
+
) -> List[int]:
|
701 |
+
assert(type(n_layers) is int and n_layers > 0)
|
702 |
+
|
703 |
+
if method is None: method = "sqrt"
|
704 |
+
|
705 |
+
if method == "sqrt":
|
706 |
+
n_checkpoints = int(n_layers**0.5)
|
707 |
+
elif type(method) is int and method > 0:
|
708 |
+
n_checkpoints = int(np.ceil(n_layers / method))
|
709 |
+
else:
|
710 |
+
raise ValueError("method must be 'sqrt' or an int >0 and <= n_layers.")
|
711 |
+
|
712 |
+
size = n_layers // n_checkpoints
|
713 |
+
sizes = np.full(n_checkpoints, size, dtype = int)
|
714 |
+
leftovers = n_layers % n_checkpoints
|
715 |
+
# We append leftovers from the right
|
716 |
+
for k in range(leftovers):
|
717 |
+
sizes[n_checkpoints-1-k] += 1
|
718 |
+
boundaries = np.hstack((0, np.cumsum(sizes)))
|
719 |
+
boundaries = boundaries.tolist()
|
720 |
+
return boundaries
|
721 |
+
pass
|
722 |
+
|
723 |
+
|
724 |
+
def calculate_n_gradient_checkpoints(
|
725 |
+
n_layers : int,
|
726 |
+
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
|
727 |
+
) -> List[int]:
|
728 |
+
assert(type(n_layers) is int and n_layers > 0)
|
729 |
+
|
730 |
+
if layers_per_checkpoint is None or layers_per_checkpoint == 1:
|
731 |
+
return None
|
732 |
+
|
733 |
+
boundaries = _calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
|
734 |
+
|
735 |
+
assert(boundaries[0] == 0 and boundaries[-1] == n_layers)
|
736 |
+
assert(min(boundaries) == 0 and max(boundaries) == n_layers)
|
737 |
+
assert(np.diff(boundaries).min() >= 0)
|
738 |
+
return boundaries
|
739 |
+
pass
|
740 |
+
|
741 |
+
|
742 |
+
def prepare_n_gradient_checkpoints(
|
743 |
+
model : Any,
|
744 |
+
layers_per_checkpoint : Optional[Union[str, int]] = "sqrt",
|
745 |
+
use_reentrant : Optional[bool] = True,
|
746 |
+
) -> None:
|
747 |
+
"""
|
748 |
+
Calculates where to place the gradient checkpoints given n_layers.
|
749 |
+
|
750 |
+
Args:
|
751 |
+
model: Any LlamaModel with layers.
|
752 |
+
layers_per_checkpoint (`Union[str, int]`, *optional*):
|
753 |
+
Can either be `sqrt` or an integer for how many layers per checkpoint you want.
|
754 |
+
The more, the less memory usage, but can be slower. Default is `sqrt`.
|
755 |
+
Choose 1 for Pytorch gradient checkpointing. 2 to wrap 2 layers in 1 module etc.
|
756 |
+
use_reentrant (`bool`, *optional*):
|
757 |
+
https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L354
|
758 |
+
Optimal gradient checkpointing algorithm `use_reentrant=False` which will
|
759 |
+
be the default in future Pytorch versions doesn't seem to work??
|
760 |
+
"""
|
761 |
+
_model = None
|
762 |
+
if hasattr(model, "layers"):
|
763 |
+
_model = model
|
764 |
+
elif hasattr(model, "model"):
|
765 |
+
if hasattr(model.model, "layers"):
|
766 |
+
_model = model.model
|
767 |
+
if _model is None:
|
768 |
+
raise TypeError("`model` or `model.model` does not have attribute `layers`. Are you sure this is a model?")
|
769 |
+
pass
|
770 |
+
|
771 |
+
if use_reentrant is False:
|
772 |
+
use_reentrant = True
|
773 |
+
pass
|
774 |
+
|
775 |
+
n_layers = len(_model.layers)
|
776 |
+
boundaries = calculate_n_gradient_checkpoints(n_layers, layers_per_checkpoint)
|
777 |
+
_model._gradient_checkpointing_boundaries = boundaries
|
778 |
+
_model._gradient_checkpointing_use_reentrant = use_reentrant
|
779 |
+
pass
|
780 |
+
|
781 |
+
|
782 |
+
class Unsloth_Offloaded_Gradient_Checkpointer(torch.autograd.Function):
|
783 |
+
"""
|
784 |
+
Saves VRAM by smartly offloading to RAM.
|
785 |
+
Tiny hit to performance, since we mask the movement via non blocking calls.
|
786 |
+
"""
|
787 |
+
@staticmethod
|
788 |
+
@torch_amp_custom_fwd
|
789 |
+
def forward(ctx, forward_function, hidden_states, *args):
|
790 |
+
saved_hidden_states = hidden_states.to("cpu", non_blocking = True)
|
791 |
+
with torch.no_grad():
|
792 |
+
output = forward_function(hidden_states, *args)
|
793 |
+
ctx.save_for_backward(saved_hidden_states)
|
794 |
+
ctx.forward_function = forward_function
|
795 |
+
ctx.args = args
|
796 |
+
return output
|
797 |
+
pass
|
798 |
+
|
799 |
+
@staticmethod
|
800 |
+
@torch_amp_custom_bwd
|
801 |
+
def backward(ctx, dY):
|
802 |
+
(hidden_states,) = ctx.saved_tensors
|
803 |
+
hidden_states = hidden_states.to("cuda:0", non_blocking = True).detach()
|
804 |
+
hidden_states.requires_grad_(True)
|
805 |
+
with torch.enable_grad():
|
806 |
+
(output,) = ctx.forward_function(hidden_states, *ctx.args)
|
807 |
+
torch.autograd.backward(output, dY)
|
808 |
+
return (None, hidden_states.grad,) + (None,)*len(ctx.args)
|
809 |
+
pass
|
810 |
+
pass
|
811 |
+
|
812 |
+
|
813 |
+
@torch._disable_dynamo
|
814 |
+
def unsloth_offloaded_gradient_checkpoint(function, *args, use_reentrant = None, **kwargs):
|
815 |
+
return Unsloth_Offloaded_Gradient_Checkpointer.apply(function, *args)
|
816 |
+
pass
|
817 |
+
|
818 |
+
|
819 |
+
import torch.utils
|
820 |
+
old_checkpoint = torch.utils.checkpoint
|
821 |
+
def patch_gradient_checkpointing():
|
822 |
+
torch.utils.checkpoint = unsloth_offloaded_gradient_checkpoint
|
823 |
+
pass
|
824 |
+
|
825 |
+
def unpatch_gradient_checkpointing():
|
826 |
+
torch.utils.checkpoint = old_checkpoint
|
827 |
+
pass
|
828 |
+
|
829 |
+
|
830 |
+
# =============================================
|
831 |
+
# Fixes Bitsandbytes to remove missing warnings
|
832 |
+
from transformers.utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
|
833 |
+
from inspect import getsource
|
834 |
+
from accelerate.utils.dataclasses import DistributedType
|
835 |
+
BitsAndBytesConfig__init__ = getsource(BitsAndBytesConfig.__init__)
|
836 |
+
BitsAndBytesConfig__init__ = re.sub(
|
837 |
+
r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
|
838 |
+
"",
|
839 |
+
BitsAndBytesConfig__init__,
|
840 |
+
flags = re.MULTILINE,
|
841 |
+
)
|
842 |
+
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
|
843 |
+
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
|
844 |
+
BitsAndBytesConfig__init__ = "\n".join(x[length_spaces:] for x in BitsAndBytesConfig__init__)
|
845 |
+
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
|
846 |
+
"__init__",
|
847 |
+
"_BitsAndBytesConfig__init__",
|
848 |
+
)
|
849 |
+
|
850 |
+
def _prepare_backend(
|
851 |
+
self, cpu: bool = False, sagemaker_dp = False, backend: str = None,
|
852 |
+
) -> tuple[str, DistributedType]:
|
853 |
+
return None, DistributedType.NO
|
854 |
+
pass
|
855 |
+
import accelerate.state
|
856 |
+
accelerate.state.PartialState._prepare_backend = _prepare_backend
|
857 |
+
|
858 |
+
import accelerate.accelerator
|
859 |
+
prepare = inspect.getsource(accelerate.accelerator.Accelerator.prepare)
|
860 |
+
prepare = prepare.split("\n")
|
861 |
+
spaces = prepare[0].find("def")
|
862 |
+
prepare = "\n".join(x[spaces:] for x in prepare)
|
863 |
+
x = "for obj in args:"
|
864 |
+
s = " "*spaces
|
865 |
+
prepare = prepare.replace(x, f'self.state.distributed_type = DistributedType.NO\n{s}{x}', 1)
|
866 |
+
exec(prepare, globals())
|
867 |
+
accelerate.accelerator.Accelerator.prepare = prepare
|
868 |
+
|
869 |
+
exec(BitsAndBytesConfig__init__, globals())
|
870 |
+
|
871 |
+
import transformers.utils.quantization_config
|
872 |
+
transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = _BitsAndBytesConfig__init__
|
873 |
+
# =============================================
|
874 |
+
|
875 |
+
# Offloading to disk for modules (lm_head, embed_tokens)
|
876 |
+
import pickle
|
877 |
+
|
878 |
+
def offload_to_disk(W, model, name, temporary_location : str = "_unsloth_temporary_saved_buffers"):
|
879 |
+
file_location = os.path.join(temporary_location, model.config._name_or_path)
|
880 |
+
if not os.path.exists(file_location):
|
881 |
+
os.makedirs(file_location)
|
882 |
+
pass
|
883 |
+
|
884 |
+
filename = os.path.join(file_location, f"{name}.pt")
|
885 |
+
W = W.weight if hasattr(W, "weight") else W
|
886 |
+
torch.save(W, filename, pickle_module = pickle, pickle_protocol = pickle.HIGHEST_PROTOCOL,)
|
887 |
+
offloaded_W = torch.load(filename, map_location = "cpu", mmap = True)
|
888 |
+
offloaded_W._offloaded_file_location = filename
|
889 |
+
return offloaded_W
|
890 |
+
pass
|
891 |
+
|
892 |
+
|
893 |
+
def offload_input_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
|
894 |
+
offloaded_W = offload_to_disk(model.get_input_embeddings(), model, "input_embeddings", temporary_location)
|
895 |
+
new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
|
896 |
+
new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
|
897 |
+
model.set_input_embeddings(new_input_embeddings)
|
898 |
+
return
|
899 |
+
pass
|
900 |
+
|
901 |
+
|
902 |
+
def offload_output_embeddings(model, temporary_location : str = "_unsloth_temporary_saved_buffers"):
|
903 |
+
offloaded_W = offload_to_disk(model.get_output_embeddings(), model, "output_embeddings", temporary_location)
|
904 |
+
|
905 |
+
new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
|
906 |
+
del new_output_embeddings.weight
|
907 |
+
new_output_embeddings.weight = offloaded_W
|
908 |
+
new_output_embeddings.in_features = offloaded_W.shape[1]
|
909 |
+
new_output_embeddings.out_features = offloaded_W.shape[0]
|
910 |
+
|
911 |
+
new_output_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
|
912 |
+
model.set_output_embeddings(new_output_embeddings)
|
913 |
+
return
|
914 |
+
pass
|
915 |
+
|
916 |
+
|
917 |
+
# Fixes a weird Torch 2.3 bug which says T4s have bfloat16
|
918 |
+
def is_bfloat16_supported():
|
919 |
+
return SUPPORTS_BFLOAT16
|
920 |
+
pass
|
921 |
+
|
922 |
+
|
923 |
+
# Patches models to add RoPE Scaling
|
924 |
+
def patch_linear_scaling(
|
925 |
+
model_name = "gemma2",
|
926 |
+
rope_module = None,
|
927 |
+
scaled_rope_module = None,
|
928 |
+
attention_module = None,
|
929 |
+
):
|
930 |
+
assert(rope_module is not None and scaled_rope_module is not None)
|
931 |
+
assert(attention_module is not None)
|
932 |
+
|
933 |
+
rope_name = rope_module.__name__
|
934 |
+
scaled_rope_name = scaled_rope_module.__name__
|
935 |
+
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
|
936 |
+
exec_code = \
|
937 |
+
f"import torch.nn as nn\n"\
|
938 |
+
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
|
939 |
+
f"from {model_filepath} import logger, "\
|
940 |
+
f"{model_name.title()}Attention, {model_name.title()}Config"
|
941 |
+
|
942 |
+
try:
|
943 |
+
function = inspect.getsource(attention_module.__init__)
|
944 |
+
except:
|
945 |
+
# Most likely already patched!
|
946 |
+
return None, None
|
947 |
+
where = function.find("def")
|
948 |
+
function = function.split("\n")
|
949 |
+
function = "\n".join(x[where:] for x in function)
|
950 |
+
init_name = f"{model_name.title()}Attention__init__"
|
951 |
+
function = function.replace("def __init__", f"def {init_name}")
|
952 |
+
function = function.replace(
|
953 |
+
"super().__init__()",
|
954 |
+
f"super({model_name.title()}Attention, self).__init__()",
|
955 |
+
)
|
956 |
+
fix_rope_function = """
|
957 |
+
if getattr(self.config, "rope_scaling", None) is None:
|
958 |
+
self.rotary_emb = {rope_function}(
|
959 |
+
dim = self.head_dim,
|
960 |
+
max_position_embeddings=self.max_position_embeddings,
|
961 |
+
base=self.rope_theta,
|
962 |
+
)
|
963 |
+
else:
|
964 |
+
scaling_type = self.config.rope_scaling["type"]
|
965 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
966 |
+
if scaling_type == "linear":
|
967 |
+
self.rotary_emb = {scaled_rope_function}(
|
968 |
+
dim = self.head_dim,
|
969 |
+
max_position_embeddings=self.max_position_embeddings,
|
970 |
+
scaling_factor=scaling_factor,
|
971 |
+
base=self.rope_theta,
|
972 |
+
)
|
973 |
+
else:
|
974 |
+
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
|
975 |
+
pass
|
976 |
+
"""
|
977 |
+
fix_rope_function = fix_rope_function.format(
|
978 |
+
rope_function = rope_module.__name__,
|
979 |
+
scaled_rope_function = scaled_rope_module.__name__,
|
980 |
+
)
|
981 |
+
rotary_emb = re.findall(
|
982 |
+
"self.rotary_emb = .+?\)", function,
|
983 |
+
flags = re.DOTALL | re.MULTILINE,
|
984 |
+
)
|
985 |
+
if len(rotary_emb) == 0: return None, function
|
986 |
+
rotary_emb = rotary_emb[0]
|
987 |
+
function = function.replace(rotary_emb, fix_rope_function, 1)
|
988 |
+
function = exec_code + "\n\n" + function
|
989 |
+
return init_name, function
|
990 |
+
pass
|
991 |
+
|
992 |
+
|
993 |
+
# Patches for Llama-3 LlamaExtendedRotaryEmbedding
|
994 |
+
def patch_llama_rope_scaling(
|
995 |
+
model_name = "llama",
|
996 |
+
rope_module = None,
|
997 |
+
scaled_rope_module = None,
|
998 |
+
extended_rope_module = None,
|
999 |
+
attention_module = None,
|
1000 |
+
longrope_module = None,
|
1001 |
+
):
|
1002 |
+
assert(\
|
1003 |
+
rope_module is not None and \
|
1004 |
+
scaled_rope_module is not None and \
|
1005 |
+
extended_rope_module is not None
|
1006 |
+
)
|
1007 |
+
assert(attention_module is not None)
|
1008 |
+
|
1009 |
+
rope_name = rope_module.__name__
|
1010 |
+
scaled_rope_name = scaled_rope_module.__name__
|
1011 |
+
model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
|
1012 |
+
exec_code = \
|
1013 |
+
f"import torch.nn as nn\n"\
|
1014 |
+
f"from typing import Union, Optional, List, Any, Callable, Tuple\n"\
|
1015 |
+
f"from {model_filepath} import logger, "\
|
1016 |
+
f"{model_name.title()}Attention, {model_name.title()}Config"
|
1017 |
+
|
1018 |
+
try:
|
1019 |
+
function = inspect.getsource(attention_module.__init__)
|
1020 |
+
except:
|
1021 |
+
# Most likely already patched!
|
1022 |
+
return None, None
|
1023 |
+
where = function.find("def")
|
1024 |
+
function = function.split("\n")
|
1025 |
+
function = "\n".join(x[where:] for x in function)
|
1026 |
+
init_name = f"{model_name.title()}Attention__init__"
|
1027 |
+
function = function.replace("def __init__", f"def {init_name}")
|
1028 |
+
function = function.replace(
|
1029 |
+
"super().__init__()",
|
1030 |
+
f"super({model_name.title()}Attention, self).__init__()",
|
1031 |
+
)
|
1032 |
+
fix_rope_function = """
|
1033 |
+
if getattr(self.config, "rope_scaling", None) is None:
|
1034 |
+
self.rotary_emb = {rope_function}(
|
1035 |
+
dim = self.head_dim,
|
1036 |
+
max_position_embeddings=self.max_position_embeddings,
|
1037 |
+
base=self.rope_theta,
|
1038 |
+
)
|
1039 |
+
else:
|
1040 |
+
scaling_type1 = self.config.rope_scaling.get("type", None)
|
1041 |
+
scaling_type2 = self.config.rope_scaling.get("rope_type", None)
|
1042 |
+
scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
|
1043 |
+
scaling_factor = self.config.rope_scaling.get("factor")
|
1044 |
+
|
1045 |
+
if scaling_type == "linear":
|
1046 |
+
self.rotary_emb = {scaled_rope_function}(
|
1047 |
+
dim = self.head_dim,
|
1048 |
+
max_position_embeddings=self.max_position_embeddings,
|
1049 |
+
scaling_factor=scaling_factor,
|
1050 |
+
base=self.rope_theta,
|
1051 |
+
)
|
1052 |
+
elif scaling_type == "llama3":
|
1053 |
+
self.rotary_emb = {extended_rope_function}(
|
1054 |
+
dim = self.head_dim,
|
1055 |
+
max_position_embeddings=self.max_position_embeddings,
|
1056 |
+
base=self.rope_theta,
|
1057 |
+
)
|
1058 |
+
elif scaling_type == "longrope":
|
1059 |
+
self.rotary_emb = {longrope_rope_function}(
|
1060 |
+
dim = self.head_dim,
|
1061 |
+
max_position_embeddings = self.max_position_embeddings,
|
1062 |
+
original_max_position_embeddings = self.config.original_max_position_embeddings,
|
1063 |
+
base = self.rope_theta,
|
1064 |
+
short_factor = self.config.rope_scaling['short_factor'],
|
1065 |
+
long_factor = self.config.rope_scaling['long_factor' ],
|
1066 |
+
)
|
1067 |
+
else:
|
1068 |
+
raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
|
1069 |
+
pass
|
1070 |
+
"""
|
1071 |
+
|
1072 |
+
fix_rope_function = fix_rope_function.format(
|
1073 |
+
rope_function = rope_module.__name__,
|
1074 |
+
scaled_rope_function = scaled_rope_module.__name__,
|
1075 |
+
extended_rope_function = extended_rope_module.__name__,
|
1076 |
+
longrope_rope_function = \
|
1077 |
+
(longrope_module if longrope_module is not None else rope_module).__name__
|
1078 |
+
)
|
1079 |
+
rotary_emb = re.findall(
|
1080 |
+
"self.rotary_emb = .+?\)", function,
|
1081 |
+
flags = re.DOTALL | re.MULTILINE,
|
1082 |
+
)
|
1083 |
+
if len(rotary_emb) == 0: return None, function
|
1084 |
+
rotary_emb = rotary_emb[0]
|
1085 |
+
function = function.replace(rotary_emb, fix_rope_function, 1)
|
1086 |
+
function = exec_code + "\n\n" + function
|
1087 |
+
return init_name, function
|
1088 |
+
pass
|
1089 |
+
|
1090 |
+
|
1091 |
+
def check_nvidia():
|
1092 |
+
# Unsloth doesn't work yet on AMD devices - we're working on it!
|
1093 |
+
output = np.array([0,])
|
1094 |
+
try:
|
1095 |
+
output = subprocess.check_output("nvidia-smi --query-gpu=memory.used --format=csv", shell = True)
|
1096 |
+
output = re.findall(rb'([\d]{1,})[\s]{1,}M', output)
|
1097 |
+
output = np.array([int(x.decode('utf-8'))/1024 for x in output])
|
1098 |
+
except:
|
1099 |
+
if not torch.cuda.is_available():
|
1100 |
+
raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!")
|
1101 |
+
return output
|
1102 |
+
pass
|
1103 |
+
PRE_CHECK = check_nvidia()
|
1104 |
+
|
1105 |
+
|
1106 |
+
def create_boolean_mask(n = 4096, sliding_window = 2048):
|
1107 |
+
# Creates a boolean mask for attention
|
1108 |
+
mask = torch.ones(n, n, dtype = torch.bool)
|
1109 |
+
if sliding_window == 0:
|
1110 |
+
return torch.triu(mask, diagonal = 1, out = mask)
|
1111 |
+
pass
|
1112 |
+
torch.triu(mask, diagonal = 0, out = mask)
|
1113 |
+
torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
|
1114 |
+
mask = mask.T
|
1115 |
+
torch.logical_not(mask, out = mask)
|
1116 |
+
return mask
|
1117 |
+
pass
|
1118 |
+
|
1119 |
+
|
1120 |
+
def test_mask_creation():
|
1121 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
1122 |
+
for n in range(2, 23):
|
1123 |
+
for s in range(1, 23):
|
1124 |
+
correct_mask = AttentionMaskConverter(
|
1125 |
+
is_causal = True,
|
1126 |
+
sliding_window = s,
|
1127 |
+
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
|
1128 |
+
correct_mask = (correct_mask == correct_mask.min())
|
1129 |
+
our_mask = create_boolean_mask(n = n, sliding_window = s)
|
1130 |
+
assert(torch.all(correct_mask == our_mask))
|
1131 |
+
pass
|
1132 |
+
correct_mask = AttentionMaskConverter(
|
1133 |
+
is_causal = True,
|
1134 |
+
sliding_window = None,
|
1135 |
+
).to_causal_4d(1, n, n, dtype = torch.float16,).squeeze(0).squeeze(0)
|
1136 |
+
correct_mask = (correct_mask == correct_mask.min())
|
1137 |
+
our_mask = create_boolean_mask(n = n, sliding_window = 0)
|
1138 |
+
assert(torch.all(correct_mask == our_mask))
|
1139 |
+
pass
|
1140 |
+
pass
|
unsloth-main/unsloth-main/unsloth/models/cohere.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .llama import *
|
16 |
+
from ._utils import __version__
|
17 |
+
try:
|
18 |
+
from transformers.models.cohere.modeling_cohere import (
|
19 |
+
CohereAttention,
|
20 |
+
CohereDecoderLayer,
|
21 |
+
CohereModel,
|
22 |
+
CohereForCausalLM,
|
23 |
+
CohereRotaryEmbedding,
|
24 |
+
apply_rotary_pos_emb,
|
25 |
+
repeat_kv,
|
26 |
+
)
|
27 |
+
except:
|
28 |
+
from packaging.version import Version
|
29 |
+
transformers_version = Version(transformers_version)
|
30 |
+
if not transformers_version >= Version("4.42"):
|
31 |
+
raise ImportError(
|
32 |
+
f"Unsloth: Your transformers version of {transformers_version} does not support Cohere.\n"\
|
33 |
+
f"The minimum required version is 4.42.3.\n"\
|
34 |
+
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
|
35 |
+
f"to obtain the latest transformers build, then restart this session."\
|
36 |
+
)
|
37 |
+
pass
|
38 |
+
pass
|
39 |
+
|
40 |
+
from transformers.modeling_attn_mask_utils import (
|
41 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
42 |
+
)
|
43 |
+
# For Pytorch 2.1.1
|
44 |
+
try:
|
45 |
+
from transformers.models.cohere.modeling_cohere import (
|
46 |
+
CohereSdpaAttention,
|
47 |
+
CohereFlashAttention2,
|
48 |
+
)
|
49 |
+
except:
|
50 |
+
CohereSdpaAttention = CohereAttention
|
51 |
+
CohereFlashAttention2 = CohereAttention
|
52 |
+
pass
|
53 |
+
|
54 |
+
|
55 |
+
def fast_layernorm_inference(self, X, out_weight = None):
|
56 |
+
XX = X.to(torch.float32, copy = True)
|
57 |
+
XX -= X.mean(-1, keepdim = True)
|
58 |
+
variance = XX.square().mean(-1, keepdim = True)
|
59 |
+
variance += self.variance_epsilon
|
60 |
+
XX *= variance.rsqrt_()
|
61 |
+
out_weight[:] = self.weight
|
62 |
+
XX *= out_weight
|
63 |
+
return XX.to(X.dtype)
|
64 |
+
pass
|
65 |
+
|
66 |
+
|
67 |
+
# QK norm in Cohere
|
68 |
+
def CohereAttention_fast_forward(
|
69 |
+
self,
|
70 |
+
hidden_states: torch.Tensor,
|
71 |
+
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
72 |
+
attention_mask: Optional[torch.Tensor] = None,
|
73 |
+
position_ids: Optional[torch.LongTensor] = None,
|
74 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
75 |
+
output_attentions: bool = False,
|
76 |
+
use_cache: bool = False,
|
77 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
78 |
+
*args, **kwargs,
|
79 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
80 |
+
|
81 |
+
# Clear inference
|
82 |
+
if hasattr(self, "paged_attention"):
|
83 |
+
del self.paged_attention_K
|
84 |
+
del self.paged_attention_V
|
85 |
+
del self.paged_attention
|
86 |
+
del self.temp_QA
|
87 |
+
del self.temp_KV
|
88 |
+
del self.RH_Q
|
89 |
+
del self.attention
|
90 |
+
del self.q_norm_out_weight
|
91 |
+
del self.k_norm_out_weight
|
92 |
+
pass
|
93 |
+
|
94 |
+
bsz, q_len, _ = hidden_states.size()
|
95 |
+
|
96 |
+
n_heads = self.num_heads
|
97 |
+
n_groups = self.num_key_value_groups
|
98 |
+
n_kv_heads = self.num_key_value_heads
|
99 |
+
head_dim = self.head_dim
|
100 |
+
assert(n_kv_heads * n_groups == n_heads)
|
101 |
+
|
102 |
+
Q, K, V = self.apply_qkv(self, hidden_states)
|
103 |
+
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
104 |
+
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
105 |
+
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
106 |
+
if self.use_qk_norm:
|
107 |
+
Q = fast_layernorm_compiled(self.q_norm, Q)
|
108 |
+
K = fast_layernorm_compiled(self.k_norm, K)
|
109 |
+
pass
|
110 |
+
|
111 |
+
kv_seq_len = K.shape[-2]
|
112 |
+
if past_key_value is not None:
|
113 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
114 |
+
|
115 |
+
if position_ids is None:
|
116 |
+
cos = self.rotary_emb.cos_cached
|
117 |
+
sin = self.rotary_emb.sin_cached
|
118 |
+
Q, K = fast_rope_embedding(Q, K, cos, sin)
|
119 |
+
else:
|
120 |
+
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
|
121 |
+
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
122 |
+
pass
|
123 |
+
|
124 |
+
if past_key_value is not None:
|
125 |
+
K = torch.cat([past_key_value[0], K], dim = 2)
|
126 |
+
V = torch.cat([past_key_value[1], V], dim = 2)
|
127 |
+
pass
|
128 |
+
past_key_value = (K, V) if use_cache else None
|
129 |
+
|
130 |
+
# Attention module
|
131 |
+
if (not HAS_FLASH_ATTENTION and attention_mask is None):
|
132 |
+
# Xformers memory efficient attention
|
133 |
+
# Also has Flash Attention v2 dispatching
|
134 |
+
Q = Q.transpose(1, 2)
|
135 |
+
K = K.transpose(1, 2)
|
136 |
+
V = V.transpose(1, 2)
|
137 |
+
|
138 |
+
# Group query attention
|
139 |
+
if n_groups != 1:
|
140 |
+
K = K .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
|
141 |
+
V = V .view(bsz, kv_seq_len, n_kv_heads, 1, head_dim)
|
142 |
+
K = K.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
|
143 |
+
V = V.expand(bsz, kv_seq_len, n_kv_heads, n_groups, head_dim)
|
144 |
+
if hidden_states.requires_grad:
|
145 |
+
K = K.reshape(bsz, kv_seq_len, n_heads, head_dim)
|
146 |
+
V = V.reshape(bsz, kv_seq_len, n_heads, head_dim)
|
147 |
+
else:
|
148 |
+
Q = Q.view(bsz, q_len, n_kv_heads, n_groups, head_dim)
|
149 |
+
pass
|
150 |
+
A = xformers_attention(Q, K, V, attn_bias = causal_mask)
|
151 |
+
A = A.view(bsz, q_len, n_heads, head_dim)
|
152 |
+
|
153 |
+
elif HAS_FLASH_ATTENTION and attention_mask is None:
|
154 |
+
Q = Q.transpose(1, 2)
|
155 |
+
K = K.transpose(1, 2)
|
156 |
+
V = V.transpose(1, 2)
|
157 |
+
A = flash_attn_func(Q, K, V, causal = True)
|
158 |
+
else:
|
159 |
+
# Grouped query attention
|
160 |
+
if n_groups != 1:
|
161 |
+
K = K[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
|
162 |
+
V = V[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, kv_seq_len, head_dim)
|
163 |
+
K = K.reshape(bsz, n_heads, kv_seq_len, head_dim)
|
164 |
+
V = V.reshape(bsz, n_heads, kv_seq_len, head_dim)
|
165 |
+
pass
|
166 |
+
# Must be contiguous or else results are False!
|
167 |
+
# https://github.com/pytorch/pytorch/issues/112577
|
168 |
+
Q, K, V = Q.contiguous(), K.contiguous(), V.contiguous()
|
169 |
+
# Needs (batch_size, n_heads, seq_len, head_dim)
|
170 |
+
# is_casual and attention_mask must not be both set!
|
171 |
+
A = scaled_dot_product_attention(Q, K, V, attn_mask = attention_mask, is_causal = False)
|
172 |
+
# Go back to (batch_size, seq_len, n_heads, head_dim)
|
173 |
+
A = A.transpose(1, 2).contiguous()
|
174 |
+
pass
|
175 |
+
attn_output = A.reshape(bsz, q_len, n_heads*head_dim)
|
176 |
+
attn_output = self.apply_o(self, attn_output)
|
177 |
+
attn_weights = None
|
178 |
+
return attn_output, attn_weights, past_key_value
|
179 |
+
pass
|
180 |
+
|
181 |
+
|
182 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
|
183 |
+
def CohereDecoderLayer_fast_forward(
|
184 |
+
self,
|
185 |
+
hidden_states: torch.Tensor,
|
186 |
+
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
187 |
+
attention_mask: Optional[torch.Tensor] = None,
|
188 |
+
position_ids: Optional[torch.LongTensor] = None,
|
189 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
190 |
+
output_attentions: Optional[bool] = False,
|
191 |
+
use_cache: Optional[bool] = False,
|
192 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
193 |
+
*args, **kwargs,
|
194 |
+
):
|
195 |
+
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
|
196 |
+
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
|
197 |
+
|
198 |
+
# Self Attention
|
199 |
+
residual = hidden_states
|
200 |
+
hidden_states = fast_layernorm_inference(self.input_layernorm, hidden_states, out_weight)
|
201 |
+
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
|
202 |
+
hidden_states=hidden_states,
|
203 |
+
causal_mask=causal_mask,
|
204 |
+
attention_mask=attention_mask,
|
205 |
+
position_ids=position_ids,
|
206 |
+
past_key_value=past_key_value,
|
207 |
+
output_attentions=output_attentions,
|
208 |
+
use_cache=use_cache,
|
209 |
+
padding_mask=padding_mask,
|
210 |
+
)
|
211 |
+
|
212 |
+
# Fully Connected
|
213 |
+
hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
|
214 |
+
residual += hidden_states_attention
|
215 |
+
residual += hidden_states_mlp
|
216 |
+
hidden_states = residual
|
217 |
+
else:
|
218 |
+
residual = hidden_states
|
219 |
+
hidden_states = fast_layernorm_compiled(self.input_layernorm, hidden_states)
|
220 |
+
hidden_states_attention, self_attn_weights, present_key_value = self.self_attn(
|
221 |
+
hidden_states=hidden_states,
|
222 |
+
causal_mask=causal_mask,
|
223 |
+
attention_mask=attention_mask,
|
224 |
+
position_ids=position_ids,
|
225 |
+
past_key_value=past_key_value,
|
226 |
+
output_attentions=output_attentions,
|
227 |
+
use_cache=use_cache,
|
228 |
+
padding_mask=padding_mask,
|
229 |
+
)
|
230 |
+
|
231 |
+
# Fully Connected
|
232 |
+
hidden_states_mlp = self.mlp(hidden_states)
|
233 |
+
hidden_states = residual + hidden_states_attention + hidden_states_mlp
|
234 |
+
pass
|
235 |
+
|
236 |
+
outputs = (hidden_states,)
|
237 |
+
if output_attentions: outputs += (self_attn_weights,)
|
238 |
+
if use_cache: outputs += (present_key_value,)
|
239 |
+
return outputs
|
240 |
+
pass
|
241 |
+
|
242 |
+
|
243 |
+
from math import sqrt as math_sqrt
|
244 |
+
KV_CACHE_INCREMENT = 256 # KV Cache update size
|
245 |
+
torch_nn_functional_softmax = torch.nn.functional.softmax
|
246 |
+
torch_matmul = torch.matmul
|
247 |
+
|
248 |
+
def CohereAttention_fast_forward_inference(
|
249 |
+
self,
|
250 |
+
hidden_states: torch.Tensor,
|
251 |
+
past_key_value: Optional[Tuple[torch.Tensor]],
|
252 |
+
position_ids,
|
253 |
+
do_prefill = False,
|
254 |
+
attention_mask = None,
|
255 |
+
):
|
256 |
+
Xn = hidden_states
|
257 |
+
bsz, _, hd = hidden_states.size()
|
258 |
+
K1, V1 = past_key_value
|
259 |
+
dtype = Xn.dtype
|
260 |
+
|
261 |
+
n_heads = self.num_heads
|
262 |
+
n_groups = self.num_key_value_groups
|
263 |
+
n_kv_heads = self.num_key_value_heads
|
264 |
+
head_dim = self.head_dim
|
265 |
+
attention_size = n_heads*head_dim
|
266 |
+
# assert(n_kv_heads * n_groups == n_heads)
|
267 |
+
seq_len = K1.shape[-2]
|
268 |
+
kv_seq_len = seq_len + 1
|
269 |
+
|
270 |
+
# Prefill phase
|
271 |
+
# if not hasattr(self, "paged_attention"):
|
272 |
+
if do_prefill:
|
273 |
+
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
|
274 |
+
self.paged_attention_K = self.paged_attention[:,0]
|
275 |
+
self.paged_attention_V = self.paged_attention[:,1]
|
276 |
+
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
277 |
+
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
278 |
+
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
|
279 |
+
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
|
280 |
+
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
281 |
+
|
282 |
+
# Mistral Nemo 12b has weird dimensions
|
283 |
+
if attention_size != self.hidden_size:
|
284 |
+
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
|
285 |
+
else:
|
286 |
+
self.temp_O = self.temp_QA[1][:,:,:self.hidden_size]
|
287 |
+
pass
|
288 |
+
|
289 |
+
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
|
290 |
+
self.scalar = 1.0 / math_sqrt(self.head_dim)
|
291 |
+
self.half_head_dim = head_dim // 2
|
292 |
+
# Cohere has QK layernorms
|
293 |
+
if self.use_qk_norm:
|
294 |
+
self.q_norm_out_weight = torch.empty(self.q_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
|
295 |
+
self.k_norm_out_weight = torch.empty(self.k_norm.weight.shape, dtype = torch.float32, device = "cuda:0")
|
296 |
+
else:
|
297 |
+
self.q_norm_out_weight = None
|
298 |
+
self.k_norm_out_weight = None
|
299 |
+
pass
|
300 |
+
elif kv_seq_len >= self.paged_attention.shape[0]:
|
301 |
+
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
|
302 |
+
self.paged_attention_K = self.paged_attention[:,0]
|
303 |
+
self.paged_attention_V = self.paged_attention[:,1]
|
304 |
+
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
|
305 |
+
pass
|
306 |
+
|
307 |
+
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
308 |
+
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
309 |
+
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
310 |
+
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
311 |
+
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
312 |
+
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
313 |
+
if self.use_qk_norm:
|
314 |
+
Q = fast_layernorm_inference(self.q_norm, Q, self.q_norm_out_weight)
|
315 |
+
K = fast_layernorm_inference(self.k_norm, K, self.k_norm_out_weight)
|
316 |
+
pass
|
317 |
+
|
318 |
+
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
|
319 |
+
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
|
320 |
+
cos, sin = self.rotary_emb.get_cached(kv_seq_len)
|
321 |
+
cos = cos[position_ids].unsqueeze(1)
|
322 |
+
sin = sin[position_ids].unsqueeze(1)
|
323 |
+
h = self.half_head_dim
|
324 |
+
|
325 |
+
RH_Q = self.RH_Q
|
326 |
+
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
|
327 |
+
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
|
328 |
+
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
|
329 |
+
Qn *= cos
|
330 |
+
Qn.addcmul_(RH_Q, sin)
|
331 |
+
|
332 |
+
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
333 |
+
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
|
334 |
+
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
|
335 |
+
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
|
336 |
+
Kn *= cos
|
337 |
+
Kn.addcmul_(RH_K, sin)
|
338 |
+
|
339 |
+
# New KV cache
|
340 |
+
# Kn = torch.cat([K1, Kn], dim = 2)
|
341 |
+
# Vn = torch.cat([V1, Vn], dim = 2)
|
342 |
+
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
|
343 |
+
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
|
344 |
+
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
|
345 |
+
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
|
346 |
+
|
347 |
+
# Handle sliding windows
|
348 |
+
sliding_window = getattr(self.config, "sliding_window", None)
|
349 |
+
if sliding_window is not None and kv_seq_len > sliding_window:
|
350 |
+
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
|
351 |
+
slicing_tokens = 1 - sliding_window
|
352 |
+
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
|
353 |
+
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
|
354 |
+
else:
|
355 |
+
Knn, Vnn = Kn, Vn
|
356 |
+
pass
|
357 |
+
|
358 |
+
# Grouped query attention
|
359 |
+
_, _, cached_len, _ = Knn.shape
|
360 |
+
if n_groups != 1:
|
361 |
+
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
362 |
+
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
363 |
+
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
364 |
+
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
|
365 |
+
pass
|
366 |
+
# else:
|
367 |
+
# Knn, Vnn = Knn, Vnn
|
368 |
+
# pass
|
369 |
+
|
370 |
+
# Attention
|
371 |
+
if bsz == 1:
|
372 |
+
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
|
373 |
+
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
|
374 |
+
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
|
375 |
+
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
376 |
+
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
|
377 |
+
A = torch_matmul(A, Vnn, out = Qn)
|
378 |
+
else:
|
379 |
+
A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
|
380 |
+
pass
|
381 |
+
A = A.transpose(1, 2)
|
382 |
+
A = A.reshape(bsz, 1, attention_size)
|
383 |
+
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
|
384 |
+
return A, (Kn, Vn)
|
385 |
+
pass
|
386 |
+
|
387 |
+
|
388 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
|
389 |
+
# @torch.inference_mode
|
390 |
+
def CohereModel_fast_forward_inference(
|
391 |
+
self,
|
392 |
+
input_ids,
|
393 |
+
past_key_values,
|
394 |
+
position_ids,
|
395 |
+
attention_mask = None,
|
396 |
+
):
|
397 |
+
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
|
398 |
+
input_ids = input_ids[:,:self.max_seq_length]
|
399 |
+
hidden_states = self.model.embed_tokens(input_ids)
|
400 |
+
hidden_states = hidden_states.to(self.config.torch_dtype)
|
401 |
+
bsz, q_len, hd = hidden_states.shape
|
402 |
+
seq_len = past_key_values[0][0].shape[-2]
|
403 |
+
if bsz != 1:
|
404 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
405 |
+
attention_mask,
|
406 |
+
(bsz, q_len),
|
407 |
+
hidden_states,
|
408 |
+
seq_len,
|
409 |
+
sliding_window = getattr(self.config, "sliding_window", None),
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
attention_mask = None
|
413 |
+
pass
|
414 |
+
|
415 |
+
next_decoder_cache = []
|
416 |
+
for idx, decoder_layer in enumerate(self.model.layers):
|
417 |
+
residual = hidden_states
|
418 |
+
hidden_states = fast_layernorm_inference(decoder_layer.input_layernorm, hidden_states, out_weight)
|
419 |
+
hidden_states_attention, present_key_value = CohereAttention_fast_forward_inference(
|
420 |
+
decoder_layer.self_attn,
|
421 |
+
hidden_states = hidden_states,
|
422 |
+
past_key_value = past_key_values[idx],
|
423 |
+
position_ids = position_ids,
|
424 |
+
attention_mask = attention_mask,
|
425 |
+
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
426 |
+
)
|
427 |
+
|
428 |
+
hidden_states_mlp = fast_swiglu_inference(self.mlp, hidden_states)
|
429 |
+
residual += hidden_states_attention
|
430 |
+
residual += hidden_states_mlp
|
431 |
+
hidden_states = residual
|
432 |
+
|
433 |
+
next_decoder_cache.append(present_key_value)
|
434 |
+
pass
|
435 |
+
hidden_states = fast_layernorm_inference(self.model.norm, hidden_states, out_weight)
|
436 |
+
|
437 |
+
return BaseModelOutputWithPast(
|
438 |
+
last_hidden_state = hidden_states,
|
439 |
+
past_key_values = next_decoder_cache,
|
440 |
+
hidden_states = [],
|
441 |
+
attentions = [],
|
442 |
+
)
|
443 |
+
pass
|
444 |
+
|
445 |
+
|
446 |
+
class FastCohereModel(FastLlamaModel):
|
447 |
+
|
448 |
+
@staticmethod
|
449 |
+
def pre_patch():
|
450 |
+
init_name, function = patch_linear_scaling(
|
451 |
+
model_name = "cohere",
|
452 |
+
rope_module = LlamaRotaryEmbedding,
|
453 |
+
scaled_rope_module = LlamaLinearScalingRotaryEmbedding,
|
454 |
+
attention_module = CohereAttention,
|
455 |
+
)
|
456 |
+
if init_name is not None:
|
457 |
+
exec(function, globals())
|
458 |
+
CohereAttention.__init__ = eval(init_name)
|
459 |
+
pass
|
460 |
+
CohereAttention .forward = CohereAttention_fast_forward
|
461 |
+
CohereSdpaAttention .forward = CohereAttention_fast_forward
|
462 |
+
CohereFlashAttention2.forward = CohereAttention_fast_forward
|
463 |
+
CohereDecoderLayer .forward = CohereDecoderLayer_fast_forward
|
464 |
+
CohereModel .forward = LlamaModel_fast_forward
|
465 |
+
CohereForCausalLM .forward = CausalLM_fast_forward(CohereModel_fast_forward_inference)
|
466 |
+
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
|
467 |
+
fix_prepare_inputs_for_generation(CohereForCausalLM)
|
468 |
+
|
469 |
+
import transformers.models.cohere.modeling_cohere
|
470 |
+
transformers.models.cohere.modeling_cohere.CohereRotaryEmbedding = LlamaRotaryEmbedding
|
471 |
+
return
|
472 |
+
pass
|
473 |
+
pass
|
unsloth-main/unsloth-main/unsloth/models/dpo.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"PatchDPOTrainer",
|
17 |
+
]
|
18 |
+
|
19 |
+
try:
|
20 |
+
from transformers.utils.notebook import (
|
21 |
+
IntervalStrategy,
|
22 |
+
NotebookTrainingTracker,
|
23 |
+
NotebookProgressCallback,
|
24 |
+
)
|
25 |
+
HAS_NOTEBOOK = True
|
26 |
+
except:
|
27 |
+
HAS_NOTEBOOK = False
|
28 |
+
pass
|
29 |
+
import torch
|
30 |
+
from ._utils import torch_compile_options
|
31 |
+
import inspect
|
32 |
+
import torch.nn as nn
|
33 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
34 |
+
|
35 |
+
|
36 |
+
DPOTrainer_metrics = [
|
37 |
+
"rewards/chosen",
|
38 |
+
"rewards/rejected",
|
39 |
+
"rewards/accuracies",
|
40 |
+
"rewards/margins",
|
41 |
+
"logps/rejected",
|
42 |
+
"logps/chosen",
|
43 |
+
"logits/rejected",
|
44 |
+
"logits/chosen",
|
45 |
+
]
|
46 |
+
set_DPOTrainer_metrics = frozenset(DPOTrainer_metrics)
|
47 |
+
|
48 |
+
|
49 |
+
def NotebookProgressCallback_on_train_begin(self, args, state, control, **kwargs):
|
50 |
+
self.first_column = "Epoch" if args.eval_strategy == IntervalStrategy.EPOCH else "Step"
|
51 |
+
self.training_loss = 0
|
52 |
+
self.last_log = 0
|
53 |
+
column_names = [self.first_column] + ["Training Loss"]
|
54 |
+
if args.eval_strategy != IntervalStrategy.NO:
|
55 |
+
column_names.append("Validation Loss")
|
56 |
+
column_names += [x.replace("/", " / ") for x in DPOTrainer_metrics]
|
57 |
+
self.training_tracker = NotebookTrainingTracker(state.max_steps, column_names)
|
58 |
+
pass
|
59 |
+
|
60 |
+
|
61 |
+
def NotebookProgressCallback_on_log(self, args, state, control, logs=None, **kwargs):
|
62 |
+
# Only for when there is no evaluation
|
63 |
+
if args.eval_strategy == IntervalStrategy.NO and "loss" in logs:
|
64 |
+
values = {"Training Loss": logs["loss"]}
|
65 |
+
for metric in DPOTrainer_metrics:
|
66 |
+
values[metric.replace("/", " / ")] = logs[metric]
|
67 |
+
pass
|
68 |
+
# First column is necessarily Step since we're not in epoch eval strategy
|
69 |
+
values["Step"] = state.global_step
|
70 |
+
self.training_tracker.write_line(values)
|
71 |
+
pass
|
72 |
+
pass
|
73 |
+
|
74 |
+
|
75 |
+
def NotebookTrainingTracker_write_line(self, values):
|
76 |
+
"""
|
77 |
+
Write the values in the inner table.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
values (`Dict[str, float]`): The values to display.
|
81 |
+
"""
|
82 |
+
if self.inner_table is None:
|
83 |
+
self.inner_table = [list(values.keys()), list(values.values())]
|
84 |
+
else:
|
85 |
+
columns = self.inner_table[0]
|
86 |
+
new_values = {}
|
87 |
+
for key, value in values.items():
|
88 |
+
lowered = key.lower()
|
89 |
+
if lowered in set_DPOTrainer_metrics:
|
90 |
+
new_values[lowered.replace("/", " / ")] = value
|
91 |
+
else:
|
92 |
+
new_values[key] = value
|
93 |
+
pass
|
94 |
+
values = new_values
|
95 |
+
|
96 |
+
self.inner_table[0] = columns
|
97 |
+
if len(self.inner_table) > 1:
|
98 |
+
last_values = self.inner_table[-1]
|
99 |
+
first_column = self.inner_table[0][0]
|
100 |
+
if last_values[0] != values[first_column]:
|
101 |
+
# write new line
|
102 |
+
self.inner_table.append([values[c] if c in values else "No Log" for c in columns])
|
103 |
+
else:
|
104 |
+
# update last line
|
105 |
+
new_values = values
|
106 |
+
for c in columns:
|
107 |
+
if c not in new_values.keys():
|
108 |
+
new_values[c] = last_values[columns.index(c)]
|
109 |
+
self.inner_table[-1] = [new_values[c] for c in columns]
|
110 |
+
else:
|
111 |
+
# Edit for evaluation purposes
|
112 |
+
self.inner_table.append([values[c] if c in values else 0 for c in columns])
|
113 |
+
pass
|
114 |
+
pass
|
115 |
+
pass
|
116 |
+
|
117 |
+
|
118 |
+
def PatchDPOTrainer():
|
119 |
+
if HAS_NOTEBOOK:
|
120 |
+
from transformers.trainer import is_in_notebook
|
121 |
+
if is_in_notebook():
|
122 |
+
# Patch DPO notebook printing
|
123 |
+
NotebookTrainingTracker.write_line = NotebookTrainingTracker_write_line
|
124 |
+
from transformers.trainer import DEFAULT_PROGRESS_CALLBACK
|
125 |
+
DEFAULT_PROGRESS_CALLBACK.on_train_begin = NotebookProgressCallback_on_train_begin
|
126 |
+
DEFAULT_PROGRESS_CALLBACK.on_log = NotebookProgressCallback_on_log
|
127 |
+
pass
|
128 |
+
pass
|
129 |
+
pass
|
130 |
+
|
unsloth-main/unsloth-main/unsloth/models/gemma.py
ADDED
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .llama import *
|
16 |
+
from ._utils import __version__
|
17 |
+
import math
|
18 |
+
|
19 |
+
try:
|
20 |
+
from transformers.models.gemma.modeling_gemma import (
|
21 |
+
GemmaAttention,
|
22 |
+
GemmaDecoderLayer,
|
23 |
+
GemmaModel,
|
24 |
+
GemmaForCausalLM,
|
25 |
+
GemmaRotaryEmbedding,
|
26 |
+
apply_rotary_pos_emb,
|
27 |
+
repeat_kv,
|
28 |
+
)
|
29 |
+
except:
|
30 |
+
from packaging.version import Version
|
31 |
+
transformers_version = Version(transformers_version)
|
32 |
+
if not transformers_version >= Version("4.38"):
|
33 |
+
raise ImportError(
|
34 |
+
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"\
|
35 |
+
f"The minimum required version is 4.38.\n"\
|
36 |
+
f'Try `pip install --upgrade "transformers>=4.38"`\n'\
|
37 |
+
f"to obtain the latest transformers build, then restart this session."\
|
38 |
+
)
|
39 |
+
pass
|
40 |
+
pass
|
41 |
+
|
42 |
+
from transformers.modeling_attn_mask_utils import (
|
43 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
44 |
+
)
|
45 |
+
# For Pytorch 2.1.1
|
46 |
+
try:
|
47 |
+
from transformers.models.gemma.modeling_gemma import (
|
48 |
+
GemmaSdpaAttention,
|
49 |
+
GemmaFlashAttention2,
|
50 |
+
)
|
51 |
+
except:
|
52 |
+
GemmaSdpaAttention = GemmaAttention
|
53 |
+
GemmaFlashAttention2 = GemmaAttention
|
54 |
+
pass
|
55 |
+
|
56 |
+
|
57 |
+
torch_nn_functional_gelu = torch.nn.functional.gelu
|
58 |
+
def fast_geglu_inference(self, X):
|
59 |
+
# gate = self.gate_proj(X)
|
60 |
+
# up = self.up_proj(X)
|
61 |
+
bsz, _, hd = X.shape
|
62 |
+
# mlp_size = self.config.intermediate_size
|
63 |
+
# temp = torch.empty((2, bsz, 1, mlp_size), dtype = X.dtype, device = "cuda:0")
|
64 |
+
|
65 |
+
gate = fast_linear_forward(self.gate_proj, X)#, out = temp[0])
|
66 |
+
up = fast_linear_forward(self. up_proj, X)#, out = temp[1])
|
67 |
+
gate = torch_nn_functional_gelu(gate, approximate = "tanh")
|
68 |
+
gate *= up
|
69 |
+
|
70 |
+
# X = self.down_proj(gate)
|
71 |
+
down = fast_linear_forward(self.down_proj, gate, out = up[:,:,:hd])
|
72 |
+
return down
|
73 |
+
pass
|
74 |
+
|
75 |
+
|
76 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
|
77 |
+
def GemmaDecoderLayer_fast_forward(
|
78 |
+
self,
|
79 |
+
hidden_states: torch.Tensor,
|
80 |
+
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
81 |
+
attention_mask: Optional[torch.Tensor] = None,
|
82 |
+
position_ids: Optional[torch.LongTensor] = None,
|
83 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
84 |
+
output_attentions: Optional[bool] = False,
|
85 |
+
use_cache: Optional[bool] = False,
|
86 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
87 |
+
*args, **kwargs,
|
88 |
+
):
|
89 |
+
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
|
90 |
+
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
|
91 |
+
|
92 |
+
# Self Attention
|
93 |
+
residual = hidden_states
|
94 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
|
95 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
96 |
+
hidden_states=hidden_states,
|
97 |
+
causal_mask=causal_mask,
|
98 |
+
attention_mask=attention_mask,
|
99 |
+
position_ids=position_ids,
|
100 |
+
past_key_value=past_key_value,
|
101 |
+
output_attentions=output_attentions,
|
102 |
+
use_cache=use_cache,
|
103 |
+
padding_mask=padding_mask,
|
104 |
+
)
|
105 |
+
hidden_states += residual
|
106 |
+
|
107 |
+
# Fully Connected
|
108 |
+
residual = hidden_states
|
109 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
|
110 |
+
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
|
111 |
+
hidden_states += residual
|
112 |
+
else:
|
113 |
+
residual = hidden_states
|
114 |
+
hidden_states = fast_rms_layernorm(self.input_layernorm, hidden_states, gemma = True)
|
115 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
116 |
+
hidden_states=hidden_states,
|
117 |
+
causal_mask=causal_mask,
|
118 |
+
attention_mask=attention_mask,
|
119 |
+
position_ids=position_ids,
|
120 |
+
past_key_value=past_key_value,
|
121 |
+
output_attentions=output_attentions,
|
122 |
+
use_cache=use_cache,
|
123 |
+
padding_mask=padding_mask,
|
124 |
+
)
|
125 |
+
hidden_states = residual + hidden_states
|
126 |
+
|
127 |
+
# Fully Connected
|
128 |
+
residual = hidden_states
|
129 |
+
hidden_states = fast_rms_layernorm(self.post_attention_layernorm, hidden_states, gemma = True)
|
130 |
+
hidden_states = self.mlp(hidden_states)
|
131 |
+
hidden_states = residual + hidden_states
|
132 |
+
pass
|
133 |
+
|
134 |
+
outputs = (hidden_states,)
|
135 |
+
if output_attentions: outputs += (self_attn_weights,)
|
136 |
+
if use_cache: outputs += (present_key_value,)
|
137 |
+
return outputs
|
138 |
+
pass
|
139 |
+
|
140 |
+
|
141 |
+
from math import sqrt as math_sqrt
|
142 |
+
|
143 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
|
144 |
+
# @torch.inference_mode
|
145 |
+
def GemmaModel_fast_forward_inference(
|
146 |
+
self,
|
147 |
+
input_ids,
|
148 |
+
past_key_values,
|
149 |
+
position_ids,
|
150 |
+
attention_mask = None,
|
151 |
+
):
|
152 |
+
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
|
153 |
+
input_ids = input_ids[:,:self.max_seq_length]
|
154 |
+
hidden_states = self.model.embed_tokens(input_ids)
|
155 |
+
hidden_states = hidden_states.to(self.config.torch_dtype)
|
156 |
+
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
|
157 |
+
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
|
158 |
+
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
|
159 |
+
|
160 |
+
bsz, q_len, hd = hidden_states.shape
|
161 |
+
seq_len = past_key_values[0][0].shape[-2]
|
162 |
+
if bsz != 1:
|
163 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
164 |
+
attention_mask,
|
165 |
+
(bsz, q_len),
|
166 |
+
hidden_states,
|
167 |
+
seq_len,
|
168 |
+
)
|
169 |
+
pass
|
170 |
+
|
171 |
+
next_decoder_cache = []
|
172 |
+
for idx, decoder_layer in enumerate(self.model.layers):
|
173 |
+
residual = hidden_states
|
174 |
+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
|
175 |
+
hidden_states, present_key_value = LlamaAttention_fast_forward_inference(
|
176 |
+
decoder_layer.self_attn,
|
177 |
+
hidden_states = hidden_states,
|
178 |
+
past_key_value = past_key_values[idx],
|
179 |
+
position_ids = position_ids,
|
180 |
+
attention_mask = attention_mask,
|
181 |
+
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
182 |
+
)
|
183 |
+
hidden_states += residual
|
184 |
+
|
185 |
+
residual = hidden_states
|
186 |
+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
|
187 |
+
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
|
188 |
+
hidden_states += residual
|
189 |
+
|
190 |
+
next_decoder_cache.append(present_key_value)
|
191 |
+
pass
|
192 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
|
193 |
+
|
194 |
+
return BaseModelOutputWithPast(
|
195 |
+
last_hidden_state = hidden_states,
|
196 |
+
past_key_values = next_decoder_cache,
|
197 |
+
hidden_states = [],
|
198 |
+
attentions = [],
|
199 |
+
)
|
200 |
+
pass
|
201 |
+
|
202 |
+
|
203 |
+
# Follows line by line https://github.com/google-deepmind/gemma/blob/main/gemma/positional_embeddings.py#L45
|
204 |
+
# Formulates cos and sin differently from Llama!
|
205 |
+
class GemmaFixedRotaryEmbedding(torch.nn.Module):
|
206 |
+
# Fixes https://github.com/huggingface/transformers/pull/28837
|
207 |
+
# https://github.com/microsoft/DeepSpeed/issues/4932
|
208 |
+
# The precision of RoPE buffers is not correct, so we cast to int64.
|
209 |
+
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None,
|
210 |
+
config = None, # [TODO] Hack to pass in config - need to remove later
|
211 |
+
):
|
212 |
+
super().__init__()
|
213 |
+
if config is not None: return # [TODO] Hack to pass in config - need to remove later
|
214 |
+
self.dim = dim
|
215 |
+
self.max_position_embeddings = max_position_embeddings
|
216 |
+
self.base = base
|
217 |
+
# Dynamic RoPE we first set it to a max of 4 * 8192 tokens then we iteratively grow this
|
218 |
+
self.current_rope_size = min(4 * 8192, self.max_position_embeddings)
|
219 |
+
|
220 |
+
# Build here to make `torch.jit.trace` work.
|
221 |
+
self._set_cos_sin_cache(seq_len=self.current_rope_size, device=device, dtype=torch.get_default_dtype())
|
222 |
+
pass
|
223 |
+
|
224 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
225 |
+
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
|
226 |
+
# in FP32. They are applied (multiplied) in FP32 as well.
|
227 |
+
self.current_rope_size = seq_len
|
228 |
+
|
229 |
+
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
|
230 |
+
freq_exponents = (2.0 / self.dim) * (
|
231 |
+
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
|
232 |
+
)
|
233 |
+
timescale = self.base**freq_exponents
|
234 |
+
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
|
235 |
+
radians_new = positions[..., None] / timescale[None, None, :]
|
236 |
+
radians_new = radians_new.squeeze(0)
|
237 |
+
|
238 |
+
emb = torch.cat((radians_new, radians_new), dim = -1)
|
239 |
+
# We must do RoPE in float32!
|
240 |
+
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
|
241 |
+
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
|
242 |
+
self.register_buffer("cos_cached", cos, persistent = False)
|
243 |
+
self.register_buffer("sin_cached", sin, persistent = False)
|
244 |
+
pass
|
245 |
+
|
246 |
+
def forward(self, x, position_ids=None, seq_len=None):
|
247 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
248 |
+
if seq_len > self.current_rope_size:
|
249 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
250 |
+
|
251 |
+
return (
|
252 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
253 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
254 |
+
)
|
255 |
+
pass
|
256 |
+
|
257 |
+
def get_cached(self, seq_len = None):
|
258 |
+
return self.cos_cached, self.sin_cached
|
259 |
+
pass
|
260 |
+
|
261 |
+
def extend_rope_embedding(self, x, seq_len):
|
262 |
+
if seq_len <= self.current_rope_size: return
|
263 |
+
# Iteratively grow by increments of 8192
|
264 |
+
self.current_rope_size = math.ceil(seq_len / 8192) * 8192
|
265 |
+
self._set_cos_sin_cache(self.current_rope_size, device = "cuda:0", dtype = x.dtype)
|
266 |
+
pass
|
267 |
+
pass
|
268 |
+
|
269 |
+
|
270 |
+
class GemmaFixedLinearScalingRotaryEmbedding(GemmaFixedRotaryEmbedding):
|
271 |
+
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
272 |
+
# Fixes https://github.com/huggingface/transformers/pull/28837
|
273 |
+
# https://github.com/microsoft/DeepSpeed/issues/4932
|
274 |
+
# The precision of RoPE buffers is not correct, so we cast to int64.
|
275 |
+
def __init__(self, dim = None, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0,
|
276 |
+
config = None, # [TODO] Hack to pass in config - need to remove later
|
277 |
+
):
|
278 |
+
self.scaling_factor = scaling_factor
|
279 |
+
super().__init__(dim = dim, max_position_embeddings = max_position_embeddings, base = base, device = device, config = config)
|
280 |
+
pass
|
281 |
+
|
282 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
283 |
+
# Note: on the original Llama codebase, these tensors are created on the target device (and not on CPU) and
|
284 |
+
# in FP32. They are applied (multiplied) in FP32 as well.
|
285 |
+
self.current_rope_size = seq_len
|
286 |
+
|
287 |
+
# The difference is we do division explicity instead of t * (1/x) ie we do t/x.
|
288 |
+
freq_exponents = (2.0 / self.dim) * (
|
289 |
+
torch.arange(self.dim // 2, dtype = torch.int64, device = "cpu").float()
|
290 |
+
)
|
291 |
+
timescale = self.base**freq_exponents
|
292 |
+
positions = torch.arange(self.current_rope_size, device = "cpu", dtype = torch.int64).float()
|
293 |
+
positions = positions / self.scaling_factor
|
294 |
+
radians_new = positions[..., None] / timescale[None, None, :]
|
295 |
+
radians_new = radians_new.squeeze(0)
|
296 |
+
|
297 |
+
emb = torch.cat((radians_new, radians_new), dim = -1)
|
298 |
+
# We must do RoPE in float32!
|
299 |
+
cos = emb.cos().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
|
300 |
+
sin = emb.sin().to(device = "cuda:0", non_blocking = True)#, dtype = dtype)
|
301 |
+
self.register_buffer("cos_cached", cos, persistent = False)
|
302 |
+
self.register_buffer("sin_cached", sin, persistent = False)
|
303 |
+
pass
|
304 |
+
pass
|
305 |
+
|
306 |
+
|
307 |
+
class FastGemmaModel(FastLlamaModel):
|
308 |
+
|
309 |
+
@staticmethod
|
310 |
+
def pre_patch():
|
311 |
+
init_name, function = patch_linear_scaling(
|
312 |
+
model_name = "gemma",
|
313 |
+
rope_module = GemmaFixedRotaryEmbedding,
|
314 |
+
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
|
315 |
+
attention_module = GemmaAttention,
|
316 |
+
)
|
317 |
+
if init_name is not None:
|
318 |
+
exec(function, globals())
|
319 |
+
GemmaAttention.__init__ = eval(init_name)
|
320 |
+
pass
|
321 |
+
GemmaAttention .forward = LlamaAttention_fast_forward
|
322 |
+
GemmaSdpaAttention .forward = LlamaAttention_fast_forward
|
323 |
+
GemmaFlashAttention2.forward = LlamaAttention_fast_forward
|
324 |
+
GemmaDecoderLayer .forward = GemmaDecoderLayer_fast_forward
|
325 |
+
GemmaModel .forward = LlamaModel_fast_forward
|
326 |
+
GemmaForCausalLM .forward = CausalLM_fast_forward(GemmaModel_fast_forward_inference)
|
327 |
+
PeftModelForCausalLM.forward = PeftModelForCausalLM_fast_forward
|
328 |
+
fix_prepare_inputs_for_generation(GemmaForCausalLM)
|
329 |
+
|
330 |
+
# Solves https://github.com/unslothai/unsloth/issues/168
|
331 |
+
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
332 |
+
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
333 |
+
# https://github.com/huggingface/transformers/pull/27931
|
334 |
+
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
|
335 |
+
import transformers.models.gemma.modeling_gemma
|
336 |
+
transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding = GemmaFixedRotaryEmbedding
|
337 |
+
return
|
338 |
+
pass
|
339 |
+
|
340 |
+
|
341 |
+
@staticmethod
|
342 |
+
def post_patch(model):
|
343 |
+
# Patch model for Gemma
|
344 |
+
layers = model.model.layers
|
345 |
+
|
346 |
+
# Torch.compile fails on embedding matrix??
|
347 |
+
# Workaround randomnly fixes it for torch versions < 2.2
|
348 |
+
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
|
349 |
+
model.config.update({"unsloth_version" : __version__})
|
350 |
+
|
351 |
+
# We also do this for the lm_head
|
352 |
+
lm_head = torch.nn.Linear(1, 1, bias = None)
|
353 |
+
del lm_head.weight
|
354 |
+
lm_head.weight = model.lm_head.weight
|
355 |
+
lm_head.in_features = lm_head.weight.shape[1]
|
356 |
+
lm_head.out_features = lm_head.weight.shape[0]
|
357 |
+
model.lm_head = lm_head
|
358 |
+
|
359 |
+
# Gemma has tied weights! This means lm_head == embed_tokens
|
360 |
+
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
|
361 |
+
lm_head = torch.nn.Linear(1, 1, bias = None)
|
362 |
+
del lm_head.weight
|
363 |
+
lm_head.weight = model.model.embed_tokens.weight
|
364 |
+
lm_head.in_features = lm_head.weight.shape[1]
|
365 |
+
lm_head.out_features = lm_head.weight.shape[0]
|
366 |
+
model.lm_head = lm_head
|
367 |
+
pass
|
368 |
+
|
369 |
+
# Also patch all dtypes - BnB seems to not allocate the correct type?
|
370 |
+
# BnB default dtype seems to be float16!
|
371 |
+
correct_dtype = lm_head.weight.dtype
|
372 |
+
|
373 |
+
for name, module in model.named_modules():
|
374 |
+
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
|
375 |
+
weight = module.weight
|
376 |
+
quant_state = weight.quant_state
|
377 |
+
|
378 |
+
if type(quant_state) is list:
|
379 |
+
# BnB seems to have float16 as default!
|
380 |
+
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
|
381 |
+
else:
|
382 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
383 |
+
quant_state.dtype = correct_dtype
|
384 |
+
pass
|
385 |
+
pass
|
386 |
+
# Downcast RoPE embedding to correct data type
|
387 |
+
# RoPE must be done in float32 for Gemma
|
388 |
+
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
|
389 |
+
# and (module.cos_cached.dtype != correct_dtype):
|
390 |
+
|
391 |
+
# module.cos_cached = module.cos_cached.to(correct_dtype)
|
392 |
+
# module.sin_cached = module.sin_cached.to(correct_dtype)
|
393 |
+
# pass
|
394 |
+
# pass
|
395 |
+
pass
|
396 |
+
|
397 |
+
# Add 1 to weight
|
398 |
+
# return output * (1 + self.weight)
|
399 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
|
400 |
+
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
|
401 |
+
|
402 |
+
# Freeze all parameters except LoRA
|
403 |
+
# We do this first since += 1 seems to not be liked by requires_grad = True
|
404 |
+
for name, param in model.named_parameters():
|
405 |
+
if ".lora_A." in name or ".lora_B." in name:
|
406 |
+
param.requires_grad_(True)
|
407 |
+
else:
|
408 |
+
param.requires_grad_(False)
|
409 |
+
pass
|
410 |
+
|
411 |
+
# Patch RMS Layernorm
|
412 |
+
for name, module in model.named_modules():
|
413 |
+
if isinstance(module, GemmaRMSNorm):
|
414 |
+
# Must be in float32
|
415 |
+
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
|
416 |
+
# module = module.to(torch.float32)
|
417 |
+
# Leave + 1 to Triton kernel itself
|
418 |
+
# module.weight += 1.0 # return output * (1 + self.weight)
|
419 |
+
if not hasattr(module, "variance_epsilon"):
|
420 |
+
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
|
421 |
+
pass
|
422 |
+
|
423 |
+
# Clear deleted GPU items
|
424 |
+
import gc
|
425 |
+
for _ in range(3):
|
426 |
+
gc.collect()
|
427 |
+
torch.cuda.empty_cache()
|
428 |
+
return model
|
429 |
+
pass
|
430 |
+
pass
|
unsloth-main/unsloth-main/unsloth/models/gemma2.py
ADDED
@@ -0,0 +1,581 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .llama import *
|
16 |
+
from ._utils import __version__
|
17 |
+
from .gemma import (
|
18 |
+
GemmaFixedRotaryEmbedding,
|
19 |
+
GemmaFixedLinearScalingRotaryEmbedding,
|
20 |
+
fast_geglu_inference,
|
21 |
+
)
|
22 |
+
try:
|
23 |
+
from transformers.models.gemma2.modeling_gemma2 import (
|
24 |
+
Gemma2Attention,
|
25 |
+
Gemma2DecoderLayer,
|
26 |
+
Gemma2Model,
|
27 |
+
Gemma2ForCausalLM,
|
28 |
+
Gemma2RotaryEmbedding,
|
29 |
+
apply_rotary_pos_emb,
|
30 |
+
repeat_kv,
|
31 |
+
)
|
32 |
+
except:
|
33 |
+
from packaging.version import Version
|
34 |
+
transformers_version = Version(transformers_version)
|
35 |
+
if not transformers_version >= Version("4.42"):
|
36 |
+
raise ImportError(
|
37 |
+
f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"\
|
38 |
+
f"The minimum required version is 4.42.3.\n"\
|
39 |
+
f'Try `pip install --upgrade "transformers>=4.42.3"`\n'\
|
40 |
+
f"to obtain the latest transformers build, then restart this session."\
|
41 |
+
)
|
42 |
+
pass
|
43 |
+
pass
|
44 |
+
|
45 |
+
from transformers.modeling_attn_mask_utils import (
|
46 |
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
47 |
+
)
|
48 |
+
# For Pytorch 2.1.1
|
49 |
+
try:
|
50 |
+
from transformers.models.gemma2.modeling_gemma2 import (
|
51 |
+
Gemma2SdpaAttention,
|
52 |
+
Gemma2FlashAttention2,
|
53 |
+
)
|
54 |
+
except:
|
55 |
+
Gemma2SdpaAttention = Gemma2Attention
|
56 |
+
Gemma2FlashAttention2 = Gemma2Attention
|
57 |
+
pass
|
58 |
+
|
59 |
+
if HAS_FLASH_ATTENTION_SOFTCAPPING:
|
60 |
+
from flash_attn import flash_attn_func
|
61 |
+
|
62 |
+
# [TODO] We must randomnly use torch.compile?
|
63 |
+
# I checked the gradients and formulas and I'm sure it's correct.
|
64 |
+
# I'm stumped :(
|
65 |
+
@torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
|
66 |
+
def fast_rms_layernorm_gemma2_compiled(layernorm, X, gemma = True):
|
67 |
+
old_dtype = X.dtype
|
68 |
+
X = X.float()
|
69 |
+
X = X * torch.rsqrt(X.square().mean(-1, keepdim = True) + layernorm.eps) * \
|
70 |
+
(1.0 + layernorm.weight.float())
|
71 |
+
return X.to(old_dtype)
|
72 |
+
pass
|
73 |
+
|
74 |
+
|
75 |
+
# Logit softcapping
|
76 |
+
def Gemma2Attention_fast_forward(
|
77 |
+
self,
|
78 |
+
hidden_states: torch.Tensor,
|
79 |
+
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
80 |
+
attention_mask: Optional[torch.Tensor] = None,
|
81 |
+
position_ids: Optional[torch.LongTensor] = None,
|
82 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
83 |
+
output_attentions: bool = False,
|
84 |
+
use_cache: bool = False,
|
85 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
86 |
+
*args, **kwargs,
|
87 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
88 |
+
|
89 |
+
# Clear inference
|
90 |
+
if hasattr(self, "paged_attention"):
|
91 |
+
del self.paged_attention_K
|
92 |
+
del self.paged_attention_V
|
93 |
+
del self.paged_attention
|
94 |
+
del self.temp_QA
|
95 |
+
del self.temp_KV
|
96 |
+
del self.RH_Q
|
97 |
+
del self.attention
|
98 |
+
pass
|
99 |
+
|
100 |
+
bsz, q_len, _ = hidden_states.size()
|
101 |
+
|
102 |
+
n_heads = self.num_heads
|
103 |
+
n_groups = self.num_key_value_groups
|
104 |
+
n_kv_heads = self.num_key_value_heads
|
105 |
+
head_dim = self.head_dim
|
106 |
+
assert(n_kv_heads * n_groups == n_heads)
|
107 |
+
|
108 |
+
Q, K, V = self.apply_qkv(self, hidden_states)
|
109 |
+
Q = Q.view(bsz, q_len, n_heads, head_dim).transpose(1, 2)
|
110 |
+
K = K.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
111 |
+
V = V.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
|
112 |
+
|
113 |
+
kv_seq_len = K.shape[-2]
|
114 |
+
if past_key_value is not None:
|
115 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
116 |
+
|
117 |
+
if position_ids is None:
|
118 |
+
cos = self.rotary_emb.cos_cached
|
119 |
+
sin = self.rotary_emb.sin_cached
|
120 |
+
Q, K = fast_rope_embedding(Q, K, cos, sin)
|
121 |
+
else:
|
122 |
+
cos, sin = self.rotary_emb(V, seq_len = kv_seq_len)
|
123 |
+
Q, K = inplace_rope_embedding(Q, K, cos, sin, position_ids)
|
124 |
+
pass
|
125 |
+
|
126 |
+
if past_key_value is not None:
|
127 |
+
K = torch.cat([past_key_value[0], K], dim = 2)
|
128 |
+
V = torch.cat([past_key_value[1], V], dim = 2)
|
129 |
+
pass
|
130 |
+
past_key_value = (K, V) if use_cache else None
|
131 |
+
|
132 |
+
# Only enable if the attention_mask is True
|
133 |
+
has_sliding_window = type(causal_mask) is bool and causal_mask is True
|
134 |
+
if HAS_FLASH_ATTENTION_SOFTCAPPING and attention_mask is None:
|
135 |
+
window = (-1, -1)
|
136 |
+
if has_sliding_window:
|
137 |
+
sw = getattr(self.config, "sliding_window", None)
|
138 |
+
sw = kv_seq_len if (sw is None or sw == "null") else sw
|
139 |
+
window = (-1, -1) if (kv_seq_len <= sw) else (sw, sw)
|
140 |
+
pass
|
141 |
+
|
142 |
+
# FA uses 1 / sqrt for softmax_scale!
|
143 |
+
if not hasattr(self, "_flash_attention_softmax_scale"):
|
144 |
+
self._flash_attention_softmax_scale = 1.0 / (self.config.query_pre_attn_scalar**0.5)
|
145 |
+
pass
|
146 |
+
|
147 |
+
Q = Q.transpose(1, 2)
|
148 |
+
K = K.transpose(1, 2)
|
149 |
+
V = V.transpose(1, 2)
|
150 |
+
A = flash_attn_func(
|
151 |
+
Q, K, V,
|
152 |
+
causal = True,
|
153 |
+
softcap = self.config.attn_logit_softcapping,
|
154 |
+
softmax_scale = self._flash_attention_softmax_scale,
|
155 |
+
window_size = window,
|
156 |
+
)
|
157 |
+
A = A.reshape(bsz, q_len, n_heads*head_dim)
|
158 |
+
else:
|
159 |
+
fx = slow_inference_attention_softcapping \
|
160 |
+
if "_flag_for_generation" in kwargs else \
|
161 |
+
slow_attention_softcapping
|
162 |
+
A = fx(Q, K, V, causal_mask, self, bsz, kv_seq_len)
|
163 |
+
pass
|
164 |
+
A = self.apply_o(self, A)
|
165 |
+
return A, None, past_key_value
|
166 |
+
pass
|
167 |
+
|
168 |
+
|
169 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L590
|
170 |
+
def Gemma2DecoderLayer_fast_forward(
|
171 |
+
self,
|
172 |
+
hidden_states: torch.Tensor,
|
173 |
+
causal_mask: Optional[xformers.attn_bias.BlockDiagonalCausalMask] = None,
|
174 |
+
attention_mask: Optional[torch.Tensor] = None,
|
175 |
+
position_ids: Optional[torch.LongTensor] = None,
|
176 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
177 |
+
output_attentions: Optional[bool] = False,
|
178 |
+
use_cache: Optional[bool] = False,
|
179 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
180 |
+
*args, **kwargs,
|
181 |
+
):
|
182 |
+
if use_cache and hasattr(self, "_flag_for_generation"): #past_key_value is not None:
|
183 |
+
out_weight = torch.empty(self.input_layernorm.weight.shape, dtype = torch.float32, device = "cuda:0")
|
184 |
+
|
185 |
+
# Self Attention
|
186 |
+
residual = hidden_states
|
187 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self.input_layernorm, hidden_states, out_weight)
|
188 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
189 |
+
hidden_states=hidden_states,
|
190 |
+
causal_mask=causal_mask,
|
191 |
+
attention_mask=attention_mask,
|
192 |
+
position_ids=position_ids,
|
193 |
+
past_key_value=past_key_value,
|
194 |
+
output_attentions=output_attentions,
|
195 |
+
use_cache=use_cache,
|
196 |
+
padding_mask=padding_mask,
|
197 |
+
_flag_for_generation=True,
|
198 |
+
)
|
199 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self.post_attention_layernorm, hidden_states, out_weight)
|
200 |
+
hidden_states += residual
|
201 |
+
|
202 |
+
# Fully Connected
|
203 |
+
residual = hidden_states
|
204 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self. pre_feedforward_layernorm, hidden_states, out_weight)
|
205 |
+
hidden_states = fast_geglu_inference(self.mlp, hidden_states)
|
206 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self.post_feedforward_layernorm, hidden_states, out_weight)
|
207 |
+
hidden_states += residual
|
208 |
+
else:
|
209 |
+
residual = hidden_states
|
210 |
+
hidden_states = fast_rms_layernorm_gemma2_compiled(self.input_layernorm, hidden_states, gemma = True)
|
211 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
212 |
+
hidden_states=hidden_states,
|
213 |
+
causal_mask=causal_mask,
|
214 |
+
attention_mask=attention_mask,
|
215 |
+
position_ids=position_ids,
|
216 |
+
past_key_value=past_key_value,
|
217 |
+
output_attentions=output_attentions,
|
218 |
+
use_cache=use_cache,
|
219 |
+
padding_mask=padding_mask,
|
220 |
+
)
|
221 |
+
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_attention_layernorm, hidden_states, gemma = True)
|
222 |
+
hidden_states = residual + hidden_states
|
223 |
+
|
224 |
+
# Fully Connected
|
225 |
+
residual = hidden_states
|
226 |
+
hidden_states = fast_rms_layernorm_gemma2_compiled(self. pre_feedforward_layernorm, hidden_states, gemma = True)
|
227 |
+
hidden_states = self.mlp(hidden_states)
|
228 |
+
hidden_states = fast_rms_layernorm_gemma2_compiled(self.post_feedforward_layernorm, hidden_states, gemma = True)
|
229 |
+
hidden_states = residual + hidden_states
|
230 |
+
pass
|
231 |
+
|
232 |
+
outputs = (hidden_states,)
|
233 |
+
if output_attentions: outputs += (self_attn_weights,)
|
234 |
+
if use_cache: outputs += (present_key_value,)
|
235 |
+
return outputs
|
236 |
+
pass
|
237 |
+
|
238 |
+
|
239 |
+
from math import sqrt as math_sqrt
|
240 |
+
KV_CACHE_INCREMENT = 256 # KV Cache update size
|
241 |
+
torch_nn_functional_softmax = torch.nn.functional.softmax
|
242 |
+
torch_matmul = torch.matmul
|
243 |
+
torch_tanh = torch.tanh
|
244 |
+
|
245 |
+
def Gemma2Attention_fast_forward_inference(
|
246 |
+
self,
|
247 |
+
hidden_states: torch.Tensor,
|
248 |
+
past_key_value: Optional[Tuple[torch.Tensor]],
|
249 |
+
position_ids,
|
250 |
+
do_prefill = False,
|
251 |
+
attention_mask = None,
|
252 |
+
use_sliding_window = False,
|
253 |
+
):
|
254 |
+
Xn = hidden_states
|
255 |
+
bsz, _, hd = hidden_states.size()
|
256 |
+
K1, V1 = past_key_value
|
257 |
+
dtype = Xn.dtype
|
258 |
+
|
259 |
+
n_heads = self.num_heads
|
260 |
+
n_groups = self.num_key_value_groups
|
261 |
+
n_kv_heads = self.num_key_value_heads
|
262 |
+
head_dim = self.head_dim
|
263 |
+
attention_size = n_heads*head_dim
|
264 |
+
# assert(n_kv_heads * n_groups == n_heads)
|
265 |
+
seq_len = K1.shape[-2]
|
266 |
+
kv_seq_len = seq_len + 1
|
267 |
+
|
268 |
+
# Prefill phase
|
269 |
+
# if not hasattr(self, "paged_attention"):
|
270 |
+
if do_prefill:
|
271 |
+
self.paged_attention = torch.empty((KV_CACHE_INCREMENT+seq_len+1, 2, bsz, n_kv_heads, head_dim), dtype = dtype, device = "cuda:0")
|
272 |
+
self.paged_attention_K = self.paged_attention[:,0]
|
273 |
+
self.paged_attention_V = self.paged_attention[:,1]
|
274 |
+
self.paged_attention_K[:seq_len] = K1.permute(2, 0, 1, 3)
|
275 |
+
self.paged_attention_V[:seq_len] = V1.permute(2, 0, 1, 3)
|
276 |
+
self.temp_QA = torch.empty((2, bsz, 1, attention_size), dtype = dtype, device = "cuda:0")
|
277 |
+
self.temp_KV = torch.empty((2, bsz, 1, n_kv_heads*head_dim), dtype = dtype, device = "cuda:0")
|
278 |
+
self.RH_Q = torch.empty((bsz, n_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
279 |
+
# Only for Gemma2
|
280 |
+
self.temp_O = torch.empty((1, bsz, self.hidden_size), dtype = dtype, device = "cuda:0")
|
281 |
+
self.attention = torch.empty((bsz, n_heads, 1, KV_CACHE_INCREMENT+seq_len), dtype = dtype, device = "cuda:0")
|
282 |
+
|
283 |
+
# See https://github.com/google/gemma_pytorch/commit/03e657582d17cb5a8617ebf333c1c16f3694670e
|
284 |
+
# Gemma 9b should use 256 and not 224 (hs / nah). 27b uses the below
|
285 |
+
# We default to using the config file itself
|
286 |
+
# s = self.config.hidden_size // self.config.num_attention_heads
|
287 |
+
self.scalar = 1.0 / math_sqrt(self.config.query_pre_attn_scalar)
|
288 |
+
# self.scalar = 1.0 / math_sqrt(self.config.hidden_size // self.config.num_attention_heads)
|
289 |
+
self.half_head_dim = head_dim // 2
|
290 |
+
self. t = self.config.attn_logit_softcapping
|
291 |
+
self.reciprocal_t = 1.0 / self.config.attn_logit_softcapping
|
292 |
+
elif kv_seq_len >= self.paged_attention.shape[0]:
|
293 |
+
self.paged_attention.resize_((self.paged_attention.shape[0]+KV_CACHE_INCREMENT, 2, bsz, n_kv_heads, head_dim))
|
294 |
+
self.paged_attention_K = self.paged_attention[:,0]
|
295 |
+
self.paged_attention_V = self.paged_attention[:,1]
|
296 |
+
self.attention.resize_((bsz, n_heads, 1, self.attention.shape[-1]+KV_CACHE_INCREMENT))
|
297 |
+
pass
|
298 |
+
|
299 |
+
Qn = fast_linear_forward(self.q_proj, Xn, out = self.temp_QA[0])
|
300 |
+
Kn = fast_linear_forward(self.k_proj, Xn, out = self.temp_KV[0])
|
301 |
+
Vn = fast_linear_forward(self.v_proj, Xn, out = self.temp_KV[1])
|
302 |
+
Qn = Qn.view(bsz, 1, n_heads, head_dim).transpose(1, 2)
|
303 |
+
Kn = Kn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
304 |
+
Vn = Vn.view(bsz, 1, n_kv_heads, head_dim).transpose(1, 2)
|
305 |
+
|
306 |
+
# cos, sin = self.rotary_emb(Vn, seq_len = kv_seq_len)
|
307 |
+
# Qn, Kn = inplace_rope_embedding(Qn, Kn, cos, sin, position_ids)
|
308 |
+
cos = self.rotary_emb.cos_cached[position_ids].unsqueeze(1)
|
309 |
+
sin = self.rotary_emb.sin_cached[position_ids].unsqueeze(1)
|
310 |
+
h = self.half_head_dim
|
311 |
+
|
312 |
+
RH_Q = self.RH_Q
|
313 |
+
RH_Q[:,:,:,:h] = Qn[:,:,:,h:]
|
314 |
+
RH_Q[:,:,:,h:] = Qn[:,:,:,:h]
|
315 |
+
torch.neg(RH_Q[:,:,:,:h], out = RH_Q[:,:,:,:h])
|
316 |
+
Qn *= cos
|
317 |
+
Qn.addcmul_(RH_Q, sin)
|
318 |
+
|
319 |
+
RH_K = RH_Q[:,:n_kv_heads,:,:] # torch.empty((n_kv_heads, 1, head_dim), dtype = dtype, device = "cuda:0")
|
320 |
+
RH_K[:,:,:,:h] = Kn[:,:,:,h:]
|
321 |
+
RH_K[:,:,:,h:] = Kn[:,:,:,:h]
|
322 |
+
torch.neg(RH_K[:,:,:,:h], out = RH_K[:,:,:,:h])
|
323 |
+
Kn *= cos
|
324 |
+
Kn.addcmul_(RH_K, sin)
|
325 |
+
|
326 |
+
# New KV cache
|
327 |
+
# Kn = torch.cat([K1, Kn], dim = 2)
|
328 |
+
# Vn = torch.cat([V1, Vn], dim = 2)
|
329 |
+
self.paged_attention_K[seq_len] = Kn.permute(2, 0, 1, 3)
|
330 |
+
self.paged_attention_V[seq_len] = Vn.permute(2, 0, 1, 3)
|
331 |
+
Kn = self.paged_attention_K[:kv_seq_len].permute(1, 2, 0, 3)
|
332 |
+
Vn = self.paged_attention_V[:kv_seq_len].permute(1, 2, 0, 3)
|
333 |
+
|
334 |
+
# Handle sliding windows
|
335 |
+
sliding_window = self.config.sliding_window
|
336 |
+
if use_sliding_window and kv_seq_len > sliding_window:
|
337 |
+
# From https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L193
|
338 |
+
slicing_tokens = 1 - sliding_window
|
339 |
+
Knn = Kn[:, :, slicing_tokens:, :]#.contiguous()
|
340 |
+
Vnn = Vn[:, :, slicing_tokens:, :]#.contiguous()
|
341 |
+
else:
|
342 |
+
Knn, Vnn = Kn, Vn
|
343 |
+
pass
|
344 |
+
|
345 |
+
# Grouped query attention
|
346 |
+
_, _, cached_len, _ = Knn.shape
|
347 |
+
if n_groups != 1:
|
348 |
+
Knn = Knn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
349 |
+
Vnn = Vnn[:, :, None, :, :].expand(bsz, n_kv_heads, n_groups, cached_len, head_dim)
|
350 |
+
Knn = Knn.reshape(bsz, n_heads, cached_len, head_dim)
|
351 |
+
Vnn = Vnn.reshape(bsz, n_heads, cached_len, head_dim)
|
352 |
+
pass
|
353 |
+
# else:
|
354 |
+
# Knn, Vnn = Knn, Vnn
|
355 |
+
# pass
|
356 |
+
|
357 |
+
# Attention
|
358 |
+
# if bsz == 1:
|
359 |
+
Qn *= self.scalar # See https://github.com/ggerganov/llama.cpp/issues/7805#issuecomment-2153349963
|
360 |
+
# It seems like doing (Q * scalar) @ K is better than (Q @ K) * scalar to stop overflows
|
361 |
+
A = torch_matmul(Qn, Knn.transpose(2, 3), out = self.attention[:,:,:,:cached_len])
|
362 |
+
# if attention_mask is not None: A += attention_mask # Must add attention_mask for batched
|
363 |
+
|
364 |
+
A *= self.reciprocal_t; torch_tanh(A, out = A); A *= self.t; # Logit softcapping
|
365 |
+
|
366 |
+
A[:] = torch_nn_functional_softmax(A, dim = -1, dtype = torch.float32)#.to(A.dtype)
|
367 |
+
A = torch_matmul(A, Vnn, out = Qn)
|
368 |
+
# else:
|
369 |
+
# A = scaled_dot_product_attention(Qn, Knn, Vnn, attn_mask = attention_mask, is_causal = False)
|
370 |
+
# pass
|
371 |
+
A = A.transpose(1, 2)
|
372 |
+
A = A.reshape(bsz, 1, attention_size)
|
373 |
+
A = fast_linear_forward(self.o_proj, A, out = self.temp_O)
|
374 |
+
return A, (Kn, Vn)
|
375 |
+
pass
|
376 |
+
|
377 |
+
|
378 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L825
|
379 |
+
# @torch.inference_mode
|
380 |
+
def Gemma2Model_fast_forward_inference(
|
381 |
+
self,
|
382 |
+
input_ids,
|
383 |
+
past_key_values,
|
384 |
+
position_ids,
|
385 |
+
attention_mask = None,
|
386 |
+
):
|
387 |
+
out_weight = torch.empty_like(self.model.layers[0].input_layernorm.weight, dtype = torch.float32, device = "cuda:0")
|
388 |
+
input_ids = input_ids[:,:self.max_seq_length]
|
389 |
+
hidden_states = self.model.embed_tokens(input_ids)
|
390 |
+
hidden_states = hidden_states.to(self.config.torch_dtype)
|
391 |
+
# 3072**0.5 = 55.5000 in bfloat16, whilst 55.4256 in float32
|
392 |
+
# 2048**0.5 = 45.2500 in bfloat16, whilst 45.2548 in float32
|
393 |
+
hidden_states *= torch.tensor(math_sqrt(self.config.hidden_size), dtype = hidden_states.dtype)
|
394 |
+
|
395 |
+
bsz, q_len, hd = hidden_states.shape
|
396 |
+
seq_len = past_key_values[0][0].shape[-2]
|
397 |
+
if bsz != 1:
|
398 |
+
if HAS_FLASH_ATTENTION_SOFTCAPPING:
|
399 |
+
SWA = True
|
400 |
+
GA = False
|
401 |
+
else:
|
402 |
+
SWA = _prepare_4d_causal_attention_mask_for_sdpa(
|
403 |
+
attention_mask,
|
404 |
+
(bsz, q_len),
|
405 |
+
hidden_states,
|
406 |
+
seq_len,
|
407 |
+
sliding_window = self.config.sliding_window,
|
408 |
+
)
|
409 |
+
GA = _prepare_4d_causal_attention_mask_for_sdpa(
|
410 |
+
attention_mask,
|
411 |
+
(bsz, q_len),
|
412 |
+
hidden_states,
|
413 |
+
seq_len,
|
414 |
+
)
|
415 |
+
pass
|
416 |
+
else:
|
417 |
+
SWA = attention_mask
|
418 |
+
GA = attention_mask
|
419 |
+
pass
|
420 |
+
next_decoder_cache = []
|
421 |
+
for idx, decoder_layer in enumerate(self.model.layers):
|
422 |
+
|
423 |
+
use_sliding_window = idx % 2 == 0
|
424 |
+
|
425 |
+
residual = hidden_states
|
426 |
+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.input_layernorm, hidden_states, out_weight)
|
427 |
+
hidden_states, present_key_value = Gemma2Attention_fast_forward_inference(
|
428 |
+
decoder_layer.self_attn,
|
429 |
+
hidden_states = hidden_states,
|
430 |
+
past_key_value = past_key_values[idx],
|
431 |
+
position_ids = position_ids,
|
432 |
+
attention_mask = SWA if use_sliding_window else GA,
|
433 |
+
do_prefill = not hasattr(decoder_layer.self_attn, "paged_attention"),
|
434 |
+
use_sliding_window = use_sliding_window,
|
435 |
+
)
|
436 |
+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_attention_layernorm, hidden_states, out_weight)
|
437 |
+
hidden_states += residual
|
438 |
+
|
439 |
+
residual = hidden_states
|
440 |
+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer. pre_feedforward_layernorm, hidden_states, out_weight)
|
441 |
+
hidden_states = fast_geglu_inference(decoder_layer.mlp, hidden_states)
|
442 |
+
hidden_states = fast_rms_layernorm_inference_gemma(decoder_layer.post_feedforward_layernorm, hidden_states, out_weight)
|
443 |
+
hidden_states += residual
|
444 |
+
|
445 |
+
next_decoder_cache.append(present_key_value)
|
446 |
+
pass
|
447 |
+
hidden_states = fast_rms_layernorm_inference_gemma(self.model.norm, hidden_states, out_weight)
|
448 |
+
|
449 |
+
return BaseModelOutputWithPast(
|
450 |
+
last_hidden_state = hidden_states,
|
451 |
+
past_key_values = next_decoder_cache,
|
452 |
+
hidden_states = [],
|
453 |
+
attentions = [],
|
454 |
+
)
|
455 |
+
pass
|
456 |
+
|
457 |
+
|
458 |
+
class FastGemma2Model(FastLlamaModel):
|
459 |
+
|
460 |
+
@staticmethod
|
461 |
+
def pre_patch():
|
462 |
+
init_name, function = patch_linear_scaling(
|
463 |
+
model_name = "gemma2",
|
464 |
+
rope_module = GemmaFixedRotaryEmbedding,
|
465 |
+
scaled_rope_module = GemmaFixedLinearScalingRotaryEmbedding,
|
466 |
+
attention_module = Gemma2Attention,
|
467 |
+
)
|
468 |
+
if init_name is not None:
|
469 |
+
exec(function, globals())
|
470 |
+
Gemma2Attention.__init__ = eval(init_name)
|
471 |
+
pass
|
472 |
+
Gemma2Attention .forward = Gemma2Attention_fast_forward
|
473 |
+
Gemma2SdpaAttention .forward = Gemma2Attention_fast_forward
|
474 |
+
Gemma2FlashAttention2.forward = Gemma2Attention_fast_forward
|
475 |
+
Gemma2DecoderLayer .forward = Gemma2DecoderLayer_fast_forward
|
476 |
+
Gemma2Model .forward = LlamaModel_fast_forward
|
477 |
+
Gemma2ForCausalLM .forward = CausalLM_fast_forward(Gemma2Model_fast_forward_inference)
|
478 |
+
PeftModelForCausalLM .forward = PeftModelForCausalLM_fast_forward
|
479 |
+
fix_prepare_inputs_for_generation(Gemma2ForCausalLM)
|
480 |
+
|
481 |
+
# Solves https://github.com/unslothai/unsloth/issues/168
|
482 |
+
# Static KV Cache was introduced in 4.38.0, causing training to be much slower.
|
483 |
+
# Inferene can now be CUDAGraphed, but we shall retain the old rotary embeddings.
|
484 |
+
# https://github.com/huggingface/transformers/pull/27931
|
485 |
+
# https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/models/llama/modeling_llama.py
|
486 |
+
import transformers.models.gemma2.modeling_gemma2
|
487 |
+
transformers.models.gemma2.modeling_gemma2.Gemma2RotaryEmbedding = GemmaFixedRotaryEmbedding
|
488 |
+
return
|
489 |
+
pass
|
490 |
+
|
491 |
+
|
492 |
+
@staticmethod
|
493 |
+
def post_patch(model):
|
494 |
+
# Patch model for Gemma
|
495 |
+
layers = model.model.layers
|
496 |
+
|
497 |
+
# Torch.compile fails on embedding matrix??
|
498 |
+
# Workaround randomnly fixes it for torch versions < 2.2
|
499 |
+
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)
|
500 |
+
model.config.update({"unsloth_version" : __version__})
|
501 |
+
|
502 |
+
# We also do this for the lm_head
|
503 |
+
lm_head = torch.nn.Linear(1, 1, bias = None)
|
504 |
+
del lm_head.weight
|
505 |
+
lm_head.weight = model.lm_head.weight
|
506 |
+
lm_head.in_features = lm_head.weight.shape[1]
|
507 |
+
lm_head.out_features = lm_head.weight.shape[0]
|
508 |
+
model.lm_head = lm_head
|
509 |
+
|
510 |
+
# Gemma has tied weights! This means lm_head == embed_tokens
|
511 |
+
if model.model.embed_tokens.weight.data_ptr() != model.lm_head.weight.data_ptr():
|
512 |
+
lm_head = torch.nn.Linear(1, 1, bias = None)
|
513 |
+
del lm_head.weight
|
514 |
+
lm_head.weight = model.model.embed_tokens.weight
|
515 |
+
lm_head.in_features = lm_head.weight.shape[1]
|
516 |
+
lm_head.out_features = lm_head.weight.shape[0]
|
517 |
+
model.lm_head = lm_head
|
518 |
+
pass
|
519 |
+
|
520 |
+
# Also patch all dtypes - BnB seems to not allocate the correct type?
|
521 |
+
# BnB default dtype seems to be float16!
|
522 |
+
correct_dtype = lm_head.weight.dtype
|
523 |
+
|
524 |
+
for name, module in model.named_modules():
|
525 |
+
if isinstance(module, (Bnb_Linear4bit, Peft_Linear4bit)):
|
526 |
+
weight = module.weight
|
527 |
+
quant_state = weight.quant_state
|
528 |
+
|
529 |
+
if type(quant_state) is list:
|
530 |
+
# BnB seems to have float16 as default!
|
531 |
+
module.weight.quant_state[2] = correct_dtype # Cast to correct dtype
|
532 |
+
else:
|
533 |
+
# https://github.com/TimDettmers/bitsandbytes/pull/763/files
|
534 |
+
quant_state.dtype = correct_dtype
|
535 |
+
pass
|
536 |
+
pass
|
537 |
+
# Downcast RoPE embedding to correct data type
|
538 |
+
# RoPE must be done in float32 for Gemma
|
539 |
+
# if (name.endswith("rotary_emb") or hasattr(module, "cos_cached")) \
|
540 |
+
# and (module.cos_cached.dtype != correct_dtype):
|
541 |
+
|
542 |
+
# module.cos_cached = module.cos_cached.to(correct_dtype)
|
543 |
+
# module.sin_cached = module.sin_cached.to(correct_dtype)
|
544 |
+
# pass
|
545 |
+
# pass
|
546 |
+
pass
|
547 |
+
|
548 |
+
# Add 1 to weight
|
549 |
+
# return output * (1 + self.weight)
|
550 |
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L89
|
551 |
+
from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm
|
552 |
+
|
553 |
+
# Freeze all parameters except LoRA
|
554 |
+
# We do this first since += 1 seems to not be liked by requires_grad = True
|
555 |
+
for name, param in model.named_parameters():
|
556 |
+
if ".lora_A." in name or ".lora_B." in name:
|
557 |
+
param.requires_grad_(True)
|
558 |
+
else:
|
559 |
+
param.requires_grad_(False)
|
560 |
+
pass
|
561 |
+
|
562 |
+
# Patch RMS Layernorm
|
563 |
+
for name, module in model.named_modules():
|
564 |
+
if isinstance(module, Gemma2RMSNorm):
|
565 |
+
# Must be in float32
|
566 |
+
# https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L36
|
567 |
+
# module = module.to(torch.float32)
|
568 |
+
# Leave + 1 to Triton kernel itself
|
569 |
+
# module.weight += 1.0 # return output * (1 + self.weight)
|
570 |
+
if not hasattr(module, "variance_epsilon"):
|
571 |
+
module.variance_epsilon = module.eps # Gemma doesn't use variance_epsilon
|
572 |
+
pass
|
573 |
+
|
574 |
+
# Clear deleted GPU items
|
575 |
+
import gc
|
576 |
+
for _ in range(3):
|
577 |
+
gc.collect()
|
578 |
+
torch.cuda.empty_cache()
|
579 |
+
return model
|
580 |
+
pass
|
581 |
+
pass
|
unsloth-main/unsloth-main/unsloth/models/llama.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|